|
@@ -10,7 +10,7 @@ from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
|
from aphrodite.modeling import get_model, InputMetadata, set_random_seed
|
|
|
from aphrodite.modeling.megatron.parallel_state import (
|
|
|
initialize_model_parallel)
|
|
|
-from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
|
+from aphrodite.common.sampling_params import SamplingParams
|
|
|
from aphrodite.common.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
|
|
from aphrodite.task_handler.cache_engine import CacheEngine
|
|
|
from aphrodite.common.utils import get_gpu_memory, get_max_shared_memory_bytes
|
|
@@ -163,10 +163,6 @@ class Worker:
|
|
|
input_tokens: List[List[int]] = []
|
|
|
input_positions: List[List[int]] = []
|
|
|
slot_mapping: List[List[int]] = []
|
|
|
- selected_token_indices: List[int] = []
|
|
|
- selected_token_start_idx = 0
|
|
|
- categorized_sample_indices = {t: [] for t in SamplingType}
|
|
|
- categorized_sample_indices_start_idx = 0
|
|
|
|
|
|
# Add prompt tokens.
|
|
|
prompt_lens: List[int] = []
|
|
@@ -186,14 +182,6 @@ class Worker:
|
|
|
prompt_len = len(prompt_tokens)
|
|
|
prompt_lens.append(prompt_len)
|
|
|
|
|
|
- if sampling_params.prompt_logprobs is not None:
|
|
|
- # NOTE: prompt token positions do not need sample, skip
|
|
|
- categorized_sample_indices_start_idx += prompt_len - 1
|
|
|
-
|
|
|
- categorized_sample_indices[sampling_params.sampling_type].append(
|
|
|
- categorized_sample_indices_start_idx)
|
|
|
- categorized_sample_indices_start_idx += 1
|
|
|
-
|
|
|
input_tokens.append(prompt_tokens)
|
|
|
# NOTE: Here we assume that the first token in the prompt
|
|
|
# is always the first token in the sequence.
|
|
@@ -219,37 +207,14 @@ class Worker:
|
|
|
max_num_blocks_per_seq = 0
|
|
|
context_lens: List[int] = []
|
|
|
generation_block_tables: List[List[int]] = []
|
|
|
- max_seq_len = max(prompt_lens) if prompt_lens else 1
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
if seq_group_metadata.is_prompt:
|
|
|
- # We need to do this in this loop as we need to know max_seq_len
|
|
|
- assert len(
|
|
|
- seq_ids) == 1, "Prompt input should have only one seq."
|
|
|
- sampling_params = seq_group_metadata.sampling_params
|
|
|
- if sampling_params.prompt_logprobs is not None:
|
|
|
- selected_token_indices.extend(
|
|
|
- range(selected_token_start_idx,
|
|
|
- selected_token_start_idx + prompt_len - 1))
|
|
|
- selected_token_indices.append(selected_token_start_idx +
|
|
|
- prompt_len - 1)
|
|
|
- selected_token_start_idx += max_seq_len
|
|
|
continue
|
|
|
|
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
|
sampling_params = seq_group_metadata.sampling_params
|
|
|
seq_groups.append((seq_ids, sampling_params))
|
|
|
|
|
|
- num_seqs = len(seq_ids)
|
|
|
- selected_token_indices.extend(
|
|
|
- range(selected_token_start_idx,
|
|
|
- selected_token_start_idx + num_seqs))
|
|
|
- selected_token_start_idx += num_seqs
|
|
|
-
|
|
|
- categorized_sample_indices[sampling_params.sampling_type].extend(
|
|
|
- range(categorized_sample_indices_start_idx,
|
|
|
- categorized_sample_indices_start_idx + num_seqs))
|
|
|
- categorized_sample_indices_start_idx += num_seqs
|
|
|
-
|
|
|
for seq_id in seq_ids:
|
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
|
generation_token = seq_data.get_last_token_id()
|
|
@@ -281,6 +246,7 @@ class Worker:
|
|
|
generation_block_tables.append(block_table)
|
|
|
|
|
|
# NOTE: This part was optimized!
|
|
|
+ max_seq_len = max(prompt_lens) if prompt_lens else 1
|
|
|
padded_input_tokens = [
|
|
|
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
|
|
|
]
|
|
@@ -310,13 +276,6 @@ class Worker:
|
|
|
context_lens_tensor = torch.tensor(context_lens,
|
|
|
dtype=torch.int,
|
|
|
device="cuda")
|
|
|
- selected_token_indices = torch.tensor(selected_token_indices,
|
|
|
- dtype=torch.long,
|
|
|
- device="cuda")
|
|
|
- categorized_sample_indices = {
|
|
|
- t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
|
|
- for t, seq_ids in categorized_sample_indices.items()
|
|
|
- }
|
|
|
block_tables_tensor = torch.tensor(padded_block_tables,
|
|
|
dtype=torch.int,
|
|
|
device="cuda")
|
|
@@ -333,8 +292,6 @@ class Worker:
|
|
|
context_lens=context_lens_tensor,
|
|
|
max_context_len=max_context_len,
|
|
|
block_tables=block_tables_tensor,
|
|
|
- selected_token_indices=selected_token_indices,
|
|
|
- categorized_sample_indices=categorized_sample_indices,
|
|
|
sliding_window=self.sliding_window,
|
|
|
)
|
|
|
return tokens_tensor, positions_tensor, input_metadata
|