Переглянути джерело

build: build flash attention kernels inside aphrodite (#1085)

* build: build flash attention kernels inside aphrodite

* update

* update

* make this crap finally work

* fix codespell
AlpinDale 1 місяць тому
батько
коміт
7fffa507ff
85 змінених файлів з 6024 додано та 15 видалено
  1. 68 3
      CMakeLists.txt
  2. 122 0
      aphrodite/_custom_ops.py
  3. 6 7
      aphrodite/attention/backends/flash_attn.py
  4. 287 0
      aphrodite/attention/ops/aphrodite_flash_attn.py
  5. 3 4
      aphrodite/attention/selector.py
  6. 65 0
      kernels/flash_attn/block_info.h
  7. 121 0
      kernels/flash_attn/dropout.h
  8. 156 0
      kernels/flash_attn/flash.h
  9. 1167 0
      kernels/flash_attn/flash_api.h
  10. 11 0
      kernels/flash_attn/flash_fwd_hdim128_bf16_causal_sm80.cu
  11. 11 0
      kernels/flash_attn/flash_fwd_hdim128_bf16_sm80.cu
  12. 11 0
      kernels/flash_attn/flash_fwd_hdim128_fp16_causal_sm80.cu
  13. 11 0
      kernels/flash_attn/flash_fwd_hdim128_fp16_sm80.cu
  14. 11 0
      kernels/flash_attn/flash_fwd_hdim160_bf16_causal_sm80.cu
  15. 11 0
      kernels/flash_attn/flash_fwd_hdim160_bf16_sm80.cu
  16. 11 0
      kernels/flash_attn/flash_fwd_hdim160_fp16_causal_sm80.cu
  17. 11 0
      kernels/flash_attn/flash_fwd_hdim160_fp16_sm80.cu
  18. 11 0
      kernels/flash_attn/flash_fwd_hdim192_bf16_causal_sm80.cu
  19. 11 0
      kernels/flash_attn/flash_fwd_hdim192_bf16_sm80.cu
  20. 11 0
      kernels/flash_attn/flash_fwd_hdim192_fp16_causal_sm80.cu
  21. 11 0
      kernels/flash_attn/flash_fwd_hdim192_fp16_sm80.cu
  22. 11 0
      kernels/flash_attn/flash_fwd_hdim224_bf16_causal_sm80.cu
  23. 11 0
      kernels/flash_attn/flash_fwd_hdim224_bf16_sm80.cu
  24. 11 0
      kernels/flash_attn/flash_fwd_hdim224_fp16_causal_sm80.cu
  25. 11 0
      kernels/flash_attn/flash_fwd_hdim224_fp16_sm80.cu
  26. 11 0
      kernels/flash_attn/flash_fwd_hdim256_bf16_causal_sm80.cu
  27. 11 0
      kernels/flash_attn/flash_fwd_hdim256_bf16_sm80.cu
  28. 11 0
      kernels/flash_attn/flash_fwd_hdim256_fp16_causal_sm80.cu
  29. 11 0
      kernels/flash_attn/flash_fwd_hdim256_fp16_sm80.cu
  30. 11 0
      kernels/flash_attn/flash_fwd_hdim32_bf16_causal_sm80.cu
  31. 11 0
      kernels/flash_attn/flash_fwd_hdim32_bf16_sm80.cu
  32. 11 0
      kernels/flash_attn/flash_fwd_hdim32_fp16_causal_sm80.cu
  33. 11 0
      kernels/flash_attn/flash_fwd_hdim32_fp16_sm80.cu
  34. 11 0
      kernels/flash_attn/flash_fwd_hdim64_bf16_causal_sm80.cu
  35. 11 0
      kernels/flash_attn/flash_fwd_hdim64_bf16_sm80.cu
  36. 11 0
      kernels/flash_attn/flash_fwd_hdim64_fp16_causal_sm80.cu
  37. 11 0
      kernels/flash_attn/flash_fwd_hdim64_fp16_sm80.cu
  38. 11 0
      kernels/flash_attn/flash_fwd_hdim96_bf16_causal_sm80.cu
  39. 11 0
      kernels/flash_attn/flash_fwd_hdim96_bf16_sm80.cu
  40. 11 0
      kernels/flash_attn/flash_fwd_hdim96_fp16_causal_sm80.cu
  41. 11 0
      kernels/flash_attn/flash_fwd_hdim96_fp16_sm80.cu
  42. 1715 0
      kernels/flash_attn/flash_fwd_kernel.h
  43. 356 0
      kernels/flash_attn/flash_fwd_launch_template.h
  44. 7 0
      kernels/flash_attn/flash_fwd_split_hdim128_bf16_causal_sm80.cu
  45. 7 0
      kernels/flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu
  46. 7 0
      kernels/flash_attn/flash_fwd_split_hdim128_fp16_causal_sm80.cu
  47. 7 0
      kernels/flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu
  48. 7 0
      kernels/flash_attn/flash_fwd_split_hdim160_bf16_causal_sm80.cu
  49. 7 0
      kernels/flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu
  50. 7 0
      kernels/flash_attn/flash_fwd_split_hdim160_fp16_causal_sm80.cu
  51. 7 0
      kernels/flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu
  52. 7 0
      kernels/flash_attn/flash_fwd_split_hdim192_bf16_causal_sm80.cu
  53. 7 0
      kernels/flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu
  54. 7 0
      kernels/flash_attn/flash_fwd_split_hdim192_fp16_causal_sm80.cu
  55. 7 0
      kernels/flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu
  56. 7 0
      kernels/flash_attn/flash_fwd_split_hdim224_bf16_causal_sm80.cu
  57. 7 0
      kernels/flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu
  58. 7 0
      kernels/flash_attn/flash_fwd_split_hdim224_fp16_causal_sm80.cu
  59. 7 0
      kernels/flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu
  60. 7 0
      kernels/flash_attn/flash_fwd_split_hdim256_bf16_causal_sm80.cu
  61. 7 0
      kernels/flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu
  62. 7 0
      kernels/flash_attn/flash_fwd_split_hdim256_fp16_causal_sm80.cu
  63. 7 0
      kernels/flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu
  64. 7 0
      kernels/flash_attn/flash_fwd_split_hdim32_bf16_causal_sm80.cu
  65. 7 0
      kernels/flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu
  66. 7 0
      kernels/flash_attn/flash_fwd_split_hdim32_fp16_causal_sm80.cu
  67. 7 0
      kernels/flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu
  68. 7 0
      kernels/flash_attn/flash_fwd_split_hdim64_bf16_causal_sm80.cu
  69. 7 0
      kernels/flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu
  70. 7 0
      kernels/flash_attn/flash_fwd_split_hdim64_fp16_causal_sm80.cu
  71. 7 0
      kernels/flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu
  72. 7 0
      kernels/flash_attn/flash_fwd_split_hdim96_bf16_causal_sm80.cu
  73. 7 0
      kernels/flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu
  74. 7 0
      kernels/flash_attn/flash_fwd_split_hdim96_fp16_causal_sm80.cu
  75. 7 0
      kernels/flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu
  76. 180 0
      kernels/flash_attn/kernel_traits.h
  77. 213 0
      kernels/flash_attn/mask.h
  78. 51 0
      kernels/flash_attn/philox.cuh
  79. 22 0
      kernels/flash_attn/registration.h
  80. 152 0
      kernels/flash_attn/rotary.h
  81. 188 0
      kernels/flash_attn/softmax.h
  82. 117 0
      kernels/flash_attn/static_switch.h
  83. 440 0
      kernels/flash_attn/utils.h
  84. 19 0
      kernels/torch_bindings.cpp
  85. 0 1
      requirements-cuda.txt

+ 68 - 3
CMakeLists.txt

@@ -222,7 +222,72 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
     "kernels/quantization/fp8/fp8_marlin.cu"
     "kernels/all_reduce/custom_all_reduce.cu"
     "kernels/permute_cols.cu"
-    "kernels/sampling/sampling.cu")
+    "kernels/sampling/sampling.cu"
+    "kernels/flash_attn/flash_fwd_hdim32_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim32_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim32_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim32_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim64_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim64_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim64_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim64_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim96_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim96_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim96_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim96_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim128_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim128_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim128_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim128_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim160_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim160_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim160_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim160_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim192_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim192_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim192_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim192_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim224_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim224_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim224_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim224_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim256_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim256_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim256_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_hdim256_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim32_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim32_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim64_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim64_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim96_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim96_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim128_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim128_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim160_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim160_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim192_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim192_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim224_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim224_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim256_bf16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim256_fp16_causal_sm80.cu"
+    "kernels/flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu"
+    )
 
   if(NOT MSVC)
     # Include CUTLASS only when needed
@@ -361,16 +426,16 @@ if(APHRODITE_GPU_LANG STREQUAL "HIP")
     WITH_SOABI)
 endif()
 
-
 if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
   message(STATUS "Enabling C extension.")
   add_dependencies(default _C)
 
   message(STATUS "Enabling moe extension.")
   add_dependencies(default _moe_C)
+
 endif()
 
 if(APHRODITE_GPU_LANG STREQUAL "HIP")
   message(STATUS "Enabling rocm extension.")
   add_dependencies(default _rocm_C)
-endif()
+endif()

+ 122 - 0
aphrodite/_custom_ops.py

@@ -1057,6 +1057,128 @@ def top_k_top_p_sampling_from_probs(
         raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
 
 
+# Flash Attention kernels
+def fwd(
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        alibi_slopes: torch.Tensor,
+        dropout_p: float,
+        softmax_scale: float,
+        causal: bool,
+        window_size_left: int,
+        window_size_right: int,
+        softcap: float,
+        return_softmax: bool,
+        out: torch.Tensor,
+        gen: Optional[torch.Generator] = None,
+):
+    return torch.ops._C.fwd(
+        q,
+        k,
+        v,
+        out,
+        alibi_slopes,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size_left,
+        window_size_right,
+        softcap,
+        return_softmax,
+        gen,
+    )
+
+def varlen_fwd(
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        out: Optional[torch.Tensor],
+        cu_seqlens_q: torch.Tensor,
+        cu_seqlens_k: torch.Tensor,
+        seqused_k: Optional[torch.Tensor],
+        block_table: Optional[torch.Tensor],
+        alibi_slopes: Optional[torch.Tensor],
+        max_seqlen_q: int,
+        max_seqlen_k: int,
+        dropout_p: float,
+        softmax_scale: float,
+        zero_tensors: bool,
+        causal: bool,
+        window_size_left: int,
+        window_size_right: int,
+        softcap: float,
+        return_softmax: bool,
+        gen: Optional[torch.Generator] = None,
+):
+    return torch.ops._C.varlen_fwd(
+        q,
+        k,
+        v,
+        out,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        seqused_k,
+        block_table,
+        alibi_slopes,
+        max_seqlen_q,
+        max_seqlen_k,
+        dropout_p,
+        softmax_scale,
+        zero_tensors,
+        causal,
+        window_size_left,
+        window_size_right,
+        softcap,
+        return_softmax,
+        gen,
+    )
+
+
+def fwd_kvcache(
+        q: torch.Tensor,
+        kcache: torch.Tensor,
+        vcache: torch.Tensor,
+        k: Optional[torch.Tensor],
+        v: Optional[torch.Tensor],
+        seqlens_k: Optional[torch.Tensor],
+        rotary_cos: Optional[torch.Tensor],
+        rotary_sin: Optional[torch.Tensor],
+        cache_batch_idx: Optional[torch.Tensor],
+        block_table: Optional[torch.Tensor],
+        alibi_slopes: Optional[torch.Tensor],
+        out: Optional[torch.Tensor],
+        softmax_scale: float,
+        causal: bool,
+        window_size_left: int,
+        window_size_right: int,
+        softcap: float,
+        rotary_interleaved: bool,
+        num_splits: int,
+):
+    return torch.ops._C.fwd_kvcache(
+        q,
+        kcache,
+        vcache,
+        k,
+        v,
+        seqlens_k,
+        rotary_cos,
+        rotary_sin,
+        cache_batch_idx,
+        block_table,
+        alibi_slopes,
+        out,
+        softmax_scale,
+        causal,
+        window_size_left,
+        window_size_right,
+        softcap,
+        rotary_interleaved,
+        num_splits,
+    )
+
+
 # TODO: remove this later
 names_and_values = globals()
 names_and_values_to_update = {}

+ 6 - 7
aphrodite/attention/backends/flash_attn.py

@@ -15,16 +15,15 @@ from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
                                                 compute_slot_mapping,
                                                 compute_slot_mapping_start_idx,
                                                 is_block_tables_empty)
+from aphrodite.attention.ops.aphrodite_flash_attn import (
+    flash_attn_varlen_func as _flash_attn_varlen_func)
+from aphrodite.attention.ops.aphrodite_flash_attn import (
+    flash_attn_with_kvcache as _flash_attn_with_kvcache)
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
 if TYPE_CHECKING:
-    from aphrodite.worker.model_runner import (ModelInputForGPUBuilder,
-                                               ModelInputForGPUWithSamplingMetadata)
-
-from aphrodite_flash_attn import (
-    flash_attn_varlen_func as _flash_attn_varlen_func)
-from aphrodite_flash_attn import (
-    flash_attn_with_kvcache as _flash_attn_with_kvcache)
+    from aphrodite.worker.model_runner import (
+        ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata)
 
 
 @torch.library.custom_op("aphrodite::flash_attn_varlen_func", mutates_args=[])

+ 287 - 0
aphrodite/attention/ops/aphrodite_flash_attn.py

@@ -0,0 +1,287 @@
+from typing import Optional, Union
+
+import torch
+
+import aphrodite._custom_ops as ops
+
+
+def maybe_contiguous(x):
+    return x.contiguous() if x is not None and x.stride(-1) != 1 else x
+
+
+def _flash_attn_forward(
+    q, k, v, dropout_p, softmax_scale, causal,
+    window_size, softcap, alibi_slopes,
+    return_softmax, *, out=None
+):
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    (out, q, k, v, out_padded, softmax_lse,
+     S_dmask, rng_state) = ops.fwd(
+        q=q,
+        k=k,
+        v=v,
+        out=out,
+        alibi_slopes=alibi_slopes,
+        dropout_p=dropout_p,
+        softmax_scale=softmax_scale,
+        causal=causal,
+        window_size_left=window_size[0],
+        window_size_right=window_size[1],
+        softcap=softcap,
+        return_softmax=return_softmax,
+        gen=None,
+    )  # type: ignore
+    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
+
+
+def _flash_attn_varlen_forward(
+    q,
+    k,
+    v,
+    cu_seqlens_q,
+    cu_seqlens_k,
+    max_seqlen_q,
+    max_seqlen_k,
+    dropout_p,
+    softmax_scale,
+    causal,
+    window_size,
+    softcap,
+    alibi_slopes,
+    return_softmax,
+    block_table,
+    *,
+    out=None
+):
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    (out, q, k, v, out_padded, softmax_lse,
+     S_dmask, rng_state) = ops.varlen_fwd(
+        q=q,
+        k=k,
+        v=v,
+        cu_seqlens_q=cu_seqlens_q,
+        cu_seqlens_k=cu_seqlens_k,
+        max_seqlen_q=max_seqlen_q,
+        max_seqlen_k=max_seqlen_k,
+        dropout_p=dropout_p,
+        softmax_scale=softmax_scale,
+        causal=causal,
+        window_size_left=window_size[0],
+        window_size_right=window_size[1],
+        softcap=softcap,
+        alibi_slopes=alibi_slopes,
+        block_table=block_table,
+        return_softmax=return_softmax,
+        gen=None,
+        out=out,
+        seqused_k=None,
+        zero_tensors=False,
+    )  # type: ignore
+    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
+
+
+class FlashAttnFunc(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        q,
+        k,
+        v,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        softcap,
+        alibi_slopes,
+        deterministic,
+        return_softmax,
+        out=None,
+    ):
+        if softmax_scale is None:
+            softmax_scale = q.shape[-1] ** (-0.5)
+        (out, q, k, v, out_padded, softmax_lse,
+         S_dmask, rng_state) = _flash_attn_forward(
+            q,
+            k,
+            v,
+            dropout_p,
+            softmax_scale,
+            causal=causal,
+            window_size=window_size,
+            softcap=softcap,
+            alibi_slopes=alibi_slopes,
+            return_softmax=return_softmax and dropout_p > 0,
+            out=out,
+        )
+        ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
+        ctx.dropout_p = dropout_p
+        ctx.softmax_scale = softmax_scale
+        ctx.causal = causal
+        ctx.window_size = window_size
+        ctx.softcap = softcap
+        ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
+        return out if not return_softmax else (out, softmax_lse, S_dmask)
+
+
+class FlashAttnVarlenFunc(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        q,
+        k,
+        v,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        softcap,
+        alibi_slopes,
+        deterministic,
+        return_softmax,
+        block_table,
+        out=None,
+    ):
+        if softmax_scale is None:
+            softmax_scale = q.shape[-1] ** (-0.5)
+        (out, q, k, v, out_padded, softmax_lse,
+         S_dmask, rng_state) = _flash_attn_varlen_forward(
+            q,
+            k,
+            v,
+            cu_seqlens_q,
+            cu_seqlens_k,
+            max_seqlen_q,
+            max_seqlen_k,
+            dropout_p,
+            softmax_scale,
+            causal=causal,
+            window_size=window_size,
+            softcap=softcap,
+            alibi_slopes=alibi_slopes,
+            return_softmax=return_softmax and dropout_p > 0,
+            block_table=block_table,
+            out=out,
+        )
+        ctx.save_for_backward(
+            q, k, v, out_padded, softmax_lse, cu_seqlens_q,
+            cu_seqlens_k, rng_state
+        )
+        ctx.dropout_p = dropout_p
+        ctx.max_seqlen_q = max_seqlen_q
+        ctx.max_seqlen_k = max_seqlen_k
+        ctx.softmax_scale = softmax_scale
+        ctx.causal = causal
+        ctx.window_size = window_size
+        ctx.softcap = softcap
+        ctx.alibi_slopes = alibi_slopes
+        ctx.deterministic = deterministic
+        return out if not return_softmax else (out, softmax_lse, S_dmask)
+    
+
+def flash_attn_varlen_func(
+    q,
+    k,
+    v,
+    cu_seqlens_q,
+    cu_seqlens_k,
+    max_seqlen_q,
+    max_seqlen_k,
+    dropout_p=0.0,
+    softmax_scale=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
+    alibi_slopes=None,
+    deterministic=False,
+    return_attn_probs=False,
+    block_table=None,
+    *,
+    out=None,
+):
+    return FlashAttnVarlenFunc.apply(
+        q,
+        k,
+        v,
+        cu_seqlens_q,
+        cu_seqlens_k,
+        max_seqlen_q,
+        max_seqlen_k,
+        dropout_p,
+        softmax_scale,
+        causal,
+        window_size,
+        softcap,
+        alibi_slopes,
+        deterministic,
+        return_attn_probs,
+        block_table,
+        out,
+    )
+
+
+def flash_attn_with_kvcache(
+    q,
+    k_cache,
+    v_cache,
+    k=None,
+    v=None,
+    rotary_cos=None,
+    rotary_sin=None,
+    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
+    cache_batch_idx: Optional[torch.Tensor] = None,
+    block_table: Optional[torch.Tensor] = None,
+    softmax_scale=None,
+    causal=False,
+    window_size=(-1, -1),  # -1 means infinite context window
+    softcap=0.0, # 0.0 means deactivated
+    rotary_interleaved=True,
+    alibi_slopes=None,
+    num_splits=0,
+    return_softmax_lse=False,
+    *,
+    out=None,
+):
+    assert k_cache.stride(-1) == 1, (
+        "k_cache must have contiguous last dimension"
+    )
+    assert v_cache.stride(-1) == 1, (
+        "v_cache must have contiguous last dimension"
+    )
+    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
+    if softmax_scale is None:
+        softmax_scale = q.shape[-1] ** (-0.5)
+    if cache_seqlens is not None and isinstance(cache_seqlens, int):
+        cache_seqlens = torch.full(
+            (k_cache.shape[0],), cache_seqlens,
+            dtype=torch.int32, device=k_cache.device
+        )
+        cache_seqlens = maybe_contiguous(cache_seqlens)
+    cache_batch_idx = maybe_contiguous(cache_batch_idx)
+    block_table = maybe_contiguous(block_table)
+    out, softmax_lse = ops.fwd_kvcache(
+        q=q,
+        kcache=k_cache,
+        vcache=v_cache,
+        k=k,
+        v=v,
+        seqlens_k=cache_seqlens,
+        rotary_cos=rotary_cos,
+        rotary_sin=rotary_sin,
+        cache_batch_idx=cache_batch_idx,
+        block_table=block_table,
+        alibi_slopes=alibi_slopes,
+        out=out,
+        softmax_scale=softmax_scale,
+        causal=causal,
+        window_size_left=window_size[0],
+        window_size_right=window_size[1],
+        softcap=softcap,
+        rotary_interleaved=rotary_interleaved,
+        num_splits=num_splits,
+    )  # type: ignore
+    return (out, softmax_lse) if return_softmax_lse else out

+ 3 - 4
aphrodite/attention/selector.py

@@ -253,8 +253,7 @@ def which_attn_to_use(
     # FlashAttn is valid for the model, checking if the package is installed.
     if selected_backend == _Backend.FLASH_ATTN:
         try:
-            import aphrodite_flash_attn  # noqa: F401
-
+            import aphrodite.attention.ops.aphrodite_flash_attn  # noqa: F401
             from aphrodite.attention.backends.flash_attn import (  # noqa: F401
                 FlashAttentionBackend)
 
@@ -267,8 +266,8 @@ def which_attn_to_use(
         except ImportError:
             logger.info(
                 "Cannot use FlashAttention-2 backend because the "
-                "aphrodite_flash_attn package is not found. "
-                "`pip install aphrodite-flash-attn` for better performance.")
+                "aphrodite._aphrodite_flash_attn_C object is not found. "
+                "This is built by default on supported hardware.")
             selected_backend = _Backend.XFORMERS
 
     return selected_backend

+ 65 - 0
kernels/flash_attn/block_info.h

@@ -0,0 +1,65 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+namespace flash {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Varlen = true>
+struct BlockInfo {
+  template <typename Params>
+  __device__ BlockInfo(const Params& params, const int bidb)
+      : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr
+                    ? -1
+                    : params.cu_seqlens_q[bidb]),
+        sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ||
+                        !params.is_seqlens_k_cumulative
+                    ? -1
+                    : params.cu_seqlens_k[bidb]),
+        actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
+                            ? params.seqlen_q
+                            : params.cu_seqlens_q[bidb + 1] - sum_s_q)
+        // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] -
+        // cu_seqlens_k[bidb]. Otherwise it's cu_seqlens_k[bidb], i.e., we use
+        // cu_seqlens_k to store the sequence lengths of K.
+        ,
+        seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
+                           ? params.seqlen_k
+                           : (params.is_seqlens_k_cumulative
+                                  ? params.cu_seqlens_k[bidb + 1] - sum_s_k
+                                  : params.cu_seqlens_k[bidb])),
+        actual_seqlen_k(params.seqused_k
+                            ? params.seqused_k[bidb]
+                            : seqlen_k_cache + (params.knew_ptr == nullptr
+                                                    ? 0
+                                                    : params.seqlen_knew)) {}
+
+  template <typename index_t>
+  __forceinline__ __device__ index_t q_offset(const index_t batch_stride,
+                                              const index_t row_stride,
+                                              const int bidb) const {
+    return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
+  }
+
+  template <typename index_t>
+  __forceinline__ __device__ index_t k_offset(const index_t batch_stride,
+                                              const index_t row_stride,
+                                              const int bidb) const {
+    return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
+  }
+
+  const int sum_s_q;
+  const int sum_s_k;
+  const int actual_seqlen_q;
+  // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise
+  // actual_seqlen_k is set to 0.
+  const int seqlen_k_cache;
+  const int actual_seqlen_k;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace flash

+ 121 - 0
kernels/flash_attn/dropout.h

@@ -0,0 +1,121 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include "philox.cuh"
+#include "utils.h"
+
+namespace flash {
+
+struct Dropout {
+  const unsigned long long seed, offset;
+  const uint8_t p_dropout_in_uint8_t;
+
+  __forceinline__ __device__ Dropout(const unsigned long long seed,
+                                     const unsigned long long offset,
+                                     const uint8_t p_dropout_in_uint8_t,
+                                     const int bid, const int hid,
+                                     const int tid, const int nheads)
+      : seed(seed),
+        offset(offset + (bid * nheads + hid) * 32 + tid % 32),
+        p_dropout_in_uint8_t(p_dropout_in_uint8_t) {}
+
+  template <bool encode_dropout_in_sign_bit = false, typename Engine,
+            typename Layout>
+  __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout>& tensor_,
+                                                int block_row_start,
+                                                int block_col_start,
+                                                int block_row_stride) {
+    // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
+    Tensor tensor = make_tensor(
+        tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
+    using T = typename Engine::value_type;
+    auto encode_dropout = [](bool keep, T val) {
+      return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
+    };
+    static_assert(decltype(size<2>(tensor))::value % 2 == 0);
+    const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
+    const uint32_t p_dropout_8bit_in_uint32_t =
+        (uint32_t(p_dropout_8bit_in_uint16_t) << 16) |
+        uint32_t(p_dropout_8bit_in_uint16_t);
+// if (cute::thread0()) { printf("threshold2 = 0x%x\n",
+// p_dropout_8bit_in_uint32_t); }
+#pragma unroll
+    for (int m = 0; m < size<1>(tensor);
+         ++m, block_row_start += block_row_stride) {
+      uint2 rowcol = make_uint2(block_row_start, block_col_start);
+#pragma unroll
+      for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
+        // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col =
+        // %d\n", m, n, int(rowcol.x), int(rowcol.y));}
+        uint4 random_uint4 = flash::philox(
+            seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
+        // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n",
+        // random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
+        uint8_t(&rnd_8)[16] = reinterpret_cast<uint8_t(&)[16]>(random_uint4);
+        // Special implementation for 16-bit types: we duplicate the threshold
+        // to the low and high 16 bits of a 32-bit value, then use the f16x2
+        // comparison instruction to get a mask. The low 16 bits of the mask
+        // will be either 0xffff or 0x0000, and the high 16 bits will be either
+        // 0xffff or 0x0000, depending on whether the random value is less than
+        // the threshold. We then do a bit-wise AND between the mask and the
+        // original value (in 32-bit). We're exploiting the fact that floating
+        // point comparison is equivalent to integer comparison, since we're
+        // comparing unsigned integers whose top 8-bits are zero.
+        if (!encode_dropout_in_sign_bit &&
+            (std::is_same<T, cutlass::half_t>::value ||
+             std::is_same<T, cutlass::bfloat16_t>::value)) {
+          uint16_t rnd_16[16];
+#pragma unroll
+          for (int i = 0; i < 16; i++) {
+            rnd_16[i] = uint16_t(rnd_8[i]);
+          }
+          uint32_t(&rnd_32)[8] = reinterpret_cast<uint32_t(&)[8]>(rnd_16);
+#pragma unroll
+          for (int j = 0; j < 2; j++) {
+            Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
+// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j *
+// 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } if
+// (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n",
+// tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
+#pragma unroll
+            for (int i = 0; i < 4; i++) {
+              uint32_t mask;
+              asm volatile("set.le.u32.f16x2 %0, %1, %2;\n"
+                           : "=r"(mask)
+                           : "r"(rnd_32[j * 4 + i]),
+                             "r"(p_dropout_8bit_in_uint32_t));
+              tensor_uint32(i) &= mask;
+            }
+            // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x,
+            // 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2),
+            // tensor_uint32(3)); }
+          }
+        } else {
+#pragma unroll
+          for (int j = 0; j < 2; j++) {
+#pragma unroll
+            for (int i = 0; i < 8; i++) {
+              tensor(i, m, n * 2 + j) =
+                  encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t,
+                                 tensor(i, m, n * 2 + j));
+            }
+            Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
+            // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x,
+            // 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2),
+            // tensor_uint32(3)); }
+          }
+        }
+        // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0))
+        // {
+        // //     printf("n = %d, ph  Philox: %u, %u, %u, %u\n", n, rnd_8.x,
+        // rnd_8.y, rnd_8.z, rnd_8.w);
+        // // }
+      }
+    }
+  }
+};
+
+}  // namespace flash

+ 156 - 0
kernels/flash_attn/flash.h

@@ -0,0 +1,156 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cuda.h>
+#include <vector>
+
+#ifdef OLD_GENERATOR_PATH
+  #include <ATen/CUDAGeneratorImpl.h>
+#else
+  #include <ATen/cuda/CUDAGeneratorImpl.h>
+#endif
+
+#include <ATen/cuda/CUDAGraphsUtils.cuh>  // For at::cuda::philox::unpack
+
+constexpr int TOTAL_DIM = 0;
+constexpr int H_DIM = 1;
+constexpr int D_DIM = 2;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Qkv_params {
+  using index_t = int64_t;
+  // The QKV matrices.
+  void* __restrict__ q_ptr;
+  void* __restrict__ k_ptr;
+  void* __restrict__ v_ptr;
+
+  // The stride between rows of the Q, K and V matrices.
+  index_t q_batch_stride;
+  index_t k_batch_stride;
+  index_t v_batch_stride;
+  index_t q_row_stride;
+  index_t k_row_stride;
+  index_t v_row_stride;
+  index_t q_head_stride;
+  index_t k_head_stride;
+  index_t v_head_stride;
+
+  // The number of heads.
+  int h, h_k;
+  // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k
+  // could be different from nheads (query).
+  int h_h_k_ratio;  // precompute h / h_k,
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Flash_fwd_params : public Qkv_params {
+  // The O matrix (output).
+  void* __restrict__ o_ptr;
+  void* __restrict__ oaccum_ptr;
+
+  // The stride between rows of O.
+  index_t o_batch_stride;
+  index_t o_row_stride;
+  index_t o_head_stride;
+
+  // The pointer to the P matrix.
+  void* __restrict__ p_ptr;
+
+  // The pointer to the softmax sum.
+  void* __restrict__ softmax_lse_ptr;
+  void* __restrict__ softmax_lseaccum_ptr;
+
+  // The dimensions.
+  int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded,
+      d_rounded, rotary_dim, total_q;
+
+  // The scaling factors for the kernel.
+  float scale_softmax;
+  float scale_softmax_log2;
+
+  // array of length b+1 holding starting offset of each sequence.
+  int* __restrict__ cu_seqlens_q;
+  int* __restrict__ cu_seqlens_k;
+
+  // If provided, the actual length of each k sequence.
+  int* __restrict__ seqused_k;
+
+  int* __restrict__ blockmask;
+
+  // The K_new and V_new matrices.
+  void* __restrict__ knew_ptr;
+  void* __restrict__ vnew_ptr;
+
+  // The stride between rows of the Q, K and V matrices.
+  index_t knew_batch_stride;
+  index_t vnew_batch_stride;
+  index_t knew_row_stride;
+  index_t vnew_row_stride;
+  index_t knew_head_stride;
+  index_t vnew_head_stride;
+
+  // The cos and sin matrices for rotary embedding.
+  void* __restrict__ rotary_cos_ptr;
+  void* __restrict__ rotary_sin_ptr;
+
+  // The indices to index into the KV cache.
+  int* __restrict__ cache_batch_idx;
+
+  // Paged KV cache
+  int* __restrict__ block_table;
+  index_t block_table_batch_stride;
+  int page_block_size;
+
+  // The dropout probability (probability of keeping an activation).
+  float p_dropout;
+  // uint32_t p_dropout_in_uint;
+  // uint16_t p_dropout_in_uint16_t;
+  uint8_t p_dropout_in_uint8_t;
+
+  // Scale factor of 1 / (1 - p_dropout).
+  float rp_dropout;
+  float scale_softmax_rp_dropout;
+
+  // Local window size
+  int window_size_left, window_size_right;
+  float softcap;
+
+  // Random state.
+  at::PhiloxCudaState philox_args;
+
+  // Pointer to the RNG seed (idx 0) and offset (idx 1).
+  uint64_t* rng_state;
+
+  bool is_bf16;
+  bool is_causal;
+
+  // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] -
+  // cu_seqlens_k[bidb]. Otherwise it's cu_seqlens_k[bidb], i.e., we use
+  // cu_seqlens_k to store the sequence lengths of K.
+  bool is_seqlens_k_cumulative;
+
+  bool is_rotary_interleaved;
+
+  int num_splits;  // For split-KV version
+
+  void* __restrict__ alibi_slopes_ptr;
+  index_t alibi_slopes_batch_stride;
+
+  bool unpadded_lse;  // For varlen paths: LSE is in [nheads, total_seqlen_q]
+                      // format instead of [b, nheads, seqlen_q].
+  bool seqlenq_ngroups_swapped;  // q has been transposed from (b, 1, (nheads_kv
+                                 // ngroups), d) to (b, ngroups, nheads_kv, d).
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename T, int Headdim, bool Is_causal>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
+template <typename T, int Headdim, bool Is_causal>
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params,
+                                  cudaStream_t stream);

+ 1167 - 0
kernels/flash_attn/flash_api.h

@@ -0,0 +1,1167 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+// Include these 2 headers instead of torch/extension.h since we don't need all
+// of the torch headers.
+#include "registration.h"
+#include <torch/library.h>
+#include <torch/nn/functional.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include <cutlass/numeric_types.h>
+
+#include "flash.h"
+#include "static_switch.h"
+
+#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
+#define CHECK_SHAPE(x, ...)                                   \
+  TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
+              #x " must have shape (" #__VA_ARGS__ ")")
+#define CHECK_CONTIGUOUS(x) \
+  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+
+void set_params_fprop(Flash_fwd_params& params,
+                      // sizes
+                      const size_t b, const size_t seqlen_q,
+                      const size_t seqlen_k, const size_t seqlen_q_rounded,
+                      const size_t seqlen_k_rounded, const size_t h,
+                      const size_t h_k, const size_t d, const size_t d_rounded,
+                      // device pointers
+                      const at::Tensor q, const at::Tensor k,
+                      const at::Tensor v, at::Tensor out, void* cu_seqlens_q_d,
+                      void* cu_seqlens_k_d, void* seqused_k, void* p_d,
+                      void* softmax_lse_d, float p_dropout, float softmax_scale,
+                      int window_size_left, int window_size_right,
+                      const float softcap, bool seqlenq_ngroups_swapped = false,
+                      const bool unpadded_lse = false) {
+  // Reset the parameters
+  params = {};
+
+  params.is_bf16 = q.dtype() == torch::kBFloat16;
+
+  // Set the pointers and strides.
+  params.q_ptr = q.data_ptr();
+  params.k_ptr = k.data_ptr();
+  params.v_ptr = v.data_ptr();
+  // All stride are in elements, not bytes.
+  params.q_row_stride = q.stride(-3);
+  params.k_row_stride = k.stride(-3);
+  params.v_row_stride = v.stride(-3);
+  params.q_head_stride = q.stride(-2);
+  params.k_head_stride = k.stride(-2);
+  params.v_head_stride = v.stride(-2);
+  params.o_ptr = out.data_ptr();
+  params.o_row_stride = out.stride(-3);
+  params.o_head_stride = out.stride(-2);
+
+  if (cu_seqlens_q_d == nullptr) {
+    params.q_batch_stride = q.stride(0);
+    params.k_batch_stride = k.stride(0);
+    params.v_batch_stride = v.stride(0);
+    params.o_batch_stride = out.stride(0);
+    if (seqlenq_ngroups_swapped) {
+      params.q_batch_stride *= seqlen_q;
+      params.o_batch_stride *= seqlen_q;
+    }
+  }
+
+  params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q_d);
+  params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k_d);
+  params.seqused_k = static_cast<int*>(seqused_k);
+
+  // P = softmax(QK^T)
+  params.p_ptr = p_d;
+
+  // Softmax sum
+  params.softmax_lse_ptr = softmax_lse_d;
+
+  // Set the dimensions.
+  params.b = b;
+  params.h = h;
+  params.h_k = h_k;
+  params.h_h_k_ratio = h / h_k;
+  params.seqlen_q = seqlen_q;
+  params.seqlen_k = seqlen_k;
+  params.seqlen_q_rounded = seqlen_q_rounded;
+  params.seqlen_k_rounded = seqlen_k_rounded;
+  params.d = d;
+  params.d_rounded = d_rounded;
+
+// Set the different scale values.
+#ifdef FLASHATTENTION_DISABLE_SOFTCAP
+  TORCH_CHECK(softcap <= 0.0,
+              "This flash attention build does not support softcap.");
+#endif
+  if (softcap > 0.0) {
+    params.softcap = softmax_scale / softcap;
+    params.scale_softmax = softcap;
+    params.scale_softmax_log2 = softcap * M_LOG2E;
+  } else {
+    // Remove potential NaN
+    params.softcap = 0.0;
+    params.scale_softmax = softmax_scale;
+    params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+  }
+
+  // Set this to probability of keeping an element to simplify things.
+  params.p_dropout = 1.f - p_dropout;
+  // Convert p from float to int so we don't have to convert the random uint to
+  // float to compare. [Minor] We want to round down since when we do the
+  // comparison we use <= instead of < params.p_dropout_in_uint =
+  // uint32_t(std::floor(params.p_dropout * 4294967295.0));
+  // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout *
+  // 65535.0));
+  params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
+  params.rp_dropout = 1.f / params.p_dropout;
+  params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
+  TORCH_CHECK(p_dropout < 1.f);
+#ifdef FLASHATTENTION_DISABLE_DROPOUT
+  TORCH_CHECK(p_dropout == 0.0f,
+              "This flash attention build does not support dropout.");
+#endif
+
+  // Causal is the special case where window_size_right == 0 and
+  // window_size_left < 0. Local is the more general case where
+  // window_size_right >= 0 or window_size_left >= 0.
+  params.is_causal = window_size_left < 0 && window_size_right == 0;
+
+  if (window_size_left < 0 && window_size_right >= 0) {
+    window_size_left = seqlen_k;
+  }
+  if (window_size_left >= 0 && window_size_right < 0) {
+    window_size_right = seqlen_k;
+  }
+  params.window_size_left = window_size_left;
+  params.window_size_right = window_size_right;
+
+#ifdef FLASHATTENTION_DISABLE_LOCAL
+  TORCH_CHECK(
+      params.is_causal || (window_size_left < 0 && window_size_right < 0),
+      "This flash attention build does not support local attention.");
+#endif
+
+  params.is_seqlens_k_cumulative = true;
+
+#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
+  TORCH_CHECK(d == d_rounded,
+              "This flash attention build does not support headdim not being a "
+              "multiple of 32.");
+#endif
+
+  params.unpadded_lse = unpadded_lse;
+  params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
+}
+
+void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream,
+                 bool force_split_kernel = false) {
+  FP16_SWITCH(!params.is_bf16, [&] {
+    HEADDIM_SWITCH(params.d, [&] {
+      BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+        if (params.num_splits <= 1 &&
+            !force_split_kernel) {  // If we don't set it num_splits == 0
+          run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
+        } else {
+          run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params,
+                                                                       stream);
+        }
+      });
+    });
+  });
+}
+
+// Find the number of splits that maximizes the occupancy. For example, if we
+// have batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency =
+// 0.89) is better than having 3 splits (efficiency = 0.67). However, we also
+// don't want too many splits as that would incur more HBM reads/writes. So we
+// find the best efficiency, then find the smallest number of splits that gets
+// 85% of the best efficiency.
+inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs,
+                                int num_n_blocks, int max_splits) {
+  // If we have enough to almost fill the SMs, then just use 1 split
+  if (batch_nheads_mblocks >= 0.8f * num_SMs) {
+    return 1;
+  }
+  max_splits = std::min({max_splits, num_SMs, num_n_blocks});
+  float max_efficiency = 0.f;
+  std::vector<float> efficiency;
+  efficiency.reserve(max_splits);
+  auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+  // Some splits are not eligible. For example, if we have 64 blocks and choose
+  // 11 splits, we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have
+  // 6 * 11 + (-2) blocks (i.e. it's 11 splits anyway). So we check if the
+  // number of blocks per split is the same as the previous num_splits.
+  auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
+    return num_splits == 1 || ceildiv(num_n_blocks, num_splits) !=
+                                  ceildiv(num_n_blocks, num_splits - 1);
+  };
+  for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+    if (!is_split_eligible(num_splits)) {
+      efficiency.push_back(0.f);
+    } else {
+      float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
+      float eff = n_waves / ceil(n_waves);
+      // printf("num_splits = %d, eff = %f\n", num_splits, eff);
+      if (eff > max_efficiency) {
+        max_efficiency = eff;
+      }
+      efficiency.push_back(eff);
+    }
+  }
+  for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+    if (!is_split_eligible(num_splits)) {
+      continue;
+    }
+    if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
+      // printf("num_splits chosen = %d\n", num_splits);
+      return num_splits;
+    }
+  }
+  return 1;
+}
+
+std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
+    Flash_fwd_params& params, const int batch_size, const int num_heads,
+    const int head_size, const int max_seqlen_k, const int max_seqlen_q,
+    const int head_size_rounded, const float p_dropout, const int num_splits,
+    cudaDeviceProp* dprops, struct c10::TensorOptions opts) {
+  // This needs to match with run_mha_fwd_splitkv_dispatch
+  const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
+  const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
+  // Technically kBlockM = 64 only for the splitKV kernels, not the standard
+  // kernel. In any case we don't expect seqlen_q to be larger than 64 for
+  // inference.
+  const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
+  params.num_splits = num_splits;
+  at::Tensor softmax_lse_accum;
+  at::Tensor out_accum;
+
+  if (p_dropout == 0.0f) {  // SplitKV is not implemented for dropout
+    if (num_splits < 1) {
+      // We multiply number of SMs by 2 to hard-code the fact that we're using
+      // 128 threads per block.
+      params.num_splits = num_splits_heuristic(
+          batch_size * num_heads * num_m_blocks,
+          dprops->multiProcessorCount * 2, num_n_blocks, 128);
+    }
+    if (params.num_splits > 1) {
+      softmax_lse_accum =
+          torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q},
+                       opts.dtype(at::kFloat));
+      out_accum = torch::empty({params.num_splits, batch_size, num_heads,
+                                max_seqlen_q, head_size_rounded},
+                               opts.dtype(at::kFloat));
+      params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
+      params.oaccum_ptr = out_accum.data_ptr();
+    }
+    TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
+  }
+
+  return std::make_tuple(softmax_lse_accum, out_accum);
+}
+
+void set_params_alibi(Flash_fwd_params& params,
+                      const c10::optional<at::Tensor>& alibi_slopes_,
+                      int batch_size, int num_heads) {
+#ifdef FLASHATTENTION_DISABLE_ALIBI
+  TORCH_CHECK(!alibi_slopes_.has_value(),
+              "This flash attention build does not support alibi.");
+  params.alibi_slopes_ptr = nullptr;
+#else
+  if (alibi_slopes_.has_value()) {
+    auto alibi_slopes = alibi_slopes_.value();
+    TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32,
+                "ALiBi slopes must have dtype fp32");
+    CHECK_DEVICE(alibi_slopes);
+    TORCH_CHECK(alibi_slopes.stride(-1) == 1,
+                "ALiBi slopes tensor must have contiguous last dimension");
+    TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) ||
+                alibi_slopes.sizes() ==
+                    torch::IntArrayRef({batch_size, num_heads}));
+    params.alibi_slopes_ptr = alibi_slopes.data_ptr();
+    params.alibi_slopes_batch_stride =
+        alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
+  } else {
+    params.alibi_slopes_ptr = nullptr;
+  }
+#endif
+}
+
+std::vector<at::Tensor> mha_fwd(
+    at::Tensor& q,        // batch_size x seqlen_q x num_heads x head_size
+    const at::Tensor& k,  // batch_size x seqlen_k x num_heads_k x head_size
+    const at::Tensor& v,  // batch_size x seqlen_k x num_heads_k x head_size
+    const c10::optional<at::Tensor>&
+        out_,  // batch_size x seqlen_q x num_heads x head_size
+    const c10::optional<at::Tensor>&
+        alibi_slopes_,  // num_heads or batch_size x num_heads
+    const double p_dropout, const double softmax_scale, bool is_causal,
+    int64_t window_size_left, int64_t window_size_right, const double softcap,
+    const bool return_softmax, c10::optional<at::Generator> gen_) {
+  auto dprops = at::cuda::getCurrentDeviceProperties();
+  // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
+  bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
+  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
+  TORCH_CHECK(is_sm90 || is_sm8x,
+              "FlashAttention only supports Ampere GPUs or newer.");
+  // We will support Turing in the near future
+  // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports
+  // Turing GPUs or newer.");
+
+  auto q_dtype = q.dtype();
+  TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+              "FlashAttention only support fp16 and bf16 data type");
+  if (q_dtype == torch::kBFloat16) {
+    TORCH_CHECK(is_sm90 || is_sm8x,
+                "bfloat16 is only supported on Ampere GPUs or newer");
+  }
+  TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+  TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+
+  CHECK_DEVICE(q);
+  CHECK_DEVICE(k);
+  CHECK_DEVICE(v);
+
+  TORCH_CHECK(q.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  TORCH_CHECK(k.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  TORCH_CHECK(v.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+
+  const auto sizes = q.sizes();
+
+  const int batch_size = sizes[0];
+  int seqlen_q = sizes[1];
+  int num_heads = sizes[2];
+  const int head_size_og = sizes[3];
+  const int seqlen_k = k.size(1);
+  const int num_heads_k = k.size(2);
+  TORCH_CHECK(batch_size > 0, "batch size must be positive");
+  TORCH_CHECK(
+      head_size_og <= 256,
+      "FlashAttention forward only supports head dimension at most 256");
+  TORCH_CHECK(
+      num_heads % num_heads_k == 0,
+      "Number of heads in key/value must divide number of heads in query");
+
+  if (softcap > 0.f) {
+    TORCH_CHECK(p_dropout == 0.f,
+                "Softcapping does not support dropout for now");
+  }
+
+  if (window_size_left >= seqlen_k) {
+    window_size_left = -1;
+  }
+  if (window_size_right >= seqlen_k) {
+    window_size_right = -1;
+  }
+
+  // causal=true is the same as causal=false in this case
+  if (seqlen_q == 1 && !alibi_slopes_.has_value()) {
+    is_causal = false;
+  }
+  if (is_causal) {
+    window_size_right = 0;
+  }
+
+  // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups,
+  // nheads_kv, d) in this case H/t Daniel Haziza
+  const int seqlenq_ngroups_swapped =
+      seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
+      window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 &&
+      !alibi_slopes_.has_value();
+  const int ngroups = num_heads / num_heads_k;
+  if (seqlenq_ngroups_swapped) {
+    q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og})
+            .transpose(1, 2);
+    seqlen_q = ngroups;
+    num_heads = num_heads_k;
+  }
+
+  CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
+  CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
+  CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
+
+  at::Tensor q_padded, k_padded, v_padded;
+  if (head_size_og % 8 != 0) {
+    q_padded = torch::nn::functional::pad(
+        q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    k_padded = torch::nn::functional::pad(
+        k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    v_padded = torch::nn::functional::pad(
+        v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+  } else {
+    q_padded = q;
+    k_padded = k;
+    v_padded = v;
+  }
+
+  at::Tensor out;
+  if (out_.has_value()) {
+    out = out_.value();
+    TORCH_CHECK(out.dtype() == q_dtype,
+                "Output must have the same dtype as inputs");
+    CHECK_DEVICE(out);
+    TORCH_CHECK(out.stride(-1) == 1,
+                "Output tensor must have contiguous last dimension");
+    CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
+    if (seqlenq_ngroups_swapped) {
+      out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og})
+                .transpose(1, 2);
+    }
+    if (head_size_og % 8 != 0) {
+      out = torch::empty_like(q_padded);
+    }
+  } else {
+    out = torch::empty_like(q_padded);
+  }
+
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  const int head_size = round_multiple(head_size_og, 8);
+  const int head_size_rounded = round_multiple(head_size, 32);
+  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+  // Otherwise the kernel will be launched from cuda:0 device
+  // Cast to char to avoid compiler warning about narrowing
+  at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+  auto opts = q.options();
+
+  auto softmax_lse =
+      torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+  at::Tensor p;
+  // Only return softmax if there's dropout to reduce compilation time
+  if (return_softmax) {
+    TORCH_CHECK(p_dropout > 0.0f,
+                "return_softmax is only supported when p_dropout > 0.0");
+    p = torch::empty(
+        {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
+  }
+
+  Flash_fwd_params params;
+  set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded,
+                   seqlen_k_rounded, num_heads, num_heads_k, head_size,
+                   head_size_rounded, q_padded, k_padded, v_padded, out,
+                   /*cu_seqlens_q_d=*/nullptr,
+                   /*cu_seqlens_k_d=*/nullptr,
+                   /*seqused_k=*/nullptr,
+                   return_softmax ? p.data_ptr() : nullptr,
+                   softmax_lse.data_ptr(), p_dropout, softmax_scale,
+                   window_size_left, window_size_right, softcap);
+
+  // Keep references to these tensors to extend their lifetime
+  at::Tensor softmax_lse_accum, out_accum;
+  std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
+      params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
+      head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
+
+  // number of times random will be generated per thread, to offset philox
+  // counter in thc random state We use a custom RNG that increases the offset
+  // by batch_size * nheads * 32.
+  int64_t counter_offset = params.b * params.h * 32;
+  auto options =
+      torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
+  auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
+  // Forward kernel will populate memory with the seed and offset.
+  params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
+
+  if (p_dropout > 0.0) {
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        gen_, at::cuda::detail::getDefaultCUDAGenerator());
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard<std::mutex> lock(gen->mutex_);
+    params.philox_args = gen->philox_cuda_state(counter_offset);
+  }
+
+  set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
+
+  if (seqlen_k > 0) {
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    run_mha_fwd(params, stream);
+  } else {
+    // If seqlen_k == 0, then we have an empty tensor. We need to set the output
+    // to 0.
+    out.zero_();
+    softmax_lse.fill_(std::numeric_limits<float>::infinity());
+  }
+
+  at::Tensor out_padded = out;
+  if (head_size_og % 8 != 0) {
+    out = out.index(
+        {"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+    if (out_.has_value()) {
+      out_.value().copy_(out);
+    }
+  }
+
+  if (seqlenq_ngroups_swapped) {
+    out = out.transpose(1, 2).reshape(
+        {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+    out_padded = out_padded.transpose(1, 2).reshape(
+        {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+    q_padded = q_padded.transpose(1, 2).reshape(
+        {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+    softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
+  }
+  return {out,        q_padded,    k_padded, v_padded,
+          out_padded, softmax_lse, p,        rng_state};
+}
+
+std::vector<at::Tensor> mha_varlen_fwd(
+    at::Tensor&
+        q,  // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
+    const at::Tensor& k,  // total_k x num_heads_k x head_size, total_k :=
+                          // \sum_{i=0}^{b} s_i or num_blocks x page_block_size
+                          // x num_heads_k x head_size if there's a block_table.
+    const at::Tensor& v,  // total_k x num_heads_k x head_size, total_k :=
+                          // \sum_{i=0}^{b} s_i or num_blocks x page_block_size
+                          // x num_heads_k x head_size if there's a block_table.
+    const c10::optional<at::Tensor>&
+        out_,  // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
+    const at::Tensor& cu_seqlens_q,  // b+1
+    const at::Tensor& cu_seqlens_k,  // b+1
+    const c10::optional<at::Tensor>&
+        seqused_k,  // b. If given, only this many elements of each batch
+                    // element's keys are used.
+    const c10::optional<at::Tensor>&
+        block_table_,  // batch_size x max_num_blocks_per_seq
+    const c10::optional<at::Tensor>&
+        alibi_slopes_,  // num_heads or b x num_heads
+    int64_t max_seqlen_q, const int64_t max_seqlen_k, const double p_dropout,
+    const double softmax_scale, const bool zero_tensors, bool is_causal,
+    int64_t window_size_left, int64_t window_size_right, const double softcap,
+    const bool return_softmax, c10::optional<at::Generator> gen_) {
+  auto dprops = at::cuda::getCurrentDeviceProperties();
+  // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
+  bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
+  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
+  TORCH_CHECK(is_sm90 || is_sm8x,
+              "FlashAttention only supports Ampere GPUs or newer.");
+  // We will support Turing in the near future
+  // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports
+  // Turing GPUs or newer.");
+
+  auto q_dtype = q.dtype();
+  TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+              "FlashAttention only support fp16 and bf16 data type");
+  if (q_dtype == torch::kBFloat16) {
+    TORCH_CHECK(is_sm90 || is_sm8x,
+                "bfloat16 is only supported on Ampere GPUs or newer");
+  }
+  TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
+  TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
+  TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32,
+              "cu_seqlens_q must have dtype int32");
+  TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32,
+              "cu_seqlens_k must have dtype int32");
+
+  CHECK_DEVICE(q);
+  CHECK_DEVICE(k);
+  CHECK_DEVICE(v);
+  CHECK_DEVICE(cu_seqlens_q);
+  CHECK_DEVICE(cu_seqlens_k);
+
+  at::Tensor block_table;
+  const bool paged_KV = block_table_.has_value();
+  if (paged_KV) {
+    block_table = block_table_.value();
+    CHECK_DEVICE(block_table);
+    TORCH_CHECK(block_table.dtype() == torch::kInt32,
+                "block_table must have dtype torch.int32");
+    TORCH_CHECK(block_table.stride(-1) == 1,
+                "block_table must have contiguous last dimension");
+  }
+
+  TORCH_CHECK(q.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  TORCH_CHECK(k.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  TORCH_CHECK(v.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  CHECK_CONTIGUOUS(cu_seqlens_q);
+  CHECK_CONTIGUOUS(cu_seqlens_k);
+
+  const auto sizes = q.sizes();
+
+  const int batch_size = cu_seqlens_q.numel() - 1;
+  int num_heads = sizes[1];
+  const int head_size_og = sizes[2];
+  const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
+
+  if (softcap > 0.f) {
+    TORCH_CHECK(p_dropout == 0.f,
+                "Softcapping does not support dropout for now");
+  }
+
+  const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
+  const int num_blocks = !paged_KV ? 0 : k.size(0);
+  const int page_block_size = !paged_KV ? 1 : k.size(1);
+  TORCH_CHECK(!paged_KV || page_block_size % 16 == 0,
+              "Paged KV cache block size must be divisible by 16");
+
+  if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) {
+    is_causal = false;
+  }  // causal=true is the same as causal=false in this case
+  if (is_causal) {
+    window_size_right = 0;
+  }
+
+  void* cu_seqlens_q_d = cu_seqlens_q.data_ptr();
+
+  // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups,
+  // nheads_kv, d) in this case H/t Daniel Haziza
+  const int seqlenq_ngroups_swapped =
+      max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
+      window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 &&
+      !alibi_slopes_.has_value();
+  const int ngroups = num_heads / num_heads_k;
+  if (seqlenq_ngroups_swapped) {
+    q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og})
+            .transpose(1, 2)
+            .reshape({batch_size * ngroups, num_heads_k, head_size_og});
+    max_seqlen_q = ngroups;
+    num_heads = num_heads_k;
+    cu_seqlens_q_d = nullptr;
+  }
+
+  const int total_q = q.sizes()[0];
+
+  TORCH_CHECK(batch_size > 0, "batch size must be positive");
+  TORCH_CHECK(
+      head_size_og <= 256,
+      "FlashAttention forward only supports head dimension at most 256");
+  TORCH_CHECK(
+      num_heads % num_heads_k == 0,
+      "Number of heads in key/value must divide number of heads in query");
+
+  if (window_size_left >= max_seqlen_k) {
+    window_size_left = -1;
+  }
+  if (window_size_right >= max_seqlen_k) {
+    window_size_right = -1;
+  }
+
+  CHECK_SHAPE(q, total_q, num_heads, head_size_og);
+  if (!paged_KV) {
+    const int total_k = k.size(0);
+    CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
+  } else {
+    CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
+    CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
+    CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
+  }
+
+  CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
+  CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
+  if (seqused_k.has_value()) {
+    auto seqused_k_ = seqused_k.value();
+    TORCH_CHECK(seqused_k_.dtype() == torch::kInt32,
+                "seqused_k must have dtype int32");
+    TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
+    TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
+    CHECK_SHAPE(seqused_k_, batch_size);
+  }
+
+  at::Tensor q_padded, k_padded, v_padded;
+  if (head_size_og % 8 != 0) {
+    q_padded = torch::nn::functional::pad(
+        q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    k_padded = torch::nn::functional::pad(
+        k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    v_padded = torch::nn::functional::pad(
+        v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+  } else {
+    q_padded = q;
+    k_padded = k;
+    v_padded = v;
+  }
+
+  at::Tensor out;
+  if (out_.has_value()) {
+    out = out_.value();
+    TORCH_CHECK(out.dtype() == q_dtype,
+                "Output must have the same dtype as inputs");
+    CHECK_DEVICE(out);
+    TORCH_CHECK(out.stride(-1) == 1,
+                "Output tensor must have contiguous last dimension");
+    CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
+    if (seqlenq_ngroups_swapped) {
+      out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og})
+                .transpose(1, 2)
+                .reshape({batch_size * ngroups, num_heads_k, head_size_og});
+    }
+    if (head_size_og % 8 != 0) {
+      out = torch::empty_like(q_padded);
+    }
+  } else {
+    out = torch::empty_like(q_padded);
+  }
+
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  const int head_size = round_multiple(head_size_og, 8);
+  const int head_size_rounded = round_multiple(head_size, 32);
+  const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
+  const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
+
+  // Otherwise the kernel will be launched from cuda:0 device
+  // Cast to char to avoid compiler warning about narrowing
+  at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+  auto opts = q.options();
+  auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
+  at::Tensor p;
+  // Only return softmax if there's dropout to reduce compilation time
+  if (return_softmax) {
+    TORCH_CHECK(p_dropout > 0.0f,
+                "return_softmax is only supported when p_dropout > 0.0");
+    p = torch::empty(
+        {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
+  }
+
+  if (zero_tensors) {
+    out.zero_();
+    softmax_lse.fill_(-std::numeric_limits<float>::infinity());
+    if (return_softmax) {
+      p.zero_();
+    }
+  }
+
+  Flash_fwd_params params;
+  set_params_fprop(
+      params, batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded,
+      seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded,
+      q_padded, k_padded, v_padded, out, cu_seqlens_q_d,
+      cu_seqlens_k.data_ptr(),
+      seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
+      return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(),
+      p_dropout, softmax_scale, window_size_left, window_size_right, softcap,
+      seqlenq_ngroups_swapped,
+      /*unpadded_lse*/ true);
+  params.total_q = total_q;
+
+  if (paged_KV) {
+    params.block_table = block_table.data_ptr<int>();
+    params.block_table_batch_stride = block_table.stride(0);
+    params.k_batch_stride = k_padded.stride(0);
+    params.v_batch_stride = v_padded.stride(0);
+  }
+  params.page_block_size = page_block_size;
+  // Keep references to these tensors to extend their lifetime
+  at::Tensor softmax_lse_accum, out_accum;
+  if (seqlenq_ngroups_swapped) {
+    // Only apply split-k for decoding
+    std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
+        params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q,
+        head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
+  }
+
+  // number of times random will be generated per thread, to offset philox
+  // counter in thc random state We use a custom RNG that increases the offset
+  // by batch_size * nheads * 32.
+  int64_t counter_offset = params.b * params.h * 32;
+  auto options =
+      torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
+  auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
+  // Forward kernel will populate memory with the seed and offset.
+  params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
+
+  if (p_dropout > 0.0) {
+    auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
+        gen_, at::cuda::detail::getDefaultCUDAGenerator());
+    // See Note [Acquire lock when using random generators]
+    std::lock_guard<std::mutex> lock(gen->mutex_);
+    params.philox_args = gen->philox_cuda_state(counter_offset);
+  }
+
+  set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
+
+  if (max_seqlen_k > 0) {
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    run_mha_fwd(params, stream, paged_KV);
+  } else {
+    // If seqlen_k == 0, then we have an empty tensor. We need to set the output
+    // to 0.
+    out.zero_();
+    softmax_lse.fill_(std::numeric_limits<float>::infinity());
+  }
+
+  at::Tensor out_padded = out;
+  if (head_size_og % 8 != 0) {
+    out = out.index(
+        {"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+    if (out_.has_value()) {
+      out_.value().copy_(out);
+    }
+  }
+
+  if (seqlenq_ngroups_swapped) {
+    int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k,
+                             head_size_og};
+    int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q,
+                            head_size_og};
+    out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
+    out_padded =
+        out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
+    q_padded =
+        q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
+    softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
+  }
+
+  return {out,        q_padded,    k_padded, v_padded,
+          out_padded, softmax_lse, p,        rng_state};
+}
+
+std::vector<at::Tensor> mha_fwd_kvcache(
+    at::Tensor& q,  // batch_size x seqlen_q x num_heads x head_size
+    const at::Tensor&
+        kcache,  // batch_size_c x seqlen_k x num_heads_k x head_size or
+                 // num_blocks x page_block_size x num_heads_k x head_size if
+                 // there's a block_table.
+    const at::Tensor&
+        vcache,  // batch_size_c x seqlen_k x num_heads_k x head_size or
+                 // num_blocks x page_block_size x num_heads_k x head_size if
+                 // there's a block_table.
+    const c10::optional<at::Tensor>&
+        k_,  // batch_size x seqlen_knew x num_heads_k x head_size
+    const c10::optional<at::Tensor>&
+        v_,  // batch_size x seqlen_knew x num_heads_k x head_size
+    const c10::optional<at::Tensor>& seqlens_k_,  // batch_size
+    const c10::optional<at::Tensor>&
+        rotary_cos_,  // seqlen_ro x (rotary_dim / 2)
+    const c10::optional<at::Tensor>&
+        rotary_sin_,  // seqlen_ro x (rotary_dim / 2)
+    const c10::optional<at::Tensor>&
+        cache_batch_idx_,  // indices to index into the KV cache
+    const c10::optional<at::Tensor>&
+        block_table_,  // batch_size x max_num_blocks_per_seq
+    const c10::optional<at::Tensor>&
+        alibi_slopes_,  // num_heads or batch_size x num_heads
+    const c10::optional<at::Tensor>&
+        out_,  // batch_size x seqlen_q x num_heads x head_size
+    const double softmax_scale, bool is_causal, int64_t window_size_left,
+    int64_t window_size_right, const double softcap,
+    bool is_rotary_interleaved,  // if true, rotary combines indices 0 & 1, else
+                                 // indices 0 & rotary_dim / 2
+    int64_t num_splits) {
+  auto dprops = at::cuda::getCurrentDeviceProperties();
+  // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
+  bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
+  bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
+  TORCH_CHECK(is_sm90 || is_sm8x,
+              "FlashAttention only supports Ampere GPUs or newer.");
+  // We will support Turing in the near future
+  // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports
+  // Turing GPUs or newer.");
+
+  auto q_dtype = q.dtype();
+  TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
+              "FlashAttention only support fp16 and bf16 data type");
+  if (q_dtype == torch::kBFloat16) {
+    TORCH_CHECK(is_sm90 || is_sm8x,
+                "bfloat16 is only supported on Ampere GPUs or newer");
+  }
+  TORCH_CHECK(kcache.dtype() == q_dtype,
+              "query and key must have the same dtype");
+  TORCH_CHECK(vcache.dtype() == q_dtype,
+              "query and value must have the same dtype");
+
+  CHECK_DEVICE(q);
+  CHECK_DEVICE(kcache);
+  CHECK_DEVICE(vcache);
+
+  TORCH_CHECK(q.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  TORCH_CHECK(kcache.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+  TORCH_CHECK(vcache.stride(-1) == 1,
+              "Input tensor must have contiguous last dimension");
+
+  at::Tensor block_table;
+  const bool paged_KV = block_table_.has_value();
+  if (paged_KV) {
+    TORCH_CHECK(!cache_batch_idx_.has_value(),
+                "Paged KVcache does not support cache_batch_idx");
+    block_table = block_table_.value();
+    CHECK_DEVICE(block_table);
+    TORCH_CHECK(block_table.dtype() == torch::kInt32,
+                "block_table must have dtype torch.int32");
+    TORCH_CHECK(block_table.stride(-1) == 1,
+                "block_table must have contiguous last dimension");
+  }
+
+  const auto sizes = q.sizes();
+
+  const int batch_size = sizes[0];
+  int seqlen_q = sizes[1];
+  const int seqlen_q_og = seqlen_q;
+  int num_heads = sizes[2];
+  const int num_heads_og = num_heads;
+  const int head_size_og = sizes[3];
+
+  const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
+  const int num_blocks = !paged_KV ? 0 : kcache.size(0);
+  const int page_block_size = !paged_KV ? 1 : kcache.size(1);
+  TORCH_CHECK(!paged_KV || page_block_size % 16 == 0,
+              "Paged KV cache block size must be divisible by 16");
+  const int seqlen_k =
+      !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
+  const int num_heads_k = kcache.size(2);
+  const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
+  TORCH_CHECK(batch_size > 0, "batch size must be positive");
+  TORCH_CHECK(
+      head_size_og <= 256,
+      "FlashAttention forward only supports head dimension at most 256");
+  TORCH_CHECK(
+      num_heads % num_heads_k == 0,
+      "Number of heads in key/value must divide number of heads in query");
+
+  // causal=true is the same as causal=false in this case
+  if (seqlen_q == 1 && !alibi_slopes_.has_value()) {
+    is_causal = false;
+  }
+  if (is_causal) {
+    window_size_right = 0;
+  }
+
+  // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups,
+  // nheads_kv, d) in this case H/t Daniel Haziza
+  const int seqlenq_ngroups_swapped =
+      seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
+      window_size_right < 0 && head_size_og % 8 == 0 &&
+      !alibi_slopes_.has_value();
+  if (seqlenq_ngroups_swapped) {
+    const int ngroups = num_heads / num_heads_k;
+    q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og})
+            .transpose(1, 2);
+    seqlen_q = ngroups;
+    num_heads = num_heads_k;
+  }
+
+  if (window_size_left >= seqlen_k) {
+    window_size_left = -1;
+  }
+  if (window_size_right >= seqlen_k) {
+    window_size_right = -1;
+  }
+
+  CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
+  if (!paged_KV) {
+    CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+    CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
+  } else {
+    CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
+    CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
+    CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
+  }
+
+  at::Tensor q_padded, kcache_padded, vcache_padded;
+  if (head_size_og % 8 != 0) {
+    q_padded = torch::nn::functional::pad(
+        q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    kcache_padded = torch::nn::functional::pad(
+        kcache,
+        torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    vcache_padded = torch::nn::functional::pad(
+        vcache,
+        torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+  } else {
+    q_padded = q;
+    kcache_padded = kcache;
+    vcache_padded = vcache;
+  }
+
+  at::Tensor out;
+  if (out_.has_value()) {
+    out = out_.value();
+    TORCH_CHECK(out.dtype() == q_dtype,
+                "Output must have the same dtype as inputs");
+    CHECK_DEVICE(out);
+    TORCH_CHECK(out.stride(-1) == 1,
+                "Output tensor must have contiguous last dimension");
+    CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og);
+    if (head_size_og % 8 != 0) {
+      out = torch::empty_like(q_padded);
+    } else if (seqlenq_ngroups_swapped) {
+      out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og})
+                .transpose(1, 2);
+    }
+  } else {
+    out = torch::empty_like(q_padded);
+  }
+
+  auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+  const int head_size = round_multiple(head_size_og, 8);
+  const int head_size_rounded = round_multiple(head_size, 32);
+  const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+  const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+  // Otherwise the kernel will be launched from cuda:0 device
+  // Cast to char to avoid compiler warning about narrowing
+  at::cuda::CUDAGuard device_guard{(char)q.get_device()};
+
+  auto opts = q.options();
+
+  auto softmax_lse =
+      torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
+
+  Flash_fwd_params params;
+  set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded,
+                   seqlen_k_rounded, num_heads, num_heads_k, head_size,
+                   head_size_rounded, q_padded, kcache_padded, vcache_padded,
+                   out,
+                   /*cu_seqlens_q_d=*/nullptr,
+                   /*cu_seqlens_k_d=*/nullptr,
+                   /*seqused_k=*/nullptr,
+                   /*p_ptr=*/nullptr, softmax_lse.data_ptr(),
+                   /*p_dropout=*/0.f, softmax_scale, window_size_left,
+                   window_size_right, softcap);
+
+  at::Tensor k, v, k_padded, v_padded;
+  if (k_.has_value()) {
+    TORCH_CHECK(v_.has_value(),
+                "If key is supplied, value must also be passed in");
+    TORCH_CHECK(seqlens_k_.has_value(),
+                "If key is supplied, seqlens_k must also be passed in");
+    TORCH_CHECK(seqlen_q <= seqlen_k,
+                "If key is supplied, it must have seqlen <= the seqlen of the "
+                "KV cache");
+    k = k_.value();
+    v = v_.value();
+    TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
+    TORCH_CHECK(v.dtype() == q_dtype,
+                "Value must have the same dtype as query");
+    CHECK_DEVICE(k);
+    CHECK_DEVICE(v);
+    TORCH_CHECK(k.stride(-1) == 1,
+                "Key tensor must have contiguous last dimension");
+    TORCH_CHECK(v.stride(-1) == 1,
+                "Value tensor must have contiguous last dimension");
+    int seqlen_knew = k.size(1);
+    CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
+    CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
+    if (head_size_og % 8 != 0) {
+      k_padded = torch::nn::functional::pad(
+          k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+      v_padded = torch::nn::functional::pad(
+          v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
+    } else {
+      k_padded = k;
+      v_padded = v;
+    }
+    params.seqlen_knew = seqlen_knew;
+    params.knew_ptr = k_padded.data_ptr();
+    params.vnew_ptr = v_padded.data_ptr();
+    // All stride are in elements, not bytes.
+    params.knew_batch_stride = k_padded.stride(0);
+    params.vnew_batch_stride = v_padded.stride(0);
+    params.knew_row_stride = k_padded.stride(-3);
+    params.vnew_row_stride = v_padded.stride(-3);
+    params.knew_head_stride = k_padded.stride(-2);
+    params.vnew_head_stride = v_padded.stride(-2);
+  }
+
+  if (seqlens_k_.has_value()) {
+    auto seqlens_k = seqlens_k_.value();
+    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32,
+                "seqlens_k must have dtype int32");
+    CHECK_DEVICE(seqlens_k);
+    CHECK_CONTIGUOUS(seqlens_k);
+    CHECK_SHAPE(seqlens_k, batch_size);
+    params.cu_seqlens_k = static_cast<int*>(seqlens_k.data_ptr());
+  }
+  params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
+
+  if (rotary_cos_.has_value()) {
+    TORCH_CHECK(k_.has_value(),
+                "If rotary cos/sin are provided, new key / value to be "
+                "appended to KV cache must also be provided");
+    auto rotary_cos = rotary_cos_.value();
+    CHECK_DEVICE(rotary_cos);
+    params.rotary_dim = rotary_cos.size(1) * 2;
+    TORCH_CHECK(params.rotary_dim <= head_size,
+                "rotary_dim must be <= headdim");
+    TORCH_CHECK(
+        params.rotary_dim % 16 == 0,
+        "Only rotary dimensions divisible by 16 are currently supported");
+    const int seqlen_ro = rotary_cos.size(0);
+    TORCH_CHECK(seqlen_ro >= seqlen_k,
+                "cos/sin seqlen must be at least the seqlen of KV cache");
+    CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
+    CHECK_CONTIGUOUS(rotary_cos);
+    TORCH_CHECK(rotary_cos.scalar_type() == q_dtype,
+                "rotary_cos must have the same dtype as query");
+
+    TORCH_CHECK(rotary_sin_.has_value(),
+                "If rotary cos is provided, rotary sin must also be provided");
+    auto rotary_sin = rotary_sin_.value();
+    CHECK_DEVICE(rotary_sin);
+    CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
+    CHECK_CONTIGUOUS(rotary_sin);
+    TORCH_CHECK(rotary_sin.scalar_type() == q_dtype,
+                "rotary_cos must have the same dtype as query");
+    params.rotary_cos_ptr = rotary_cos.data_ptr();
+    params.rotary_sin_ptr = rotary_sin.data_ptr();
+    params.is_rotary_interleaved = is_rotary_interleaved;
+  } else {
+    params.rotary_dim = 0;
+  }
+
+  if (cache_batch_idx_.has_value()) {
+    auto cache_batch_idx = cache_batch_idx_.value();
+    CHECK_DEVICE(cache_batch_idx);
+    CHECK_CONTIGUOUS(cache_batch_idx);
+    TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32,
+                "cache_batch_idx must have dtype int32");
+    params.cache_batch_idx = reinterpret_cast<int*>(cache_batch_idx.data_ptr());
+  }
+
+  // Keep references to these tensors to extend their lifetime
+  at::Tensor softmax_lse_accum, out_accum;
+  std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
+      params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
+      head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
+
+  if (paged_KV) {
+    params.block_table = block_table.data_ptr<int>();
+    params.block_table_batch_stride = block_table.stride(0);
+  }
+  params.page_block_size = page_block_size;
+
+  set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
+
+  auto stream = at::cuda::getCurrentCUDAStream().stream();
+  // Only split kernel supports appending to KV cache, or indexing to the cache
+  // with cache_batch_idx, or paged KV cache
+  run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() ||
+                                  cache_batch_idx_.has_value() || paged_KV);
+
+  if (head_size_og % 8 != 0) {
+    out = out.index(
+        {"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
+    if (out_.has_value()) {
+      out_.value().copy_(out);
+    }
+    if (k_.has_value()) {
+      // It's expensive to copy the KV cache here for the case where head size
+      // not divisible by 8, but we don't expect to get this case in practice.
+      // This is just so that the code works for that case.
+      kcache.copy_(kcache_padded.index(
+          {"...",
+           torch::indexing::Slice(torch::indexing::None, head_size_og)}));
+      vcache.copy_(vcache_padded.index(
+          {"...",
+           torch::indexing::Slice(torch::indexing::None, head_size_og)}));
+    }
+  }
+
+  if (seqlenq_ngroups_swapped) {
+    out = out.transpose(1, 2).reshape(
+        {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
+    softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
+  }
+  return {out, softmax_lse};
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim128_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim128_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params& params,
+                                                   cudaStream_t stream) {
+  run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim128_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim128_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params& params,
+                                               cudaStream_t stream) {
+  run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim160_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim160_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params& params,
+                                                   cudaStream_t stream) {
+  run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim160_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim160_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params& params,
+                                               cudaStream_t stream) {
+  run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim192_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim192_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params& params,
+                                                   cudaStream_t stream) {
+  run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim192_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim192_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params& params,
+                                               cudaStream_t stream) {
+  run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim224_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim224<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim224_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params& params,
+                                                   cudaStream_t stream) {
+  run_mha_fwd_hdim224<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim224_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 224, true>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim224<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim224_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 224, false>(Flash_fwd_params& params,
+                                               cudaStream_t stream) {
+  run_mha_fwd_hdim224<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim256_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim256_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params& params,
+                                                   cudaStream_t stream) {
+  run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim256_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim256_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params& params,
+                                               cudaStream_t stream) {
+  run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim32_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params& params,
+                                                 cudaStream_t stream) {
+  run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim32_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim32_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params& params,
+                                             cudaStream_t stream) {
+  run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim32_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim64_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params& params,
+                                                 cudaStream_t stream) {
+  run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim64_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim64_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params& params,
+                                             cudaStream_t stream) {
+  run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim64_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim96_bf16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params& params,
+                                                 cudaStream_t stream) {
+  run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim96_bf16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params& params,
+                                                  cudaStream_t stream) {
+  run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim96_fp16_causal_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params& params,
+                                             cudaStream_t stream) {
+  run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);
+}

+ 11 - 0
kernels/flash_attn/flash_fwd_hdim96_fp16_sm80.cu

@@ -0,0 +1,11 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up
+// compilation. This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template <>
+void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params& params,
+                                              cudaStream_t stream) {
+  run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
+}

+ 1715 - 0
kernels/flash_attn/flash_fwd_kernel.h

@@ -0,0 +1,1715 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cute/tensor.hpp>
+
+#include <cutlass/cutlass.h>
+#include <cutlass/array.h>
+#include <cutlass/numeric_types.h>
+
+#include "block_info.h"
+#include "kernel_traits.h"
+#include "utils.h"
+#include "softmax.h"
+#include "mask.h"
+#include "dropout.h"
+#include "rotary.h"
+
+namespace flash {
+
+using namespace cute;
+
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout>& tensor,
+                                              const float softcap) {
+#pragma unroll
+  for (int i = 0; i < size(tensor); ++i) {
+    tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
+__forceinline__ __device__ auto get_lse_tile(
+    const Params& params, const int bidb, const int bidh, const int m_block,
+    const BlockInfo</*Varlen=*/!Is_even_MN>& binfo) {
+  // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) -
+  // this is non-variable seqlen path. Otherwise, when
+  // params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b)
+  // to account for seqlen_q <-> h swapping trick. Otherwise, it's written as
+  // (h, b, seqlen_q).
+  const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped;
+  auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0;
+  auto gmem_ptr_lse = make_gmem_ptr(
+      reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset);
+
+  auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q)
+                            : make_shape(params.b, params.h, params.seqlen_q);
+  auto lse_stride =
+      params.seqlenq_ngroups_swapped
+          ? make_stride(1, params.seqlen_q * params.b, params.b)
+          : (params.unpadded_lse
+                 ? make_stride(params.h * params.total_q, params.total_q, 1)
+                 : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1));
+
+  auto lse_layout = make_layout(lse_shape, lse_stride);
+  Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout);
+  auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _);
+  return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block));
+}
+
+template <typename Kernel_traits, bool Is_dropout, bool Is_causal,
+          bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K,
+          bool Is_softcap, bool Return_softmax, typename Params>
+inline __device__ void compute_attn_1rowblock(const Params& params,
+                                              const int bidb, const int bidh,
+                                              const int m_block) {
+  using Element = typename Kernel_traits::Element;
+  using ElementAccum = typename Kernel_traits::ElementAccum;
+  using index_t = typename Kernel_traits::index_t;
+
+  // Shared memory.
+  extern __shared__ char smem_[];
+
+  // The thread index.
+  const int tidx = threadIdx.x;
+
+  constexpr int kBlockM = Kernel_traits::kBlockM;
+  constexpr int kBlockN = Kernel_traits::kBlockN;
+  constexpr int kHeadDim = Kernel_traits::kHeadDim;
+  constexpr int kNWarps = Kernel_traits::kNWarps;
+
+  auto seed_offset = at::cuda::philox::unpack(params.philox_args);
+  flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset),
+                         params.p_dropout_in_uint8_t, bidb, bidh, tidx,
+                         params.h);
+
+  // Save seed and offset for backward, before any early exiting. Otherwise the
+  // 0-th thread block might exit early and no one saves the rng states.
+  if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
+      tidx == 0) {
+    params.rng_state[0] = std::get<0>(seed_offset);
+    params.rng_state[1] = std::get<1>(seed_offset);
+  }
+
+  const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+  if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+
+  const int n_block_min =
+      !Is_local
+          ? 0
+          : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k -
+                         binfo.actual_seqlen_q - params.window_size_left) /
+                            kBlockN);
+  int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
+  if (Is_causal || Is_local) {
+    n_block_max = std::min(
+        n_block_max,
+        cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k -
+                           binfo.actual_seqlen_q + params.window_size_right,
+                       kBlockN));
+    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
+    //     printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
+    // }
+  }
+  // We exit early and write 0 to gO and gLSE. This also covers the case where
+  // actual_seqlen_k == 0. Otherwise we might read OOB elements from gK and gV.
+  if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
+    Tensor mO = make_tensor(
+        make_gmem_ptr(
+            reinterpret_cast<Element*>(params.o_ptr) +
+            binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
+        make_shape(binfo.actual_seqlen_q, params.h, params.d),
+        make_stride(params.o_row_stride, params.o_head_stride, _1{}));
+    Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                           make_coord(m_block, 0));  // (kBlockM, kHeadDim)
+
+    Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(
+        params, bidb, bidh, m_block, binfo);
+
+    typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+    auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+    Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+    Tensor tOrO = make_tensor<Element>(shape(tOgO));
+    clear(tOrO);
+    // Construct identity layout for sO
+    Tensor cO = make_identity_tensor(make_shape(
+        size<0>(gO), size<1>(gO)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+    // Repeat the partitioning with identity layouts
+    Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
+    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
+    if (!Is_even_K) {
+#pragma unroll
+      for (int k = 0; k < size(tOpO); ++k) {
+        tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+      }
+    }
+    // Clear_OOB_K must be false since we don't want to write zeros to gmem
+    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
+                /*Clear_OOB_K=*/false>(
+        gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO,
+        binfo.actual_seqlen_q - m_block * kBlockM);
+#pragma unroll
+    for (int m = 0; m < size<1>(tOgO); ++m) {
+      const int row = get<0>(tOcO(0, m, 0));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM &&
+          get<1>(tOcO(0, m, 0)) == 0) {
+        gLSE(row) = INFINITY;
+      }
+    }
+    return;
+  }
+  // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max =
+  // %d\n", m_block, n_block_min, n_block_max); }
+
+  // We iterate over the blocks in reverse order. This is because the last block
+  // is the only one that needs masking when we read K and V from global memory.
+  // Moreover, iterating in reverse might save us 1 register (we just need
+  // n_block instead of both n_block and n_block_max).
+
+  const index_t row_offset_p =
+      ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) *
+          params.seqlen_k_rounded +
+      (n_block_max - 1) * kBlockN;
+
+  Tensor mQ =
+      make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) +
+                                binfo.q_offset(params.q_batch_stride,
+                                               params.q_row_stride, bidb)),
+                  make_shape(binfo.actual_seqlen_q, params.h, params.d),
+                  make_stride(params.q_row_stride, params.q_head_stride, _1{}));
+  Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                         make_coord(m_block, 0));  // (kBlockM, kHeadDim)
+  Tensor mK =
+      make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) +
+                                binfo.k_offset(params.k_batch_stride,
+                                               params.k_row_stride, bidb)),
+                  make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
+                  make_stride(params.k_row_stride, params.k_head_stride, _1{}));
+  Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _),
+                         Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                         make_coord(_, 0));  // (kBlockN, kHeadDim, nblocksN)
+  Tensor mV =
+      make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) +
+                                binfo.k_offset(params.v_batch_stride,
+                                               params.v_row_stride, bidb)),
+                  make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
+                  make_stride(params.v_row_stride, params.v_head_stride, _1{}));
+  Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _),
+                         Shape<Int<kBlockN>, Int<kHeadDim>>{},
+                         make_coord(_, 0));  // (kBlockN, kHeadDim, nblocksN)
+  Tensor gP = make_tensor(
+      make_gmem_ptr(reinterpret_cast<Element*>(params.p_ptr) + row_offset_p),
+      Shape<Int<kBlockM>, Int<kBlockN>>{},
+      make_stride(params.seqlen_k_rounded, _1{}));
+
+  Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
+                          typename Kernel_traits::SmemLayoutQ{});
+  // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
+  Tensor sK =
+      make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
+                  typename Kernel_traits::SmemLayoutKV{});
+
+  Tensor sV =
+      make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
+  Tensor sVt =
+      make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+  Tensor sVtNoSwizzle =
+      make_tensor(sV.data().get(),
+                  typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+  typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+  auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+
+  Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
+  Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
+  Tensor tKgK =
+      gmem_thr_copy_QKV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K, nblocksN)
+  Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
+  Tensor tVgV =
+      gmem_thr_copy_QKV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K, nblocksN)
+  Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
+
+  typename Kernel_traits::TiledMma tiled_mma;
+  auto thr_mma = tiled_mma.get_thread_slice(tidx);
+  Tensor tSrQ = thr_mma.partition_fragment_A(sQ);  // (MMA,MMA_M,MMA_K)
+  Tensor tSrK = thr_mma.partition_fragment_B(sK);  // (MMA,MMA_N,MMA_K)
+  Tensor tOrVt =
+      thr_mma.partition_fragment_B(sVtNoSwizzle);  // (MMA, MMA_K,MMA_N)
+
+  Tensor tSgS = thr_mma.partition_C(gP);
+
+  Tensor acc_o = partition_fragment_C(
+      tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K
+
+  //
+  // Copy Atom retiling
+  //
+
+  auto smem_tiled_copy_Q =
+      make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+  // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
+  Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+  // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
+
+  auto smem_tiled_copy_K =
+      make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+  Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+  auto smem_tiled_copy_V = make_tiled_copy_B(
+      typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+  auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+  Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+  //
+  // PREDICATES
+  //
+
+  // // Allocate predicate tensors for m and n
+  // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)),
+  // Stride<_1,_0>{}); Tensor tKVpKV =
+  // make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)),
+  // Stride<_1,_0>{});
+
+  // Construct identity layout for sQ and sK
+  Tensor cQ = make_identity_tensor(
+      make_shape(size<0>(sQ), size<1>(sQ)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  Tensor cKV = make_identity_tensor(
+      make_shape(size<0>(sK), size<1>(sK)));  // (BLK_N,BLK_K) -> (blk_n,blk_k)
+  // Tensor tScQ = thr_mma.partition_A(cQ);                           //
+  // (MMA,MMA_M,MMA_K) if (cute::thread0()) {
+  //     print(tScQ.layout()); printf("\n");
+  //     for (int i = 0; i < size(tScQ); ++i) {
+  //         printf("%d ", get<0>(tScQ(i)));
+  //     }
+  //     printf("\n");
+  //     for (int i = 0; i < size(tScQ); ++i) {
+  //         printf("%d ", get<1>(tScQ(i)));
+  //     }
+  //     printf("\n");
+  // }
+
+  // Repeat the partitioning with identity layouts
+  Tensor tQcQ = gmem_thr_copy_QKV.partition_S(
+      cQ);  // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(
+      cKV);  // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+
+  // Allocate predicate tensors for k
+  Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
+  Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
+
+  // Set predicates for k bounds
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tQpQ); ++k) {
+      tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
+    }
+#pragma unroll
+    for (int k = 0; k < size(tKVpKV); ++k) {
+      tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
+    }
+  }
+
+  // Prologue
+
+  // We don't need to clear the sQ smem tiles since we'll only write out the
+  // valid outputs
+  flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ,
+                                     tQpQ,
+                                     binfo.actual_seqlen_q - m_block * kBlockM);
+  if (Kernel_traits::Is_Q_in_regs) {
+    cute::cp_async_fence();
+  }
+
+  // // if (cute::thread(1, 0)) { print(tQsQ); }
+  // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element
+  // *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
+  // // if (cute::thread0()) { print(sQNoSwizzle); }
+
+  if (Kernel_traits::Share_Q_K_smem) {
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
+    CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));  // M
+    cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
+    __syncthreads();
+  }
+
+  int n_block = n_block_max - 1;
+  // We don't need to clear the sK smem tiles since we'll mask out the scores
+  // anyway.
+  flash::copy<Is_even_MN, Is_even_K>(
+      gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
+      binfo.actual_seqlen_k - n_block * kBlockN);
+  cute::cp_async_fence();
+  // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
+  // __syncthreads();
+
+  if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
+    flash::cp_async_wait<1>();
+    __syncthreads();
+    Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
+    CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));  // M
+    cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
+  }
+
+  clear(acc_o);
+
+  flash::Softmax<2 * size<1>(acc_o)> softmax;
+
+  const float alibi_slope =
+      !Has_alibi || params.alibi_slopes_ptr == nullptr
+          ? 0.0f
+          : reinterpret_cast<float*>(params.alibi_slopes_ptr)
+                    [bidb * params.alibi_slopes_batch_stride + bidh] /
+                params.scale_softmax;
+  flash::Mask<Is_causal, Is_local, Has_alibi> mask(
+      binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left,
+      params.window_size_right, alibi_slope);
+
+  // For performance reason, we separate out two kinds of iterations:
+  // those that need masking on S, and those that don't.
+  // We need masking on S for the very last block when K and V has length not
+  // multiple of kBlockN. We also need masking on S if it's causal, for the last
+  // ceil_div(kBlockM, kBlockN) blocks. We will have at least 1 "masking"
+  // iteration.
+
+  // If not even_N, then seqlen_k might end in the middle of a block. In that
+  // case we need to mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+  constexpr int n_masking_steps =
+      (!Is_causal && !Is_local)
+          ? 1
+          : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN)
+                                       : cute::ceil_div(kBlockM, kBlockN) + 1);
+#pragma unroll
+  for (int masking_step = 0; masking_step < n_masking_steps;
+       ++masking_step, --n_block) {
+    Tensor acc_s = partition_fragment_C(
+        tiled_mma,
+        Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+
+    // Advance gV
+    if (masking_step > 0) {
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(
+          gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
+    } else {
+      // Clear the smem tiles to account for predicated off loads
+      flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+          gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV,
+          binfo.actual_seqlen_k - n_block * kBlockN);
+    }
+    cute::cp_async_fence();
+
+    flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
+        smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
+    // if (cute::thread0()) { print(acc_s); }
+    if constexpr (Is_softcap) {
+      apply_softcap(acc_s, params.softcap);
+    }
+
+    mask.template apply_mask<Is_causal, Is_even_MN>(
+        acc_s, n_block * kBlockN,
+        m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    if (n_block > n_block_min) {
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV,
+                                                  tKgK(_, _, _, n_block - 1),
+                                                  tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the
+      // synchronization isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    // TODO: when we have key_padding_mask we'll need to Check_inf
+    masking_step == 0
+        ? softmax.template softmax_rescale_o<
+              /*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(
+              acc_s, acc_o, params.scale_softmax_log2)
+        : softmax.template softmax_rescale_o<
+              /*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(
+              acc_s, acc_o, params.scale_softmax_log2);
+
+    // Convert acc_s from fp32 to fp16/bf16
+    Tensor rP = flash::convert_type<Element>(acc_s);
+    int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
+    int block_col_idx = n_block * (kBlockN / 32);
+    if (Return_softmax) {
+      Tensor rP_drop = make_fragment_like(rP);
+      cute::copy(rP, rP_drop);
+      dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
+          rP_drop, block_row_idx, block_col_idx, kNWarps);
+      cute::copy(rP_drop, tSgS);
+      tSgS.data() = tSgS.data() + (-kBlockN);
+    }
+    if (Is_dropout) {
+      dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
+    }
+
+    // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+    Tensor tOrP = make_tensor(
+        rP.data(),
+        flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+    // if (cute::thread0()) { print(tOrP); }
+    flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
+                   smem_thr_copy_V);
+    // if (cute::thread0()) { print(scores); }
+
+    // This check is at the end of the loop since we always have at least 1
+    // iteration
+    if (n_masking_steps > 1 && n_block <= n_block_min) {
+      --n_block;
+      break;
+    }
+  }
+
+  // These are the iterations where we don't need masking on S
+  for (; n_block >= n_block_min; --n_block) {
+    Tensor acc_s = partition_fragment_C(
+        tiled_mma,
+        Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    flash::copy</*Is_even_MN=*/true, Is_even_K>(
+        gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
+    cute::cp_async_fence();
+
+    flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+        acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
+        smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
+    if constexpr (Is_softcap) {
+      apply_softcap(acc_s, params.softcap);
+    }
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    if (n_block > n_block_min) {
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV,
+                                                  tKgK(_, _, _, n_block - 1),
+                                                  tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the
+      // synchronization isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    mask.template apply_mask</*Causal_mask=*/false>(
+        acc_s, n_block * kBlockN,
+        m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
+
+    softmax
+        .template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(
+            acc_s, acc_o, params.scale_softmax_log2);
+
+    Tensor rP = flash::convert_type<Element>(acc_s);
+    int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
+    int block_col_idx = n_block * (kBlockN / 32);
+    if (Return_softmax) {
+      Tensor rP_drop = make_fragment_like(rP);
+      cute::copy(rP, rP_drop);
+      dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
+          rP_drop, block_row_idx, block_col_idx, kNWarps);
+      cute::copy(rP_drop, tSgS);
+      tSgS.data() = tSgS.data() + (-kBlockN);
+    }
+    if (Is_dropout) {
+      dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
+    }
+
+    // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+    Tensor tOrP = make_tensor(
+        rP.data(),
+        flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+    flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
+                   smem_thr_copy_V);
+  }
+
+  // Epilogue
+
+  Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(
+      acc_o, params.scale_softmax, params.rp_dropout);
+
+  // Convert acc_o from fp32 to fp16/bf16
+  Tensor rO = flash::convert_type<Element>(acc_o);
+  Tensor sO = make_tensor(
+      sQ.data(), typename Kernel_traits::SmemLayoutO{});  // (SMEM_M,SMEM_N)
+  // Partition sO to match the accumulator partitioning
+  auto smem_tiled_copy_O =
+      make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
+  auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
+  Tensor taccOrO =
+      smem_thr_copy_O.retile_S(rO);  // ((Atom,AtomNum), MMA_M, MMA_N)
+  Tensor taccOsO =
+      smem_thr_copy_O.partition_D(sO);  // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+  // sO has the same size as sQ, so we don't need to sync here.
+  if (Kernel_traits::Share_Q_K_smem) {
+    __syncthreads();
+  }
+
+  cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
+
+  Tensor mO =
+      make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) +
+                                binfo.q_offset(params.o_batch_stride,
+                                               params.o_row_stride, bidb)),
+                  make_shape(binfo.actual_seqlen_q, params.h, params.d),
+                  make_stride(params.o_row_stride, params.o_head_stride, _1{}));
+  Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                         make_coord(m_block, 0));  // (kBlockM, kHeadDim)
+  Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(
+      params, bidb, bidh, m_block, binfo);
+
+  typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+  auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+  Tensor tOsO =
+      gmem_thr_copy_O.partition_S(sO);  // ((Atom,AtomNum),ATOM_M,ATOM_N)
+  Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+
+  __syncthreads();
+
+  Tensor tOrO = make_tensor<Element>(shape(tOgO));
+  cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
+
+  Tensor caccO = make_identity_tensor(
+      Shape<Int<kBlockM>, Int<kHeadDim>>{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  Tensor taccOcO = thr_mma.partition_C(caccO);  // (MMA,MMA_M,MMA_K)
+  static_assert(decltype(size<0>(taccOcO))::value == 4);
+  // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+  Tensor taccOcO_row =
+      logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
+  CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));  // MMA_M
+  if (get<1>(taccOcO_row(0)) == 0) {
+#pragma unroll
+    for (int mi = 0; mi < size(lse); ++mi) {
+      const int row = get<0>(taccOcO_row(mi));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
+        gLSE(row) = lse(mi);
+      }
+    }
+  }
+
+  // Construct identity layout for sO
+  Tensor cO = make_identity_tensor(
+      make_shape(size<0>(sO), size<1>(sO)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  // Repeat the partitioning with identity layouts
+  Tensor tOcO =
+      gmem_thr_copy_O.partition_D(cO);  // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tOpO); ++k) {
+      tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+    }
+  }
+  // Clear_OOB_K must be false since we don't want to write zeros to gmem
+  flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
+              /*Clear_OOB_K=*/false>(gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO,
+                                     binfo.actual_seqlen_q - m_block * kBlockM);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
+          bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split,
+          bool Append_KV, typename Params>
+inline __device__ void compute_attn_1rowblock_splitkv(
+    const Params& params, const int bidb, const int bidh, const int m_block,
+    const int n_split_idx, const int num_n_splits) {
+  using Element = typename Kernel_traits::Element;
+  using ElementAccum = typename Kernel_traits::ElementAccum;
+  using index_t = typename Kernel_traits::index_t;
+
+  // Shared memory.
+  extern __shared__ char smem_[];
+
+  // The thread index.
+  const int tidx = threadIdx.x;
+
+  constexpr int kBlockM = Kernel_traits::kBlockM;
+  constexpr int kBlockN = Kernel_traits::kBlockN;
+  constexpr int kHeadDim = Kernel_traits::kHeadDim;
+  constexpr int kNWarps = Kernel_traits::kNWarps;
+
+  using GmemTiledCopyO =
+      std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO,
+                         typename Kernel_traits::GmemTiledCopyOaccum>;
+  using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
+
+  const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
+  // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
+  // printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d,
+  // actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative,
+  // binfo.seqlen_k_cache, binfo.actual_seqlen_k); } if (threadIdx.x == 0 &&
+  // blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p,
+  // seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache
+  // + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
+  if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
+
+  const int n_blocks_per_split =
+      ((binfo.actual_seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) /
+      num_n_splits;
+  const int n_block_min =
+      !Is_local ? n_split_idx * n_blocks_per_split
+                : std::max(n_split_idx * n_blocks_per_split,
+                           (m_block * kBlockM + binfo.actual_seqlen_k -
+                            binfo.actual_seqlen_q - params.window_size_left) /
+                               kBlockN);
+  int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN),
+                             (n_split_idx + 1) * n_blocks_per_split);
+  if (Is_causal || Is_local) {
+    n_block_max = std::min(
+        n_block_max,
+        cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k -
+                           binfo.actual_seqlen_q + params.window_size_right,
+                       kBlockN));
+  }
+  if (n_block_min >=
+      n_block_max) {  // This also covers the case where n_block_max <= 0
+    // We exit early and write 0 to gOaccum and -inf to gLSEaccum.
+    // Otherwise we might read OOB elements from gK and gV,
+    // or get wrong results when we combine gOaccum from different blocks.
+    const index_t row_offset_o =
+        binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) +
+        m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+    const index_t row_offset_oaccum =
+        (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
+         m_block * kBlockM) *
+        params.d_rounded;
+    const index_t row_offset_lseaccum =
+        ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
+        m_block * kBlockM;
+    Tensor gOaccum = make_tensor(
+        make_gmem_ptr(reinterpret_cast<ElementO*>(Split ? params.oaccum_ptr
+                                                        : params.o_ptr) +
+                      (Split ? row_offset_oaccum : row_offset_o)),
+        Shape<Int<kBlockM>, Int<kHeadDim>>{},
+        make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+    Tensor gLSEaccum = make_tensor(
+        make_gmem_ptr(
+            reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr
+                                                  : params.softmax_lse_ptr) +
+            row_offset_lseaccum),
+        Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+    GmemTiledCopyO gmem_tiled_copy_Oaccum;
+    auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+    Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+    Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
+    clear(tOrOaccum);
+    // Construct identity layout for sO
+    Tensor cO = make_identity_tensor(make_shape(
+        size<0>(gOaccum), size<1>(gOaccum)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+    // Repeat the partitioning with identity layouts
+    Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
+    Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+    if (!Is_even_K) {
+#pragma unroll
+      for (int k = 0; k < size(tOpO); ++k) {
+        tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+      }
+    }
+    // Clear_OOB_K must be false since we don't want to write zeros to gmem
+    flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
+                /*Clear_OOB_K=*/false>(
+        gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO,
+        binfo.actual_seqlen_q - m_block * kBlockM);
+#pragma unroll
+    for (int m = 0; m < size<1>(tOgOaccum); ++m) {
+      const int row = get<0>(tOcO(0, m, 0));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM &&
+          get<1>(tOcO(0, m, 0)) == 0) {
+        gLSEaccum(row) = Split ? -INFINITY : INFINITY;
+      }
+    }
+    return;
+  }
+
+  // We iterate over the blocks in reverse order. This is because the last block
+  // is the only one that needs masking when we read K and V from global memory.
+  // Moreover, iterating in reverse might save us 1 register (we just need
+  // n_block instead of both n_block and n_block_max).
+
+  // We move K and V to the last block.
+  const int bidb_cache =
+      params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
+  const int* block_table =
+      params.block_table == nullptr
+          ? nullptr
+          : params.block_table + bidb * params.block_table_batch_stride;
+  const index_t row_offset_k =
+      block_table == nullptr
+          ? binfo.k_offset(params.k_batch_stride, params.k_row_stride,
+                           bidb_cache) +
+                (n_block_max - 1) * kBlockN * params.k_row_stride +
+                (bidh / params.h_h_k_ratio) * params.k_head_stride
+          : (bidh / params.h_h_k_ratio) *
+                params.k_head_stride;  // block addresses are later resolved
+                                       // per-thread
+
+  const index_t row_offset_v =
+      block_table == nullptr
+          ? binfo.k_offset(params.v_batch_stride, params.v_row_stride,
+                           bidb_cache) +
+                (n_block_max - 1) * kBlockN * params.v_row_stride +
+                (bidh / params.h_h_k_ratio) * params.v_head_stride
+          : (bidh / params.h_h_k_ratio) * params.v_head_stride;
+
+  Tensor mQ =
+      make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) +
+                                binfo.q_offset(params.q_batch_stride,
+                                               params.q_row_stride, bidb)),
+                  make_shape(binfo.actual_seqlen_q, params.h, params.d),
+                  make_stride(params.q_row_stride, params.q_head_stride, _1{}));
+  Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                         make_coord(m_block, 0));  // (kBlockM, kHeadDim)
+  Tensor gK = make_tensor(
+      make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k),
+      Shape<Int<kBlockN>, Int<kHeadDim>>{},
+      make_stride(params.k_row_stride, _1{}));
+  // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr
+  // = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k,
+  // gK.data()); }
+  Tensor gV = make_tensor(
+      make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v),
+      Shape<Int<kBlockN>, Int<kHeadDim>>{},
+      make_stride(params.v_row_stride, _1{}));
+  Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element*>(smem_)),
+                          typename Kernel_traits::SmemLayoutQ{});
+  Tensor sK =
+      make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{});
+  Tensor sV =
+      make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
+  Tensor sVt =
+      make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+  Tensor sVtNoSwizzle =
+      make_tensor(sV.data().get(),
+                  typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+  typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q;
+  auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
+  typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV;
+  auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx);
+
+  Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
+  Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
+
+  Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK);  // (KCPY, KCPY_N, KCPY_K)
+  Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK);
+  Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV);  // (VCPY, VCPY_N, VCPY_K)
+  Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV);
+
+  Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout()));
+  Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout()));
+  Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout()));
+  Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout()));
+
+  if (block_table != nullptr) {
+    tKgK.data() =
+        gK.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                        tidx, n_block_max, params.page_block_size, block_table,
+                        params.k_batch_stride, params.k_row_stride);
+    tVgV.data() =
+        gV.data() + flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                        tidx, n_block_max, params.page_block_size, block_table,
+                        params.v_batch_stride, params.v_row_stride);
+  }
+
+  typename Kernel_traits::TiledMma tiled_mma;
+  auto thr_mma = tiled_mma.get_thread_slice(tidx);
+  Tensor tSrQ = thr_mma.partition_fragment_A(sQ);  // (MMA,MMA_M,MMA_K)
+  Tensor tSrK = thr_mma.partition_fragment_B(sK);  // (MMA,MMA_N,MMA_K)
+  Tensor tOrVt =
+      thr_mma.partition_fragment_B(sVtNoSwizzle);  // (MMA, MMA_K,MMA_N)
+
+  Tensor acc_o = partition_fragment_C(
+      tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});  // MMA, MMA_M, MMA_K
+
+  //
+  // Copy Atom retiling
+  //
+
+  auto smem_tiled_copy_Q =
+      make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+  Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+
+  auto smem_tiled_copy_K =
+      make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+  auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+  Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+  auto smem_tiled_copy_V = make_tiled_copy_B(
+      typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+  auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+  Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+  // PREDICATES
+  //
+
+  // // Allocate predicate tensors for m and n
+  // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)),
+  // Stride<_1,_0>{}); Tensor tKVpKV =
+  // make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)),
+  // Stride<_1,_0>{});
+
+  // Construct identity layout for sQ and sK
+  Tensor cQ = make_identity_tensor(
+      make_shape(size<0>(sQ), size<1>(sQ)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  Tensor cKV = make_identity_tensor(
+      make_shape(size<0>(sK), size<1>(sK)));  // (BLK_N,BLK_K) -> (blk_n,blk_k)
+
+  // Repeat the partitioning with identity layouts
+  Tensor tQcQ =
+      gmem_thr_copy_Q.partition_S(cQ);  // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(
+      cKV);  // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+  Tensor tKVcKV =
+      make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout()));
+
+  // Allocate predicate tensors for k
+  Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
+  Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
+
+  // Set predicates for k bounds
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tQpQ); ++k) {
+      tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
+    }
+#pragma unroll
+    for (int k = 0; k < size(tKVpKV); ++k) {
+      tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
+    }
+  }
+
+  // Prologue
+
+  // Copy from Knew to K, optionally apply rotary embedding.
+  if constexpr (Append_KV) {
+    typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary;
+    auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
+    typename Kernel_traits::GmemTiledCopyRotcossinContPaged
+        gmem_tiled_copy_rotary_cont;
+    auto gmem_thr_copy_rotary_cont =
+        gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
+
+    // Even if we have MQA / GQA, all threadblocks responsible for the same KV
+    // head are writing to gmem. Technically it's a race condition, but they all
+    // write the same content anyway, and it's safe. We want to do this so that
+    // all threadblocks can proceed right after they finish writing the KV
+    // cache.
+    const index_t row_offset_cossin =
+        ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2);
+    Tensor gCos = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+        make_stride(params.rotary_dim / 2, _1{}));
+    Tensor gSin = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockN>, Int<kHeadDim / 2>>{},
+        make_stride(params.rotary_dim / 2, _1{}));
+    Tensor gCosCont = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockN>, Int<kHeadDim>>{},
+        make_stride(params.rotary_dim / 2, _1{}));
+    Tensor gSinCont = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockN>, Int<kHeadDim>>{},
+        make_stride(params.rotary_dim / 2, _1{}));
+
+    Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos);
+    Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin);
+    Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+    Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+
+    Tensor tRgCos =
+        make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout()));
+    Tensor tRgSin =
+        make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout()));
+    Tensor tRgCosCont = make_tensor(
+        tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout()));
+    Tensor tRgSinCont = make_tensor(
+        tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout()));
+
+    // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p,
+    // tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr,
+    // gCos.data(), tRgCos.data(), params.rotary_dim); } if (cute::thread(8, 0))
+    // { print_tensor(gCos); } if (cute::thread(0, 0)) { print_tensor(tRgCos); }
+
+    const index_t row_offset_knew =
+        binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) +
+        ((n_block_max - 1) * kBlockN) * params.knew_row_stride +
+        (bidh / params.h_h_k_ratio) * params.knew_head_stride;
+    const index_t row_offset_vnew =
+        binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) +
+        ((n_block_max - 1) * kBlockN) * params.vnew_row_stride +
+        (bidh / params.h_h_k_ratio) * params.vnew_head_stride;
+    // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew
+    // "line up". When we access them, e.g. if gK has 128 rows and gKnew has 64
+    // rows, we access gK[:128] and gKNew[128:128 + 64]. This maps to accessing
+    // the first 64 rows of knew_ptr.
+    Tensor gKnew = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.knew_ptr) +
+                      row_offset_knew -
+                      binfo.seqlen_k_cache * params.knew_row_stride),
+        Shape<Int<kBlockN>, Int<kHeadDim>>{},
+        make_stride(params.knew_row_stride, _1{}));
+    // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
+    // printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n",
+    // params.knew_ptr, row_offset_knew, gKnew.data()); }
+    Tensor gVnew = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.vnew_ptr) +
+                      row_offset_vnew -
+                      binfo.seqlen_k_cache * params.vnew_row_stride),
+        Shape<Int<kBlockN>, Int<kHeadDim>>{},
+        make_stride(params.vnew_row_stride, _1{}));
+    typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new;
+    auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx);
+    Tensor tKgKnew_ =
+        gmem_thr_copy_KV_new.partition_S(gKnew);  // (KCPY, KCPY_N, KCPY_K)
+    Tensor tVgVnew_ =
+        gmem_thr_copy_KV_new.partition_S(gVnew);  // (VCPY, VCPY_N, VCPY_K)
+
+    auto tKgKnew =
+        make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout()));
+    auto tVgVnew =
+        make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout()));
+
+    const int n_block_copy_min =
+        std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
+    auto tKgK_data = tKgK.data();
+    auto tVgV_data = tVgV.data();
+    for (int n_block = n_block_max - 1; n_block >= n_block_copy_min;
+         n_block--) {
+      flash::copy_w_min_idx<Is_even_K>(
+          tVgVnew, tVgV, tKVcKV, tKVpKV,
+          binfo.actual_seqlen_k - n_block * kBlockN,
+          binfo.seqlen_k_cache - n_block * kBlockN);
+      tVgVnew.data() =
+          tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
+      if (params.rotary_dim == 0) {
+        flash::copy_w_min_idx<Is_even_K>(
+            tKgKnew, tKgK, tKVcKV, tKVpKV,
+            binfo.actual_seqlen_k - n_block * kBlockN,
+            binfo.seqlen_k_cache - n_block * kBlockN);
+      } else {
+        if (params.is_rotary_interleaved) {
+          // Don't clear OOB_K because we're writing to global memory
+          flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>(
+              tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV,
+              binfo.actual_seqlen_k - n_block * kBlockN,
+              binfo.seqlen_k_cache - n_block * kBlockN, params.d,
+              params.rotary_dim);
+          tRgCos.data() =
+              tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2));
+          tRgSin.data() =
+              tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2));
+        } else {
+          // Don't clear OOB_K because we're writing to global memory
+          flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>(
+              tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV,
+              binfo.actual_seqlen_k - n_block * kBlockN,
+              binfo.seqlen_k_cache - n_block * kBlockN, params.d,
+              params.rotary_dim);
+          tRgCosCont.data() =
+              tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+          tRgSinCont.data() =
+              tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2));
+        }
+      }
+      tKgKnew.data() =
+          tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
+      if (block_table == nullptr) {
+        tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+        tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      } else {
+        if (n_block > n_block_copy_min) {
+          tVgV.data() =
+              gV.data() +
+              flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                  tidx, n_block, params.page_block_size, block_table,
+                  params.v_batch_stride, params.v_row_stride);
+          tKgK.data() =
+              gK.data() +
+              flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                  tidx, n_block, params.page_block_size, block_table,
+                  params.k_batch_stride, params.k_row_stride);
+        }
+      }
+    }
+    // Need this before we can read in K again, so that we'll see the updated K
+    // values.
+    __syncthreads();
+    tKgK.data() = tKgK_data;
+    tVgV.data() = tVgV_data;
+  }
+
+  // Read Q from gmem to smem, optionally apply rotary embedding.
+  if (!Append_KV || params.rotary_dim == 0) {
+    // We don't need to clear the sQ smem tiles since we'll only write out the
+    // valid outputs
+    flash::copy<Is_even_MN, Is_even_K>(
+        gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
+        binfo.actual_seqlen_q - m_block * kBlockM);
+  } else {
+    typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
+    auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
+    typename Kernel_traits::GmemTiledCopyRotcossinCont
+        gmem_tiled_copy_rotary_cont;
+    auto gmem_thr_copy_rotary_cont =
+        gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
+    const index_t row_offset_cossin =
+        (binfo.seqlen_k_cache +
+         (Is_causal || Is_local ? m_block * kBlockM : 0)) *
+        (params.rotary_dim / 2);
+    // If not causal, all the queries get the same the cos/sin, taken at
+    // location seqlen_k_cache. We do this by setting the row stride of gCos /
+    // gSin to 0.
+    Tensor gCos = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+        make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor gSin = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockM>, Int<kHeadDim / 2>>{},
+        make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor gCosCont = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_cos_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockM>, Int<kHeadDim>>{},
+        make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor gSinCont = make_tensor(
+        make_gmem_ptr(reinterpret_cast<Element*>(params.rotary_sin_ptr) +
+                      row_offset_cossin),
+        Shape<Int<kBlockM>, Int<kHeadDim>>{},
+        make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{}));
+    Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
+    Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin);
+    Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
+    Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
+    if (params.is_rotary_interleaved) {
+      flash::copy_rotary_interleaved<Is_even_K>(
+          tQgQ, tQsQ, tRgCos, tRgSin, tQcQ,
+          binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d,
+          params.rotary_dim);
+    } else {
+      flash::copy_rotary_contiguous<Is_even_K>(
+          tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ,
+          binfo.actual_seqlen_q - m_block * kBlockM, 0, params.d,
+          params.rotary_dim);
+    }
+  }
+
+  int n_block = n_block_max - 1;
+  // We don't need to clear the sK smem tiles since we'll mask out the scores
+  // anyway.
+  flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV,
+                                     tKVpKV,
+                                     binfo.actual_seqlen_k - n_block * kBlockN);
+  cute::cp_async_fence();
+
+  // flash::cp_async_wait<0>();
+  // __syncthreads();
+  // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
+  // __syncthreads();
+
+  clear(acc_o);
+
+  flash::Softmax<2 * size<1>(acc_o)> softmax;
+
+  const float alibi_slope =
+      !Has_alibi ? 0.0f
+                 : reinterpret_cast<float*>(params.alibi_slopes_ptr)
+                           [bidb * params.alibi_slopes_batch_stride + bidh] /
+                       params.scale_softmax;
+  flash::Mask<Is_causal, Is_local, Has_alibi> mask(
+      binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left,
+      params.window_size_right, alibi_slope);
+
+  // For performance reason, we separate out two kinds of iterations:
+  // those that need masking on S, and those that don't.
+  // We need masking on S for the very last block when K and V has length not
+  // multiple of kBlockN. We also need masking on S if it's causal, for the last
+  // ceil_div(kBlockM, kBlockN) blocks. We will have at least 1 "masking"
+  // iteration.
+
+  // If not even_N, then seqlen_k might end in the middle of a block. In that
+  // case we need to mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+  constexpr int n_masking_steps =
+      (!Is_causal && !Is_local)
+          ? 1
+          : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN)
+                                       : cute::ceil_div(kBlockM, kBlockN) + 1);
+#pragma unroll
+  for (int masking_step = 0; masking_step < n_masking_steps;
+       ++masking_step, --n_block) {
+    Tensor acc_s = partition_fragment_C(
+        tiled_mma,
+        Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+
+    // Advance gV
+    if (masking_step > 0) {
+      if (block_table == nullptr) {
+        tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+      } else {
+        tVgV.data() =
+            gV.data() +
+            flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                tidx, n_block + 1, params.page_block_size, block_table,
+                params.v_batch_stride, params.v_row_stride);
+      }
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV,
+                                                  tVsV, tKVcKV, tKVpKV);
+    } else {
+      // Clear the smem tiles to account for predicated off loads
+      flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
+          gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV,
+          binfo.actual_seqlen_k - n_block * kBlockN);
+    }
+    cute::cp_async_fence();
+
+    flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
+                smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
+    // if (cute::thread0()) { print(acc_s); }
+    if constexpr (Is_softcap) {
+      apply_softcap(acc_s, params.softcap);
+    }
+
+    mask.template apply_mask<Is_causal, Is_even_MN>(
+        acc_s, n_block * kBlockN,
+        m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); }
+    // __syncthreads();
+
+    if (n_block > n_block_min) {
+      // Advance gK
+      if (block_table == nullptr) {
+        tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      } else {
+        tKgK.data() = gK.data() +
+                      flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                          tidx, n_block, params.page_block_size, block_table,
+                          params.k_batch_stride, params.k_row_stride);
+      }
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK,
+                                                  tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the
+      // synchronization isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    // We have key_padding_mask so we'll need to Check_inf
+    masking_step == 0
+        ? softmax.template softmax_rescale_o</*Is_first=*/true,
+                                             /*Check_inf=*/Is_causal ||
+                                                 Is_local || !Is_even_MN>(
+              acc_s, acc_o, params.scale_softmax_log2)
+        : softmax.template softmax_rescale_o</*Is_first=*/false,
+                                             /*Check_inf=*/Is_causal ||
+                                                 Is_local || !Is_even_MN>(
+              acc_s, acc_o, params.scale_softmax_log2);
+    // if (cute::thread0()) { print(scores_max); print(scores_sum);
+    // print(scores); }
+
+    // Convert acc_s from fp32 to fp16/bf16
+    Tensor rP = flash::convert_type<Element>(acc_s);
+    // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+    Tensor tOrP = make_tensor(
+        rP.data(),
+        flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+
+    flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
+                   smem_thr_copy_V);
+
+    // This check is at the end of the loop since we always have at least 1
+    // iteration
+    if (n_masking_steps > 1 && n_block <= n_block_min) {
+      --n_block;
+      break;
+    }
+  }
+
+  // These are the iterations where we don't need masking on S
+  for (; n_block >= n_block_min; --n_block) {
+    Tensor acc_s = partition_fragment_C(
+        tiled_mma,
+        Shape<Int<kBlockM>, Int<kBlockN>>{});  // (MMA=4, MMA_M, MMA_N)
+    clear(acc_s);
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    // Advance gV
+    if (block_table == nullptr) {
+      tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+    } else {
+      tVgV.data() = gV.data() +
+                    flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                        tidx, n_block + 1, params.page_block_size, block_table,
+                        params.v_batch_stride, params.v_row_stride);
+    }
+    flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tVgV, tVsV,
+                                                tKVcKV, tKVpKV);
+    cute::cp_async_fence();
+
+    flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q,
+                smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K);
+    if constexpr (Is_softcap) {
+      apply_softcap(acc_s, params.softcap);
+    }
+
+    flash::cp_async_wait<0>();
+    __syncthreads();
+    if (n_block > n_block_min) {
+      // Advance gK
+      if (block_table == nullptr) {
+        tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+      } else {
+        tKgK.data() = gK.data() +
+                      flash::resolve_thread_kv_page_slice_offset<Kernel_traits>(
+                          tidx, n_block, params.page_block_size, block_table,
+                          params.k_batch_stride, params.k_row_stride);
+      }
+      flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_KV, tKgK,
+                                                  tKsK, tKVcKV, tKVpKV);
+      // This cp_async_fence needs to be in the if block, otherwise the
+      // synchronization isn't right and we get race conditions.
+      cute::cp_async_fence();
+    }
+
+    mask.template apply_mask</*Causal_mask=*/false>(
+        acc_s, n_block * kBlockN,
+        m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
+    softmax
+        .template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(
+            acc_s, acc_o, params.scale_softmax_log2);
+
+    Tensor rP = flash::convert_type<Element>(acc_s);
+    // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+    // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
+    Tensor tOrP = make_tensor(
+        rP.data(),
+        flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
+
+    flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V,
+                   smem_thr_copy_V);
+  }
+
+  // Epilogue
+
+  Tensor lse =
+      softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(
+          acc_o, params.scale_softmax);
+  // if (cute::thread0()) { print(lse); }
+
+  Tensor sOaccum =
+      make_tensor(make_smem_ptr(reinterpret_cast<ElementO*>(smem_)),
+                  typename Kernel_traits::SmemLayoutO{});  // (SMEM_M,SMEM_N)
+  // Partition sO to match the accumulator partitioning
+  using SmemTiledCopyO =
+      std::conditional_t<!Split, typename Kernel_traits::SmemCopyAtomO,
+                         typename Kernel_traits::SmemCopyAtomOaccum>;
+  auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma);
+  auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
+  Tensor rO = flash::convert_type<ElementO>(acc_o);
+  Tensor taccOrOaccum =
+      smem_thr_copy_Oaccum.retile_S(rO);  // ((Atom,AtomNum), MMA_M, MMA_N)
+  Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(
+      sOaccum);  // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+  // sOaccum is larger than sQ, so we need to syncthreads here
+  // TODO: allocate enough smem for sOaccum
+  if constexpr (Split) {
+    __syncthreads();
+  }
+
+  cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
+
+  const index_t row_offset_o =
+      binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) +
+      m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+  const index_t row_offset_oaccum =
+      (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q +
+       m_block * kBlockM) *
+      params.d_rounded;
+  const index_t row_offset_lseaccum =
+      (Split || !params.unpadded_lse
+           ? ((n_split_idx * params.b + bidb) * params.h + bidh) *
+                 params.seqlen_q
+           : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb)) +
+      m_block * kBlockM;
+
+  Tensor gOaccum =
+      make_tensor(make_gmem_ptr(reinterpret_cast<ElementO*>(
+                                    Split ? params.oaccum_ptr : params.o_ptr) +
+                                (Split ? row_offset_oaccum : row_offset_o)),
+                  Shape<Int<kBlockM>, Int<kHeadDim>>{},
+                  make_stride(Split ? kHeadDim : params.o_row_stride, _1{}));
+  Tensor gLSEaccum = make_tensor(
+      make_gmem_ptr(
+          reinterpret_cast<ElementAccum*>(Split ? params.softmax_lseaccum_ptr
+                                                : params.softmax_lse_ptr) +
+          row_offset_lseaccum),
+      Shape<Int<kBlockM>>{}, Stride<_1>{});
+  // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n",
+  // row_offset_o, bidh, gOaccum.data()); }
+
+  GmemTiledCopyO gmem_tiled_copy_Oaccum;
+  auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+  Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(
+      sOaccum);  // ((Atom,AtomNum),ATOM_M,ATOM_N)
+  Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
+
+  __syncthreads();
+
+  Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
+  cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
+
+  Tensor caccO = make_identity_tensor(
+      Shape<Int<kBlockM>, Int<kHeadDim>>{});  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  Tensor taccOcO = thr_mma.partition_C(caccO);  // (MMA,MMA_M,MMA_K)
+  static_assert(decltype(size<0>(taccOcO))::value == 4);
+  // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+  Tensor taccOcO_row =
+      logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
+  CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row));  // MMA_M
+  if (get<1>(taccOcO_row(0)) == 0) {
+#pragma unroll
+    for (int mi = 0; mi < size(lse); ++mi) {
+      const int row = get<0>(taccOcO_row(mi));
+      if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
+        gLSEaccum(row) = lse(mi);
+      }
+    }
+  }
+
+  // Construct identity layout for sO
+  Tensor cO = make_identity_tensor(make_shape(
+      size<0>(sOaccum), size<1>(sOaccum)));  // (BLK_M,BLK_K) -> (blk_m,blk_k)
+  // Repeat the partitioning with identity layouts
+  Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(
+      cO);  // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+  Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tOpO); ++k) {
+      tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+    }
+  }
+  // Clear_OOB_K must be false since we don't want to write zeros to gmem
+  flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false,
+              /*Clear_OOB_K=*/false>(gmem_tiled_copy_Oaccum, tOrOaccum,
+                                     tOgOaccum, tOcO, tOpO,
+                                     binfo.actual_seqlen_q - m_block * kBlockM);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_dropout, bool Is_causal,
+          bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K,
+          bool Is_softcap, bool Return_softmax, typename Params>
+inline __device__ void compute_attn(const Params& params) {
+  const int m_block = blockIdx.x;
+  // The block index for the batch.
+  const int bidb = blockIdx.y;
+  // The block index for the head.
+  const int bidh = blockIdx.z;
+
+  // We want the fwd and bwd to generate the same dropout pattern (RNG), without
+  // restricting them to have the same number of threads or have to traverse the
+  // attention matrix in the same order. In the Philox RNG, we use the offset to
+  // store the batch, head, and the lane id (within a warp). We use the
+  // subsequence to store the location of the 16 x 32 blocks within the
+  // attention matrix. This way, as long as we have the batch, head, and the
+  // location of the 16 x 32 block within the attention matrix, we can generate
+  // the exact same dropout pattern.
+
+  flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local,
+                                Has_alibi, Is_even_MN, Is_even_K, Is_softcap,
+                                Return_softmax>(params, bidb, bidh, m_block);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi,
+          bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split,
+          bool Append_KV, typename Params>
+inline __device__ void compute_attn_splitkv(const Params& params) {
+  const int m_block = blockIdx.x;
+  // The block index for the batch.
+  const int bidb = Split ? blockIdx.z / params.h : blockIdx.y;
+  // The block index for the head.
+  const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
+  const int n_split_idx = Split ? blockIdx.y : 0;
+  const int num_n_splits = Split ? gridDim.y : 1;
+  flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local,
+                                        Has_alibi, Is_even_MN, Is_even_K,
+                                        Is_softcap, Split, Append_KV>(
+      params, bidb, bidh, m_block, n_split_idx, num_n_splits);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Kernel_traits, int kBlockM, int Log_max_splits,
+          bool Is_even_K, typename Params>
+inline __device__ void combine_attn_seqk_parallel(const Params& params) {
+  using Element = typename Kernel_traits::Element;
+  using ElementAccum = typename Kernel_traits::ElementAccum;
+  using index_t = typename Kernel_traits::index_t;
+  constexpr int kMaxSplits = 1 << Log_max_splits;
+  constexpr int kHeadDim = Kernel_traits::kHeadDim;
+  constexpr int kNThreads = Kernel_traits::kNThreads;
+
+  static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
+  static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32,
+                "kBlockM must be 4, 8, 16 or 32");
+  static_assert(kNThreads == 128, "We assume that each block has 128 threads");
+
+  // Shared memory.
+  // kBlockM + 1 instead of kBlockM to reduce bank conflicts.
+  __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
+
+  // The thread and block index.
+  const int tidx = threadIdx.x;
+  const int bidx = blockIdx.x;
+
+  const index_t lse_size = params.b * params.h * params.seqlen_q;
+
+  const index_t row_offset_lse = bidx * kBlockM;
+  Tensor gLSEaccum = make_tensor(
+      make_gmem_ptr(
+          reinterpret_cast<ElementAccum*>(params.softmax_lseaccum_ptr) +
+          row_offset_lse),
+      Shape<Int<kMaxSplits>, Int<kBlockM>>{}, make_stride(lse_size, _1{}));
+
+  // LSE format is different depending on params.unpadded_lse and
+  // params.seqlenq_ngroups_swapped, see comment in get_lse_tile. This tensor's
+  // layout maps row_offset_lse to {bidb, bidh, q_offset}.
+  Tensor gLSE = make_tensor(
+      make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) +
+                    row_offset_lse),
+      Shape<Int<kBlockM>>{}, Stride<_1>{});
+
+  // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb,
+  // q_offset}.
+  Layout flat_layout = make_layout(lse_size);
+  Layout orig_layout =
+      make_layout(make_shape(params.seqlen_q, params.h, params.b));
+  auto transposed_stride =
+      params.seqlenq_ngroups_swapped
+          ? make_stride(params.b, params.seqlen_q * params.b, 1)
+          : make_stride(1, params.seqlen_q * params.b, params.seqlen_q);
+  Layout remapped_layout = make_layout(
+      make_shape(params.seqlen_q, params.h, params.b), transposed_stride);
+  Layout final_layout = cute::composition(
+      remapped_layout, cute::composition(orig_layout, flat_layout));
+
+  Tensor gLSE_unpadded = make_tensor(
+      make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr)),
+      final_layout);
+
+  constexpr int kNLsePerThread =
+      (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
+
+  // Read the LSE values from gmem and store them in shared memory, then
+  // transpose them.
+  constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
+#pragma unroll
+  for (int l = 0; l < kNLsePerThread; ++l) {
+    const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
+    const int col = tidx % kBlockM;
+    ElementAccum lse =
+        (row < params.num_splits && col < lse_size - bidx * kBlockM)
+            ? gLSEaccum(row, col)
+            : -INFINITY;
+    if (row < kMaxSplits) {
+      sLSE[row][col] = lse;
+    }
+    // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse
+    // = %f\n", tidx, row, col, lse); }
+  }
+  // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse =
+  // %f\n", tidx, row_offset_lse, lse_accum(0)); }
+  __syncthreads();
+  Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
+  constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
+  // To make sure that kMaxSplits is within 1 warp: we decide how many elements
+  // within kMaxSplits each thread should hold. If kMaxSplits = 16, then each
+  // thread holds 2 elements (128 threads, kBlockM rows, so each time we load we
+  // can load 128 / kBlockM rows). constexpr int kThreadsPerSplit = kMaxSplits /
+  // kRowsPerLoadTranspose; static_assert(kThreadsPerSplit <= 32);
+  static_assert(kRowsPerLoadTranspose <= 32);
+  static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
+#pragma unroll
+  for (int l = 0; l < kNLsePerThread; ++l) {
+    const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+    const int col = tidx / kRowsPerLoadTranspose;
+    lse_accum(l) =
+        (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
+    // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse
+    // = %f\n", tidx, row, col, lse_accum(l)); }
+  }
+
+  // Compute the logsumexp of the LSE along the split dimension.
+  ElementAccum lse_max = lse_accum(0);
+#pragma unroll
+  for (int l = 1; l < kNLsePerThread; ++l) {
+    lse_max = max(lse_max, lse_accum(l));
+  }
+  MaxOp<float> max_op;
+  lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
+  lse_max =
+      lse_max == -INFINITY ? 0.0f : lse_max;  // In case all local LSEs are -inf
+  float lse_sum = expf(lse_accum(0) - lse_max);
+#pragma unroll
+  for (int l = 1; l < kNLsePerThread; ++l) {
+    lse_sum += expf(lse_accum(l) - lse_max);
+  }
+  SumOp<float> sum_op;
+  lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
+  // For the case where all local lse == -INFINITY, we want to set lse_logsum to
+  // INFINITY. Otherwise lse_logsum is log(0.0) = -INFINITY and we get NaN when
+  // we do lse_accum(l) - lse_logsum.
+  ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum)
+                                ? INFINITY
+                                : logf(lse_sum) + lse_max;
+  // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f,
+  // lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
+  if (tidx % kRowsPerLoadTranspose == 0 &&
+      tidx / kRowsPerLoadTranspose < kBlockM) {
+    if (params.unpadded_lse) {
+      const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose;
+      if (lse_offset < lse_size) {
+        gLSE_unpadded(lse_offset) = lse_logsum;
+      }
+    } else {
+      gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum;
+    }
+  }
+// Store the scales exp(lse - lse_logsum) in shared memory.
+#pragma unroll
+  for (int l = 0; l < kNLsePerThread; ++l) {
+    const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
+    const int col = tidx / kRowsPerLoadTranspose;
+    if (row < params.num_splits && col < kBlockM) {
+      sLSE[row][col] = expf(lse_accum(l) - lse_logsum);
+    }
+  }
+  __syncthreads();
+
+  const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
+  Tensor gOaccum = make_tensor(
+      make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.oaccum_ptr) +
+                    row_offset_oaccum),
+      Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
+  constexpr int kBlockN = kNThreads / kBlockM;
+  using GmemLayoutAtomOaccum =
+      Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
+  using GmemTiledCopyOaccum = decltype(make_tiled_copy(
+      Copy_Atom<DefaultCopy, ElementAccum>{}, GmemLayoutAtomOaccum{},
+      Layout<Shape<_1, _4>>{}));  // Val layout, 4 vals per store
+  GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
+  auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
+  Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
+  Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
+  Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
+  clear(tOrO);
+
+  // Predicates
+  Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
+  // Repeat the partitioning with identity layouts
+  Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
+  Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
+  if (!Is_even_K) {
+#pragma unroll
+    for (int k = 0; k < size(tOpOaccum); ++k) {
+      tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d;
+    }
+  }
+  // Load Oaccum in then scale and accumulate to O
+  for (int split = 0; split < params.num_splits; ++split) {
+    flash::copy</*Is_even_MN=*/false, Is_even_K>(
+        gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum,
+        params.b * params.h * params.seqlen_q - bidx * kBlockM);
+#pragma unroll
+    for (int m = 0; m < size<1>(tOrOaccum); ++m) {
+      int row = get<0>(tOcOaccum(0, m, 0));
+      ElementAccum lse_scale = sLSE[split][row];
+#pragma unroll
+      for (int k = 0; k < size<2>(tOrOaccum); ++k) {
+#pragma unroll
+        for (int i = 0; i < size<0>(tOrOaccum); ++i) {
+          tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
+        }
+      }
+      // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0],
+      // sLSE[split][1]); print(tOrOaccum); }
+    }
+    tOgOaccum.data() = tOgOaccum.data() +
+                       params.b * params.h * params.seqlen_q * params.d_rounded;
+  }
+  // if (cute::thread0()) { print_tensor(tOrO); }
+
+  Tensor rO = flash::convert_type<Element>(tOrO);
+// Write to gO
+#pragma unroll
+  for (int m = 0; m < size<1>(rO); ++m) {
+    const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
+    if (idx < params.b * params.h * params.seqlen_q) {
+      const int batch_idx = idx / (params.h * params.seqlen_q);
+      const int head_idx =
+          (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
+      // The index to the rows of Q
+      const int row = idx - batch_idx * (params.h * params.seqlen_q) -
+                      head_idx * params.seqlen_q;
+      auto o_ptr = reinterpret_cast<Element*>(params.o_ptr) +
+                   batch_idx * params.o_batch_stride +
+                   head_idx * params.o_head_stride + row * params.o_row_stride;
+#pragma unroll
+      for (int k = 0; k < size<2>(rO); ++k) {
+        if (Is_even_K || tOpOaccum(k)) {
+          const int col = get<1>(tOcOaccum(0, m, k));
+          Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
+                                  Shape<Int<decltype(size<0>(rO))::value>>{},
+                                  Stride<_1>{});
+          // TODO: Should check if this is using vectorized store, but it seems
+          // pretty fast
+          copy(rO(_, m, k), gO);
+          // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d,
+          // batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx,
+          // batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
+          // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] =
+          // recast<uint64_t>(rO)(0, m, k);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace flash

+ 356 - 0
kernels/flash_attn/flash_fwd_launch_template.h

@@ -0,0 +1,356 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <ATen/cuda/CUDAContext.h>
+
+#include "static_switch.h"
+#include "flash.h"
+#include "flash_fwd_kernel.h"
+
+// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#define ARCH_SUPPORTS_FLASH
+#define KERNEL_PARAM_MODIFIER __grid_constant__
+#else
+#define KERNEL_PARAM_MODIFIER
+#endif
+
+// Define a macro for unsupported architecture handling to centralize the error message
+#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
+
+// Use a macro to clean up kernel definitions
+#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
+template<typename Kernel_traits, __VA_ARGS__> \
+__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) {
+    #if defined(ARCH_SUPPORTS_FLASH)
+        static_assert(!(Is_causal && Is_local)); // Enforce constraints
+        flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params);
+    #else
+        FLASH_UNSUPPORTED_ARCH
+    #endif
+}
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
+    #if defined(ARCH_SUPPORTS_FLASH)
+        flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
+    #else
+        FLASH_UNSUPPORTED_ARCH
+    #endif
+}
+
+DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
+    static_assert(Log_max_splits >= 1);
+    flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
+}
+
+template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
+void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr size_t smem_size = Kernel_traits::kSmemSize;
+    // printf("smem_size = %d\n", smem_size);
+
+    // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
+    // https://github.com/kokkos/kokkos-kernels/issues/349
+    // https://github.com/HazyResearch/flash-attention/issues/21
+
+    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
+    dim3 grid(num_m_block, params.b, params.h);
+    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
+    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
+    const bool return_softmax = params.p_ptr != nullptr;
+    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+                BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
+                    ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+                        SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+                            // Will only return softmax if dropout, to reduce compilation time.
+                            // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                            // If return_softmax, set IsEvenMNConst to false to reduce number of templates
+                            // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
+                            // If Is_local, set Is_causal to false
+                            auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
+                            // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
+                            // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
+                            // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
+                            if (smem_size >= 48 * 1024) {
+                                C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                    kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                            }
+                            // int ctas_per_sm;
+                            // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
+                            //     &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
+                            // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
+                            kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                            C10_CUDA_KERNEL_LAUNCH_CHECK();
+                        });
+                    });
+                });
+            });
+        });
+    });
+}
+
+template<typename Kernel_traits, bool Is_causal>
+void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
+    static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
+    static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
+    constexpr size_t smem_size = Kernel_traits::kSmemSize;
+    const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
+    dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
+    const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
+    const bool is_even_K = params.d == Kernel_traits::kHeadDim;
+    BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+            LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
+                BOOL_SWITCH(params.num_splits > 1, Split, [&] {
+                    BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
+                        ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
+                            SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
+                                // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
+                                // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
+                                // If Is_local, set Is_causal to false
+                                auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
+                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
+                                // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
+                                if (smem_size >= 48 * 1024) {
+                                    C10_CUDA_CHECK(cudaFuncSetAttribute(
+                                        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+                                }
+                                kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
+                                C10_CUDA_KERNEL_LAUNCH_CHECK();
+                            });
+                        });
+                    });
+                });
+            });
+        });
+    });
+    if (params.num_splits > 1) {
+        // We want kBlockM to be as small as possible for more parallelism.
+        // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
+        // If headdim is divisible by 64, then we set kBlockM = 8, etc.
+        constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
+        dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
+        EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
+            if (params.num_splits <= 2) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 4) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 8) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 16) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 32) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 64) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            } else if (params.num_splits <= 128) {
+                flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
+            }
+            C10_CUDA_KERNEL_LAUNCH_CHECK();
+        });
+    }
+}
+
+template<typename T, int Headdim, bool Is_causal>
+void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int kBlockM = 64;  // Fixed for all head dimensions
+    // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
+    // and for headdim 192 with block size 64 x 128.
+    // Also for headdim 160 with block size 64 x 128 after the rotary addition.
+    constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
+    run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 32;
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 64;
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        if constexpr(!Is_dropout) {
+            // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
+            // Using block size (64 x 256) is 27% slower for seqlen=2k
+            // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 96;
+    auto dprops = at::cuda::getCurrentDeviceProperties();
+    bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+        if (is_sm8x) {
+            if constexpr(!Is_causal) {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            } else {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            }
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+        // These two are always slower
+        // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 128;
+    auto dprops = at::cuda::getCurrentDeviceProperties();
+    bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        if constexpr(!Is_dropout) {
+            // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+            // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
+            if (is_sm8x) {
+                if constexpr(!Is_causal) {
+                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+                } else {
+                    run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+                }
+            } else {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            }
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // 1st ones are good for H100, A100
+            // 2nd one is good for A6000 bc we get slightly better occupancy
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
+            // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
+        }
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 160;
+    auto dprops = at::cuda::getCurrentDeviceProperties();
+    bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        // For A100, H100, 128 x 32 is the fastest.
+        // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
+        // and 128 x 64 with 8 warps is the fastest for non-causal.
+        if (is_sm8x) {
+            if constexpr(!Is_causal) {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            } else {
+                run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+            }
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 192;
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        if constexpr(!Is_dropout) {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 224;
+    int device;
+    cudaGetDevice(&device);
+    int max_smem_per_block;
+    cudaError status_ = cudaDeviceGetAttribute(
+        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+    if (status_ != cudaSuccess) {
+      C10_CUDA_CHECK(status_);
+    }
+    // printf("max_smem_per_block = %d\n", max_smem_per_block);
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) {  // 112 KB
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
+        // If we have N = 32, there are only 1024 elements to load at once, where each load
+        // is 8 elements. This means we can only use 128 threads and not 256 threads.
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+    });
+}
+
+template<typename T, bool Is_causal>
+void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
+    constexpr static int Headdim = 256;
+    int device;
+    cudaGetDevice(&device);
+    int max_smem_per_sm, max_smem_per_block;
+    cudaError status_ = cudaDeviceGetAttribute(
+        &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
+    status_ = cudaDeviceGetAttribute(
+        &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
+    if (status_ != cudaSuccess) {
+      C10_CUDA_CHECK(status_);
+    }
+    // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
+    DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
+        // For A100, we want to run with 128 x 64 (128KB smem).
+        // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
+        if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        } else {
+            run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        }
+        // 64 KB
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
+        // 96 KB
+        // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
+    });
+}

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim128_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim128_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim160_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim160_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim192_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim192_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim224_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim224_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim256_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim256_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim32_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim32_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim64_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim64_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim96_bf16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim96_fp16_causal_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream);

+ 7 - 0
kernels/flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu

@@ -0,0 +1,7 @@
+// Copyright (c) 2023, Tri Dao.
+// Splitting the different head dimensions to different files to speed up compilation.
+// This file is auto-generated. See "generate_kernels.py"
+
+#include "flash_fwd_launch_template.h"
+
+template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream);

+ 180 - 0
kernels/flash_attn/kernel_traits.h

@@ -0,0 +1,180 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include "cute/tensor.hpp"
+
+#include "cutlass/cutlass.h"
+#include "cutlass/layout/layout.h"
+#include <cutlass/numeric_types.h>
+
+using namespace cute;
+
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
+struct Flash_kernel_traits {
+
+#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
+    using Element = elem_type;
+    static constexpr bool Has_cp_async = true;
+#else
+    using Element = cutlass::half_t;
+    static constexpr bool Has_cp_async = false;
+#endif
+
+    using ElementAccum = float;
+    using index_t = int64_t;
+
+#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 800
+    using MMA_Atom_Arch = std::conditional_t<
+        std::is_same_v<elem_type, cutlass::half_t>,
+        MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
+        MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
+    >;
+#else
+    using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
+#endif
+
+#if defined(__CUDA_ARCH__) &&  __CUDA_ARCH__ >= 750
+    using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
+    using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
+#else
+    using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
+    using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
+#endif
+};
+
+// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
+template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
+         typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
+struct Flash_fwd_kernel_traits : public Base {
+    using Element = typename Base::Element;
+    using ElementAccum = typename Base::ElementAccum;
+    using index_t = typename Base::index_t;
+    static constexpr bool Has_cp_async = Base::Has_cp_async;
+    using SmemCopyAtom = typename Base::SmemCopyAtom;
+    using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
+
+    static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
+    static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
+
+    // The number of threads.
+    static constexpr int kNWarps = kNWarps_;
+    static constexpr int kNThreads = kNWarps * 32;
+
+    static constexpr int kBlockM = kBlockM_;
+    static constexpr int kBlockN = kBlockN_;
+    static constexpr int kHeadDim = kHeadDim_;
+    static_assert(kHeadDim % 32 == 0);
+    static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
+    static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
+    static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
+
+    using TiledMma = TiledMMA<
+        typename Base::MMA_Atom_Arch,
+        Layout<Shape<Int<kNWarps>,_1,_1>>,  // 4x1x1 or 8x1x1 thread group
+        Tile<Int<16 * kNWarps>, _16, _16>>;
+
+    using SmemLayoutAtomQ = decltype(
+        composition(Swizzle<kSwizzle, 3, 3>{},
+                    // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
+                    Layout<Shape<_8, Int<kBlockKSmem>>,
+                           Stride<Int<kBlockKSmem>, _1>>{}));
+    using SmemLayoutQ = decltype(tile_to_shape(
+        SmemLayoutAtomQ{},
+        Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+
+    using SmemLayoutKV = decltype(tile_to_shape(
+        SmemLayoutAtomQ{},
+        Shape<Int<kBlockN>, Int<kHeadDim>>{}));
+
+    // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
+    using SmemLayoutVtransposed = decltype(
+        composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
+    using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
+
+    using SmemLayoutAtomO = decltype(
+        composition(Swizzle<kSwizzle, 3, 3>{},
+                    Layout<Shape<Int<8>, Int<kBlockKSmem>>,
+                           Stride<Int<kBlockKSmem>, _1>>{}));
+    using SmemLayoutO = decltype(tile_to_shape(
+        SmemLayoutAtomO{},
+        Shape<Int<kBlockM>, Int<kHeadDim>>{}));
+    using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
+    using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
+
+    static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
+    static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
+    static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
+
+    static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
+    static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
+    // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
+    // For example, for d=128, smem is split into 2 "pages", each page takes care of columns
+    // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
+    // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
+    // to the same banks.
+    static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
+    static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
+    using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
+                                  Stride<Int<kGmemThreadsPerRow>, _1>>;
+
+    // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
+    // from the same address by the same threadblock. This is slightly faster.
+    using Gmem_copy_struct = std::conditional_t<
+        Has_cp_async,
+        SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
+        DefaultCopy
+    >;
+    using GmemTiledCopyQKV = decltype(
+        make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
+                        GmemLayoutAtom{},
+                        Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per read
+
+    // from how many rows does each thread have to fetch
+    static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow);
+    // Here we assign a contiguous tile to each thread, rather than a 1x8 row every 
+    // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
+    // do not cross a page boundary. This way, each thread need only fetch 1 page index per
+    // mainloop iteration. R>udimentary testing shows no slowdown.
+    using GmemTiledCopyQKVPaged = decltype(
+        make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
+                        GmemLayoutAtom{},
+                        Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));
+    using GmemTiledCopyO = decltype(
+        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+                        GmemLayoutAtom{},
+                        Layout<Shape<_1, _8>>{}));  // Val layout, 8 vals per store
+
+    using GmemLayoutAtomOaccum = std::conditional_t<
+        kBlockKSmem == 32,
+        Layout<Shape <_16, _8>,  // Thread layout, 8 threads per row
+               Stride< _8, _1>>,
+        Layout<Shape <_8, _16>,  // Thread layout, 16 threads per row
+               Stride< _16, _1>>
+    >;
+    using GmemTiledCopyOaccum = decltype(
+        make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
+                        GmemLayoutAtomOaccum{},
+                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per store
+    using GmemLayoutAtomRotcossin = GmemLayoutAtom;
+    using GmemTiledCopyRotcossin = decltype(
+        make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
+                        GmemLayoutAtomRotcossin{},
+                        Layout<Shape < _1, _4>>{}));  // Val layout, 4 vals per load
+    using GmemTiledCopyRotcossinCont = decltype(
+        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+                        GmemLayoutAtomRotcossin{},
+                        Layout<Shape < _1, _8>>{}));  // Val layout, 8 vals per load
+    using GmemTiledCopyRotcossinPaged = decltype(
+        make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
+                        GmemLayoutAtomRotcossin{},
+                        Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{}));  // Val layout, 4 vals per load
+    using GmemTiledCopyRotcossinContPaged = decltype(
+        make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
+                        GmemLayoutAtomRotcossin{},
+                        Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{}));  // Val layout, 8 vals per load
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////

+ 213 - 0
kernels/flash_attn/mask.h

@@ -0,0 +1,213 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cute/tensor.hpp>
+
+namespace flash {
+
+using namespace cute;
+
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
+                                  const int col_idx_offset_ = 0) {
+    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    static_assert(Layout::rank == 2, "Only support 2D Tensor");
+    const int lane_id = threadIdx.x % 32;
+    const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+    #pragma unroll
+    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+        const int col_idx_base = col_idx_offset + nj * 8;
+        #pragma unroll
+        for (int j = 0; j < size<1, 0>(tensor); ++j) {
+            const int col_idx = col_idx_base + j;
+            if (col_idx >= max_seqlen_k) {
+                // Without the "make_coord" we get wrong results
+                #pragma unroll
+                for (int mi = 0; mi < size<0>(tensor); ++mi) {
+                    tensor(mi, make_coord(j, nj)) = -INFINITY;
+                }
+            }
+        }
+    }
+}
+
+template <bool HasWSLeft=true, typename Engine, typename Layout>
+__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+                                        const int max_seqlen_k, const int row_idx_offset,
+                                        const int max_seqlen_q, const int warp_row_stride,
+                                        const int window_size_left, const int window_size_right) {
+    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    static_assert(Layout::rank == 2, "Only support 2D Tensor");
+    const int lane_id = threadIdx.x % 32;
+    const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+    #pragma unroll
+    for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+        const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+        #pragma unroll
+        for (int i = 0; i < size<0, 0>(tensor); ++i) {
+            const int row_idx = row_idx_base + i * 8;
+            const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+            const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
+            #pragma unroll
+            for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+                const int col_idx_base = col_idx_offset + nj * 8;
+                #pragma unroll
+                for (int j = 0; j < size<1, 0>(tensor); ++j) {
+                    const int col_idx = col_idx_base + j;
+                    if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
+                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+                    }
+                }
+            }
+            // if (cute::thread0()) {
+            //     printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
+            //     print(tensor(make_coord(i, mi), _));
+            //     // print(tensor(_, j + nj * size<1, 0>(tensor)));
+            // }
+        }
+    }
+}
+
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
+                                         const int max_seqlen_k, const int row_idx_offset,
+                                         const int max_seqlen_q, const int warp_row_stride) {
+    // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
+    apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
+                                          max_seqlen_q, warp_row_stride, -1, 0);
+}
+
+template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void apply_mask_causal_w_idx(
+    Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
+    const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
+{
+    // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
+    static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+    static_assert(Layout1::rank == 2, "Only support 2D Tensor");
+    CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
+    CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
+    #pragma unroll
+    for (int mi = 0; mi < size<0>(tensor); ++mi) {
+        const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
+        #pragma unroll
+        for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
+            if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
+                tensor(mi, ni) = -INFINITY;
+            }
+        }
+        // if (cute::thread0()) {
+        //     printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
+        //     print(tensor(_, make_coord(j, ni)));
+        //     // print(tensor(_, j + ni * size<1, 0>(tensor)));
+        // }
+    }
+}
+
+template <bool Is_causal, bool Is_local, bool Has_alibi>
+struct Mask {
+
+    const int max_seqlen_k, max_seqlen_q;
+    const int window_size_left, window_size_right;
+    const float alibi_slope;
+
+    __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
+                                    const int window_size_left, const int window_size_right,
+                                    const float alibi_slope=0.f)
+        : max_seqlen_k(max_seqlen_k)
+        , max_seqlen_q(max_seqlen_q)
+        , window_size_left(window_size_left)
+        , window_size_right(window_size_right)
+        , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
+    };
+
+    // Causal_mask: whether this particular iteration needs causal masking
+    template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
+    __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
+                                               const int col_idx_offset_,
+                                               const int row_idx_offset,
+                                               const int warp_row_stride) {
+        static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
+        static_assert(Layout::rank == 3, "Only support 3D Tensor");
+        static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
+        static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
+        // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
+        if constexpr (Need_masking) {
+            // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+            Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
+            // Do we need both row and column indices, or just column incides?
+            static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
+            const int lane_id = threadIdx.x % 32;
+            const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
+            if constexpr (Col_idx_only) {
+                #pragma unroll
+                for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+                    const int col_idx_base = col_idx_offset + nj * 8;
+                    #pragma unroll
+                    for (int j = 0; j < size<1, 0>(tensor); ++j) {
+                        const int col_idx = col_idx_base + j;
+                        #pragma unroll
+                        for (int mi = 0; mi < size<0>(tensor); ++mi) {
+                            // No causal, no local
+                            if constexpr (Has_alibi) {
+                                tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
+                            }
+                            if constexpr (!Is_even_MN) {
+                                if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
+                            }
+                        }
+                    }
+                }
+            } else {
+                #pragma unroll
+                for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
+                    const int row_idx_base = row_idx_offset + mi * warp_row_stride;
+                    #pragma unroll
+                    for (int i = 0; i < size<0, 0>(tensor); ++i) {
+                        const int row_idx = row_idx_base + i * 8;
+                        const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
+                        const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
+                        #pragma unroll
+                        for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
+                            const int col_idx_base = col_idx_offset + nj * 8;
+                            #pragma unroll
+                            for (int j = 0; j < size<1, 0>(tensor); ++j) {
+                                const int col_idx = col_idx_base + j;
+                                if constexpr (Has_alibi) {
+                                    if constexpr (Is_causal) {
+                                        tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
+                                    } else {
+                                        tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
+
+                                    }
+                                }
+                                if constexpr (Causal_mask) {
+                                    if (col_idx >= col_idx_limit_right) {
+                                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+                                    }
+                                }
+                                if constexpr (Is_local) {
+                                    if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
+                                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+                                    }
+                                }
+                                if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
+                                    // Causal and Local already handles MN masking
+                                    if (col_idx >= max_seqlen_k) {
+                                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    };
+
+};
+
+} // namespace flash

+ 51 - 0
kernels/flash_attn/philox.cuh

@@ -0,0 +1,51 @@
+// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
+#pragma once
+// Philox CUDA.
+
+namespace flash {
+
+struct ull2 {
+    unsigned long long x;
+    unsigned long long y;
+};
+
+__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
+    uint2 *res;
+    unsigned long long tmp;
+    asm ("mul.wide.u32 %0, %1, %2;\n\t"
+          : "=l"(tmp)
+          : "r"(a), "r"(b));
+    res = (uint2*)(&tmp);
+    return *res;
+}
+
+__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
+    constexpr unsigned long kPhiloxSA = 0xD2511F53;
+    constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
+    uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
+    uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
+    uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
+    return ret;
+}
+
+__forceinline__ __device__ uint4 philox(unsigned long long seed,
+                               unsigned long long subsequence,
+                               unsigned long long offset) {
+    constexpr unsigned long kPhilox10A = 0x9E3779B9;
+    constexpr unsigned long kPhilox10B = 0xBB67AE85;
+    uint2 key = reinterpret_cast<uint2&>(seed);
+    uint4 counter;
+    ull2 *tmp = reinterpret_cast<ull2*>(&counter);
+    tmp->x = offset;
+    tmp->y = subsequence;
+    #pragma unroll
+    for (int i = 0; i < 6; i++) {
+        counter = philox_single_round(counter, key);
+        key.x += (kPhilox10A);
+        key.y += (kPhilox10B);
+    }
+    uint4 output = philox_single_round(counter, key);
+    return output;
+}
+
+} // namespace flash

+ 22 - 0
kernels/flash_attn/registration.h

@@ -0,0 +1,22 @@
+#pragma once
+
+#include <Python.h>
+
+#define _CONCAT(A, B) A##B
+#define CONCAT(A, B) _CONCAT(A, B)
+
+#define _STRINGIFY(A) #A
+#define STRINGIFY(A) _STRINGIFY(A)
+
+// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
+// could be a macro instead of a literal token.
+#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
+
+// REGISTER_EXTENSION allows the shared library to be loaded and initialized
+// via python's import statement.
+#define REGISTER_EXTENSION(NAME)                                               \
+  PyMODINIT_FUNC CONCAT(PyInit_, NAME)() {                                     \
+    static struct PyModuleDef module = {PyModuleDef_HEAD_INIT,                 \
+                                        STRINGIFY(NAME), nullptr, 0, nullptr}; \
+    return PyModule_Create(&module);                                           \
+  }

+ 152 - 0
kernels/flash_attn/rotary.h

@@ -0,0 +1,152 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cute/tensor.hpp>
+
+#include "utils.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace flash {
+
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
+                                               Tensor<Engine1, Layout1> &D,
+                                               Tensor<Engine2, Layout2> const &Cos,
+                                               Tensor<Engine2, Layout2> const &Sin,
+                                               Tensor<Engine3, Layout3> const &identity_MN,
+                                               const int max_MN, const int min_MN,
+                                               const int dim, const int rotary_dim) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));                     // MMA_K
+    static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
+    static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
+    Tensor rCos = make_fragment_like(Cos);
+    Tensor rSin = make_fragment_like(Sin);
+    Tensor rS = make_fragment_like(S);
+    #pragma unroll
+    for (int m = 0; m < size<1>(S); ++m) {
+        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+            #pragma unroll
+            for (int k = 0; k < size<2>(S); ++k) {
+                if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+                    cute::copy(S(_, m, k), rS(_, m, k));
+                    if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+                        cute::copy(Cos(_, m, k), rCos(_, m, k));
+                        cute::copy(Sin(_, m, k), rSin(_, m, k));
+                        Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+                        Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+                        Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+                        #pragma unroll
+                        for (int i = 0; i < size<0>(rS) / 2; ++i) {
+                            float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
+                            float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
+                            S_fp32(2 * i) = real;
+                            S_fp32(2 * i + 1) = imag;
+                        }
+                        // Idk but I need to copy for the convert_type to work
+                        Tensor S_fp32_copy = make_fragment_like(S_fp32);
+                        cute::copy(S_fp32, S_fp32_copy);
+                        using T = typename Engine0::value_type;
+                        Tensor S_og_type = convert_type<T>(S_fp32_copy);
+                        cute::copy(S_og_type, rS(_, m, k));
+                    }
+                    cute::copy(rS(_, m, k), D(_, m, k));
+                } else if (Clear_OOB_K) {
+                    cute::clear(D(_, m, k));
+                }
+            }
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true, bool Clear_OOB_K=true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
+                                              Tensor<Engine1, Layout1> &D,
+                                              Tensor<Engine2, Layout2> const &Cos,
+                                              Tensor<Engine2, Layout2> const &Sin,
+                                              Tensor<Engine3, Layout3> const &identity_MN,
+                                              const int max_MN, const int min_MN,
+                                              const int dim, const int rotary_dim) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin));                     // MMA_K
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
+    static_assert(decltype(size<0>(Cos))::value % 2 == 0);  // Since we do fast conversion from fp16/bf16 to fp32
+    Tensor rCos = make_fragment_like(Cos);
+    Tensor rSin = make_fragment_like(Sin);
+    Tensor rS = make_fragment_like(S);
+    Tensor rS_other = make_fragment_like(rS(_, 0, 0));
+    #pragma unroll
+    for (int m = 0; m < size<1>(S); ++m) {
+        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+            #pragma unroll
+            for (int k = 0; k < size<2>(S); ++k) {
+                if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
+                    cute::copy(S(_, m, k), rS(_, m, k));
+                    if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
+                        const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
+                        Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
+                        cute::copy(gS_other, rS_other);
+                        // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
+                        Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
+                        Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
+                        cute::copy(gCos, rCos(_, m, k));
+                        cute::copy(gSin, rSin(_, m, k));
+                        // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
+                        Tensor S_fp32 = convert_type<float>(rS(_, m, k));
+                        Tensor S_other_fp32 = convert_type<float>(rS_other);
+                        Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
+                        Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
+                        #pragma unroll
+                        for (int i = 0; i < size<0>(rS); ++i) {
+                            S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
+                        }
+                        // Idk but I need to copy for the convert_type to work
+                        Tensor S_fp32_copy = make_fragment_like(S_fp32);
+                        cute::copy(S_fp32, S_fp32_copy);
+                        using T = typename Engine0::value_type;
+                        Tensor S_og_type = convert_type<T>(S_fp32_copy);
+                        cute::copy(S_og_type, rS(_, m, k));
+                        // if (cute::thread0()) { print_tensor(rS(_, m, k)); }
+                    }
+                    cute::copy(rS(_, m, k), D(_, m, k));
+                } else if (Clear_OOB_K) {
+                    cute::clear(D(_, m, k));
+                }
+            }
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace flash

+ 188 - 0
kernels/flash_attn/softmax.h

@@ -0,0 +1,188 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cmath>
+
+#include <cute/tensor.hpp>
+
+#include <cutlass/numeric_types.h>
+
+#include "philox.cuh"
+#include "utils.h"
+
+namespace flash {
+
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
+__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
+    static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+    static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+    CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
+    #pragma unroll
+    for (int mi = 0; mi < size<0>(tensor); mi++) {
+        summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
+        #pragma unroll
+        for (int ni = 1; ni < size<1>(tensor); ni++) {
+            summary(mi) = op(summary(mi), tensor(mi, ni));
+        }
+    }
+}
+
+template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
+__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
+    CUTE_STATIC_ASSERT_V(size(dst) == size(src));
+    #pragma unroll
+    for (int i = 0; i < size(dst); i++){
+        dst(i) = Allreduce<4>::run(src(i), op);
+    }
+}
+
+template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
+__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
+    thread_reduce_<zero_init>(tensor, summary, op);
+    quad_allreduce_(summary, summary, op);
+}
+
+template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
+    MaxOp<float> max_op;
+    reduce_<zero_init>(tensor, max, max_op);
+}
+
+template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
+    SumOp<float> sum_op;
+    thread_reduce_<zero_init>(tensor, sum, sum_op);
+}
+
+// Apply the exp to all the elements.
+template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
+    static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+    static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
+    #pragma unroll
+    for (int mi = 0; mi < size<0>(tensor); ++mi) {
+        // If max is -inf, then all elements must have been -inf (possibly due to masking).
+        // We don't want (-inf - (-inf)) since that would give NaN.
+        // If we don't have float around M_LOG2E the multiplication is done in fp64.
+        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
+        #pragma unroll
+        for (int ni = 0; ni < size<1>(tensor); ++ni)  {
+            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
+            // max * log_2(e)) This allows the compiler to use the ffma
+            // instruction instead of fadd and fmul separately.
+            // The following macro will disable the use of fma.
+            // See: https://github.com/pytorch/pytorch/issues/121558 for more details
+            // This macro is set in PyTorch and not FlashAttention
+            #ifdef UNFUSE_FMA
+                tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
+            #else
+                tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
+            #endif
+        }
+    }
+}
+
+// Apply the exp to all the elements.
+template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
+__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
+    static_assert(Layout0::rank == 2, "Only support 2D Tensor");
+    static_assert(Layout1::rank == 1, "Only support 1D Tensor");
+    CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
+    #pragma unroll
+    for (int mi = 0; mi < size<0>(tensor); ++mi) {
+        MaxOp<float> max_op;
+        max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
+        #pragma unroll
+        for (int ni = 1; ni < size<1>(tensor); ni++) {
+            max(mi) = max_op(max(mi), tensor(mi, ni));
+        }
+        max(mi) = Allreduce<4>::run(max(mi), max_op);
+        // If max is -inf, then all elements must have been -inf (possibly due to masking).
+        // We don't want (-inf - (-inf)) since that would give NaN.
+        const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
+        sum(mi) = 0;
+        #pragma unroll
+        for (int ni = 0; ni < size<1>(tensor); ++ni)  {
+            // Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
+            // max * log_2(e)) This allows the compiler to use the ffma
+            // instruction instead of fadd and fmul separately.
+            tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
+            sum(mi) += tensor(mi, ni);
+        }
+        SumOp<float> sum_op;
+        sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <int kNRows>
+struct Softmax {
+
+    using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
+    TensorT row_max, row_sum;
+
+    __forceinline__ __device__ Softmax() {};
+
+    template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
+    __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
+        // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+        Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+        static_assert(decltype(size<0>(scores))::value == kNRows);
+        if (Is_first) {
+            flash::template reduce_max</*zero_init=*/true>(scores, row_max);
+            flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
+            flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
+        } else {
+            Tensor scores_max_prev = make_fragment_like(row_max);
+            cute::copy(row_max, scores_max_prev);
+            flash::template reduce_max</*zero_init=*/false>(scores, row_max);
+            // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+            Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+            static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
+            #pragma unroll
+            for (int mi = 0; mi < size(row_max); ++mi) {
+                float scores_max_cur = !Check_inf
+                    ? row_max(mi)
+                    : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
+                float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
+                row_sum(mi) *= scores_scale;
+                #pragma unroll
+                for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
+            }
+            flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
+            // We don't do the reduce across threads here since we don't need to use the row_sum.
+            // We do that reduce at the end when we need to normalize the softmax.
+            flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
+        }
+    };
+
+    template<bool Is_dropout=false, bool Split=false, typename Tensor0>
+    __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
+        SumOp<float> sum_op;
+        quad_allreduce_(row_sum, row_sum, sum_op);
+        TensorT lse = make_fragment_like(row_sum);
+        Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+        static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
+        #pragma unroll
+        for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
+            float sum = row_sum(mi);
+            float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
+            lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
+            float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
+            #pragma unroll
+            for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
+        }
+        return lse;
+    };
+};
+
+}  // namespace flash

+ 117 - 0
kernels/flash_attn/static_switch.h

@@ -0,0 +1,117 @@
+// Inspired by
+// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
+// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
+
+#pragma once
+
+/// @param COND       - a boolean expression to switch by
+/// @param CONST_NAME - a name given for the constexpr bool variable.
+/// @param ...       - code to execute for true and false
+///
+/// Usage:
+/// ```
+/// BOOL_SWITCH(flag, BoolConst, [&] {
+///     some_function<BoolConst>(...);
+/// });
+/// ```
+
+#define BOOL_SWITCH(COND, CONST_NAME, ...)      \
+  [&] {                                         \
+    if (COND) {                                 \
+      constexpr static bool CONST_NAME = true;  \
+      return __VA_ARGS__();                     \
+    } else {                                    \
+      constexpr static bool CONST_NAME = false; \
+      return __VA_ARGS__();                     \
+    }                                           \
+  }()
+
+#ifdef FLASHATTENTION_DISABLE_DROPOUT
+  #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;   \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define DROPOUT_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_ALIBI
+  #define ALIBI_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;   \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define ALIBI_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
+  #define EVENK_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = true;    \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define EVENK_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_SOFTCAP
+  #define SOFTCAP_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;    \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define SOFTCAP_SWITCH BOOL_SWITCH
+#endif
+
+#ifdef FLASHATTENTION_DISABLE_LOCAL
+  #define LOCAL_SWITCH(COND, CONST_NAME, ...)   \
+  [&] {                                         \
+    constexpr static bool CONST_NAME = false;    \
+    return __VA_ARGS__();                       \
+  }()
+#else
+  #define LOCAL_SWITCH BOOL_SWITCH
+#endif
+
+#define FP16_SWITCH(COND, ...)               \
+  [&] {                                      \
+    if (COND) {                              \
+      using elem_type = cutlass::half_t;     \
+      return __VA_ARGS__();                  \
+    } else {                                 \
+      using elem_type = cutlass::bfloat16_t; \
+      return __VA_ARGS__();                  \
+    }                                        \
+  }()
+
+#define HEADDIM_SWITCH(HEADDIM, ...)   \
+  [&] {                                    \
+    if (HEADDIM <= 32) {                   \
+      constexpr static int kHeadDim = 32;  \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 64) {            \
+      constexpr static int kHeadDim = 64;  \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 96) {            \
+      constexpr static int kHeadDim = 96;  \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 128) {           \
+      constexpr static int kHeadDim = 128; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 160) {           \
+      constexpr static int kHeadDim = 160; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 192) {           \
+      constexpr static int kHeadDim = 192; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 224) {           \
+      constexpr static int kHeadDim = 224; \
+      return __VA_ARGS__();                \
+    } else if (HEADDIM <= 256) {           \
+      constexpr static int kHeadDim = 256; \
+      return __VA_ARGS__();                \
+    }                                      \
+  }()

+ 440 - 0
kernels/flash_attn/utils.h

@@ -0,0 +1,440 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <assert.h>
+#include <stdint.h>
+#include <stdlib.h>
+
+#include <cuda_fp16.h>
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#include <cuda_bf16.h>
+#endif
+
+#include <cute/tensor.hpp>
+
+#include <cutlass/array.h>
+#include <cutlass/cutlass.h>
+#include <cutlass/numeric_conversion.h>
+#include <cutlass/numeric_types.h>
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+namespace flash {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+__forceinline__ __device__ uint32_t relu2(const uint32_t x);
+
+template<>
+__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
+    uint32_t res;
+    const uint32_t zero = 0u;
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
+#else
+    asm volatile( \
+        "{\n" \
+        "\t .reg .f16x2 sela;\n" \
+        "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
+        "\t and.b32 %0, sela, %1;\n" 
+        "}\n" : "=r"(res) : "r"(x), "r"(zero));
+#endif
+    return res;
+}
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+template<>
+__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
+    uint32_t res;
+    const uint32_t zero = 0u;
+    asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
+    return res;
+}
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+
+template<typename T>
+__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
+
+template<>
+__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
+    uint32_t res;
+    const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
+    const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
+    asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
+    return res;
+}
+
+template<>
+__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
+    uint32_t res;
+    const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
+    const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
+    asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
+    return res;
+}
+
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct MaxOp {
+__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
+};
+
+template <>
+struct MaxOp<float> {
+// This is slightly faster
+__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct SumOp {
+__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int THREADS>
+struct Allreduce {
+    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
+    template<typename T, typename Operator>
+    static __device__ __forceinline__ T run(T x, Operator &op) {
+        constexpr int OFFSET = THREADS / 2;
+        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
+        return Allreduce<OFFSET>::run(x, op);
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+struct Allreduce<2> {
+template<typename T, typename Operator> 
+static __device__ __forceinline__ T run(T x, Operator &op) {
+    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
+    return x;
+}
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
+         typename Tensor2, typename Tensor3, typename Tensor4,
+         typename TiledMma, typename TiledCopyA, typename TiledCopyB,
+         typename ThrCopyA, typename ThrCopyB>
+__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
+                            Tensor4 const& tCsB, TiledMma tiled_mma,
+                            TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
+                            ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
+    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N
+    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K
+    Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
+    CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));            // M
+    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
+    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N
+    if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
+    if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
+    #pragma unroll
+    for (int i = 0; i < size<2>(tCrA); ++i) {
+        if (i < size<2>(tCrA) - 1) {
+            if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
+            if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
+        }
+        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
+         typename TiledMma, typename TiledCopy, typename ThrCopy>
+__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
+                               TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
+                               ThrCopy smem_thr_copy_B) {
+    CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc));                     // MMA_N
+    CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB));                     // MMA_K
+    Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
+    CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // N
+    cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
+    #pragma unroll
+    for (int i = 0; i < size<2>(tCrA); ++i) {
+        if (i < size<2>(tCrA) - 1) {
+            cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
+        }
+        cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+template<typename Layout>
+__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
+    static_assert(decltype(size<0>(acc_layout))::value == 4);
+    static_assert(decltype(rank(acc_layout))::value == 3);
+    auto l = logical_divide(acc_layout, Shape<_2>{});  // ((2, 2), MMA_M, MMA_N)
+    return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
+template<typename MMA_traits, typename Layout>
+__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
+    using X = Underscore;
+    static_assert(decltype(size<0>(acc_layout))::value == 4);
+    static_assert(decltype(rank(acc_layout))::value == 3);
+    constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
+    static_assert(mma_shape_K == 8 || mma_shape_K == 16);
+    if constexpr (mma_shape_K == 8) {
+        return acc_layout;
+    } else {
+        auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))
+        return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
+template<typename Layout>
+__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
+    using X = Underscore;
+    static_assert(decltype(size<0>(acc_layout))::value == 4);
+    static_assert(decltype(rank(acc_layout))::value == 3);
+    auto l = logical_divide(acc_layout, Shape<X, X, _2>{});  // (4, MMA_M, (2, MMA_N / 2)))
+    return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename To_type, typename Engine, typename Layout>
+__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
+    using From_type = typename Engine::value_type;
+    constexpr int numel = decltype(size(tensor))::value;
+    cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
+    // HACK: this requires tensor to be "contiguous"
+    auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
+    return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <typename Engine, typename Layout>
+__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
+    constexpr int numel = decltype(size(tensor))::value;
+    static_assert(numel % 2 == 0);
+    using value_t = typename Engine::value_type;
+    // HACK: this requires tensor to be "contiguous"
+    Tensor tensor_uint32 = recast<uint32_t>(tensor);
+    #pragma unroll
+    for (int i = 0; i < size(tensor_uint32); ++i) {
+        tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
+template <typename To_type, typename Engine, typename Layout>
+__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
+    using From_type = typename Engine::value_type;
+    static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
+    static_assert(std::is_same_v<float, From_type>);
+    constexpr int numel = decltype(size(tensor))::value;
+    static_assert(numel % 2 == 0);
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    // HACK: this requires tensor to be "contiguous"
+    Tensor tensor_float2 = recast<float2>(tensor);
+    Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
+    #pragma unroll
+    for (int i = 0; i < size(out_uint32); ++i) {
+        out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
+    }
+    Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
+#else
+    Tensor out = flash::convert_type<To_type>(tensor);
+    flash::relu_(out);
+#endif
+    return out;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Blocks until all but N previous cp.async.commit_group operations have committed.
+// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
+// (which is equivalent to commit_group then wait_group 0).
+// Instead we just call cp.async.wait_group 0, which is slightly faster.
+// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
+template <int N>
+CUTE_HOST_DEVICE
+void cp_async_wait() {
+#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
+    asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// resolves offset of a slice of a paged kv copy from gmem.
+// assumes that the tensor has already been positioned at the correct head.
+template <typename Kernel_traits>
+__forceinline__ __device__
+int64_t resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, 
+                            const int* block_table, const int page_stride, const int row_stride) {
+    constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow;
+    constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread;
+    constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad;
+    constexpr int kBlockN = Kernel_traits::kBlockN;
+    
+    const int64_t col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad;
+    const int64_t block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread;
+    const int64_t global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN;
+    const int64_t page_offset = global_row_offset % page_block_size;
+    const int64_t virtual_page_idx = global_row_offset / page_block_size;
+
+    return ((int64_t) block_table[virtual_page_idx]) * ((int64_t) page_stride)
+        + page_offset * ((int64_t) row_stride)
+        + col_offset;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),         
+// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
+// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
+template <class Shape, class Stride>
+__forceinline__ __device__
+auto reshape_thread_tile(Layout<Shape, Stride> l) {
+    return make_layout(append(get<0>(l.shape()), get<2>(l.shape())),
+                        append(get<0>(l.stride()), get<2>(l.stride())));
+}
+
+// reshapes and flattens the thread tile layout. A separate function is needed for the case where
+// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
+// for the case of swizzled layouts
+template <class Shape, class Stride>
+__forceinline__ __device__
+auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
+    auto mode_0 = filter(flatten(get<0>(l)));
+    return make_layout(append(mode_0.shape(), get<2>(l.shape())),
+                        append(mode_0.stride(), get<2>(l.stride())));
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
+          typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
+                            Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
+                            Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    // There's no case where !Clear_OOB_K && Clear_OOB_MN
+    static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
+    #pragma unroll
+    for (int m = 0; m < size<1>(S); ++m) {
+        if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+            #pragma unroll
+            for (int k = 0; k < size<2>(S); ++k) {
+                if (Is_even_K || predicate_K(k)) {
+                    cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
+                } else if (Clear_OOB_K) {
+                    cute::clear(D(_, m, k));
+                }
+            }
+        } else if (Clear_OOB_MN) {
+            cute::clear(D(_, m, _));
+        }
+    }
+    // TD [2023-04-13]: Strange that the code below can cause race condition.
+    // I think it's because the copies are under an if statement.
+    // if (Is_even_K) {
+    //     #pragma unroll
+    //     for (int m = 0; m < size<1>(S); ++m) {
+    //         if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+    //             copy(tiled_copy, S(_, m, _), D(_, m, _));
+    //         } else if (Clear_OOB_MN) {
+    //             clear(D(_, m, _));
+    //         }
+    //     }
+    // } else {  // It's slightly faster in this case if iterate over K first
+    //     #pragma unroll
+    //     for (int k = 0; k < size<2>(S); ++k) {
+    //         if (predicate_K(k)) {
+    //             #pragma unroll
+    //             for (int m = 0; m < size<1>(S); ++m) {
+    //                 if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
+    //                     copy(tiled_copy, S(_, m, k), D(_, m, k));
+    //                 } else if (Clear_OOB_MN) {
+    //                     clear(D(_, m, k));
+    //                 }
+    //             }
+    //         } else if (Clear_OOB_K) {  // There's no case where !Clear_OOB_K && Clear_OOB_MN
+    //             if (Clear_OOB_MN || Is_even_MN) {
+    //                 clear(D(_, _, k));
+    //             } else {
+    //                 #pragma unroll
+    //                 for (int m = 0; m < size<1>(S); ++m) {
+    //                     if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
+    //                         clear(D(_, m, k));
+    //                     }
+    //                 }
+    //             }
+    //         }
+    //     }
+    // }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <bool Is_even_K=true,
+          typename Engine0, typename Layout0, typename Engine1, typename Layout1,
+          typename Engine2, typename Layout2, typename Engine3, typename Layout3>
+__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
+                                      Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
+                                      Tensor<Engine3, Layout3> const &predicate_K,
+                                      const int max_MN=0, const int min_MN=0) {
+    CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
+    CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D));                     // MMA
+    CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D));                     // MMA_M
+    CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D));                     // MMA_K
+    // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
+    #pragma unroll
+    for (int m = 0; m < size<1>(S); ++m) {
+        // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+        if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
+            // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
+            #pragma unroll
+            for (int k = 0; k < size<2>(S); ++k) {
+                if (Is_even_K || predicate_K(k)) {
+                    cute::copy(S(_, m, k), D(_, m, k));
+                }
+            }
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace flash

+ 19 - 0
kernels/torch_bindings.cpp

@@ -3,6 +3,7 @@
 #include "ops.h"
 #include "core/registration.h"
 #include "quantization/quant_ops.h"
+#include "flash_attn/flash_api.h"
 
 #include <torch/library.h>
 
@@ -447,6 +448,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "Tensor? final_states_out_,"
       "bool silu_activation) -> Tensor");
   ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
+
+  ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
+          "float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
+          "float softcap, bool return_softmax, Generator? gen) -> Tensor[]");
+  ops.impl("fwd", torch::kCUDA, &mha_fwd);
+
+  ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
+          "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? block_table, Tensor? alibi_slopes, "
+          "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
+          "bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
+          "Generator? gen) -> Tensor[]");
+  ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
+
+  ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
+          "Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? block_table, Tensor? alibi_slopes, "
+          "Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
+          "float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
+  ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
 #endif
 }
 

+ 0 - 1
requirements-cuda.txt

@@ -7,7 +7,6 @@ torch == 2.4.0; platform_system == 'Linux'
 torchvision == 0.19; platform_system == 'Linux'  # for phi3v
 xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
 triton >= 2.2.1; platform_system == 'Linux'
-aphrodite-flash-attn == 2.6.1.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
 
 # Windows dependencies
 winloop; platform_system == 'Windows'