123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- import math
- import random
- import pytest
- import torch
- import torch.nn.functional as F
- from einops import rearrange
- from flash_attn.layers.rotary import apply_rotary_emb, apply_rotary_emb_torch
- from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_kv_
- from flash_attn.bert_padding import pad_input, unpad_input
- is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
- def generate_cos_sin(seqlen, rotary_dim, device, dtype):
- assert rotary_dim % 2 == 0
- angle = torch.rand(seqlen * 2, rotary_dim // 2, device=device) * 2 * math.pi
- cos = torch.cos(angle).to(dtype=dtype)
- sin = torch.sin(angle).to(dtype=dtype)
- return cos, sin
- def generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device):
- if seqlen_offsets_type == 0:
- return 0
- elif seqlen_offsets_type is int:
- return torch.randint(0, seqlen + 1, (1,)).item()
- elif seqlen_offsets_type is torch.Tensor:
- return torch.randint(0, seqlen + 1, (batch_size,), dtype=torch.int32, device=device)
- def index_cos_sin(cos, sin, seqlen_offsets, seqlen):
- if isinstance(seqlen_offsets, torch.Tensor):
- batch_size = seqlen_offsets.shape[0]
- arange = rearrange(torch.arange(seqlen, device=cos.device), "s -> 1 s")
- idx = rearrange(seqlen_offsets, "b -> b 1") + arange
- cos_pt = rearrange(cos[idx.flatten()], "(b s) d -> b s d", b=batch_size)
- sin_pt = rearrange(sin[idx.flatten()], "(b s) d -> b s d", b=batch_size)
- else:
- cos_pt = cos[seqlen_offsets : seqlen_offsets + seqlen]
- sin_pt = sin[seqlen_offsets : seqlen_offsets + seqlen]
- return cos_pt, sin_pt
- @pytest.mark.parametrize(
- "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
- )
- # @pytest.mark.parametrize('dtype', ([torch.float16]))
- @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
- # @pytest.mark.parametrize("seqlen_offsets_type", [0])
- @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
- # @pytest.mark.parametrize('rotary_fraction', [1.0])
- @pytest.mark.parametrize("interleaved", [False, True])
- # @pytest.mark.parametrize('interleaved', [True])
- @pytest.mark.parametrize("inplace", [False, True])
- # @pytest.mark.parametrize('inplace', [False])
- def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
- rtol = 1e-3
- batch_size = 32
- nheads = 4
- seqlen = 217
- headdim = 128
- device = "cuda"
- rotary_dim = int(rotary_fraction * headdim)
- torch.manual_seed(42)
- x = torch.randn(
- batch_size, seqlen, nheads, headdim, dtype=dtype, device=device, requires_grad=True
- )
- x_pt = x.detach().clone().requires_grad_()
- cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
- seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
- out = apply_rotary_emb(
- x, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=inplace
- )
- cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
- out_pt = apply_rotary_emb_torch(
- x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
- ).to(dtype=dtype)
- print(f"Output max diff: {(out - out_pt).abs().max().item()}")
- g = torch.randn_like(out)
- g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
- out.backward(g)
- out_pt.backward(g_pt)
- print(f"Grad max diff: {(x.grad - x_pt.grad).abs().max().item()}")
- if not inplace:
- assert torch.equal(x, x_pt)
- # Numerical error if we just do any arithmetic
- atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
- assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
- atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
- assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)
- @pytest.mark.parametrize(
- "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
- )
- # @pytest.mark.parametrize('dtype', ([torch.float16]))
- @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
- # @pytest.mark.parametrize("seqlen_offsets_type", [0])
- @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
- # @pytest.mark.parametrize('rotary_fraction', [1.0])
- @pytest.mark.parametrize("interleaved", [False, True])
- # @pytest.mark.parametrize('interleaved', [False])
- def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
- rtol = 1e-3
- batch_size = 32
- nheads = 4
- seqlen = 512
- headdim = 128
- device = "cuda"
- rotary_dim = int(rotary_fraction * headdim)
- torch.manual_seed(42)
- qkv = torch.randn(
- batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
- )
- qkv_pt = qkv.detach().clone().requires_grad_()
- cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
- seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
- out = apply_rotary_emb_qkv_(
- qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
- )
- cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
- q_pt = apply_rotary_emb_torch(
- qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
- ).to(dtype=dtype)
- k_pt = apply_rotary_emb_torch(
- qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
- ).to(dtype=dtype)
- out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
- print(f"Output max diff: {(out - out_pt).abs().max().item()}")
- g = torch.randn_like(out)
- g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
- out.backward(g)
- out_pt.backward(g_pt)
- print(f"Grad max diff: {(qkv.grad - qkv_pt.grad).abs().max().item()}")
- # Numerical error if we just do any arithmetic
- atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
- assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
- atol = ((qkv_pt.grad + 0.3 - 0.3) - qkv_pt.grad).abs().max().item()
- assert torch.allclose(qkv.grad, qkv_pt.grad, rtol=rtol, atol=2 * atol)
- @pytest.mark.parametrize(
- "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
- )
- # @pytest.mark.parametrize('dtype', ([torch.float16]))
- @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
- # @pytest.mark.parametrize("seqlen_offsets_type", [0])
- @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
- # @pytest.mark.parametrize('rotary_fraction', [1.0])
- @pytest.mark.parametrize("interleaved", [False, True])
- # @pytest.mark.parametrize('interleaved', [False])
- def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
- rtol = 1e-3
- batch_size = 32
- nheads = 4
- seqlen = 781
- headdim = 64
- device = "cuda"
- rotary_dim = int(rotary_fraction * headdim)
- torch.manual_seed(42)
- kv = torch.randn(
- batch_size, seqlen, 2, nheads, headdim, dtype=dtype, device=device, requires_grad=True
- )
- kv_pt = kv.detach().clone().requires_grad_()
- cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
- seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
- out = apply_rotary_emb_kv_(kv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved)
- cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
- k_pt = apply_rotary_emb_torch(
- kv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
- ).to(dtype=dtype)
- out_pt = torch.stack([k_pt, kv_pt[:, :, 1]], dim=2)
- print(f"Output max diff: {(out - out_pt).abs().max().item()}")
- g = torch.randn_like(out)
- g_pt = g.clone() # Since inplace=True, we modify the gradient inplace
- out.backward(g)
- out_pt.backward(g_pt)
- print(f"Grad max diff: {(kv.grad - kv_pt.grad).abs().max().item()}")
- # Numerical error if we just do any arithmetic
- atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
- assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
- atol = ((kv_pt.grad + 0.3 - 0.3) - kv_pt.grad).abs().max().item()
- assert torch.allclose(kv.grad, kv_pt.grad, rtol=rtol, atol=2 * atol)
- @pytest.mark.parametrize(
- "dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
- )
- # @pytest.mark.parametrize("dtype", ([torch.float16]))
- @pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
- # @pytest.mark.parametrize("seqlen_offsets_type", [0])
- @pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
- # @pytest.mark.parametrize("rotary_fraction", [1.0])
- @pytest.mark.parametrize("interleaved", [False, True])
- # @pytest.mark.parametrize("interleaved", [True])
- @pytest.mark.parametrize("inplace", [False, True])
- # @pytest.mark.parametrize("inplace", [False])
- def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_offsets_type, dtype):
- rtol = 1e-3
- batch_size = 32
- nheads = 4
- seqlen = 217
- headdim = 128
- device = "cuda"
- rotary_dim = int(rotary_fraction * headdim)
- torch.manual_seed(42)
- x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
- x_pt = x.detach().clone().requires_grad_()
- lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
- padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
- x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
- x_unpad_clone = x_unpad.clone()
- x_unpad = x_unpad.requires_grad_()
- cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
- seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
- out_unpad = apply_rotary_emb(
- x_unpad,
- cos,
- sin,
- seqlen_offsets=seqlen_offsets,
- interleaved=interleaved,
- inplace=inplace,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
- )
- out = pad_input(out_unpad, indices, batch_size, seqlen)
- cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
- out_pt = apply_rotary_emb_torch(
- x_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
- ).to(dtype=dtype)
- out_pt = out_pt.masked_fill(rearrange(~padding_mask, "b s -> b s 1 1"), 0.0)
- print(f"Output max diff: {(out - out_pt).abs().max().item()}")
- g = torch.randn_like(out)
- g_pt = g.clone() # If inplace=True, we might modify the gradient inplace
- out.backward(g)
- out_pt.backward(g_pt)
- x_grad = pad_input(x_unpad.grad, indices, batch_size, seqlen)
- print(f"Grad max diff: {(x_grad - x_pt.grad).abs().max().item()}")
- if not inplace:
- assert torch.equal(x_unpad, x_unpad_clone)
- # Numerical error if we just do any arithmetic
- atol = ((out_pt + 0.3 - 0.3) - out_pt).abs().max().item()
- assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
- atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
- assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol)
- def test_compilation_count():
- batch_size = 1
- headdim = 128
- device = "cuda"
- dtype = torch.float16
- torch.manual_seed(42)
- from triton.runtime.jit import JITFunction
- from flash_attn.ops.triton.rotary import rotary_kernel
- compilation_count = 0
- def count_compilations(*args, **kwargs):
- nonlocal compilation_count
- compilation_count += 1
- old_cache_func = JITFunction.cache_hook
- try:
- rotary_kernel.cache.clear()
- JITFunction.cache_hook = count_compilations
- for seqlen in (128, 256):
- for nheads in (4, 32):
- x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device)
- x.requires_grad_()
- cos, sin = generate_cos_sin(seqlen, headdim, device, dtype)
- out = apply_rotary_emb(x, cos, sin)
- out.backward(torch.randn_like(out))
- # Only two kernels are expected to be compiled:
- # * for the forward pass (conjugate=False)
- # * for the backward pass (conjugate=True)
- assert compilation_count == 2
- finally:
- JITFunction.cache_hook = old_cache_func
|