浏览代码

bwd benchmark + small fixes (#1129)

Ying Zhang 7 月之前
父节点
当前提交
3669b25206
共有 2 个文件被更改,包括 40 次插入19 次删除
  1. 39 18
      hopper/benchmark_attn.py
  2. 1 1
      hopper/epilogue_fwd_sm90_tma.hpp

+ 39 - 18
hopper/benchmark_attn.py

@@ -48,14 +48,13 @@ def convert_to_cudnn_type(torch_type):
         raise ValueError("Unsupported tensor data type.")
 
 
-def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
+def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None):
     b, nheads, seqlen_q, headdim = q.shape
-    _, _, seqlen_k, _ = k.shape
-    assert v.shape == (b, nheads, seqlen_k, headdim)
+    _, nheads_kv, seqlen_k, _ = k.shape
+    assert v.shape == (b, nheads_kv, seqlen_k, headdim)
     assert cudnn is not None, 'CUDNN is not available'
     q_gpu, k_gpu, v_gpu = q, k, v
-    o_gpu = torch.empty_like(q_gpu)
-    stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
+    o_gpu, stats_gpu = o, stats
     graph_forward = cudnn.pygraph(
         io_data_type=convert_to_cudnn_type(q.dtype),
         intermediate_data_type=cudnn.data_type.FLOAT,
@@ -65,7 +64,7 @@ def cudnn_sdpa_setup(q, k, v, grad, causal=False, varlen=False, seqlens=None):
     k_forward = graph_forward.tensor_like(k_gpu.detach())
     v_forward = graph_forward.tensor_like(v_gpu.detach())
 
-    seqlens_reshaped = seqlens.reshape(b, 1, 1, 1).contiguous().cuda() if varlen else None
+    seqlens_reshaped = seqlens if varlen else None
     seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
     seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None
 
@@ -193,8 +192,8 @@ dim = 2048
 # headdim = 64
 headdim = 256
 
-# for mode in ['fwd', 'bwd']:
-for mode in ['fwd']:
+for mode in ['fwd', 'bwd']:
+# for mode in ['bwd']:
     for headdim in [64, 128, 256]:
     # for headdim in [128]:
         for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]:
@@ -206,31 +205,38 @@ for mode in ['fwd']:
             # seqlen = 512
             # nheads = 8
             # headdim = 128
+            # nheads = 16
+            # headdim = 128
             nheads_kv = nheads
+            # nheads_kv = 1
     
             qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
                             requires_grad=True)
             q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
-            k = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
-            v = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
+            k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
+            v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True)
             q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
             k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
             v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
             grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype)
             grad_t = grad.transpose(1, 2).contiguous()
+            o_t = torch.empty_like(q.transpose(1, 2))
+            stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device)
     
             bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad)
 
             for causal in [False, True]:
             # for causal in [True]:
-                print(f"\n### {headdim = }, {seqlen = }, {causal = } ###")
+                print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###")
                 # For var-seq-len
                 lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
+                seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda()
                 cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()
                 if headdim <= 128 and cudnn is not None:
-                    cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal)
-                    cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), causal=causal, varlen=True, seqlens=lens)
+                    cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal)
+                    cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn)
                 f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode)
+                ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal)
                 _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2')
                 if mode == 'bwd':
                     ref_dv, v.grad = v.grad.clone(), None
@@ -238,7 +244,7 @@ for mode in ['fwd']:
                     ref_dq, q.grad = q.grad.clone(), None
                 # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False)
                 if headdim <= 128:
-                    if triton_attention is not None:
+                    if triton_attention is not None and nheads_kv == nheads:
                         if mode == 'fwd':
                             time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
                             _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
@@ -255,22 +261,31 @@ for mode in ['fwd']:
                         if mode == 'fwd':
                             _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN')
                             _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
+                            cudnn_sdpa_fwd()
+                            torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
+                            cudnn_sdpa_fwd_varlen()
+                            torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05)
                         else:
                             cudnn_sdpa_fwd()
                             _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
+                            _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN')
                             dq, dk, dv = cudnn_sdpa_bwd()
                             torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
                             torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
                             torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
+                            dq, dk, dv = cudnn_sdpa_bwd_varlen()
+                            torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05)
+                            torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05)
+                            torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05)
                         # pytorch_profiler(cudnn_sdpa, backward=False)
-                if headdim == 128 or mode == 'fwd':
+
+                if headdim <= 128 or mode == 'fwd':
                     time.sleep(1)
                     _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3')
                     q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
                     k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
                     v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
                     time.sleep(1)
-                    _, m1_var = bench_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
                     if mode == 'bwd':
                         dv, v.grad = v.grad.clone(), None
                         dk, k.grad = k.grad.clone(), None
@@ -279,15 +294,21 @@ for mode in ['fwd']:
                         torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05)
                         torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05)
  
+                    bench_var_fn = bench_fn
+                    if mode == 'bwd':
+                        grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
+                        bench_var_fn = partial(benchmark_backward, grad=grad_var)
+                    _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len')
+
                 # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False)
                 print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS')
                 if headdim <= 128:
-                    if triton_attention is not None:
+                    if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads:
                         print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS')
                     if cudnn is not None:
                         print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS')
                         print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS')
-                if headdim == 128 or mode == 'fwd':
+                if headdim <= 128 or mode == 'fwd':
                     print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS')
                     print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS')
     

+ 1 - 1
hopper/epilogue_fwd_sm90_tma.hpp

@@ -288,7 +288,7 @@ struct CollectiveEpilogueFwd {
             gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
         );
         static_assert(kBlockM <= NumMmaThreads);
-        if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; }
+        if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
     }
 
 };