Просмотр исходного кода

[Rotary] Add test for rotary when qkv are packed an there's GQA

Tri Dao 6 месяцев назад
Родитель
Сommit
cc1690d9d6
1 измененных файлов с 24 добавлено и 8 удалено
  1. 24 8
      tests/test_rotary.py

+ 24 - 8
tests/test_rotary.py

@@ -97,13 +97,15 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t
     "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
 )
 # @pytest.mark.parametrize('dtype', ([torch.float16]))
+@pytest.mark.parametrize("gqa", [False, True])
+# @pytest.mark.parametrize("gqa", [False])
 @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
 # @pytest.mark.parametrize("seqlen_offsets_type", [0])
 @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
 # @pytest.mark.parametrize('rotary_fraction', [1.0])
 @pytest.mark.parametrize("interleaved", [False, True])
 # @pytest.mark.parametrize('interleaved', [False])
-def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
+def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, dtype):
     rtol = 1e-3
     batch_size = 32
     nheads = 4
@@ -112,23 +114,37 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype
     device = "cuda"
     rotary_dim = int(rotary_fraction * headdim)
     torch.manual_seed(42)
-    qkv = torch.randn(
-        batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
-    )
+    if not gqa:
+        qkv = torch.randn(
+            batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
+        )
+    else:
+        nheads_k = nheads // 2
+        qkv = torch.randn(
+            batch_size, seqlen, nheads + nheads_k * 2, headdim, dtype=dtype, device=device, requires_grad=True
+        )
     qkv_pt = qkv.detach().clone().requires_grad_()
     cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
     seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
     out = apply_rotary_emb_qkv_(
-        qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
+        qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved,
+        num_heads_q=None if not gqa else nheads
     )
     cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
+    if not gqa:
+        q_pt, k_pt, v_pt = qkv_pt.unbind(2)
+    else:
+        q_pt, k_pt, v_pt = qkv_pt.split([nheads, nheads_k, nheads_k], dim=2)
     q_pt = apply_rotary_emb_torch(
-        qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
+        q_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
     ).to(dtype=dtype)
     k_pt = apply_rotary_emb_torch(
-        qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
+        k_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
     ).to(dtype=dtype)
-    out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
+    if not gqa:
+        out_pt = torch.stack([q_pt, k_pt, v_pt], dim=2)
+    else:
+        out_pt = torch.cat([q_pt, k_pt, v_pt], dim=2)
     print(f"Output max diff: {(out - out_pt).abs().max().item()}")
 
     g = torch.randn_like(out)