Browse Source

[TP] Put parallel embeddings in separate modules

Tri Dao 2 years ago
parent
commit
4cab4de5ea
1 changed files with 63 additions and 37 deletions
  1. 63 37
      flash_attn/modules/embedding.py

+ 63 - 37
flash_attn/modules/embedding.py

@@ -2,6 +2,7 @@
 
 import torch
 import torch.nn as nn
+from torch import Tensor
 
 from einops import rearrange
 
@@ -81,6 +82,51 @@ class BertEmbeddings(nn.Module):
         return embeddings
 
 
+class VocabParallelEmbedding(nn.Embedding):
+
+    def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
+        self.process_group = process_group
+        if process_group is not None:
+            world_size = torch.distributed.get_world_size(process_group)
+            if num_embeddings % world_size != 0:
+                raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by '
+                                 f'world_size ({world_size})')
+            if world_size > 1 and padding_idx is not None:
+                raise RuntimeError('ParallelEmbedding does not support padding_idx')
+        else:
+            world_size = 1
+        super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
+
+    def forward(self, input: Tensor) -> Tensor:
+        if self.process_group is None:
+            return super().forward(input)
+        else:
+            rank = torch.distributed.get_rank(self.process_group)
+            vocab_size = self.num_embeddings
+            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
+            # Create a mask of valid vocab ids (1 means it needs to be masked).
+            input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
+            input = input - vocab_start_index
+            input[input_ids_mask] = 0
+            embeddings = super().forward(input)
+            embeddings[input_ids_mask] = 0.0
+            return embeddings
+
+
+class ColumnParallelEmbedding(nn.Embedding):
+
+    def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
+        self.process_group = process_group
+        if process_group is not None:
+            world_size = torch.distributed.get_world_size(process_group)
+            if embedding_dim % world_size != 0:
+                raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by '
+                                 f'world_size ({world_size})')
+        else:
+            world_size = 1
+        super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
+
+
 class ParallelGPT2Embeddings(nn.Module):
 
     def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group,
@@ -88,22 +134,17 @@ class ParallelGPT2Embeddings(nn.Module):
         """
             If max_position_embeddings <= 0, there's no position embeddings
         """
-        world_size = torch.distributed.get_world_size(process_group)
-        if vocab_size % world_size != 0:
-            raise ValueError(f'vocab_size ({vocab_size}) must be divisible by '
-                             f'world_size ({world_size})')
-        if embed_dim % world_size != 0:
-            raise ValueError(f'embed_dim ({embed_dim}) must be divisible by '
-                             f'world_size ({world_size})')
         factory_kwargs = {'device': device, 'dtype': dtype}
         super().__init__()
         self.process_group = process_group
-        self.word_embeddings = nn.Embedding(vocab_size // world_size, embed_dim,
-                                            padding_idx=padding_idx, **factory_kwargs)
+        self.word_embeddings = VocabParallelEmbedding(
+            vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group,
+            **factory_kwargs
+        )
         self.max_position_embeddings = max_position_embeddings
         if self.max_position_embeddings > 0:
-            self.position_embeddings = nn.Embedding(
-                max_position_embeddings, embed_dim // world_size, **factory_kwargs
+            self.position_embeddings = ColumnParallelEmbedding(
+                max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
             )
 
     def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
@@ -113,32 +154,17 @@ class ParallelGPT2Embeddings(nn.Module):
         """
         batch_size, seqlen = input_ids.shape
         world_size = torch.distributed.get_world_size(self.process_group)
-        if world_size <= 1:
-            embeddings = self.word_embeddings(input_ids)
-            if self.max_position_embeddings > 0:
-                if position_ids is None:
-                    position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
-                position_embeddings = self.position_embeddings(position_ids)
+        embeddings = self.word_embeddings(input_ids)
+        if self.max_position_embeddings > 0:
+            if position_ids is None:
+                position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
+            position_embeddings = self.position_embeddings(position_ids)
+            if world_size <= 1:
                 embeddings = embeddings + position_embeddings
-            if combine_batch_seqlen_dim:
-                embeddings = rearrange(embeddings, 'b s d -> (b s) d')
-            return embeddings
-        else:
-            rank = torch.distributed.get_rank(self.process_group)
-            vocab_size = self.word_embeddings.num_embeddings
-            vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
-            # Create a mask of valid vocab ids (1 means it needs to be masked).
-            input_ids_mask = (input_ids < vocab_start_index) | (input_ids >= vocab_end_index)
-            input_ids = input_ids - vocab_start_index
-            input_ids[input_ids_mask] = 0
-            embeddings = self.word_embeddings(input_ids)
-            embeddings[input_ids_mask] = 0.0
-            if self.max_position_embeddings > 0:
-                if position_ids is None:
-                    position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
-                position_embeddings = self.position_embeddings(position_ids)
+            else:
                 partition_dim = self.position_embeddings.embedding_dim
+                rank = torch.distributed.get_rank(self.process_group)
                 embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings
-            if combine_batch_seqlen_dim:
-                embeddings = rearrange(embeddings, 'b s d -> (b s) d')
-            return reduce_scatter(embeddings, self.process_group)
+        if combine_batch_seqlen_dim:
+            embeddings = rearrange(embeddings, 'b s d -> (b s) d')
+        return embeddings if world_size <= 1 else reduce_scatter(embeddings, self.process_group)