mlp_speculator.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import math
  2. from typing import Iterable, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  6. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  8. ParallelLMHead, VocabParallelEmbedding)
  9. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  10. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  11. from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
  12. SQRT2 = 2**0.5
  13. class MLPSpeculatorLayerNorm(nn.Module):
  14. """
  15. A L2 normalization implementation
  16. ...
  17. Args
  18. ----
  19. normalized_shape : int
  20. Dimensionality of input data (size of final tensor axis)
  21. eps : float
  22. Safety term to prevent division by zero. Make sure the chosen value
  23. fits in the range of your encoding scheme
  24. (i.e. fp16 requires eps >= 6e-8).
  25. elementwise_scale_and_shift : bool
  26. Include a learned scaling and shift term after normalization.
  27. """
  28. def __init__(
  29. self,
  30. normalized_shape,
  31. eps=1e-06,
  32. elementwise_scale_and_shift=True,
  33. ):
  34. super(MLPSpeculatorLayerNorm, self).__init__()
  35. self.elementwise_scale_and_shift = elementwise_scale_and_shift
  36. if self.elementwise_scale_and_shift:
  37. self.weight = nn.Parameter(torch.empty(normalized_shape))
  38. self.bias = nn.Parameter(torch.empty(normalized_shape))
  39. self.eps = eps
  40. def forward(self, x):
  41. xf = x
  42. xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
  43. x = xf.type_as(x)
  44. if self.elementwise_scale_and_shift:
  45. x = self.weight * x
  46. x = x + self.bias
  47. return x
  48. class MLPSpeculator(nn.Module):
  49. def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
  50. super().__init__()
  51. self.n_predict = config.n_predict
  52. self.vocab_size = config.vocab_size
  53. self.emb_dim = config.emb_dim
  54. self.inner_dim = config.inner_dim if config.inner_dim != 0 \
  55. else config.emb_dim
  56. self.max_speculative_tokens = config.num_lookahead_tokens
  57. self.tie_weights = config.tie_weights
  58. self.scale_input = config.scale_input
  59. if self.tie_weights:
  60. assert (
  61. self.n_predict >
  62. 1), "You cannot tie weights between stages when only 1 exists"
  63. embedding = VocabParallelEmbedding(
  64. config.vocab_size,
  65. self.inner_dim,
  66. org_num_embeddings=config.vocab_size)
  67. self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
  68. # the initial projection from the base model may
  69. # have a different size, so that stays separate.
  70. proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
  71. proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
  72. self.proj = nn.ModuleList([proj_first] + [proj_tied] *
  73. (self.max_speculative_tokens - 1))
  74. head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  75. self.head = nn.ModuleList([head] * self.max_speculative_tokens)
  76. ln = MLPSpeculatorLayerNorm(self.inner_dim,
  77. elementwise_scale_and_shift=True)
  78. self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
  79. else:
  80. self.emb = nn.ModuleList([
  81. VocabParallelEmbedding(config.vocab_size,
  82. self.inner_dim,
  83. org_num_embeddings=config.vocab_size)
  84. for _ in range(self.max_speculative_tokens)
  85. ])
  86. self.proj = nn.ModuleList([
  87. nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
  88. self.inner_dim,
  89. bias=False)
  90. for i in range(self.max_speculative_tokens)
  91. ])
  92. self.head = nn.ModuleList([
  93. ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
  94. for _ in range(self.max_speculative_tokens)
  95. ])
  96. self.ln = nn.ModuleList([
  97. MLPSpeculatorLayerNorm(self.inner_dim,
  98. elementwise_scale_and_shift=True)
  99. for _ in range(self.max_speculative_tokens)
  100. ])
  101. if self.scale_input:
  102. self.ln0 = MLPSpeculatorLayerNorm(
  103. self.emb_dim, elementwise_scale_and_shift=False)
  104. self.state_weight = 0.5**(0.5 / config.n_predict)
  105. self.emb_weight = math.sqrt(
  106. (1 - self.state_weight**2) * (self.inner_dim / 2))
  107. self.activation = nn.GELU()
  108. self.config = config
  109. self.logits_processor = LogitsProcessor(config.vocab_size,
  110. config.vocab_size, 1.0)
  111. self.sampler = Sampler()
  112. def generate_proposals(
  113. self,
  114. input_ids: torch.Tensor,
  115. previous_hidden_states: torch.Tensor,
  116. num_predict_tokens: int,
  117. sampling_metadata: SamplingMetadata,
  118. ) -> List[SamplerOutput]:
  119. if num_predict_tokens > self.max_speculative_tokens:
  120. raise ValueError(f"Max speculative tokens for model is "
  121. f"{self.max_speculative_tokens}, but "
  122. f"{num_predict_tokens} were requested")
  123. # b x 1 x d
  124. previous_hidden_states = previous_hidden_states.unsqueeze(1)
  125. if self.scale_input:
  126. previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
  127. # b x 1
  128. last_tokens = input_ids.unsqueeze(1)
  129. next_tokens = []
  130. for head_index in range(num_predict_tokens):
  131. # Project and predict
  132. z = self.emb[head_index](last_tokens) # b k d
  133. states = self.proj[head_index](previous_hidden_states)
  134. # Weighted add of state_weight*state and emb_weight*z
  135. # Let subsequent LN take care of denominator
  136. # state_weight is close to 1, so shouldn't be any precision issues
  137. states.add_(z, alpha=self.emb_weight / self.state_weight)
  138. states = self.activation(self.ln[head_index](states)) # b k d
  139. previous_hidden_states = states
  140. # TODO: not yet supporting top_k_tokens_per_head
  141. states = states.flatten(0, 1)
  142. logits = self.logits_processor(self.head[head_index], states,
  143. sampling_metadata)
  144. output = self.sampler(logits, sampling_metadata)
  145. last_tokens = output.sampled_token_ids
  146. next_tokens.append(output)
  147. return next_tokens
  148. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  149. params_dict = dict(self.named_parameters())
  150. for name, loaded_weight in weights:
  151. param = params_dict.get(name.replace("speculator.", ""))
  152. if param is not None:
  153. weight_loader = getattr(param, "weight_loader",
  154. default_weight_loader)
  155. weight_loader(param, loaded_weight)