123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- from typing import Optional, Sequence
- import torch
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter
- from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_reduce)
- from aphrodite.modeling.utils import set_weight_attrs
- DEFAULT_VOCAB_PADDING_SIZE = 64
- def pad_vocab_size(vocab_size: int,
- pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
- """Pad the vocab size to the given value."""
- return ((vocab_size + pad_to - 1) // pad_to) * pad_to
- def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
- rank: int) -> Sequence[int]:
- index_f = rank * per_partition_vocab_size
- index_l = index_f + per_partition_vocab_size
- return index_f, index_l
- def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
- world_size: int) -> Sequence[int]:
- per_partition_vocab_size = divide(global_vocab_size, world_size)
- return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
- rank)
- class VocabParallelEmbedding(torch.nn.Module):
- """Embedding parallelized in the vocabulary dimension.
- Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
- make sure it is divisible by the number of model parallel GPUs.
- Args:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
- params_dtype: type of the parameters.
- org_num_embeddings: original vocabulary size (without LoRA).
- padding_size: padding size for the vocabulary.
- """
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- params_dtype: Optional[torch.dtype] = None,
- org_num_embeddings: Optional[int] = None,
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
- super().__init__()
- # Keep the input dimensions.
- self.num_embeddings = num_embeddings
- self.org_vocab_size = org_num_embeddings or num_embeddings
- self.num_embeddings_padded = pad_vocab_size(num_embeddings,
- padding_size)
- self.embedding_dim = embedding_dim
- if params_dtype is None:
- params_dtype = torch.get_default_dtype()
- self.tp_size = get_tensor_model_parallel_world_size()
- # Divide the weight matrix along the vocaburaly dimension.
- self.vocab_start_index, self.vocab_end_index = (
- vocab_range_from_global_vocab_size(
- self.num_embeddings_padded, get_tensor_model_parallel_rank(),
- self.tp_size))
- self.num_embeddings_per_partition = (self.vocab_end_index -
- self.vocab_start_index)
- self.weight = Parameter(
- torch.empty(self.num_embeddings_per_partition,
- self.embedding_dim,
- dtype=params_dtype))
- set_weight_attrs(self.weight, {
- "parallel_dim": 0,
- "weight_loader": self.weight_loader
- })
- def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
- parallel_dim = param.parallel_dim
- assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
- loaded_weight = loaded_weight[self.vocab_start_index:self.
- vocab_end_index]
- param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
- def forward(self, input_):
- if self.tp_size > 1:
- # Build the mask.
- input_mask = ((input_ < self.vocab_start_index) |
- (input_ >= self.vocab_end_index))
- # Mask the input.
- masked_input = input_.clone() - self.vocab_start_index
- masked_input[input_mask] = 0
- else:
- masked_input = input_
- # Get the embeddings.
- output_parallel = F.embedding(masked_input, self.weight)
- # Mask the output embedding.
- if self.tp_size > 1:
- output_parallel[input_mask, :] = 0.0
- # Reduce across all the model parallel GPUs.
- output = tensor_model_parallel_all_reduce(output_parallel)
- return output
- def extra_repr(self) -> str:
- s = f"num_embeddings={self.num_embeddings_per_partition}"
- s += f", embedding_dim={self.embedding_dim}"
- s += f", org_vocab_size={self.org_vocab_size}"
- s += f', num_embeddings_padded={self.num_embeddings_padded}'
- s += f', tp_size={self.tp_size}'
- return s
- class ParallelLMHead(VocabParallelEmbedding):
- """Parallelized LM head.
- Output logits weight matrices used in the Sampler. The weight and bias
- tensors are padded to make sure they are divisible by the number of
- model parallel GPUs.
- Args:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
- bias: whether to use bias.
- params_dtype: type of the parameters.
- org_num_embeddings: original vocabulary size (without LoRA).
- padding_size: padding size for the vocabulary.
- """
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- bias: bool = False,
- params_dtype: Optional[torch.dtype] = None,
- org_num_embeddings: Optional[int] = None,
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
- super().__init__(num_embeddings, embedding_dim, params_dtype,
- org_num_embeddings, padding_size)
- if bias:
- self.bias = Parameter(
- torch.empty(self.num_embeddings_per_partition,
- dtype=params_dtype))
- set_weight_attrs(self.bias, {
- "parallel_dim": 0,
- "weight_loader": self.weight_loader
- })
- else:
- self.register_parameter("bias", None)
- def forward(self, input_):
- logits = self.linear_method.apply_weights(self.linear_weights, input_)
- if self.bias is not None:
- logits += self.bias
- return logits
- class ParallelTWEHead(torch.nn.Module):
- """Parallelized tie word embeddings head.
- Output logits weight matrices used in the Sampler. The weight and bias
- tensors are read from a VocabParallelEmbedding.
- Args:
- embeddings: the VocabParallelEmbedding to mirror
- """
- def __init__(self, embeddings: VocabParallelEmbedding):
- super().__init__()
- self.linear_method = embeddings.linear_method
- self.linear_weights = embeddings.linear_weights
- def forward(self, input_):
- logits = self.linear_method.apply_weights(self.linear_weights, input_)
- return logits
|