mlp_speculator.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import math
  2. from typing import Iterable, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.common.sequence import SamplerOutput
  6. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  7. from aphrodite.modeling.layers.sampler import Sampler
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  9. ParallelLMHead, VocabParallelEmbedding)
  10. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  11. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  12. from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
  13. SQRT2 = 2**0.5
  14. class MLPSpeculatorLayerNorm(nn.Module):
  15. """
  16. A L2 normalization implementation
  17. ...
  18. Args
  19. ----
  20. normalized_shape : int
  21. Dimensionality of input data (size of final tensor axis)
  22. eps : float
  23. Safety term to prevent division by zero. Make sure the chosen value
  24. fits in the range of your encoding scheme
  25. (i.e. fp16 requires eps >= 6e-8).
  26. elementwise_scale_and_shift : bool
  27. Include a learned scaling and shift term after normalization.
  28. """
  29. def __init__(
  30. self,
  31. normalized_shape,
  32. eps=1e-06,
  33. elementwise_scale_and_shift=True,
  34. ):
  35. super(MLPSpeculatorLayerNorm, self).__init__()
  36. self.elementwise_scale_and_shift = elementwise_scale_and_shift
  37. if self.elementwise_scale_and_shift:
  38. self.weight = nn.Parameter(torch.empty(normalized_shape))
  39. self.bias = nn.Parameter(torch.empty(normalized_shape))
  40. self.eps = eps
  41. def forward(self, x):
  42. xf = x
  43. xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
  44. x = xf.type_as(x)
  45. if self.elementwise_scale_and_shift:
  46. x = self.weight * x
  47. x = x + self.bias
  48. return x
  49. class MLPSpeculator(nn.Module):
  50. def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
  51. super().__init__()
  52. self.n_predict = config.n_predict
  53. self.vocab_size = config.vocab_size
  54. self.emb_dim = config.emb_dim
  55. self.inner_dim = config.inner_dim if config.inner_dim != 0 \
  56. else config.emb_dim
  57. self.max_speculative_tokens = config.num_lookahead_tokens
  58. self.tie_weights = config.tie_weights
  59. self.scale_input = config.scale_input
  60. if self.tie_weights:
  61. assert (
  62. self.n_predict >
  63. 1), "You cannot tie weights between stages when only 1 exists"
  64. embedding = VocabParallelEmbedding(
  65. config.vocab_size,
  66. self.inner_dim,
  67. org_num_embeddings=config.vocab_size)
  68. self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
  69. # the initial projection from the base model may
  70. # have a different size, so that stays separate.
  71. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
  72. proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
  73. self.proj = nn.ModuleList([proj_first] + [proj_tied] *
  74. (self.max_speculative_tokens - 1))
  75. head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  76. self.head = nn.ModuleList([head] * self.max_speculative_tokens)
  77. ln = MLPSpeculatorLayerNorm(self.inner_dim,
  78. elementwise_scale_and_shift=True)
  79. self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
  80. else:
  81. self.emb = nn.ModuleList([
  82. VocabParallelEmbedding(config.vocab_size,
  83. self.inner_dim,
  84. org_num_embeddings=config.vocab_size)
  85. for _ in range(self.max_speculative_tokens)
  86. ])
  87. self.proj = nn.ModuleList([
  88. nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
  89. self.inner_dim,
  90. bias=False)
  91. for i in range(self.max_speculative_tokens)
  92. ])
  93. self.head = nn.ModuleList([
  94. ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  95. for _ in range(self.max_speculative_tokens)
  96. ])
  97. self.ln = nn.ModuleList([
  98. MLPSpeculatorLayerNorm(self.inner_dim,
  99. elementwise_scale_and_shift=True)
  100. for _ in range(self.max_speculative_tokens)
  101. ])
  102. if self.scale_input:
  103. self.ln0 = MLPSpeculatorLayerNorm(
  104. self.emb_dim, elementwise_scale_and_shift=False)
  105. self.state_weight = 0.5**(0.5 / config.n_predict)
  106. self.emb_weight = math.sqrt(
  107. (1 - self.state_weight**2) * (self.inner_dim / 2))
  108. self.activation = nn.GELU()
  109. self.config = config
  110. self.logits_processor = LogitsProcessor(config.vocab_size,
  111. config.vocab_size, 1.0)
  112. self.sampler = Sampler()
  113. def generate_proposals(
  114. self,
  115. input_ids: torch.Tensor,
  116. previous_hidden_states: torch.Tensor,
  117. num_predict_tokens: int,
  118. sampling_metadata: SamplingMetadata,
  119. ) -> List[SamplerOutput]:
  120. if num_predict_tokens > self.max_speculative_tokens:
  121. raise ValueError(f"Max speculative tokens for model is "
  122. f"{self.max_speculative_tokens}, but "
  123. f"{num_predict_tokens} were requested")
  124. # b x 1 x d
  125. previous_hidden_states = previous_hidden_states.unsqueeze(1)
  126. if self.scale_input:
  127. previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
  128. # b x 1
  129. last_tokens = input_ids.unsqueeze(1)
  130. next_tokens = []
  131. for head_index in range(num_predict_tokens):
  132. # Project and predict
  133. z = self.emb[head_index](last_tokens) # b k d
  134. states = self.proj[head_index](previous_hidden_states)
  135. # Weighted add of state_weight*state and emb_weight*z
  136. # Let subsequent LN take care of denominator
  137. # state_weight is close to 1, so shouldn't be any precision issues
  138. states.add_(z, alpha=self.emb_weight / self.state_weight)
  139. states = self.activation(self.ln[head_index](states)) # b k d
  140. previous_hidden_states = states
  141. # TODO: not yet supporting top_k_tokens_per_head
  142. states = states.flatten(0, 1)
  143. logits = self.logits_processor(self.head[head_index], states,
  144. sampling_metadata)
  145. output = self.sampler(logits, sampling_metadata)
  146. last_tokens = output.sampled_token_ids
  147. next_tokens.append(output)
  148. return next_tokens
  149. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  150. params_dict = dict(self.named_parameters())
  151. for name, loaded_weight in weights:
  152. param = params_dict.get(name.replace("speculator.", ""))
  153. if param is not None:
  154. weight_loader = getattr(param, "weight_loader",
  155. default_weight_loader)
  156. weight_loader(param, loaded_weight)