embedding.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright (c) 2022, Tri Dao.
  2. import torch
  3. import torch.nn as nn
  4. from einops import rearrange
  5. from torch import Tensor
  6. from flash_attn.utils.distributed import all_reduce, reduce_scatter
  7. class GPT2Embeddings(nn.Module):
  8. def __init__(
  9. self,
  10. embed_dim,
  11. vocab_size,
  12. max_position_embeddings,
  13. padding_idx=None,
  14. word_embed_proj_dim=None,
  15. device=None,
  16. dtype=None,
  17. ):
  18. """
  19. If max_position_embeddings <= 0, there's no position embeddings
  20. If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
  21. the project up to embed_dim
  22. """
  23. factory_kwargs = {"device": device, "dtype": dtype}
  24. super().__init__()
  25. if word_embed_proj_dim is None:
  26. self.word_embeddings = nn.Embedding(
  27. vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
  28. )
  29. self.project_in = None
  30. else:
  31. self.word_embeddings = nn.Embedding(
  32. vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs
  33. )
  34. self.project_in = nn.Linear(
  35. word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs
  36. )
  37. self.max_position_embeddings = max_position_embeddings
  38. if self.max_position_embeddings > 0:
  39. self.position_embeddings = nn.Embedding(
  40. max_position_embeddings, embed_dim, **factory_kwargs
  41. )
  42. def forward(self, input_ids, position_ids=None):
  43. """
  44. input_ids: (batch, seqlen)
  45. position_ids: (batch, seqlen)
  46. """
  47. batch_size, seqlen = input_ids.shape
  48. embeddings = self.word_embeddings(input_ids)
  49. if self.project_in is not None:
  50. embeddings = self.project_in(embeddings)
  51. if self.max_position_embeddings > 0:
  52. if position_ids is None:
  53. position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
  54. position_embeddings = self.position_embeddings(position_ids)
  55. embeddings = embeddings + position_embeddings
  56. return embeddings
  57. class BertEmbeddings(nn.Module):
  58. def __init__(
  59. self,
  60. embed_dim,
  61. vocab_size,
  62. max_position_embeddings,
  63. type_vocab_size,
  64. padding_idx=None,
  65. device=None,
  66. dtype=None,
  67. ):
  68. """
  69. If max_position_embeddings <= 0, there's no position embeddings
  70. If type_vocab_size <= 0, there's no token type embeddings
  71. """
  72. factory_kwargs = {"device": device, "dtype": dtype}
  73. super().__init__()
  74. self.word_embeddings = nn.Embedding(
  75. vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
  76. )
  77. self.max_position_embeddings = max_position_embeddings
  78. self.type_vocab_size = type_vocab_size
  79. if self.max_position_embeddings > 0:
  80. self.position_embeddings = nn.Embedding(
  81. max_position_embeddings, embed_dim, **factory_kwargs
  82. )
  83. if self.type_vocab_size > 0:
  84. self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
  85. def forward(self, input_ids, position_ids=None, token_type_ids=None):
  86. """
  87. input_ids: (batch, seqlen)
  88. position_ids: (batch, seqlen)
  89. token_type_ids: (batch, seqlen)
  90. """
  91. batch_size, seqlen = input_ids.shape
  92. embeddings = self.word_embeddings(input_ids)
  93. if self.max_position_embeddings > 0:
  94. if position_ids is None:
  95. position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
  96. position_embeddings = self.position_embeddings(position_ids)
  97. embeddings = embeddings + position_embeddings
  98. if self.type_vocab_size > 0:
  99. if token_type_ids is None:
  100. token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
  101. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  102. embeddings = embeddings + token_type_embeddings
  103. return embeddings
  104. class VocabParallelEmbedding(nn.Embedding):
  105. def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
  106. self.process_group = process_group
  107. if process_group is not None:
  108. world_size = torch.distributed.get_world_size(process_group)
  109. if num_embeddings % world_size != 0:
  110. raise ValueError(
  111. f"num_embeddings ({num_embeddings}) must be divisible by "
  112. f"world_size ({world_size})"
  113. )
  114. if world_size > 1 and padding_idx is not None:
  115. raise RuntimeError("ParallelEmbedding does not support padding_idx")
  116. else:
  117. world_size = 1
  118. super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
  119. def forward(self, input: Tensor) -> Tensor:
  120. if self.process_group is None:
  121. return super().forward(input)
  122. else:
  123. rank = torch.distributed.get_rank(self.process_group)
  124. vocab_size = self.num_embeddings
  125. vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
  126. # Create a mask of valid vocab ids (1 means it needs to be masked).
  127. input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
  128. input = input - vocab_start_index
  129. input[input_ids_mask] = 0
  130. embeddings = super().forward(input)
  131. embeddings[input_ids_mask] = 0.0
  132. return embeddings
  133. class ColumnParallelEmbedding(nn.Embedding):
  134. def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
  135. self.process_group = process_group
  136. if process_group is not None:
  137. world_size = torch.distributed.get_world_size(process_group)
  138. if embedding_dim % world_size != 0:
  139. raise ValueError(
  140. f"embedding_dim ({embedding_dim}) must be divisible by "
  141. f"world_size ({world_size})"
  142. )
  143. else:
  144. world_size = 1
  145. super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
  146. class ParallelGPT2Embeddings(nn.Module):
  147. def __init__(
  148. self,
  149. embed_dim,
  150. vocab_size,
  151. max_position_embeddings,
  152. process_group,
  153. padding_idx=None,
  154. sequence_parallel=True,
  155. device=None,
  156. dtype=None,
  157. ):
  158. """
  159. If max_position_embeddings <= 0, there's no position embeddings
  160. """
  161. factory_kwargs = {"device": device, "dtype": dtype}
  162. super().__init__()
  163. self.process_group = process_group
  164. self.sequence_parallel = sequence_parallel
  165. self.word_embeddings = VocabParallelEmbedding(
  166. vocab_size,
  167. embed_dim,
  168. padding_idx=padding_idx,
  169. process_group=process_group,
  170. **factory_kwargs,
  171. )
  172. self.max_position_embeddings = max_position_embeddings
  173. if self.max_position_embeddings > 0:
  174. self.position_embeddings = ColumnParallelEmbedding(
  175. max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
  176. )
  177. def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
  178. """
  179. input_ids: (batch, seqlen)
  180. position_ids: (batch, seqlen)
  181. """
  182. batch_size, seqlen = input_ids.shape
  183. world_size = torch.distributed.get_world_size(self.process_group)
  184. embeddings = self.word_embeddings(input_ids)
  185. if self.max_position_embeddings > 0:
  186. if position_ids is None:
  187. position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
  188. position_embeddings = self.position_embeddings(position_ids)
  189. if world_size <= 1:
  190. embeddings = embeddings + position_embeddings
  191. else:
  192. partition_dim = self.position_embeddings.embedding_dim
  193. rank = torch.distributed.get_rank(self.process_group)
  194. embeddings[
  195. ..., rank * partition_dim : (rank + 1) * partition_dim
  196. ] += position_embeddings
  197. if combine_batch_seqlen_dim:
  198. embeddings = rearrange(embeddings, "b s d -> (b s) d")
  199. reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
  200. return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)