mlp_speculator.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import math
  2. from typing import Iterable, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.common.sequence import SamplerOutput
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  8. from aphrodite.modeling.layers.sampler import Sampler
  9. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  10. ParallelLMHead, VocabParallelEmbedding)
  11. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  12. from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
  13. SQRT2 = 2**0.5
  14. class MLPSpeculatorLayerNorm(nn.Module):
  15. """
  16. A L2 normalization implementation
  17. ...
  18. Args
  19. ----
  20. normalized_shape : int
  21. Dimensionality of input data (size of final tensor axis)
  22. eps : float
  23. Safety term to prevent division by zero. Make sure the chosen value
  24. fits in the range of your encoding scheme
  25. (i.e. fp16 requires eps >= 6e-8).
  26. elementwise_scale_and_shift : bool
  27. Include a learned scaling and shift term after normalization.
  28. """
  29. def __init__(
  30. self,
  31. normalized_shape,
  32. eps=1e-06,
  33. elementwise_scale_and_shift=True,
  34. ):
  35. super(MLPSpeculatorLayerNorm, self).__init__()
  36. self.elementwise_scale_and_shift = elementwise_scale_and_shift
  37. if self.elementwise_scale_and_shift:
  38. self.weight = nn.Parameter(torch.empty(normalized_shape))
  39. self.bias = nn.Parameter(torch.empty(normalized_shape))
  40. self.eps = eps
  41. def forward(self, x):
  42. xf = x
  43. xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
  44. x = xf.type_as(x)
  45. if self.elementwise_scale_and_shift:
  46. x = self.weight * x
  47. x = x + self.bias
  48. return x
  49. class MLPSpeculator(nn.Module):
  50. def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
  51. super().__init__()
  52. self.n_predict = config.n_predict
  53. self.vocab_size = config.vocab_size
  54. self.emb_dim = config.emb_dim
  55. self.inner_dim = config.inner_dim if config.inner_dim != 0 \
  56. else config.emb_dim
  57. self.max_speculative_tokens = config.num_lookahead_tokens
  58. self.tie_weights = config.tie_weights
  59. self.scale_input = config.scale_input
  60. if self.tie_weights:
  61. assert (
  62. self.n_predict >
  63. 1), "You cannot tie weights between stages when only 1 exists"
  64. embedding = VocabParallelEmbedding(
  65. config.vocab_size,
  66. self.inner_dim,
  67. org_num_embeddings=config.vocab_size)
  68. self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
  69. # the initial projection from the base model may
  70. # have a different size, so that stays separate.
  71. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
  72. proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
  73. self.proj = nn.ModuleList([proj_first] + [proj_tied] *
  74. (self.max_speculative_tokens - 1))
  75. head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  76. self.head = nn.ModuleList([head] * self.max_speculative_tokens)
  77. ln = MLPSpeculatorLayerNorm(self.inner_dim,
  78. elementwise_scale_and_shift=True)
  79. self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
  80. else:
  81. self.emb = nn.ModuleList([
  82. VocabParallelEmbedding(config.vocab_size,
  83. self.inner_dim,
  84. org_num_embeddings=config.vocab_size)
  85. for _ in range(self.max_speculative_tokens)
  86. ])
  87. self.proj = nn.ModuleList([
  88. nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
  89. self.inner_dim,
  90. bias=False)
  91. for i in range(self.max_speculative_tokens)
  92. ])
  93. self.head = nn.ModuleList([
  94. ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  95. for _ in range(self.max_speculative_tokens)
  96. ])
  97. self.ln = nn.ModuleList([
  98. MLPSpeculatorLayerNorm(self.inner_dim,
  99. elementwise_scale_and_shift=True)
  100. for _ in range(self.max_speculative_tokens)
  101. ])
  102. if self.scale_input:
  103. self.ln0 = MLPSpeculatorLayerNorm(
  104. self.emb_dim, elementwise_scale_and_shift=False)
  105. self.state_weight = 0.5**(0.5 / config.n_predict)
  106. self.emb_weight = math.sqrt(
  107. (1 - self.state_weight**2) * (self.inner_dim / 2))
  108. self.activation = nn.GELU()
  109. self.config = config
  110. self.logits_processor = LogitsProcessor(config.vocab_size,
  111. config.vocab_size, 1.0)
  112. self.sampler = Sampler()
  113. def generate_proposals(
  114. self,
  115. input_ids: torch.Tensor,
  116. previous_hidden_states: torch.Tensor,
  117. num_predict_tokens: int,
  118. sampling_metadata: SamplingMetadata,
  119. ) -> List[SamplerOutput]:
  120. if num_predict_tokens > self.max_speculative_tokens:
  121. raise ValueError(f"Max speculative tokens for model is "
  122. f"{self.max_speculative_tokens}, but "
  123. f"{num_predict_tokens} were requested")
  124. # b x 1 x d
  125. previous_hidden_states = previous_hidden_states.unsqueeze(1)
  126. if self.scale_input:
  127. previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
  128. # b x 1
  129. last_tokens = input_ids.unsqueeze(1)
  130. next_tokens = []
  131. for head_index in range(num_predict_tokens):
  132. # Project and predict
  133. z = self.emb[head_index](last_tokens) # b k d
  134. states = self.proj[head_index](previous_hidden_states)
  135. # Weighted add of state_weight*state and emb_weight*z
  136. # Let subsequent LN take care of denominator
  137. # state_weight is close to 1, so shouldn't be any precision issues
  138. states.add_(z, alpha=self.emb_weight / self.state_weight)
  139. states = self.activation(self.ln[head_index](states)) # b k d
  140. # TODO: not yet supporting top_k_tokens_per_head
  141. previous_hidden_states = states
  142. logits = self.logits_processor(self.head[head_index], states,
  143. sampling_metadata)
  144. output = self.sampler(logits.flatten(0, 1), sampling_metadata)
  145. last_tokens = output.sampled_token_ids
  146. next_tokens.append(output)
  147. return next_tokens
  148. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  149. params_dict = dict(self.named_parameters())
  150. for name, loaded_weight in weights:
  151. param = params_dict.get(name.replace("speculator.", ""))
  152. if param is not None:
  153. weight_loader = getattr(param, "weight_loader",
  154. default_weight_loader)
  155. weight_loader(param, loaded_weight)