123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- import math
- from typing import Iterable, List, Tuple
- import torch
- import torch.nn as nn
- from aphrodite.common.sequence import SamplerOutput
- from aphrodite.modeling import SamplingMetadata
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.sampler import Sampler
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- ParallelLMHead, VocabParallelEmbedding)
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.transformers_utils.configs import MLPSpeculatorConfig
- SQRT2 = 2**0.5
- class MLPSpeculatorLayerNorm(nn.Module):
- """
- A L2 normalization implementation
- ...
- Args
- ----
- normalized_shape : int
- Dimensionality of input data (size of final tensor axis)
- eps : float
- Safety term to prevent division by zero. Make sure the chosen value
- fits in the range of your encoding scheme
- (i.e. fp16 requires eps >= 6e-8).
- elementwise_scale_and_shift : bool
- Include a learned scaling and shift term after normalization.
- """
- def __init__(
- self,
- normalized_shape,
- eps=1e-06,
- elementwise_scale_and_shift=True,
- ):
- super(MLPSpeculatorLayerNorm, self).__init__()
- self.elementwise_scale_and_shift = elementwise_scale_and_shift
- if self.elementwise_scale_and_shift:
- self.weight = nn.Parameter(torch.empty(normalized_shape))
- self.bias = nn.Parameter(torch.empty(normalized_shape))
- self.eps = eps
- def forward(self, x):
- xf = x
- xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
- x = xf.type_as(x)
- if self.elementwise_scale_and_shift:
- x = self.weight * x
- x = x + self.bias
- return x
- class MLPSpeculator(nn.Module):
- def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
- super().__init__()
- self.n_predict = config.n_predict
- self.vocab_size = config.vocab_size
- self.emb_dim = config.emb_dim
- self.inner_dim = config.inner_dim if config.inner_dim != 0 \
- else config.emb_dim
- self.max_speculative_tokens = config.num_lookahead_tokens
- self.tie_weights = config.tie_weights
- self.scale_input = config.scale_input
- if self.tie_weights:
- assert (
- self.n_predict >
- 1), "You cannot tie weights between stages when only 1 exists"
- embedding = VocabParallelEmbedding(
- config.vocab_size,
- self.inner_dim,
- org_num_embeddings=config.vocab_size)
- self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
- # the initial projection from the base model may
- # have a different size, so that stays separate.
- proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
- proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
- self.proj = nn.ModuleList([proj_first] + [proj_tied] *
- (self.max_speculative_tokens - 1))
- head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
- self.head = nn.ModuleList([head] * self.max_speculative_tokens)
- ln = MLPSpeculatorLayerNorm(self.inner_dim,
- elementwise_scale_and_shift=True)
- self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
- else:
- self.emb = nn.ModuleList([
- VocabParallelEmbedding(config.vocab_size,
- self.inner_dim,
- org_num_embeddings=config.vocab_size)
- for _ in range(self.max_speculative_tokens)
- ])
- self.proj = nn.ModuleList([
- nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
- self.inner_dim,
- bias=False)
- for i in range(self.max_speculative_tokens)
- ])
- self.head = nn.ModuleList([
- ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
- for _ in range(self.max_speculative_tokens)
- ])
- self.ln = nn.ModuleList([
- MLPSpeculatorLayerNorm(self.inner_dim,
- elementwise_scale_and_shift=True)
- for _ in range(self.max_speculative_tokens)
- ])
- if self.scale_input:
- self.ln0 = MLPSpeculatorLayerNorm(
- self.emb_dim, elementwise_scale_and_shift=False)
- self.state_weight = 0.5**(0.5 / config.n_predict)
- self.emb_weight = math.sqrt(
- (1 - self.state_weight**2) * (self.inner_dim / 2))
- self.activation = nn.GELU()
- self.config = config
- self.logits_processor = LogitsProcessor(config.vocab_size,
- config.vocab_size, 1.0)
- self.sampler = Sampler()
- def generate_proposals(
- self,
- input_ids: torch.Tensor,
- previous_hidden_states: torch.Tensor,
- num_predict_tokens: int,
- sampling_metadata: SamplingMetadata,
- ) -> List[SamplerOutput]:
- if num_predict_tokens > self.max_speculative_tokens:
- raise ValueError(f"Max speculative tokens for model is "
- f"{self.max_speculative_tokens}, but "
- f"{num_predict_tokens} were requested")
- # b x 1 x d
- previous_hidden_states = previous_hidden_states.unsqueeze(1)
- if self.scale_input:
- previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
- # b x 1
- last_tokens = input_ids.unsqueeze(1)
- next_tokens = []
- for head_index in range(num_predict_tokens):
- # Project and predict
- z = self.emb[head_index](last_tokens) # b k d
- states = self.proj[head_index](previous_hidden_states)
- # Weighted add of state_weight*state and emb_weight*z
- # Let subsequent LN take care of denominator
- # state_weight is close to 1, so shouldn't be any precision issues
- states.add_(z, alpha=self.emb_weight / self.state_weight)
- states = self.activation(self.ln[head_index](states)) # b k d
- previous_hidden_states = states
- # TODO: not yet supporting top_k_tokens_per_head
- states = states.flatten(0, 1)
- logits = self.logits_processor(self.head[head_index], states,
- sampling_metadata)
- output = self.sampler(logits, sampling_metadata)
- last_tokens = output.sampled_token_ids
- next_tokens.append(output)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- params_dict = dict(self.named_parameters())
- for name, loaded_weight in weights:
- param = params_dict.get(name.replace("speculator.", ""))
- if param is not None:
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|