mlp_speculator.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import math
  2. from typing import Iterable, List, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.common.sequence import SamplerOutput
  6. from aphrodite.modeling import SamplingMetadata
  7. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  8. from aphrodite.modeling.layers.sampler import Sampler
  9. from aphrodite.modeling.layers.vocab_parallel_embedding import \
  10. VocabParallelEmbedding
  11. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  12. from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
  13. class MLPSpeculatorLayerNorm(nn.Module):
  14. """
  15. A L2 normalization implementation
  16. ...
  17. Args
  18. ----
  19. normalized_shape : int
  20. Dimensionality of input data (size of final tensor axis)
  21. eps : float
  22. Safety term to prevent division by zero. Make sure the chosen value
  23. fits in the range of your encoding scheme
  24. (i.e. fp16 requires eps >= 6e-8).
  25. """
  26. def __init__(
  27. self,
  28. normalized_shape,
  29. eps=1e-06,
  30. ):
  31. super(MLPSpeculatorLayerNorm, self).__init__()
  32. self.weight = nn.Parameter(torch.empty(normalized_shape))
  33. self.bias = nn.Parameter(torch.empty(normalized_shape))
  34. self.eps = eps
  35. def forward(self, x):
  36. xf = x
  37. xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
  38. x = xf.type_as(x)
  39. x = self.weight * x
  40. x = x + self.bias
  41. return x
  42. class MLPSpeculator(nn.Module):
  43. def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
  44. super().__init__()
  45. self.n_predict = config.n_predict
  46. self.vocab_size = config.vocab_size
  47. self.emb_dim = config.emb_dim
  48. self.inner_dim = config.inner_dim if config.inner_dim != 0 \
  49. else config.emb_dim
  50. self.max_speculative_tokens = config.num_lookahead_tokens
  51. self.emb = nn.ModuleList([
  52. VocabParallelEmbedding(config.vocab_size,
  53. self.inner_dim,
  54. org_num_embeddings=config.vocab_size)
  55. for _ in range(self.max_speculative_tokens)
  56. ])
  57. self.proj = nn.ModuleList([
  58. nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
  59. self.inner_dim,
  60. bias=False) for i in range(self.max_speculative_tokens)
  61. ])
  62. self.head = nn.ModuleList([
  63. nn.Linear(self.inner_dim, self.vocab_size, bias=False)
  64. for _ in range(self.max_speculative_tokens)
  65. ])
  66. self.ln = nn.ModuleList([
  67. MLPSpeculatorLayerNorm(self.inner_dim)
  68. for _ in range(self.max_speculative_tokens)
  69. ])
  70. self.state_weight = 0.5**(0.5 / config.n_predict)
  71. self.emb_weight = math.sqrt(
  72. (1 - self.state_weight**2) * (self.inner_dim / 2))
  73. self.activation = nn.GELU()
  74. self.config = config
  75. self.logits_processor = LogitsProcessor(config.vocab_size,
  76. config.vocab_size, 1.0)
  77. self.sampler = Sampler()
  78. def generate_proposals(
  79. self,
  80. input_ids: torch.Tensor,
  81. previous_hidden_states: torch.Tensor,
  82. num_predict_tokens: int,
  83. sampling_metadata: SamplingMetadata,
  84. ) -> List[SamplerOutput]:
  85. if num_predict_tokens > self.max_speculative_tokens:
  86. raise ValueError(f"Max speculative tokens for model is "
  87. f"{self.max_speculative_tokens}, but "
  88. f"{num_predict_tokens} were requested")
  89. # b x 1 x d
  90. previous_hidden_states = previous_hidden_states.unsqueeze(1)
  91. # b x 1
  92. last_tokens = input_ids.unsqueeze(1)
  93. next_tokens = []
  94. for head_index in range(num_predict_tokens):
  95. # Project and predict
  96. z = self.emb[head_index](last_tokens) # b k d
  97. states = self.proj[head_index](previous_hidden_states)
  98. # Weighted add of state_weight*state and emb_weight*z
  99. # Let subsequent LN take care of denominator
  100. # state_weight is close to 1, so shouldn't be any precision issues
  101. states.add_(z, alpha=self.emb_weight / self.state_weight)
  102. states = self.activation(self.ln[head_index](states)) # b k d
  103. # TODO: not yet supporting top_k_tokens_per_head
  104. previous_hidden_states = states
  105. logits = self.logits_processor(self.head[head_index].weight,
  106. states, sampling_metadata)
  107. output = self.sampler(logits.flatten(0, 1), sampling_metadata)
  108. last_tokens = output.sampled_token_ids
  109. next_tokens.append(output)
  110. return next_tokens
  111. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  112. params_dict = dict(self.named_parameters())
  113. for name, loaded_weight in weights:
  114. param = params_dict.get(name.replace("speculator.", ""))
  115. if param is not None:
  116. weight_loader = getattr(param, "weight_loader",
  117. default_weight_loader)
  118. weight_loader(param, loaded_weight)