medusa.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from aphrodite.common.sequence import SamplerOutput
  5. from aphrodite.common.utils import progress_bar
  6. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  8. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
  9. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  10. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  11. from aphrodite.transformers_utils.configs.medusa import MedusaConfig
  12. class ResidualBlock(nn.Module):
  13. def __init__(self, hidden_size: int, num_layers: int) -> None:
  14. super().__init__()
  15. self.layers = nn.ModuleList([
  16. nn.Linear(hidden_size, hidden_size, bias=False)
  17. for _ in range(num_layers)
  18. ])
  19. self.act = nn.SiLU()
  20. def forward(self, x: torch.Tensor) -> torch.Tensor:
  21. for layer in self.layers:
  22. x = x + self.act(layer(x))
  23. return x
  24. class Medusa(nn.Module):
  25. def __init__(self, config: MedusaConfig, **_) -> None:
  26. super().__init__()
  27. self.config = config
  28. self.blocks = nn.ModuleList([
  29. ResidualBlock(hidden_size=self.config.hidden_size,
  30. num_layers=self.config.num_hidden_layers)
  31. for _ in range(self.config.num_heads)
  32. ])
  33. self.orig_vocab_size = config.vocab_size
  34. self.truncated_vocab_size = config.truncated_vocab_size
  35. self.unpadded_vocab_size = self.truncated_vocab_size
  36. self.lm_heads = nn.ModuleList([
  37. ParallelLMHead(
  38. self.unpadded_vocab_size,
  39. config.hidden_size,
  40. org_num_embeddings=self.truncated_vocab_size,
  41. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  42. ) for _ in range(self.config.num_heads)
  43. ])
  44. logit_scale = getattr(config, "logit_scale", 1.0)
  45. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  46. self.truncated_vocab_size,
  47. logit_scale)
  48. self.token_map = None
  49. def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
  50. return [block(hidden_states) for block in self.blocks]
  51. def compute_logits(
  52. self, hidden_states: List[torch.Tensor],
  53. sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
  54. logits_lst: List[torch.Tensor] = []
  55. for hs, lm_head in zip(hidden_states, self.lm_heads):
  56. _logits = self.logits_processor(lm_head, hs, sampling_metadata)
  57. if _logits is None:
  58. # _logits should only be None on rank > 0, in which case
  59. # it should remain true for every lm_head
  60. assert len(logits_lst) == 0
  61. continue
  62. if self.token_map is None:
  63. logits_lst.append(_logits)
  64. else:
  65. logits_lst.append(-torch.inf * torch.ones(
  66. size=(*_logits.shape[:-1], self.orig_vocab_size),
  67. device=_logits.device,
  68. dtype=_logits.dtype))
  69. logits_lst[-1][..., self.token_map] = _logits
  70. return logits_lst
  71. def sample(
  72. self,
  73. logits: List[torch.Tensor],
  74. sampling_metadata: SamplingMetadata,
  75. ) -> List[SamplerOutput]:
  76. logits = torch.stack(logits, dim=0).float()
  77. logprobs = torch.log_softmax(logits, dim=-1)
  78. token_ids = logits.argmax(-1) # support only top-1 for now
  79. probs = torch.softmax(logits, dim=-1)
  80. token_id_list = []
  81. token_prob_list = []
  82. token_logprob_list = []
  83. for idx, seq_group in enumerate(sampling_metadata.seq_groups):
  84. token_id_list.append(token_ids[:, seq_group.sample_indices])
  85. token_prob_list.append(probs[:, seq_group.sample_indices])
  86. token_logprob_list.append(logprobs[:, seq_group.sample_indices])
  87. outputs: List[Optional[SamplerOutput]] = []
  88. for idx in range(len(sampling_metadata.seq_groups)):
  89. outputs.append(
  90. SamplerOutput(
  91. outputs=None,
  92. sampled_token_probs=token_prob_list[idx].squeeze(1),
  93. logprobs=token_logprob_list[idx].squeeze(1),
  94. sampled_token_ids=token_id_list[idx].squeeze(1),
  95. ))
  96. return outputs
  97. def generate_proposals(
  98. self,
  99. previous_hidden_states: torch.Tensor,
  100. sampling_metadata: SamplingMetadata,
  101. ) -> List[SamplerOutput]:
  102. return self.sample(
  103. logits=self.compute_logits(
  104. hidden_states=self.forward(previous_hidden_states),
  105. sampling_metadata=sampling_metadata,
  106. ),
  107. sampling_metadata=sampling_metadata,
  108. )
  109. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  110. params_dict = dict(self.named_parameters())
  111. weights_map = {}
  112. weights_list = list(weights)
  113. for name, loaded_weight in progress_bar(weights_list,
  114. desc="Loading modules..."):
  115. name = name.replace("medusa_heads.", "")
  116. if name == "token_map":
  117. if self.truncated_vocab_size < self.orig_vocab_size:
  118. self.token_map = nn.Parameter(loaded_weight,
  119. requires_grad=False)
  120. elif name in params_dict:
  121. weights_map[name] = loaded_weight
  122. for name, loaded_weight in weights_map.items():
  123. if "lm_head" in name and self.token_map is not None and\
  124. loaded_weight.shape[0] > self.token_map.shape[0]:
  125. loaded_weight = loaded_weight[self.token_map]
  126. param = params_dict[name]
  127. weight_loader = getattr(param, "weight_loader",
  128. default_weight_loader)
  129. weight_loader(param, loaded_weight)
  130. if self.token_map is not None:
  131. self.token_map.to(device=self.lm_heads[0].weight.device)
  132. assert (self.truncated_vocab_size
  133. == self.orig_vocab_size) or (self.token_map is not None)