Browse Source

Fix backward with softcap

Tri Dao 5 months ago
parent
commit
ea7a98f15d

+ 0 - 2
hopper/flash_api.cpp

@@ -712,7 +712,6 @@ mha_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x head_si
     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
-    TORCH_CHECK(softcap == 0.0, "Softcap is not yet supported in the backward pass");
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
     const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
@@ -914,7 +913,6 @@ mha_varlen_bwd(const at::Tensor &dout,  // batch_size x seqlen_q x num_heads, x
     TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
     TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
     TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
-    TORCH_CHECK(softcap == 0.0, "Softcap is not yet supported in the backward pass");
 
     auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
     const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);

+ 1 - 0
hopper/flash_bwd_launch_template.h

@@ -262,6 +262,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
         if (params.softcap == 0.f) {
             run_mha_bwd_dispatch<T, 128, 128, 64, Is_causal, Is_local, /*Has_softcap=*/false, 2, 2, true, false, false, 2, 1, 2, 2>(params, stream);
         } else {
+            // register spill with 128 x 128
             run_mha_bwd_dispatch<T, 96, 128, 64, Is_causal, Is_local, /*Has_softcap=*/true, 2, 2, true, false, true, 2, 1, 2, 2>(params, stream);
         }
     });

+ 10 - 1
hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp

@@ -425,6 +425,15 @@ struct CollectiveMainloopBwd {
             assert(args.cu_seqlens_q != nullptr);
             assert(args.cu_seqlens_k != nullptr);
         }
+        // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val.
+        // Right after this, we multiply by log2(e) before applying exp2.
+        // To reduce the number of instructions, we instead pre-multiply softmax_scale / softcap_val
+        // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e)
+        // (assigning it to params.softmax_scale_log2).
+        // In the backward, we need to multiply by
+        // (1 - tanh^2) * softmax_scale / softcap_val * softcap_val = (1 - tanh^2) * softmax_scale.
+        // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale
+        // (the original softmax_scale) at the end.
         return {args.shape_Q, args.shape_K, args.shape_dQaccum,
                 args.ptr_dQaccum, args.stride_dQaccum,
                 cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))),
@@ -1037,7 +1046,7 @@ struct CollectiveMainloopBwd {
             }
         }
 
-        static constexpr int n_local_bottom_steps = (!Is_local || !SeparateMaskingIterations) ? 0 : cute::ceil_div(kBlockM, kBlockN) + 1;
+        static constexpr int n_local_bottom_steps = (!Is_local || !SeparateMaskingIterations) ? 0 : cute::ceil_div(kBlockN, kBlockM) + 1;
         auto mask_fn = [&](auto& tSrS, int m_block) { causal_local_mask_fn(tSrS, m_block, cute::bool_constant<Is_causal && !SeparateMaskingIterations>{}, cute::bool_constant<Is_local && !SeparateMaskingIterations>{}); };
         CUTLASS_PRAGMA_NO_UNROLL
         for (; m_block < m_block_max - n_local_bottom_steps; ++m_block) {

+ 7 - 9
hopper/test_flash_attn.py

@@ -47,7 +47,7 @@ def generate_qkv(
     assert v.shape == (batch_size, seqlen_k, nheads_k, d)
 
     if query_padding_mask is not None:
-        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)
+        q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask)
         output_pad_fn = lambda output_unpad: pad_input(
             output_unpad, indices_q, batch_size, seqlen_q
         )
@@ -62,8 +62,8 @@ def generate_qkv(
         )
 
     if key_padding_mask is not None:
-        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)
-        v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)
+        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, *rest = unpad_input(k, key_padding_mask)
+        v_unpad, _, _, _, *rest = unpad_input(v, key_padding_mask)
     else:
         k_unpad = rearrange(k, "b s h d -> (b s) h d")
         v_unpad = rearrange(v, "b s h d -> (b s) h d")
@@ -404,7 +404,7 @@ def test_flash_attn_output(
     # breakpoint()
 
 
-    if dtype != torch.float8_e4m3fn and not V_colmajor and softcap == 0.0:
+    if dtype != torch.float8_e4m3fn and not V_colmajor:
         g = torch.randn_like(out)
         do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2)
         import flashattn_hopper_cuda
@@ -459,7 +459,7 @@ def test_flash_attn_output(
     multiple = 2
     assert (out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()
 
-    if dtype != torch.float8_e4m3fn and not V_colmajor and softcap == 0.0:
+    if dtype != torch.float8_e4m3fn and not V_colmajor:
         multiple = 2 if softcap == 0.0 else 4
         assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item()
@@ -603,7 +603,7 @@ def test_flash_attn_varlen_output(
     #     print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
     # breakpoint()
 
-    if dtype != torch.float8_e4m3fn and softcap == 0.0:
+    if dtype != torch.float8_e4m3fn:
         g_unpad = torch.randn_like(out_unpad)
         do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
         import flashattn_hopper_cuda
@@ -667,10 +667,8 @@ def test_flash_attn_varlen_output(
     # of a Pytorch implementation.
     assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
 
-    if dtype != torch.float8_e4m3fn and softcap == 0.0:
+    if dtype != torch.float8_e4m3fn:
         multiple = 2 if softcap == 0.0 else 4
         assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item()
         assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item()
         assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item()
-
-