Kaynağa Gözat

Revert "fix: sync CPU delay in sampler (#93)"

This reverts commit ce66e1df56ac3cba35878821b692d101bfb0a3cf.
AlpinDale 1 yıl önce
ebeveyn
işleme
69204736de

+ 33 - 3
aphrodite/modeling/layers/sampler.py

@@ -144,8 +144,29 @@ def _prune_hidden_states(
     hidden_states: torch.Tensor,
     input_metadata: InputMetadata,
 ) -> torch.Tensor:
+    selected_token_indices: List[int] = []
+    start_idx = 0
+    for i, seq_group in enumerate(input_metadata.seq_groups):
+        seq_ids, sampling_params = seq_group
+        if i < input_metadata.num_prompts:
+            assert len(seq_ids) == 1, "Prompt input should have only one seq."
+            prompt_len = input_metadata.prompt_lens[i]
+            if sampling_params.prompt_logprobs is not None:
+                selected_token_indices.extend(
+                    range(start_idx, start_idx + prompt_len - 1))
+            selected_token_indices.append(start_idx + prompt_len - 1)
+            start_idx += input_metadata.max_prompt_len
+        else:
+            num_seqs = len(seq_ids)
+            selected_token_indices.extend(
+                range(start_idx, start_idx + num_seqs))
+            start_idx += num_seqs
+
+    selected_token_indices = torch.tensor(selected_token_indices,
+                                          dtype=torch.long,
+                                          device=hidden_states.device)
     hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
-    return hidden_states.index_select(0, input_metadata.selected_token_indices)
+    return hidden_states.index_select(0, selected_token_indices)
 
 
 def _get_penalties(
@@ -636,11 +657,20 @@ def _sample(
     input_metadata: InputMetadata,
 ) -> List[Tuple[List[int], List[int]]]:
     categorized_seq_group_ids = {t: [] for t in SamplingType}
-    categorized_sample_indices = input_metadata.categorized_sample_indices
+    categorized_sample_indices = {t: [] for t in SamplingType}
+    start_idx = 0
     for i, seq_group in enumerate(input_metadata.seq_groups):
-        _, sampling_params = seq_group
+        seq_ids, sampling_params = seq_group
         sampling_type = sampling_params.sampling_type
+        if (i < input_metadata.num_prompts
+                and sampling_params.prompt_logprobs is not None):
+            prompt_len = input_metadata.prompt_lens[i]
+            start_idx += prompt_len - 1
         categorized_seq_group_ids[sampling_type].append(i)
+        num_seqs = len(seq_ids)
+        categorized_sample_indices[sampling_type].extend(
+            range(start_idx, start_idx + num_seqs))
+        start_idx += num_seqs
 
     sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
     for sampling_type in SamplingType:

+ 11 - 18
aphrodite/modeling/metadata.py

@@ -2,7 +2,7 @@ from typing import Dict, List, Tuple, Optional
 import torch
 from xformers.ops import AttentionBias
 
-from aphrodite.common.sampling_params import SamplingParams, SamplingType
+from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import SequenceData
 
 
@@ -28,8 +28,6 @@ class InputMetadata:
         context_lens: torch.Tensor,
         max_context_len: int,
         block_tables: torch.Tensor,
-        selected_token_indices: torch.Tensor,
-        categorized_sample_indices: Dict[SamplingType, torch.Tensor],
         sliding_window: Optional[int] = None,
     ) -> None:
         self.seq_groups = seq_groups
@@ -39,8 +37,6 @@ class InputMetadata:
         self.context_lens = context_lens
         self.max_context_len = max_context_len
         self.block_tables = block_tables
-        self.selected_token_indices = selected_token_indices
-        self.categorized_sample_indices = categorized_sample_indices
 
         self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
         self.to_cache = None
@@ -76,16 +72,13 @@ class InputMetadata:
 
     def __repr__(self) -> str:
         # Print only useful metadata.
-        return (
-            f'InputMetadata('
-            f'num_prompt_tokens={self.num_prompt_tokens}, '
-            f'num_prompts={self.num_prompts}, '
-            f'prompt_lens={self.prompt_lens}, '
-            f'num_generation_tokens={self.num_generation_tokens}, '
-            f'context_lens={self.context_lens}, '
-            f'max_context_len={self.max_context_len}), '
-            f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
-            f'block_tables={self.block_tables}, '
-            f'selected_token_indices={self.selected_token_indices}, '
-            f'categorized_sample_indices={self.categorized_sample_indices}, '
-            f'slot_mapping={self.slot_mapping})')
+        return (f'InputMetadata('
+                f'num_prompt_tokens={self.num_prompt_tokens}, '
+                f'num_prompts={self.num_prompts}, '
+                f'prompt_lens={self.prompt_lens}, '
+                f'num_generation_tokens={self.num_generation_tokens}, '
+                f'context_lens={self.context_lens}, '
+                f'max_context_len={self.max_context_len}), '
+                f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
+                f'block_tables={self.block_tables}), '
+                f'slot_mapping={self.slot_mapping}')

+ 2 - 45
aphrodite/task_handler/worker.py

@@ -10,7 +10,7 @@ from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
 from aphrodite.modeling import get_model, InputMetadata, set_random_seed
 from aphrodite.modeling.megatron.parallel_state import (
     initialize_model_parallel)
-from aphrodite.common.sampling_params import SamplingParams, SamplingType
+from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.common.utils import get_gpu_memory, get_max_shared_memory_bytes
@@ -163,10 +163,6 @@ class Worker:
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
         slot_mapping: List[List[int]] = []
-        selected_token_indices: List[int] = []
-        selected_token_start_idx = 0
-        categorized_sample_indices = {t: [] for t in SamplingType}
-        categorized_sample_indices_start_idx = 0
 
         # Add prompt tokens.
         prompt_lens: List[int] = []
@@ -186,14 +182,6 @@ class Worker:
             prompt_len = len(prompt_tokens)
             prompt_lens.append(prompt_len)
 
-            if sampling_params.prompt_logprobs is not None:
-                # NOTE: prompt token positions do not need sample, skip
-                categorized_sample_indices_start_idx += prompt_len - 1
-
-            categorized_sample_indices[sampling_params.sampling_type].append(
-                categorized_sample_indices_start_idx)
-            categorized_sample_indices_start_idx += 1
-
             input_tokens.append(prompt_tokens)
             # NOTE: Here we assume that the first token in the prompt
             # is always the first token in the sequence.
@@ -219,37 +207,14 @@ class Worker:
         max_num_blocks_per_seq = 0
         context_lens: List[int] = []
         generation_block_tables: List[List[int]] = []
-        max_seq_len = max(prompt_lens) if prompt_lens else 1
         for seq_group_metadata in seq_group_metadata_list:
             if seq_group_metadata.is_prompt:
-                # We need to do this in this loop as we need to know max_seq_len
-                assert len(
-                    seq_ids) == 1, "Prompt input should have only one seq."
-                sampling_params = seq_group_metadata.sampling_params
-                if sampling_params.prompt_logprobs is not None:
-                    selected_token_indices.extend(
-                        range(selected_token_start_idx,
-                              selected_token_start_idx + prompt_len - 1))
-                selected_token_indices.append(selected_token_start_idx +
-                                              prompt_len - 1)
-                selected_token_start_idx += max_seq_len
                 continue
 
             seq_ids = list(seq_group_metadata.seq_data.keys())
             sampling_params = seq_group_metadata.sampling_params
             seq_groups.append((seq_ids, sampling_params))
 
-            num_seqs = len(seq_ids)
-            selected_token_indices.extend(
-                range(selected_token_start_idx,
-                      selected_token_start_idx + num_seqs))
-            selected_token_start_idx += num_seqs
-
-            categorized_sample_indices[sampling_params.sampling_type].extend(
-                range(categorized_sample_indices_start_idx,
-                      categorized_sample_indices_start_idx + num_seqs))
-            categorized_sample_indices_start_idx += num_seqs
-
             for seq_id in seq_ids:
                 seq_data = seq_group_metadata.seq_data[seq_id]
                 generation_token = seq_data.get_last_token_id()
@@ -281,6 +246,7 @@ class Worker:
                 generation_block_tables.append(block_table)
 
         # NOTE: This part was optimized!
+        max_seq_len = max(prompt_lens) if prompt_lens else 1
         padded_input_tokens = [
             _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
         ]
@@ -310,13 +276,6 @@ class Worker:
         context_lens_tensor = torch.tensor(context_lens,
                                            dtype=torch.int,
                                            device="cuda")
-        selected_token_indices = torch.tensor(selected_token_indices,
-                                              dtype=torch.long,
-                                              device="cuda")
-        categorized_sample_indices = {
-            t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
-            for t, seq_ids in categorized_sample_indices.items()
-        }
         block_tables_tensor = torch.tensor(padded_block_tables,
                                            dtype=torch.int,
                                            device="cuda")
@@ -333,8 +292,6 @@ class Worker:
             context_lens=context_lens_tensor,
             max_context_len=max_context_len,
             block_tables=block_tables_tensor,
-            selected_token_indices=selected_token_indices,
-            categorized_sample_indices=categorized_sample_indices,
             sliding_window=self.sliding_window,
         )
         return tokens_tensor, positions_tensor, input_metadata