vocab_parallel_embedding.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from typing import Optional, Sequence
  2. import torch
  3. import torch.nn.functional as F
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.megatron.parallel_state import (
  6. get_tensor_model_parallel_rank,
  7. get_tensor_model_parallel_world_size,
  8. )
  9. from aphrodite.modeling.megatron.utils import divide
  10. from aphrodite.modeling.megatron.communication_op import (
  11. tensor_model_parallel_all_reduce)
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
  14. """Pad the vocab size to the given value."""
  15. return ((vocab_size + pad_to - 1) // pad_to) * pad_to
  16. def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
  17. rank: int) -> Sequence[int]:
  18. index_f = rank * per_partition_vocab_size
  19. index_l = index_f + per_partition_vocab_size
  20. return index_f, index_l
  21. def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
  22. world_size: int) -> Sequence[int]:
  23. per_partition_vocab_size = divide(global_vocab_size, world_size)
  24. return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  25. rank)
  26. class VocabParallelEmbedding(torch.nn.Module):
  27. """Embedding parallelized in the vocabulary dimension.
  28. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
  29. make sure it is divisible by the number of model parallel GPUs.
  30. Args:
  31. num_embeddings: vocabulary size.
  32. embedding_dim: size of hidden state.
  33. params_dtype: type of the parameters.
  34. """
  35. def __init__(self,
  36. num_embeddings: int,
  37. embedding_dim: int,
  38. params_dtype: Optional[torch.dtype] = None):
  39. super().__init__()
  40. # Keep the input dimensions.
  41. self.num_embeddings = num_embeddings
  42. self.num_embeddings_padded = pad_vocab_size(num_embeddings)
  43. self.embedding_dim = embedding_dim
  44. if params_dtype is None:
  45. params_dtype = torch.get_default_dtype()
  46. self.tp_size = get_tensor_model_parallel_world_size()
  47. # Divide the weight matrix along the vocaburaly dimension.
  48. self.vocab_start_index, self.vocab_end_index = (
  49. vocab_range_from_global_vocab_size(
  50. self.num_embeddings_padded, get_tensor_model_parallel_rank(),
  51. self.tp_size))
  52. self.num_embeddings_per_partition = (self.vocab_end_index -
  53. self.vocab_start_index)
  54. self.weight = Parameter(
  55. torch.empty(self.num_embeddings_per_partition,
  56. self.embedding_dim,
  57. device=torch.cuda.current_device(),
  58. dtype=params_dtype))
  59. set_weight_attrs(self.weight, {
  60. "parallel_dim": 0,
  61. "weight_loader": self.weight_loader
  62. })
  63. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  64. parallel_dim = param.parallel_dim
  65. assert loaded_weight.shape[parallel_dim] == self.num_embeddings
  66. loaded_weight = loaded_weight[self.vocab_start_index:self.
  67. vocab_end_index]
  68. param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
  69. def forward(self, input_):
  70. if self.tp_size > 1:
  71. # Build the mask.
  72. input_mask = ((input_ < self.vocab_start_index) |
  73. (input_ >= self.vocab_end_index))
  74. # Mask the input.
  75. masked_input = input_.clone() - self.vocab_start_index
  76. masked_input[input_mask] = 0
  77. else:
  78. masked_input = input_
  79. # Get the embeddings.
  80. output_parallel = F.embedding(masked_input, self.weight)
  81. # Mask the output embedding.
  82. if self.tp_size > 1:
  83. output_parallel[input_mask, :] = 0.0
  84. # Reduce across all the model parallel GPUs.
  85. output = tensor_model_parallel_all_reduce(output_parallel)
  86. return output
  87. class ParallelLMHead(VocabParallelEmbedding):
  88. """Parallelized LM head.
  89. Output logits weight matrices used in the Sampler. The weight and bias
  90. tensors are padded to make sure they are divisible by the number of
  91. model parallel GPUs.
  92. Args:
  93. num_embeddings: vocabulary size.
  94. embedding_dim: size of hidden state.
  95. bias: whether to use bias.
  96. params_dtype: type of the parameters.
  97. """
  98. def __init__(self,
  99. num_embeddings: int,
  100. embedding_dim: int,
  101. bias: bool = False,
  102. params_dtype: Optional[torch.dtype] = None):
  103. super().__init__(num_embeddings, embedding_dim, params_dtype)
  104. if bias:
  105. self.bias = Parameter(
  106. torch.empty(self.num_embeddings_per_partition,
  107. device=torch.cuda.current_device(),
  108. dtype=params_dtype))
  109. set_weight_attrs(self.bias, {
  110. "parallel_dim": 0,
  111. "weight_loader": self.weight_loader
  112. })
  113. else:
  114. self.register_parameter("bias", None)
  115. def forward(self, input_):
  116. del input_
  117. raise RuntimeError("LMHead's weights should be used in the sampler.")