Browse Source

fix: sync CPU delay in sampler (#93)

* sync CPU delay in sampler

* fix formatting
AlpinDale 1 year ago
parent
commit
ce66e1df56
3 changed files with 66 additions and 46 deletions
  1. 3 33
      aphrodite/modeling/layers/sampler.py
  2. 18 11
      aphrodite/modeling/metadata.py
  3. 45 2
      aphrodite/task_handler/worker.py

+ 3 - 33
aphrodite/modeling/layers/sampler.py

@@ -144,29 +144,8 @@ def _prune_hidden_states(
     hidden_states: torch.Tensor,
     input_metadata: InputMetadata,
 ) -> torch.Tensor:
-    selected_token_indices: List[int] = []
-    start_idx = 0
-    for i, seq_group in enumerate(input_metadata.seq_groups):
-        seq_ids, sampling_params = seq_group
-        if i < input_metadata.num_prompts:
-            assert len(seq_ids) == 1, "Prompt input should have only one seq."
-            prompt_len = input_metadata.prompt_lens[i]
-            if sampling_params.prompt_logprobs is not None:
-                selected_token_indices.extend(
-                    range(start_idx, start_idx + prompt_len - 1))
-            selected_token_indices.append(start_idx + prompt_len - 1)
-            start_idx += input_metadata.max_prompt_len
-        else:
-            num_seqs = len(seq_ids)
-            selected_token_indices.extend(
-                range(start_idx, start_idx + num_seqs))
-            start_idx += num_seqs
-
-    selected_token_indices = torch.tensor(selected_token_indices,
-                                          dtype=torch.long,
-                                          device=hidden_states.device)
     hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
-    return hidden_states.index_select(0, selected_token_indices)
+    return hidden_states.index_select(0, input_metadata.selected_token_indices)
 
 
 def _get_penalties(
@@ -657,20 +636,11 @@ def _sample(
     input_metadata: InputMetadata,
 ) -> List[Tuple[List[int], List[int]]]:
     categorized_seq_group_ids = {t: [] for t in SamplingType}
-    categorized_sample_indices = {t: [] for t in SamplingType}
-    start_idx = 0
+    categorized_sample_indices = input_metadata.categorized_sample_indices
     for i, seq_group in enumerate(input_metadata.seq_groups):
-        seq_ids, sampling_params = seq_group
+        _, sampling_params = seq_group
         sampling_type = sampling_params.sampling_type
-        if (i < input_metadata.num_prompts
-                and sampling_params.prompt_logprobs is not None):
-            prompt_len = input_metadata.prompt_lens[i]
-            start_idx += prompt_len - 1
         categorized_seq_group_ids[sampling_type].append(i)
-        num_seqs = len(seq_ids)
-        categorized_sample_indices[sampling_type].extend(
-            range(start_idx, start_idx + num_seqs))
-        start_idx += num_seqs
 
     sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
     for sampling_type in SamplingType:

+ 18 - 11
aphrodite/modeling/metadata.py

@@ -2,7 +2,7 @@ from typing import Dict, List, Tuple, Optional
 import torch
 from xformers.ops import AttentionBias
 
-from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SequenceData
 
 
@@ -28,6 +28,8 @@ class InputMetadata:
         context_lens: torch.Tensor,
         max_context_len: int,
         block_tables: torch.Tensor,
+        selected_token_indices: torch.Tensor,
+        categorized_sample_indices: Dict[SamplingType, torch.Tensor],
         sliding_window: Optional[int] = None,
     ) -> None:
         self.seq_groups = seq_groups
@@ -37,6 +39,8 @@ class InputMetadata:
         self.context_lens = context_lens
         self.max_context_len = max_context_len
         self.block_tables = block_tables
+        self.selected_token_indices = selected_token_indices
+        self.categorized_sample_indices = categorized_sample_indices
 
         self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
         self.to_cache = None
@@ -72,13 +76,16 @@ class InputMetadata:
 
     def __repr__(self) -> str:
         # Print only useful metadata.
-        return (f'InputMetadata('
-                f'num_prompt_tokens={self.num_prompt_tokens}, '
-                f'num_prompts={self.num_prompts}, '
-                f'prompt_lens={self.prompt_lens}, '
-                f'num_generation_tokens={self.num_generation_tokens}, '
-                f'context_lens={self.context_lens}, '
-                f'max_context_len={self.max_context_len}), '
-                f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
-                f'block_tables={self.block_tables}), '
-                f'slot_mapping={self.slot_mapping}')
+        return (
+            f'InputMetadata('
+            f'num_prompt_tokens={self.num_prompt_tokens}, '
+            f'num_prompts={self.num_prompts}, '
+            f'prompt_lens={self.prompt_lens}, '
+            f'num_generation_tokens={self.num_generation_tokens}, '
+            f'context_lens={self.context_lens}, '
+            f'max_context_len={self.max_context_len}), '
+            f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
+            f'block_tables={self.block_tables}, '
+            f'selected_token_indices={self.selected_token_indices}, '
+            f'categorized_sample_indices={self.categorized_sample_indices}, '
+            f'slot_mapping={self.slot_mapping})')

+ 45 - 2
aphrodite/task_handler/worker.py

@@ -10,7 +10,7 @@ from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
 from aphrodite.modeling import get_model, InputMetadata, set_random_seed
 from aphrodite.modeling.megatron.parallel_state import (
     initialize_model_parallel)
-from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.common.utils import get_gpu_memory, get_max_shared_memory_bytes
@@ -163,6 +163,10 @@ class Worker:
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
         slot_mapping: List[List[int]] = []
+        selected_token_indices: List[int] = []
+        selected_token_start_idx = 0
+        categorized_sample_indices = {t: [] for t in SamplingType}
+        categorized_sample_indices_start_idx = 0
 
         # Add prompt tokens.
         prompt_lens: List[int] = []
@@ -182,6 +186,14 @@ class Worker:
             prompt_len = len(prompt_tokens)
             prompt_lens.append(prompt_len)
 
+            if sampling_params.prompt_logprobs is not None:
+                # NOTE: prompt token positions do not need sample, skip
+                categorized_sample_indices_start_idx += prompt_len - 1
+
+            categorized_sample_indices[sampling_params.sampling_type].append(
+                categorized_sample_indices_start_idx)
+            categorized_sample_indices_start_idx += 1
+
             input_tokens.append(prompt_tokens)
             # NOTE: Here we assume that the first token in the prompt
             # is always the first token in the sequence.
@@ -207,14 +219,37 @@ class Worker:
         max_num_blocks_per_seq = 0
         context_lens: List[int] = []
         generation_block_tables: List[List[int]] = []
+        max_seq_len = max(prompt_lens) if prompt_lens else 1
         for seq_group_metadata in seq_group_metadata_list:
             if seq_group_metadata.is_prompt:
+                # We need to do this in this loop as we need to know max_seq_len
+                assert len(
+                    seq_ids) == 1, "Prompt input should have only one seq."
+                sampling_params = seq_group_metadata.sampling_params
+                if sampling_params.prompt_logprobs is not None:
+                    selected_token_indices.extend(
+                        range(selected_token_start_idx,
+                              selected_token_start_idx + prompt_len - 1))
+                selected_token_indices.append(selected_token_start_idx +
+                                              prompt_len - 1)
+                selected_token_start_idx += max_seq_len
                 continue
 
             seq_ids = list(seq_group_metadata.seq_data.keys())
             sampling_params = seq_group_metadata.sampling_params
             seq_groups.append((seq_ids, sampling_params))
 
+            num_seqs = len(seq_ids)
+            selected_token_indices.extend(
+                range(selected_token_start_idx,
+                      selected_token_start_idx + num_seqs))
+            selected_token_start_idx += num_seqs
+
+            categorized_sample_indices[sampling_params.sampling_type].extend(
+                range(categorized_sample_indices_start_idx,
+                      categorized_sample_indices_start_idx + num_seqs))
+            categorized_sample_indices_start_idx += num_seqs
+
             for seq_id in seq_ids:
                 seq_data = seq_group_metadata.seq_data[seq_id]
                 generation_token = seq_data.get_last_token_id()
@@ -246,7 +281,6 @@ class Worker:
                 generation_block_tables.append(block_table)
 
         # NOTE: This part was optimized!
-        max_seq_len = max(prompt_lens) if prompt_lens else 1
         padded_input_tokens = [
             _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
         ]
@@ -276,6 +310,13 @@ class Worker:
         context_lens_tensor = torch.tensor(context_lens,
                                            dtype=torch.int,
                                            device="cuda")
+        selected_token_indices = torch.tensor(selected_token_indices,
+                                              dtype=torch.long,
+                                              device="cuda")
+        categorized_sample_indices = {
+            t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
+            for t, seq_ids in categorized_sample_indices.items()
+        }
         block_tables_tensor = torch.tensor(padded_block_tables,
                                            dtype=torch.int,
                                            device="cuda")
@@ -292,6 +333,8 @@ class Worker:
             context_lens=context_lens_tensor,
             max_context_len=max_context_len,
             block_tables=block_tables_tensor,
+            selected_token_indices=selected_token_indices,
+            categorized_sample_indices=categorized_sample_indices,
             sliding_window=self.sliding_window,
         )
         return tokens_tensor, positions_tensor, input_metadata