1
0

shm_broadcast.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import pickle
  2. import time
  3. from contextlib import contextmanager
  4. from dataclasses import dataclass, field
  5. from multiprocessing import shared_memory
  6. from typing import List, Optional
  7. from unittest.mock import patch
  8. import torch
  9. import torch.distributed as dist
  10. from loguru import logger
  11. from torch.distributed import ProcessGroup
  12. from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
  13. import aphrodite.common.envs as envs
  14. from aphrodite.common.utils import get_ip, get_open_port
  15. APHRODITE_RINGBUFFER_WARNING_INTERVAL = (
  16. envs.APHRODITE_RINGBUFFER_WARNING_INTERVAL)
  17. # time to wait if the queue is full or empty
  18. # if we sleep for too short, it will consume too much CPU
  19. # if we sleep for too long, it will slow down the writer/reader
  20. # 0.1 us is a good balance
  21. RINGBUFFER_SLEEP_INTERVAL = 1e-7
  22. class ShmRingBuffer:
  23. def __init__(self,
  24. n_reader: int,
  25. max_chunk_bytes: int,
  26. max_chunks: int,
  27. name: Optional[str] = None):
  28. """
  29. A shared memory ring buffer implementation for broadcast communication.
  30. Essentially, it is a queue where only one will `enqueue` and multiple
  31. will `dequeue`. The max size of each item, together with the max number
  32. of items that can be stored in the buffer are known in advance.
  33. In this case, we don't need to synchronize the access to
  34. the buffer.
  35. Buffer memory layout:
  36. data metadata
  37. | |
  38. | (current_idx) | (current_idx)
  39. v v
  40. +-------------------------------+----------------------------------------+
  41. | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
  42. +-------------------------------+----------------------------------------+
  43. | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
  44. metadata memory layout: each byte is a flag, the first byte is the written
  45. flag, and the rest are reader flags. The flags are set to 0 by default.
  46. +--------------+--------------+--------------+-----+--------------+
  47. | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
  48. +--------------+--------------+--------------+-----+--------------+
  49. The state of metadata is as follows:
  50. (case 1) 0???...???: the block is not written yet, cannot read, can write
  51. (case 2) 1000...000: the block is just written, can read, cannot write
  52. (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
  53. (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
  54. State transition for readers:
  55. When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
  56. Only after the caller finishes reading the block, the reader can mark the block as read.
  57. Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
  58. State transition for writer:
  59. When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
  60. to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
  61. can reset the reader flags to 0, and mark the block as written (from 0 to 1).
  62. NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
  63. During creation, `name` is None and the buffer is created. We can pass the
  64. created object to other processes by pickling it. The other processes will
  65. get the name of the shared memory and open it, so that they can access the
  66. same shared memory buffer.
  67. """# noqa
  68. self.n_reader = n_reader
  69. self.metadata_size = 1 + n_reader
  70. self.max_chunk_bytes = max_chunk_bytes
  71. self.max_chunks = max_chunks
  72. self.total_bytes_of_buffer = (self.max_chunk_bytes +
  73. self.metadata_size) * self.max_chunks
  74. self.data_offset = 0
  75. self.metadata_offset = self.max_chunk_bytes * self.max_chunks
  76. if name is None:
  77. # we are creating a buffer
  78. self.is_creator = True
  79. self.shared_memory = shared_memory.SharedMemory(
  80. create=True, size=self.total_bytes_of_buffer)
  81. # initialize the metadata section to 0
  82. with memoryview(self.shared_memory.buf[self.metadata_offset:]
  83. ) as metadata_buffer:
  84. torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
  85. else:
  86. # we are opening an existing buffer
  87. self.is_creator = False
  88. # fix to https://stackoverflow.com/q/62748654/9191338
  89. # Python incorrectly tracks shared memory even if it is not
  90. # created by the process. The following patch is a workaround.
  91. with patch("multiprocessing.resource_tracker.register",
  92. lambda *args, **kwargs: None):
  93. try:
  94. self.shared_memory = shared_memory.SharedMemory(name=name)
  95. assert self.shared_memory.size == self.total_bytes_of_buffer # noqa
  96. except FileNotFoundError:
  97. # we might deserialize the object in a different node
  98. # in this case, this object is not used,
  99. # and we should suppress the error
  100. pass
  101. def __reduce__(self):
  102. return (
  103. self.__class__,
  104. (self.n_reader, self.max_chunk_bytes, self.max_chunks,
  105. self.shared_memory.name),
  106. )
  107. def __del__(self):
  108. if hasattr(self, "shared_memory"):
  109. self.shared_memory.close()
  110. if self.is_creator:
  111. self.shared_memory.unlink()
  112. @contextmanager
  113. def get_data(self, current_idx: int):
  114. start = self.data_offset + current_idx * self.max_chunk_bytes
  115. end = start + self.max_chunk_bytes
  116. with memoryview(self.shared_memory.buf[start:end]) as buf:
  117. yield buf
  118. @contextmanager
  119. def get_metadata(self, current_idx: int):
  120. start = self.metadata_offset + current_idx * self.metadata_size
  121. end = start + self.metadata_size
  122. with memoryview(self.shared_memory.buf[start:end]) as buf:
  123. yield buf
  124. @dataclass
  125. class Handle:
  126. connect_ip: str
  127. local_reader_ranks: List[int] = field(default_factory=list)
  128. buffer: Optional[ShmRingBuffer] = None
  129. local_subscribe_port: Optional[int] = None
  130. remote_subscribe_port: Optional[int] = None
  131. class MessageQueue:
  132. def __init__(
  133. self,
  134. n_reader, # number of all readers
  135. n_local_reader, # number of local readers through shared memory
  136. local_reader_ranks: Optional[List[int]] = None,
  137. max_chunk_bytes: int = 1024 * 1024 * 10,
  138. max_chunks: int = 10,
  139. connect_ip: Optional[str] = None,
  140. ):
  141. if local_reader_ranks is None:
  142. local_reader_ranks = list(range(n_local_reader))
  143. else:
  144. assert len(local_reader_ranks) == n_local_reader
  145. self.n_local_reader = n_local_reader
  146. n_remote_reader = n_reader - n_local_reader
  147. self.n_remote_reader = n_remote_reader
  148. if connect_ip is None:
  149. connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
  150. context = Context()
  151. if n_local_reader > 0:
  152. # for local readers, we will:
  153. # 1. create a shared memory ring buffer to communicate small data
  154. # 2. create a publish-subscribe socket to communicate large data
  155. self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
  156. max_chunks)
  157. # XPUB is very similar to PUB,
  158. # except that it can receive subscription messages
  159. # to confirm the number of subscribers
  160. self.local_socket = context.socket(XPUB)
  161. # set the verbose option so that we can receive every subscription
  162. # message. otherwise, we will only receive the first subscription
  163. # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
  164. self.local_socket.setsockopt(XPUB_VERBOSE, True)
  165. local_subscribe_port = get_open_port()
  166. socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
  167. logger.debug(f"Binding to {socket_addr}")
  168. self.local_socket.bind(socket_addr)
  169. self.current_idx = 0
  170. else:
  171. self.buffer = None # type: ignore
  172. local_subscribe_port = None
  173. self.local_socket = None
  174. self.current_idx = -1
  175. if n_remote_reader > 0:
  176. # for remote readers, we will:
  177. # create a publish-subscribe socket to communicate large data
  178. self.remote_socket = context.socket(XPUB)
  179. self.remote_socket.setsockopt(XPUB_VERBOSE, True)
  180. remote_subscribe_port = get_open_port()
  181. socket_addr = f"tcp://*:{remote_subscribe_port}"
  182. self.remote_socket.bind(socket_addr)
  183. else:
  184. remote_subscribe_port = None
  185. self.remote_socket = None
  186. self._is_writer = True
  187. self._is_local_reader = False
  188. self.local_reader_rank = -1
  189. # rank does not matter for remote readers
  190. self._is_remote_reader = False
  191. self.handle = Handle(
  192. connect_ip=connect_ip,
  193. local_reader_ranks=local_reader_ranks,
  194. buffer=self.buffer,
  195. local_subscribe_port=local_subscribe_port,
  196. remote_subscribe_port=remote_subscribe_port,
  197. )
  198. logger.debug("Aphrodite message queue communication handle: "
  199. f"{self.handle}")
  200. def export_handle(self) -> Handle:
  201. return self.handle
  202. @staticmethod
  203. def create_from_handle(handle: Handle, rank) -> "MessageQueue":
  204. self = MessageQueue.__new__(MessageQueue)
  205. self.handle = handle
  206. self._is_writer = False
  207. context = Context()
  208. if rank in handle.local_reader_ranks:
  209. assert handle.buffer is not None
  210. self.buffer = handle.buffer
  211. self.current_idx = 0
  212. self.local_reader_rank = handle.local_reader_ranks.index(rank)
  213. self._is_local_reader = True
  214. self._is_remote_reader = False
  215. self.local_socket = context.socket(SUB)
  216. self.local_socket.setsockopt_string(SUBSCRIBE, "")
  217. socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
  218. logger.debug(f"Connecting to {socket_addr}")
  219. self.local_socket.connect(socket_addr)
  220. self.remote_socket = None
  221. else:
  222. self.buffer = None # type: ignore
  223. self.current_idx = -1
  224. self.local_reader_rank = -1
  225. self._is_local_reader = False
  226. self._is_remote_reader = True
  227. self.local_socket = None
  228. self.remote_socket = context.socket(SUB)
  229. self.remote_socket.setsockopt_string(SUBSCRIBE, "")
  230. socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
  231. logger.debug(f"Connecting to {socket_addr}")
  232. self.remote_socket.connect(socket_addr)
  233. return self
  234. def wait_until_ready(self):
  235. """This is a collective operation. All processes (including the
  236. readers and the writer) should call this function.
  237. """
  238. if self._is_writer:
  239. # wait for all readers to connect
  240. # local readers
  241. for i in range(self.n_local_reader):
  242. # wait for subscription messages from all local readers
  243. self.local_socket.recv()
  244. if self.n_local_reader > 0:
  245. # send a message to all local readers
  246. # to make sure the publish channel is working
  247. self.local_socket.send(b"READY")
  248. # remote readers
  249. for i in range(self.n_remote_reader):
  250. # wait for subscription messages from all remote readers
  251. self.remote_socket.recv()
  252. if self.n_remote_reader > 0:
  253. # send a message to all remote readers
  254. # to make sure the publish channel is working
  255. self.remote_socket.send(b"READY")
  256. elif self._is_local_reader:
  257. # wait for the writer to send a message
  258. recv = self.local_socket.recv()
  259. assert recv == b"READY"
  260. elif self._is_remote_reader:
  261. # wait for the writer to send a message
  262. recv = self.remote_socket.recv()
  263. assert recv == b"READY"
  264. @contextmanager
  265. def acquire_write(self):
  266. assert self._is_writer, "Only writers can acquire write"
  267. start_time = time.monotonic()
  268. n_warning = 1
  269. while True:
  270. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  271. read_count = sum(metadata_buffer[1:])
  272. written_flag = metadata_buffer[0]
  273. if written_flag and read_count != self.buffer.n_reader:
  274. # this block is written and not read by all readers
  275. # for writers, `self.current_idx` is the next block to write
  276. # if this block is not ready to write,
  277. # we need to wait until it is read by all readers
  278. # wait for a while
  279. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  280. # if we wait for a long time, we should warn the user
  281. if time.monotonic(
  282. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  283. logger.warning(
  284. "No available block found in "
  285. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds")
  286. n_warning += 1
  287. continue
  288. # found a block that is either
  289. # (1) not written
  290. # (2) read by all readers
  291. # mark the block as not written
  292. metadata_buffer[0] = 0
  293. # let caller write to the buffer
  294. with self.buffer.get_data(self.current_idx) as buf:
  295. yield buf
  296. # caller has written to the buffer
  297. # NOTE: order is important here
  298. # first set the read flags to 0
  299. # then set the written flag to 1
  300. # otherwise, the readers may think they already read the block
  301. for i in range(1, self.buffer.n_reader + 1):
  302. # set read flag to 0, meaning it is not read yet
  303. metadata_buffer[i] = 0
  304. # mark the block as written
  305. metadata_buffer[0] = 1
  306. self.current_idx = (self.current_idx +
  307. 1) % self.buffer.max_chunks
  308. break
  309. @contextmanager
  310. def acquire_read(self):
  311. assert self._is_local_reader, "Only readers can acquire read"
  312. start_time = time.monotonic()
  313. n_warning = 1
  314. while True:
  315. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  316. read_flag = metadata_buffer[self.local_reader_rank + 1]
  317. written_flag = metadata_buffer[0]
  318. if not written_flag or read_flag:
  319. # this block is either
  320. # (1) not written
  321. # (2) already read by this reader
  322. # for readers, `self.current_idx` is the next block to read
  323. # if this block is not ready,
  324. # we need to wait until it is written
  325. # wait for a while
  326. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  327. # if we wait for a long time, we should warn the user
  328. if time.monotonic(
  329. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  330. logger.warning(
  331. "No available block found in "
  332. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds."
  333. )
  334. n_warning += 1
  335. continue
  336. # found a block that is not read by this reader
  337. # let caller read from the buffer
  338. with self.buffer.get_data(self.current_idx) as buf:
  339. yield buf
  340. # caller has read from the buffer
  341. # set the read flag
  342. metadata_buffer[self.local_reader_rank + 1] = 1
  343. self.current_idx = (self.current_idx +
  344. 1) % self.buffer.max_chunks
  345. break
  346. def enqueue(self, obj):
  347. assert self._is_writer, "Only writers can enqueue"
  348. serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
  349. if self.n_local_reader > 0:
  350. if len(serialized_obj) >= self.buffer.max_chunk_bytes:
  351. with self.acquire_write() as buf:
  352. buf[0] = 1 # overflow
  353. self.local_socket.send(serialized_obj)
  354. else:
  355. with self.acquire_write() as buf:
  356. buf[0] = 0 # not overflow
  357. buf[1:len(serialized_obj) + 1] = serialized_obj
  358. if self.n_remote_reader > 0:
  359. self.remote_socket.send(serialized_obj)
  360. def dequeue(self):
  361. if self._is_local_reader:
  362. with self.acquire_read() as buf:
  363. overflow = buf[0] == 1
  364. if not overflow:
  365. # no need to know the size of serialized object
  366. # pickle format contains the size information internally
  367. # see https://docs.python.org/3/library/pickle.html
  368. obj = pickle.loads(buf[1:])
  369. if overflow:
  370. recv = self.local_socket.recv()
  371. obj = pickle.loads(recv)
  372. elif self._is_remote_reader:
  373. recv = self.remote_socket.recv()
  374. obj = pickle.loads(recv)
  375. else:
  376. raise RuntimeError("Only readers can dequeue")
  377. return obj
  378. def broadcast_object(self, obj=None):
  379. if self._is_writer:
  380. self.enqueue(obj)
  381. return obj
  382. else:
  383. return self.dequeue()
  384. @staticmethod
  385. def create_from_process_group(pg: ProcessGroup,
  386. max_chunk_bytes,
  387. max_chunks,
  388. writer_rank=0) -> "MessageQueue":
  389. group_rank = dist.get_rank(pg)
  390. group_world_size = dist.get_world_size(pg)
  391. global_ranks = dist.get_process_group_ranks(pg)
  392. from aphrodite.distributed.parallel_state import in_the_same_node_as
  393. status = in_the_same_node_as(pg, source_rank=writer_rank)
  394. same_node_ranks = [i for i, s in enumerate(status) if s]
  395. n_reader = group_world_size - 1
  396. n_local_reader = len(same_node_ranks) - 1
  397. local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
  398. buffer_io: MessageQueue
  399. if group_rank == writer_rank:
  400. buffer_io = MessageQueue(
  401. n_reader=n_reader,
  402. n_local_reader=n_local_reader,
  403. local_reader_ranks=local_reader_ranks,
  404. max_chunk_bytes=max_chunk_bytes,
  405. max_chunks=max_chunks,
  406. )
  407. handle = buffer_io.export_handle()
  408. dist.broadcast_object_list([handle],
  409. src=global_ranks[writer_rank],
  410. group=pg)
  411. else:
  412. recv = [None]
  413. dist.broadcast_object_list(recv,
  414. src=global_ranks[writer_rank],
  415. group=pg)
  416. handle = recv[0] # type: ignore
  417. buffer_io = MessageQueue.create_from_handle(handle, group_rank)
  418. buffer_io.wait_until_ready()
  419. return buffer_io