Bladeren bron

chore: mamba cache single buffer (#673)

AlpinDale 6 maanden geleden
bovenliggende
commit
8583aefed7
2 gewijzigde bestanden met toevoegingen van 148 en 124 verwijderingen
  1. 148 121
      aphrodite/modeling/models/jamba.py
  2. 0 3
      aphrodite/task_handler/model_runner.py

+ 148 - 121
aphrodite/modeling/models/jamba.py

@@ -611,12 +611,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
             # compatibility
             if not lora_config else lora_config.lora_vocab_padding_size,
         )
-        # Current step used indices
-        self.current_indices: List[int] = []
         # Used to track and store by the Mamba cache between steps.
         self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
-        # Used as an input_buffer for the CUDA graph runs.
-        self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
         # Maps between the request id and a dict that maps between the seq_id
         # and its index inside the self.mamba_cache
         self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
@@ -646,95 +642,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
             batch_size = input_ids.shape[0]
             if attn_metadata.prefill_metadata:
                 batch_size = len(request_ids_to_seq_ids)
-            (
-                current_seqlen_agnostic_cache,
-                indices,
-            ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
-                                                      batch_size,
-                                                      finished_requests_ids)
+            mamba_cache = self._prepare_current_run_mamba_cache(
+                request_ids_to_seq_ids, batch_size, finished_requests_ids)
         else:
             # CUDA graph capturing runs
-            current_seqlen_agnostic_cache, indices = (
-                kwargs["seqlen_agnostic_capture_inputs"],
-                [],
-            )
-        self.current_indices = indices
+            mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
 
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata,
-                                   current_seqlen_agnostic_cache[0],
-                                   current_seqlen_agnostic_cache[1])
-
-        if "seqlen_agnostic_capture_inputs" not in kwargs:
-            self._copy_mamba_cache_by_indices(self.current_indices,
-                                              current_seqlen_agnostic_cache)
-
+                                   attn_metadata, mamba_cache[0],
+                                   mamba_cache[1])
         return hidden_states
 
-    def _copy_mamba_cache_by_indices(
-            self, indices: List[int],
-            current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
-        for i, offset in enumerate(indices):
-            self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
+    def _swap_mamba_cache(self, from_index: int, to_index: int):
+        assert len(self.mamba_cache) > 0
+        for cache_t in self.mamba_cache:
+            cache_t[:, [to_index,from_index]] = \
+             cache_t[:, [from_index,to_index]]
 
-    def _copy_mamba_cache(self, index_to: int, index_from: int,
-                          from_buffer: Tuple[torch.Tensor, torch.Tensor]):
+    def _copy_mamba_cache(self, from_index: int, to_index: int):
         assert len(self.mamba_cache) > 0
-        for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
-            cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
+        for cache_t in self.mamba_cache:
+            cache_t[:, to_index].copy_(cache_t[:, from_index],
                                        non_blocking=True)
 
-    def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
-                                      seqs_id: List[int]) -> List[int]:
-        indices_for_current_run = []
-        for seq_id in seqs_id:
-            if cur_rid not in self.mamba_cache_indices_mapping:
-                self.mamba_cache_indices_mapping[cur_rid] = {}
-                first_free_index = self._first_free_index_in_mamba_cache()
-                self.mamba_cache_indices_mapping[cur_rid][
-                    seq_id] = first_free_index
-                index_for_current_run = first_free_index
-            ## case of decoding n>1, copy prefill cache to decoding indices
-            elif seq_id not in (seq_ids2indices :=
-                                self.mamba_cache_indices_mapping[cur_rid]):
-                first_free_index = self._first_free_index_in_mamba_cache()
-                index_exist = list(seq_ids2indices.values())[0]
-                self._copy_mamba_cache(index_from=index_exist,
-                                       index_to=first_free_index,
-                                       from_buffer=self.mamba_cache)
-                self.mamba_cache_indices_mapping[cur_rid][
-                    seq_id] = first_free_index
-                index_for_current_run = first_free_index
-            else:
-                index_for_current_run = self.mamba_cache_indices_mapping[
-                    cur_rid][seq_id]
-
-            indices_for_current_run.append(index_for_current_run)
-        return indices_for_current_run
+    def _move_out_if_already_occupied(self, index: int,
+                                      all_occupied_indices: List[int]):
+        if index in all_occupied_indices:
+            first_free_index = self._first_free_index_in_mamba_cache()
+            # In case occupied, move the occupied to a new empty block
+            self._move_cache_index_and_mappings(from_index=index,
+                                                to_index=first_free_index)
+
+    def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
+                                                       seq_id: int,
+                                                       destination_index: int):
+        """
+        Assign (req_id,seq_id) pair to a `destination_index` index, if
+        already occupied, move the occupying index to a free index.
+        """
+        all_occupied_indices = self._get_all_occupied_indices()
+        if cur_rid not in self.mamba_cache_indices_mapping:
+            self._move_out_if_already_occupied(
+                index=destination_index,
+                all_occupied_indices=all_occupied_indices)
+            self.mamba_cache_indices_mapping[cur_rid] = {
+                seq_id: destination_index
+            }
+        elif seq_id not in (seq_ids2indices :=
+                            self.mamba_cache_indices_mapping[cur_rid]):
+            # parallel sampling , where n > 1, assume prefill have
+            # already happened now we only need to copy the already
+            # existing cache into the siblings seq_ids caches
+            self._move_out_if_already_occupied(
+                index=destination_index,
+                all_occupied_indices=all_occupied_indices)
+            index_exists = list(seq_ids2indices.values())[0]
+            # case of decoding n>1, copy prefill cache to decoding indices
+            self._copy_mamba_cache(from_index=index_exists,
+                                   to_index=destination_index)
+            self.mamba_cache_indices_mapping[cur_rid][
+                seq_id] = destination_index
+        else:
+            # already exists
+            cache_index_already_exists = self.mamba_cache_indices_mapping[
+                cur_rid][seq_id]
+            if cache_index_already_exists != destination_index:
+                # In case the seq id already exists but not in
+                # the right destination, swap it with what's occupying it
+                self._swap_pair_indices_and_mappings(
+                    from_index=cache_index_already_exists,
+                    to_index=destination_index)
 
     def _prepare_current_run_mamba_cache(
-        self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
-        finished_requests_ids: List[str]
-    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
-        indices_for_current_run = []
-        for request_id, seqs_id in request_ids_to_seq_ids.items():
+            self, request_ids_to_seq_ids: Dict[str, list[int]],
+            batch_size: int, finished_requests_ids: List[str]):
+        running_indices = []
+        request_ids_to_seq_ids_flatten = [
+            (req_id, seq_id)
+            for req_id, seq_ids in request_ids_to_seq_ids.items()
+            for seq_id in seq_ids
+        ]
+        for dest_index, (request_id,
+                         seq_id) in enumerate(request_ids_to_seq_ids_flatten):
             if request_id in finished_requests_ids:
-                # Do not allocate cache for requests that run
+                # Do not allocate cache index for requests that run
                 # and finish right after
                 continue
-            indices_for_current_run += self._assign_seq_id_to_mamba_cache(
-                request_id, seqs_id)
-        ## Pad the batch in case of running batch that was not captured via CG
-        padded_indices = indices_for_current_run.copy()
-        pad_index = self._first_free_index_in_mamba_cache()
+            self._assign_seq_id_to_mamba_cache_in_specific_dest(
+                request_id, seq_id, dest_index)
+            running_indices.append(dest_index)
 
-        for _ in range(batch_size - len(indices_for_current_run)):
-            padded_indices.append(pad_index)
+        self._clean_up_first_bs_blocks(batch_size, running_indices)
+        conv_state = self.mamba_cache[0][:, :batch_size]
+        temporal_state = self.mamba_cache[1][:, :batch_size]
 
-        conv_state = self.mamba_cache[0][:, padded_indices]
-        temporal_state = self.mamba_cache[1][:, padded_indices]
+        return (conv_state, temporal_state)
 
-        return (conv_state, temporal_state), indices_for_current_run
+    def _get_all_occupied_indices(self):
+        return [
+            cache_idx
+            for seq_ids2indices in self.mamba_cache_indices_mapping.values()
+            for cache_idx in seq_ids2indices.values()
+        ]
+
+    def _clean_up_first_bs_blocks(self, batch_size: int,
+                                  indices_for_current_run: List[int]):
+        # move out all of the occupied but currently not running blocks
+        # outside of the first n blocks
+        destination_indices = set([range(batch_size)])
+        max_possible_batch_size = self.mamba_cache[0].shape[1]
+        for destination_index in destination_indices:
+            if destination_index in self._get_all_occupied_indices() and  \
+               destination_index not in indices_for_current_run:
+                # move not running indices outside of the batch
+                all_other_indices = list(
+                    range(batch_size, max_possible_batch_size))
+                first_avail_index = self._first_free_index_in_mamba_cache(
+                    all_other_indices)
+                self._swap_indices(from_index=destination_index,
+                                   to_index=first_avail_index)
+
+    def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
+        self._copy_mamba_cache(from_index=from_index, to_index=to_index)
+        self._update_mapping_index(from_index=from_index, to_index=to_index)
+
+    def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
+        self._swap_mamba_cache(from_index=from_index, to_index=to_index)
+        self._swap_mapping_index(from_index=from_index, to_index=to_index)
+
+    def _swap_mapping_index(self, from_index: int, to_index: int):
+        for seq_ids2index in self.mamba_cache_indices_mapping.values():
+            for seq_id, index in seq_ids2index.items():
+                if from_index == index:
+                    seq_ids2index.update({seq_id: to_index})
+                elif to_index == index:
+                    seq_ids2index.update({seq_id: from_index})
+
+    def _update_mapping_index(self, from_index: int, to_index: int):
+        for seq_ids2index in self.mamba_cache_indices_mapping.values():
+            for seq_id, index in seq_ids2index.items():
+                if from_index == index:
+                    seq_ids2index.update({seq_id: to_index})
+                    return
 
     def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
         """
@@ -749,28 +798,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
         self._release_mamba_cache(finished_requests_ids)
         request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
         cg_batch_size = input_buffers['input_ids'].shape[0]
-        (
-            current_mamba_cache,
-            indices,
-        ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
-                                                  cg_batch_size,
-                                                  finished_requests_ids)
-        self.current_indices = indices
-
-        for input_buffer, current_cache_buffer in zip(
-                input_buffers["seqlen_agnostic_capture_inputs"],
-                current_mamba_cache):
-            input_buffer.copy_(current_cache_buffer, non_blocking=True)
-
-    def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
-        """
-        Copy the relevant Mamba cache from the CUDA graph input_buffers
-        back to the JambaForCausalLM.mamba_cache after CUDA 
-        graph replay run is done.
-        """
-        self._copy_mamba_cache_by_indices(
-            self.current_indices,
-            input_buffers["seqlen_agnostic_capture_inputs"])
+        self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
+                                              cg_batch_size,
+                                              finished_requests_ids)
 
     def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
         """
@@ -778,26 +808,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
         The buffer is used to maintain the Mamba Cache during the CUDA graph 
         replay runs.
         """
-        return tuple(buffer[:, :batch_size]
-                     for buffer in self.mamba_gc_cache_buffer)
+        return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
 
     def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
         for req_id in finished_seq_groups_req_ids:
             if req_id in self.mamba_cache_indices_mapping:
                 self.mamba_cache_indices_mapping.pop(req_id)
 
-    def _first_free_index_in_mamba_cache(self) -> int:
-        if self.mamba_cache:
+    def _first_free_index_in_mamba_cache(
+            self, indices_range: Optional[List[int]] = None) -> int:
+        assert self.mamba_cache is not None
+        if indices_range is None:
             max_possible_batch_size = self.mamba_cache[0].shape[1]
-            occupied = [
-                id for seq_ids in self.mamba_cache_indices_mapping.values()
-                for id in seq_ids.values()
-            ]
-            first_free_index = [
-                i not in occupied for i in range(max_possible_batch_size)
-            ].index(True)
-            return first_free_index
-        return 0
+            indices_range = list(range(max_possible_batch_size))
+        all_occupied_indices = self._get_all_occupied_indices()
+        for i in indices_range:
+            if i not in all_occupied_indices:
+                return i
+        raise Exception("Couldn't find a free spot in the mamba cache! This"
+                        "should never happen")
 
     def _get_mamba_cache_shape(
             self
@@ -821,20 +850,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
             [layer_type == "mamba" for layer_type in layers_type])
         max_batch_size = (_get_graph_batch_size(
             self.scheduler_config.max_num_seqs) if self.scheduler_config else
-                          max(_BATCH_SIZES_TO_CAPTURE)) + 10
+                          max(_BATCH_SIZES_TO_CAPTURE) + 2)
         conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
         assert conv_state_shape is not None and temporal_state_shape is not None
 
-        for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
-            buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
-                                  conv_state_shape,
-                                  dtype=dtype,
-                                  device="cuda"),
-                      torch.empty(size=(mamba_layers, max_batch_size) +
-                                  temporal_state_shape,
-                                  dtype=dtype,
-                                  device="cuda"))
-            setattr(self, buffername, buffer)
+        self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
+                                        conv_state_shape,
+                                        dtype=dtype,
+                                        device="cuda"),
+                            torch.empty(size=(mamba_layers, max_batch_size) +
+                                        temporal_state_shape,
+                                        dtype=dtype,
+                                        device="cuda"))
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:

+ 0 - 3
aphrodite/task_handler/model_runner.py

@@ -1740,9 +1740,6 @@ class CUDAGraphRunner:
                                               non_blocking=True)
         # Run the graph.
         self.graph.replay()
-        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
-            self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
-                                                      **kwargs)
 
         # Return the output tensor.
         if get_pp_group().is_last_rank: