|
@@ -48,14 +48,13 @@ def convert_to_cudnn_type(torch_type):
|
|
|
raise ValueError("Unsupported tensor data type.")
|
|
|
|
|
|
|
|
|
-def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
|
|
|
+def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
|
|
|
b, nheads, seqlen_q, headdim = q.shape
|
|
|
- _, _, seqlen_k, _ = k.shape
|
|
|
- assert v.shape == (b, nheads, seqlen_k, headdim)
|
|
|
+ _, nheads_kv, seqlen_k, _ = k.shape
|
|
|
+ assert v.shape == (b, nheads_kv, seqlen_k, headdim)
|
|
|
assert cudnn is not None, 'CUDNN is not available'
|
|
|
q_gpu, k_gpu, v_gpu = q, k, v
|
|
|
- o_gpu = torch.empty_like(q_gpu)
|
|
|
- stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
|
|
|
+ o_gpu, stats_gpu = o, stats
|
|
|
graph_forward = cudnn.pygraph(
|
|
|
io_data_type=convert_to_cudnn_type(q.dtype),
|
|
|
intermediate_data_type=cudnn.data_type.FLOAT,
|
|
@@ -65,7 +64,7 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
|
|
|
k_forward = graph_forward.tensor_like(k_gpu.detach())
|
|
|
v_forward = graph_forward.tensor_like(v_gpu.detach())
|
|
|
|
|
|
- seqlens_reshaped = seqlens.reshape(b, 1, 1, 1).contiguous().cuda() if varlen else None
|
|
|
+ seqlens_reshaped = seqlens if varlen else None
|
|
|
seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
|
|
seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
|
|
|
|
|
@@ -193,8 +192,8 @@ dim = 2048
|
|
|
# headdim = 64
|
|
|
headdim = 256
|
|
|
|
|
|
-# for mode in ['fwd', 'bwd']:
|
|
|
-for mode in ['fwd']:
|
|
|
+for mode in ['fwd', 'bwd']:
|
|
|
+# for mode in ['bwd']:
|
|
|
for headdim in [64, 128, 256]:
|
|
|
# for headdim in [128]:
|
|
|
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
|
|
@@ -206,31 +205,38 @@ for mode in ['fwd']:
|
|
|
# seqlen = 512
|
|
|
# nheads = 8
|
|
|
# headdim = 128
|
|
|
+ # nheads = 16
|
|
|
+ # headdim = 128
|
|
|
nheads_kv = nheads
|
|
|
+ # nheads_kv = 1
|
|
|
|
|
|
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
|
|
requires_grad=True)
|
|
|
q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
|
- k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
|
- v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
|
+ v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
|
|
|
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
|
|
|
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
|
|
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
|
|
|
grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
|
|
|
grad_t = grad.transpose(1, 2).contiguous()
|
|
|
+ o_t = torch.empty_like(q.transpose(1, 2))
|
|
|
+ stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
|
|
|
|
|
|
bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
|
|
|
|
|
|
for causal in [False, True]:
|
|
|
# for causal in [True]:
|
|
|
- print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
|
|
|
+ print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
|
|
|
# For var-seq-len
|
|
|
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
|
|
|
+ seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
|
|
|
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
|
|
|
if headdim <= 128 and cudnn is not None:
|
|
|
- cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
|
|
|
- cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal, varlen=True, seqlens=lens)
|
|
|
+ cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
|
|
|
+ cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
|
|
|
f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
|
|
|
+ ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
|
|
|
_, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
|
|
|
if mode == 'bwd':
|
|
|
ref_dv, v.grad = v.grad.clone(), None
|
|
@@ -238,7 +244,7 @@ for mode in ['fwd']:
|
|
|
ref_dq, q.grad = q.grad.clone(), None
|
|
|
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
|
|
|
if headdim <= 128:
|
|
|
- if triton_attention is not None:
|
|
|
+ if triton_attention is not None and nheads_kv == nheads:
|
|
|
if mode == 'fwd':
|
|
|
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
|
|
_, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
|
|
@@ -255,22 +261,31 @@ for mode in ['fwd']:
|
|
|
if mode == 'fwd':
|
|
|
_, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
|
_, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
|
+ cudnn_sdpa_fwd()
|
|
|
+ torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
+ cudnn_sdpa_fwd_varlen()
|
|
|
+ torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
else:
|
|
|
cudnn_sdpa_fwd()
|
|
|
_, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
|
+ _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
|
|
|
dq, dk, dv = cudnn_sdpa_bwd()
|
|
|
torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
+ dq, dk, dv = cudnn_sdpa_bwd_varlen()
|
|
|
+ torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
+ torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
+ torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
|
|
|
# pytorch_profiler(cudnn_sdpa, backward=False)
|
|
|
- if headdim == 128 or mode == 'fwd':
|
|
|
+
|
|
|
+ if headdim <= 128 or mode == 'fwd':
|
|
|
time.sleep(1)
|
|
|
_, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
|
|
|
q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
|
|
|
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
|
|
|
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
|
|
|
time.sleep(1)
|
|
|
- _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
|
|
if mode == 'bwd':
|
|
|
dv, v.grad = v.grad.clone(), None
|
|
|
dk, k.grad = k.grad.clone(), None
|
|
@@ -279,15 +294,21 @@ for mode in ['fwd']:
|
|
|
torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
|
|
|
torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
|
|
|
|
|
|
+ bench_var_fn = bench_fn
|
|
|
+ if mode == 'bwd':
|
|
|
+ grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
|
|
|
+ bench_var_fn = partial(benchmark_backward, grad=grad_var)
|
|
|
+ _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
|
|
|
+
|
|
|
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
|
|
|
print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
|
|
|
if headdim <= 128:
|
|
|
- if triton_attention is not None:
|
|
|
+ if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
|
|
|
print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
|
|
|
if cudnn is not None:
|
|
|
print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
|
|
|
print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
|
|
|
- if headdim == 128 or mode == 'fwd':
|
|
|
+ if headdim <= 128 or mode == 'fwd':
|
|
|
print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
|
|
|
print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
|
|
|
|