|
@@ -611,12 +611,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|
|
# compatibility
|
|
|
if not lora_config else lora_config.lora_vocab_padding_size,
|
|
|
)
|
|
|
- # Current step used indices
|
|
|
- self.current_indices: List[int] = []
|
|
|
# Used to track and store by the Mamba cache between steps.
|
|
|
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
|
|
- # Used as an input_buffer for the CUDA graph runs.
|
|
|
- self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
|
|
# Maps between the request id and a dict that maps between the seq_id
|
|
|
# and its index inside the self.mamba_cache
|
|
|
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
|
@@ -646,95 +642,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|
|
batch_size = input_ids.shape[0]
|
|
|
if attn_metadata.prefill_metadata:
|
|
|
batch_size = len(request_ids_to_seq_ids)
|
|
|
- (
|
|
|
- current_seqlen_agnostic_cache,
|
|
|
- indices,
|
|
|
- ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
|
|
- batch_size,
|
|
|
- finished_requests_ids)
|
|
|
+ mamba_cache = self._prepare_current_run_mamba_cache(
|
|
|
+ request_ids_to_seq_ids, batch_size, finished_requests_ids)
|
|
|
else:
|
|
|
# CUDA graph capturing runs
|
|
|
- current_seqlen_agnostic_cache, indices = (
|
|
|
- kwargs["seqlen_agnostic_capture_inputs"],
|
|
|
- [],
|
|
|
- )
|
|
|
- self.current_indices = indices
|
|
|
+ mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
|
|
|
|
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
|
- attn_metadata,
|
|
|
- current_seqlen_agnostic_cache[0],
|
|
|
- current_seqlen_agnostic_cache[1])
|
|
|
-
|
|
|
- if "seqlen_agnostic_capture_inputs" not in kwargs:
|
|
|
- self._copy_mamba_cache_by_indices(self.current_indices,
|
|
|
- current_seqlen_agnostic_cache)
|
|
|
-
|
|
|
+ attn_metadata, mamba_cache[0],
|
|
|
+ mamba_cache[1])
|
|
|
return hidden_states
|
|
|
|
|
|
- def _copy_mamba_cache_by_indices(
|
|
|
- self, indices: List[int],
|
|
|
- current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
|
|
|
- for i, offset in enumerate(indices):
|
|
|
- self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
|
|
|
+ def _swap_mamba_cache(self, from_index: int, to_index: int):
|
|
|
+ assert len(self.mamba_cache) > 0
|
|
|
+ for cache_t in self.mamba_cache:
|
|
|
+ cache_t[:, [to_index,from_index]] = \
|
|
|
+ cache_t[:, [from_index,to_index]]
|
|
|
|
|
|
- def _copy_mamba_cache(self, index_to: int, index_from: int,
|
|
|
- from_buffer: Tuple[torch.Tensor, torch.Tensor]):
|
|
|
+ def _copy_mamba_cache(self, from_index: int, to_index: int):
|
|
|
assert len(self.mamba_cache) > 0
|
|
|
- for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
|
|
|
- cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
|
|
|
+ for cache_t in self.mamba_cache:
|
|
|
+ cache_t[:, to_index].copy_(cache_t[:, from_index],
|
|
|
non_blocking=True)
|
|
|
|
|
|
- def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
|
|
|
- seqs_id: List[int]) -> List[int]:
|
|
|
- indices_for_current_run = []
|
|
|
- for seq_id in seqs_id:
|
|
|
- if cur_rid not in self.mamba_cache_indices_mapping:
|
|
|
- self.mamba_cache_indices_mapping[cur_rid] = {}
|
|
|
- first_free_index = self._first_free_index_in_mamba_cache()
|
|
|
- self.mamba_cache_indices_mapping[cur_rid][
|
|
|
- seq_id] = first_free_index
|
|
|
- index_for_current_run = first_free_index
|
|
|
- ## case of decoding n>1, copy prefill cache to decoding indices
|
|
|
- elif seq_id not in (seq_ids2indices :=
|
|
|
- self.mamba_cache_indices_mapping[cur_rid]):
|
|
|
- first_free_index = self._first_free_index_in_mamba_cache()
|
|
|
- index_exist = list(seq_ids2indices.values())[0]
|
|
|
- self._copy_mamba_cache(index_from=index_exist,
|
|
|
- index_to=first_free_index,
|
|
|
- from_buffer=self.mamba_cache)
|
|
|
- self.mamba_cache_indices_mapping[cur_rid][
|
|
|
- seq_id] = first_free_index
|
|
|
- index_for_current_run = first_free_index
|
|
|
- else:
|
|
|
- index_for_current_run = self.mamba_cache_indices_mapping[
|
|
|
- cur_rid][seq_id]
|
|
|
-
|
|
|
- indices_for_current_run.append(index_for_current_run)
|
|
|
- return indices_for_current_run
|
|
|
+ def _move_out_if_already_occupied(self, index: int,
|
|
|
+ all_occupied_indices: List[int]):
|
|
|
+ if index in all_occupied_indices:
|
|
|
+ first_free_index = self._first_free_index_in_mamba_cache()
|
|
|
+ # In case occupied, move the occupied to a new empty block
|
|
|
+ self._move_cache_index_and_mappings(from_index=index,
|
|
|
+ to_index=first_free_index)
|
|
|
+
|
|
|
+ def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
|
|
|
+ seq_id: int,
|
|
|
+ destination_index: int):
|
|
|
+ """
|
|
|
+ Assign (req_id,seq_id) pair to a `destination_index` index, if
|
|
|
+ already occupied, move the occupying index to a free index.
|
|
|
+ """
|
|
|
+ all_occupied_indices = self._get_all_occupied_indices()
|
|
|
+ if cur_rid not in self.mamba_cache_indices_mapping:
|
|
|
+ self._move_out_if_already_occupied(
|
|
|
+ index=destination_index,
|
|
|
+ all_occupied_indices=all_occupied_indices)
|
|
|
+ self.mamba_cache_indices_mapping[cur_rid] = {
|
|
|
+ seq_id: destination_index
|
|
|
+ }
|
|
|
+ elif seq_id not in (seq_ids2indices :=
|
|
|
+ self.mamba_cache_indices_mapping[cur_rid]):
|
|
|
+ # parallel sampling , where n > 1, assume prefill have
|
|
|
+ # already happened now we only need to copy the already
|
|
|
+ # existing cache into the siblings seq_ids caches
|
|
|
+ self._move_out_if_already_occupied(
|
|
|
+ index=destination_index,
|
|
|
+ all_occupied_indices=all_occupied_indices)
|
|
|
+ index_exists = list(seq_ids2indices.values())[0]
|
|
|
+ # case of decoding n>1, copy prefill cache to decoding indices
|
|
|
+ self._copy_mamba_cache(from_index=index_exists,
|
|
|
+ to_index=destination_index)
|
|
|
+ self.mamba_cache_indices_mapping[cur_rid][
|
|
|
+ seq_id] = destination_index
|
|
|
+ else:
|
|
|
+ # already exists
|
|
|
+ cache_index_already_exists = self.mamba_cache_indices_mapping[
|
|
|
+ cur_rid][seq_id]
|
|
|
+ if cache_index_already_exists != destination_index:
|
|
|
+ # In case the seq id already exists but not in
|
|
|
+ # the right destination, swap it with what's occupying it
|
|
|
+ self._swap_pair_indices_and_mappings(
|
|
|
+ from_index=cache_index_already_exists,
|
|
|
+ to_index=destination_index)
|
|
|
|
|
|
def _prepare_current_run_mamba_cache(
|
|
|
- self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
|
|
|
- finished_requests_ids: List[str]
|
|
|
- ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
|
|
|
- indices_for_current_run = []
|
|
|
- for request_id, seqs_id in request_ids_to_seq_ids.items():
|
|
|
+ self, request_ids_to_seq_ids: Dict[str, list[int]],
|
|
|
+ batch_size: int, finished_requests_ids: List[str]):
|
|
|
+ running_indices = []
|
|
|
+ request_ids_to_seq_ids_flatten = [
|
|
|
+ (req_id, seq_id)
|
|
|
+ for req_id, seq_ids in request_ids_to_seq_ids.items()
|
|
|
+ for seq_id in seq_ids
|
|
|
+ ]
|
|
|
+ for dest_index, (request_id,
|
|
|
+ seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
|
|
if request_id in finished_requests_ids:
|
|
|
- # Do not allocate cache for requests that run
|
|
|
+ # Do not allocate cache index for requests that run
|
|
|
# and finish right after
|
|
|
continue
|
|
|
- indices_for_current_run += self._assign_seq_id_to_mamba_cache(
|
|
|
- request_id, seqs_id)
|
|
|
- ## Pad the batch in case of running batch that was not captured via CG
|
|
|
- padded_indices = indices_for_current_run.copy()
|
|
|
- pad_index = self._first_free_index_in_mamba_cache()
|
|
|
+ self._assign_seq_id_to_mamba_cache_in_specific_dest(
|
|
|
+ request_id, seq_id, dest_index)
|
|
|
+ running_indices.append(dest_index)
|
|
|
|
|
|
- for _ in range(batch_size - len(indices_for_current_run)):
|
|
|
- padded_indices.append(pad_index)
|
|
|
+ self._clean_up_first_bs_blocks(batch_size, running_indices)
|
|
|
+ conv_state = self.mamba_cache[0][:, :batch_size]
|
|
|
+ temporal_state = self.mamba_cache[1][:, :batch_size]
|
|
|
|
|
|
- conv_state = self.mamba_cache[0][:, padded_indices]
|
|
|
- temporal_state = self.mamba_cache[1][:, padded_indices]
|
|
|
+ return (conv_state, temporal_state)
|
|
|
|
|
|
- return (conv_state, temporal_state), indices_for_current_run
|
|
|
+ def _get_all_occupied_indices(self):
|
|
|
+ return [
|
|
|
+ cache_idx
|
|
|
+ for seq_ids2indices in self.mamba_cache_indices_mapping.values()
|
|
|
+ for cache_idx in seq_ids2indices.values()
|
|
|
+ ]
|
|
|
+
|
|
|
+ def _clean_up_first_bs_blocks(self, batch_size: int,
|
|
|
+ indices_for_current_run: List[int]):
|
|
|
+ # move out all of the occupied but currently not running blocks
|
|
|
+ # outside of the first n blocks
|
|
|
+ destination_indices = set([range(batch_size)])
|
|
|
+ max_possible_batch_size = self.mamba_cache[0].shape[1]
|
|
|
+ for destination_index in destination_indices:
|
|
|
+ if destination_index in self._get_all_occupied_indices() and \
|
|
|
+ destination_index not in indices_for_current_run:
|
|
|
+ # move not running indices outside of the batch
|
|
|
+ all_other_indices = list(
|
|
|
+ range(batch_size, max_possible_batch_size))
|
|
|
+ first_avail_index = self._first_free_index_in_mamba_cache(
|
|
|
+ all_other_indices)
|
|
|
+ self._swap_indices(from_index=destination_index,
|
|
|
+ to_index=first_avail_index)
|
|
|
+
|
|
|
+ def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
|
|
|
+ self._copy_mamba_cache(from_index=from_index, to_index=to_index)
|
|
|
+ self._update_mapping_index(from_index=from_index, to_index=to_index)
|
|
|
+
|
|
|
+ def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
|
|
|
+ self._swap_mamba_cache(from_index=from_index, to_index=to_index)
|
|
|
+ self._swap_mapping_index(from_index=from_index, to_index=to_index)
|
|
|
+
|
|
|
+ def _swap_mapping_index(self, from_index: int, to_index: int):
|
|
|
+ for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
|
|
+ for seq_id, index in seq_ids2index.items():
|
|
|
+ if from_index == index:
|
|
|
+ seq_ids2index.update({seq_id: to_index})
|
|
|
+ elif to_index == index:
|
|
|
+ seq_ids2index.update({seq_id: from_index})
|
|
|
+
|
|
|
+ def _update_mapping_index(self, from_index: int, to_index: int):
|
|
|
+ for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
|
|
+ for seq_id, index in seq_ids2index.items():
|
|
|
+ if from_index == index:
|
|
|
+ seq_ids2index.update({seq_id: to_index})
|
|
|
+ return
|
|
|
|
|
|
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
|
|
"""
|
|
@@ -749,28 +798,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|
|
self._release_mamba_cache(finished_requests_ids)
|
|
|
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
|
|
cg_batch_size = input_buffers['input_ids'].shape[0]
|
|
|
- (
|
|
|
- current_mamba_cache,
|
|
|
- indices,
|
|
|
- ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
|
|
- cg_batch_size,
|
|
|
- finished_requests_ids)
|
|
|
- self.current_indices = indices
|
|
|
-
|
|
|
- for input_buffer, current_cache_buffer in zip(
|
|
|
- input_buffers["seqlen_agnostic_capture_inputs"],
|
|
|
- current_mamba_cache):
|
|
|
- input_buffer.copy_(current_cache_buffer, non_blocking=True)
|
|
|
-
|
|
|
- def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
|
|
|
- """
|
|
|
- Copy the relevant Mamba cache from the CUDA graph input_buffers
|
|
|
- back to the JambaForCausalLM.mamba_cache after CUDA
|
|
|
- graph replay run is done.
|
|
|
- """
|
|
|
- self._copy_mamba_cache_by_indices(
|
|
|
- self.current_indices,
|
|
|
- input_buffers["seqlen_agnostic_capture_inputs"])
|
|
|
+ self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
|
|
+ cg_batch_size,
|
|
|
+ finished_requests_ids)
|
|
|
|
|
|
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
|
|
"""
|
|
@@ -778,26 +808,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|
|
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
|
|
replay runs.
|
|
|
"""
|
|
|
- return tuple(buffer[:, :batch_size]
|
|
|
- for buffer in self.mamba_gc_cache_buffer)
|
|
|
+ return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
|
|
|
|
|
|
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
|
|
|
for req_id in finished_seq_groups_req_ids:
|
|
|
if req_id in self.mamba_cache_indices_mapping:
|
|
|
self.mamba_cache_indices_mapping.pop(req_id)
|
|
|
|
|
|
- def _first_free_index_in_mamba_cache(self) -> int:
|
|
|
- if self.mamba_cache:
|
|
|
+ def _first_free_index_in_mamba_cache(
|
|
|
+ self, indices_range: Optional[List[int]] = None) -> int:
|
|
|
+ assert self.mamba_cache is not None
|
|
|
+ if indices_range is None:
|
|
|
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
|
|
- occupied = [
|
|
|
- id for seq_ids in self.mamba_cache_indices_mapping.values()
|
|
|
- for id in seq_ids.values()
|
|
|
- ]
|
|
|
- first_free_index = [
|
|
|
- i not in occupied for i in range(max_possible_batch_size)
|
|
|
- ].index(True)
|
|
|
- return first_free_index
|
|
|
- return 0
|
|
|
+ indices_range = list(range(max_possible_batch_size))
|
|
|
+ all_occupied_indices = self._get_all_occupied_indices()
|
|
|
+ for i in indices_range:
|
|
|
+ if i not in all_occupied_indices:
|
|
|
+ return i
|
|
|
+ raise Exception("Couldn't find a free spot in the mamba cache! This"
|
|
|
+ "should never happen")
|
|
|
|
|
|
def _get_mamba_cache_shape(
|
|
|
self
|
|
@@ -821,20 +850,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|
|
[layer_type == "mamba" for layer_type in layers_type])
|
|
|
max_batch_size = (_get_graph_batch_size(
|
|
|
self.scheduler_config.max_num_seqs) if self.scheduler_config else
|
|
|
- max(_BATCH_SIZES_TO_CAPTURE)) + 10
|
|
|
+ max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
|
|
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
|
|
|
assert conv_state_shape is not None and temporal_state_shape is not None
|
|
|
|
|
|
- for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
|
|
|
- buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
|
|
|
- conv_state_shape,
|
|
|
- dtype=dtype,
|
|
|
- device="cuda"),
|
|
|
- torch.empty(size=(mamba_layers, max_batch_size) +
|
|
|
- temporal_state_shape,
|
|
|
- dtype=dtype,
|
|
|
- device="cuda"))
|
|
|
- setattr(self, buffername, buffer)
|
|
|
+ self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
|
|
|
+ conv_state_shape,
|
|
|
+ dtype=dtype,
|
|
|
+ device="cuda"),
|
|
|
+ torch.empty(size=(mamba_layers, max_batch_size) +
|
|
|
+ temporal_state_shape,
|
|
|
+ dtype=dtype,
|
|
|
+ device="cuda"))
|
|
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|