|
@@ -1,8 +1,11 @@
|
|
|
from typing import Optional, Sequence
|
|
|
|
|
|
import torch
|
|
|
+import torch.nn.functional as F
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
|
|
+from aphrodite.modeling.megatron.communication_op import (
|
|
|
+ tensor_model_parallel_gather)
|
|
|
from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
|
|
|
from aphrodite.modeling.megatron.parallel_state import (
|
|
|
get_tensor_model_parallel_rank,
|
|
@@ -54,7 +57,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
num_embeddings: int,
|
|
|
embedding_dim: int,
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
- linear_method=None,
|
|
|
+ linear_method = None,
|
|
|
org_num_embeddings: Optional[int] = None,
|
|
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
|
|
super().__init__()
|
|
@@ -75,28 +78,26 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
self.tp_size))
|
|
|
self.num_embeddings_per_partition = (self.vocab_end_index -
|
|
|
self.vocab_start_index)
|
|
|
- if linear_method is None or not linear_method.quant_config.quant_vocab(
|
|
|
- ):
|
|
|
+ if linear_method is None or not linear_method.quant_config.quant_vocab():
|
|
|
linear_method = UnquantizedLinearMethod()
|
|
|
self.linear_method = linear_method
|
|
|
self.linear_weights = self.linear_method.create_weights(
|
|
|
- self.embedding_dim, self.num_embeddings_per_partition,
|
|
|
- self.embedding_dim, self.num_embeddings_padded, params_dtype)
|
|
|
+ self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
|
|
|
+ self.num_embeddings_padded, params_dtype)
|
|
|
for name, weight in self.linear_weights.items():
|
|
|
if isinstance(weight, torch.nn.parameter.Parameter):
|
|
|
self.register_parameter(name, weight)
|
|
|
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
|
|
|
|
|
+
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
output_dim = getattr(param, "output_dim", None)
|
|
|
if output_dim is not None:
|
|
|
- assert loaded_weight.shape[output_dim] == self.num_embeddings
|
|
|
+ assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
|
|
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
|
|
vocab_end_index]
|
|
|
if isinstance(param, torch.nn.parameter.UninitializedParameter):
|
|
|
- param.materialize(
|
|
|
- (self.num_embeddings_per_partition, loaded_weight.shape[1]),
|
|
|
- dtype=loaded_weight.dtype)
|
|
|
+ param.materialize((self.num_embeddings_per_partition, loaded_weight.shape[1]), dtype=loaded_weight.dtype)
|
|
|
if output_dim is not None:
|
|
|
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
|
|
else:
|
|
@@ -113,8 +114,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
|
else:
|
|
|
masked_input = input_
|
|
|
# Get the embeddings.
|
|
|
- output_parallel = self.linear_method.apply_embedding(
|
|
|
- self.linear_weights, masked_input)
|
|
|
+ output_parallel = self.linear_method.apply_embedding(self.linear_weights, masked_input)
|
|
|
+ # output_parallel = F.embedding(masked_input, self.weight)
|
|
|
# Mask the output embedding.
|
|
|
if self.tp_size > 1:
|
|
|
output_parallel[input_mask, :] = 0.0
|
|
@@ -144,11 +145,11 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
|
embedding_dim: int,
|
|
|
bias: bool = False,
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
- linear_method=None,
|
|
|
+ linear_method = None,
|
|
|
org_num_embeddings: Optional[int] = None,
|
|
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
|
|
|
- super().__init__(num_embeddings, embedding_dim, params_dtype,
|
|
|
- linear_method, org_num_embeddings, padding_size)
|
|
|
+ super().__init__(num_embeddings, embedding_dim, params_dtype, linear_method,
|
|
|
+ org_num_embeddings, padding_size)
|
|
|
if bias:
|
|
|
self.bias = Parameter(
|
|
|
torch.empty(self.num_embeddings_per_partition,
|
|
@@ -161,7 +162,8 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
|
self.register_parameter("bias", None)
|
|
|
|
|
|
def forward(self, input_):
|
|
|
- logits = self.linear_method.apply_weights(self.linear_weights, input_)
|
|
|
+ logits = self.linear_method.apply_weights(
|
|
|
+ self.linear_weights, input_)
|
|
|
if self.bias is not None:
|
|
|
logits += self.bias
|
|
|
return logits
|