1
0

logits_processor.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """A layer that compute logits from hidden_stats."""
  2. import inspect
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from aphrodite.distributed import tensor_model_parallel_gather
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  8. VocabParallelEmbedding
  9. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  10. class LogitsProcessor(nn.Module):
  11. """Process logits and apply logits processors from sampling metadata.
  12. This layer does the following:
  13. 1. Gather logits from model hidden_states.
  14. 2. Scale logits if needed.
  15. 3. Apply logits processors (if any).
  16. """
  17. def __init__(self,
  18. vocab_size: int,
  19. org_vocab_size: Optional[int] = None,
  20. scale: float = 1.0,
  21. logits_as_input: bool = False,
  22. soft_cap: Optional[float] = None) -> None:
  23. """
  24. Args:
  25. scale: A scaling factor to apply to the logits.
  26. """
  27. super().__init__()
  28. self.scale = scale
  29. self.vocab_size = vocab_size
  30. # Whether the input is logits (default is hidden states).
  31. self.logits_as_input = logits_as_input
  32. # original vocabulary size (without LoRA).
  33. self.org_vocab_size = org_vocab_size or vocab_size
  34. # Soft cap the logits. Used in Gemma 2.
  35. self.soft_cap = soft_cap
  36. def forward(
  37. self,
  38. lm_head: VocabParallelEmbedding,
  39. hidden_states: torch.Tensor,
  40. sampling_metadata: SamplingMetadata,
  41. embedding_bias: Optional[torch.Tensor] = None,
  42. ) -> torch.Tensor:
  43. if self.logits_as_input:
  44. logits = hidden_states
  45. else:
  46. hidden_states = _prune_hidden_states(hidden_states,
  47. sampling_metadata)
  48. # Get the logits for the next tokens.
  49. logits = self._get_logits(hidden_states, lm_head, embedding_bias)
  50. if logits is not None:
  51. if self.soft_cap is not None:
  52. logits = logits / self.soft_cap
  53. logits = torch.tanh(logits)
  54. logits = logits * self.soft_cap
  55. if self.scale != 1.0:
  56. logits *= self.scale
  57. # Apply logits processors (if any).
  58. logits = _apply_logits_processors(logits, sampling_metadata)
  59. return logits
  60. def _get_logits(self, hidden_states: torch.Tensor,
  61. lm_head: VocabParallelEmbedding,
  62. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  63. # Get the logits for the next tokens.
  64. logits = lm_head.linear_method.apply(lm_head,
  65. hidden_states,
  66. bias=embedding_bias)
  67. logits = tensor_model_parallel_gather(logits)
  68. # Remove paddings in vocab (if any).
  69. if logits is not None:
  70. logits = logits[:, :self.org_vocab_size]
  71. return logits
  72. def extra_repr(self) -> str:
  73. s = f"vocab_size={self.vocab_size}"
  74. s += f", forg_vocab_size={self.org_vocab_size}"
  75. s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
  76. return s
  77. def _prune_hidden_states(
  78. hidden_states: torch.Tensor,
  79. sampling_metadata: SamplingMetadata,
  80. ) -> torch.Tensor:
  81. return hidden_states.index_select(0,
  82. sampling_metadata.selected_token_indices)
  83. def _apply_logits_processors(
  84. logits: torch.Tensor,
  85. sampling_metadata: SamplingMetadata,
  86. ) -> torch.Tensor:
  87. found_logits_processors = False
  88. logits_processed = 0
  89. for seq_group in sampling_metadata.seq_groups:
  90. seq_ids = seq_group.seq_ids
  91. sampling_params = seq_group.sampling_params
  92. logits_processors = sampling_params.logits_processors
  93. if logits_processors:
  94. found_logits_processors = True
  95. for seq_id, logits_row_idx in zip(seq_ids,
  96. seq_group.sample_indices):
  97. logits_row = logits[logits_row_idx]
  98. past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
  99. prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
  100. for logits_processor in logits_processors:
  101. parameters = inspect.signature(logits_processor).parameters
  102. if len(parameters) == 3:
  103. logits_row = logits_processor(prompt_tokens_ids,
  104. past_tokens_ids,
  105. logits_row)
  106. else:
  107. logits_row = logits_processor(past_tokens_ids,
  108. logits_row)
  109. logits[logits_row_idx] = logits_row
  110. logits_processed += len(seq_group.sample_indices) + len(
  111. seq_group.prompt_logprob_indices)
  112. if found_logits_processors:
  113. # verifies that no rows in logits were missed unexpectedly
  114. assert logits_processed == logits.shape[0]
  115. return logits