|
@@ -0,0 +1,103 @@
|
|
|
|
+from typing import Optional
|
|
|
|
+
|
|
|
|
+import torch
|
|
|
|
+import torch.nn as nn
|
|
|
|
+
|
|
|
|
+from aphrodite.common.utils import is_neuron
|
|
|
|
+from aphrodite.modeling.megatron.communication_op import (
|
|
|
|
+ tensor_model_parallel_gather)
|
|
|
|
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class LogitsProcessor(nn.Module):
|
|
|
|
+ """Process logits and apply logits processors from sampling metadata.
|
|
|
|
+ This layer does the following:
|
|
|
|
+ 1. Gather logits from model hidden_states.
|
|
|
|
+ 2. Scale logits if needed.
|
|
|
|
+ 3. Apply logits processors (if any).
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ def __init__(self,
|
|
|
|
+ vocab_size: int,
|
|
|
|
+ org_vocab_size: Optional[int] = None,
|
|
|
|
+ scale: Optional[float] = 1.0) -> None:
|
|
|
|
+ """
|
|
|
|
+ Args:
|
|
|
|
+ scale: A scaling factor to apply to the logits.
|
|
|
|
+ """
|
|
|
|
+ super().__init__()
|
|
|
|
+ self.scale = scale
|
|
|
|
+ self.vocab_size = vocab_size
|
|
|
|
+ # Transformers-neuronx generate outputs as logits directly.
|
|
|
|
+ self.logits_as_hidden_states = is_neuron()
|
|
|
|
+ # original vocabulary size (without LoRA).
|
|
|
|
+ self.org_vocab_size = org_vocab_size or vocab_size
|
|
|
|
+
|
|
|
|
+ def forward(
|
|
|
|
+ self,
|
|
|
|
+ embedding: torch.Tensor,
|
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+ embedding_bias: Optional[torch.Tensor] = None,
|
|
|
|
+ ) -> torch.Tensor:
|
|
|
|
+ if self.logits_as_hidden_states:
|
|
|
|
+ logits = hidden_states
|
|
|
|
+ else:
|
|
|
|
+ hidden_states = _prune_hidden_states(hidden_states,
|
|
|
|
+ sampling_metadata)
|
|
|
|
+
|
|
|
|
+ # Get the logits for the next tokens.
|
|
|
|
+ logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
|
|
|
+
|
|
|
|
+ if logits is not None:
|
|
|
|
+ logits *= self.scale
|
|
|
|
+
|
|
|
|
+ # Apply logits processors (if any).
|
|
|
|
+ logits = _apply_logits_processors(logits, sampling_metadata)
|
|
|
|
+
|
|
|
|
+ return logits
|
|
|
|
+
|
|
|
|
+ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
|
|
|
|
+ embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
|
|
+ # Get the logits for the next tokens.
|
|
|
|
+ logits = torch.matmul(hidden_states, embedding.t())
|
|
|
|
+ if embedding_bias is not None:
|
|
|
|
+ logits += embedding_bias
|
|
|
|
+ logits = tensor_model_parallel_gather(logits)
|
|
|
|
+ # Remove paddings in vocab (if any).
|
|
|
|
+ if logits is not None:
|
|
|
|
+ logits = logits[:, :self.org_vocab_size]
|
|
|
|
+ return logits
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _prune_hidden_states(
|
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
|
|
+ return hidden_states.index_select(0,
|
|
|
|
+ sampling_metadata.selected_token_indices)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def _apply_logits_processors(
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+) -> torch.Tensor:
|
|
|
|
+ logits_row_idx = 0
|
|
|
|
+ found_logits_processors = False
|
|
|
|
+ for seq_ids, sampling_params in sampling_metadata.seq_groups:
|
|
|
|
+ logits_processors = sampling_params.logits_processors
|
|
|
|
+ if logits_processors:
|
|
|
|
+ found_logits_processors = True
|
|
|
|
+ for seq_id in seq_ids:
|
|
|
|
+ logits_row = logits[logits_row_idx]
|
|
|
|
+ token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
|
|
|
|
+ for logits_processor in logits_processors:
|
|
|
|
+ logits_row = logits_processor(token_ids, logits_row)
|
|
|
|
+ logits[logits_row_idx] = logits_row
|
|
|
|
+ logits_row_idx += 1
|
|
|
|
+ else:
|
|
|
|
+ logits_row_idx += len(seq_ids)
|
|
|
|
+ if found_logits_processors:
|
|
|
|
+ assert logits_row_idx == logits.shape[0]
|
|
|
|
+ return logits
|