mlp_speculator.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import math
  2. from typing import Iterable, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.common.sequence import SamplerOutput
  6. from aphrodite.common.utils import progress_bar
  7. from aphrodite.modeling import SamplingMetadata
  8. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  9. from aphrodite.modeling.layers.sampler import Sampler
  10. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  11. ParallelLMHead, VocabParallelEmbedding)
  12. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  13. from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
  14. SQRT2 = 2**0.5
  15. class MLPSpeculatorLayerNorm(nn.Module):
  16. """
  17. A L2 normalization implementation
  18. ...
  19. Args
  20. ----
  21. normalized_shape : int
  22. Dimensionality of input data (size of final tensor axis)
  23. eps : float
  24. Safety term to prevent division by zero. Make sure the chosen value
  25. fits in the range of your encoding scheme
  26. (i.e. fp16 requires eps >= 6e-8).
  27. elementwise_scale_and_shift : bool
  28. Include a learned scaling and shift term after normalization.
  29. """
  30. def __init__(
  31. self,
  32. normalized_shape,
  33. eps=1e-06,
  34. elementwise_scale_and_shift=True,
  35. ):
  36. super(MLPSpeculatorLayerNorm, self).__init__()
  37. self.elementwise_scale_and_shift = elementwise_scale_and_shift
  38. if self.elementwise_scale_and_shift:
  39. self.weight = nn.Parameter(torch.empty(normalized_shape))
  40. self.bias = nn.Parameter(torch.empty(normalized_shape))
  41. self.eps = eps
  42. def forward(self, x):
  43. xf = x
  44. xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
  45. x = xf.type_as(x)
  46. if self.elementwise_scale_and_shift:
  47. x = self.weight * x
  48. x = x + self.bias
  49. return x
  50. class MLPSpeculator(nn.Module):
  51. def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
  52. super().__init__()
  53. self.n_predict = config.n_predict
  54. self.vocab_size = config.vocab_size
  55. self.emb_dim = config.emb_dim
  56. self.inner_dim = config.inner_dim if config.inner_dim != 0 \
  57. else config.emb_dim
  58. self.max_speculative_tokens = config.num_lookahead_tokens
  59. self.tie_weights = config.tie_weights
  60. self.scale_input = config.scale_input
  61. if self.tie_weights:
  62. assert (
  63. self.n_predict >
  64. 1), "You cannot tie weights between stages when only 1 exists"
  65. embedding = VocabParallelEmbedding(
  66. config.vocab_size,
  67. self.inner_dim,
  68. org_num_embeddings=config.vocab_size)
  69. self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
  70. # the initial projection from the base model may
  71. # have a different size, so that stays separate.
  72. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
  73. proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
  74. self.proj = nn.ModuleList([proj_first] + [proj_tied] *
  75. (self.max_speculative_tokens - 1))
  76. head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  77. self.head = nn.ModuleList([head] * self.max_speculative_tokens)
  78. ln = MLPSpeculatorLayerNorm(self.inner_dim,
  79. elementwise_scale_and_shift=True)
  80. self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
  81. else:
  82. self.emb = nn.ModuleList([
  83. VocabParallelEmbedding(config.vocab_size,
  84. self.inner_dim,
  85. org_num_embeddings=config.vocab_size)
  86. for _ in range(self.max_speculative_tokens)
  87. ])
  88. self.proj = nn.ModuleList([
  89. nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
  90. self.inner_dim,
  91. bias=False)
  92. for i in range(self.max_speculative_tokens)
  93. ])
  94. self.head = nn.ModuleList([
  95. ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  96. for _ in range(self.max_speculative_tokens)
  97. ])
  98. self.ln = nn.ModuleList([
  99. MLPSpeculatorLayerNorm(self.inner_dim,
  100. elementwise_scale_and_shift=True)
  101. for _ in range(self.max_speculative_tokens)
  102. ])
  103. if self.scale_input:
  104. self.ln0 = MLPSpeculatorLayerNorm(
  105. self.emb_dim, elementwise_scale_and_shift=False)
  106. self.state_weight = 0.5**(0.5 / config.n_predict)
  107. self.emb_weight = math.sqrt(
  108. (1 - self.state_weight**2) * (self.inner_dim / 2))
  109. self.activation = nn.GELU()
  110. self.config = config
  111. self.logits_processor = LogitsProcessor(config.vocab_size,
  112. config.vocab_size, 1.0)
  113. self.sampler = Sampler()
  114. def generate_proposals(
  115. self,
  116. input_ids: torch.Tensor,
  117. previous_hidden_states: torch.Tensor,
  118. num_predict_tokens: int,
  119. sampling_metadata: SamplingMetadata,
  120. ) -> List[SamplerOutput]:
  121. if num_predict_tokens > self.max_speculative_tokens:
  122. raise ValueError(f"Max speculative tokens for model is "
  123. f"{self.max_speculative_tokens}, but "
  124. f"{num_predict_tokens} were requested")
  125. # b x 1 x d
  126. previous_hidden_states = previous_hidden_states.unsqueeze(1)
  127. if self.scale_input:
  128. previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
  129. # b x 1
  130. last_tokens = input_ids.unsqueeze(1)
  131. next_tokens = []
  132. for head_index in range(num_predict_tokens):
  133. # Project and predict
  134. z = self.emb[head_index](last_tokens) # b k d
  135. states = self.proj[head_index](previous_hidden_states)
  136. # Weighted add of state_weight*state and emb_weight*z
  137. # Let subsequent LN take care of denominator
  138. # state_weight is close to 1, so shouldn't be any precision issues
  139. states.add_(z, alpha=self.emb_weight / self.state_weight)
  140. states = self.activation(self.ln[head_index](states)) # b k d
  141. previous_hidden_states = states
  142. # TODO: not yet supporting top_k_tokens_per_head
  143. states = states.flatten(0, 1)
  144. logits = self.logits_processor(self.head[head_index], states,
  145. sampling_metadata)
  146. output = self.sampler(logits, sampling_metadata)
  147. last_tokens = output.sampled_token_ids
  148. next_tokens.append(output)
  149. return next_tokens
  150. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  151. params_dict = dict(self.named_parameters())
  152. weights_list = list(weights)
  153. for name, loaded_weight in progress_bar(weights_list,
  154. desc="Loading modules..."):
  155. param = params_dict.get(name.replace("speculator.", ""))
  156. if param is not None:
  157. weight_loader = getattr(param, "weight_loader",
  158. default_weight_loader)
  159. weight_loader(param, loaded_weight)