|
@@ -97,13 +97,15 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t
|
|
|
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
|
|
|
)
|
|
|
# @pytest.mark.parametrize('dtype', ([torch.float16]))
|
|
|
+@pytest.mark.parametrize("gqa", [False, True])
|
|
|
+# @pytest.mark.parametrize("gqa", [False])
|
|
|
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
|
|
|
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
|
|
|
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
|
|
|
# @pytest.mark.parametrize('rotary_fraction', [1.0])
|
|
|
@pytest.mark.parametrize("interleaved", [False, True])
|
|
|
# @pytest.mark.parametrize('interleaved', [False])
|
|
|
-def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
|
|
|
+def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, dtype):
|
|
|
rtol = 1e-3
|
|
|
batch_size = 32
|
|
|
nheads = 4
|
|
@@ -112,23 +114,37 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype
|
|
|
device = "cuda"
|
|
|
rotary_dim = int(rotary_fraction * headdim)
|
|
|
torch.manual_seed(42)
|
|
|
- qkv = torch.randn(
|
|
|
- batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
|
|
- )
|
|
|
+ if not gqa:
|
|
|
+ qkv = torch.randn(
|
|
|
+ batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ nheads_k = nheads // 2
|
|
|
+ qkv = torch.randn(
|
|
|
+ batch_size, seqlen, nheads + nheads_k * 2, headdim, dtype=dtype, device=device, requires_grad=True
|
|
|
+ )
|
|
|
qkv_pt = qkv.detach().clone().requires_grad_()
|
|
|
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
|
|
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
|
|
|
out = apply_rotary_emb_qkv_(
|
|
|
- qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
|
|
|
+ qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved,
|
|
|
+ num_heads_q=None if not gqa else nheads
|
|
|
)
|
|
|
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
|
|
|
+ if not gqa:
|
|
|
+ q_pt, k_pt, v_pt = qkv_pt.unbind(2)
|
|
|
+ else:
|
|
|
+ q_pt, k_pt, v_pt = qkv_pt.split([nheads, nheads_k, nheads_k], dim=2)
|
|
|
q_pt = apply_rotary_emb_torch(
|
|
|
- qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
|
|
+ q_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
|
|
).to(dtype=dtype)
|
|
|
k_pt = apply_rotary_emb_torch(
|
|
|
- qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
|
|
+ k_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
|
|
|
).to(dtype=dtype)
|
|
|
- out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
|
|
|
+ if not gqa:
|
|
|
+ out_pt = torch.stack([q_pt, k_pt, v_pt], dim=2)
|
|
|
+ else:
|
|
|
+ out_pt = torch.cat([q_pt, k_pt, v_pt], dim=2)
|
|
|
print(f"Output max diff: {(out - out_pt).abs().max().item()}")
|
|
|
|
|
|
g = torch.randn_like(out)
|