Jelajahi Sumber

rocm: add custom paged attention kernels for ROCm (#1043)

AlpinDale 2 bulan lalu
induk
melakukan
4a7cb8f232

+ 25 - 0
CMakeLists.txt

@@ -342,10 +342,35 @@ define_gpu_extension_target(
   USE_SABI 3
   WITH_SOABI)
 
+
+if(APHRODITE_GPU_LANG STREQUAL "HIP")
+  #
+  # _rocm_C extension
+  #
+  set(APHRODITE_ROCM_EXT_SRC
+    "kernels/rocm/torch_bindings.cpp"
+    "kernels/rocm/attention.cu")
+  define_gpu_extension_target(
+    _rocm_C
+    DESTINATION aphrodite
+    LANGUAGE ${APHRODITE_GPU_LANG}
+    SOURCES ${APHRODITE_ROCM_EXT_SRC}
+    COMPILE_FLAGS ${APHRODITE_GPU_FLAGS}
+    ARCHITECTURES ${APHRODITE_GPU_ARCHES}
+    USE_SABI 3
+    WITH_SOABI)
+endif()
+
+
 if(APHRODITE_GPU_LANG STREQUAL "CUDA" OR APHRODITE_GPU_LANG STREQUAL "HIP")
   message(STATUS "Enabling C extension.")
   add_dependencies(default _C)
 
   message(STATUS "Enabling moe extension.")
   add_dependencies(default _moe_C)
+endif()
+
+if(APHRODITE_GPU_LANG STREQUAL "HIP")
+  message(STATUS "Enabling rocm extension.")
+  add_dependencies(default _rocm_C)
 endif()

+ 27 - 0
aphrodite/_custom_ops.py

@@ -16,6 +16,9 @@ if not current_platform.is_tpu():
     except ImportError as e:
         logger.warning(f"Failed to import from aphrodite._C with {e}")
 
+if current_platform.is_rocm():
+    import aphrodite._rocm_C  # noqa: F401
+
 with contextlib.suppress(ImportError):
     # ruff: noqa: F401
     import aphrodite._moe_C
@@ -127,6 +130,30 @@ def paged_attention_v2(
         blocksparse_block_size, blocksparse_head_sliding_step)
 
 
+def paged_attention_rocm(
+    out: torch.Tensor,
+    exp_sum: torch.Tensor,
+    max_logits: torch.Tensor,
+    tmp_out: torch.Tensor,
+    query: torch.Tensor,
+    key_cache: torch.Tensor,
+    value_cache: torch.Tensor,
+    num_kv_heads: int,
+    scale: float,
+    block_tables: torch.Tensor,
+    seq_lens: torch.Tensor,
+    block_size: int,
+    max_seq_len: int,
+    alibi_slopes: Optional[torch.Tensor],
+    kv_cache_dtype: str,
+) -> None:
+    torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
+                                      key_cache, value_cache, num_kv_heads,
+                                      scale, block_tables, seq_lens,
+                                      block_size, max_seq_len, alibi_slopes,
+                                      kv_cache_dtype)
+
+
 # pos encoding ops
 def rotary_embedding(
     positions: torch.Tensor,

+ 69 - 14
aphrodite/attention/backends/rocm_flash_attn.py

@@ -6,6 +6,7 @@ import torch
 from loguru import logger
 
 import aphrodite.common.envs as envs
+from aphrodite import _custom_ops as ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
@@ -15,6 +16,8 @@ from aphrodite.attention.backends.utils import (CommonAttentionState,
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
+_PARTITION_SIZE = 256
+ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
 
 class ROCmFlashAttentionBackend(AttentionBackend):
 
@@ -479,20 +482,61 @@ class ROCmFlashAttentionImpl(AttentionImpl):
 
         if decode_meta := attn_metadata.decode_metadata:
             # Decoding run.
-            output[num_prefill_tokens:] = PagedAttention.forward_decode(
-                decode_query,
-                key_cache,
-                value_cache,
-                decode_meta.block_tables,
-                decode_meta.seq_lens_tensor,
-                decode_meta.max_decode_seq_len,
-                self.kv_cache_dtype,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-                k_scale,
-                v_scale,
-            )
+            # Whether to use rocm custom paged attention or not
+            num_seqs, num_heads, head_size = decode_query.shape
+            block_size = value_cache.shape[3]
+            gqa_ratio = num_heads // self.num_kv_heads
+            use_custom = use_rocm_custom_paged_attention(
+                decode_query.dtype, head_size, block_size, self.kv_cache_dtype,
+                gqa_ratio, decode_meta.max_decode_seq_len)
+            if use_custom:
+                max_seq_len = decode_meta.max_decode_seq_len
+                max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
+                                      _PARTITION_SIZE)
+                assert _PARTITION_SIZE % block_size == 0
+                tmp_output = torch.empty(
+                    size=(num_seqs, num_heads, max_num_partitions, head_size),
+                    dtype=output.dtype,
+                    device=output.device,
+                )
+                exp_sums = torch.empty(
+                    size=(num_seqs, num_heads, max_num_partitions),
+                    dtype=torch.float32,
+                    device=output.device,
+                )
+                max_logits = torch.empty_like(exp_sums)
+                ops.paged_attention_rocm(
+                    output[num_prefill_tokens:],
+                    exp_sums,
+                    max_logits,
+                    tmp_output,
+                    decode_query,
+                    key_cache,
+                    value_cache,
+                    self.num_kv_heads,
+                    self.scale,
+                    decode_meta.block_tables,
+                    decode_meta.seq_lens_tensor,
+                    block_size,
+                    max_seq_len,
+                    self.alibi_slopes,
+                    self.kv_cache_dtype,
+                )
+            else:
+                output[num_prefill_tokens:] = PagedAttention.forward_decode(
+                    decode_query,
+                    key_cache,
+                    value_cache,
+                    decode_meta.block_tables,
+                    decode_meta.seq_lens_tensor,
+                    decode_meta.max_decode_seq_len,
+                    self.kv_cache_dtype,
+                    self.num_kv_heads,
+                    self.scale,
+                    self.alibi_slopes,
+                    k_scale,
+                    v_scale,
+                )
 
         # Reshape the output tensor.
         return output.view(num_tokens, hidden_size)
@@ -531,3 +575,14 @@ def _sdpa_attention(
             start = end
 
     return output
+
+
+def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
+                                    block_size: int, kv_cache_dtype: str,
+                                    gqa_ratio: int, max_seq_len: int) -> bool:
+    # rocm custom page attention not support on navi (gfx1*)
+    return (not ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
+            and (head_size == 64 or head_size == 128)
+            and (block_size == 16 or block_size == 32)
+            and kv_cache_dtype == "auto"
+            and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)

+ 958 - 0
kernels/rocm/attention.cu

@@ -0,0 +1,958 @@
+/*
+ * Copyright (c) 2024, The vLLM team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <torch/all.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <hip/hip_bf16.h>
+#include <algorithm>
+#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
+                           defined(__gfx941__) || defined(__gfx942__))
+  #define __HIP__MI300_MI250__
+#endif
+#if defined(NDEBUG)
+  #undef NDEBUG
+  #include <assert.h>
+  #define UNREACHABLE_CODE assert(false);
+  #define NDEBUG
+#else
+  #define UNREACHABLE_CODE assert(false);
+#endif
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+#define WARP_SIZE 64
+#if defined(__HIP__MI300_MI250__)  // TODO: Add NAVI support
+  #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
+  #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
+using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
+using float16x4 =
+    __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
+typedef float16x4 _Half4;
+typedef struct _Half8 {
+  _Half4 xy[2];
+} _Half8;
+using bit16_t = uint16_t;
+using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
+typedef bit16x4 _B16x4;
+typedef struct _B16x8 {
+  _B16x4 xy[2];
+} _B16x8;
+////// Non temporal load stores ///////
+template <typename T>
+__device__ __forceinline__ T load(T* addr) {
+  return addr[0];
+}
+template <typename T>
+__device__ __forceinline__ void store(T value, T* addr) {
+  addr[0] = value;
+}
+template <typename T, int absz, int cbid, int blgp>
+__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
+                                                  const _B16x4& inpB,
+                                                  const floatx4& inpC) {
+  if constexpr (std::is_same<T, _Float16>::value) {
+    return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid,
+                                              blgp);
+  } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
+    return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid,
+                                                  blgp);
+  } else {
+    static_assert(false, "unsupported 16b dtype");
+  }
+}
+template <typename T>
+__device__ __forceinline__ float to_float(const T& inp) {
+  if constexpr (std::is_same<T, _Float16>::value) {
+    return (float)inp;
+  } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
+    return __bfloat162float(inp);
+  } else {
+    static_assert(false, "unsupported 16b dtype");
+  }
+}
+template <typename T>
+__device__ __forceinline__ T from_float(const float& inp) {
+  if constexpr (std::is_same<T, _Float16>::value) {
+    return (_Float16)inp;
+  } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
+    return __float2bfloat16(inp);
+  } else {
+    static_assert(false, "unsupported 16b dtype");
+  }
+}
+template <typename T>
+__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
+  union tmpcvt {
+    uint16_t u;
+    _Float16 f;
+    __hip_bfloat16 b;
+  } t16;
+  _B16x4 ret;
+  if constexpr (std::is_same<T, _Float16>::value) {
+  #pragma unroll
+    for (int i = 0; i < 4; i++) {
+      t16.f = (_Float16)inp[i];
+      ret[i] = t16.u;
+    }
+    return ret;
+  } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
+  #pragma unroll
+    for (int i = 0; i < 4; i++) {
+      t16.b = __float2bfloat16(inp[i]);
+      ret[i] = t16.u;
+    }
+    return ret;
+  } else {
+    static_assert(false, "unsupported 16b dtype");
+  }
+}
+template <typename T>
+__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
+                                        const _B16x4& inp2) {
+  union tmpcvt {
+    uint16_t u;
+    _Float16 f;
+    __hip_bfloat16 b;
+  } t1, t2, res;
+  _B16x4 ret;
+  if constexpr (std::is_same<T, _Float16>::value) {
+  #pragma unroll
+    for (int i = 0; i < 4; i++) {
+      t1.u = inp1[i];
+      t2.u = inp2[i];
+      res.f = t1.f + t2.f;
+      ret[i] = res.u;
+    }
+    return ret;
+  } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
+  #pragma unroll
+    for (int i = 0; i < 4; i++) {
+      t1.u = inp1[i];
+      t2.u = inp2[i];
+      res.b = t1.b + t2.b;
+      ret[i] = res.u;
+    }
+    return ret;
+  } else {
+    static_assert(false, "unsupported 16b dtype");
+  }
+}
+///////////////////////////////////////
+// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
+// block (partition size)
+template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
+          int GQA_RATIO>
+__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
+    const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                           // head_size/x, block_size, x]
+    const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                           // head_size, block_size]
+    const int num_kv_heads, const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
+    float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                     // max_num_partitions]
+    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
+                                 // head_size]
+    scalar_t* __restrict__ final_out,  // [num_seqs, num_heads, head_size]
+  #if 0
+  scalar_t* __restrict__ qk_out,             // [num_heads, num_seqs, max_ctx_blocks,block_size]
+  #endif
+    int max_ctx_blocks) {
+  constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
+  const int warpid = threadIdx.x / WARP_SIZE;
+  const int laneid = threadIdx.x % WARP_SIZE;
+  const int lane4id = laneid % 4;
+  const int seq_idx = blockIdx.x;
+  const int partition_idx = blockIdx.y;
+  const int partition_size = blockDim.x;
+  const int max_num_partitions = gridDim.y;
+  const int context_len = context_lens[seq_idx];
+  const int partition_start_token_idx = partition_idx * partition_size;
+  // exit if partition is out of context for seq
+  if (partition_start_token_idx >= context_len) {
+    return;
+  }
+  constexpr int QHLOOP =
+      DIVIDE_ROUND_UP(GQA_RATIO, 4);  // each 4 lanes fetch 4 different qheads,
+                                      // total qheads =8, so qhloop is 2
+  constexpr int GQA_RATIO4 = 4 * QHLOOP;
+  __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
+  __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1];
+  _B16x8 Qlocal[QHLOOP];
+  constexpr int x = 16 / sizeof(scalar_t);
+  constexpr int KHELOOP = HEAD_SIZE / x;
+  _B16x8 Klocal[KHELOOP];
+  constexpr int VHELOOP =
+      HEAD_SIZE /
+      WARP_SIZE;  // v head_size dimension is distributed across lanes
+  constexpr int VTLOOP = 8;  // 16 separate 4xtokens across warp -> 16/2
+                             // 8xtokens
+  _B16x8 Vlocal[VHELOOP][VTLOOP];
+  floatx4 dout[QHLOOP];
+  float qk_max[QHLOOP];
+  #pragma unroll
+  for (int h = 0; h < QHLOOP; h++) {
+    dout[h] = {0};
+    qk_max[h] = -FLT_MAX;
+  }
+  const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
+  const int wg_start_kv_head_idx = blockIdx.z;
+  const int warp_start_token_idx =
+      partition_start_token_idx + warpid * WARP_SIZE;
+  if (warp_start_token_idx >= context_len) {  // warp out of context
+  #pragma unroll
+    for (int h = 0; h < GQA_RATIO4; h++) {
+      shared_qk_max[warpid][h] = -FLT_MAX;
+      shared_exp_sum[warpid][h] = 0.0f;
+    }
+  } else {  // warp within context
+    const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+    const int last_ctx_block = num_context_blocks - 1;
+    const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+    const int local_token_idx = threadIdx.x;
+    const int global_token_idx = partition_start_token_idx + local_token_idx;
+    const int block_idx = (global_token_idx < context_len)
+                              ? global_token_idx / BLOCK_SIZE
+                              : last_ctx_block;
+    // fetch block number for q and k
+    // int32 physical_block_number leads to overflow when multiplied with
+    // kv_block_stride
+    const int64_t physical_block_number =
+        static_cast<int64_t>(block_table[block_idx]);
+    // fetch vphysical block numbers up front
+    constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
+    int vphysical_blocks[VBLOCKS];
+    const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
+  #pragma unroll
+    for (int b = 0; b < VBLOCKS; b++) {
+      const int vblock_idx = warp_start_block_idx + b;
+      const int vblock_idx_ctx =
+          (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
+      vphysical_blocks[b] = block_table[vblock_idx_ctx];
+    }
+    // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
+    const scalar_t* q_ptr =
+        q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
+    const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
+    const int qhead_elemh8 = laneid / 4;
+  #pragma unroll
+    for (int h = 0; h < QHLOOP - 1; h++) {
+      const int qhead_idx = h * 4 + lane4id;
+      Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
+    }
+    const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
+    if (final_qhead_idx < GQA_RATIO) {
+      Qlocal[QHLOOP - 1] =
+          q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
+    } else {
+      Qlocal[QHLOOP - 1].xy[0] = {0};
+      Qlocal[QHLOOP - 1].xy[1] = {0};
+    }
+    const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
+                            wg_start_kv_head_idx * kv_head_stride;
+    const int physical_block_offset =
+        local_token_idx % BLOCK_SIZE;  // since x=half8, physical_block_offset
+                                       // is already cast as _H8
+    const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
+  #pragma unroll
+    for (int d = 0; d < KHELOOP; d++) {
+      Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
+    }
+    float alibi_slope[QHLOOP];
+    if (alibi_slopes != nullptr) {
+  #pragma unroll
+      for (int h = 0; h < QHLOOP; h++) {
+        const int qhead_idx = h * 4 + lane4id;
+        alibi_slope[h] = (qhead_idx < GQA_RATIO)
+                             ? alibi_slopes[wg_start_head_idx + qhead_idx]
+                             : 0.f;
+      }
+    }
+    const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
+    const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
+  // iterate over each v block
+  #pragma unroll
+    for (int b = 0; b < VBLOCKS; b++) {
+      // int32 physical_block_number leads to overflow when multiplied with
+      // kv_block_stride
+      const int64_t vphysical_block_number =
+          static_cast<int64_t>(vphysical_blocks[b]);
+      const _B16x8* v_ptrh8b =
+          v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
+  // iterate over each head elem (within head_size)
+  #pragma unroll
+      for (int h = 0; h < VHELOOP; h++) {
+        const int head_size_elem = h * WARP_SIZE + laneid;
+        const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
+  // iterate over all velems within block
+  #pragma unroll
+        for (int d = 0; d < BLOCK_SIZE / 8; d++) {
+          Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
+        }
+      }
+    }
+  #pragma unroll
+    for (int h = 0; h < QHLOOP; h++) {
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0],
+                                                  Klocal[0].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1],
+                                                  Klocal[0].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0],
+                                                  Klocal[1].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1],
+                                                  Klocal[1].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0],
+                                                  Klocal[2].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1],
+                                                  Klocal[2].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0],
+                                                  Klocal[3].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1],
+                                                  Klocal[3].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0],
+                                                  Klocal[4].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1],
+                                                  Klocal[4].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0],
+                                                  Klocal[5].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1],
+                                                  Klocal[5].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0],
+                                                  Klocal[6].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1],
+                                                  Klocal[6].xy[1], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0],
+                                                  Klocal[7].xy[0], dout[h]);
+      dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1],
+                                                  Klocal[7].xy[1], dout[h]);
+      if constexpr (KHELOOP > 8) {
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0],
+                                                    Klocal[8].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1],
+                                                    Klocal[8].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0],
+                                                    Klocal[9].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1],
+                                                    Klocal[9].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0],
+                                                     Klocal[10].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1],
+                                                     Klocal[10].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0],
+                                                     Klocal[11].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1],
+                                                     Klocal[11].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0],
+                                                     Klocal[12].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1],
+                                                     Klocal[12].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
+                                                     Klocal[13].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
+                                                     Klocal[13].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
+                                                     Klocal[14].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
+                                                     Klocal[14].xy[1], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
+                                                     Klocal[15].xy[0], dout[h]);
+        dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
+                                                     Klocal[15].xy[1], dout[h]);
+      }  // KHELOOP>8
+      dout[h] *= scale;
+    }
+  // transpose dout so that 4 token ids are in each lane, and 4 heads are across
+  // 4 lanes
+  #pragma unroll
+    for (int h = 0; h < QHLOOP; h++) {
+      floatx4 tmp = {0};
+  #pragma unroll
+      for (int i = 0; i < 4; i++) {
+        const float B = (lane4id == i) ? 1.0f : 0.0f;
+        // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
+        tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
+        // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
+      }
+      dout[h] = tmp;
+    }
+    const int lane4_token_idx = 4 * (global_token_idx >> 2);
+    const int alibi_offset = lane4_token_idx - context_len + 1;
+    if (alibi_slopes != nullptr) {
+  #pragma unroll
+      for (int h = 0; h < QHLOOP; h++) {
+  #pragma unroll
+        for (int i = 0; i < 4; i++) {
+          dout[h][i] += alibi_slope[h] * (alibi_offset + i);
+        }
+      }
+    }
+  #pragma unroll
+    for (int h = 0; h < QHLOOP; h++) {
+      qk_max[h] = -FLT_MAX;
+  #pragma unroll
+      for (int i = 0; i < 4; i++) {
+        qk_max[h] = (lane4_token_idx + i < context_len)
+                        ? fmaxf(qk_max[h], dout[h][i])
+                        : qk_max[h];
+      }
+  #pragma unroll
+      for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
+        qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
+      }
+    }
+    float exp_sum[QHLOOP];
+  #pragma unroll
+    for (int h = 0; h < QHLOOP; h++) {
+      exp_sum[h] = 0.0f;
+  #pragma unroll
+      for (int i = 0; i < 4; i++) {
+        dout[h][i] = (lane4_token_idx + i < context_len)
+                         ? __expf(dout[h][i] - qk_max[h])
+                         : 0.0f;
+        exp_sum[h] += dout[h][i];
+      }
+  #pragma unroll
+      for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
+        exp_sum[h] += __shfl_xor(exp_sum[h], mask);
+      }
+    }
+  #pragma unroll
+    for (int h = 0; h < QHLOOP; h++) {
+      const int head_idx = 4 * h + lane4id;
+      shared_qk_max[warpid][head_idx] = qk_max[h];
+      shared_exp_sum[warpid][head_idx] = exp_sum[h];
+    }
+  }  // warp within context
+  __syncthreads();
+  const int num_heads = gridDim.z * GQA_RATIO;
+  float* max_logits_ptr =
+      max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
+  float* exp_sums_ptr =
+      exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
+  #pragma unroll
+  for (int h = 0; h < QHLOOP; h++) {
+    float global_qk_max = -FLT_MAX;
+    float warp_qk_max[NWARPS];
+    const int head_idx = 4 * h + lane4id;
+  #pragma unroll
+    for (int w = 0; w < NWARPS; w++) {
+      warp_qk_max[w] = shared_qk_max[w][head_idx];
+      global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
+    }
+    float global_exp_sum = 0.0f;
+  #pragma unroll
+    for (int w = 0; w < NWARPS; w++) {
+      global_exp_sum +=
+          shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
+    }
+    if (head_idx < GQA_RATIO) {
+      max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
+          global_qk_max;
+      exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
+          global_exp_sum;
+    }
+    const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) *
+                                       __expf(qk_max[h] - global_qk_max);
+    dout[h] *= global_inv_sum_scale;
+  }
+  // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
+  // are 4x16 tokens across warp
+  _B16x4 logits[QHLOOP];
+  #pragma unroll
+  for (int h = 0; h < QHLOOP; h++) {
+    logits[h] = from_floatx4<scalar_t>(dout[h]);
+  }
+  __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
+  if (warp_start_token_idx >= context_len) {  // warp out of context
+  #pragma unroll
+    for (int qh = 0; qh < QHLOOP; qh++) {
+  #pragma unroll
+      for (int vh = 0; vh < VHELOOP; vh++) {
+        vout_shared[qh][vh][laneid][warpid] = {0};
+      }
+    }
+  } else {  // warp in context
+  // iterate across heads
+  #pragma unroll
+    for (int qh = 0; qh < QHLOOP; qh++) {
+  // iterate over each v head elem (within head_size)
+  #pragma unroll
+      for (int vh = 0; vh < VHELOOP; vh++) {
+        floatx4 acc = {0};
+        // iterate over tokens
+        acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
+                                                acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
+                                                 Vlocal[vh][5].xy[0], acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
+                                                 Vlocal[vh][5].xy[1], acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
+                                                 Vlocal[vh][6].xy[0], acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
+                                                 Vlocal[vh][6].xy[1], acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
+                                                 Vlocal[vh][7].xy[0], acc);
+        acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
+                                                 Vlocal[vh][7].xy[1], acc);
+        vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
+      }
+    }
+  }  // warp in context
+  __syncthreads();
+  if (warpid == 0) {
+    _B16x4 vout[QHLOOP][VHELOOP];
+    // iterate across heads
+    scalar_t* out_ptr;
+    int out_num_partitions;
+    if (context_len > partition_size) {
+      out_num_partitions = max_num_partitions;
+      out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+                partition_idx * HEAD_SIZE;
+    } else {
+      out_num_partitions = 1;
+      out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
+    }
+  #pragma unroll
+    for (int qh = 0; qh < QHLOOP; qh++) {
+  // iterate over each v head elem (within head_size)
+  #pragma unroll
+      for (int vh = 0; vh < VHELOOP; vh++) {
+        vout[qh][vh] = {0};
+  #pragma unroll
+        for (int w = 0; w < NWARPS; w++) {
+          vout[qh][vh] =
+              addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
+        }
+        const int head_size_elem = vh * WARP_SIZE + laneid;
+        bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
+  #pragma unroll
+        for (int i = 0; i < 4; i++) {
+          const int head_idx = 4 * qh + i;
+          if (head_idx < GQA_RATIO) {
+            out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
+                            HEAD_SIZE +
+                        head_size_elem] = vout[qh][vh][i];
+          }
+        }
+      }
+    }
+  }
+}
+// Grid: (num_heads, num_seqs).
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE>
+__global__
+__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
+    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
+                                           // max_num_partitions, head_size]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_partitions) {
+  const int num_heads = gridDim.x;
+  const int head_idx = blockIdx.x;
+  const int seq_idx = blockIdx.y;
+  const int context_len = context_lens[seq_idx];
+  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+  if (num_partitions == 1) {
+    // if num_partitions==1, main kernel will write to out directly, no work in
+    // reduction kernel
+    return;
+  }
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int warpid = threadIdx.x / WARP_SIZE;
+  const int laneid = threadIdx.x % WARP_SIZE;
+  __shared__ float shared_global_exp_sum;
+  __shared__ float shared_exp_sums[2 * WARP_SIZE];
+  if (warpid == 0) {
+    const float* max_logits_ptr = max_logits +
+                                  seq_idx * num_heads * max_num_partitions +
+                                  head_idx * max_num_partitions;
+    // valid partition is the last valid partition in case threadid > num
+    // partitions
+    const int valid_partition =
+        (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1;
+    const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions)
+                                     ? WARP_SIZE + threadIdx.x
+                                     : num_partitions - 1;
+    float reg_max_logit = max_logits_ptr[valid_partition];
+    float reg_max_logit2 = max_logits_ptr[valid_partition2];
+    float max_logit = fmaxf(reg_max_logit, reg_max_logit2);
+  #pragma unroll
+    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+      max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
+    }
+    const float* exp_sums_ptr = exp_sums +
+                                seq_idx * num_heads * max_num_partitions +
+                                head_idx * max_num_partitions;
+    float global_exp_sum = 0.0f;
+    float rescaled_exp_sum = exp_sums_ptr[valid_partition];
+    float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2];
+    rescaled_exp_sum *=
+        (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f;
+    rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions)
+                             ? expf(reg_max_logit2 - max_logit)
+                             : 0.0f;
+    global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2;
+    shared_exp_sums[threadIdx.x] = rescaled_exp_sum;
+    shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2;
+  #pragma unroll
+    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+      global_exp_sum += __shfl_xor(global_exp_sum, mask);
+    }
+    if (threadIdx.x == 0) {
+      shared_global_exp_sum = global_exp_sum;
+    }
+  }  // warpid == 0
+  const scalar_t* tmp_out_ptr =
+      tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+      head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
+  constexpr int MAX_NPAR = 64;
+  scalar_t tmps[MAX_NPAR];
+  const float dzero = 0.0f;
+  #pragma unroll
+  for (int j = 0; j < MAX_NPAR; j++) {
+    tmps[j] = from_float<scalar_t>(dzero);
+  }
+  const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
+  const int num_partition_offset = (num_partitions)*HEAD_SIZE;
+  int idx = 0;
+  constexpr int JCHUNK = 16;
+  #pragma unroll
+  for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
+    // lastj is last valid partition
+    const int lastj_offset =
+        (j < num_partition_offset) ? j : last_partition_offset;
+    tmps[idx] = tmp_out_ptr[lastj_offset];
+    idx++;
+  }
+  __syncthreads();
+  if (num_partitions > JCHUNK) {
+  #pragma unroll
+    for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
+         j += HEAD_SIZE) {
+      const int lastj_offset =
+          (j < num_partition_offset) ? j : last_partition_offset;
+      tmps[idx] = tmp_out_ptr[lastj_offset];
+      idx++;
+    }
+    if (num_partitions > 2 * JCHUNK) {
+  #pragma unroll
+      for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
+           j += HEAD_SIZE) {
+        const int lastj_offset =
+            (j < num_partition_offset) ? j : last_partition_offset;
+        tmps[idx] = tmp_out_ptr[lastj_offset];
+        idx++;
+      }
+    }
+  }  // num_partitions > JCHUNK
+  // Aggregate tmp_out to out.
+  float acc = 0.0f;
+  #pragma unroll
+  for (int j = 0; j < JCHUNK; j++) {
+    acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
+  }
+  if (num_partitions > JCHUNK) {
+  #pragma unroll
+    for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
+      acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
+    }
+    if (num_partitions > 2 * JCHUNK) {
+  #pragma unroll
+      for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
+        acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
+      }
+    }
+  }
+  if (num_partitions > MAX_NPAR) {
+    idx = 0;
+  #pragma unroll
+    for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE;
+         j += HEAD_SIZE) {
+      // lastj is last valid partition
+      const int lastj_offset =
+          (j < num_partition_offset) ? j : last_partition_offset;
+      tmps[idx] = tmp_out_ptr[lastj_offset];
+      idx++;
+    }
+  #pragma unroll
+    for (int j = 0; j < MAX_NPAR; j++) {
+      acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + MAX_NPAR];
+    }
+  }
+  const float inv_global_exp_sum =
+      __fdividef(1.0f, shared_global_exp_sum + 1e-6f);
+  acc *= inv_global_exp_sum;
+  scalar_t* out_ptr =
+      out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+  out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
+}
+#else  // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
+template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
+          int GQA_RATIO>
+__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
+    const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                           // head_size/x, block_size, x]
+    const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                           // head_size, block_size]
+    const int num_kv_heads, const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
+    float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                     // max_num_partitions]
+    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
+                                 // head_size]
+    scalar_t* __restrict__ final_out,  // [num_seqs, num_heads, head_size]
+  #if 0
+  scalar_t* __restrict__ qk_out,             // [num_heads, num_seqs, max_ctx_blocks,block_size]
+  #endif
+    int max_ctx_blocks) {
+  UNREACHABLE_CODE
+}
+// Grid: (num_heads, num_seqs).
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE>
+__global__
+__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
+    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
+                                           // max_num_partitions, head_size]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_partitions){UNREACHABLE_CODE}
+#endif  // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
+#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO)                                    \
+  paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
+      <<<grid, block, 0, stream>>>(                                           \
+          query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale,     \
+          block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq,         \
+          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,        \
+          exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks);
+template <typename T, int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 256>
+void paged_attention_custom_launcher(
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, const int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& context_lens,
+    int max_context_len,
+#if 0
+  torch::Tensor& qk_out,
+  torch::Tensor& softmax_out,
+#endif
+    const c10::optional<torch::Tensor>& alibi_slopes) {
+  int num_seqs = query.size(0);
+  int num_heads = query.size(1);
+  int head_size = query.size(2);
+  int max_num_blocks_per_seq = block_tables.size(1);
+  int q_stride = query.stride(0);
+  int kv_block_stride = key_cache.stride(0);
+  int kv_head_stride = key_cache.stride(1);
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr =
+      alibi_slopes
+          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+          : nullptr;
+  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
+  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
+  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
+  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
+  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
+  int* block_tables_ptr = block_tables.data_ptr<int>();
+  int* context_lens_ptr = context_lens.data_ptr<int>();
+#if 0
+  T* qk_out_ptr = reinterpret_cast<T*>(qk_out.data_ptr());
+  T* softmax_out_ptr = reinterpret_cast<T*>(softmax_out.data_ptr());
+#endif
+  const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
+  const int max_num_partitions =
+      DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+  const int gqa_ratio = num_heads / num_kv_heads;
+  assert(num_heads % num_kv_heads == 0);
+  assert(head_size == HEAD_SIZE);
+  assert(max_num_partitions <= 128);
+  constexpr int NTHR = PARTITION_SIZE;
+  dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
+  dim3 block(NTHR);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  switch (gqa_ratio) {
+    case 1:
+      LAUNCH_CUSTOM_ATTENTION(1);
+      break;
+    case 2:
+      LAUNCH_CUSTOM_ATTENTION(2);
+      break;
+    case 3:
+      LAUNCH_CUSTOM_ATTENTION(3);
+      break;
+    case 4:
+      LAUNCH_CUSTOM_ATTENTION(4);
+      break;
+    case 5:
+      LAUNCH_CUSTOM_ATTENTION(5);
+      break;
+    case 6:
+      LAUNCH_CUSTOM_ATTENTION(6);
+      break;
+    case 7:
+      LAUNCH_CUSTOM_ATTENTION(7);
+      break;
+    case 8:
+      LAUNCH_CUSTOM_ATTENTION(8);
+      break;
+    case 9:
+      LAUNCH_CUSTOM_ATTENTION(9);
+      break;
+    case 10:
+      LAUNCH_CUSTOM_ATTENTION(10);
+      break;
+    case 11:
+      LAUNCH_CUSTOM_ATTENTION(11);
+      break;
+    case 12:
+      LAUNCH_CUSTOM_ATTENTION(12);
+      break;
+    case 13:
+      LAUNCH_CUSTOM_ATTENTION(13);
+      break;
+    case 14:
+      LAUNCH_CUSTOM_ATTENTION(14);
+      break;
+    case 15:
+      LAUNCH_CUSTOM_ATTENTION(15);
+      break;
+    case 16:
+      LAUNCH_CUSTOM_ATTENTION(16);
+      break;
+    default:
+      TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
+      break;
+  }
+  // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG);
+  // dim3 block2(1024);
+  //  LAUNCH_CUSTOM_ATTENTION2;
+  // reduction kernel is only required if max_context_len > partition size,
+  // otherwise main kernel writes directly to final output
+  //  note there are cases with graphing where max_context_len is the max
+  //  supported by graphing, not the actual max among all the sequences: in that
+  //  case reduction kernel will still run but return immediately
+  if (max_context_len > PARTITION_SIZE) {
+    dim3 reduce_grid(num_heads, num_seqs);
+    dim3 reduce_block(head_size);
+    paged_attention_ll4mi_reduce_kernel<T, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE>
+        <<<reduce_grid, reduce_block, 0, stream>>>(
+            out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,
+            context_lens_ptr, max_num_partitions);
+  }
+}
+#define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE)                     \
+  paged_attention_custom_launcher<T, BLK_SIZE, HEAD_SIZE>(               \
+      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
+      num_kv_heads, scale, block_tables, context_lens, max_context_len,  \
+      alibi_slopes);
+#define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE)                    \
+  switch (block_size) {                                           \
+    case 16:                                                      \
+      CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE);                     \
+      break;                                                      \
+    case 32:                                                      \
+      CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE);                     \
+      break;                                                      \
+    default:                                                      \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+      break;                                                      \
+  }
+#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T)                        \
+  switch (head_size) {                                          \
+    case 64:                                                    \
+      CALL_CUSTOM_LAUNCHER_BLK(T, 64);                          \
+      break;                                                    \
+    case 128:                                                   \
+      CALL_CUSTOM_LAUNCHER_BLK(T, 128);                         \
+      break;                                                    \
+    default:                                                    \
+      TORCH_CHECK(false, "Unsupported head size: ", head_size); \
+      break;                                                    \
+  }
+void paged_attention(
+    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
+    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
+    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
+    torch::Tensor&
+        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
+    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
+    torch::Tensor&
+        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
+    torch::Tensor&
+        value_cache,  // [num_blocks, num_heads, head_size, block_size]
+    int64_t num_kv_heads, double scale,
+    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    torch::Tensor& context_lens,  // [num_seqs]
+    int64_t block_size, int64_t max_context_len,
+    const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype) {
+  assert(kv_cache_dtype == "auto");
+  const int head_size = query.size(2);
+  if (query.dtype() == at::ScalarType::Half) {
+    CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16);
+  } else if (query.dtype() == at::ScalarType::BFloat16) {
+    CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16);
+  } else {
+    TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+  }
+}
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP

+ 13 - 0
kernels/rocm/ops.h

@@ -0,0 +1,13 @@
+#pragma once
+
+#include <torch/all.h>
+
+void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
+                     torch::Tensor& max_logits, torch::Tensor& tmp_out,
+                     torch::Tensor& query, torch::Tensor& key_cache,
+                     torch::Tensor& value_cache, int64_t num_kv_heads,
+                     double scale, torch::Tensor& block_tables,
+                     torch::Tensor& context_lens, int64_t block_size,
+                     int64_t max_context_len,
+                     const c10::optional<torch::Tensor>& alibi_slopes,
+                     const std::string& kv_cache_dtype);

+ 29 - 0
kernels/rocm/torch_bindings.cpp

@@ -0,0 +1,29 @@
+#include "core/registration.h"
+#include "rocm/ops.h"
+// Note on op signatures:
+// The X_meta signatures are for the meta functions corresponding to op X.
+// They must be kept in sync with the signature for X. Generally, only
+// functions that return Tensors require a meta function.
+//
+// See the following links for detailed docs on op registration and function
+// schemas.
+// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9
+// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations
+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
+  // Aphrodite custom ops for rocm
+  // Custom attention op
+  // Compute the attention between an input query and the cached
+  // keys/values using PagedAttention.
+  rocm_ops.def(
+      "paged_attention(Tensor! out, Tensor exp_sums,"
+      "                Tensor max_logits, Tensor tmp_out,"
+      "                Tensor query, Tensor key_cache,"
+      "                Tensor value_cache, int num_kv_heads,"
+      "                float scale, Tensor block_tables,"
+      "                Tensor context_lens, int block_size,"
+      "                int max_context_len,"
+      "                Tensor? alibi_slopes,"
+      "                str kv_cache_dtype) -> ()");
+  rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
+}
+REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 3 - 0
setup.py

@@ -478,6 +478,9 @@ if _is_cuda() or _is_hip():
 if _build_custom_ops():
     ext_modules.append(CMakeExtension(name="aphrodite._C"))
 
+if _is_hip():
+    ext_modules.append(CMakeExtension(name="aphrodite._rocm_C"))
+
 package_data = {
     "aphrodite": [
         "endpoints/kobold/klite.embd", "quantization/hadamard.safetensors",

+ 154 - 3
tests/kernels/test_attention.py

@@ -3,8 +3,6 @@ from typing import List, Optional, Tuple
 
 import pytest
 import torch
-from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
 
 from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip
@@ -12,6 +10,10 @@ from tests.kernels.utils import opcheck
 
 from .allclose_default import get_default_atol, get_default_rtol
 
+if not is_hip():
+    from xformers import ops as xops
+    from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
+
 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
 # This will change depending on the compute capability.
 # - 512 as a buffer
@@ -326,13 +328,162 @@ def ref_multi_query_kv_attention(
     return torch.cat(ref_outputs, dim=0)
 
 
-# TODO(woosuk): Add tests for USE_ALIBI=True.
+@pytest.mark.parametrize("version", ["rocm"])
+@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
+@pytest.mark.parametrize("num_heads", NUM_HEADS)
+@pytest.mark.parametrize("head_size", [64, 128])  # only test 64 128
+@pytest.mark.parametrize("use_alibi", USE_ALIBI)
+@pytest.mark.parametrize("block_size", BLOCK_SIZES)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
+@pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+@pytest.mark.skipif(not is_hip(), reason="only for rocm")
+def test_paged_attention_rocm(
+    kv_cache_factory,
+    version: str,
+    num_seqs: int,
+    num_heads: Tuple[int, int],
+    head_size: int,
+    use_alibi: bool,
+    block_size: int,
+    dtype: torch.dtype,
+    kv_cache_dtype: str,
+    seed: int,
+    device: str,
+) -> None:
+    random.seed(seed)
+    torch.random.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+    torch.set_default_device(device)
+    scale = float(1.0 / (head_size**0.5))
+    num_query_heads, num_kv_heads = num_heads
+    query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
+    query.uniform_(-scale, scale)
+    assert num_query_heads % num_kv_heads == 0
+    num_queries_per_kv = num_query_heads // num_kv_heads
+    alibi_slopes = None
+    if use_alibi:
+        alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
+    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
+    context_lens[-1] = MAX_SEQ_LEN
+    #context_lens = [8192 for _ in range(num_seqs)]
+    max_context_len = max(context_lens)
+    context_lens = torch.tensor(context_lens, dtype=torch.int)
+    #print('>>> ctx lens', context_lens)
+    # Create the block tables.
+    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
+    block_tables = []
+    for _ in range(num_seqs):
+        block_table = [
+            random.randint(0, NUM_BLOCKS - 1)
+            for _ in range(max_num_blocks_per_seq)
+        ]
+        block_tables.append(block_table)
+    block_tables = torch.tensor(block_tables, dtype=torch.int)
+    # Create the KV caches.
+    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
+                                                num_kv_heads, head_size,
+                                                kv_cache_dtype, dtype, seed,
+                                                device)
+    key_cache, value_cache = key_caches[0], value_caches[0]
+    # TODO enable fp8 kv cache
+    # Using default kv_scale
+    # kv_scale = 1.0
+    # Call the paged attention kernel.
+    output = torch.empty_like(query)
+    PARTITION_SIZE_ROCM = 256
+    num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
+                      PARTITION_SIZE_ROCM)
+    assert PARTITION_SIZE_ROCM % block_size == 0
+    num_seqs, num_heads, head_size = output.shape
+    tmp_output = torch.empty(
+        size=(num_seqs, num_heads, num_partitions, head_size),
+        dtype=output.dtype,
+    )
+    exp_sums = torch.empty(
+        size=(num_seqs, num_heads, num_partitions),
+        dtype=torch.float32,
+    )
+    max_logits = torch.empty_like(exp_sums)
+    if version == "rocm":
+        ops.paged_attention_rocm(
+            output,
+            exp_sums,
+            max_logits,
+            tmp_output,
+            query,
+            key_cache,
+            value_cache,
+            num_kv_heads,
+            scale,
+            block_tables,
+            context_lens,
+            block_size,
+            max_context_len,
+            alibi_slopes,
+            kv_cache_dtype,
+        )
+    else:
+        raise AssertionError(f"Unknown version: {version}")
+    # Run the reference implementation.
+    if kv_cache_dtype == "fp8":
+        # Convert cache data back to dtype.
+        x = 16 // torch.tensor([], dtype=dtype).element_size()
+        key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
+                           block_size, x)
+        dequantized_key_cache = torch.empty(size=key_cache_shape,
+                                            dtype=dtype,
+                                            device=device)
+        ops.convert_fp8(key_cache, dequantized_key_cache)
+        key_cache = dequantized_key_cache
+        value_cache_shape = value_cache.shape
+        dequantized_value_cache = torch.empty(size=value_cache_shape,
+                                              dtype=dtype,
+                                              device=device)
+        ops.convert_fp8(value_cache, dequantized_value_cache)
+        value_cache = dequantized_value_cache
+    ref_output = torch.empty_like(query)
+    ref_single_query_cached_kv_attention(
+        ref_output,
+        query,
+        num_queries_per_kv,
+        key_cache,
+        value_cache,
+        block_tables,
+        context_lens,
+        scale,
+        alibi_slopes,
+    )
+    # NOTE: Due to the kernel-level differences in the two
+    # implementations, there is a small numerical difference in the two
+    # outputs. Thus, we use a relaxed tolerance for the test.
+    atol = get_default_atol(output) if is_hip() else 1e-3
+    rtol = get_default_rtol(output) if is_hip() else 1e-5
+    # NOTE: FP8 KV Cache will introduce quantization error,
+    # so we use a relaxed tolerance for the test.
+    atol, rtol = 1e-4, 1e-5
+    if dtype == torch.bfloat16:
+        atol, rtol = 2e-4, 1e-5
+    if use_alibi:
+        if dtype == torch.half:
+            atol, rtol = 5e-4, 1e-5
+        if dtype == torch.bfloat16:
+            atol, rtol = 1e-3, 1e-5
+    if kv_cache_dtype == "fp8":
+        atol, rtol = 1e-2, 1e-5
+    assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
+
+
+# TODO: Add tests for USE_ALIBI=True.
 @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
 @pytest.mark.parametrize("num_heads", NUM_HEADS)
 @pytest.mark.parametrize("head_size", HEAD_SIZES)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("seed", SEEDS)
 @pytest.mark.parametrize("device", CUDA_DEVICES)
+@pytest.mark.skipif(is_hip(), reason="skip for rocm")
 @torch.inference_mode()
 def test_multi_query_kv_attention(
     num_seqs: int,