vocab_parallel_embedding.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from typing import Optional, Sequence
  2. import torch
  3. import torch.nn.functional as F
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  6. get_tensor_model_parallel_world_size,
  7. tensor_model_parallel_all_reduce)
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. DEFAULT_VOCAB_PADDING_SIZE = 64
  10. def pad_vocab_size(vocab_size: int,
  11. pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
  12. """Pad the vocab size to the given value."""
  13. return ((vocab_size + pad_to - 1) // pad_to) * pad_to
  14. def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
  15. rank: int) -> Sequence[int]:
  16. index_f = rank * per_partition_vocab_size
  17. index_l = index_f + per_partition_vocab_size
  18. return index_f, index_l
  19. def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
  20. world_size: int) -> Sequence[int]:
  21. per_partition_vocab_size = divide(global_vocab_size, world_size)
  22. return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  23. rank)
  24. class VocabParallelEmbedding(torch.nn.Module):
  25. """Embedding parallelized in the vocabulary dimension.
  26. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
  27. make sure it is divisible by the number of model parallel GPUs.
  28. Args:
  29. num_embeddings: vocabulary size.
  30. embedding_dim: size of hidden state.
  31. params_dtype: type of the parameters.
  32. org_num_embeddings: original vocabulary size (without LoRA).
  33. padding_size: padding size for the vocabulary.
  34. """
  35. def __init__(self,
  36. num_embeddings: int,
  37. embedding_dim: int,
  38. params_dtype: Optional[torch.dtype] = None,
  39. org_num_embeddings: Optional[int] = None,
  40. padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
  41. super().__init__()
  42. # Keep the input dimensions.
  43. self.num_embeddings = num_embeddings
  44. self.org_vocab_size = org_num_embeddings or num_embeddings
  45. self.num_embeddings_padded = pad_vocab_size(num_embeddings,
  46. padding_size)
  47. self.embedding_dim = embedding_dim
  48. if params_dtype is None:
  49. params_dtype = torch.get_default_dtype()
  50. self.tp_size = get_tensor_model_parallel_world_size()
  51. # Divide the weight matrix along the vocaburaly dimension.
  52. self.vocab_start_index, self.vocab_end_index = (
  53. vocab_range_from_global_vocab_size(
  54. self.num_embeddings_padded, get_tensor_model_parallel_rank(),
  55. self.tp_size))
  56. self.num_embeddings_per_partition = (self.vocab_end_index -
  57. self.vocab_start_index)
  58. self.weight = Parameter(
  59. torch.empty(self.num_embeddings_per_partition,
  60. self.embedding_dim,
  61. dtype=params_dtype))
  62. set_weight_attrs(self.weight, {
  63. "parallel_dim": 0,
  64. "weight_loader": self.weight_loader
  65. })
  66. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  67. parallel_dim = param.parallel_dim
  68. assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
  69. loaded_weight = loaded_weight[self.vocab_start_index:self.
  70. vocab_end_index]
  71. param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
  72. def forward(self, input_):
  73. if self.tp_size > 1:
  74. # Build the mask.
  75. input_mask = ((input_ < self.vocab_start_index) |
  76. (input_ >= self.vocab_end_index))
  77. # Mask the input.
  78. masked_input = input_.clone() - self.vocab_start_index
  79. masked_input[input_mask] = 0
  80. else:
  81. masked_input = input_
  82. # Get the embeddings.
  83. output_parallel = F.embedding(masked_input, self.weight)
  84. # Mask the output embedding.
  85. if self.tp_size > 1:
  86. output_parallel[input_mask, :] = 0.0
  87. # Reduce across all the model parallel GPUs.
  88. output = tensor_model_parallel_all_reduce(output_parallel)
  89. return output
  90. def extra_repr(self) -> str:
  91. s = f"num_embeddings={self.num_embeddings_per_partition}"
  92. s += f", embedding_dim={self.embedding_dim}"
  93. s += f", org_vocab_size={self.org_vocab_size}"
  94. s += f', num_embeddings_padded={self.num_embeddings_padded}'
  95. s += f', tp_size={self.tp_size}'
  96. return s
  97. class ParallelLMHead(VocabParallelEmbedding):
  98. """Parallelized LM head.
  99. Output logits weight matrices used in the Sampler. The weight and bias
  100. tensors are padded to make sure they are divisible by the number of
  101. model parallel GPUs.
  102. Args:
  103. num_embeddings: vocabulary size.
  104. embedding_dim: size of hidden state.
  105. bias: whether to use bias.
  106. params_dtype: type of the parameters.
  107. org_num_embeddings: original vocabulary size (without LoRA).
  108. padding_size: padding size for the vocabulary.
  109. """
  110. def __init__(self,
  111. num_embeddings: int,
  112. embedding_dim: int,
  113. bias: bool = False,
  114. params_dtype: Optional[torch.dtype] = None,
  115. org_num_embeddings: Optional[int] = None,
  116. padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
  117. super().__init__(num_embeddings, embedding_dim, params_dtype,
  118. org_num_embeddings, padding_size)
  119. if bias:
  120. self.bias = Parameter(
  121. torch.empty(self.num_embeddings_per_partition,
  122. dtype=params_dtype))
  123. set_weight_attrs(self.bias, {
  124. "parallel_dim": 0,
  125. "weight_loader": self.weight_loader
  126. })
  127. else:
  128. self.register_parameter("bias", None)
  129. def forward(self, input_):
  130. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  131. if self.bias is not None:
  132. logits += self.bias
  133. return logits
  134. class ParallelTWEHead(torch.nn.Module):
  135. """Parallelized tie word embeddings head.
  136. Output logits weight matrices used in the Sampler. The weight and bias
  137. tensors are read from a VocabParallelEmbedding.
  138. Args:
  139. embeddings: the VocabParallelEmbedding to mirror
  140. """
  141. def __init__(self, embeddings: VocabParallelEmbedding):
  142. super().__init__()
  143. self.linear_method = embeddings.linear_method
  144. self.linear_weights = embeddings.linear_weights
  145. def forward(self, input_):
  146. logits = self.linear_method.apply_weights(self.linear_weights, input_)
  147. return logits