vocab_parallel_embedding.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from typing import Optional, Sequence
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite.distributed import (
  5. divide,
  6. get_tensor_model_parallel_rank,
  7. get_tensor_model_parallel_world_size,
  8. tensor_model_parallel_all_reduce,
  9. )
  10. from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
  11. from aphrodite.modeling.utils import set_weight_attrs
  12. DEFAULT_VOCAB_PADDING_SIZE = 64
  13. def pad_vocab_size(vocab_size: int,
  14. pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
  15. """Pad the vocab size to the given value."""
  16. return ((vocab_size + pad_to - 1) // pad_to) * pad_to
  17. def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
  18. rank: int) -> Sequence[int]:
  19. index_f = rank * per_partition_vocab_size
  20. index_l = index_f + per_partition_vocab_size
  21. return index_f, index_l
  22. def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
  23. world_size: int) -> Sequence[int]:
  24. per_partition_vocab_size = divide(global_vocab_size, world_size)
  25. return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  26. rank)
  27. class VocabParallelEmbedding(torch.nn.Module):
  28. """Embedding parallelized in the vocabulary dimension.
  29. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
  30. make sure it is divisible by the number of model parallel GPUs.
  31. Args:
  32. num_embeddings: vocabulary size.
  33. embedding_dim: size of hidden state.
  34. params_dtype: type of the parameters.
  35. org_num_embeddings: original vocabulary size (without LoRA).
  36. padding_size: padding size for the vocabulary.
  37. """
  38. def __init__(self,
  39. num_embeddings: int,
  40. embedding_dim: int,
  41. params_dtype: Optional[torch.dtype] = None,
  42. linear_method=None,
  43. org_num_embeddings: Optional[int] = None,
  44. padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
  45. is_input_emb: bool = True):
  46. super().__init__()
  47. # Keep the input dimensions.
  48. self.num_embeddings = num_embeddings
  49. self.org_vocab_size = org_num_embeddings or num_embeddings
  50. self.num_embeddings_padded = pad_vocab_size(num_embeddings,
  51. padding_size)
  52. self.embedding_dim = embedding_dim
  53. if params_dtype is None:
  54. params_dtype = torch.get_default_dtype()
  55. self.tp_size = get_tensor_model_parallel_world_size()
  56. # Divide the weight matrix along the vocaburaly dimension.
  57. self.vocab_start_index, self.vocab_end_index = (
  58. vocab_range_from_global_vocab_size(
  59. self.num_embeddings_padded, get_tensor_model_parallel_rank(),
  60. self.tp_size))
  61. self.num_embeddings_per_partition = (self.vocab_end_index -
  62. self.vocab_start_index)
  63. idx = 0 if is_input_emb else 1
  64. if linear_method is None or not linear_method.quant_config.quant_vocab(
  65. )[idx]:
  66. linear_method = UnquantizedLinearMethod()
  67. self.linear_method = linear_method
  68. self.linear_weights = self.linear_method.create_weights(
  69. self.embedding_dim, [self.num_embeddings_per_partition],
  70. self.embedding_dim, self.num_embeddings_padded, params_dtype)
  71. for name, weight in self.linear_weights.items():
  72. if isinstance(weight, torch.nn.parameter.Parameter):
  73. self.register_parameter(name, weight)
  74. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  75. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  76. output_dim = getattr(param, "output_dim", None)
  77. packed_dim = getattr(param, "packed_dim", None)
  78. if output_dim is not None:
  79. shard_offset = self.vocab_start_index
  80. shard_size = min(self.vocab_end_index,
  81. self.org_vocab_size) - shard_offset
  82. if packed_dim == output_dim:
  83. shard_size = shard_size // param.pack_factor
  84. shard_offset = shard_offset // param.pack_factor
  85. loaded_weight = loaded_weight.narrow(output_dim, shard_offset,
  86. shard_size)
  87. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  88. vocab_shape = list(loaded_weight.shape)
  89. if output_dim is not None:
  90. if packed_dim == output_dim:
  91. vocab_shape[output_dim] = (
  92. self.num_embeddings_per_partition // param.pack_factor)
  93. else:
  94. vocab_shape[output_dim] = self.num_embeddings_per_partition
  95. param.materialize(vocab_shape, dtype=loaded_weight.dtype)
  96. if output_dim is not None:
  97. param.data.narrow(
  98. output_dim, 0,
  99. loaded_weight.shape[output_dim]).copy_(loaded_weight)
  100. else:
  101. param.data.copy_(loaded_weight)
  102. def forward(self, input_):
  103. if self.tp_size > 1:
  104. # Build the mask.
  105. input_mask = ((input_ < self.vocab_start_index) |
  106. (input_ >= self.vocab_end_index))
  107. # Mask the input.
  108. masked_input = input_.clone() - self.vocab_start_index
  109. masked_input[input_mask] = 0
  110. else:
  111. masked_input = input_
  112. # Get the embeddings.
  113. output_parallel = self.linear_method.apply_embedding(
  114. self.linear_weights, masked_input)
  115. # output_parallel = F.embedding(masked_input, self.weight)
  116. # Mask the output embedding.
  117. if self.tp_size > 1:
  118. output_parallel[input_mask, :] = 0.0
  119. # Reduce across all the model parallel GPUs.
  120. output = tensor_model_parallel_all_reduce(output_parallel)
  121. return output
  122. class ParallelLMHead(VocabParallelEmbedding):
  123. """Parallelized LM head.
  124. Output logits weight matrices used in the Sampler. The weight and bias
  125. tensors are padded to make sure they are divisible by the number of
  126. model parallel GPUs.
  127. Args:
  128. num_embeddings: vocabulary size.
  129. embedding_dim: size of hidden state.
  130. bias: whether to use bias.
  131. params_dtype: type of the parameters.
  132. org_num_embeddings: original vocabulary size (without LoRA).
  133. padding_size: padding size for the vocabulary.
  134. """
  135. def __init__(self,
  136. num_embeddings: int,
  137. embedding_dim: int,
  138. bias: bool = False,
  139. params_dtype: Optional[torch.dtype] = None,
  140. linear_method=None,
  141. org_num_embeddings: Optional[int] = None,
  142. padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
  143. super().__init__(num_embeddings, embedding_dim, params_dtype,
  144. linear_method, org_num_embeddings, padding_size,
  145. False)
  146. if bias:
  147. self.bias = Parameter(
  148. torch.empty(self.num_embeddings_per_partition,
  149. dtype=params_dtype))
  150. set_weight_attrs(self.bias, {
  151. "output_dim": 0,
  152. "weight_loader": self.weight_loader
  153. })
  154. else:
  155. self.register_parameter("bias", None)
  156. def forward(self, input_):
  157. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  158. if self.bias is not None:
  159. logits += self.bias
  160. return logits