Ver código fonte

fix: `MLPSpeculator` handling of `num_speculative_tokens`

AlpinDale 7 meses atrás
pai
commit
51cfadeb29

+ 7 - 3
aphrodite/common/config.py

@@ -1002,15 +1002,19 @@ class SpeculativeConfig:
                 max_logprobs=target_model_config.max_logprobs,
             )
 
-            if (draft_model_config.hf_config.model_type == "mlp_speculator"
+            draft_hf_config = draft_model_config.hf_config
+            if (draft_hf_config.model_type == "mlp_speculator"
                     and target_parallel_config.world_size != 1):
                 # MLPSpeculator TP support will be added very soon
                 raise ValueError(
                     "Speculative decoding with mlp_speculator models does not "
                     "yet support distributed inferencing (TP > 1).")
 
-            n_predict = getattr(draft_model_config.hf_config, "n_predict",
-                                None)
+            if (num_speculative_tokens is not None
+                    and hasattr(draft_hf_config, "num_lookahead_tokens")):
+                draft_hf_config.num_lookahead_tokens = num_speculative_tokens
+
+            n_predict = getattr(draft_hf_config, "n_predict", None)
             if n_predict is not None:
                 if num_speculative_tokens is None:
                     # Default to max value defined in draft model config.

+ 8 - 7
aphrodite/modeling/models/mlp_speculator.py

@@ -11,6 +11,7 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
 
 
 class MLPSpeculatorLayerNorm(nn.Module):
@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
 
 class MLPSpeculator(nn.Module):
 
-    def __init__(self, config, **kwargs) -> None:
+    def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
         super().__init__()
         self.n_predict = config.n_predict
         self.vocab_size = config.vocab_size
@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
         self.inner_dim = config.inner_dim if config.inner_dim != 0 \
             else config.emb_dim
 
-        self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
-                                              self.n_predict)
+        self.max_speculative_tokens = config.num_lookahead_tokens
 
         self.emb = nn.ModuleList([
             VocabParallelEmbedding(config.vocab_size,
@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
-            param = params_dict[name.replace("speculator.", "")]
-            weight_loader = getattr(param, "weight_loader",
-                                    default_weight_loader)
-            weight_loader(param, loaded_weight)
+            param = params_dict.get(name.replace("speculator.", ""))
+            if param is not None:
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 2 - 0
aphrodite/transformers_utils/configs/mlp_speculator.py

@@ -34,6 +34,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
                 candidate tree.
                 For each candidate branch in the tree, head n produces topk[n]
                 additional sub-branches.
+                NOTE: This parameter is currently unused.
             n_candidates: int
                 number of child candidates to create per sequence
         """
@@ -46,4 +47,5 @@ class MLPSpeculatorConfig(PretrainedConfig):
         self.n_predict = n_predict
         self.top_k_tokens_per_head = top_k_tokens_per_head
         self.n_candidates = n_candidates
+        self.num_lookahead_tokens = n_predict
         super().__init__(**kwargs)