mamba_cache.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. from typing import Dict, List, Optional
  2. import torch
  3. class MambaCacheManager:
  4. def __init__(self, dtype, num_mamba_layers, max_batch_size,
  5. conv_state_shape, temporal_state_shape):
  6. conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
  7. conv_state_shape,
  8. dtype=dtype,
  9. device="cuda")
  10. temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
  11. temporal_state_shape,
  12. dtype=dtype,
  13. device="cuda")
  14. self.mamba_cache = (conv_state, temporal_state)
  15. # Maps between the request id and a dict that maps between the seq_id
  16. # and its index inside the self.mamba_cache
  17. self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
  18. def _swap_mamba_cache(self, from_index: int, to_index: int):
  19. assert len(self.mamba_cache) > 0
  20. for cache_t in self.mamba_cache:
  21. cache_t[:, [to_index,from_index]] = \
  22. cache_t[:, [from_index,to_index]]
  23. def _copy_mamba_cache(self, from_index: int, to_index: int):
  24. assert len(self.mamba_cache) > 0
  25. for cache_t in self.mamba_cache:
  26. cache_t[:, to_index].copy_(cache_t[:, from_index],
  27. non_blocking=True)
  28. def _move_out_if_already_occupied(self, index: int,
  29. all_occupied_indices: List[int]):
  30. if index in all_occupied_indices:
  31. first_free_index = self._first_free_index_in_mamba_cache()
  32. # In case occupied, move the occupied to a new empty block
  33. self._move_cache_index_and_mappings(from_index=index,
  34. to_index=first_free_index)
  35. def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
  36. seq_id: int,
  37. destination_index: int):
  38. """
  39. Assign (req_id,seq_id) pair to a `destination_index` index, if
  40. already occupied, move the occupying index to a free index.
  41. """
  42. all_occupied_indices = self._get_all_occupied_indices()
  43. if cur_rid not in self.mamba_cache_indices_mapping:
  44. self._move_out_if_already_occupied(
  45. index=destination_index,
  46. all_occupied_indices=all_occupied_indices)
  47. self.mamba_cache_indices_mapping[cur_rid] = {
  48. seq_id: destination_index
  49. }
  50. elif seq_id not in (seq_ids2indices :=
  51. self.mamba_cache_indices_mapping[cur_rid]):
  52. # parallel sampling , where n > 1, assume prefill have
  53. # already happened now we only need to copy the already
  54. # existing cache into the siblings seq_ids caches
  55. self._move_out_if_already_occupied(
  56. index=destination_index,
  57. all_occupied_indices=all_occupied_indices)
  58. index_exists = list(seq_ids2indices.values())[0]
  59. # case of decoding n>1, copy prefill cache to decoding indices
  60. self._copy_mamba_cache(from_index=index_exists,
  61. to_index=destination_index)
  62. self.mamba_cache_indices_mapping[cur_rid][
  63. seq_id] = destination_index
  64. else:
  65. # already exists
  66. cache_index_already_exists = self.mamba_cache_indices_mapping[
  67. cur_rid][seq_id]
  68. if cache_index_already_exists != destination_index:
  69. # In case the seq id already exists but not in
  70. # the right destination, swap it with what's occupying it
  71. self._swap_pair_indices_and_mappings(
  72. from_index=cache_index_already_exists,
  73. to_index=destination_index)
  74. def prepare_current_run_state(self,
  75. request_ids_to_seq_ids: Dict[str, list[int]],
  76. batch_size: int,
  77. finished_requests_ids: List[str]):
  78. running_indices = []
  79. request_ids_to_seq_ids_flatten = [
  80. (req_id, seq_id)
  81. for req_id, seq_ids in request_ids_to_seq_ids.items()
  82. for seq_id in seq_ids
  83. ]
  84. for dest_index, (request_id,
  85. seq_id) in enumerate(request_ids_to_seq_ids_flatten):
  86. if request_id in finished_requests_ids:
  87. # Do not allocate cache index for requests that run
  88. # and finish right after
  89. continue
  90. self._assign_seq_id_to_mamba_cache_in_specific_dest(
  91. request_id, seq_id, dest_index)
  92. running_indices.append(dest_index)
  93. self._clean_up_first_bs_blocks(batch_size, running_indices)
  94. conv_state = self.mamba_cache[0][:, :batch_size]
  95. temporal_state = self.mamba_cache[1][:, :batch_size]
  96. return (conv_state, temporal_state)
  97. def _get_all_occupied_indices(self):
  98. return [
  99. cache_idx
  100. for seq_ids2indices in self.mamba_cache_indices_mapping.values()
  101. for cache_idx in seq_ids2indices.values()
  102. ]
  103. def _clean_up_first_bs_blocks(self, batch_size: int,
  104. indices_for_current_run: List[int]):
  105. # move out all of the occupied but currently not running blocks
  106. # outside of the first n blocks
  107. destination_indices = range(batch_size)
  108. max_possible_batch_size = self.mamba_cache[0].shape[1]
  109. for destination_index in destination_indices:
  110. if destination_index in self._get_all_occupied_indices() and \
  111. destination_index not in indices_for_current_run:
  112. # move not running indices outside of the batch
  113. all_other_indices = list(
  114. range(batch_size, max_possible_batch_size))
  115. first_avail_index = self._first_free_index_in_mamba_cache(
  116. all_other_indices)
  117. self._swap_indices(from_index=destination_index,
  118. to_index=first_avail_index)
  119. def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
  120. self._copy_mamba_cache(from_index=from_index, to_index=to_index)
  121. self._update_mapping_index(from_index=from_index, to_index=to_index)
  122. def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
  123. self._swap_mamba_cache(from_index=from_index, to_index=to_index)
  124. self._swap_mapping_index(from_index=from_index, to_index=to_index)
  125. def _swap_mapping_index(self, from_index: int, to_index: int):
  126. for seq_ids2index in self.mamba_cache_indices_mapping.values():
  127. for seq_id, index in seq_ids2index.items():
  128. if from_index == index:
  129. seq_ids2index.update({seq_id: to_index})
  130. elif to_index == index:
  131. seq_ids2index.update({seq_id: from_index})
  132. def _update_mapping_index(self, from_index: int, to_index: int):
  133. for seq_ids2index in self.mamba_cache_indices_mapping.values():
  134. for seq_id, index in seq_ids2index.items():
  135. if from_index == index:
  136. seq_ids2index.update({seq_id: to_index})
  137. return
  138. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  139. """
  140. Copy the relevant Mamba cache into the CUDA graph input buffer
  141. that was provided during the capture runs
  142. (JambaForCausalLM.mamba_gc_cache_buffer).
  143. """
  144. assert all(
  145. key in kwargs
  146. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  147. finished_requests_ids = kwargs["finished_requests_ids"]
  148. self.release_finished_requests(finished_requests_ids)
  149. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  150. cg_batch_size = input_buffers['input_ids'].shape[0]
  151. self.prepare_current_run_state(request_ids_to_seq_ids, cg_batch_size,
  152. finished_requests_ids)
  153. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  154. """
  155. Provide the CUDA graph capture runs with a buffer in adjusted size.
  156. The buffer is used to maintain the Mamba Cache during the CUDA graph
  157. replay runs.
  158. """
  159. return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
  160. def release_finished_requests(self,
  161. finished_seq_groups_req_ids: List[str]):
  162. for req_id in finished_seq_groups_req_ids:
  163. if req_id in self.mamba_cache_indices_mapping:
  164. self.mamba_cache_indices_mapping.pop(req_id)
  165. def _first_free_index_in_mamba_cache(
  166. self, indices_range: Optional[List[int]] = None) -> int:
  167. assert self.mamba_cache is not None
  168. if indices_range is None:
  169. max_possible_batch_size = self.mamba_cache[0].shape[1]
  170. indices_range = list(range(max_possible_batch_size))
  171. all_occupied_indices = self._get_all_occupied_indices()
  172. for i in indices_range:
  173. if i not in all_occupied_indices:
  174. return i
  175. raise Exception("Couldn't find a free spot in the mamba cache! This"
  176. "should never happen")
  177. def initialize_tensors(self, dtype, num_mamba_layers, max_batch_size,
  178. conv_state_shape, temporal_state_shape):
  179. conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
  180. conv_state_shape,
  181. dtype=dtype,
  182. device="cuda")
  183. temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
  184. temporal_state_shape,
  185. dtype=dtype,
  186. device="cuda")
  187. self.mamba_cache = (conv_state, temporal_state)
  188. self.initialized = True