eagle.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from aphrodite.attention.backends.abstract import AttentionMetadata
  5. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  6. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  8. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
  9. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  10. from aphrodite.modeling.models import ModelRegistry
  11. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  12. from aphrodite.transformers_utils.configs.eagle import EAGLEConfig
  13. class EAGLE(nn.Module):
  14. """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
  15. Reference implementation: https://github.com/SafeAILab/EAGLE
  16. Differences from reference implementation:
  17. 1. In reference, LlamaDecoderLayer implementation doesn't have
  18. input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
  19. but we do as HF implementation also does.
  20. 2. We allow any decoder layer to be used in EAGLE whereas in reference
  21. decoder layer is fixed to be LlamaDecoderLayer.
  22. 3. We have an optional token_map which reduces draft vocab to most
  23. frequently used tokens to give some additional speed-up by reducing
  24. sampling overhead. This is disabled unless the checkpoint file has
  25. explicit token_map tensor and config has an optional attribute
  26. truncated_vocab_size < vocab_size. To use this technique, one has to find
  27. the top-k most frequent tokens in target dataset and add that as a tensor
  28. in the draft checkpoint (using key token_map). Also, the draft config
  29. needs to have truncated_vocab_size (=k) as an attribute."""
  30. def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
  31. super().__init__()
  32. self.config = config
  33. architectures = getattr(self.config.model, "architectures", [])
  34. model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
  35. self.model = model_cls(self.config.model, *args, **kwargs)
  36. self.fc = nn.Linear(
  37. config.model.hidden_size * 2, config.model.hidden_size, bias=False
  38. )
  39. self.orig_vocab_size = config.vocab_size
  40. self.truncated_vocab_size = config.truncated_vocab_size
  41. self.unpadded_vocab_size = self.truncated_vocab_size
  42. self.lm_head = ParallelLMHead(
  43. self.unpadded_vocab_size,
  44. config.hidden_size,
  45. org_num_embeddings=self.truncated_vocab_size,
  46. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  47. )
  48. logit_scale = getattr(config, "logit_scale", 1.0)
  49. self.logits_processor = LogitsProcessor(
  50. self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale
  51. )
  52. # Token map is a idx to token mapping to reduce the vocab size for
  53. # the draft model. Using smaller vocab size for draft, containing
  54. # only most frequent tokens reduces the speculation overhead. This
  55. # doesn't affect the acceptance rate much and thus gives more speed
  56. # -up. By default, this is disabled and is only used if the EAGLE
  57. # checkpoint file has token_map tensor.
  58. self.token_map = None
  59. @property
  60. def sampler(self):
  61. return self.model.sampler
  62. def forward(
  63. self,
  64. input_ids: torch.Tensor,
  65. positions: torch.Tensor,
  66. kv_caches: List[torch.Tensor],
  67. attn_metadata: AttentionMetadata,
  68. previous_hidden_states: torch.Tensor,
  69. intermediate_tensors: Optional[IntermediateTensors] = None,
  70. ) -> torch.Tensor:
  71. tok_embeds = self.model.model.embed_tokens(input_ids)
  72. inputs_embeds = self.fc(
  73. torch.cat([tok_embeds, previous_hidden_states], dim=-1)
  74. )
  75. inputs_embeds[positions == 0] = 0 # masking inputs at position=0
  76. hidden_states = self.model.model(
  77. input_ids=None,
  78. inputs_embeds=inputs_embeds,
  79. positions=positions,
  80. kv_caches=kv_caches,
  81. attn_metadata=attn_metadata,
  82. intermediate_tensors=intermediate_tensors,
  83. )
  84. return hidden_states
  85. def compute_logits(
  86. self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
  87. ) -> torch.Tensor:
  88. logits = self.logits_processor(
  89. self.lm_head, hidden_states, sampling_metadata
  90. )
  91. if self.token_map is not None:
  92. _logits = logits
  93. logits = -torch.inf * torch.ones(
  94. size=(*_logits.shape[:-1], self.orig_vocab_size),
  95. device=_logits.device,
  96. dtype=_logits.dtype,
  97. )
  98. logits[..., self.token_map] = _logits
  99. return logits
  100. def sample(
  101. self,
  102. logits: torch.Tensor,
  103. sampling_metadata: SamplingMetadata,
  104. ) -> Optional[SamplerOutput]:
  105. next_tokens = self.sampler(logits, sampling_metadata)
  106. return next_tokens
  107. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  108. # This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
  109. # due to missing lm_head weights and its config being that of a
  110. # Llama model. Here's a compatible version with the same weights:
  111. # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
  112. # Also, here's an example script for converting trained EAGLE
  113. # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
  114. model_weights = {}
  115. for name, loaded_weight in weights:
  116. if name == "token_map":
  117. if self.config.truncated_vocab_size < self.config.vocab_size:
  118. self.token_map = nn.Parameter(
  119. loaded_weight, requires_grad=False
  120. )
  121. elif name.startswith("fc."):
  122. weight_loader = getattr(
  123. self.fc.weight, "weight_loader", default_weight_loader
  124. )
  125. weight_loader(self.fc.weight, loaded_weight)
  126. elif name.startswith("model.lm_head.") or name.startswith(
  127. "model.model."
  128. ):
  129. model_weights[name.split("model.", 1)[-1]] = loaded_weight
  130. elif name.startswith("lm_head.") or name.startswith("model."):
  131. model_weights[name] = loaded_weight
  132. else:
  133. model_weights[f"model.{name}"] = loaded_weight
  134. lm_head_weight = model_weights.pop("lm_head.weight")
  135. if (
  136. self.token_map is not None
  137. and lm_head_weight.shape[0] > self.token_map.shape[0]
  138. ):
  139. lm_head_weight = lm_head_weight[self.token_map]
  140. weight_loader = getattr(
  141. self.lm_head.weight, "weight_loader", default_weight_loader
  142. )
  143. weight_loader(self.lm_head.weight, lm_head_weight)
  144. self.model.load_weights(model_weights.items())