medusa.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  5. from aphrodite.modeling.layers.sampler import SamplerOutput
  6. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  7. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
  8. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  9. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  10. from aphrodite.transformers_utils.configs.medusa import MedusaConfig
  11. class ResidualBlock(nn.Module):
  12. def __init__(self, hidden_size: int, num_layers: int) -> None:
  13. super().__init__()
  14. self.layers = nn.ModuleList([
  15. nn.Linear(hidden_size, hidden_size, bias=False)
  16. for _ in range(num_layers)
  17. ])
  18. self.act = nn.SiLU()
  19. def forward(self, x: torch.Tensor) -> torch.Tensor:
  20. for layer in self.layers:
  21. x = x + self.act(layer(x))
  22. return x
  23. class Medusa(nn.Module):
  24. """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
  25. Reference implementation: https://github.com/FasterDecoding/Medusa
  26. Differences from reference implementation:
  27. 1. Currently this only supports generating proposals from top-1 tokens.
  28. 2. We have an optional token_map which reduces draft vocab to most
  29. frequently used tokens to give some additional speed-up by reducing
  30. sampling overhead. This is disabled unless the checkpoint file has
  31. explicit token_map tensor and config has an optional attribute
  32. truncated_vocab_size < vocab_size. To use this technique, one has to find
  33. the top-k most frequent tokens in target dataset and add that as a tensor
  34. in the draft checkpoint (using key token_map). Also, the draft config
  35. needs to have truncated_vocab_size (=k) as an attribute."""
  36. def __init__(self, config: MedusaConfig, **_) -> None:
  37. super().__init__()
  38. self.config = config
  39. self.blocks = nn.ModuleList([
  40. ResidualBlock(hidden_size=self.config.hidden_size,
  41. num_layers=self.config.num_hidden_layers)
  42. for _ in range(self.config.num_heads)
  43. ])
  44. self.orig_vocab_size = config.vocab_size
  45. self.truncated_vocab_size = config.truncated_vocab_size
  46. self.unpadded_vocab_size = self.truncated_vocab_size
  47. self.lm_heads = nn.ModuleList([
  48. ParallelLMHead(
  49. self.unpadded_vocab_size,
  50. config.hidden_size,
  51. org_num_embeddings=self.truncated_vocab_size,
  52. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  53. ) for _ in range(self.config.num_heads)
  54. ])
  55. logit_scale = getattr(config, "logit_scale", 1.0)
  56. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  57. self.truncated_vocab_size,
  58. logit_scale)
  59. # Token map is a idx to token mapping to reduce the vocab size for
  60. # the draft model. Using smaller vocab size for draft, containing
  61. # only most frequent tokens reduces the speculation overhead. This
  62. # doesn't affect the acceptance rate much and thus gives more speed
  63. # -up. By default, this is disabled and is only used if the EAGLE
  64. # checkpoint file has token_map tensor.
  65. self.token_map = None
  66. def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
  67. return [block(hidden_states) for block in self.blocks]
  68. def compute_logits(
  69. self, hidden_states: List[torch.Tensor],
  70. sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
  71. logits_lst: List[torch.Tensor] = []
  72. for hs, lm_head in zip(hidden_states, self.lm_heads):
  73. _logits = self.logits_processor(lm_head, hs, sampling_metadata)
  74. if _logits is None:
  75. # _logits should only be None on rank > 0, in which case
  76. # it should remain true for every lm_head
  77. assert len(logits_lst) == 0
  78. continue
  79. if self.token_map is None:
  80. logits_lst.append(_logits)
  81. else:
  82. logits_lst.append(-torch.inf * torch.ones(
  83. size=(*_logits.shape[:-1], self.orig_vocab_size),
  84. device=_logits.device,
  85. dtype=_logits.dtype))
  86. logits_lst[-1][..., self.token_map] = _logits
  87. return logits_lst
  88. def sample(
  89. self,
  90. logits: List[torch.Tensor],
  91. sampling_metadata: SamplingMetadata,
  92. ) -> List[SamplerOutput]:
  93. logits = torch.stack(logits, dim=0).float()
  94. logprobs = torch.log_softmax(logits, dim=-1)
  95. token_ids = logits.argmax(-1) # support only top-1 for now
  96. probs = torch.softmax(logits, dim=-1)
  97. token_id_list = []
  98. token_prob_list = []
  99. token_logprob_list = []
  100. for idx, seq_group in enumerate(sampling_metadata.seq_groups):
  101. token_id_list.append(token_ids[:, seq_group.sample_indices])
  102. token_prob_list.append(probs[:, seq_group.sample_indices])
  103. token_logprob_list.append(logprobs[:, seq_group.sample_indices])
  104. outputs: List[Optional[SamplerOutput]] = []
  105. for idx in range(len(sampling_metadata.seq_groups)):
  106. outputs.append(
  107. SamplerOutput(
  108. outputs=None,
  109. sampled_token_probs=token_prob_list[idx].squeeze(1),
  110. logprobs=token_logprob_list[idx].squeeze(1),
  111. sampled_token_ids=token_id_list[idx].squeeze(1),
  112. ))
  113. return outputs
  114. def generate_proposals(
  115. self,
  116. previous_hidden_states: torch.Tensor,
  117. sampling_metadata: SamplingMetadata,
  118. ) -> List[SamplerOutput]:
  119. return self.sample(
  120. logits=self.compute_logits(
  121. hidden_states=self.forward(previous_hidden_states),
  122. sampling_metadata=sampling_metadata,
  123. ),
  124. sampling_metadata=sampling_metadata,
  125. )
  126. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  127. params_dict = dict(self.named_parameters())
  128. weights_map = {}
  129. for name, loaded_weight in weights:
  130. name = name.replace("medusa_heads.", "")
  131. if name == "token_map":
  132. if self.truncated_vocab_size < self.orig_vocab_size:
  133. self.token_map = nn.Parameter(loaded_weight,
  134. requires_grad=False)
  135. elif name in params_dict:
  136. weights_map[name] = loaded_weight
  137. for name, loaded_weight in weights_map.items():
  138. if "lm_head" in name and self.token_map is not None and\
  139. loaded_weight.shape[0] > self.token_map.shape[0]:
  140. loaded_weight = loaded_weight[self.token_map]
  141. param = params_dict[name]
  142. weight_loader = getattr(param, "weight_loader",
  143. default_weight_loader)
  144. weight_loader(param, loaded_weight)
  145. if self.token_map is not None:
  146. self.token_map.to(device=self.lm_heads[0].weight.device)
  147. assert (self.truncated_vocab_size
  148. == self.orig_vocab_size) or (self.token_map is not None)