1
0
AlpinDale 11 сар өмнө
parent
commit
f01c668259

+ 0 - 158
aphrodite/modeling/layers/sampler.py

@@ -7,14 +7,11 @@ import torch.nn as nn
 from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
                                                   OutputMetadata,
                                                   SamplingTensors)
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_gather)
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
                                        SamplerOutput, SequenceData,
                                        SequenceGroupOutput, SequenceOutput)
 from aphrodite.modeling.layers.ops.sample import sample as sample_triton
-from aphrodite.common.utils import is_neuron
 
 
 class Sampler(nn.Module):
@@ -32,118 +29,22 @@ class Sampler(nn.Module):
     parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
     """
 
-    def __init__(self,
-                 vocab_size: int,
-                 org_vocab_size: Optional[int] = None) -> None:
-        super().__init__()
-        self.vocab_size = vocab_size
-        # Transformers-neuronx generates outputs as logits directly
-        self.logits_as_hidden_states = is_neuron()
-        # original vocabulary size (without LoRA).
-        self.org_vocab_size = org_vocab_size or vocab_size
-
-    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
-                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
-        # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
-        if embedding_bias is not None:
-            logits += embedding_bias
-        logits = tensor_model_parallel_gather(logits)
-        # Remove paddings in vocab (if any).
-        if logits is not None:
-            logits = logits[:, :self.org_vocab_size]
-        return logits
-
-    def forward(
-        self,
-        embedding: torch.Tensor,
-        hidden_states: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
-        embedding_bias: Optional[torch.Tensor] = None,
-    ) -> Optional[SamplerOutput]:
-        # Get the hidden states that we use for sampling.
-        if self.logits_as_hidden_states:
-            logits = hidden_states
-        else:
-            hidden_states = _prune_hidden_states(hidden_states,
-                                                 sampling_metadata)
-
-            # Get the logits for the next tokens.
-            logits = self._get_logits(hidden_states, embedding, embedding_bias)
-
-        return _perform_sampling(logits, sampling_metadata)
-
-
-# FIXME: This is a hack for the missing GPU blocks. This should be removed
-# once a proper fix is implemented.
-class QuantSampler(nn.Module):
-    """Samples the next tokens from the model's outputs.
-
-    This layer does the following:
-    1. Discard the hidden states that are not used for sampling (i.e., all
-        tokens except the final one in each prompt).
-    2. Compute the logits for the next tokens.
-    3. Apply presence and frequency penalties.
-    4. Apply temperature scaling.
-    5. Apply top-p and top-k truncation.
-    6. Sample the next tokens.
-    Here, each sequence group within the batch can have different sampling
-    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
-    """
-
-    def __init__(self,
-                 vocab_size: int,
-                 org_vocab_size: Optional[int] = None) -> None:
-        super().__init__()
-        self.vocab_size = vocab_size
-        # original vocabulary size (without LoRA).
-        self.org_vocab_size = org_vocab_size or vocab_size
-
-    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
-                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
-        # Get the logits for the next tokens.
-        logits = torch.matmul(hidden_states, embedding.t())
-        if embedding_bias is not None:
-            logits += embedding_bias
-        logits = tensor_model_parallel_gather(logits)
-        # Remove paddings in vocab (if any).
-        if logits is not None:
-            logits = logits[:, :self.org_vocab_size]
-        return logits
-
     def forward(
         self,
         logits: torch.Tensor,
         sampling_metadata: SamplingMetadata,
     ) -> Optional[SamplerOutput]:
-        # Get the hidden states that we use for sampling.
-        logits = _prune_hidden_states(logits, sampling_metadata)
-        logits = tensor_model_parallel_gather(logits)
-        # Remove paddings in vocab (if any).
-        if logits is not None:
-            logits = logits[:, :self.vocab_size]
 
         return _perform_sampling(logits, sampling_metadata)
 
-
 def _perform_sampling(
         logits: torch.Tensor,
         sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
-    # Only perform sampling in the driver worker.
-    # Note: `_get_logits` is still distributed across TP workers because
-    # the `embedding` weight is distributed across TP workers.
-    # TODO: Change the get_logits part to a separate stage.
-    if not sampling_metadata.perform_sampling:
-        return None
-
     assert logits is not None
     _, vocab_size = logits.shape
 
     output_metadata = OutputMetadata()
 
-    # Apply logits processors (if any)
-    logits = _apply_logits_processors(logits, sampling_metadata)
-
     # Prepare sampling tensors with pinned memory to avoid blocking.
     (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
      do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
@@ -207,15 +108,6 @@ def _perform_sampling(
                                  prompt_logprobs, sample_logprobs,
                                  output_metadata)
 
-
-def _prune_hidden_states(
-    hidden_states: torch.Tensor,
-    sampling_metadata: SamplingMetadata,
-) -> torch.Tensor:
-    return hidden_states.index_select(0,
-                                      sampling_metadata.selected_token_indices)
-
-
 def _get_bin_counts_and_mask(
     tokens: torch.Tensor,
     vocab_size: int,
@@ -247,56 +139,6 @@ def _get_custom_token_bans(
     return banned_tokens
 
 
-# def _apply_logits_processors(
-#     logits: torch.Tensor,
-#     metadata: SamplingMetadata,
-# ) -> torch.Tensor:
-#     seq_offset = 0
-#     for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
-#         seq_size = len(seq_ids)
-#         output_tokens = []
-#         if (i < metadata.num_prompts
-#                 and sampling_params.prompt_logprobs is not None):
-#             prompt_seqs = metadata.prompt_lens[i] - 1
-#             seq_size += prompt_seqs
-#             output_tokens.extend([[]] * prompt_seqs)
-#         seq_end = seq_offset + seq_size
-
-#         if sampling_params.logits_processors:
-#             output_tokens.extend(metadata.seq_data[sid].output_token_ids
-#                                  for sid in seq_ids)
-#             for proc in sampling_params.logits_processors:
-#                 proc(logits[seq_offset:seq_end], output_tokens)
-
-#         seq_offset = seq_end
-
-#     return logits
-
-
-def _apply_logits_processors(
-    logits: torch.Tensor,
-    sampling_metadata: SamplingMetadata,
-) -> torch.Tensor:
-    logits_row_idx = 0
-    found_logits_processors = False
-    for seq_ids, sampling_params in sampling_metadata.seq_groups:
-        logits_processors = sampling_params.logits_processors
-        if logits_processors:
-            found_logits_processors = True
-            for seq_id in seq_ids:
-                logits_row = logits[logits_row_idx]
-                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
-                for logits_processor in logits_processors:
-                    logits_row = logits_processor(token_ids, logits_row)
-                logits[logits_row_idx] = logits_row
-                logits_row_idx += 1
-        else:
-            logits_row_idx += len(seq_ids)
-    if found_logits_processors:
-        assert logits_row_idx == logits.shape[0]
-    return logits
-
-
 def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                      output_tokens_tensor: torch.Tensor,
                      presence_penalties: torch.Tensor,