|
@@ -11,6 +11,7 @@ from aphrodite.modeling.layers.sampler import Sampler
|
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import \
|
|
|
VocabParallelEmbedding
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
+from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
|
|
|
|
|
|
|
|
|
class MLPSpeculatorLayerNorm(nn.Module):
|
|
@@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|
|
|
|
|
class MLPSpeculator(nn.Module):
|
|
|
|
|
|
- def __init__(self, config, **kwargs) -> None:
|
|
|
+ def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
|
|
|
super().__init__()
|
|
|
self.n_predict = config.n_predict
|
|
|
self.vocab_size = config.vocab_size
|
|
@@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
|
|
|
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
|
|
|
else config.emb_dim
|
|
|
|
|
|
- self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
|
|
|
- self.n_predict)
|
|
|
+ self.max_speculative_tokens = config.num_lookahead_tokens
|
|
|
|
|
|
self.emb = nn.ModuleList([
|
|
|
VocabParallelEmbedding(config.vocab_size,
|
|
@@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
for name, loaded_weight in weights:
|
|
|
- param = params_dict[name.replace("speculator.", "")]
|
|
|
- weight_loader = getattr(param, "weight_loader",
|
|
|
- default_weight_loader)
|
|
|
- weight_loader(param, loaded_weight)
|
|
|
+ param = params_dict.get(name.replace("speculator.", ""))
|
|
|
+ if param is not None:
|
|
|
+ weight_loader = getattr(param, "weight_loader",
|
|
|
+ default_weight_loader)
|
|
|
+ weight_loader(param, loaded_weight)
|