123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- from dataclasses import dataclass
- from typing import List, Optional, Sequence, Tuple
- import torch
- import torch.nn.functional as F
- from torch.nn.parameter import Parameter, UninitializedParameter
- from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_reduce)
- from aphrodite.modeling.utils import set_weight_attrs
- from aphrodite.quantization.base_config import (
- QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
- DEFAULT_VOCAB_PADDING_SIZE = 64
- class UnquantizedEmbeddingMethod(QuantizeMethodBase):
- """Unquantized method for embeddings."""
- def create_weights(self, layer: torch.nn.Module,
- input_size_per_partition: int,
- output_partition_sizes: List[int], input_size: int,
- output_size: int, params_dtype: torch.dtype,
- **extra_weight_attrs):
- """Create weights for embedding layer."""
- weight = Parameter(torch.empty(sum(output_partition_sizes),
- input_size_per_partition,
- dtype=params_dtype),
- requires_grad=False)
- set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
- layer.register_parameter("weight", weight)
- set_weight_attrs(weight, extra_weight_attrs)
- def apply(self,
- layer: torch.nn.Module,
- x: torch.Tensor,
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
- return F.linear(x, layer.weight, bias)
- def embedding(self, layer: torch.nn.Module,
- input_: torch.Tensor) -> torch.Tensor:
- return F.embedding(input_, layer.weight)
- def pad_vocab_size(vocab_size: int,
- pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
- """Pad the vocab size to the given value."""
- return ((vocab_size + pad_to - 1) // pad_to) * pad_to
- def vocab_range_from_per_partition_vocab_size(
- per_partition_vocab_size: int,
- rank: int,
- offset: int = 0) -> Sequence[int]:
- index_f = rank * per_partition_vocab_size
- index_l = index_f + per_partition_vocab_size
- return index_f + offset, index_l + offset
- def vocab_range_from_global_vocab_size(global_vocab_size: int,
- rank: int,
- world_size: int,
- offset: int = 0) -> Sequence[int]:
- per_partition_vocab_size = divide(global_vocab_size, world_size)
- return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
- rank,
- offset=offset)
- @dataclass
- class VocabParallelEmbeddingShardIndices:
- """Indices for a shard of a vocab parallel embedding."""
- padded_org_vocab_start_index: int
- padded_org_vocab_end_index: int
- padded_added_vocab_start_index: int
- padded_added_vocab_end_index: int
- org_vocab_start_index: int
- org_vocab_end_index: int
- added_vocab_start_index: int
- added_vocab_end_index: int
- @property
- def num_org_elements(self) -> int:
- return self.org_vocab_end_index - self.org_vocab_start_index
- @property
- def num_added_elements(self) -> int:
- return self.added_vocab_end_index - self.added_vocab_start_index
- @property
- def num_org_elements_padded(self) -> int:
- return (self.padded_org_vocab_end_index -
- self.padded_org_vocab_start_index)
- @property
- def num_added_elements_padded(self) -> int:
- return (self.padded_added_vocab_end_index -
- self.padded_added_vocab_start_index)
- @property
- def num_org_vocab_padding(self) -> int:
- return self.num_org_elements_padded - self.num_org_elements
- @property
- def num_added_vocab_padding(self) -> int:
- return self.num_added_elements_padded - self.num_added_elements
- @property
- def num_elements_padded(self) -> int:
- return self.num_org_elements_padded + self.num_added_elements_padded
- def __post_init__(self):
- # sanity checks
- assert (self.padded_org_vocab_start_index <=
- self.padded_org_vocab_end_index)
- assert (self.padded_added_vocab_start_index <=
- self.padded_added_vocab_end_index)
- assert self.org_vocab_start_index <= self.org_vocab_end_index
- assert self.added_vocab_start_index <= self.added_vocab_end_index
- assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
- assert (self.added_vocab_start_index <=
- self.padded_added_vocab_start_index)
- assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
- assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
- assert self.num_org_elements <= self.num_org_elements_padded
- assert self.num_added_elements <= self.num_added_elements_padded
- @torch.jit.script
- def get_masked_input_and_mask(
- input_: torch.Tensor, org_vocab_start_index: int,
- org_vocab_end_index: int, num_org_vocab_padding: int,
- added_vocab_start_index: int,
- added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
- # torch.jit.script will fuse all of the pointwise ops below
- # into a single kernel, making it very fast
- org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
- org_vocab_end_index)
- added_vocab_mask = (input_ >= added_vocab_start_index) & (
- input_ < added_vocab_end_index)
- added_offset = added_vocab_start_index - (
- org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
- valid_offset = (org_vocab_start_index *
- org_vocab_mask) + (added_offset * added_vocab_mask)
- vocab_mask = org_vocab_mask | added_vocab_mask
- input_ = vocab_mask * (input_ - valid_offset)
- return input_, ~vocab_mask
- class VocabParallelEmbedding(torch.nn.Module):
- """Embedding parallelized in the vocabulary dimension.
- Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
- make sure it is divisible by the number of model parallel GPUs.
- In order to support various loading methods, we ensure that LoRA-added
- embeddings are always at the end of TP-sharded tensors. In other words,
- we shard base embeddings and LoRA embeddings separately (both padded),
- and place them in the same tensor.
- In this example, we will have the original vocab size = 1010,
- added vocab size = 16 and padding to 64. Therefore, the total
- vocab size with padding will be 1088 (because we first pad 1010 to
- 1024, add 16, and then pad to 1088).
- Therefore, the tensor format looks like the following:
- TP1, rank 0 (no sharding):
- |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
- corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
- index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
- TP2, rank 0:
- |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
- corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
- index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
- TP2, rank 1:
- |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
- corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
- index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
- Args:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
- params_dtype: type of the parameters.
- org_num_embeddings: original vocabulary size (without LoRA).
- padding_size: padding size for the vocabulary.
- quant_config: quant config for the layer.
- prefix: full name of the layer in the state dict
- """ # noqa: E501
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- params_dtype: Optional[torch.dtype] = None,
- org_num_embeddings: Optional[int] = None,
- padding_size: Optional[int] = None,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = ""):
- super().__init__()
- padding_size = padding_size or get_tensor_model_parallel_world_size()
- # Keep the input dimensions.
- tp_rank = get_tensor_model_parallel_rank()
- self.tp_size = get_tensor_model_parallel_world_size()
- self.num_embeddings = num_embeddings
- self.padding_size = padding_size
- self.org_vocab_size = org_num_embeddings or num_embeddings
- num_added_embeddings = num_embeddings - self.org_vocab_size
- self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
- self.padding_size)
- self.num_embeddings_padded = pad_vocab_size(
- self.org_vocab_size_padded + num_added_embeddings,
- self.padding_size)
- assert self.org_vocab_size_padded <= self.num_embeddings_padded
- self.shard_indices = self._get_indices(self.num_embeddings_padded,
- self.org_vocab_size_padded,
- self.num_embeddings,
- self.org_vocab_size, tp_rank,
- self.tp_size)
- self.embedding_dim = embedding_dim
- linear_method = None
- if quant_config is not None:
- linear_method = quant_config.get_quant_method(self, prefix=prefix)
- if linear_method is None:
- linear_method = UnquantizedEmbeddingMethod()
- # If we are making an embedding layer, then our quantization linear
- # method must implement the embedding operation. If we are another
- # layer type like ParallelLMHead, this is not important.
- is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
- linear_method_implements_embedding = method_has_implemented_embedding(
- type(linear_method))
- if is_embedding_layer and not linear_method_implements_embedding:
- raise NotImplementedError(
- f"The class {type(linear_method).__name__} must implement "
- "the 'embedding' method, see UnquantizedEmbeddingMethod.")
- self.linear_method: QuantizeMethodBase = linear_method
- if params_dtype is None:
- params_dtype = torch.get_default_dtype()
- # Divide the weight matrix along the vocaburaly dimension.
- self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
- self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
- self.tp_size)
- assert (self.shard_indices.num_elements_padded ==
- self.num_embeddings_per_partition)
- self.num_org_embeddings_per_partition = (
- self.shard_indices.org_vocab_end_index -
- self.shard_indices.org_vocab_start_index)
- self.num_added_embeddings_per_partition = (
- self.shard_indices.added_vocab_end_index -
- self.shard_indices.added_vocab_start_index)
- self.linear_method.create_weights(self,
- self.embedding_dim,
- [self.num_embeddings_per_partition],
- self.embedding_dim,
- self.num_embeddings_padded,
- params_dtype=params_dtype,
- weight_loader=self.weight_loader)
- @classmethod
- def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
- vocab_size: int, org_vocab_size: int, tp_rank: int,
- tp_size: int) -> VocabParallelEmbeddingShardIndices:
- """Get start and end indices for vocab parallel embedding, following the
- layout outlined in the class docstring, based on the given tp_rank and
- tp_size."""
- num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
- padded_org_vocab_start_index, padded_org_vocab_end_index = (
- vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
- tp_size))
- padded_added_vocab_start_index, padded_added_vocab_end_index = (
- vocab_range_from_global_vocab_size(num_added_embeddings_padded,
- tp_rank,
- tp_size,
- offset=org_vocab_size))
- # remove padding
- org_vocab_start_index = min(padded_org_vocab_start_index,
- org_vocab_size)
- org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
- added_vocab_start_index = min(padded_added_vocab_start_index,
- vocab_size)
- added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
- return VocabParallelEmbeddingShardIndices(
- padded_org_vocab_start_index, padded_org_vocab_end_index,
- padded_added_vocab_start_index, padded_added_vocab_end_index,
- org_vocab_start_index, org_vocab_end_index,
- added_vocab_start_index, added_vocab_end_index)
- def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
- """Get a mapping that can be used to reindex the gathered
- logits for sampling.
-
- During sampling, we gather logits from all ranks. The relationship
- of index->token_id will follow the same format as outlined in the class
- docstring. However, after the gather, we want to reindex the final
- logits tensor to map index->token_id one-to-one (the index is always
- equal the token_id it corresponds to). The indices returned by this
- method allow us to do that.
- """
- if self.tp_size < 2:
- return None
- base_embeddings: List[int] = []
- added_embeddings: List[int] = []
- padding: List[int] = []
- for tp_rank in range(self.tp_size):
- shard_indices = self._get_indices(self.num_embeddings_padded,
- self.org_vocab_size_padded,
- self.num_embeddings,
- self.org_vocab_size, tp_rank,
- self.tp_size)
- range_start = self.num_embeddings_per_partition * tp_rank
- range_end = self.num_embeddings_per_partition * (tp_rank + 1)
- base_embeddings.extend(
- range(range_start,
- range_start + shard_indices.num_org_elements))
- padding.extend(
- range(range_start + shard_indices.num_org_elements,
- range_start + shard_indices.num_org_elements_padded))
- added_embeddings.extend(
- range(
- range_start + shard_indices.num_org_elements_padded,
- range_start + shard_indices.num_org_elements_padded +
- shard_indices.num_added_elements))
- padding.extend(
- range(
- range_start + shard_indices.num_org_elements_padded +
- shard_indices.num_added_elements,
- range_start + shard_indices.num_org_elements_padded +
- shard_indices.num_added_elements_padded))
- assert (range_start + shard_indices.num_org_elements_padded +
- shard_indices.num_added_elements_padded == range_end)
- ret = base_embeddings + added_embeddings + padding
- assert len(ret) == self.num_embeddings_padded
- return ret
- def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
- output_dim = getattr(param, "output_dim", None)
- packed_dim = getattr(param, "packed_dim", None)
- # If the parameter is a gguf weight, then load it directly.
- if getattr(param, "is_gguf_weight_type", None):
- param.data.copy_(loaded_weight)
- param.weight_type = loaded_weight.item()
- return
- elif isinstance(param, UninitializedParameter):
- shape = list(loaded_weight.shape)
- if output_dim is not None:
- shape[output_dim] = shape[output_dim] // self.tp_size
- param.materialize(tuple(shape), dtype=loaded_weight.dtype)
- # If parameter does not have output dim, then it should
- # be copied onto all gpus (e.g. g_idx for act_order gptq).
- if output_dim is None:
- assert param.data.shape == loaded_weight.shape
- param.data.copy_(loaded_weight)
- return
- # Shard indexes for loading the weight
- start_idx = self.shard_indices.org_vocab_start_index
- shard_size = self.shard_indices.org_vocab_end_index - start_idx
- # If param packed on the same dim we are sharding on, then
- # need to adjust offsets of loaded weight by pack_factor.
- if packed_dim is not None and packed_dim == output_dim:
- assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
- param.pack_factor)
- start_idx = start_idx // param.pack_factor
- shard_size = shard_size // param.pack_factor
- else:
- assert loaded_weight.shape[output_dim] == self.org_vocab_size
- # Copy the data.
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
- param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
- param[loaded_weight.shape[0]:].data.fill_(0)
- def forward(self, input_):
- if self.tp_size > 1:
- # Build the mask.
- masked_input, input_mask = get_masked_input_and_mask(
- input_, self.shard_indices.org_vocab_start_index,
- self.shard_indices.org_vocab_end_index,
- self.shard_indices.num_org_vocab_padding,
- self.shard_indices.added_vocab_start_index,
- self.shard_indices.added_vocab_end_index)
- else:
- masked_input = input_
- # Get the embeddings.
- output_parallel = self.linear_method.embedding(self,
- masked_input.long())
- # Mask the output embedding.
- if self.tp_size > 1:
- output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
- # Reduce across all the model parallel GPUs.
- output = tensor_model_parallel_all_reduce(output_parallel)
- return output
- def extra_repr(self) -> str:
- s = f"num_embeddings={self.num_embeddings_per_partition}"
- s += f", embedding_dim={self.embedding_dim}"
- s += f", org_vocab_size={self.org_vocab_size}"
- s += f', num_embeddings_padded={self.num_embeddings_padded}'
- s += f', tp_size={self.tp_size}'
- return s
- class ParallelLMHead(VocabParallelEmbedding):
- """Parallelized LM head.
- Output logits weight matrices used in the Sampler. The weight and bias
- tensors are padded to make sure they are divisible by the number of
- model parallel GPUs.
- Args:
- num_embeddings: vocabulary size.
- embedding_dim: size of hidden state.
- bias: whether to use bias.
- params_dtype: type of the parameters.
- org_num_embeddings: original vocabulary size (without LoRA).
- padding_size: padding size for the vocabulary.
- """
- def __init__(self,
- num_embeddings: int,
- embedding_dim: int,
- bias: bool = False,
- params_dtype: Optional[torch.dtype] = None,
- org_num_embeddings: Optional[int] = None,
- padding_size: Optional[int] = None,
- quant_config: Optional[QuantizationConfig] = None,
- prefix: str = ""):
- super().__init__(num_embeddings, embedding_dim, params_dtype,
- org_num_embeddings, padding_size, quant_config,
- prefix)
- if bias:
- self.bias = Parameter(
- torch.empty(self.num_embeddings_per_partition,
- dtype=params_dtype))
- set_weight_attrs(self.bias, {
- "output_dim": 0,
- "weight_loader": self.weight_loader,
- })
- else:
- self.register_parameter("bias", None)
- def forward(self, input_):
- logits = self.linear_method.apply_weights(self.linear_weights, input_)
- if self.bias is not None:
- logits += self.bias
- return logits
- class ParallelTWEHead(torch.nn.Module):
- """Parallelized tie word embeddings head.
- Output logits weight matrices used in the Sampler. The weight and bias
- tensors are read from a VocabParallelEmbedding.
- Args:
- embeddings: the VocabParallelEmbedding to mirror
- """
- def __init__(self, embeddings: VocabParallelEmbedding):
- super().__init__()
- self.linear_method = embeddings.linear_method
- self.linear_weights = embeddings.linear_weights
- def forward(self, input_):
- logits = self.linear_method.apply_weights(self.linear_weights, input_)
- return logits
|