|
@@ -8,7 +8,10 @@ from torch.nn.parameter import Parameter
|
|
|
from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
|
|
|
get_tensor_model_parallel_world_size,
|
|
|
tensor_model_parallel_all_reduce)
|
|
|
+from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
|
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
|
+from aphrodite.quantization.base_config import (QuantizationConfig,
|
|
|
+ QuantizeMethodBase)
|
|
|
|
|
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
|
|
|
|
@@ -156,6 +159,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
params_dtype: type of the parameters.
|
|
|
org_num_embeddings: original vocabulary size (without LoRA).
|
|
|
padding_size: padding size for the vocabulary.
|
|
|
+ quant_config: quant config for the layer.
|
|
|
""" # noqa: E501
|
|
|
|
|
|
def __init__(self,
|
|
@@ -163,7 +167,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
embedding_dim: int,
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
org_num_embeddings: Optional[int] = None,
|
|
|
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
|
|
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None):
|
|
|
super().__init__()
|
|
|
|
|
|
# Keep the input dimensions.
|
|
@@ -186,6 +191,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
self.org_vocab_size, tp_rank,
|
|
|
self.tp_size)
|
|
|
self.embedding_dim = embedding_dim
|
|
|
+
|
|
|
+ linear_method = None
|
|
|
+ if quant_config is not None:
|
|
|
+ linear_method = quant_config.get_quant_method(self)
|
|
|
+ if linear_method is None:
|
|
|
+ linear_method = UnquantizedLinearMethod()
|
|
|
+ self.linear_method: QuantizeMethodBase = linear_method
|
|
|
+
|
|
|
if params_dtype is None:
|
|
|
params_dtype = torch.get_default_dtype()
|
|
|
# Divide the weight matrix along the vocaburaly dimension.
|
|
@@ -200,14 +213,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
self.num_added_embeddings_per_partition = (
|
|
|
self.shard_indices.added_vocab_end_index -
|
|
|
self.shard_indices.added_vocab_start_index)
|
|
|
- self.weight = Parameter(
|
|
|
- torch.empty(self.num_embeddings_per_partition,
|
|
|
- self.embedding_dim,
|
|
|
- dtype=params_dtype))
|
|
|
- set_weight_attrs(self.weight, {
|
|
|
- "parallel_dim": 0,
|
|
|
- "weight_loader": self.weight_loader
|
|
|
- })
|
|
|
+ self.linear_method.create_weights(self,
|
|
|
+ self.embedding_dim,
|
|
|
+ [self.num_embeddings_per_partition],
|
|
|
+ self.embedding_dim,
|
|
|
+ self.num_embeddings_padded,
|
|
|
+ params_dtype=params_dtype,
|
|
|
+ weight_loader=self.weight_loader)
|
|
|
|
|
|
@classmethod
|
|
|
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
|
@@ -287,10 +299,32 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
return ret
|
|
|
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
- parallel_dim = param.parallel_dim
|
|
|
- assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
|
|
|
- loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
|
|
|
- self.shard_indices.org_vocab_end_index]
|
|
|
+ output_dim = getattr(param, "output_dim", None)
|
|
|
+ packed_dim = getattr(param, "packed_dim", None)
|
|
|
+
|
|
|
+ # If parameter does not have output dim, then it should
|
|
|
+ # be copied onto all gpus (e.g. g_idx for act_order gptq).
|
|
|
+ if output_dim is None:
|
|
|
+ assert param.data.shape == loaded_weight.shape
|
|
|
+ param.data.copy_(loaded_weight)
|
|
|
+ return
|
|
|
+
|
|
|
+ # Shard indexes for loading the weight
|
|
|
+ start_idx = self.shard_indices.org_vocab_start_index
|
|
|
+ shard_size = self.shard_indices.org_vocab_end_index - start_idx
|
|
|
+
|
|
|
+ # If param packed on the same dim we are sharding on, then
|
|
|
+ # need to adjust offsets of loaded weight by pack_factor.
|
|
|
+ if packed_dim is not None and packed_dim == output_dim:
|
|
|
+ assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
|
|
+ param.pack_factor)
|
|
|
+ start_idx = start_idx // param.pack_factor
|
|
|
+ shard_size = shard_size // param.pack_factor
|
|
|
+ else:
|
|
|
+ assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
|
|
+
|
|
|
+ # Copy the data.
|
|
|
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
|
|
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
|
|
param[loaded_weight.shape[0]:].data.fill_(0)
|
|
|
|
|
@@ -345,16 +379,17 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
|
bias: bool = False,
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
org_num_embeddings: Optional[int] = None,
|
|
|
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
|
|
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
|
|
+ quant_config: Optional[QuantizationConfig] = None):
|
|
|
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
|
|
- org_num_embeddings, padding_size)
|
|
|
+ org_num_embeddings, padding_size, quant_config)
|
|
|
if bias:
|
|
|
self.bias = Parameter(
|
|
|
torch.empty(self.num_embeddings_per_partition,
|
|
|
dtype=params_dtype))
|
|
|
set_weight_attrs(self.bias, {
|
|
|
- "parallel_dim": 0,
|
|
|
- "weight_loader": self.weight_loader
|
|
|
+ "output_dim": 0,
|
|
|
+ "weight_loader": self.weight_loader,
|
|
|
})
|
|
|
else:
|
|
|
self.register_parameter("bias", None)
|