logits_processor.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """A layer that compute logits from hidden_stats."""
  2. import inspect
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from aphrodite.distributed import (tensor_model_parallel_all_gather,
  7. tensor_model_parallel_gather)
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  9. VocabParallelEmbedding)
  10. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  11. from aphrodite.platforms import current_platform
  12. class LogitsProcessor(nn.Module):
  13. """Process logits and apply logits processors from sampling metadata.
  14. This layer does the following:
  15. 1. Gather logits from model hidden_states.
  16. 2. Scale logits if needed.
  17. 3. Apply logits processors (if any).
  18. """
  19. def __init__(self,
  20. vocab_size: int,
  21. org_vocab_size: Optional[int] = None,
  22. scale: float = 1.0,
  23. logits_as_input: bool = False,
  24. soft_cap: Optional[float] = None) -> None:
  25. """
  26. Args:
  27. scale: A scaling factor to apply to the logits.
  28. """
  29. super().__init__()
  30. self.scale = scale
  31. self.vocab_size = vocab_size
  32. # Whether the input is logits (default is hidden states).
  33. self.logits_as_input = logits_as_input
  34. # original vocabulary size (without LoRA).
  35. self.org_vocab_size = org_vocab_size or vocab_size
  36. # Soft cap the logits. Used in Gemma 2.
  37. self.soft_cap = soft_cap
  38. # Whether to use gather or all-gather to gather the logits.
  39. self.use_gather = not current_platform.is_tpu()
  40. def forward(
  41. self,
  42. lm_head: VocabParallelEmbedding,
  43. hidden_states: torch.Tensor,
  44. sampling_metadata: SamplingMetadata,
  45. embedding_bias: Optional[torch.Tensor] = None,
  46. ) -> Optional[torch.Tensor]:
  47. if self.logits_as_input:
  48. logits = hidden_states
  49. else:
  50. hidden_states = _prune_hidden_states(hidden_states,
  51. sampling_metadata)
  52. # Get the logits for the next tokens.
  53. logits = self._get_logits(hidden_states, lm_head, embedding_bias)
  54. if logits is not None:
  55. if self.soft_cap is not None:
  56. logits = logits / self.soft_cap
  57. logits = torch.tanh(logits)
  58. logits = logits * self.soft_cap
  59. if self.scale != 1.0:
  60. logits *= self.scale
  61. # Apply logits processors (if any).
  62. logits = _apply_logits_processors(logits, sampling_metadata)
  63. return logits
  64. def _get_logits(
  65. self,
  66. hidden_states: torch.Tensor,
  67. lm_head: VocabParallelEmbedding,
  68. embedding_bias: Optional[torch.Tensor],
  69. ) -> Optional[torch.Tensor]:
  70. # Get the logits for the next tokens.
  71. logits = lm_head.linear_method.apply(lm_head,
  72. hidden_states,
  73. bias=embedding_bias)
  74. if self.use_gather:
  75. # None may be returned for rank > 0
  76. logits = tensor_model_parallel_gather(logits)
  77. else:
  78. # Gather is not supported for some devices such as TPUs.
  79. # Use all-gather instead.
  80. # NOTE: Here, the outputs of every device should not be None
  81. # because XLA requires strict SPMD among all devices. Every device
  82. # should execute the same operations after gathering the logits.
  83. logits = tensor_model_parallel_all_gather(logits)
  84. # Remove paddings in vocab (if any).
  85. if logits is not None:
  86. logits = logits[..., :self.org_vocab_size]
  87. return logits
  88. def extra_repr(self) -> str:
  89. s = f"vocab_size={self.vocab_size}"
  90. s += f", forg_vocab_size={self.org_vocab_size}"
  91. s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
  92. return s
  93. def _prune_hidden_states(
  94. hidden_states: torch.Tensor,
  95. sampling_metadata: SamplingMetadata,
  96. ) -> torch.Tensor:
  97. return hidden_states.index_select(0,
  98. sampling_metadata.selected_token_indices)
  99. def _apply_logits_processors(
  100. logits: torch.Tensor,
  101. sampling_metadata: SamplingMetadata,
  102. ) -> torch.Tensor:
  103. found_logits_processors = False
  104. logits_processed = 0
  105. for seq_group in sampling_metadata.seq_groups:
  106. seq_ids = seq_group.seq_ids
  107. sampling_params = seq_group.sampling_params
  108. logits_processors = sampling_params.logits_processors
  109. if logits_processors:
  110. found_logits_processors = True
  111. for seq_id, logits_row_idx in zip(seq_ids,
  112. seq_group.sample_indices):
  113. logits_row = logits[logits_row_idx]
  114. past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
  115. prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
  116. for logits_processor in logits_processors:
  117. parameters = inspect.signature(logits_processor).parameters
  118. if len(parameters) == 3:
  119. logits_row = logits_processor(prompt_tokens_ids,
  120. past_tokens_ids,
  121. logits_row)
  122. else:
  123. logits_row = logits_processor(past_tokens_ids,
  124. logits_row)
  125. logits[logits_row_idx] = logits_row
  126. logits_processed += len(seq_group.sample_indices) + len(
  127. seq_group.prompt_logprob_indices)
  128. if found_logits_processors:
  129. # verifies that no rows in logits were missed unexpectedly
  130. assert logits_processed == logits.shape[0]
  131. return logits