eagle.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. import torch.nn as nn
  4. from aphrodite.attention.backends.abstract import AttentionMetadata
  5. from aphrodite.common.sequence import IntermediateTensors
  6. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  7. from aphrodite.modeling.layers.sampler import SamplerOutput
  8. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  9. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
  10. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  11. from aphrodite.modeling.models import ModelRegistry
  12. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  13. from aphrodite.transformers_utils.configs.eagle import EAGLEConfig
  14. class EAGLE(nn.Module):
  15. """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
  16. Reference implementation: https://github.com/SafeAILab/EAGLE
  17. Differences from reference implementation:
  18. 1. In reference, LlamaDecoderLayer implementation doesn't have
  19. input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
  20. but we do as HF implementation also does.
  21. 2. We allow any decoder layer to be used in EAGLE whereas in reference
  22. decoder layer is fixed to be LlamaDecoderLayer.
  23. 3. We have an optional token_map which reduces draft vocab to most
  24. frequently used tokens to give some additional speed-up by reducing
  25. sampling overhead. This is disabled unless the checkpoint file has
  26. explicit token_map tensor and config has an optional attribute
  27. truncated_vocab_size < vocab_size. To use this technique, one has to find
  28. the top-k most frequent tokens in target dataset and add that as a tensor
  29. in the draft checkpoint (using key token_map). Also, the draft config
  30. needs to have truncated_vocab_size (=k) as an attribute."""
  31. def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
  32. super().__init__()
  33. self.config = config
  34. architectures = getattr(self.config.model, "architectures", [])
  35. model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
  36. self.model = model_cls(self.config.model, *args, **kwargs)
  37. self.fc = nn.Linear(
  38. config.model.hidden_size * 2, config.model.hidden_size, bias=False
  39. )
  40. self.orig_vocab_size = config.vocab_size
  41. self.truncated_vocab_size = config.truncated_vocab_size
  42. self.unpadded_vocab_size = self.truncated_vocab_size
  43. self.lm_head = ParallelLMHead(
  44. self.unpadded_vocab_size,
  45. config.hidden_size,
  46. org_num_embeddings=self.truncated_vocab_size,
  47. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  48. )
  49. logit_scale = getattr(config, "logit_scale", 1.0)
  50. self.logits_processor = LogitsProcessor(
  51. self.unpadded_vocab_size, self.truncated_vocab_size, logit_scale
  52. )
  53. # Token map is a idx to token mapping to reduce the vocab size for
  54. # the draft model. Using smaller vocab size for draft, containing
  55. # only most frequent tokens reduces the speculation overhead. This
  56. # doesn't affect the acceptance rate much and thus gives more speed
  57. # -up. By default, this is disabled and is only used if the EAGLE
  58. # checkpoint file has token_map tensor.
  59. self.token_map = None
  60. @property
  61. def sampler(self):
  62. return self.model.sampler
  63. def forward(
  64. self,
  65. input_ids: torch.Tensor,
  66. positions: torch.Tensor,
  67. kv_caches: List[torch.Tensor],
  68. attn_metadata: AttentionMetadata,
  69. previous_hidden_states: torch.Tensor,
  70. intermediate_tensors: Optional[IntermediateTensors] = None,
  71. ) -> torch.Tensor:
  72. tok_embeds = self.model.model.embed_tokens(input_ids)
  73. inputs_embeds = self.fc(
  74. torch.cat([tok_embeds, previous_hidden_states], dim=-1)
  75. )
  76. inputs_embeds[positions == 0] = 0 # masking inputs at position=0
  77. hidden_states = self.model.model(
  78. input_ids=None,
  79. inputs_embeds=inputs_embeds,
  80. positions=positions,
  81. kv_caches=kv_caches,
  82. attn_metadata=attn_metadata,
  83. intermediate_tensors=intermediate_tensors,
  84. )
  85. return hidden_states
  86. def compute_logits(
  87. self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
  88. ) -> torch.Tensor:
  89. logits = self.logits_processor(
  90. self.lm_head, hidden_states, sampling_metadata
  91. )
  92. if self.token_map is not None:
  93. _logits = logits
  94. logits = -torch.inf * torch.ones(
  95. size=(*_logits.shape[:-1], self.orig_vocab_size),
  96. device=_logits.device,
  97. dtype=_logits.dtype,
  98. )
  99. logits[..., self.token_map] = _logits
  100. return logits
  101. def sample(
  102. self,
  103. logits: torch.Tensor,
  104. sampling_metadata: SamplingMetadata,
  105. ) -> Optional[SamplerOutput]:
  106. next_tokens = self.sampler(logits, sampling_metadata)
  107. return next_tokens
  108. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  109. # This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
  110. # due to missing lm_head weights and its config being that of a
  111. # Llama model. Here's a compatible version with the same weights:
  112. # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
  113. # Also, here's an example script for converting trained EAGLE
  114. # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
  115. model_weights = {}
  116. for name, loaded_weight in weights:
  117. if name == "token_map":
  118. if self.config.truncated_vocab_size < self.config.vocab_size:
  119. self.token_map = nn.Parameter(
  120. loaded_weight, requires_grad=False
  121. )
  122. elif name.startswith("fc."):
  123. weight_loader = getattr(
  124. self.fc.weight, "weight_loader", default_weight_loader
  125. )
  126. weight_loader(self.fc.weight, loaded_weight)
  127. elif name.startswith("model.lm_head.") or name.startswith(
  128. "model.model."
  129. ):
  130. model_weights[name.split("model.", 1)[-1]] = loaded_weight
  131. elif name.startswith("lm_head.") or name.startswith("model."):
  132. model_weights[name] = loaded_weight
  133. else:
  134. model_weights[f"model.{name}"] = loaded_weight
  135. lm_head_weight = model_weights.pop("lm_head.weight")
  136. if (
  137. self.token_map is not None
  138. and lm_head_weight.shape[0] > self.token_map.shape[0]
  139. ):
  140. lm_head_weight = lm_head_weight[self.token_map]
  141. weight_loader = getattr(
  142. self.lm_head.weight, "weight_loader", default_weight_loader
  143. )
  144. weight_loader(self.lm_head.weight, lm_head_weight)
  145. self.model.load_weights(model_weights.items())