소스 검색

add new logits processor

AlpinDale 11 달 전
부모
커밋
fa6af97a5a
1개의 변경된 파일103개의 추가작업 그리고 0개의 파일을 삭제
  1. 103 0
      aphrodite/modeling/layers/logits_processor.py

+ 103 - 0
aphrodite/modeling/layers/logits_processor.py

@@ -0,0 +1,103 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from aphrodite.common.utils import is_neuron
+from aphrodite.modeling.megatron.communication_op import (
+    tensor_model_parallel_gather)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+
+
+class LogitsProcessor(nn.Module):
+    """Process logits and apply logits processors from sampling metadata.
+    This layer does the following:
+    1. Gather logits from model hidden_states.
+    2. Scale logits if needed.
+    3. Apply logits processors (if any).
+    """
+
+    def __init__(self,
+                 vocab_size: int,
+                 org_vocab_size: Optional[int] = None,
+                 scale: Optional[float] = 1.0) -> None:
+        """
+        Args:
+            scale: A scaling factor to apply to the logits.
+        """
+        super().__init__()
+        self.scale = scale
+        self.vocab_size = vocab_size
+        # Transformers-neuronx generate outputs as logits directly.
+        self.logits_as_hidden_states = is_neuron()
+        # original vocabulary size (without LoRA).
+        self.org_vocab_size = org_vocab_size or vocab_size
+
+    def forward(
+        self,
+        embedding: torch.Tensor,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+        embedding_bias: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if self.logits_as_hidden_states:
+            logits = hidden_states
+        else:
+            hidden_states = _prune_hidden_states(hidden_states,
+                                                 sampling_metadata)
+
+            # Get the logits for the next tokens.
+            logits = self._get_logits(hidden_states, embedding, embedding_bias)
+
+        if logits is not None:
+            logits *= self.scale
+
+            # Apply logits processors (if any).
+            logits = _apply_logits_processors(logits, sampling_metadata)
+
+        return logits
+
+    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
+                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
+        # Get the logits for the next tokens.
+        logits = torch.matmul(hidden_states, embedding.t())
+        if embedding_bias is not None:
+            logits += embedding_bias
+        logits = tensor_model_parallel_gather(logits)
+        # Remove paddings in vocab (if any).
+        if logits is not None:
+            logits = logits[:, :self.org_vocab_size]
+        return logits
+
+
+def _prune_hidden_states(
+    hidden_states: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
+    return hidden_states.index_select(0,
+                                      sampling_metadata.selected_token_indices)
+
+
+def _apply_logits_processors(
+    logits: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+    logits_row_idx = 0
+    found_logits_processors = False
+    for seq_ids, sampling_params in sampling_metadata.seq_groups:
+        logits_processors = sampling_params.logits_processors
+        if logits_processors:
+            found_logits_processors = True
+            for seq_id in seq_ids:
+                logits_row = logits[logits_row_idx]
+                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
+                for logits_processor in logits_processors:
+                    logits_row = logits_processor(token_ids, logits_row)
+                logits[logits_row_idx] = logits_row
+                logits_row_idx += 1
+        else:
+            logits_row_idx += len(seq_ids)
+    if found_logits_processors:
+        assert logits_row_idx == logits.shape[0]
+    return logits