123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- from typing import Optional, Sequence
- import torch
- from torch.nn.parameter import Parameter
- from aphrodite.distributed import (
- divide,
- get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_reduce,
- )
- from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
- from aphrodite.modeling.utils import set_weight_attrs
- DEFAULT_VOCAB_PADDING_SIZE = 64
- def pad_vocab_size(vocab_size: int,
- pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
- """Pad the vocab size to the given value."""
- return ((vocab_size + pad_to - 1) // pad_to) * pad_to
- def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
- rank: int) -> Sequence[int]:
- index_f = rank * per_partition_vocab_size
- index_l = index_f + per_partition_vocab_size
- return index_f, index_l
- def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
- world_size: int) -> Sequence[int]:
- per_partition_vocab_size = divide(global_vocab_size, world_size)
- return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
- rank)
- class VocabParallelEmbedding(torch.nn.Module):
- """Embedding parallelized in the vocabulary dimension.
- Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
- make sure it is divisible by the number of model parallel GPUs.
- Args:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
- params_dtype: type of the parameters.
- org_num_embeddings: original vocabulary size (without LoRA).
- padding_size: padding size for the vocabulary.
- """
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- params_dtype: Optional[torch.dtype] = None,
- linear_method=None,
- org_num_embeddings: Optional[int] = None,
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
- is_input_emb: bool = True):
- super().__init__()
- # Keep the input dimensions.
- self.num_embeddings = num_embeddings
- self.org_vocab_size = org_num_embeddings or num_embeddings
- self.num_embeddings_padded = pad_vocab_size(num_embeddings,
- padding_size)
- self.embedding_dim = embedding_dim
- if params_dtype is None:
- params_dtype = torch.get_default_dtype()
- self.tp_size = get_tensor_model_parallel_world_size()
- # Divide the weight matrix along the vocaburaly dimension.
- self.vocab_start_index, self.vocab_end_index = (
- vocab_range_from_global_vocab_size(
- self.num_embeddings_padded, get_tensor_model_parallel_rank(),
- self.tp_size))
- self.num_embeddings_per_partition = (self.vocab_end_index -
- self.vocab_start_index)
- idx = 0 if is_input_emb else 1
- if linear_method is None or not linear_method.quant_config.quant_vocab(
- )[idx]:
- linear_method = UnquantizedLinearMethod()
- self.linear_method = linear_method
- self.linear_weights = self.linear_method.create_weights(
- self.embedding_dim, [self.num_embeddings_per_partition],
- self.embedding_dim, self.num_embeddings_padded, params_dtype)
- for name, weight in self.linear_weights.items():
- if isinstance(weight, torch.nn.parameter.Parameter):
- self.register_parameter(name, weight)
- set_weight_attrs(weight, {"weight_loader": self.weight_loader})
- def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
- output_dim = getattr(param, "output_dim", None)
- packed_dim = getattr(param, "packed_dim", None)
- if output_dim is not None:
- shard_offset = self.vocab_start_index
- shard_size = min(self.vocab_end_index,
- self.org_vocab_size) - shard_offset
- if packed_dim == output_dim:
- shard_size = shard_size // param.pack_factor
- shard_offset = shard_offset // param.pack_factor
- loaded_weight = loaded_weight.narrow(output_dim, shard_offset,
- shard_size)
- if isinstance(param, torch.nn.parameter.UninitializedParameter):
- vocab_shape = list(loaded_weight.shape)
- if output_dim is not None:
- if packed_dim == output_dim:
- vocab_shape[output_dim] = (
- self.num_embeddings_per_partition // param.pack_factor)
- else:
- vocab_shape[output_dim] = self.num_embeddings_per_partition
- param.materialize(vocab_shape, dtype=loaded_weight.dtype)
- if output_dim is not None:
- param.data.narrow(
- output_dim, 0,
- loaded_weight.shape[output_dim]).copy_(loaded_weight)
- else:
- param.data.copy_(loaded_weight)
- def forward(self, input_):
- if self.tp_size > 1:
- # Build the mask.
- input_mask = ((input_ < self.vocab_start_index) |
- (input_ >= self.vocab_end_index))
- # Mask the input.
- masked_input = input_.clone() - self.vocab_start_index
- masked_input[input_mask] = 0
- else:
- masked_input = input_
- # Get the embeddings.
- output_parallel = self.linear_method.apply_embedding(
- self.linear_weights, masked_input)
- # output_parallel = F.embedding(masked_input, self.weight)
- # Mask the output embedding.
- if self.tp_size > 1:
- output_parallel[input_mask, :] = 0.0
- # Reduce across all the model parallel GPUs.
- output = tensor_model_parallel_all_reduce(output_parallel)
- return output
- class ParallelLMHead(VocabParallelEmbedding):
- """Parallelized LM head.
- Output logits weight matrices used in the Sampler. The weight and bias
- tensors are padded to make sure they are divisible by the number of
- model parallel GPUs.
- Args:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
- bias: whether to use bias.
- params_dtype: type of the parameters.
- org_num_embeddings: original vocabulary size (without LoRA).
- padding_size: padding size for the vocabulary.
- """
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- bias: bool = False,
- params_dtype: Optional[torch.dtype] = None,
- linear_method=None,
- org_num_embeddings: Optional[int] = None,
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
- super().__init__(num_embeddings, embedding_dim, params_dtype,
- linear_method, org_num_embeddings, padding_size,
- False)
- if bias:
- self.bias = Parameter(
- torch.empty(self.num_embeddings_per_partition,
- dtype=params_dtype))
- set_weight_attrs(self.bias, {
- "output_dim": 0,
- "weight_loader": self.weight_loader
- })
- else:
- self.register_parameter("bias", None)
- def forward(self, input_):
- logits = self.linear_method.apply_weights(self.linear_weights, input_)
- if self.bias is not None:
- logits += self.bias
- return logits
|