浏览代码

[Rotary] Support qkv block layout from GQA

Tri Dao 6 月之前
父节点
当前提交
8c20cfef49
共有 3 个文件被更改,包括 62 次插入17 次删除
  1. 61 14
      flash_attn/layers/rotary.py
  2. 0 2
      flash_attn/ops/triton/cross_entropy.py
  3. 1 1
      flash_attn/ops/triton/rotary.py

+ 61 - 14
flash_attn/layers/rotary.py

@@ -139,22 +139,37 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
         sin_k=None,
         interleaved=False,
         seqlen_offsets: Union[int, torch.Tensor] = 0,
+        num_heads_q: Union[int] = None,
     ):
-        batch, seqlen, three, nheads, headdim = qkv.shape
-        assert three == 3
         if cos_k is None and sin_k is None and qkv.is_contiguous():
             # Call 1 kernel instead of 2 kernels
             # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
             # dimensions, we get the same tensor
-            # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
-            qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
+            if qkv.dim() == 5:
+                batch, seqlen, three, nheads, headdim = qkv.shape
+                assert three == 3
+                # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
+                qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
+            else:
+                assert qkv.dim() == 4
+                assert num_heads_q is not None
+                num_heads_k = (qkv.shape[2] - num_heads_q) // 2
+                assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
+                qk = qkv[:, :, :num_heads_q + num_heads_k]
             apply_rotary(
                 qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
             )
         else:
             cos_k = cos if cos_k is None else cos_k
             sin_k = sin if sin_k is None else sin_k
-            q, k = qkv[:, :, 0], qkv[:, :, 1]
+            if qkv.dim() == 5:
+                q, k = qkv[:, :, 0], qkv[:, :, 1]
+            else:
+                assert qkv.dim() == 4
+                assert num_heads_q is not None
+                num_heads_k = (qkv.shape[2] - num_heads_q) // 2
+                assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
+                q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k]
             apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
             apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
             ctx.save_for_backward(cos, sin, cos_k, sin_k)
@@ -165,6 +180,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
             ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
             ctx.seqlen_offsets = None
         ctx.interleaved = interleaved
+        ctx.num_heads_q = num_heads_q
         return qkv
 
     @staticmethod
@@ -178,7 +194,14 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
             # Call 1 kernel instead of 2 kernels
             # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
             # dimensions, we get the same tensor
-            dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
+            if dqkv.dim() == 5:
+                dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
+            else:
+                assert dqkv.dim() == 4
+                assert ctx.num_heads_q is not None
+                num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
+                assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
+                dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k]
             apply_rotary(
                 dqk,
                 cos,
@@ -191,9 +214,23 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
         else:
             cos_k = cos if cos_k is None else cos_k
             sin_k = sin if sin_k is None else sin_k
-            dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
+            if dqkv.dim() == 5:
+                dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
+            else:
+                assert dqkv.dim() == 4
+                assert ctx.num_heads_q is not None
+                num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
+                assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
+                dq = dqkv[:, :, : ctx.num_heads_q]
+                dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k]
             apply_rotary(
-                dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
+                dq,
+                cos,
+                sin,
+                seqlen_offsets,
+                interleaved=ctx.interleaved,
+                inplace=True,
+                conjugate=True,
             )
             apply_rotary(
                 dk,
@@ -204,7 +241,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
                 inplace=True,
                 conjugate=True,
             )
-        return dqkv, None, None, None, None, None, None
+        return dqkv, None, None, None, None, None, None, None
 
 
 def apply_rotary_emb_qkv_(
@@ -215,10 +252,13 @@ def apply_rotary_emb_qkv_(
     sin_k=None,
     interleaved=False,
     seqlen_offsets: Union[int, torch.Tensor] = 0,
+    num_heads_q: Optional[int] = None,
 ):
     """
     Arguments:
-        qkv: (batch_size, seqlen, 3, nheads, headdim)
+        qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).
+            If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
+            then num_heads_q must be provided.
         cos, sin: (seqlen, rotary_dim / 2)
         cos_k, sin_k: (seqlen, rotary_dim / 2), optional
         interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
@@ -226,11 +266,13 @@ def apply_rotary_emb_qkv_(
         seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
             Most commonly used in inference when we have KV cache.
     Return:
-        qkv: (batch_size, seqlen, 3, nheads, headdim)
+        qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)
     rotary_dim must be <= headdim
     Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
     """
-    return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
+    return ApplyRotaryEmbQKV_.apply(
+        qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q
+    )
 
 
 class ApplyRotaryEmbKV_(torch.autograd.Function):
@@ -417,10 +459,13 @@ class RotaryEmbedding(torch.nn.Module):
         kv: Optional[torch.Tensor] = None,
         seqlen_offset: Union[int, torch.Tensor] = 0,
         max_seqlen: Optional[int] = None,
+        num_heads_q: Optional[int] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         """
-        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
-             else it's just q of shape (batch, seqlen, nheads, headdim)
+        qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)
+            if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).
+            If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
+            then num_heads_q must be provided.
         kv: (batch, seqlen, 2, nheads, headdim)
         seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
             Most commonly used in inference when we have KV cache.
@@ -441,6 +486,7 @@ class RotaryEmbedding(torch.nn.Module):
                     self._sin_cached,
                     interleaved=self.interleaved,
                     seqlen_offsets=seqlen_offset,
+                    num_heads_q=num_heads_q,
                 )
             else:
                 return apply_rotary_emb_qkv_(
@@ -451,6 +497,7 @@ class RotaryEmbedding(torch.nn.Module):
                     self._sin_k_cached,
                     interleaved=self.interleaved,
                     seqlen_offsets=seqlen_offset,
+                    num_heads_q=num_heads_q,
                 )
         else:
             q = qkv

+ 0 - 2
flash_attn/ops/triton/cross_entropy.py

@@ -245,8 +245,6 @@ class CrossEntropyLoss(torch.autograd.Function):
         ctx.total_classes = total_classes
         ctx.class_start_idx = class_start_idx
         ctx.inplace_backward = inplace_backward
-
-
         return losses, z_losses
 
     @staticmethod

+ 1 - 1
flash_attn/ops/triton/rotary.py

@@ -194,7 +194,7 @@ def apply_rotary(
         else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))
     )
     grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads)  # noqa
-    BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
+    BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)
 
     # Need this, otherwise Triton tries to launch from cuda:0 and we get
     # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)