瀏覽代碼

[Gen] Add kernel from FasterTransformer for benchmarking

Tri Dao 2 年之前
父節點
當前提交
a01d1213d7

+ 8 - 0
csrc/ft_attention/README.md

@@ -0,0 +1,8 @@
+# Attention kernel from FasterTransformer
+
+This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from
+FasterTransformer v5.2.1 for benchmarking purpose.
+
+```sh
+cd csrc/ft_attention && pip install .
+```

+ 257 - 0
csrc/ft_attention/cuda_bf16_fallbacks.cuh

@@ -0,0 +1,257 @@
+// Downloaded from from FasterTransformer v5.2.1
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
+/*
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "cuda_bf16_wrapper.h"
+#include <cuda_fp16.h>
+
+namespace fastertransformer {
+
+#ifdef ENABLE_BF16
+inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float2 f_val;
+    f_val.x = __low2float(val);
+    f_val.y = __high2float(val);
+    return f_val;
+#else
+    return __bfloat1622float2(val);
+#endif
+}
+
+inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float2 f_val;
+    f_val.x = max(min(__low2float(val), 127.f), -128.f);
+    f_val.y = max(min(__high2float(val), 127.f), -128.f);
+    union { int8_t int8[2]; int16_t int16; };
+    int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x));
+    int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y));
+    return int16;
+#else
+    val = __hmin2(val, make_bfloat162(127., 127.));
+    val = __hmax2(val, make_bfloat162(-128., -128.));
+    union { int8_t int8[2]; int16_t int16; };
+    int8[0] = static_cast<int8_t>(static_cast<short>(val.x));
+    int8[1] = static_cast<int8_t>(static_cast<short>(val.y));
+    return int16;
+#endif
+}
+
+inline __device__ __nv_bfloat162 float22bf162(const float2 val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __floats2bfloat162_rn(val.x, val.y);
+#else
+    return __float22bfloat162_rn(val);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    __nv_bfloat162 val2;
+    val2.x = val;
+    val2.y = val;
+    return val2;
+#else
+    return __bfloat162bfloat162(val);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fxl, fxh, fyl, fyh;
+    fxl = __low2float(x);
+    fxh = __high2float(x);
+    fyl = __low2float(y);
+    fyh = __high2float(y);
+    return __floats2bfloat162_rn(fxl + fyl, fxh + fyh);
+#else
+    return __hadd2(x, y);
+#endif
+}
+
+inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) );
+#else
+    return __hadd(x, y);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fxl, fxh, fyl, fyh;
+    fxl = __low2float(x);
+    fxh = __high2float(x);
+    fyl = __low2float(y);
+    fyh = __high2float(y);
+    return __floats2bfloat162_rn(fxl - fyl, fxh - fyh);
+#else
+    return __hsub2(x, y);
+#endif
+}
+
+inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) );
+#else
+    return __hsub(x, y);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fxl, fxh, fyl, fyh;
+    fxl = __low2float(x);
+    fxh = __high2float(x);
+    fyl = __low2float(y);
+    fyh = __high2float(y);
+    return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
+#else
+    return __hmul2(x, y);
+#endif
+}
+
+inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) );
+#else
+    return __hmul(x, y);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fxl, fxh, fyl, fyh, fzl, fzh;
+    fxl = __low2float(x);
+    fxh = __high2float(x);
+    fyl = __low2float(y);
+    fyh = __high2float(y);
+    fzl = __low2float(z);
+    fzh = __high2float(z);
+    return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh);
+#else
+    return __hfma2(x, y, z);
+#endif
+}
+
+inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z));
+#else
+    return __hfma(x, y, z);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fxl, fxh;
+    fxl = __low2float(x);
+    fxh = __high2float(x);;
+    return __floats2bfloat162_rn(expf(fxl), expf(fxh));
+#else
+    return h2exp(x);
+#endif
+}
+
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
+inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); };
+inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); };
+
+inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y)
+{
+    __nv_bfloat162 t; t.x = x; t.y = y; return t;
+}
+
+#endif
+
+inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c));
+#else
+    return a + b + c;
+#endif
+}
+
+inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d));
+#else
+    return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d);
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fal, fah, fbl, fbh, fcl, fch;
+    fal = __low2float(a);
+    fah = __high2float(a);
+    fbl = __low2float(b);
+    fbh = __high2float(b);
+    fcl = __low2float(c);
+    fch = __high2float(c);
+    return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch);
+#else
+    return a + b + c;
+#endif
+}
+
+inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c));
+#else
+    return a * b * c;
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fal, fah, fbl, fbh, fcl, fch;
+    fal = __low2float(a);
+    fah = __high2float(a);
+    fbl = __low2float(b);
+    fbh = __high2float(b);
+    fcl = __low2float(c);
+    fch = __high2float(c);
+    return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch);
+#else
+    return a * b * c;
+#endif
+}
+
+inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) {
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
+    float fal, fah, fbl, fbh, fcl, fch, fdl, fdh;
+    fal = __low2float(a);
+    fah = __high2float(a);
+    fbl = __low2float(b);
+    fbh = __high2float(b);
+    fcl = __low2float(c);
+    fch = __high2float(c);
+    fdl = __low2float(d);
+    fdh = __high2float(d);
+    return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh);
+#else
+    return a * b * c + d;
+#endif
+}
+
+#endif // ENABLE_BF16
+
+}  // namespace fastertransformer

+ 23 - 0
csrc/ft_attention/cuda_bf16_wrapper.h

@@ -0,0 +1,23 @@
+// Downloaded from from FasterTransformer v5.2.1
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
+/*
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#ifdef ENABLE_BF16
+#include <cuda_bf16.h>
+#endif

+ 152 - 0
csrc/ft_attention/decoder_masked_multihead_attention.cu

@@ -0,0 +1,152 @@
+// Adapted from from FasterTransformer v5.2.1
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "decoder_masked_multihead_attention.h"
+#include "decoder_masked_multihead_attention_utils.h"
+#include "cuda_bf16_wrapper.h"
+#include <assert.h>
+#include <float.h>
+#include <type_traits>
+
+#include "decoder_masked_multihead_attention_template.hpp"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream)    \
+    size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK);          \
+    dim3   grid(params.num_heads, params.batch_size);                                                                  \
+    mmha::masked_multihead_attention_kernel<T,                                                                         \
+                                            Dh,                                                                        \
+                                            Dh_MAX,                                                                    \
+                                            THDS_PER_KEY,                                                              \
+                                            THDS_PER_VALUE,                                                            \
+                                            THDS_PER_BLOCK,                                                            \
+                                            DO_CROSS_ATTENTION><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// !!! Specialize the launcher for Cross attention
+template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
+void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
+{
+    constexpr int  THREADS_PER_VALUE  = Dh_MAX * sizeof(T) / 16;
+    constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
+    int            tlength            = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
+    // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
+    if (tlength < 32) {
+        MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream);
+    }
+    else if (tlength < 2048) {
+        MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream);
+    }
+    else {
+        MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#undef MMHA_LAUNCH_KERNEL
+
+template<typename T, typename KERNEL_PARAMS_TYPE>
+void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
+{
+    switch (params.hidden_size_per_head) {
+        case 32:
+            mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 48:
+            mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 64:
+            mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 80:
+            mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 96:
+            mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 128:
+            mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 160:
+            mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 192:
+            mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 224:
+            mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        case 256:
+            mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
+            break;
+        default:
+            assert(false);
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream)
+{
+    multihead_attention_<float, Masked_multihead_attention_params<float>>(params, stream);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
+{
+    multihead_attention_<uint16_t, Masked_multihead_attention_params<uint16_t>>(params, stream);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
+                                const cudaStream_t&                                     stream)
+{
+    multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream);
+}
+#endif
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream)
+{
+    multihead_attention_<float, Cross_multihead_attention_params<float>>(params, stream);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream)
+{
+    multihead_attention_<uint16_t, Cross_multihead_attention_params<uint16_t>>(params, stream);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
+                               const cudaStream_t&                                    stream)
+{
+    multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream);
+}
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////

+ 181 - 0
csrc/ft_attention/decoder_masked_multihead_attention.h

@@ -0,0 +1,181 @@
+// Downloaded from from FasterTransformer v5.2.1
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "cuda_bf16_wrapper.h"
+#include <cuda_fp16.h>
+#include <cuda_runtime_api.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#define CHECK_CUDA(call)                                                                                               \
+    do {                                                                                                               \
+        cudaError_t status_ = call;                                                                                    \
+        if (status_ != cudaSuccess) {                                                                                  \
+            fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_));              \
+            exit(1);                                                                                                   \
+        }                                                                                                              \
+    } while (0)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// The structure of parameters for the masked multihead attention kernel.
+//
+// We use the following terminology to describe the different dimensions.
+//
+// B:  Batch size (number of sequences),
+// L:  Sequence length,
+// D:  Hidden dimension,
+// H:  Number of heads,
+// Dh: Hidden dimension per head - Dh = D / H.
+
+template<typename T>
+struct Multihead_attention_params_base {
+
+    // The output buffer. Dimensions B x D.
+    T* out = nullptr;
+
+    // The input Qs and the associated bias. Dimensions B x D and D, resp.
+    const T *q = nullptr, *q_bias = nullptr;
+    // The input Ks and the associated bias. Dimensions B x D and D, resp.
+    const T *k = nullptr, *k_bias = nullptr;
+    // The input Vs and the associated bias. Dimensions B x D and D, resp.
+    const T *v = nullptr, *v_bias = nullptr;
+
+    // The cache for the Ks. The size must be at least B x L x D.
+    T* k_cache = nullptr;
+    // The cache for the Vs. The size must be at least B x L x D.
+    T* v_cache = nullptr;
+    // The indirections to use for cache when beam sampling.
+    const int* cache_indir = nullptr;
+
+    // Stride to handle the case when KQV is a single buffer
+    int stride = 0;
+
+    // The batch size.
+    int batch_size = 0;
+    // The beam width
+    int beam_width = 0;
+    // The sequence length.
+    int memory_max_len = 0;
+    // The number of heads (H).
+    int num_heads = 0;
+    // The hidden dimension per head (Dh).
+    int hidden_size_per_head = 0;
+    // The per-head latent space reserved for rotary embeddings.
+    int  rotary_embedding_dim = 0;
+    bool neox_rotary_style    = false;
+    // The maximum length of input sentences.
+    int max_input_length = 0;
+    // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
+    int timestep = 0;
+    // The current timestep of each sentences (support different timestep for different sentences)
+
+    // The 1.f / sqrt(Dh). Computed on the host.
+    float inv_sqrt_dh = 0.0f;
+
+    // Used when we have some input context like gpt
+    const int* total_padding_tokens = nullptr;
+
+    const bool* masked_tokens            = nullptr;
+    const int*  prefix_prompt_lengths    = nullptr;
+    int         max_prefix_prompt_length = 0;
+
+    const T* relative_attention_bias        = nullptr;
+    int      relative_attention_bias_stride = 0;
+    // The slope per head of linear position bias to attention score (H).
+    const T* linear_bias_slopes = nullptr;
+
+    const T*   ia3_key_weights   = nullptr;
+    const T*   ia3_value_weights = nullptr;
+    const int* ia3_tasks         = nullptr;
+
+    const float* qkv_scale_out       = nullptr;
+    const float* attention_out_scale = nullptr;
+    int          int8_mode           = 0;
+};
+
+template<typename T, bool CROSS_ATTENTION>
+struct Multihead_attention_params: public Multihead_attention_params_base<T> {
+    // output cross attentions
+    float* cross_attention_out        = nullptr;
+    int    max_decoder_seq_len        = 0;
+    bool   is_return_cross_attentions = false;
+
+    // allows to exist attention eary
+    bool* finished = nullptr;
+
+    // required in case of cross attention
+    // will need it here till if constexpr in c++17
+    int* memory_length_per_sample = nullptr;
+
+    // required in case of masked attention with different length
+    const int* length_per_sample = nullptr;
+};
+
+template<typename T>
+struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> {
+    // output cross attentions
+    float* cross_attention_out        = nullptr;
+    int    max_decoder_seq_len        = 0;
+    bool   is_return_cross_attentions = false;
+
+    // allows to exist attention eary
+    bool* finished = nullptr;
+
+    // required in case of cross attention
+    int* memory_length_per_sample = nullptr;
+
+    // required in case of masked attention with different length
+    const int* length_per_sample = nullptr;
+};
+
+template<class T>
+using Masked_multihead_attention_params = Multihead_attention_params<T, false>;
+
+template<class T>
+using Cross_multihead_attention_params = Multihead_attention_params<T, true>;
+
+template<typename T>
+struct outputCrossAttentionParam {
+    // max decoder output length
+    int  max_decoder_seq_len        = 0;
+    T*   cross_attention_out        = nullptr;
+    bool is_return_cross_attentions = false;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
+void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
+#ifdef ENABLE_BF16
+void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params,
+                                const cudaStream_t&                                     stream);
+#endif
+void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
+void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
+#ifdef ENABLE_BF16
+void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params,
+                               const cudaStream_t&                                    stream);
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////

+ 1605 - 0
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp

@@ -0,0 +1,1605 @@
+// Downloaded from from FasterTransformer v5.2.1
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "decoder_masked_multihead_attention.h"
+#include "decoder_masked_multihead_attention_utils.h"
+#include "cuda_bf16_wrapper.h"
+#include "cuda_bf16_fallbacks.cuh"
+#include <assert.h>
+#include <float.h>
+#include <type_traits>
+
+// #define MMHA_USE_HMMA_FOR_REDUCTION
+
+// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
+
+// Does not seem to affect the accuracy that much
+// #define MMHA_USE_FP32_ACUM_FOR_FMA
+
+// Seems to slightly improve the accuracy
+#define MMHA_USE_FP32_ACUM_FOR_OUT
+
+#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
+ // Does not seem to improve the accuracy
+ //#define MMHA_USE_FP32_ACUM_FOR_LOGITS
+#endif
+
+namespace mmha {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+//
+// We use the following terminology to describe the different dimensions.
+//
+// B:  Batch size (number of sequences),
+// L:  Sequence length,
+// D:  Hidden dimension,
+// H:  Number of heads,
+// Dh: Hidden dimension per head - Dh = D / H.
+//
+// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
+// 64, 128 and 256 threads per block.
+//
+// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
+// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
+// cache buffer helps with memory accesses and contains keys with bias.
+//
+// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
+// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
+// values for x are chosen to create chunks of 16 bytes.
+//
+// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
+// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
+// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
+// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
+//
+// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
+// shared memory.
+//
+// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
+// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
+// except for the current timestep. The layout of the cache buffer for the values is much simpler
+// as it is [B, H, L, Dh].
+//
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int Dh>
+struct Qk_vec_ {
+};
+
+template<>
+struct Qk_vec_<float, 32> {
+    using Type = float;
+};
+template<>
+struct Qk_vec_<float, 64> {
+    using Type = float2;
+};
+template<>
+struct Qk_vec_<float, 128> {
+    using Type = float4;
+};
+template<>
+struct Qk_vec_<float, 256> {
+    using Type = float4;
+};
+template<>
+struct Qk_vec_<uint16_t, 32> {
+    using Type = uint32_t;
+};
+template<>
+struct Qk_vec_<uint16_t, 64> {
+    using Type = uint32_t;
+};
+template<>
+struct Qk_vec_<uint16_t, 128> {
+    using Type = uint2;
+};
+template<>
+struct Qk_vec_<uint16_t, 256> {
+    using Type = uint4;
+};
+#ifdef ENABLE_BF16
+template<>
+struct Qk_vec_<__nv_bfloat16, 32> {
+    using Type = __nv_bfloat162;
+};
+template<>
+struct Qk_vec_<__nv_bfloat16, 64> {
+    using Type = __nv_bfloat162;
+};
+template<>
+struct Qk_vec_<__nv_bfloat16, 128> {
+    using Type = bf16_4_t;
+};
+template<>
+struct Qk_vec_<__nv_bfloat16, 256> {
+    using Type = bf16_8_t;
+};
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int THREADS_PER_KEY>
+struct K_vec_ {
+};
+
+template<>
+struct K_vec_<float, 4> {
+    using Type = float;
+};
+template<>
+struct K_vec_<float, 2> {
+    using Type = float2;
+};
+template<>
+struct K_vec_<float, 1> {
+    using Type = float4;
+};
+template<>
+struct K_vec_<uint16_t, 4> {
+    using Type = uint32_t;
+};
+template<>
+struct K_vec_<uint16_t, 2> {
+    using Type = uint2;
+};
+template<>
+struct K_vec_<uint16_t, 1> {
+    using Type = uint4;
+};
+#ifdef ENABLE_BF16
+template<>
+struct K_vec_<__nv_bfloat16, 4> {
+    using Type = __nv_bfloat162;
+};
+template<>
+struct K_vec_<__nv_bfloat16, 2> {
+    using Type = bf16_4_t;
+};
+template<>
+struct K_vec_<__nv_bfloat16, 1> {
+    using Type = bf16_8_t;
+};
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int V_VEC_SIZE>
+struct V_vec_ {
+};
+
+template<>
+struct V_vec_<float, 1> {
+    using Type = float;
+};
+template<>
+struct V_vec_<float, 2> {
+    using Type = float2;
+};
+template<>
+struct V_vec_<float, 4> {
+    using Type = float4;
+};
+template<>
+struct V_vec_<uint16_t, 2> {
+    using Type = uint32_t;
+};
+template<>
+struct V_vec_<uint16_t, 4> {
+    using Type = uint2;
+};
+template<>
+struct V_vec_<uint16_t, 8> {
+    using Type = uint4;
+};
+#ifdef ENABLE_BF16
+template<>
+struct V_vec_<__nv_bfloat16, 2> {
+    using Type = __nv_bfloat162;
+};
+template<>
+struct V_vec_<__nv_bfloat16, 4> {
+    using Type = bf16_4_t;
+};
+template<>
+struct V_vec_<__nv_bfloat16, 8> {
+    using Type = bf16_8_t;
+};
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
+template<typename T>
+struct Qk_vec_acum_fp32_ {
+};
+
+template<>
+struct Qk_vec_acum_fp32_<float> {
+    using Type = float;
+};
+template<>
+struct Qk_vec_acum_fp32_<float2> {
+    using Type = float2;
+};
+template<>
+struct Qk_vec_acum_fp32_<float4> {
+    using Type = float4;
+};
+// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float;        };
+template<>
+struct Qk_vec_acum_fp32_<uint32_t> {
+    using Type = float2;
+};
+template<>
+struct Qk_vec_acum_fp32_<uint2> {
+    using Type = Float4_;
+};
+template<>
+struct Qk_vec_acum_fp32_<uint4> {
+    using Type = Float8_;
+};
+template<>
+struct Qk_vec_acum_fp32_<__nv_bfloat16> {
+    using Type = float;
+};
+template<>
+struct Qk_vec_acum_fp32_<__nv_bfloat162> {
+    using Type = float2;
+};
+template<>
+struct Qk_vec_acum_fp32_<bf16_4_t> {
+    using Type = Float4_;
+};
+template<>
+struct Qk_vec_acum_fp32_<bf16_8_t> {
+    using Type = Float8_;
+};
+
+template<>
+struct Qk_vec_acum_fp32_<uint4> {
+    using Type = Float8_;
+};
+template<>
+struct Qk_vec_acum_fp32_<__nv_bfloat16> {
+    using Type = float;
+};
+template<>
+struct Qk_vec_acum_fp32_<__nv_bfloat162> {
+    using Type = float2;
+};
+template<>
+struct Qk_vec_acum_fp32_<bf16_4_t> {
+    using Type = Float4_;
+};
+template<>
+struct Qk_vec_acum_fp32_<bf16_8_t> {
+    using Type = Float8_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct K_vec_acum_fp32_ {
+};
+
+template<>
+struct K_vec_acum_fp32_<float> {
+    using Type = float;
+};
+template<>
+struct K_vec_acum_fp32_<float2> {
+    using Type = float2;
+};
+template<>
+struct K_vec_acum_fp32_<float4> {
+    using Type = float4;
+};
+template<>
+struct K_vec_acum_fp32_<uint32_t> {
+    using Type = float2;
+};
+template<>
+struct K_vec_acum_fp32_<uint2> {
+    using Type = Float4_;
+};
+template<>
+struct K_vec_acum_fp32_<uint4> {
+    using Type = Float8_;
+};
+template<>
+struct K_vec_acum_fp32_<__nv_bfloat16> {
+    using Type = float;
+};
+template<>
+struct K_vec_acum_fp32_<__nv_bfloat162> {
+    using Type = float2;
+};
+template<>
+struct K_vec_acum_fp32_<bf16_4_t> {
+    using Type = Float4_;
+};
+template<>
+struct K_vec_acum_fp32_<bf16_8_t> {
+    using Type = Float8_;
+};
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+template<typename T>
+struct V_vec_acum_fp32_ {
+};
+
+template<>
+struct V_vec_acum_fp32_<float> {
+    using Type = float;
+};
+template<>
+struct V_vec_acum_fp32_<float2> {
+    using Type = float2;
+};
+template<>
+struct V_vec_acum_fp32_<float4> {
+    using Type = float4;
+};
+template<>
+struct V_vec_acum_fp32_<uint32_t> {
+    using Type = float2;
+};
+template<>
+struct V_vec_acum_fp32_<uint2> {
+    using Type = Float4_;
+};
+template<>
+struct V_vec_acum_fp32_<uint4> {
+    using Type = Float8_;
+};
+#ifdef ENABLE_BF16
+template<>
+struct V_vec_acum_fp32_<__nv_bfloat162> {
+    using Type = float2;
+};
+template<>
+struct V_vec_acum_fp32_<bf16_4_t> {
+    using Type = Float4_;
+};
+template<>
+struct V_vec_acum_fp32_<bf16_8_t> {
+    using Type = Float8_;
+};
+#endif  // ENABLE_BF16
+#endif
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int THREADS_PER_KEY, typename K_vec, int N>
+inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
+{
+#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
+    using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
+#else
+    using K_vec_acum = K_vec;
+#endif
+    // Compute the parallel products for Q*K^T (treat vector lanes separately).
+    K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
+#pragma unroll
+    for (int ii = 1; ii < N; ++ii) {
+        qk_vec = fma(q[ii], k[ii], qk_vec);
+    }
+
+    // Finalize the reduction across lanes.
+    float qk = sum(qk_vec);
+#pragma unroll
+    for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
+        qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
+    }
+    return qk;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int THREADS_PER_KEY>
+struct Qk_dot {
+    template<typename K_vec, int N>
+    static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
+    {
+        return qk_dot_<THREADS_PER_KEY>(q, k);
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
+{
+    float4 c;
+    float  zero = 0.f;
+    asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
+                 "    {%0, %1, %2, %3}, \n"
+                 "    {%4, %5}, \n"
+                 "    {%6}, \n"
+                 "    {%7, %7, %7, %7}; \n"
+
+                 : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
+                 : "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int N>
+inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
+{
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
+#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
+    using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
+#else
+    using K_vec_acum = uint32_t;
+#endif
+    K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
+#pragma unroll
+    for (int ii = 1; ii < N; ++ii) {
+        qk_vec = fma(q[ii], k[ii], qk_vec);
+    }
+#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
+    uint32_t qk_vec_ = float2_to_half2(qk_vec);
+    return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
+#else
+    return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
+#endif
+#else
+    return 0.f;
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+struct Qk_dot<uint16_t, 4> {
+    template<int N>
+    static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
+    {
+#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
+        return qk_hmma_dot_(q, k);
+#else
+        return qk_dot_<4>(q, k);
+#endif  // defined MMHA_USE_HMMA_FOR_REDUCTION
+    }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
+inline __device__ float block_sum(float* red_smem, float sum)
+{
+
+    // Decompose the thread index into warp / lane.
+    int warp = threadIdx.x / WARP_SIZE;
+    int lane = threadIdx.x % WARP_SIZE;
+
+// Compute the sum per warp.
+#pragma unroll
+    for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+        sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+    }
+
+    // Warp leaders store the data to shared memory.
+    if (lane == 0) {
+        red_smem[warp] = sum;
+    }
+
+    // Make sure the data is in shared memory.
+    __syncthreads();
+
+    // The warps compute the final sums.
+    if (lane < WARPS_PER_BLOCK) {
+        sum = red_smem[lane];
+    }
+
+// Parallel reduction inside the warp.
+#pragma unroll
+    for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
+        sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+    }
+
+    // Broadcast to other threads.
+    return __shfl_sync(uint32_t(-1), sum, 0);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(float& dst, float src)
+{
+    dst = src;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(uint16_t& dst, float src)
+{
+    dst = float_to_half(src);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(uint32_t& dst, float2 src)
+{
+    dst = float2_to_half2(src);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+#ifdef ENABLE_BF16
+inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
+{
+    dst = __float2bfloat16(src);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
+{
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    dst = __float22bfloat162_rn(src);
+#else
+    dst   = __floats2bfloat162_rn(src.x, src.y);
+#endif
+}
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(uint2& dst, Float4_ src)
+{
+    dst.x = float2_to_half2(src.x);
+    dst.y = float2_to_half2(src.y);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(uint2& dst, float4 src)
+{
+    convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(uint4& dst, Float8_ src)
+{
+    dst.x = float2_to_half2(src.x);
+    dst.y = float2_to_half2(src.y);
+    dst.z = float2_to_half2(src.z);
+    dst.w = float2_to_half2(src.w);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
+{
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    dst.x = __float22bfloat162_rn(src.x);
+    dst.y = __float22bfloat162_rn(src.y);
+#else
+    dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
+    dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
+{
+    convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
+{
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    dst.x = __float22bfloat162_rn(src.x);
+    dst.y = __float22bfloat162_rn(src.y);
+    dst.z = __float22bfloat162_rn(src.z);
+    dst.w = __float22bfloat162_rn(src.w);
+#else
+    dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
+    dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
+    dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
+    dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
+#endif
+}
+#endif  // ENABLE_BF16
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(float2& dst, float2 src)
+{
+    dst = src;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void convert_from_float(float4& dst, float4 src)
+{
+    dst = src;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float convert_to_float(float4 u)
+{
+    return u.x;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float convert_to_float(uint4 u)
+{
+    float2 tmp = half2_to_float2(u.x);
+    return tmp.x;
+}
+
+#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float cast_to_float(float u)
+{
+    return u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 cast_to_float(float2 u)
+{
+    return u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float4 cast_to_float(float4 u)
+{
+    return u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ cast_to_float(Float4_ u)
+{
+    return u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ cast_to_float(Float8_ u)
+{
+    return u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 cast_to_float(uint32_t u)
+{
+    return half2_to_float2(u);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ cast_to_float(uint2 u)
+{
+    Float4_ tmp;
+    tmp.x = half2_to_float2(u.x);
+    tmp.y = half2_to_float2(u.y);
+    return tmp;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ cast_to_float(uint4 u)
+{
+    Float8_ tmp;
+    tmp.x = half2_to_float2(u.x);
+    tmp.y = half2_to_float2(u.y);
+    tmp.z = half2_to_float2(u.z);
+    tmp.w = half2_to_float2(u.w);
+    return tmp;
+}
+
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float float_from_int8(int8_t u)
+{
+    return u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 float_from_int8(int16_t u)
+{
+    union {
+        int16_t int16;
+        int8_t  int8[2];
+    };
+    int16 = u;
+    return make_float2(int8[0], int8[1]);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float4 float_from_int8(int32_t u)
+{
+    union {
+        int32_t int32;
+        int8_t  int8[4];
+    };
+    int32 = u;
+    return make_float4(int8[0], int8[1], int8[2], int8[3]);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// clang-format off
+inline __device__ Float8_ float_from_int8(int64_t u)
+{
+    union {
+        int64_t int64;
+        int16_t int16[4];
+    };
+    int64 = u;
+    return Float8_ {float_from_int8(int16[0]),
+                    float_from_int8(int16[1]),
+                    float_from_int8(int16[2]),
+                    float_from_int8(int16[3])};
+}
+// clang-format on
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ int8_t cast_to_int8(float val)
+{
+    union {
+        int8_t  int8[2];
+        int16_t int16;
+    };
+    asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
+    return int8[0];
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ int32_t cast_to_int8(float4 val)
+{
+    union {
+        int8_t  int8[4];
+        int32_t int32;
+    };
+    int8[0] = cast_to_int8(val.x);
+    int8[1] = cast_to_int8(val.y);
+    int8[2] = cast_to_int8(val.z);
+    int8[3] = cast_to_int8(val.w);
+    return int32;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ int64_t cast_to_int8(Float8_ val)
+{
+    union {
+        int8_t  int8[8];
+        int64_t int64;
+    };
+    int8[0] = cast_to_int8(val.x.x);
+    int8[1] = cast_to_int8(val.x.y);
+    int8[2] = cast_to_int8(val.y.x);
+    int8[3] = cast_to_int8(val.y.y);
+    int8[4] = cast_to_int8(val.z.x);
+    int8[5] = cast_to_int8(val.z.y);
+    int8[6] = cast_to_int8(val.w.x);
+    int8[7] = cast_to_int8(val.w.y);
+    return int64;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+inline __device__ __host__ T div_up(T m, T n)
+{
+    return (m + n - 1) / n;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, bool DO_CROSS_ATTENTION>
+inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
+                                 int                                                      threads_per_value,
+                                 int                                                      threads_per_block)
+{
+    // The amount of shared memory needed to store the Q*K^T values in float.
+    const int max_timesteps = min(params.timestep, params.memory_max_len);
+    size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
+
+    // The extra memory needed if we are not using floats for the final logits.
+    size_t logits_sz = 0;
+#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
+    if (sizeof(T) != 4) {
+        // TDOD
+        logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
+                                           div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
+    }
+#endif
+
+    // The total size needed during softmax.
+    size_t softmax_sz = qk_sz + logits_sz;
+
+    // The number of partial rows to reduce in the final reduction.
+    int rows_per_red = threads_per_block / threads_per_value;
+    // The amount of storage needed to finalize the outputs.
+    size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
+
+    size_t transpose_rotary_size = 0;
+    if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
+        transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
+    }
+
+    // The max.
+    return max(max(softmax_sz, red_sz), transpose_rotary_size);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ constexpr uint32_t shfl_mask(int threads)
+{
+    return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<
+    // The type of the inputs. Supported types: float and half.
+    typename T,
+    // The hidden dimension per head.
+    int Dh,
+    int Dh_MAX,
+    // The number of threads per key.
+    int THREADS_PER_KEY,
+    // The number of threads per value.
+    int THREADS_PER_VALUE,
+    // The number of threads in a threadblock.
+    int  THREADS_PER_BLOCK,
+    bool DO_CROSS_ATTENTION>
+__global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
+{
+
+    // Make sure the hidden dimension per head is a multiple of the number of threads per key.
+    static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
+    // Make sure the hidden dimension per head is a multiple of the number of threads per value.
+    static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
+
+    // The size of a warp.
+    constexpr int WARP_SIZE = 32;
+    // The number of warps in a threadblock.
+    constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
+
+    // Use smem_size_in_bytes (above) to determine the amount of shared memory.
+    extern __shared__ char smem_[];
+
+    // The shared memory for the Q*K^T values and partial logits in softmax.
+    float* qk_smem = reinterpret_cast<float*>(smem_);
+
+    // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
+    char* logits_smem_ = smem_;
+#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
+    if (sizeof(T) != 4) {
+        // TODO - change to tlength
+        const int max_timesteps = min(params.timestep, params.memory_max_len);
+        logits_smem_ +=
+            (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
+    }
+    T* logits_smem = reinterpret_cast<T*>(logits_smem_);
+#else
+    float* logits_smem = reinterpret_cast<float*>(logits_smem_);
+#endif
+
+    // The shared memory to do the final reduction for the output values. Reuse qk_smem.
+    T* out_smem = reinterpret_cast<T*>(smem_);
+
+    // The shared memory buffers for the block-wide reductions. One for max, one for sum.
+    __shared__ float red_smem[WARPS_PER_BLOCK * 2];
+
+    // A vector of Q or K elements for the current timestep.
+    using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
+
+    // Use alignment for safely casting the shared buffers as Qk_vec.
+    // Shared memory to store Q inputs.
+    __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
+
+    // This is one of the reasons we should have a separate kernel for cross attention
+    __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
+
+    // A vector of Q or K elements for the current timestep.
+    using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
+    // The number of elements per vector.
+    constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
+    // Make sure the hidden size per head is a multiple of the vector size.
+    static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
+    // We will use block wide reduction if needed
+    // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
+    // The number of vectors per warp.
+    constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
+
+    // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
+    // owns x elements, we have to decompose the linear index into chunks of x values and the posi-
+    // tion of the thread in that chunk.
+
+    // The number of elements in a chunk of 16B (that's the x in the above formula).
+    constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
+    // The number of K vectors in 16B.
+    constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
+
+    // The batch/beam idx
+    const int bi = blockIdx.y;
+    if (params.finished != nullptr && params.finished[bi] == true) {
+        return;
+    }
+    // The beam idx
+    const int beami = bi % params.beam_width;
+    // The "beam-aware" batch idx
+    const int bbi = bi / params.beam_width;
+    // The head.
+    const int hi = blockIdx.x;
+    // Combine the batch and the head indices.
+    const int bhi = bi * params.num_heads + hi;
+    // Combine the "beam-aware" batch idx and the head indices.
+    const int bbhi = bbi * params.beam_width * params.num_heads + hi;
+    // The thread in the block.
+    const int tidx = threadIdx.x;
+
+    const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
+
+    // While doing the product Q*K^T for the different keys we track the max.
+    float qk_max = -FLT_MAX;
+
+    float qk = 0.0F;
+
+    int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
+
+    const size_t bi_seq_len_offset = bi * params.memory_max_len;
+
+    // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
+    int       tlength      = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
+                             (params.length_per_sample == nullptr) ?
+                                                    params.timestep :
+                                                    params.length_per_sample[bi] + params.max_prefix_prompt_length;
+    const int first_step   = max(0, tlength + 1 - params.memory_max_len);
+    const int tlength_circ = tlength % params.memory_max_len;
+
+    // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
+    const bool is_masked = tidx >= QK_VECS_PER_WARP;
+
+    // The offset in the Q and K buffer also accounts for the batch.
+    int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
+    // The offset in the bias buffer.
+    int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
+
+    const bool do_ia3      = handle_kv && params.ia3_tasks != nullptr;
+    const int  ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
+
+    // Trigger the loads from the Q and K buffers.
+    Qk_vec q;
+    zero(q);
+    if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
+        if (params.int8_mode == 2) {
+            using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
+            using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
+            const auto q_scaling = params.qkv_scale_out[0];
+            const auto q_quant =
+                *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
+
+            convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
+        }
+        else {
+            q = *reinterpret_cast<const Qk_vec*>(&params.q[qk_offset]);
+        }
+    }
+
+    Qk_vec k;
+    zero(k);
+    if (DO_CROSS_ATTENTION) {
+        // The 16B chunk written by the thread.
+        int co = tidx / QK_VECS_IN_16B;
+        // The position of the thread in that 16B chunk.
+        int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
+
+        // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
+        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
+                     // params.timestep*QK_ELTS_IN_16B +
+                     tlength * QK_ELTS_IN_16B + ci;
+        k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
+                *reinterpret_cast<const Qk_vec*>(&params.k_cache[offset]) :
+                k;
+    }
+    else {
+        if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
+            if (params.int8_mode == 2) {
+                using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
+                using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
+                const auto k_scaling = params.qkv_scale_out[1];
+                const auto k_quant =
+                    *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
+
+                convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
+            }
+            else {
+                k = *reinterpret_cast<const Qk_vec*>(&params.k[qk_offset]);
+            }
+        }
+    }
+
+    // Trigger the loads from the Q and K bias buffers.
+    Qk_vec q_bias;
+    zero(q_bias);
+    q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
+                 *reinterpret_cast<const Qk_vec*>(&params.q_bias[qk_bias_offset]) :
+                 q_bias;
+
+    Qk_vec k_bias;
+    zero(k_bias);
+    if (handle_kv) {
+        k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
+                     *reinterpret_cast<const Qk_vec*>(&params.k_bias[qk_bias_offset]) :
+                     k_bias;
+    }
+
+    // Computes the Q/K values with bias.
+    q = add(q, q_bias);
+    if (handle_kv) {
+        k = add(k, k_bias);
+    }
+    if (do_ia3 && !is_masked) {
+        k = mul<Qk_vec, Qk_vec, Qk_vec>(
+            k,
+            *reinterpret_cast<const Qk_vec*>(
+                &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
+    }
+
+    // Padded len
+    const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
+    if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
+        if (handle_kv) {
+            apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
+        }
+        else {
+            apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
+        }
+    }
+    else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
+        const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
+
+        T* q_smem = reinterpret_cast<T*>(smem_);
+        T* k_smem = q_smem + params.rotary_embedding_dim;
+
+        const int half_rotary_dim = params.rotary_embedding_dim / 2;
+        const int half_idx        = (tidx * QK_VEC_SIZE) / half_rotary_dim;
+        const int intra_half_idx  = (tidx * QK_VEC_SIZE) % half_rotary_dim;
+        const int smem_pitch      = half_rotary_dim;  // TODO: adjust for bank conflicts
+
+        assert(half_rotary_dim % QK_VEC_SIZE == 0);
+
+        if (do_rotary) {
+            *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
+
+            if (handle_kv) {
+                *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
+            }
+        }
+
+        __syncthreads();
+
+        const int     transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
+        constexpr int tidx_factor   = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
+        if (do_rotary) {
+            mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
+
+            if (handle_kv) {
+                mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
+
+                mmha::apply_rotary_embedding(
+                    q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len);
+
+                mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
+            }
+            else {
+                mmha::apply_rotary_embedding(
+                    q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep);
+            }
+            mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
+        }
+
+        __syncthreads();
+
+        if (do_rotary) {
+            q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
+            if (handle_kv) {
+                k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
+            }
+        }
+
+        __syncthreads();
+    }
+
+    if (!is_masked) {
+        // Store the Q values to shared memory.
+        *reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
+
+        // Store Dh values of k_bias into smem, since will need to add later
+        // if params.timestep == 0
+        if (DO_CROSS_ATTENTION && params.timestep == 0) {
+            *reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
+        }
+
+        // Write the K values to the global memory cache.
+        //
+        // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
+        // system. We designed it this way as it allows much better memory loads (and there are many
+        // more loads) + the stores are really "write and forget" since we won't need the ack before
+        // the end of the kernel. There's plenty of time for the transactions to complete.
+
+        // The 16B chunk written by the thread.
+        int co = tidx / QK_VECS_IN_16B;
+        // The position of the thread in that 16B chunk.
+        int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
+
+        // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
+        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
+                     // params.timestep*QK_ELTS_IN_16B +
+                     tlength_circ * QK_ELTS_IN_16B + ci;
+
+        if (handle_kv) {
+            // Trigger the stores to global memory.
+            if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
+                *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
+            }
+        }
+
+        // Compute \sum_i Q[i] * K^T[i] for the current timestep.
+#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
+        using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
+#else
+        using Qk_vec_acum = Qk_vec;
+#endif
+        qk = dot<Qk_vec_acum, Qk_vec>(q, k);
+        if (QK_VECS_PER_WARP <= WARP_SIZE) {
+#pragma unroll
+            for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
+                qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
+            }
+        }
+    }
+
+    if (QK_VECS_PER_WARP > WARP_SIZE) {
+        constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
+        qk                          = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
+    }
+
+    // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
+    if (tidx == 0) {
+        // Normalize qk.
+        qk *= params.inv_sqrt_dh;
+        if (params.relative_attention_bias != nullptr) {
+            qk = add(qk,
+                     params.relative_attention_bias[hi * params.relative_attention_bias_stride
+                                                        * params.relative_attention_bias_stride
+                                                    + (tlength - padd_len) * params.relative_attention_bias_stride
+                                                    + (tlength - padd_len)]);
+        }
+        // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
+
+        qk_max                        = qk;
+        qk_smem[tlength - first_step] = qk;
+        // qk_smem[params.timestep] = qk;
+    }
+
+    // Make sure the data is in shared memory.
+    __syncthreads();
+
+    // The type of queries and keys for the math in the Q*K^T product.
+    using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
+    // The number of elements per vector.
+    constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
+    // Make sure the hidden size per head is a multiple of the vector size.
+    static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
+    // The number of elements per thread.
+    constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
+    // The number of vectors per thread.
+    constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
+
+    // The position the first key loaded by each thread from the cache buffer (for this B * H).
+    int ko = tidx / THREADS_PER_KEY;
+    // The position of the thread in the chunk of keys.
+    int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
+
+    static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
+
+    // Load the Q values from shared memory. The values are reused during the loop on K.
+    K_vec q_vec[K_VECS_PER_THREAD];
+#pragma unroll
+    for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+        q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
+    }
+
+    K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
+    if (DO_CROSS_ATTENTION && params.timestep == 0) {
+#pragma unroll
+        for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+            k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
+        }
+    }
+
+    // The number of timesteps loaded per iteration.
+    constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
+    // The number of keys per warp.
+    constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
+
+    // The base pointer for the key in the cache buffer.
+    T* k_cache = &params.k_cache[bhi * params.memory_max_len * Dh + ki];
+    // Base pointer for the beam's batch, before offsetting with indirection buffer
+    T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
+
+    // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
+    // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
+    int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
+
+    // prefix prompt length if has
+    const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
+
+    // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
+    const bool has_beams    = params.cache_indir != nullptr;
+    const int* beam_indices = has_beams ? &params.cache_indir[bi_seq_len_offset] : nullptr;
+
+    for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
+        const int ti_circ = ti % params.memory_max_len;
+
+        // The keys loaded from the key cache.
+        K_vec k[K_VECS_PER_THREAD];
+        K_vec k_vec_zero;
+        zero(k_vec_zero);
+#pragma unroll
+        for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+            int jj = ii * params.memory_max_len + ti_circ;
+            // if( ti < params.timestep ) {
+            const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
+            if (ti < tlength) {
+                if (!within_bounds) {
+                    k[ii] = k_vec_zero;
+                }
+                else {
+                    if (has_beams) {
+                        const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
+                        k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
+                    }
+                    else {
+                        k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
+                    }
+                }
+                // add bias and update k_cache
+                if (DO_CROSS_ATTENTION && params.timestep == 0) {
+                    k[ii] = add(k[ii], k_bias_vec[ii]);
+
+                    if (do_ia3) {
+                        k[ii] = mul<K_vec, K_vec, K_vec>(
+                            k[ii],
+                            *reinterpret_cast<const K_vec*>(
+                                &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
+                                                        + ii * THREADS_PER_KEY * K_VEC_SIZE]));
+                    }
+
+                    if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
+                        *reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
+                    }
+                }
+            }
+        }
+
+        // Perform the dot product and normalize qk.
+        //
+        // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
+        float qk      = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
+        bool  is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
+
+        // Store the product to shared memory. There's one qk value per timestep. Update the max.
+        // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
+        if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
+            if (params.relative_attention_bias != nullptr) {
+                qk = add(qk,
+                         params.relative_attention_bias[hi * params.relative_attention_bias_stride
+                                                            * params.relative_attention_bias_stride
+                                                        + tlength * params.relative_attention_bias_stride + ti]);
+            }
+            if (params.linear_bias_slopes != nullptr) {
+                // Apply the linear position bias: (ki - qi) * slope[hi].
+                // The padding token locates between the input context and the generated tokens.
+                // We need to remove the number of padding tokens in the distance computation.
+                //   ti   : 0 1 2 3 4 5 6 7 8 9(tlength)
+                //   token: i i i i p p p o o o where i=input, p=pad, o=output.
+                // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
+                int   max_context_length = params.max_prefix_prompt_length + params.max_input_length;
+                float dist               = (ti < max_context_length ? ti + padd_len : ti) - tlength;
+
+                qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
+            }
+            qk_max                   = is_mask ? qk_max : fmaxf(qk_max, qk);
+            qk_smem[ti - first_step] = qk;
+        }
+    }
+
+// Perform the final reduction to compute the max inside each warp.
+//
+// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
+// group so it's not needed to run the reduction inside the group (again).
+#pragma unroll
+    for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
+        qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+    }
+
+    // Decompose the thread index into warp and lane.
+    const int warp = tidx / WARP_SIZE;
+    const int lane = tidx % WARP_SIZE;
+
+    // The warp leader writes the max to shared memory.
+    if (lane == 0) {
+        red_smem[warp] = qk_max;
+    }
+
+    // Make sure the products are in shared memory.
+    __syncthreads();
+
+    // The warps finalize the reduction.
+    qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+    for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
+        qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+    }
+
+    // Broadcast to all the threads in the warp.
+    qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+
+    // Compute the logits and start the sum.
+    float sum = 0.f;
+    // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
+    for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
+        bool  is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
+        float logit   = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
+        sum += logit;
+        qk_smem[ti - first_step] = logit;
+    }
+
+    // Compute the sum.
+    sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
+
+    // Normalize the logits.
+    float inv_sum = __fdividef(1.f, sum + 1.e-6f);
+    // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
+    const size_t cross_attention_out_offset =
+        params.is_return_cross_attentions ?
+            bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
+            0;
+    for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
+        float logit = qk_smem[ti - first_step] * inv_sum;
+        if (params.is_return_cross_attentions) {
+            params.cross_attention_out[cross_attention_out_offset + ti] = logit;
+        }
+        convert_from_float(logits_smem[ti - first_step], logit);
+    }
+
+    // Put Values part below so we leverage __syncthreads
+    // from the previous step
+
+    // The number of elements per vector.
+    constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
+    // A vector of V elements for the current timestep.
+    using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
+
+    // The value computed by this thread.
+    int vo = tidx / THREADS_PER_VALUE;
+    // The hidden dimensions computed by this particular thread.
+    int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
+
+    // The base pointer for the value in the cache buffer.
+    T* v_cache = &params.v_cache[bhi * params.memory_max_len * Dh + vi];
+    // Base pointer for the beam's batch, before offsetting with indirection buffer
+    T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
+
+    // The number of values processed per iteration of the loop.
+    constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
+
+    // One group of threads computes the product(s) for the current timestep.
+    V_vec v_bias;
+    zero(v_bias);
+    // if( vo == params.timestep % V_PER_ITER ) {
+    if (Dh == Dh_MAX || vi < Dh) {
+        if (handle_kv) {
+            if (vo == tlength % V_PER_ITER) {
+                // Trigger the loads from the V bias buffer.
+                if (params.v_bias != nullptr) {
+                    v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi * Dh + vi]);
+                }
+                if (DO_CROSS_ATTENTION) {
+                    *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
+                }
+            }
+        }
+    }
+
+    // From previous, before values, step
+    // Also make sure the logits are in shared memory.
+    __syncthreads();
+
+    // Values continued
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+    using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
+#else
+    using V_vec_acum = V_vec;
+#endif
+    // The partial outputs computed by each thread.
+    V_vec_acum out;
+    zero(out);
+
+    // Loop over the timesteps to compute the partial outputs.
+    // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
+    if (Dh == Dh_MAX || vi < Dh) {
+        for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
+            const int ti_circ = ti % params.memory_max_len;
+
+            // Fetch offset based on cache_indir when beam sampling
+            const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
+            const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
+            // Load the values from the cache.
+            V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
+            if (DO_CROSS_ATTENTION && params.timestep == 0) {
+                v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
+                if (do_ia3) {
+                    v = mul<V_vec, V_vec, V_vec>(
+                        v,
+                        *reinterpret_cast<const V_vec*>(
+                            &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
+                }
+                *reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
+            }
+            // Load the logits from shared memory.
+#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
+            float logit = logits_smem[ti - first_step];
+            out         = fma(logit, cast_to_float(v), out);
+#else
+            T logit = logits_smem[ti - first_step];
+
+            // Update the partial sums.
+            out = fma(logit, v, out);
+#endif
+        }
+    }
+
+    // One group of threads computes the product(s) for the current timestep.
+    // if( vo == params.timestep % V_PER_ITER ) {
+    if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
+
+        V_vec v;
+        if (DO_CROSS_ATTENTION) {
+            v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
+        }
+        else {
+            // Trigger the loads from the V buffer.
+            const auto v_offset = qkv_base_offset + vi;
+            if (params.int8_mode == 2) {
+                using Packed_Int8_t  = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
+                using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
+                const auto v_scaling = params.qkv_scale_out[2];
+                const auto v_quant =
+                    *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
+
+                convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
+            }
+            else {
+                v = *reinterpret_cast<const V_vec*>(&params.v[v_offset]);
+            }
+            // Trigger the loads from the V bias buffer.
+            // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
+        }
+
+        // Compute the V values with bias.
+        if (handle_kv) {
+            v = add(v, v_bias);
+
+            if (do_ia3) {
+                v = mul<V_vec, V_vec, V_vec>(
+                    v,
+                    *reinterpret_cast<const V_vec*>(
+                        &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
+            }
+
+            // Store the values with bias back to global memory in the cache for V.
+            //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
+            *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
+        }
+
+        // Initialize the output value with the current timestep.
+#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
+        // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
+        out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
+#else
+        // out = fma(logits_smem[params.timestep], v, out);
+        out = fma(logits_smem[tlength - first_step], v, out);
+#endif
+    }
+
+    // Make sure we can start writing to shared memory.
+    __syncthreads();
+
+    // Run the final reduction amongst the different groups computing different partial outputs.
+    if (Dh == Dh_MAX || vi < Dh) {
+#pragma unroll
+        for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
+
+            // The midpoint in the number of active groups.
+            int midpoint = active_groups / 2;
+
+            // The upper part of active threads store to shared memory.
+            if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+                convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
+#else
+                *reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
+#endif
+            }
+            __syncthreads();
+
+            // The bottom warps update their values.
+            if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
+                out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
+            }
+            __syncthreads();
+        }
+    }
+
+    // Output the final values.
+    if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
+#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
+        if (params.int8_mode == 2) {
+            using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
+            out                 = mul<V_vec_acum, float>(*params.attention_out_scale, out);
+            *reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
+                cast_to_int8(out);
+        }
+        else {
+            convert_from_float(*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]), out);
+        }
+#else
+        // TODO: support int8_mode?
+        *reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]) = out;
+#endif
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+}  // namespace mmha
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
+void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);

+ 1788 - 0
csrc/ft_attention/decoder_masked_multihead_attention_utils.h

@@ -0,0 +1,1788 @@
+// Downloaded from from FasterTransformer v5.2.1
+// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "cuda_bf16_wrapper.h"
+#include "cuda_bf16_fallbacks.cuh"
+#include <stdint.h>
+
+using namespace fastertransformer;
+
+namespace mmha {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Float8_ {
+    float2 x;
+    float2 y;
+    float2 z;
+    float2 w;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Float4_ {
+    float2 x;
+    float2 y;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+struct bf16_4_t {
+    __nv_bfloat162 x;
+    __nv_bfloat162 y;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct bf16_8_t {
+    __nv_bfloat162 x;
+    __nv_bfloat162 y;
+    __nv_bfloat162 z;
+    __nv_bfloat162 w;
+};
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct num_elems;
+template<>
+struct num_elems<float> {
+    static constexpr int value = 1;
+};
+template<>
+struct num_elems<float2> {
+    static constexpr int value = 2;
+};
+template<>
+struct num_elems<float4> {
+    static constexpr int value = 4;
+};
+template<>
+struct num_elems<Float4_> {
+    static constexpr int value = 4;
+};
+template<>
+struct num_elems<Float8_> {
+    static constexpr int value = 8;
+};
+
+template<>
+struct num_elems<uint32_t> {
+    static constexpr int value = 2;
+};
+template<>
+struct num_elems<uint2> {
+    static constexpr int value = 4;
+};
+template<>
+struct num_elems<uint4> {
+    static constexpr int value = 8;
+};
+
+#ifdef ENABLE_BF16
+template<>
+struct num_elems<__nv_bfloat162> {
+    static constexpr int value = 2;
+};
+template<>
+struct num_elems<bf16_4_t> {
+    static constexpr int value = 4;
+};
+template<>
+struct num_elems<bf16_8_t> {
+    static constexpr int value = 8;
+};
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T, int N>
+struct packed_type;
+template<typename T>
+struct packed_type<T, 1> {
+    using type = T;
+};
+template<>
+struct packed_type<int8_t, 2> {
+    using type = int16_t;
+};
+template<>
+struct packed_type<int8_t, 4> {
+    using type = int32_t;
+};
+template<>
+struct packed_type<int8_t, 8> {
+    using type = int64_t;
+};
+
+template<>
+struct packed_type<float, 2> {
+    using type = float2;
+};
+template<>
+struct packed_type<float, 4> {
+    using type = float4;
+};
+template<>
+struct packed_type<float, 8> {
+    using type = Float8_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float add(float a, float b)
+{
+    return a + b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 add(float2 a, float2 b)
+{
+    float2 c;
+    c.x = add(a.x, b.x);
+    c.y = add(a.y, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float4 add(float4 a, float4 b)
+{
+    float4 c;
+    c.x = add(a.x, b.x);
+    c.y = add(a.y, b.y);
+    c.z = add(a.z, b.z);
+    c.w = add(a.w, b.w);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b)
+{
+    return a + b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b)
+{
+    return bf16hadd2(a, b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b)
+{
+    bf16_4_t c;
+    c.x = add(a.x, b.x);
+    c.y = add(a.y, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b)
+{
+    bf16_8_t c;
+    c.x = add(a.x, b.x);
+    c.y = add(a.y, b.y);
+    c.z = add(a.z, b.z);
+    c.w = add(a.w, b.w);
+    return c;
+}
+#endif  // ENABLE_BF16
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint16_t add(uint16_t a, uint16_t b)
+{
+    uint16_t c;
+    asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint32_t add(uint32_t a, uint32_t b)
+{
+    uint32_t c;
+    asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint2 add(uint2 a, uint2 b)
+{
+    uint2 c;
+    c.x = add(a.x, b.x);
+    c.y = add(a.y, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint4 add(uint4 a, uint4 b)
+{
+    uint4 c;
+    c.x = add(a.x, b.x);
+    c.y = add(a.y, b.y);
+    c.z = add(a.z, b.z);
+    c.w = add(a.w, b.w);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint16_t float_to_half(float f)
+{
+    union {
+        uint32_t u32;
+        uint16_t u16[2];
+    } tmp;
+#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800  // Is it better?
+  float zero = 0.f;
+  asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
+#else
+    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+#endif
+    return tmp.u16[0];
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint32_t float2_to_half2(float2 f)
+{
+    union {
+        uint32_t u32;
+        uint16_t u16[2];
+    } tmp;
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+#else
+    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+#endif
+    return tmp.u32;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float half_to_float(uint16_t h)
+{
+    float f;
+    asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+    return f;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 half2_to_float2(uint32_t v)
+{
+    uint16_t lo, hi;
+    asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
+    return make_float2(half_to_float(lo), half_to_float(hi));
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float add(float a, uint16_t b)
+{
+    return a + half_to_float(b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+inline __device__ float add(float a, __nv_bfloat16 b)
+{
+    return a + __bfloat162float(b);
+}
+#endif
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 add(uint32_t a, float2 fb)
+{
+    float2 fa = half2_to_float2(a);
+    return add(fa, fb);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ add(uint2 a, Float4_ fb)
+{
+    Float4_ fc;
+    fc.x = add(a.x, fb.x);
+    fc.y = add(a.y, fb.y);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ add(uint4 a, Float8_ fb)
+{
+    Float8_ fc;
+    fc.x = add(a.x, fb.x);
+    fc.y = add(a.y, fb.y);
+    fc.z = add(a.z, fb.z);
+    fc.w = add(a.w, fb.w);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint32_t h0_h0(uint16_t a)
+{
+    uint32_t b;
+    asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
+    return b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float fma(float a, float b, float c)
+{
+    return a * b + c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c)
+{
+    float2 d;
+    d.x = fma(a.x, b.x, c.x);
+    d.y = fma(a.y, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 fma(float a, float2 b, float2 c)
+{
+    float2 d;
+    d.x = fma(a, b.x, c.x);
+    d.y = fma(a, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c)
+{
+    float4 d;
+    d.x = fma(a.x, b.x, c.x);
+    d.y = fma(a.y, b.y, c.y);
+    d.z = fma(a.z, b.z, c.z);
+    d.w = fma(a.w, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float4 fma(float a, float4 b, float4 c)
+{
+    float4 d;
+    d.x = fma(a, b.x, c.x);
+    d.y = fma(a, b.y, c.y);
+    d.z = fma(a, b.z, c.z);
+    d.w = fma(a, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c)
+{
+    Float4_ d;
+    d.x = fma(a, b.x, c.x);
+    d.y = fma(a, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c)
+{
+    Float8_ d;
+    d.x = fma(a, b.x, c.x);
+    d.y = fma(a, b.y, c.y);
+    d.z = fma(a, b.z, c.z);
+    d.w = fma(a, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+inline __device__ float2 add(__nv_bfloat162 a, float2 fb)
+{
+    float2 fa = bf1622float2(a);
+    return add(fa, fb);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ add(bf16_4_t a, Float4_ fb)
+{
+    Float4_ fc;
+    fc.x = add(a.x, fb.x);
+    fc.y = add(a.y, fb.y);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ add(bf16_8_t a, Float8_ fb)
+{
+    Float8_ fc;
+    fc.x = add(a.x, fb.x);
+    fc.y = add(a.y, fb.y);
+    fc.z = add(a.z, fb.z);
+    fc.w = add(a.w, fb.w);
+    return fc;
+}
+#endif  // ENABLE_BF16
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c)
+{
+    uint32_t d;
+    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c)
+{
+    return fma(h0_h0(a), b, c);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c)
+{
+    uint2 d;
+    d.x = fma(a.x, b.x, c.x);
+    d.y = fma(a.y, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c)
+{
+    uint32_t s = h0_h0(a);
+    uint2    d;
+    d.x = fma(s, b.x, c.x);
+    d.y = fma(s, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c)
+{
+    uint4 d;
+    d.x = fma(a.x, b.x, c.x);
+    d.y = fma(a.y, b.y, c.y);
+    d.z = fma(a.z, b.z, c.z);
+    d.w = fma(a.w, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c)
+{
+    uint32_t s = h0_h0(a);
+    uint4    d;
+    d.x = fma(s, b.x, c.x);
+    d.y = fma(s, b.y, c.y);
+    d.z = fma(s, b.z, c.z);
+    d.w = fma(s, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float fma(uint16_t a, uint16_t b, float fc)
+{
+    float fa = half_to_float(a);
+    float fb = half_to_float(b);
+    return fa * fb + fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc)
+{
+    float2 fa = half2_to_float2(a);
+    float2 fb = half2_to_float2(b);
+    return fma(fa, fb, fc);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc)
+{
+    return fma(h0_h0(a), b, fc);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc)
+{
+    Float4_ fd;
+    fd.x = fma(a.x, b.x, fc.x);
+    fd.y = fma(a.y, b.y, fc.y);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc)
+{
+    uint32_t s = h0_h0(a);
+    Float4_  fd;
+    fd.x = fma(s, b.x, fc.x);
+    fd.y = fma(s, b.y, fc.y);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc)
+{
+    Float8_ fd;
+    fd.x = fma(a.x, b.x, fc.x);
+    fd.y = fma(a.y, b.y, fc.y);
+    fd.z = fma(a.z, b.z, fc.z);
+    fd.w = fma(a.w, b.w, fc.w);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc)
+{
+    uint32_t s = h0_h0(a);
+    Float8_  fd;
+    fd.x = fma(s, b.x, fc.x);
+    fd.y = fma(s, b.y, fc.y);
+    fd.z = fma(s, b.z, fc.z);
+    fd.w = fma(s, b.w, fc.w);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+#ifdef ENABLE_BF16
+inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c)
+{
+    return bf16hfma2(a, b, c);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c)
+{
+    return bf16hfma2(bf162bf162(a), b, c);
+}
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c)
+{
+    bf16_4_t d;
+    d.x = fma(a.x, b.x, c.x);
+    d.y = fma(a.y, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    bf16_4_t       d;
+    d.x = fma(s, b.x, c.x);
+    d.y = fma(s, b.y, c.y);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c)
+{
+    bf16_8_t d;
+    d.x = fma(a.x, b.x, c.x);
+    d.y = fma(a.y, b.y, c.y);
+    d.z = fma(a.z, b.z, c.z);
+    d.w = fma(a.w, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    bf16_8_t       d;
+    d.x = fma(s, b.x, c.x);
+    d.y = fma(s, b.y, c.y);
+    d.z = fma(s, b.z, c.z);
+    d.w = fma(s, b.w, c.w);
+    return d;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc)
+{
+    return __bfloat162float(a) * __bfloat162float(b) + fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc)
+{
+    float2 fa = bf1622float2(a);
+    float2 fb = bf1622float2(b);
+    return fma(fa, fb, fc);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc)
+{
+    return fma(bf162bf162(a), b, fc);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc)
+{
+    Float4_ fd;
+    fd.x = fma(a.x, b.x, fc.x);
+    fd.y = fma(a.y, b.y, fc.y);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    Float4_        fd;
+    fd.x = fma(s, b.x, fc.x);
+    fd.y = fma(s, b.y, fc.y);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc)
+{
+    Float8_ fd;
+    fd.x = fma(a.x, b.x, fc.x);
+    fd.y = fma(a.y, b.y, fc.y);
+    fd.z = fma(a.z, b.z, fc.z);
+    fd.w = fma(a.w, b.w, fc.w);
+    return fd;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    Float8_        fd;
+    fd.x = fma(s, b.x, fc.x);
+    fd.y = fma(s, b.y, fc.y);
+    fd.z = fma(s, b.z, fc.z);
+    fd.w = fma(s, b.w, fc.w);
+    return fd;
+}
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename Acc, typename A, typename B>
+inline __device__ Acc mul(A a, B b)
+{
+    return a * b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float mul<float, float>(float a, float b)
+{
+    return a * b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float2 mul(float2 a, float2 b)
+{
+    float2 c;
+    c.x = a.x * b.x;
+    c.y = a.y * b.y;
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float2 mul(float a, float2 b)
+{
+    float2 c;
+    c.x = a * b.x;
+    c.y = a * b.y;
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float4 mul(float4 a, float4 b)
+{
+    float4 c;
+    c.x = a.x * b.x;
+    c.y = a.y * b.y;
+    c.z = a.z * b.z;
+    c.w = a.w * b.w;
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float4 mul(float a, float4 b)
+{
+    float4 c;
+    c.x = a * b.x;
+    c.y = a * b.y;
+    c.z = a * b.z;
+    c.w = a * b.w;
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float8_ mul(float a, Float8_ b)
+{
+    Float8_ c;
+    c.x = make_float2(a * b.x.x, a * b.x.y);
+    c.y = make_float2(a * b.y.x, a * b.y.y);
+    c.z = make_float2(a * b.z.x, a * b.z.y);
+    c.w = make_float2(a * b.w.x, a * b.w.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint16_t mul(uint16_t a, uint16_t b)
+{
+    uint16_t c;
+    asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint32_t mul(uint32_t a, uint32_t b)
+{
+    uint32_t c;
+    asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint32_t mul(uint16_t a, uint32_t b)
+{
+    return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint2 mul(uint2 a, uint2 b)
+{
+    uint2 c;
+    c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+    c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint2 mul(uint16_t a, uint2 b)
+{
+    uint32_t s = h0_h0(a);
+    uint2    c;
+    c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+    c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint4 mul(uint4 a, uint4 b)
+{
+    uint4 c;
+    c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
+    c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
+    c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
+    c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ uint4 mul(uint16_t a, uint4 b)
+{
+    uint32_t s = h0_h0(a);
+    uint4    c;
+    c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
+    c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
+    c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
+    c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float mul(uint16_t a, uint16_t b)
+{
+    float fa = half_to_float(a);
+    float fb = half_to_float(b);
+    return fa * fb;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float mul(uint16_t a, float b)
+{
+    return half_to_float(a) * b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float2 mul(uint32_t a, uint32_t b)
+{
+    float2 fa = half2_to_float2(a);
+    float2 fb = half2_to_float2(b);
+    return mul<float2, float2, float2>(fa, fb);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float2 mul(uint16_t a, uint32_t b)
+{
+    return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float4_ mul(uint2 a, uint2 b)
+{
+    Float4_ fc;
+    fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+    fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float4_ mul(uint16_t a, uint2 b)
+{
+    uint32_t s = h0_h0(a);
+    Float4_  fc;
+    fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+    fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float8_ mul(uint4 a, uint4 b)
+{
+    Float8_ fc;
+    fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
+    fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
+    fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
+    fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float8_ mul(uint16_t a, uint4 b)
+{
+    uint32_t s = h0_h0(a);
+    Float8_  fc;
+    fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
+    fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
+    fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
+    fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+template<>
+inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b)
+{
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    return __hmul(a, b);
+#else
+    return bf16hmul(a, b);
+#endif
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
+{
+    return bf16hmul2(a, b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
+{
+    return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b)
+{
+    bf16_4_t c;
+    c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
+    c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    bf16_4_t       c;
+    c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
+    c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
+{
+    bf16_8_t c;
+    c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
+    c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
+    c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
+    c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    bf16_8_t       c;
+    c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
+    c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
+    c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
+    c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
+    return c;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b)
+{
+    float fa = (float)a;
+    float fb = (float)b;
+    return fa * fb;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float mul(__nv_bfloat16 a, float b)
+{
+    return __bfloat162float(a) * b;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b)
+{
+    float2 fa = bf1622float2(a);
+    float2 fb = bf1622float2(b);
+    return mul<float2, float2, float2>(fa, fb);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b)
+{
+    return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b)
+{
+    Float4_ fc;
+    fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
+    fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    Float4_        fc;
+    fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
+    fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b)
+{
+    Float8_ fc;
+    fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
+    fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
+    fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
+    fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
+    return fc;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<>
+inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b)
+{
+    __nv_bfloat162 s = bf162bf162(a);
+    Float8_        fc;
+    fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
+    fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
+    fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
+    fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
+    return fc;
+}
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(float v)
+{
+    return v;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(float2 v)
+{
+    return v.x + v.y;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(float4 v)
+{
+    return v.x + v.y + v.z + v.w;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+#ifdef ENABLE_BF16
+inline __device__ float sum(__nv_bfloat162 v)
+{
+    float2 vf = bf1622float2(v);
+    return vf.x + vf.y;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(bf16_4_t v)
+{
+    return sum(v.x) + sum(v.y);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(bf16_8_t v)
+{
+    return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
+}
+#endif  // ENABLE_BF16
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(uint16_t v)
+{
+    return half_to_float(v);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(uint32_t v)
+{
+    float2 tmp = half2_to_float2(v);
+    return tmp.x + tmp.y;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(uint2 v)
+{
+    uint32_t c = add(v.x, v.y);
+    return sum(c);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(uint4 v)
+{
+#if 1
+    uint32_t c = add(v.x, v.y);
+    c          = add(c, v.z);
+    c          = add(c, v.w);
+#else
+    uint32_t c = add(v.x, v.y);
+    uint32_t d = add(v.z, v.w);
+    c          = add(c, d);
+#endif
+    return sum(c);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(Float4_ v)
+{
+    return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float sum(Float8_ v)
+{
+    return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+inline __device__ float dot(T a, T b)
+{
+    return sum(mul<T, T, T>(a, b));
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename A, typename T>
+inline __device__ float dot(T a, T b)
+{
+    return sum(mul<A, T, T>(a, b));
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ void zero(uint16_t& dst)
+{
+    dst = uint16_t(0);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+inline __device__ void zero(T& dst)
+{
+    constexpr int WORDS = sizeof(T) / 4;
+    union {
+        T        raw;
+        uint32_t words[WORDS];
+    } tmp;
+#pragma unroll
+    for (int ii = 0; ii < WORDS; ++ii) {
+        tmp.words[ii] = 0u;
+    }
+    dst = tmp.raw;
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step)
+{
+    const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim);
+    return {cos(inv_freq), sin(inv_freq)};
+}
+
+inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef)
+{
+    float2 rot_v;
+    rot_v.x = coef.x * v.x - coef.y * v.y;
+    rot_v.y = coef.x * v.y + coef.y * v.x;
+    return rot_v;
+}
+
+inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef)
+{
+    float2 fv     = half2_to_float2(v);
+    float2 rot_fv = rotary_embedding_transform(fv, coef);
+    return float2_to_half2(rot_fv);
+}
+
+#ifdef ENABLE_BF16
+inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef)
+{
+    float2 fv     = bf1622float2(v);
+    float2 rot_fv = rotary_embedding_transform(fv, coef);
+    return __floats2bfloat162_rn(rot_fv.x, rot_fv.y);
+}
+#endif
+
+inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step)
+{
+    return;
+}
+
+inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step)
+{
+    return;
+}
+
+inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (2 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+    q               = rotary_embedding_transform(q, coef);
+}
+
+inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (2 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+    q               = rotary_embedding_transform(q, coef);
+    k               = rotary_embedding_transform(k, coef);
+}
+
+inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (4 * tid >= rot_embed_dim) {
+        return;
+    }
+
+    Float4_&   q_    = *reinterpret_cast<Float4_*>(&q);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+    q_.x             = rotary_embedding_transform(q_.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+    q_.y             = rotary_embedding_transform(q_.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (4 * tid >= rot_embed_dim) {
+        return;
+    }
+
+    Float4_&   q_    = *reinterpret_cast<Float4_*>(&q);
+    Float4_&   k_    = *reinterpret_cast<Float4_*>(&k);
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+    q_.x             = rotary_embedding_transform(q_.x, coef0);
+    k_.x             = rotary_embedding_transform(k_.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+    q_.y             = rotary_embedding_transform(q_.y, coef1);
+    k_.y             = rotary_embedding_transform(k_.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (2 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+    q               = rotary_embedding_transform(q, coef);
+}
+
+inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (2 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+    q               = rotary_embedding_transform(q, coef);
+    k               = rotary_embedding_transform(k, coef);
+}
+
+inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (4 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (4 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    k.x              = rotary_embedding_transform(k.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+    k.y              = rotary_embedding_transform(k.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (8 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+    q.z              = rotary_embedding_transform(q.z, coef2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+    q.w              = rotary_embedding_transform(q.w, coef3);
+}
+
+inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (8 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    k.x              = rotary_embedding_transform(k.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+    k.y              = rotary_embedding_transform(k.y, coef1);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+    q.z              = rotary_embedding_transform(q.z, coef2);
+    k.z              = rotary_embedding_transform(k.z, coef2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+    q.w              = rotary_embedding_transform(q.w, coef3);
+    k.w              = rotary_embedding_transform(k.w, coef3);
+}
+
+#ifdef ENABLE_BF16
+inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (2 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+    q               = rotary_embedding_transform(q, coef);
+}
+
+inline __device__ void
+apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (2 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
+    q               = rotary_embedding_transform(q, coef);
+    k               = rotary_embedding_transform(k, coef);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (4 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (4 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    k.x              = rotary_embedding_transform(k.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+    k.y              = rotary_embedding_transform(k.y, coef1);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step)
+{
+    if (8 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+    q.z              = rotary_embedding_transform(q.z, coef2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+    q.w              = rotary_embedding_transform(q.w, coef3);
+}
+
+inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step)
+{
+    if (8 * tid >= rot_embed_dim) {
+        return;
+    }
+    const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
+    q.x              = rotary_embedding_transform(q.x, coef0);
+    k.x              = rotary_embedding_transform(k.x, coef0);
+    const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
+    q.y              = rotary_embedding_transform(q.y, coef1);
+    k.y              = rotary_embedding_transform(k.y, coef1);
+    const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
+    q.z              = rotary_embedding_transform(q.z, coef2);
+    k.z              = rotary_embedding_transform(k.z, coef2);
+    const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
+    q.w              = rotary_embedding_transform(q.w, coef3);
+    k.w              = rotary_embedding_transform(k.w, coef3);
+}
+#endif  // ENABLE_BF16
+
+template<typename Vec_T, typename T>
+__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch)
+{
+    return;
+}
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t u32;
+        uint16_t u16[2];
+    } tmp;
+    tmp.u16[0] = smem[transpose_idx];
+    tmp.u16[1] = smem[smem_pitch + transpose_idx];
+
+    vec = tmp.u32;
+}
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t u32;
+        uint16_t u16[2];
+    } tmp_1, tmp_2;
+    tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
+    tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
+
+    union {
+        uint2    u32x2;
+        uint16_t u16[4];
+    } tmp_3;
+    tmp_3.u16[0] = tmp_1.u16[0];
+    tmp_3.u16[1] = tmp_2.u16[0];
+    tmp_3.u16[2] = tmp_1.u16[1];
+    tmp_3.u16[3] = tmp_2.u16[1];
+
+    vec = tmp_3.u32x2;
+}
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint64_t u64;
+        uint16_t u16[4];
+    } tmp_1, tmp_2;
+    tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
+    tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
+
+    union {
+        uint4    u32x4;
+        uint16_t u16[8];
+    } tmp_3;
+    tmp_3.u16[0] = tmp_1.u16[0];
+    tmp_3.u16[1] = tmp_2.u16[0];
+    tmp_3.u16[2] = tmp_1.u16[1];
+    tmp_3.u16[3] = tmp_2.u16[1];
+    tmp_3.u16[4] = tmp_1.u16[2];
+    tmp_3.u16[5] = tmp_2.u16[2];
+    tmp_3.u16[6] = tmp_1.u16[3];
+    tmp_3.u16[7] = tmp_2.u16[3];
+
+    vec = tmp_3.u32x4;
+}
+
+#ifdef ENABLE_BF16
+template<>
+__device__ __inline__ void
+vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t      u32;
+        __nv_bfloat16 bf16[2];
+    } tmp_1, tmp_2;
+    tmp_1.u32 = *reinterpret_cast<uint32_t*>(&smem[transpose_idx]);
+    tmp_2.u32 = *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]);
+
+    vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
+    vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
+}
+
+template<>
+__device__ __inline__ void
+vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint64_t      u64;
+        __nv_bfloat16 bf16[4];
+    } tmp_1, tmp_2;
+    tmp_1.u64 = *reinterpret_cast<uint64_t*>(&smem[transpose_idx]);
+    tmp_2.u64 = *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]);
+
+    vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]};
+    vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]};
+    vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]};
+    vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]};
+}
+#endif  // ENABLE_BF16
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch)
+{
+    vec.x = smem[transpose_idx];
+    vec.z = smem[transpose_idx + 1];
+    vec.y = smem[smem_pitch + transpose_idx];
+    vec.w = smem[smem_pitch + transpose_idx + 1];
+}
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t u32;
+        half     u16[2];
+    } tmp;
+    tmp.u16[0] = smem[transpose_idx];
+    tmp.u16[1] = smem[smem_pitch + transpose_idx];
+
+    vec = tmp.u32;
+}
+
+#ifdef ENABLE_BF16
+template<>
+__device__ __inline__ void
+vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
+{
+    vec.x = smem[transpose_idx];
+    vec.y = smem[smem_pitch + transpose_idx];
+}
+#endif
+
+template<>
+__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch)
+{
+    vec.x = smem[transpose_idx];
+    vec.y = smem[smem_pitch + transpose_idx];
+}
+
+template<typename Vec_T, typename T>
+__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch);
+
+template<>
+__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch)
+{
+    return;
+}
+
+#ifdef ENABLE_BF16
+template<>
+__device__ __inline__ void
+write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
+{
+    return;
+}
+
+template<>
+__device__ __inline__ void
+write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
+{
+    return;
+}
+#endif
+
+template<>
+__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint64_t u64;
+        uint16_t u16[4];
+    } tmp_1, tmp_2;
+
+    union {
+        uint4    u32x4;
+        uint16_t u16[8];
+    } tmp_3;
+    tmp_3.u32x4  = vec;
+    tmp_1.u16[0] = tmp_3.u16[0];
+    tmp_2.u16[0] = tmp_3.u16[1];
+    tmp_1.u16[1] = tmp_3.u16[2];
+    tmp_2.u16[1] = tmp_3.u16[3];
+    tmp_1.u16[2] = tmp_3.u16[4];
+    tmp_2.u16[2] = tmp_3.u16[5];
+    tmp_1.u16[3] = tmp_3.u16[6];
+    tmp_2.u16[3] = tmp_3.u16[7];
+
+    *reinterpret_cast<uint64_t*>(&smem[transpose_idx])              = tmp_1.u64;
+    *reinterpret_cast<uint64_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u64;
+}
+
+template<>
+__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t u32;
+        uint16_t u16[2];
+    } tmp_1, tmp_2;
+
+    union {
+        uint2    u32x2;
+        uint16_t u16[4];
+    } tmp_3;
+    tmp_3.u32x2  = vec;
+    tmp_1.u16[0] = tmp_3.u16[0];
+    tmp_2.u16[0] = tmp_3.u16[1];
+    tmp_1.u16[1] = tmp_3.u16[2];
+    tmp_2.u16[1] = tmp_3.u16[3];
+
+    *reinterpret_cast<uint32_t*>(&smem[transpose_idx])              = tmp_1.u32;
+    *reinterpret_cast<uint32_t*>(&smem[smem_pitch + transpose_idx]) = tmp_2.u32;
+}
+
+template<>
+__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t u32;
+        uint16_t u16[2];
+    } tmp;
+    tmp.u32 = vec;
+
+    smem[transpose_idx]              = tmp.u16[0];
+    smem[smem_pitch + transpose_idx] = tmp.u16[1];
+}
+
+template<>
+__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch)
+{
+    smem[transpose_idx]                  = vec.x;
+    smem[transpose_idx + 1]              = vec.z;
+    smem[smem_pitch + transpose_idx]     = vec.y;
+    smem[smem_pitch + transpose_idx + 1] = vec.w;
+}
+
+template<>
+__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch)
+{
+    union {
+        uint32_t u32;
+        half     u16[2];
+    } tmp;
+
+    tmp.u32                          = vec;
+    smem[transpose_idx]              = tmp.u16[0];
+    smem[smem_pitch + transpose_idx] = tmp.u16[1];
+}
+
+#ifdef ENABLE_BF16
+template<>
+__device__ __inline__ void
+write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
+{
+    smem[transpose_idx]              = vec.x;
+    smem[smem_pitch + transpose_idx] = vec.y;
+}
+#endif
+
+template<>
+__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch)
+{
+    smem[transpose_idx]              = vec.x;
+    smem[smem_pitch + transpose_idx] = vec.y;
+}
+
+}  // namespace mmha

+ 167 - 0
csrc/ft_attention/ft_attention.cpp

@@ -0,0 +1,167 @@
+#include <torch/extension.h>
+#include "ATen/cuda/CUDAContext.h"
+
+#include "decoder_masked_multihead_attention.h"
+
+#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
+#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+
+#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...)                  \
+  if (TYPE == at::ScalarType::Half) {                                      \
+    using scalar_t = at::Half;                                             \
+    __VA_ARGS__();                                                         \
+  } else if (TYPE == at::ScalarType::BFloat16) {                           \
+    using scalar_t = at::BFloat16;                                         \
+    __VA_ARGS__();                                                         \
+  } else if (TYPE == at::ScalarType::Float)  {                             \
+    using scalar_t = float;                                                \
+    __VA_ARGS__();                                                         \
+  } else {                                                                 \
+    AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
+  }
+
+// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...)                  \
+//   if (TYPE == at::ScalarType::Half) {                                      \
+//     using scalar_t = at::Half;                                             \
+//     __VA_ARGS__();                                                         \
+//   } else if (TYPE == at::ScalarType::Float)  {                             \
+//     using scalar_t = float;                                                \
+//     __VA_ARGS__();                                                         \
+//   } else {                                                                 \
+//     AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
+//   }
+
+template<typename T>
+void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
+                                const cudaStream_t& stream);
+
+template<typename T>
+void cross_multihead_attention(const Masked_multihead_attention_params<T>& params,
+                               const cudaStream_t& stream);
+
+template<typename T>
+struct SATypeConverter {
+    using Type = T;
+};
+
+template<>
+struct SATypeConverter<at::Half> {
+    using Type = uint16_t;
+};
+
+template<>
+struct SATypeConverter<at::BFloat16> {
+    using Type = __nv_bfloat16;
+};
+
+template <typename T>
+void set_params(Masked_multihead_attention_params<T> &params,
+                const size_t batch_size,
+                const size_t nheads,
+                const size_t memory_max_seqlen,
+                const size_t headdim,
+                const int timestep,
+                const int rotary_embedding_dim,
+                const bool neox_rotary_style,
+                T *q_ptr,
+                T *k_ptr,
+                T *v_ptr,
+                T *k_cache_ptr,
+                T *v_cache_ptr,
+                int *length_per_sample,
+                T *out_ptr) {
+    // Reset the parameters
+    memset(&params, 0, sizeof(params));
+    params.q = q_ptr;
+    params.k = k_ptr;
+    params.v = v_ptr;
+    params.q_bias = nullptr;
+    params.k_bias = nullptr;
+    params.v_bias = nullptr;
+    params.k_cache = k_cache_ptr;
+    params.v_cache = v_cache_ptr;
+    params.out = out_ptr;
+    params.cache_indir = nullptr;
+    params.stride = 0;
+    params.batch_size = batch_size;
+    params.beam_width = 1;
+    params.memory_max_len = memory_max_seqlen;
+    params.num_heads = nheads;
+    params.hidden_size_per_head = headdim;
+    params.rotary_embedding_dim = rotary_embedding_dim;
+    params.neox_rotary_style = neox_rotary_style;
+    params.timestep = timestep;
+    params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
+    params.total_padding_tokens = nullptr;
+    params.masked_tokens = nullptr;
+    params.prefix_prompt_lengths = nullptr;
+    // params.max_prefix_prompt_length = memory_max_seqlen;  // TODO: waht should this be?
+    params.max_prefix_prompt_length = 0;  // TODO: waht should this be?
+    params.relative_attention_bias = nullptr;
+    params.relative_attention_bias_stride = 0;
+    params.cross_attention_out = nullptr;
+    params.max_decoder_seq_len = 0;
+    params.is_return_cross_attentions = false;
+    params.finished = nullptr;
+    params.memory_length_per_sample = nullptr;
+    params.length_per_sample = length_per_sample;
+}
+
+torch::Tensor single_query_attention(const torch::Tensor q,
+                                     const torch::Tensor k,
+                                     const torch::Tensor v,
+                                     torch::Tensor k_cache,
+                                     torch::Tensor v_cache,
+                                     c10::optional<const torch::Tensor> length_per_sample_,
+                                     const int timestep,
+                                     const int rotary_embedding_dim = 0,
+                                     const bool neox_rotary_style=true) {
+    CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
+    int batch_size = v_cache.size(0);
+    int nheads = v_cache.size(1);
+    int memory_max_seqlen = v_cache.size(2);
+    int headdim = v_cache.size(3);
+    CHECK_SHAPE(q, batch_size, nheads, headdim);
+    CHECK_SHAPE(k, batch_size, nheads, headdim);
+    CHECK_SHAPE(v, batch_size, nheads, headdim);
+    // TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
+    // TODO: avoid contiguous requirment by storing the stride
+    CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v);
+    CHECK_CONTIGUOUS(v_cache);
+
+    if (length_per_sample_.has_value()) {
+        auto length_per_sample = length_per_sample_.value();
+        CHECK_DEVICE(length_per_sample);
+        CHECK_SHAPE(length_per_sample, batch_size);
+        CHECK_CONTIGUOUS(length_per_sample);
+        TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
+    }
+
+    torch::Tensor out = torch::empty_like(q);
+
+    DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {
+        using DataType = typename SATypeConverter<scalar_t>::Type;
+        Masked_multihead_attention_params<DataType> params;
+        set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
+                   rotary_embedding_dim, neox_rotary_style,
+                   reinterpret_cast<DataType*>(q.data_ptr()),
+                   reinterpret_cast<DataType*>(k.data_ptr()),
+                   reinterpret_cast<DataType*>(v.data_ptr()),
+                   reinterpret_cast<DataType*>(k_cache.data_ptr()),
+                   reinterpret_cast<DataType*>(v_cache.data_ptr()),
+                   length_per_sample_.has_value()
+                       ? length_per_sample_.value().data_ptr<int>() : nullptr,
+                   reinterpret_cast<DataType*>(out.data_ptr()));
+        auto stream = at::cuda::getCurrentCUDAStream();
+        masked_multihead_attention(params, stream);
+    });
+    return out;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("single_query_attention", &single_query_attention, "Attention with a single query",
+          py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
+          py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
+          py::arg("neox_rotary_style")=true);
+}

+ 143 - 0
csrc/ft_attention/setup.py

@@ -0,0 +1,143 @@
+# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
+import torch
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
+from setuptools import setup, find_packages
+import subprocess
+
+import sys
+import warnings
+import os
+
+# ninja build does not work unless include_dirs are abs path
+this_dir = os.path.dirname(os.path.abspath(__file__))
+
+
+def get_cuda_bare_metal_version(cuda_dir):
+    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
+    output = raw_output.split()
+    release_idx = output.index("release") + 1
+    release = output[release_idx].split(".")
+    bare_metal_major = release[0]
+    bare_metal_minor = release[1][0]
+
+    return raw_output, bare_metal_major, bare_metal_minor
+
+
+def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
+    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
+    torch_binary_major = torch.version.cuda.split(".")[0]
+    torch_binary_minor = torch.version.cuda.split(".")[1]
+
+    print("\nCompiling cuda extensions with")
+    print(raw_output + "from " + cuda_dir + "/bin\n")
+
+    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
+        raise RuntimeError(
+            "Cuda extensions are being compiled with a version of Cuda that does "
+            "not match the version used to compile Pytorch binaries.  "
+            "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+            + "In some cases, a minor-version mismatch will not cause later errors:  "
+            "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
+            "You can try commenting out this check (at your own risk)."
+        )
+
+
+def raise_if_cuda_home_none(global_option: str) -> None:
+    if CUDA_HOME is not None:
+        return
+    raise RuntimeError(
+        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
+        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
+        "only images whose names contain 'devel' will provide nvcc."
+    )
+
+
+def append_nvcc_threads(nvcc_extra_args):
+    _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+    if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
+        return nvcc_extra_args + ["--threads", "4"]
+    return nvcc_extra_args
+
+
+if not torch.cuda.is_available():
+    # https://github.com/NVIDIA/apex/issues/486
+    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
+    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
+    print(
+        "\nWarning: Torch did not find available GPUs on this system.\n",
+        "If your intention is to cross-compile, this is not an error.\n"
+        "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
+        "Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
+        "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
+        "If you wish to cross-compile for a single specific architecture,\n"
+        'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
+    )
+    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
+        _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+        if int(bare_metal_major) == 11:
+            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
+            if int(bare_metal_minor) > 0:
+                os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
+        else:
+            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
+
+print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
+TORCH_MAJOR = int(torch.__version__.split(".")[0])
+TORCH_MINOR = int(torch.__version__.split(".")[1])
+
+cmdclass = {}
+ext_modules = []
+
+# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
+# See https://github.com/pytorch/pytorch/pull/70650
+generator_flag = []
+torch_dir = torch.__path__[0]
+if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
+    generator_flag = ["-DOLD_GENERATOR_PATH"]
+
+raise_if_cuda_home_none("--ft_attention")
+# Check, if CUDA11 is installed for compute capability 8.0
+cc_flag = []
+# cc_flag.append("-gencode")
+# cc_flag.append("arch=compute_70,code=sm_70")
+cc_flag.append("-gencode")
+cc_flag.append("arch=compute_80,code=sm_80")
+
+ext_modules.append(
+    CUDAExtension(
+        name="ft_attention",
+        sources=[
+            "ft_attention.cpp",
+            "decoder_masked_multihead_attention.cu",
+        ],
+        extra_compile_args={
+            "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag,
+            "nvcc": append_nvcc_threads(
+                [
+                    "-DENABLE_BF16",  # TODO
+                    "-O3",
+                    "-U__CUDA_NO_HALF_OPERATORS__",
+                    "-U__CUDA_NO_HALF_CONVERSIONS__",
+                    "-U__CUDA_NO_BFLOAT16_OPERATORS__",
+                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+                    "-U__CUDA_NO_BFLOAT162_OPERATORS__",
+                    "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
+                    "--expt-relaxed-constexpr",
+                    "--expt-extended-lambda",
+                    "--use_fast_math",
+                ]
+                + generator_flag
+                + cc_flag
+            ),
+        },
+        include_dirs=[this_dir],
+    )
+)
+
+setup(
+    name="ft_attention",
+    version="0.1",
+    description="Attention for single query from FasterTransformer",
+    ext_modules=ext_modules,
+    cmdclass={"build_ext": BuildExtension} if ext_modules else {},
+)