"""A layer that compute logits from hidden_stats.""" from typing import Optional import torch import torch.nn as nn from aphrodite.distributed import tensor_model_parallel_gather from aphrodite.modeling.sampling_metadata import SamplingMetadata class LogitsProcessor(nn.Module): """Process logits and apply logits processors from sampling metadata. This layer does the following: 1. Gather logits from model hidden_states. 2. Scale logits if needed. 3. Apply logits processors (if any). """ def __init__(self, vocab_size: int, org_vocab_size: Optional[int] = None, scale: Optional[float] = 1.0, logits_as_input: bool = False) -> None: """ Args: scale: A scaling factor to apply to the logits. """ super().__init__() self.scale = scale self.vocab_size = vocab_size # Whether the input is logits (default is hidden states). self.logits_as_input = logits_as_input # original vocabulary size (without LoRA). if org_vocab_size is not None: self.org_vocab_size = min(org_vocab_size, vocab_size) else: self.org_vocab_size = vocab_size def forward( self, lm_head: nn.Module, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.logits_as_input: logits = hidden_states else: hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: logits *= self.scale # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) return logits def _get_logits(self, hidden_states: torch.Tensor, lm_head: nn.Module, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. logits = lm_head(hidden_states) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) # Remove paddings in vocab (if any). if logits is not None: logits = logits[:, :self.org_vocab_size] return logits def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: return hidden_states.index_select(0, sampling_metadata.selected_token_indices) def _apply_logits_processors( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: logits_row_idx = 0 found_logits_processors = False for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group logits_processors = sampling_params.logits_processors # handle prompt_logprobs by skipping rows in logits added for # the prompt tokens (prompt logprobs are not processed) if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): assert len(seq_ids) == 1 logits_row_idx += sampling_metadata.prompt_lens[i] - 1 if logits_processors: found_logits_processors = True for seq_id in seq_ids: logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: logits_row_idx += len(seq_ids) if found_logits_processors: # Ensure that no rows in logits were unexpectedly skipped. assert logits_row_idx == logits.shape[0] return logits