shm_broadcast.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import os
  2. import pickle
  3. import time
  4. from contextlib import contextmanager
  5. from dataclasses import dataclass, field
  6. from multiprocessing import shared_memory
  7. from typing import List, Optional
  8. from unittest.mock import patch
  9. import torch
  10. import torch.distributed as dist
  11. from loguru import logger
  12. from torch.distributed import ProcessGroup
  13. from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
  14. from aphrodite.common.utils import get_ip, get_open_port
  15. APHRODITE_RINGBUFFER_WARNING_INTERVAL = os.getenv(
  16. "APHRODITE_RINGBUFFER_WARNING_INTERVAL", 60)
  17. # time to wait if the queue is full or empty
  18. # if we sleep for too short, it will consume too much CPU
  19. # if we sleep for too long, it will slow down the writer/reader
  20. # 0.1 us is a good balance
  21. RINGBUFFER_SLEEP_INTERVAL = 1e-7
  22. class ShmRingBuffer:
  23. def __init__(self,
  24. n_reader: int,
  25. max_chunk_bytes: int,
  26. max_chunks: int,
  27. name: Optional[str] = None):
  28. """
  29. A shared memory ring buffer implementation for broadcast communication.
  30. Essentially, it is a queue where only one will `enqueue` and multiple
  31. will `dequeue`. The max size of each item, together with the max number
  32. of items that can be stored in the buffer are known in advance.
  33. In this case, we don't need to synchronize the access to
  34. the buffer.
  35. Buffer memory layout:
  36. data metadata
  37. | |
  38. | (current_idx) | (current_idx)
  39. v v
  40. +-------------------------------+----------------------------------------+
  41. | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
  42. +-------------------------------+----------------------------------------+
  43. | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
  44. metadata memory layout: each byte is a flag, the first byte is the written
  45. flag, and the rest are reader flags. The flags are set to 0 by default.
  46. +--------------+--------------+--------------+-----+--------------+
  47. | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
  48. +--------------+--------------+--------------+-----+--------------+
  49. The state of metadata is as follows:
  50. (case 1) 0???...???: the block is not written yet, cannot read, can write
  51. (case 2) 1000...000: the block is just written, can read, cannot write
  52. (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
  53. (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
  54. State transition for readers:
  55. When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
  56. Only after the caller finishes reading the block, the reader can mark the block as read.
  57. Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
  58. State transition for writer:
  59. When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
  60. to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
  61. can reset the reader flags to 0, and mark the block as written (from 0 to 1).
  62. NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
  63. During creation, `name` is None and the buffer is created. We can pass the
  64. created object to other processes by pickling it. The other processes will
  65. get the name of the shared memory and open it, so that they can access the
  66. same shared memory buffer.
  67. """# noqa
  68. self.n_reader = n_reader
  69. self.metadata_size = 1 + n_reader
  70. self.max_chunk_bytes = max_chunk_bytes
  71. self.max_chunks = max_chunks
  72. self.total_bytes_of_buffer = (self.max_chunk_bytes +
  73. self.metadata_size) * self.max_chunks
  74. self.data_offset = 0
  75. self.metadata_offset = self.max_chunk_bytes * self.max_chunks
  76. if name is None:
  77. # we are creating a buffer
  78. self.is_creator = True
  79. self.shared_memory = shared_memory.SharedMemory(
  80. create=True, size=self.total_bytes_of_buffer)
  81. # initialize the metadata section to 0
  82. with memoryview(self.shared_memory.buf[self.metadata_offset:]
  83. ) as metadata_buffer:
  84. torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
  85. else:
  86. # we are opening an existing buffer
  87. self.is_creator = False
  88. # fix to https://stackoverflow.com/q/62748654/9191338
  89. # Python incorrectly tracks shared memory even if it is not
  90. # created by the process. The following patch is a workaround.
  91. with patch("multiprocessing.resource_tracker.register",
  92. lambda *args, **kwargs: None):
  93. try:
  94. self.shared_memory = shared_memory.SharedMemory(name=name)
  95. assert self.shared_memory.size == self.total_bytes_of_buffer # noqa
  96. except FileNotFoundError:
  97. # we might deserialize the object in a different node
  98. # in this case, this object is not used,
  99. # and we should suppress the error
  100. pass
  101. def __reduce__(self):
  102. return (
  103. self.__class__,
  104. (self.n_reader, self.max_chunk_bytes, self.max_chunks,
  105. self.shared_memory.name),
  106. )
  107. def __del__(self):
  108. if hasattr(self, "shared_memory"):
  109. self.shared_memory.close()
  110. if self.is_creator:
  111. self.shared_memory.unlink()
  112. @contextmanager
  113. def get_data(self, current_idx: int):
  114. start = self.data_offset + current_idx * self.max_chunk_bytes
  115. end = start + self.max_chunk_bytes
  116. with memoryview(self.shared_memory.buf[start:end]) as buf:
  117. yield buf
  118. @contextmanager
  119. def get_metadata(self, current_idx: int):
  120. start = self.metadata_offset + current_idx * self.metadata_size
  121. end = start + self.metadata_size
  122. with memoryview(self.shared_memory.buf[start:end]) as buf:
  123. yield buf
  124. @dataclass
  125. class Handle:
  126. connect_ip: str
  127. local_reader_ranks: List[int] = field(default_factory=list)
  128. buffer: Optional[ShmRingBuffer] = None
  129. local_subscribe_port: Optional[int] = None
  130. remote_subscribe_port: Optional[int] = None
  131. class MessageQueue:
  132. def __init__(
  133. self,
  134. n_reader, # number of all readers
  135. n_local_reader, # number of local readers through shared memory
  136. local_reader_ranks: Optional[List[int]] = None,
  137. max_chunk_bytes: int = 1024 * 1024 * 10,
  138. max_chunks: int = 10,
  139. connect_ip: Optional[str] = None,
  140. ):
  141. if local_reader_ranks is None:
  142. local_reader_ranks = list(range(n_local_reader))
  143. else:
  144. assert len(local_reader_ranks) == n_local_reader
  145. self.n_local_reader = n_local_reader
  146. n_remote_reader = n_reader - n_local_reader
  147. self.n_remote_reader = n_remote_reader
  148. if connect_ip is None:
  149. connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
  150. context = Context()
  151. if n_local_reader > 0:
  152. # for local readers, we will:
  153. # 1. create a shared memory ring buffer to communicate small data
  154. # 2. create a publish-subscribe socket to communicate large data
  155. self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
  156. max_chunks)
  157. # XPUB is very similar to PUB,
  158. # except that it can receive subscription messages
  159. # to confirm the number of subscribers
  160. self.local_socket = context.socket(XPUB)
  161. # set the verbose option so that we can receive every subscription
  162. # message. otherwise, we will only receive the first subscription
  163. # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
  164. self.local_socket.setsockopt(XPUB_VERBOSE, True)
  165. local_subscribe_port = get_open_port()
  166. self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
  167. self.current_idx = 0
  168. else:
  169. self.buffer = None # type: ignore
  170. local_subscribe_port = None
  171. self.local_socket = None
  172. self.current_idx = -1
  173. if n_remote_reader > 0:
  174. # for remote readers, we will:
  175. # create a publish-subscribe socket to communicate large data
  176. self.remote_socket = context.socket(XPUB)
  177. self.remote_socket.setsockopt(XPUB_VERBOSE, True)
  178. remote_subscribe_port = get_open_port()
  179. self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
  180. else:
  181. remote_subscribe_port = None
  182. self.remote_socket = None
  183. self._is_writer = True
  184. self._is_local_reader = False
  185. self.local_reader_rank = -1
  186. # rank does not matter for remote readers
  187. self._is_remote_reader = False
  188. self.handle = Handle(
  189. connect_ip=connect_ip,
  190. local_reader_ranks=local_reader_ranks,
  191. buffer=self.buffer,
  192. local_subscribe_port=local_subscribe_port,
  193. remote_subscribe_port=remote_subscribe_port,
  194. )
  195. logger.debug("Aphrodite message queue communication handle: "
  196. f"{self.handle}")
  197. def export_handle(self) -> Handle:
  198. return self.handle
  199. @staticmethod
  200. def create_from_handle(handle: Handle, rank) -> "MessageQueue":
  201. self = MessageQueue.__new__(MessageQueue)
  202. self.handle = handle
  203. self._is_writer = False
  204. context = Context()
  205. if rank in handle.local_reader_ranks:
  206. assert handle.buffer is not None
  207. self.buffer = handle.buffer
  208. self.current_idx = 0
  209. self.local_reader_rank = handle.local_reader_ranks.index(rank)
  210. self._is_local_reader = True
  211. self._is_remote_reader = False
  212. self.local_socket = context.socket(SUB)
  213. self.local_socket.setsockopt_string(SUBSCRIBE, "")
  214. self.local_socket.connect(
  215. f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
  216. self.remote_socket = None
  217. else:
  218. self.buffer = None # type: ignore
  219. self.current_idx = -1
  220. self.local_reader_rank = -1
  221. self._is_local_reader = False
  222. self._is_remote_reader = True
  223. self.local_socket = None
  224. self.remote_socket = context.socket(SUB)
  225. self.remote_socket.setsockopt_string(SUBSCRIBE, "")
  226. self.remote_socket.connect(
  227. f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
  228. return self
  229. def wait_until_ready(self):
  230. """This is a collective operation. All processes (including the
  231. readers and the writer) should call this function.
  232. """
  233. if self._is_writer:
  234. # wait for all readers to connect
  235. # local readers
  236. for i in range(self.n_local_reader):
  237. # wait for subscription messages from all local readers
  238. self.local_socket.recv()
  239. if self.n_local_reader > 0:
  240. # send a message to all local readers
  241. # to make sure the publish channel is working
  242. self.local_socket.send(b"READY")
  243. # remote readers
  244. for i in range(self.n_remote_reader):
  245. # wait for subscription messages from all remote readers
  246. self.remote_socket.recv()
  247. if self.n_remote_reader > 0:
  248. # send a message to all remote readers
  249. # to make sure the publish channel is working
  250. self.remote_socket.send(b"READY")
  251. elif self._is_local_reader:
  252. # wait for the writer to send a message
  253. recv = self.local_socket.recv()
  254. assert recv == b"READY"
  255. elif self._is_remote_reader:
  256. # wait for the writer to send a message
  257. recv = self.remote_socket.recv()
  258. assert recv == b"READY"
  259. @contextmanager
  260. def acquire_write(self):
  261. assert self._is_writer, "Only writers can acquire write"
  262. start_time = time.monotonic()
  263. n_warning = 1
  264. while True:
  265. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  266. read_count = sum(metadata_buffer[1:])
  267. written_flag = metadata_buffer[0]
  268. if written_flag and read_count != self.buffer.n_reader:
  269. # this block is written and not read by all readers
  270. # for writers, `self.current_idx` is the next block to write
  271. # if this block is not ready to write,
  272. # we need to wait until it is read by all readers
  273. # wait for a while
  274. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  275. # if we wait for a long time, we should warn the user
  276. if time.monotonic(
  277. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  278. logger.warning(
  279. "No available block found in "
  280. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds")
  281. n_warning += 1
  282. continue
  283. # found a block that is either
  284. # (1) not written
  285. # (2) read by all readers
  286. # mark the block as not written
  287. metadata_buffer[0] = 0
  288. # let caller write to the buffer
  289. with self.buffer.get_data(self.current_idx) as buf:
  290. yield buf
  291. # caller has written to the buffer
  292. # NOTE: order is important here
  293. # first set the read flags to 0
  294. # then set the written flag to 1
  295. # otherwise, the readers may think they already read the block
  296. for i in range(1, self.buffer.n_reader + 1):
  297. # set read flag to 0, meaning it is not read yet
  298. metadata_buffer[i] = 0
  299. # mark the block as written
  300. metadata_buffer[0] = 1
  301. self.current_idx = (self.current_idx +
  302. 1) % self.buffer.max_chunks
  303. break
  304. @contextmanager
  305. def acquire_read(self):
  306. assert self._is_local_reader, "Only readers can acquire read"
  307. start_time = time.monotonic()
  308. n_warning = 1
  309. while True:
  310. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  311. read_flag = metadata_buffer[self.local_reader_rank + 1]
  312. written_flag = metadata_buffer[0]
  313. if not written_flag or read_flag:
  314. # this block is either
  315. # (1) not written
  316. # (2) already read by this reader
  317. # for readers, `self.current_idx` is the next block to read
  318. # if this block is not ready,
  319. # we need to wait until it is written
  320. # wait for a while
  321. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  322. # if we wait for a long time, we should warn the user
  323. if time.monotonic(
  324. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  325. logger.warning(
  326. "No available block found in "
  327. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds."
  328. )
  329. n_warning += 1
  330. continue
  331. # found a block that is not read by this reader
  332. # let caller read from the buffer
  333. with self.buffer.get_data(self.current_idx) as buf:
  334. yield buf
  335. # caller has read from the buffer
  336. # set the read flag
  337. metadata_buffer[self.local_reader_rank + 1] = 1
  338. self.current_idx = (self.current_idx +
  339. 1) % self.buffer.max_chunks
  340. break
  341. def enqueue(self, obj):
  342. assert self._is_writer, "Only writers can enqueue"
  343. serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
  344. if self.n_local_reader > 0:
  345. if len(serialized_obj) >= self.buffer.max_chunk_bytes:
  346. with self.acquire_write() as buf:
  347. buf[0] = 1 # overflow
  348. self.local_socket.send(serialized_obj)
  349. else:
  350. with self.acquire_write() as buf:
  351. buf[0] = 0 # not overflow
  352. buf[1:len(serialized_obj) + 1] = serialized_obj
  353. if self.n_remote_reader > 0:
  354. self.remote_socket.send(serialized_obj)
  355. def dequeue(self):
  356. if self._is_local_reader:
  357. with self.acquire_read() as buf:
  358. overflow = buf[0] == 1
  359. if not overflow:
  360. # no need to know the size of serialized object
  361. # pickle format contains the size information internally
  362. # see https://docs.python.org/3/library/pickle.html
  363. obj = pickle.loads(buf[1:])
  364. if overflow:
  365. recv = self.local_socket.recv()
  366. obj = pickle.loads(recv)
  367. elif self._is_remote_reader:
  368. recv = self.remote_socket.recv()
  369. obj = pickle.loads(recv)
  370. else:
  371. raise RuntimeError("Only readers can dequeue")
  372. return obj
  373. def broadcast_object(self, obj=None):
  374. if self._is_writer:
  375. self.enqueue(obj)
  376. return obj
  377. else:
  378. return self.dequeue()
  379. @staticmethod
  380. def create_from_process_group(pg: ProcessGroup,
  381. max_chunk_bytes,
  382. max_chunks,
  383. writer_rank=0) -> "MessageQueue":
  384. group_rank = dist.get_rank(pg)
  385. group_world_size = dist.get_world_size(pg)
  386. global_ranks = dist.get_process_group_ranks(pg)
  387. from aphrodite.distributed.parallel_state import in_the_same_node_as
  388. status = in_the_same_node_as(pg, source_rank=writer_rank)
  389. same_node_ranks = [i for i, s in enumerate(status) if s]
  390. n_reader = group_world_size - 1
  391. n_local_reader = len(same_node_ranks) - 1
  392. local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
  393. buffer_io: MessageQueue
  394. if group_rank == writer_rank:
  395. buffer_io = MessageQueue(
  396. n_reader=n_reader,
  397. n_local_reader=n_local_reader,
  398. local_reader_ranks=local_reader_ranks,
  399. max_chunk_bytes=max_chunk_bytes,
  400. max_chunks=max_chunks,
  401. )
  402. handle = buffer_io.export_handle()
  403. dist.broadcast_object_list([handle],
  404. src=global_ranks[writer_rank],
  405. group=pg)
  406. else:
  407. recv = [None]
  408. dist.broadcast_object_list(recv,
  409. src=global_ranks[writer_rank],
  410. group=pg)
  411. handle = recv[0] # type: ignore
  412. buffer_io = MessageQueue.create_from_handle(handle, group_rank)
  413. buffer_io.wait_until_ready()
  414. return buffer_io