shm_broadcast.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. import pickle
  2. import time
  3. from contextlib import contextmanager
  4. from dataclasses import dataclass, field
  5. from multiprocessing import shared_memory
  6. from typing import List, Optional
  7. from unittest.mock import patch
  8. import torch
  9. import torch.distributed as dist
  10. from loguru import logger
  11. from torch.distributed import ProcessGroup
  12. from zmq import IPV6 # type: ignore
  13. from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
  14. import aphrodite.common.envs as envs
  15. from aphrodite.common.utils import get_ip, get_open_port, is_valid_ipv6_address
  16. APHRODITE_RINGBUFFER_WARNING_INTERVAL = (
  17. envs.APHRODITE_RINGBUFFER_WARNING_INTERVAL)
  18. # time to wait if the queue is full or empty
  19. # if we sleep for too short, it will consume too much CPU
  20. # if we sleep for too long, it will slow down the writer/reader
  21. # 0.1 us is a good balance
  22. RINGBUFFER_SLEEP_INTERVAL = 1e-7
  23. class ShmRingBuffer:
  24. def __init__(self,
  25. n_reader: int,
  26. max_chunk_bytes: int,
  27. max_chunks: int,
  28. name: Optional[str] = None):
  29. """
  30. A shared memory ring buffer implementation for broadcast communication.
  31. Essentially, it is a queue where only one will `enqueue` and multiple
  32. will `dequeue`. The max size of each item, together with the max number
  33. of items that can be stored in the buffer are known in advance.
  34. In this case, we don't need to synchronize the access to
  35. the buffer.
  36. Buffer memory layout:
  37. data metadata
  38. | |
  39. | (current_idx) | (current_idx)
  40. v v
  41. +-------------------------------+----------------------------------------+
  42. | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
  43. +-------------------------------+----------------------------------------+
  44. | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
  45. metadata memory layout: each byte is a flag, the first byte is the written
  46. flag, and the rest are reader flags. The flags are set to 0 by default.
  47. +--------------+--------------+--------------+-----+--------------+
  48. | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
  49. +--------------+--------------+--------------+-----+--------------+
  50. The state of metadata is as follows:
  51. (case 1) 0???...???: the block is not written yet, cannot read, can write
  52. (case 2) 1000...000: the block is just written, can read, cannot write
  53. (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
  54. (case 4) 1111...111: the block is written and read by all readers, cannot read, can write
  55. State transition for readers:
  56. When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
  57. Only after the caller finishes reading the block, the reader can mark the block as read.
  58. Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
  59. State transition for writer:
  60. When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
  61. to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
  62. can reset the reader flags to 0, and mark the block as written (from 0 to 1).
  63. NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
  64. During creation, `name` is None and the buffer is created. We can pass the
  65. created object to other processes by pickling it. The other processes will
  66. get the name of the shared memory and open it, so that they can access the
  67. same shared memory buffer.
  68. """# noqa
  69. self.n_reader = n_reader
  70. self.metadata_size = 1 + n_reader
  71. self.max_chunk_bytes = max_chunk_bytes
  72. self.max_chunks = max_chunks
  73. self.total_bytes_of_buffer = (self.max_chunk_bytes +
  74. self.metadata_size) * self.max_chunks
  75. self.data_offset = 0
  76. self.metadata_offset = self.max_chunk_bytes * self.max_chunks
  77. if name is None:
  78. # we are creating a buffer
  79. self.is_creator = True
  80. self.shared_memory = shared_memory.SharedMemory(
  81. create=True, size=self.total_bytes_of_buffer)
  82. # initialize the metadata section to 0
  83. with memoryview(self.shared_memory.buf[self.metadata_offset:]
  84. ) as metadata_buffer:
  85. torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
  86. else:
  87. # we are opening an existing buffer
  88. self.is_creator = False
  89. # fix to https://stackoverflow.com/q/62748654/9191338
  90. # Python incorrectly tracks shared memory even if it is not
  91. # created by the process. The following patch is a workaround.
  92. with patch("multiprocessing.resource_tracker.register",
  93. lambda *args, **kwargs: None):
  94. try:
  95. self.shared_memory = shared_memory.SharedMemory(name=name)
  96. assert self.shared_memory.size == self.total_bytes_of_buffer # noqa
  97. except FileNotFoundError:
  98. # we might deserialize the object in a different node
  99. # in this case, this object is not used,
  100. # and we should suppress the error
  101. pass
  102. def __reduce__(self):
  103. return (
  104. self.__class__,
  105. (self.n_reader, self.max_chunk_bytes, self.max_chunks,
  106. self.shared_memory.name),
  107. )
  108. def __del__(self):
  109. if hasattr(self, "shared_memory"):
  110. self.shared_memory.close()
  111. if self.is_creator:
  112. self.shared_memory.unlink()
  113. @contextmanager
  114. def get_data(self, current_idx: int):
  115. start = self.data_offset + current_idx * self.max_chunk_bytes
  116. end = start + self.max_chunk_bytes
  117. with memoryview(self.shared_memory.buf[start:end]) as buf:
  118. yield buf
  119. @contextmanager
  120. def get_metadata(self, current_idx: int):
  121. start = self.metadata_offset + current_idx * self.metadata_size
  122. end = start + self.metadata_size
  123. with memoryview(self.shared_memory.buf[start:end]) as buf:
  124. yield buf
  125. @dataclass
  126. class Handle:
  127. connect_ip: str
  128. local_reader_ranks: List[int] = field(default_factory=list)
  129. buffer: Optional[ShmRingBuffer] = None
  130. local_subscribe_port: Optional[int] = None
  131. remote_subscribe_port: Optional[int] = None
  132. class MessageQueue:
  133. def __init__(
  134. self,
  135. n_reader, # number of all readers
  136. n_local_reader, # number of local readers through shared memory
  137. local_reader_ranks: Optional[List[int]] = None,
  138. max_chunk_bytes: int = 1024 * 1024 * 10,
  139. max_chunks: int = 10,
  140. connect_ip: Optional[str] = None,
  141. ):
  142. if local_reader_ranks is None:
  143. local_reader_ranks = list(range(n_local_reader))
  144. else:
  145. assert len(local_reader_ranks) == n_local_reader
  146. self.n_local_reader = n_local_reader
  147. n_remote_reader = n_reader - n_local_reader
  148. self.n_remote_reader = n_remote_reader
  149. if connect_ip is None:
  150. connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
  151. context = Context()
  152. if n_local_reader > 0:
  153. # for local readers, we will:
  154. # 1. create a shared memory ring buffer to communicate small data
  155. # 2. create a publish-subscribe socket to communicate large data
  156. self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
  157. max_chunks)
  158. # XPUB is very similar to PUB,
  159. # except that it can receive subscription messages
  160. # to confirm the number of subscribers
  161. self.local_socket = context.socket(XPUB)
  162. # set the verbose option so that we can receive every subscription
  163. # message. otherwise, we will only receive the first subscription
  164. # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
  165. self.local_socket.setsockopt(XPUB_VERBOSE, True)
  166. local_subscribe_port = get_open_port()
  167. socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
  168. logger.debug(f"Binding to {socket_addr}")
  169. self.local_socket.bind(socket_addr)
  170. self.current_idx = 0
  171. else:
  172. self.buffer = None # type: ignore
  173. local_subscribe_port = None
  174. self.local_socket = None
  175. self.current_idx = -1
  176. if n_remote_reader > 0:
  177. # for remote readers, we will:
  178. # create a publish-subscribe socket to communicate large data
  179. self.remote_socket = context.socket(XPUB)
  180. self.remote_socket.setsockopt(XPUB_VERBOSE, True)
  181. remote_subscribe_port = get_open_port()
  182. if is_valid_ipv6_address(connect_ip):
  183. self.remote_socket.setsockopt(IPV6, 1)
  184. socket_addr = f"tcp://*:{remote_subscribe_port}"
  185. self.remote_socket.bind(socket_addr)
  186. else:
  187. remote_subscribe_port = None
  188. self.remote_socket = None
  189. self._is_writer = True
  190. self._is_local_reader = False
  191. self.local_reader_rank = -1
  192. # rank does not matter for remote readers
  193. self._is_remote_reader = False
  194. self.handle = Handle(
  195. connect_ip=connect_ip,
  196. local_reader_ranks=local_reader_ranks,
  197. buffer=self.buffer,
  198. local_subscribe_port=local_subscribe_port,
  199. remote_subscribe_port=remote_subscribe_port,
  200. )
  201. logger.debug("Aphrodite message queue communication handle: "
  202. f"{self.handle}")
  203. def export_handle(self) -> Handle:
  204. return self.handle
  205. @staticmethod
  206. def create_from_handle(handle: Handle, rank) -> "MessageQueue":
  207. self = MessageQueue.__new__(MessageQueue)
  208. self.handle = handle
  209. self._is_writer = False
  210. context = Context()
  211. if rank in handle.local_reader_ranks:
  212. assert handle.buffer is not None
  213. self.buffer = handle.buffer
  214. self.current_idx = 0
  215. self.local_reader_rank = handle.local_reader_ranks.index(rank)
  216. self._is_local_reader = True
  217. self._is_remote_reader = False
  218. self.local_socket = context.socket(SUB)
  219. self.local_socket.setsockopt_string(SUBSCRIBE, "")
  220. socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
  221. logger.debug(f"Connecting to {socket_addr}")
  222. self.local_socket.connect(socket_addr)
  223. self.remote_socket = None
  224. else:
  225. self.buffer = None # type: ignore
  226. self.current_idx = -1
  227. self.local_reader_rank = -1
  228. self._is_local_reader = False
  229. self._is_remote_reader = True
  230. self.local_socket = None
  231. self.remote_socket = context.socket(SUB)
  232. self.remote_socket.setsockopt_string(SUBSCRIBE, "")
  233. if is_valid_ipv6_address(handle.connect_ip):
  234. self.remote_socket.setsockopt(IPV6, 1)
  235. socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
  236. logger.debug(f"Connecting to {socket_addr}")
  237. self.remote_socket.connect(socket_addr)
  238. return self
  239. def wait_until_ready(self):
  240. """This is a collective operation. All processes (including the
  241. readers and the writer) should call this function.
  242. """
  243. if self._is_writer:
  244. # wait for all readers to connect
  245. # local readers
  246. for i in range(self.n_local_reader):
  247. # wait for subscription messages from all local readers
  248. self.local_socket.recv()
  249. if self.n_local_reader > 0:
  250. # send a message to all local readers
  251. # to make sure the publish channel is working
  252. self.local_socket.send(b"READY")
  253. # remote readers
  254. for i in range(self.n_remote_reader):
  255. # wait for subscription messages from all remote readers
  256. self.remote_socket.recv()
  257. if self.n_remote_reader > 0:
  258. # send a message to all remote readers
  259. # to make sure the publish channel is working
  260. self.remote_socket.send(b"READY")
  261. elif self._is_local_reader:
  262. # wait for the writer to send a message
  263. recv = self.local_socket.recv()
  264. assert recv == b"READY"
  265. elif self._is_remote_reader:
  266. # wait for the writer to send a message
  267. recv = self.remote_socket.recv()
  268. assert recv == b"READY"
  269. @contextmanager
  270. def acquire_write(self):
  271. assert self._is_writer, "Only writers can acquire write"
  272. start_time = time.monotonic()
  273. n_warning = 1
  274. while True:
  275. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  276. read_count = sum(metadata_buffer[1:])
  277. written_flag = metadata_buffer[0]
  278. if written_flag and read_count != self.buffer.n_reader:
  279. # this block is written and not read by all readers
  280. # for writers, `self.current_idx` is the next block to write
  281. # if this block is not ready to write,
  282. # we need to wait until it is read by all readers
  283. # wait for a while
  284. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  285. # if we wait for a long time, we should warn the user
  286. if time.monotonic(
  287. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  288. logger.warning(
  289. "No available block found in "
  290. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds")
  291. n_warning += 1
  292. continue
  293. # found a block that is either
  294. # (1) not written
  295. # (2) read by all readers
  296. # mark the block as not written
  297. metadata_buffer[0] = 0
  298. # let caller write to the buffer
  299. with self.buffer.get_data(self.current_idx) as buf:
  300. yield buf
  301. # caller has written to the buffer
  302. # NOTE: order is important here
  303. # first set the read flags to 0
  304. # then set the written flag to 1
  305. # otherwise, the readers may think they already read the block
  306. for i in range(1, self.buffer.n_reader + 1):
  307. # set read flag to 0, meaning it is not read yet
  308. metadata_buffer[i] = 0
  309. # mark the block as written
  310. metadata_buffer[0] = 1
  311. self.current_idx = (self.current_idx +
  312. 1) % self.buffer.max_chunks
  313. break
  314. @contextmanager
  315. def acquire_read(self):
  316. assert self._is_local_reader, "Only readers can acquire read"
  317. start_time = time.monotonic()
  318. n_warning = 1
  319. while True:
  320. with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
  321. read_flag = metadata_buffer[self.local_reader_rank + 1]
  322. written_flag = metadata_buffer[0]
  323. if not written_flag or read_flag:
  324. # this block is either
  325. # (1) not written
  326. # (2) already read by this reader
  327. # for readers, `self.current_idx` is the next block to read
  328. # if this block is not ready,
  329. # we need to wait until it is written
  330. # wait for a while
  331. time.sleep(RINGBUFFER_SLEEP_INTERVAL)
  332. # if we wait for a long time, we should warn the user
  333. if time.monotonic(
  334. ) - start_time > APHRODITE_RINGBUFFER_WARNING_INTERVAL * n_warning: # type: ignore # noqa
  335. logger.warning(
  336. "No available block found in "
  337. f"{APHRODITE_RINGBUFFER_WARNING_INTERVAL} seconds."
  338. )
  339. n_warning += 1
  340. continue
  341. # found a block that is not read by this reader
  342. # let caller read from the buffer
  343. with self.buffer.get_data(self.current_idx) as buf:
  344. yield buf
  345. # caller has read from the buffer
  346. # set the read flag
  347. metadata_buffer[self.local_reader_rank + 1] = 1
  348. self.current_idx = (self.current_idx +
  349. 1) % self.buffer.max_chunks
  350. break
  351. def enqueue(self, obj):
  352. assert self._is_writer, "Only writers can enqueue"
  353. serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
  354. if self.n_local_reader > 0:
  355. if len(serialized_obj) >= self.buffer.max_chunk_bytes:
  356. with self.acquire_write() as buf:
  357. buf[0] = 1 # overflow
  358. self.local_socket.send(serialized_obj)
  359. else:
  360. with self.acquire_write() as buf:
  361. buf[0] = 0 # not overflow
  362. buf[1:len(serialized_obj) + 1] = serialized_obj
  363. if self.n_remote_reader > 0:
  364. self.remote_socket.send(serialized_obj)
  365. def dequeue(self):
  366. if self._is_local_reader:
  367. with self.acquire_read() as buf:
  368. overflow = buf[0] == 1
  369. if not overflow:
  370. # no need to know the size of serialized object
  371. # pickle format contains the size information internally
  372. # see https://docs.python.org/3/library/pickle.html
  373. obj = pickle.loads(buf[1:])
  374. if overflow:
  375. recv = self.local_socket.recv()
  376. obj = pickle.loads(recv)
  377. elif self._is_remote_reader:
  378. recv = self.remote_socket.recv()
  379. obj = pickle.loads(recv)
  380. else:
  381. raise RuntimeError("Only readers can dequeue")
  382. return obj
  383. def broadcast_object(self, obj=None):
  384. if self._is_writer:
  385. self.enqueue(obj)
  386. return obj
  387. else:
  388. return self.dequeue()
  389. @staticmethod
  390. def create_from_process_group(pg: ProcessGroup,
  391. max_chunk_bytes,
  392. max_chunks,
  393. writer_rank=0) -> "MessageQueue":
  394. group_rank = dist.get_rank(pg)
  395. group_world_size = dist.get_world_size(pg)
  396. global_ranks = dist.get_process_group_ranks(pg)
  397. from aphrodite.distributed.parallel_state import in_the_same_node_as
  398. status = in_the_same_node_as(pg, source_rank=writer_rank)
  399. same_node_ranks = [i for i, s in enumerate(status) if s]
  400. n_reader = group_world_size - 1
  401. n_local_reader = len(same_node_ranks) - 1
  402. local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
  403. buffer_io: MessageQueue
  404. if group_rank == writer_rank:
  405. buffer_io = MessageQueue(
  406. n_reader=n_reader,
  407. n_local_reader=n_local_reader,
  408. local_reader_ranks=local_reader_ranks,
  409. max_chunk_bytes=max_chunk_bytes,
  410. max_chunks=max_chunks,
  411. )
  412. handle = buffer_io.export_handle()
  413. dist.broadcast_object_list([handle],
  414. src=global_ranks[writer_rank],
  415. group=pg)
  416. else:
  417. recv = [None]
  418. dist.broadcast_object_list(recv,
  419. src=global_ranks[writer_rank],
  420. group=pg)
  421. handle = recv[0] # type: ignore
  422. buffer_io = MessageQueue.create_from_handle(handle, group_rank)
  423. buffer_io.wait_until_ready()
  424. return buffer_io