shm_broadcast.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import os
  2. import pickle
  3. import time
  4. from contextlib import contextmanager
  5. from dataclasses import dataclass, field
  6. from multiprocessing import shared_memory
  7. from typing import List, Optional
  8. from unittest.mock import patch
  9. import torch
  10. import torch.distributed as dist
  11. from loguru import logger
  12. from torch.distributed import ProcessGroup
  13. from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
  14. from aphrodite.common.utils import get_ip, get_open_port
  15. APHRODITE_RINGBUFFER_WARNING_INTERVAL = os.getenv(
  16. "APHRODITE_RINGBUFFER_WARNING_INTERVAL", 60)
  17. # time to wait if the queue is full or empty
  18. # if we sleep for too short, it will consume too much CPU
  19. # if we sleep for too long, it will slow down the writer/reader
  20. # 0.1 us is a good balance
  21. RINGBUFFER_SLEEP_INTERVAL = 1e-7
  22. class ShmRingBuffer:
  23. def __init__(self,
  24. n_reader: int,
  25. max_chunk_bytes: int,
  26. max_chunks: int,
  27. name: Optional[str] = None):
  28. """
  29. A shared memory ring buffer implementation for broadcast communication.
  30. Essentially, it is a queue where only one will `enqueue` and multiple
  31. will `dequeue`. The max size of each item, together with the max number
  32. of items that can be stored in the buffer are known in advance.
  33. In this case, we don't need to synchronize the access to
  34. the buffer.
  35. Buffer memory layout:
  36. data metadata
  37. | |
  38. | (current_idx) | (current_idx)
  39. v v
  40. +-------------------------------+----------------------------------------+
  41. | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
  42. +-------------------------------+----------------------------------------+
  43. | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
  44. metadata memory layout: each byte is a flag, the first byte is the written
  45. flag, and the rest are reader flags. The flags are set to 0 by default.
  46. +--------------+--------------+--------------+-----+--------------+
  47. | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
  48. +--------------+--------------+--------------+-----+--------------+
  49. The state of metadata is as follows:
  50. (case 1) 0???...???: the block is not written yet, cannot read, can write
  51. (case 2) 1000...000: the block is just written, can read, cannot write
  52. (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
  53. (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
  54. State transition for readers:
  55. When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
  56. Only after the caller finishes reading the block, the reader can mark the block as read.
  57. Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
  58. State transition for writer:
  59. When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
  60. to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
  61. can reset the reader flags to 0, and mark the block as written (from 0 to 1).
  62. NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
  63. During creation, `name` is None and the buffer is created. We can pass the
  64. created object to other processes by pickling it. The other processes will
  65. get the name of the shared memory and open it, so that they can access the
  66. same shared memory buffer.
  67. """# noqa
  68. self.n_reader = n_reader
  69. self.metadata_size = 1 + n_reader
  70. self.max_chunk_bytes = max_chunk_bytes
  71. self.max_chunks = max_chunks
  72. self.total_bytes_of_buffer = (self.max_chunk_bytes +
  73. self.metadata_size) * self.max_chunks
  74. self.data_offset = 0
  75. self.metadata_offset = self.max_chunk_bytes * self.max_chunks
  76. if name is None:
  77. # we are creating a buffer
  78. self.is_creator = True
  79. self.shared_memory = shared_memory.SharedMemory(
  80. create=True, size=self.total_bytes_of_buffer)
  81. # initialize the metadata section to 0
  82. with memoryview(self.shared_memory.buf[self.metadata_offset:]
  83. ) as metadata_buffer:
  84. torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
  85. else:
  86. # we are opening an existing buffer
  87. self.is_creator = False
  88. # fix to https://stackoverflow.com/q/62748654/9191338
  89. # Python incorrectly tracks shared memory even if it is not
  90. # created by the process. The following patch is a workaround.
  91. with patch("multiprocessing.resource_tracker.register",
  92. lambda *args, **kwargs: None):
  93. try:
  94. self.shared_memory = shared_memory.SharedMemory(name=name)
  95. assert self.shared_memory.size == self.total_bytes_of_buffer # noqa
  96. except FileNotFoundError:
  97. # we might deserialize the object in a different node
  98. # in this case, this object is not used,
  99. # and we should suppress the error
  100. pass
  101. def __reduce__(self):
  102. return (
  103. self.__class__,
  104. (self.n_reader, self.max_chunk_bytes, self.max_chunks,
  105. self.shared_memory.name),
  106. )
  107. def __del__(self):
  108. if hasattr(self, "shared_memory"):
  109. self.shared_memory.close()
  110. if self.is_creator:
  111. self.shared_memory.unlink()
  112. @contextmanager
  113. def get_data(self, current_idx: int):
  114. start = self.data_offset + current_idx * self.max_chunk_bytes
  115. end = start + self.max_chunk_bytes
  116. with memoryview(self.shared_memory.buf[start:end]) as buf:
  117. yield buf
  118. @contextmanager
  119. def get_metadata(self, current_idx: int):
  120. start = self.metadata_offset + current_idx * self.metadata_size
  121. end = start + self.metadata_size
  122. with memoryview(self.shared_memory.buf[start:end]) as buf:
  123. yield buf
  124. @dataclass
  125. class Handle:
  126. connect_ip: str
  127. local_reader_ranks: List[int] = field(default_factory=list)
  128. buffer: Optional[ShmRingBuffer] = None
  129. local_subscribe_port: Optional[int] = None
  130. local_sync_port: Optional[int] = None
  131. remote_subscribe_port: Optional[int] = None
  132. remote_sync_port: Optional[int] = None
  133. class MessageQueue:
  134. def __init__(
  135. self,
  136. n_reader, # number of all readers
  137. n_local_reader, # number of local readers through shared memory
  138. local_reader_ranks: Optional[List[int]] = None,
  139. max_chunk_bytes: int = 1024 * 1024 * 10,
  140. max_chunks: int = 10,
  141. connect_ip: Optional[str] = None,
  142. ):
  143. if local_reader_ranks is None:
  144. local_reader_ranks = list(range(n_local_reader))
  145. else:
  146. assert len(local_reader_ranks) == n_local_reader
  147. self.n_local_reader = n_local_reader
  148. n_remote_reader = n_reader - n_local_reader
  149. self.n_remote_reader = n_remote_reader
  150. if connect_ip is None:
  151. connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
  152. context = Context()
  153. if n_local_reader > 0:
  154. # for local readers, we will:
  155. # 1. create a shared memory ring buffer to communicate small data
  156. # 2. create a publish-subscribe socket to communicate large data
  157. self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
  158. max_chunks)
  159. self.local_socket = context.socket(PUB)
  160. local_subscribe_port = get_open_port()
  161. self.local_socket.bind(f"tcp://*:{local_subscribe_port}")
  162. self.local_sync_socket = context.socket(REP)
  163. local_sync_port = get_open_port()
  164. self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
  165. self.current_idx = 0
  166. else:
  167. self.buffer = None # type: ignore
  168. local_subscribe_port = None
  169. local_sync_port = None
  170. self.local_socket = None
  171. self.local_sync_socket = None
  172. self.current_idx = -1
  173. if n_remote_reader > 0:
  174. # for remote readers, we will:
  175. # create a publish-subscribe socket to communicate large data
  176. self.remote_socket = context.socket(PUB)
  177. remote_subscribe_port = get_open_port()
  178. self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")
  179. self.remote_sync_socket = context.socket(REP)
  180. remote_sync_port = get_open_port()
  181. self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
  182. else:
  183. remote_subscribe_port = None
  184. remote_sync_port = None
  185. self.remote_socket = None
  186. self.remote_sync_socket = None
  187. self._is_writer = True
  188. self._is_local_reader = False
  189. self.local_reader_rank = -1
  190. # rank does not matter for remote readers
  191. self._is_remote_reader = False
  192. self.handle = Handle(
  193. connect_ip=connect_ip,
  194. local_reader_ranks=local_reader_ranks,
  195. buffer=self.buffer,
  196. local_subscribe_port=local_subscribe_port,
  197. local_sync_port=local_sync_port,
  198. remote_subscribe_port=remote_subscribe_port,
  199. remote_sync_port=remote_sync_port,
  200. )
  201. logger.info("Aphrodite message queue communication handle: "
  202. f"{self.handle}")
  203. def export_handle(self) -> Handle:
  204. return self.handle
  205. @staticmethod
  206. def create_from_handle(handle: Handle, rank) -> "MessageQueue":
  207. self = MessageQueue.__new__(MessageQueue)
  208. self.handle = handle
  209. self._is_writer = False
  210. context = Context()
  211. if rank in handle.local_reader_ranks:
  212. assert handle.buffer is not None
  213. self.buffer = handle.buffer
  214. self.current_idx = 0
  215. self.local_reader_rank = handle.local_reader_ranks.index(rank)
  216. self._is_local_reader = True
  217. self._is_remote_reader = False
  218. self.local_socket = context.socket(SUB)
  219. self.local_socket.setsockopt_string(SUBSCRIBE, "")
  220. self.local_socket.connect(
  221. f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")
  222. self.local_sync_socket = context.socket(REQ)
  223. self.local_sync_socket.connect(
  224. f"tcp://{handle.connect_ip}:{handle.local_sync_port}")
  225. self.remote_socket = None
  226. self.remote_sync_socket = None
  227. else:
  228. self.buffer = None # type: ignore
  229. self.current_idx = -1
  230. self.local_reader_rank = -1
  231. self._is_local_reader = False
  232. self._is_remote_reader = True
  233. self.local_socket = None
  234. self.local_sync_socket = None
  235. self.remote_socket = context.socket(SUB)
  236. self.remote_socket.setsockopt_string(SUBSCRIBE, "")
  237. self.remote_socket.connect(
  238. f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")
  239. self.remote_sync_socket = context.socket(REQ)
  240. self.remote_sync_socket.connect(
  241. f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")
  242. return self
  243. def wait_until_ready(self):
  244. """This is a collective operation. All processes (including the
  245. readers and the writer) should call this function.
  246. """
  247. if self._is_writer:
  248. # wait for all readers to connect
  249. # local readers
  250. for i in range(self.n_local_reader):
  251. recv = self.local_sync_socket.recv()
  252. assert recv == b"READY"
  253. self.local_sync_socket.send(b"READY")
  254. if self.n_local_reader > 0:
  255. self.local_socket.send(b"READY")
  256. # remote readers
  257. for i in range(self.n_remote_reader):
  258. recv = self.remote_sync_socket.recv()
  259. assert recv == b"READY"
  260. self.remote_sync_socket.send(b"READY")
  261. if self.n_remote_reader > 0:
  262. self.remote_socket.send(b"READY")
  263. elif self._is_local_reader:
  264. self.local_sync_socket.send(b"READY")
  265. recv = self.local_sync_socket.recv()
  266. assert recv == b"READY"
  267. recv = self.local_socket.recv()
  268. assert recv == b"READY"
  269. elif self._is_remote_reader:
  270. self.remote_sync_socket.send(b"READY")
  271. recv = self.remote_sync_socket.recv()
  272. assert recv == b"READY"
  273. recv = self.remote_socket.recv()
  274. assert recv == b"READY"
  275. @contextmanager
  276. def acquire_write(self):
  277. assert self._is_writer, "Only writers can acquire write"
  278. start_time = time.monotonic()
  279. n_warning = 1
  280. while True:
  281. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  282. read_count = sum(metadata_buffer[1:])
  283. written_flag = metadata_buffer[0]
  284. if written_flag and read_count != self.buffer.n_reader:
  285. # this block is written and not read by all readers
  286. # for writers, `self.current_idx` is the next block to write
  287. # if this block is not ready to write,
  288. # we need to wait until it is read by all readers
  289. # wait for a while
  290. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  291. # if we wait for a long time, we should warn the user
  292. if time.monotonic(
  293. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  294. logger.warning(
  295. "No available block found in "
  296. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds")
  297. n_warning += 1
  298. continue
  299. # found a block that is either
  300. # (1) not written
  301. # (2) read by all readers
  302. # mark the block as not written
  303. metadata_buffer[0] = 0
  304. # let caller write to the buffer
  305. with self.buffer.get_data(self.current_idx) as buf:
  306. yield buf
  307. # caller has written to the buffer
  308. # NOTE: order is important here
  309. # first set the read flags to 0
  310. # then set the written flag to 1
  311. # otherwise, the readers may think they already read the block
  312. for i in range(1, self.buffer.n_reader + 1):
  313. # set read flag to 0, meaning it is not read yet
  314. metadata_buffer[i] = 0
  315. # mark the block as written
  316. metadata_buffer[0] = 1
  317. self.current_idx = (self.current_idx +
  318. 1) % self.buffer.max_chunks
  319. break
  320. @contextmanager
  321. def acquire_read(self):
  322. assert self._is_local_reader, "Only readers can acquire read"
  323. start_time = time.monotonic()
  324. n_warning = 1
  325. while True:
  326. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  327. read_flag = metadata_buffer[self.local_reader_rank + 1]
  328. written_flag = metadata_buffer[0]
  329. if not written_flag or read_flag:
  330. # this block is either
  331. # (1) not written
  332. # (2) already read by this reader
  333. # for readers, `self.current_idx` is the next block to read
  334. # if this block is not ready,
  335. # we need to wait until it is written
  336. # wait for a while
  337. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  338. # if we wait for a long time, we should warn the user
  339. if time.monotonic(
  340. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  341. logger.warning(
  342. "No available block found in "
  343. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds."
  344. )
  345. n_warning += 1
  346. continue
  347. # found a block that is not read by this reader
  348. # let caller read from the buffer
  349. with self.buffer.get_data(self.current_idx) as buf:
  350. yield buf
  351. # caller has read from the buffer
  352. # set the read flag
  353. metadata_buffer[self.local_reader_rank + 1] = 1
  354. self.current_idx = (self.current_idx +
  355. 1) % self.buffer.max_chunks
  356. break
  357. def enqueue(self, obj):
  358. assert self._is_writer, "Only writers can enqueue"
  359. serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
  360. if self.n_local_reader > 0:
  361. if len(serialized_obj) >= self.buffer.max_chunk_bytes:
  362. with self.acquire_write() as buf:
  363. buf[0] = 1 # overflow
  364. self.local_socket.send(serialized_obj)
  365. else:
  366. with self.acquire_write() as buf:
  367. buf[0] = 0 # not overflow
  368. buf[1:len(serialized_obj) + 1] = serialized_obj
  369. if self.n_remote_reader > 0:
  370. self.remote_socket.send(serialized_obj)
  371. def dequeue(self):
  372. if self._is_local_reader:
  373. with self.acquire_read() as buf:
  374. overflow = buf[0] == 1
  375. if not overflow:
  376. # no need to know the size of serialized object
  377. # pickle format contains the size information internally
  378. # see https://docs.python.org/3/library/pickle.html
  379. obj = pickle.loads(buf[1:])
  380. if overflow:
  381. recv = self.local_socket.recv()
  382. obj = pickle.loads(recv)
  383. elif self._is_remote_reader:
  384. recv = self.remote_socket.recv()
  385. obj = pickle.loads(recv)
  386. else:
  387. raise RuntimeError("Only readers can dequeue")
  388. return obj
  389. def broadcast_object(self, obj=None):
  390. if self._is_writer:
  391. self.enqueue(obj)
  392. return obj
  393. else:
  394. return self.dequeue()
  395. @staticmethod
  396. def create_from_process_group(pg: ProcessGroup,
  397. max_chunk_bytes,
  398. max_chunks,
  399. writer_rank=0) -> "MessageQueue":
  400. group_rank = dist.get_rank(pg)
  401. group_world_size = dist.get_world_size(pg)
  402. global_ranks = dist.get_process_group_ranks(pg)
  403. from aphrodite.distributed.parallel_state import in_the_same_node_as
  404. status = in_the_same_node_as(pg, source_rank=writer_rank)
  405. same_node_ranks = [i for i, s in enumerate(status) if s]
  406. n_reader = group_world_size - 1
  407. n_local_reader = len(same_node_ranks) - 1
  408. local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
  409. buffer_io: MessageQueue
  410. if group_rank == writer_rank:
  411. buffer_io = MessageQueue(
  412. n_reader=n_reader,
  413. n_local_reader=n_local_reader,
  414. local_reader_ranks=local_reader_ranks,
  415. max_chunk_bytes=max_chunk_bytes,
  416. max_chunks=max_chunks,
  417. )
  418. handle = buffer_io.export_handle()
  419. dist.broadcast_object_list([handle],
  420. src=global_ranks[writer_rank],
  421. group=pg)
  422. else:
  423. recv = [None]
  424. dist.broadcast_object_list(recv,
  425. src=global_ranks[writer_rank],
  426. group=pg)
  427. handle = recv[0] # type: ignore
  428. buffer_io = MessageQueue.create_from_handle(handle, group_rank)
  429. buffer_io.wait_until_ready()
  430. return buffer_io