|
@@ -7,14 +7,11 @@ import torch.nn as nn
|
|
from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
|
|
OutputMetadata,
|
|
OutputMetadata,
|
|
SamplingTensors)
|
|
SamplingTensors)
|
|
-from aphrodite.modeling.megatron.communication_op import (
|
|
|
|
- tensor_model_parallel_gather)
|
|
|
|
from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
from aphrodite.common.sampling_params import SamplingParams, SamplingType
|
|
from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
|
from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
|
SamplerOutput, SequenceData,
|
|
SamplerOutput, SequenceData,
|
|
SequenceGroupOutput, SequenceOutput)
|
|
SequenceGroupOutput, SequenceOutput)
|
|
from aphrodite.modeling.layers.ops.sample import sample as sample_triton
|
|
from aphrodite.modeling.layers.ops.sample import sample as sample_triton
|
|
-from aphrodite.common.utils import is_neuron
|
|
|
|
|
|
|
|
|
|
|
|
class Sampler(nn.Module):
|
|
class Sampler(nn.Module):
|
|
@@ -32,118 +29,22 @@ class Sampler(nn.Module):
|
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self,
|
|
|
|
- vocab_size: int,
|
|
|
|
- org_vocab_size: Optional[int] = None) -> None:
|
|
|
|
- super().__init__()
|
|
|
|
- self.vocab_size = vocab_size
|
|
|
|
- # Transformers-neuronx generates outputs as logits directly
|
|
|
|
- self.logits_as_hidden_states = is_neuron()
|
|
|
|
- # original vocabulary size (without LoRA).
|
|
|
|
- self.org_vocab_size = org_vocab_size or vocab_size
|
|
|
|
-
|
|
|
|
- def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
|
|
- embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
|
|
- # Get the logits for the next tokens.
|
|
|
|
- logits = torch.matmul(hidden_states, embedding.t())
|
|
|
|
- if embedding_bias is not None:
|
|
|
|
- logits += embedding_bias
|
|
|
|
- logits = tensor_model_parallel_gather(logits)
|
|
|
|
- # Remove paddings in vocab (if any).
|
|
|
|
- if logits is not None:
|
|
|
|
- logits = logits[:, :self.org_vocab_size]
|
|
|
|
- return logits
|
|
|
|
-
|
|
|
|
- def forward(
|
|
|
|
- self,
|
|
|
|
- embedding: torch.Tensor,
|
|
|
|
- hidden_states: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
- embedding_bias: Optional[torch.Tensor] = None,
|
|
|
|
- ) -> Optional[SamplerOutput]:
|
|
|
|
- # Get the hidden states that we use for sampling.
|
|
|
|
- if self.logits_as_hidden_states:
|
|
|
|
- logits = hidden_states
|
|
|
|
- else:
|
|
|
|
- hidden_states = _prune_hidden_states(hidden_states,
|
|
|
|
- sampling_metadata)
|
|
|
|
-
|
|
|
|
- # Get the logits for the next tokens.
|
|
|
|
- logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
|
|
|
-
|
|
|
|
- return _perform_sampling(logits, sampling_metadata)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-# FIXME: This is a hack for the missing GPU blocks. This should be removed
|
|
|
|
-# once a proper fix is implemented.
|
|
|
|
-class QuantSampler(nn.Module):
|
|
|
|
- """Samples the next tokens from the model's outputs.
|
|
|
|
-
|
|
|
|
- This layer does the following:
|
|
|
|
- 1. Discard the hidden states that are not used for sampling (i.e., all
|
|
|
|
- tokens except the final one in each prompt).
|
|
|
|
- 2. Compute the logits for the next tokens.
|
|
|
|
- 3. Apply presence and frequency penalties.
|
|
|
|
- 4. Apply temperature scaling.
|
|
|
|
- 5. Apply top-p and top-k truncation.
|
|
|
|
- 6. Sample the next tokens.
|
|
|
|
- Here, each sequence group within the batch can have different sampling
|
|
|
|
- parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
|
|
|
- """
|
|
|
|
-
|
|
|
|
- def __init__(self,
|
|
|
|
- vocab_size: int,
|
|
|
|
- org_vocab_size: Optional[int] = None) -> None:
|
|
|
|
- super().__init__()
|
|
|
|
- self.vocab_size = vocab_size
|
|
|
|
- # original vocabulary size (without LoRA).
|
|
|
|
- self.org_vocab_size = org_vocab_size or vocab_size
|
|
|
|
-
|
|
|
|
- def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
|
|
- embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
|
|
- # Get the logits for the next tokens.
|
|
|
|
- logits = torch.matmul(hidden_states, embedding.t())
|
|
|
|
- if embedding_bias is not None:
|
|
|
|
- logits += embedding_bias
|
|
|
|
- logits = tensor_model_parallel_gather(logits)
|
|
|
|
- # Remove paddings in vocab (if any).
|
|
|
|
- if logits is not None:
|
|
|
|
- logits = logits[:, :self.org_vocab_size]
|
|
|
|
- return logits
|
|
|
|
-
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
) -> Optional[SamplerOutput]:
|
|
- # Get the hidden states that we use for sampling.
|
|
|
|
- logits = _prune_hidden_states(logits, sampling_metadata)
|
|
|
|
- logits = tensor_model_parallel_gather(logits)
|
|
|
|
- # Remove paddings in vocab (if any).
|
|
|
|
- if logits is not None:
|
|
|
|
- logits = logits[:, :self.vocab_size]
|
|
|
|
|
|
|
|
return _perform_sampling(logits, sampling_metadata)
|
|
return _perform_sampling(logits, sampling_metadata)
|
|
|
|
|
|
-
|
|
|
|
def _perform_sampling(
|
|
def _perform_sampling(
|
|
logits: torch.Tensor,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
|
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
|
- # Only perform sampling in the driver worker.
|
|
|
|
- # Note: `_get_logits` is still distributed across TP workers because
|
|
|
|
- # the `embedding` weight is distributed across TP workers.
|
|
|
|
- # TODO: Change the get_logits part to a separate stage.
|
|
|
|
- if not sampling_metadata.perform_sampling:
|
|
|
|
- return None
|
|
|
|
-
|
|
|
|
assert logits is not None
|
|
assert logits is not None
|
|
_, vocab_size = logits.shape
|
|
_, vocab_size = logits.shape
|
|
|
|
|
|
output_metadata = OutputMetadata()
|
|
output_metadata = OutputMetadata()
|
|
|
|
|
|
- # Apply logits processors (if any)
|
|
|
|
- logits = _apply_logits_processors(logits, sampling_metadata)
|
|
|
|
-
|
|
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
# Prepare sampling tensors with pinned memory to avoid blocking.
|
|
(sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
|
|
(sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
|
|
do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
|
|
do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
|
|
@@ -207,15 +108,6 @@ def _perform_sampling(
|
|
prompt_logprobs, sample_logprobs,
|
|
prompt_logprobs, sample_logprobs,
|
|
output_metadata)
|
|
output_metadata)
|
|
|
|
|
|
-
|
|
|
|
-def _prune_hidden_states(
|
|
|
|
- hidden_states: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
-) -> torch.Tensor:
|
|
|
|
- return hidden_states.index_select(0,
|
|
|
|
- sampling_metadata.selected_token_indices)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def _get_bin_counts_and_mask(
|
|
def _get_bin_counts_and_mask(
|
|
tokens: torch.Tensor,
|
|
tokens: torch.Tensor,
|
|
vocab_size: int,
|
|
vocab_size: int,
|
|
@@ -247,56 +139,6 @@ def _get_custom_token_bans(
|
|
return banned_tokens
|
|
return banned_tokens
|
|
|
|
|
|
|
|
|
|
-# def _apply_logits_processors(
|
|
|
|
-# logits: torch.Tensor,
|
|
|
|
-# metadata: SamplingMetadata,
|
|
|
|
-# ) -> torch.Tensor:
|
|
|
|
-# seq_offset = 0
|
|
|
|
-# for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
|
|
|
|
-# seq_size = len(seq_ids)
|
|
|
|
-# output_tokens = []
|
|
|
|
-# if (i < metadata.num_prompts
|
|
|
|
-# and sampling_params.prompt_logprobs is not None):
|
|
|
|
-# prompt_seqs = metadata.prompt_lens[i] - 1
|
|
|
|
-# seq_size += prompt_seqs
|
|
|
|
-# output_tokens.extend([[]] * prompt_seqs)
|
|
|
|
-# seq_end = seq_offset + seq_size
|
|
|
|
-
|
|
|
|
-# if sampling_params.logits_processors:
|
|
|
|
-# output_tokens.extend(metadata.seq_data[sid].output_token_ids
|
|
|
|
-# for sid in seq_ids)
|
|
|
|
-# for proc in sampling_params.logits_processors:
|
|
|
|
-# proc(logits[seq_offset:seq_end], output_tokens)
|
|
|
|
-
|
|
|
|
-# seq_offset = seq_end
|
|
|
|
-
|
|
|
|
-# return logits
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def _apply_logits_processors(
|
|
|
|
- logits: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata,
|
|
|
|
-) -> torch.Tensor:
|
|
|
|
- logits_row_idx = 0
|
|
|
|
- found_logits_processors = False
|
|
|
|
- for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
|
|
|
- logits_processors = sampling_params.logits_processors
|
|
|
|
- if logits_processors:
|
|
|
|
- found_logits_processors = True
|
|
|
|
- for seq_id in seq_ids:
|
|
|
|
- logits_row = logits[logits_row_idx]
|
|
|
|
- token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
|
|
|
- for logits_processor in logits_processors:
|
|
|
|
- logits_row = logits_processor(token_ids, logits_row)
|
|
|
|
- logits[logits_row_idx] = logits_row
|
|
|
|
- logits_row_idx += 1
|
|
|
|
- else:
|
|
|
|
- logits_row_idx += len(seq_ids)
|
|
|
|
- if found_logits_processors:
|
|
|
|
- assert logits_row_idx == logits.shape[0]
|
|
|
|
- return logits
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|