Sfoglia il codice sorgente

Add Megatron attention implementation for benchmarking

Tri Dao 2 anni fa
parent
commit
ed553e9238

+ 32 - 3
benchmarks/benchmark_causal.py

@@ -6,12 +6,17 @@ import torch.nn.functional as F
 
 
 from einops import rearrange, repeat
 from einops import rearrange, repeat
 
 
-from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
+from flash_attn.utils.benchmark import benchmark_all, pytorch_profiler
 from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
 from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
 from flash_attn.triton.fused_attention import attention as attention
 from flash_attn.triton.fused_attention import attention as attention
 
 
+try:
+    from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
+except ImportError:
+    scaled_upper_triang_masked_softmax = None
 
 
-def attention_pytorch(qkv, dropout_p=0.0, causal=False):
+
+def attention_pytorch(qkv, dropout_p=0.0, causal=True):
     """
     """
     Arguments:
     Arguments:
         qkv: (batch_size, seqlen, 3, nheads, head_dim)
         qkv: (batch_size, seqlen, 3, nheads, head_dim)
@@ -53,10 +58,31 @@ def attention_triton(q, k, v):
     return attention(q, k, v, softmax_scale)
     return attention(q, k, v, softmax_scale)
 
 
 
 
+def attention_megatron(qkv):
+    """
+    Arguments:
+        qkv: (batch_size, seqlen, 3, nheads, head_dim)
+    Output:
+        output: (batch_size, seqlen, nheads, head_dim)
+    """
+    batch_size, seqlen, _, nheads, d = qkv.shape
+    q, k, v = qkv.unbind(dim=2)
+    q = rearrange(q, 'b t h d -> (b h) t d')
+    k = rearrange(k, 'b s h d -> (b h) d s')
+    softmax_scale = 1.0 / math.sqrt(d)
+    # Preallocate attn_weights for `baddbmm`
+    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
+    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
+                       '(b h) t s -> b h t s', h=nheads)
+    attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0)
+    output = torch.einsum('bhts,bshd->bthd', attention, v)
+    return output.to(dtype=qkv.dtype)
+
+
 torch.manual_seed(0)
 torch.manual_seed(0)
 repeats = 30
 repeats = 30
 batch_size = 2
 batch_size = 2
-seqlen = 2048
+seqlen = 4096
 nheads = 12
 nheads = 12
 headdim = 128
 headdim = 128
 dropout_p = 0.0
 dropout_p = 0.0
@@ -77,3 +103,6 @@ benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
 q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
 q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
                        requires_grad=True) for _ in range(3)]
                        requires_grad=True) for _ in range(3)]
 benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton')
 benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton')
+
+if scaled_upper_triang_masked_softmax is not None:
+    benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')

+ 148 - 0
csrc/fused_softmax/fused_softmax.cpp

@@ -0,0 +1,148 @@
+/* coding=utf-8
+ * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cuda_fp16.h>
+#include <torch/extension.h>
+#include <vector>
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_masked_softmax {
+
+torch::Tensor fwd_cuda(
+    torch::Tensor const& input, 
+    torch::Tensor const& mask,
+    float scale_factor);
+
+torch::Tensor bwd_cuda(
+    torch::Tensor const& output_grads, 
+    torch::Tensor const& softmax_results,
+    float scale_factor);
+
+int get_batch_per_block_cuda(
+    int query_seq_len,
+    int key_seq_len,
+    int batches,
+    int attn_heads);
+
+torch::Tensor fwd(
+    torch::Tensor const& input,
+    torch::Tensor const& mask,
+    float scale_factor) {
+  AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
+  AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
+	     (input.scalar_type() == at::ScalarType::BFloat16), 
+      "Only fp16 and bf16 are supported");
+  AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
+
+  return fwd_cuda(input, mask, scale_factor);
+}
+
+torch::Tensor bwd(
+    torch::Tensor const& output_grads, 
+    torch::Tensor const& softmax_results,
+    float scale_factor) {
+
+  AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
+  AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
+
+  AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
+	     (output_grads.scalar_type() == at::ScalarType::BFloat16), 
+      "Only fp16 and bf16 are supported");
+  AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
+	     (softmax_results.scalar_type() == at::ScalarType::BFloat16), 
+      "Only fp16 and bf16 are supported");
+
+  return bwd_cuda(output_grads, softmax_results, scale_factor);
+}
+
+int get_batch_per_block(
+    int query_seq_len,
+    int key_seq_len,
+    int batches,
+    int attn_heads) {
+    return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
+}
+
+} // end namespace scaled_masked_softmax
+} // end namespace fused_softmax
+} // end namespace multihead_attn
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_upper_triang_masked_softmax {
+
+torch::Tensor fwd_cuda(
+    torch::Tensor const& input,
+    float scale_factor);
+
+torch::Tensor bwd_cuda(
+    torch::Tensor const& output_grads,
+    torch::Tensor const& softmax_results,
+    float scale_factor);
+
+torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
+  AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
+  AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
+	     (input.scalar_type() == at::ScalarType::BFloat16),
+      "Only fp16 and bf16 are supported");
+
+  return fwd_cuda(input, scale_factor);
+}
+
+torch::Tensor bwd(
+    torch::Tensor const& output_grads,
+    torch::Tensor const& softmax_results,
+    float scale_factor) {
+
+  AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
+  AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
+
+  AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
+	     (output_grads.scalar_type() == at::ScalarType::BFloat16),
+      "Only fp16 and bf16 are supported");
+  AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
+	     (softmax_results.scalar_type() == at::ScalarType::BFloat16),
+      "Only fp16 and bf16 are supported");
+
+  return bwd_cuda(output_grads, softmax_results, scale_factor);
+}
+
+} // end namespace scaled_upper_triang_masked_softmax
+} // end namespace fused_softmax
+} // end namespace multihead_attn
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("scaled_masked_softmax_forward",
+        &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, 
+	"Self Multihead Attention scaled, time masked softmax -- Forward.");
+
+  m.def("scaled_masked_softmax_backward",
+        &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
+	"Self Multihead Attention scaled, time masked softmax -- Backward.");
+
+  m.def("scaled_masked_softmax_get_batch_per_block",
+        &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
+        "Return Batch per block size."
+  );
+
+  m.def("scaled_upper_triang_masked_softmax_forward",
+        &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
+        "Self Multihead Attention scaled, time masked softmax -- Forward.");
+  m.def("scaled_upper_triang_masked_softmax_backward",
+        &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
+        "Self Multihead Attention scaled, time masked softmax -- Backward.");
+}

+ 528 - 0
csrc/fused_softmax/scaled_masked_softmax.h

@@ -0,0 +1,528 @@
+/* coding=utf-8
+ * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <assert.h>
+#include <cuda_fp16.h>
+#include <cfloat>
+#include <limits>
+#include <stdint.h>
+#include <cuda_fp16.h>
+#include <c10/macros/Macros.h>
+
+namespace {
+
+template <typename Datatype, int ELEMENTS_PER_LDG>
+__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
+
+template <>
+__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
+
+template <>
+__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
+
+template <>
+__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
+
+template <>
+__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
+
+template <>
+__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
+
+template <>
+__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
+
+int log2_ceil(int value) {
+    int log2_value = 0;
+    while ((1 << log2_value) < value) ++log2_value;
+    return log2_value;
+}
+
+template<typename T>
+struct Add {
+  __device__ __forceinline__ T operator()(T a, T b) const {
+    return a + b;
+  }
+};
+
+template<typename T>
+struct Max {
+  __device__ __forceinline__ T operator()(T a, T b) const {
+    return a < b ? b : a;
+  }
+};
+
+template <typename T>
+__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if CUDA_VERSION >= 9000
+    return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+    return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
+__device__ __forceinline__ void warp_reduce(acc_t* sum) {
+    ReduceOp<acc_t> r;
+    #pragma unroll
+    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
+        #pragma unroll
+        for (int i = 0;  i < WARP_BATCH;  ++i) {
+            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
+            sum[i] = r(sum[i], b);
+        }
+    }
+}
+
+/*
+ * Extended softmax (from native aten pytorch) with following additional features
+ * 1) input scaling
+ * 2) Explicit masking
+ */	
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_masked_softmax_warp_forward(
+    output_t *dst, 
+    const input_t *src,
+    const uint8_t *mask, 
+    const acc_t scale, 
+    int micro_batch_size, 
+    int element_count,
+    int pad_batches) 
+{
+    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
+    // warp_size of method warp_softmax_forward_kernel.
+    constexpr int next_power_of_two = 1 << log2_elements;
+    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
+
+    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
+    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
+    int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH;
+    int pad_first_batch = 0;
+    if (pad_batches != 1) { // bert style
+        pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH;
+    } else { // gpt2 style
+        pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
+    }
+
+    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
+    // many batches have to computed within this WARP.
+    int local_batches = micro_batch_size - first_batch;
+    if (local_batches > WARP_BATCH)
+        local_batches = WARP_BATCH;
+
+    // there might be multiple batches per warp. compute the index within the batch
+    int local_idx = threadIdx.x;
+
+    src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
+    dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
+    mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
+
+    // load data from global memory
+    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
+    input_t temp_data[ELEMENTS_PER_LDG_STG];
+    uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+
+            if (element_index < batch_element_count) {
+                int itr_idx = i*element_count+it*WARP_SIZE;
+                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
+                copy_vector<uint8_t, ELEMENTS_PER_LDG_STG>(temp_mask, mask + itr_idx);
+
+                #pragma unroll
+                  for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                      if (temp_mask[element] != 1) {
+                          elements[i][it + element] = (acc_t)temp_data[element] * scale;
+                      } else {
+                          elements[i][it + element] = -10000.0;
+                      }
+                  }
+            } else {
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
+                }
+            }
+        }
+    }
+
+    // compute max_value
+    acc_t max_value[WARP_BATCH];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        max_value[i] = elements[i][0];
+        #pragma unroll
+        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
+            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+        }
+    }
+    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
+
+    // compute scale value to account for full mask
+    acc_t scale_value[WARP_BATCH];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
+    }
+ 
+    acc_t sum[WARP_BATCH] { 0.0f };
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            elements[i][it] = std::exp((elements[i][it] - max_value[i]));
+            sum[i] += elements[i][it];
+        }
+    }
+    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+    // store result
+    output_t out[ELEMENTS_PER_LDG_STG];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        if (i >= local_batches)
+            break;
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+            if (element_index < element_count) {
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    out[element] = elements[i][it + element] * scale_value[i]/ sum[i];
+                }
+                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count + it * WARP_SIZE, out);  
+            } else {
+                break;
+            } 
+        }
+    }
+}
+
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_masked_softmax_warp_backward(
+    output_t *gradInput, 
+    input_t *grad, 
+    const input_t *output,
+    acc_t scale, 
+    int micro_batch_size, 
+    int element_count)
+{
+    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
+    // warp_size of method warp_softmax_backward_kernel.
+    constexpr int next_power_of_two = 1 << log2_elements;
+    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
+
+    // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
+    // gridDim/blockIdx = (seq_len, attn_heads, batches) 
+    int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
+    
+    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
+    // many batches have to computed within this WARP.
+    int local_batches = micro_batch_size - first_batch;
+    if (local_batches > WARP_BATCH)
+        local_batches = WARP_BATCH;
+
+    // there might be multiple batches per warp. compute the index within the batch
+    int local_idx = threadIdx.x;
+
+    // the first element to process by the current thread
+    int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
+    grad += thread_offset;
+    output += thread_offset;
+    gradInput += thread_offset;
+
+    // load data from global memory
+    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
+    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
+    input_t temp_grad[ELEMENTS_PER_LDG_STG];
+    input_t temp_output[ELEMENTS_PER_LDG_STG];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : element_count;
+
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+            if (element_index < batch_element_count) {
+                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count + it * WARP_SIZE);
+                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count + it * WARP_SIZE);
+
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    output_reg[i][it + element] = (acc_t)temp_output[element];
+                }
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
+                }
+            } 
+        }
+    }
+   
+    acc_t sum[WARP_BATCH];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        sum[i] = grad_reg[i][0];
+        #pragma unroll
+        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
+            sum[i] += grad_reg[i][it];
+        }
+    }
+    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+    // store result
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        if (i >= local_batches)
+            break;
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+            if (element_index < element_count) {
+                // compute gradients
+                output_t out[ELEMENTS_PER_LDG_STG];
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
+                }
+                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count + it * WARP_SIZE, out);
+            } 
+        }
+    }
+}
+} // end of anonymous namespace
+
+int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
+    int log2_elements = log2_ceil(key_seq_len);
+    const int next_power_of_two = 1 << log2_elements;
+
+    int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+    constexpr int threads_per_block = 128;
+    int warps_per_block = (threads_per_block / warp_size);
+    int batches_per_block = warps_per_block * batches_per_warp;
+
+    return batches_per_block;
+}
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_masked_softmax_forward(
+    output_t *dst, 
+    const input_t *src, 
+    const uint8_t *mask,
+    const input_t scale, 
+    int query_seq_len, 
+    int key_seq_len, 
+    int batches,
+    int attn_heads,
+    int pad_batches)
+{
+    TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 );
+    if (key_seq_len == 0) {
+        return;
+    } else {
+        int log2_elements = log2_ceil(key_seq_len);
+        const int next_power_of_two = 1 << log2_elements;
+        int batch_count = batches * attn_heads * query_seq_len;
+
+        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
+        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
+        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+        // use 128 threads per block to maximimize gpu utilization
+        constexpr int threads_per_block = 128;
+
+        int warps_per_block = (threads_per_block / warp_size);
+        int batches_per_block = warps_per_block * batches_per_warp;
+        TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
+        dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
+        dim3 threads(warp_size, warps_per_block, 1);
+        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+        switch (log2_elements) {
+            case 0: // 1
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 1: // 2
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 2: // 4
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 3: // 8
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 4: // 16
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 5: // 32
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 6: // 64
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 7: // 128
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 8: // 256
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 9: // 512
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 10: // 1024
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 11: // 2048
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 12: // 4096
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            case 13: // 8192
+                scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches);
+                break;
+            default:
+                break;
+        }
+    }
+}
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_masked_softmax_backward(
+    output_t *grad_input, 
+    input_t *grad, 
+    const input_t *output, 
+    const acc_t scale, 
+    int query_seq_len, 
+    int key_seq_len, 
+    int batches,
+    int attn_heads)
+{
+    TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 );
+    if (key_seq_len == 0) {
+       return;
+    } else {
+        int log2_elements = log2_ceil(key_seq_len);
+        const int next_power_of_two = 1 << log2_elements;
+        int batch_count = batches *  attn_heads * query_seq_len;
+
+        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
+        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
+        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+        // use 128 threads per block to maximimize gpu utilization
+        constexpr int threads_per_block = 128;
+
+        int warps_per_block = (threads_per_block / warp_size);
+        int batches_per_block = warps_per_block * batches_per_warp;
+        int blocks = batch_count/batches_per_block;
+        dim3 threads(warp_size, warps_per_block, 1);
+        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+        switch (log2_elements) {
+            case 0: // 1
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 1: // 2
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 2: // 4
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 3: // 8
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 4: // 16
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 5: // 32
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 6: // 64
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 7: // 128
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 8: // 256
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 9: // 512
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 10: // 1024
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 11: // 2048
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 12: // 4096
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            case 13: // 8192
+                scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, key_seq_len);
+                break;
+            default:
+                break;
+        }
+    }
+}

+ 121 - 0
csrc/fused_softmax/scaled_masked_softmax_cuda.cu

@@ -0,0 +1,121 @@
+/* coding=utf-8
+ * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <ATen/ATen.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#include <cuda_profiler_api.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+#include "scaled_masked_softmax.h"
+#include "type_shim.h"
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_masked_softmax {
+
+int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
+    return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
+}
+
+
+torch::Tensor fwd_cuda(
+    torch::Tensor const& input,
+    torch::Tensor const& mask,
+    float scale_factor)
+{
+  // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
+  const int batches = input.size(0);
+  const int pad_batches = mask.size(0);
+  const int attn_heads = input.size(1);
+  const int query_seq_len = input.size(2);
+  const int key_seq_len = input.size(3);
+  TORCH_INTERNAL_ASSERT(key_seq_len <= 8192);
+  TORCH_INTERNAL_ASSERT(query_seq_len > 1);
+  TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
+  TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
+  TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
+  TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
+
+  // Output 
+  auto act_options = input.options().requires_grad(false);
+  torch::Tensor softmax_results = 
+      torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
+
+  // Softmax Intermediate Result Ptr
+  void* input_ptr = static_cast<void*>(input.data_ptr());
+  void* mask_ptr = static_cast<void*>(mask.data_ptr());
+  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
+
+  DISPATCH_HALF_AND_BFLOAT(
+      input.scalar_type(),
+      "dispatch_scaled_masked_softmax_forward",
+      dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
+          reinterpret_cast<scalar_t*>(softmax_results_ptr),
+          reinterpret_cast<const scalar_t*>(input_ptr),
+          reinterpret_cast<const uint8_t*>(mask_ptr),
+          scale_factor,
+          query_seq_len,
+          key_seq_len,
+          batches,
+          attn_heads,
+          pad_batches
+      );
+  );
+  return softmax_results;
+}
+
+torch::Tensor bwd_cuda(
+    torch::Tensor const& output_grads_, 
+    torch::Tensor const& softmax_results_, 
+    float scale_factor)  {
+    
+  auto output_grads = output_grads_.contiguous();
+  auto softmax_results = softmax_results_.contiguous();
+
+  //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
+  const int batches = output_grads.size(0);
+  const int attn_heads = output_grads.size(1);
+  const int query_seq_len = output_grads.size(2);
+  const int key_seq_len = output_grads.size(3);
+
+  auto act_options = output_grads.options().requires_grad(false);
+  torch::Tensor input_grads = 
+      torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
+  void* input_grads_ptr = static_cast<void*>(input_grads.data_ptr());
+  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
+
+  //Softmax Grad
+  DISPATCH_HALF_AND_BFLOAT(
+      output_grads_.scalar_type(),
+      "dispatch_scaled_masked_softmax_backward",
+      dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
+          reinterpret_cast<scalar_t*>(input_grads_ptr), 
+          reinterpret_cast<scalar_t*>(output_grads_ptr), 
+          reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
+          scale_factor,
+          query_seq_len,
+          key_seq_len,
+          batches,
+          attn_heads
+      );
+  );
+  return input_grads;
+}
+}
+}
+}

+ 529 - 0
csrc/fused_softmax/scaled_upper_triang_masked_softmax.h

@@ -0,0 +1,529 @@
+/* coding=utf-8
+ * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <assert.h>
+#include <cuda_fp16.h>
+#include <cfloat>
+#include <limits>
+#include <stdint.h>
+#include <c10/macros/Macros.h>
+
+namespace {
+
+template <typename Datatype, int ELEMENTS_PER_LDG>
+__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
+
+template <>
+__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
+
+template <>
+__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
+  
+template <>
+__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
+
+template <>
+__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
+
+template <>
+__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
+
+template <>
+__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
+
+template <typename Datatype, int ELEMENTS_PER_LDG>
+__device__ __inline__ void copy_zero_vector(Datatype *dst);
+
+template <>
+__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) { *dst = 0.0; }
+
+template <>
+__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
+
+template <>
+__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
+
+template <>
+__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); }
+
+
+int log2_ceil(int value) {
+    int log2_value = 0;
+    while ((1 << log2_value) < value) ++log2_value;
+    return log2_value;
+}
+
+template<typename T>
+struct Add {
+  __device__ __forceinline__ T operator()(T a, T b) const {
+    return a + b;
+  }
+};
+
+template<typename T>
+struct Max {
+  __device__ __forceinline__ T operator()(T a, T b) const {
+    return a < b ? b : a;
+  }
+};
+
+template <typename T>
+__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
+{
+#if CUDA_VERSION >= 9000
+    return __shfl_xor_sync(mask, value, laneMask, width);
+#else
+    return __shfl_xor(value, laneMask, width);
+#endif
+}
+
+template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
+__device__ __forceinline__ void warp_reduce(acc_t* sum) {
+    ReduceOp<acc_t> r;
+    #pragma unroll
+    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
+        #pragma unroll
+        for (int i = 0;  i < WARP_BATCH;  ++i) {
+            acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
+            sum[i] = r(sum[i], b);
+        }
+    }
+}
+
+/*
+ * Extended softmax (from native aten pytorch) with following additional features
+ * 1) input scaling
+ * 2) Implicit time (diagonal masking)
+ */
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_upper_triang_masked_softmax_warp_forward(
+    output_t *dst, 
+    const input_t *src, 
+    const acc_t scale, 
+    int micro_batch_size, 
+    int stride, 
+    int element_count) 
+{
+    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
+    // warp_size of method warp_softmax_forward_kernel.
+    constexpr int next_power_of_two = 1 << log2_elements;
+    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
+
+    int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
+    int local_seq = blockIdx.x + 1; 
+    int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
+
+    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
+    // many batches have to computed within this WARP.
+    int local_batches = micro_batch_size - first_batch;
+    if (local_batches > WARP_BATCH)
+        local_batches = WARP_BATCH;
+
+    // there might be multiple batches per warp. compute the index within the batch
+    int local_idx = threadIdx.x;
+
+    src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
+    dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
+
+    // load data from global memory
+    acc_t elements[WARP_BATCH][WARP_ITERATIONS];
+    input_t temp_data[ELEMENTS_PER_LDG_STG];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : local_seq;
+
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+
+            if (element_index < batch_element_count) {
+                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + i*element_count*stride + it*WARP_SIZE);
+
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    if ((element_index + element) < batch_element_count) {
+                        elements[i][it+element] = (acc_t)temp_data[element] * scale;
+                    } else {
+                        elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
+                    }
+                }
+            } else {
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    elements[i][it + element] = -std::numeric_limits<acc_t>::infinity();
+                }
+            }
+        }
+    }
+
+    // compute max_value
+    acc_t max_value[WARP_BATCH];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        max_value[i] = elements[i][0];
+        #pragma unroll
+        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
+            max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
+        }
+    }
+    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
+
+    acc_t sum[WARP_BATCH] { 0.0f };
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
+            if (it < warp_iteration_limit) {
+                elements[i][it] = std::exp((elements[i][it] - max_value[i]));
+                sum[i] += elements[i][it];
+            } 
+        }
+    }
+    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+    // store result
+    output_t out[ELEMENTS_PER_LDG_STG];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        if (i >= local_batches)
+            break;
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+
+            if (element_index < local_seq) {
+
+                #pragma unroll  
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    if (element_index + element < local_seq) {
+                        out[element] = elements[i][it + element] / sum[i];
+                    } else {
+                        out[element] = 0;
+                    }
+                }
+                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE, out);
+            } else if (element_index < element_count) {
+                copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride + it * WARP_SIZE);
+            } else {
+                break;
+            } 
+        }
+    }
+}
+
+template <typename input_t, typename output_t, typename acc_t, int log2_elements>
+__global__ void scaled_upper_triang_masked_softmax_warp_backward(
+    output_t *gradInput, 
+    input_t *grad, 
+    const input_t *output,
+    acc_t scale, 
+    int micro_batch_size, 
+    int stride, 
+    int element_count)
+{
+    // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 
+    // warp_size of method warp_softmax_backward_kernel.
+    constexpr int next_power_of_two = 1 << log2_elements;
+    constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+    constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
+    constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
+
+    int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
+    int local_seq = blockIdx.x + 1; 
+    
+    // micro_batch_size might not be a multiple of WARP_BATCH. Check how
+    // many batches have to computed within this WARP.
+    int local_batches = micro_batch_size - first_batch;
+    if (local_batches > WARP_BATCH)
+        local_batches = WARP_BATCH;
+
+    // there might be multiple batches per warp. compute the index within the batch
+    int local_idx = threadIdx.x;
+
+    // the first element to process by the current thread
+    int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
+    grad += thread_offset;
+    output += thread_offset;
+    gradInput += thread_offset;
+
+    // load data from global memory
+    acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
+    acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f };
+    input_t temp_grad[ELEMENTS_PER_LDG_STG];
+    input_t temp_output[ELEMENTS_PER_LDG_STG];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        int batch_element_count = (i >= local_batches) ? 0 : local_seq;
+
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+            if (element_index < batch_element_count) {
+                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + i * element_count * stride + it * WARP_SIZE);
+                copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, output + i * element_count * stride + it * WARP_SIZE);
+
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    if (element_index + element < batch_element_count) {
+                        output_reg[i][it + element] = (acc_t)temp_output[element];
+                    }
+                }
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    if (element_index + element < batch_element_count) {
+                        grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element];
+                    }
+                }
+            }
+        }
+    }
+   
+    acc_t sum[WARP_BATCH];
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        sum[i] = grad_reg[i][0];
+        #pragma unroll
+        for (int it = 1;  it < WARP_ITERATIONS;  ++it) {
+            sum[i] += grad_reg[i][it];
+        }
+    }
+    warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Add>(sum);
+
+    // store result
+    #pragma unroll
+    for (int i = 0;  i < WARP_BATCH;  ++i) {
+        if (i >= local_batches)
+            break;
+        #pragma unroll
+        for (int it = 0;  it < WARP_ITERATIONS;  it+=ELEMENTS_PER_LDG_STG) {
+            int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE;
+            if (element_index < element_count) {
+                // compute gradients
+                output_t out[ELEMENTS_PER_LDG_STG];
+                #pragma unroll
+                for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
+                    out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i]));
+                }
+                copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + i * element_count * stride + it * WARP_SIZE, out);
+            } 
+        }
+    }
+}
+
+} // end of anonymous namespace
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_upper_triang_masked_softmax_forward(
+    output_t *dst, 
+    const input_t *src, 
+    const input_t scale, 
+    int softmax_elements, 
+    int softmax_elements_stride, 
+    int attn_batches)
+{
+    TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 );
+    if (softmax_elements == 0) {
+        return;
+    } else {
+        int log2_elements = log2_ceil(softmax_elements);
+        const int next_power_of_two = 1 << log2_elements;
+        int seq_len = softmax_elements;
+        int batch_count = attn_batches * seq_len;
+
+        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
+        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
+        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+        // use 128 threads per block to maximimize gpu utilization
+        constexpr int threads_per_block = 128;
+
+        int warps_per_block = (threads_per_block / warp_size);
+        int batches_per_block = warps_per_block * batches_per_warp;
+        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
+
+        int blocks_per_seq = attn_batches / batches_per_block;
+        dim3 blocks(seq_len, blocks_per_seq, 1);
+        dim3 threads(warp_size, warps_per_block, 1);
+        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+        switch (log2_elements) {
+            case 0: // 1
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 1: // 2
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 2: // 4
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 3: // 8
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 4: // 16
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 5: // 32
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 6: // 64
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 7: // 128
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 8: // 256
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 9: // 512
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 10: // 1024
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 11: // 2048
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 12: // 4096
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 13: // 8192
+                scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            default:
+                break;
+        }
+    }
+}
+
+template<typename input_t, typename output_t, typename acc_t>
+void dispatch_scaled_upper_triang_masked_softmax_backward(
+    output_t *grad_input, 
+    input_t *grad, 
+    const input_t *output, 
+    const acc_t scale, 
+    int softmax_elements, 
+    int softmax_elements_stride, 
+    int attn_batches)
+{
+    TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 );
+    if (softmax_elements == 0) {
+       return;
+    } else {
+        int log2_elements = log2_ceil(softmax_elements);
+        const int next_power_of_two = 1 << log2_elements;
+        int seq_len = softmax_elements;
+        int batch_count = attn_batches * seq_len;
+
+        // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
+        int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
+
+        // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
+        int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
+
+        // use 128 threads per block to maximimize gpu utilization
+        constexpr int threads_per_block = 128;
+
+        int warps_per_block = (threads_per_block / warp_size);
+        int batches_per_block = warps_per_block * batches_per_warp;
+        TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
+
+        int blocks_per_seq = attn_batches / batches_per_block;
+        dim3 blocks(seq_len, blocks_per_seq, 1);
+        dim3 threads(warp_size, warps_per_block, 1);
+        // Launch code would be more elegant if C++ supported FOR CONSTEXPR
+        switch (log2_elements) {
+            case 0: // 1
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 1: // 2
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 2: // 4
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 3: // 8
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 4: // 16
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 5: // 32
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 6: // 64
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 7: // 128
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 8: // 256
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 9: // 512
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 10: // 1024
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 11: // 2048
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 12: // 4096
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            case 13: // 8192
+                scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
+                    <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements);
+                break;
+            default:
+                break;
+        }
+    }
+}

+ 98 - 0
csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu

@@ -0,0 +1,98 @@
+/* coding=utf-8
+ * Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <ATen/ATen.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#include <cuda_profiler_api.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <torch/extension.h>
+#include "scaled_upper_triang_masked_softmax.h"
+#include "type_shim.h"
+
+namespace multihead_attn {
+namespace fused_softmax {
+namespace scaled_upper_triang_masked_softmax {
+
+torch::Tensor fwd_cuda(
+    torch::Tensor const& input, 
+    float scale_factor)
+{
+  // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
+  const int attn_batches = input.size(0);
+  const int seq_len = input.size(1);
+  TORCH_INTERNAL_ASSERT(seq_len <= 8192);
+
+  // Output 
+  auto act_options = input.options().requires_grad(false);
+  torch::Tensor softmax_results = 
+      torch::empty({attn_batches, seq_len, seq_len}, act_options);
+
+  // Softmax Intermediate Result Ptr
+  void* input_ptr = static_cast<void*>(input.data_ptr());
+  void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
+
+  DISPATCH_HALF_AND_BFLOAT(
+      input.scalar_type(),
+      "dispatch_scaled_upper_triang_masked_softmax_forward",
+      dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
+	  reinterpret_cast<scalar_t*>(softmax_results_ptr),
+	  reinterpret_cast<const scalar_t*>(input_ptr),
+	  scale_factor,
+	  seq_len,
+	  seq_len,
+	  attn_batches);
+      );
+  return softmax_results;
+}
+				      
+
+torch::Tensor bwd_cuda(
+    torch::Tensor const& output_grads_, 
+    torch::Tensor const& softmax_results_, 
+    float scale_factor)  {
+	
+  auto output_grads = output_grads_.contiguous();
+  auto softmax_results = softmax_results_.contiguous();
+
+  //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
+  const int attn_batches = output_grads.size(0);
+  const int seq_len = output_grads.size(1);
+  TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
+
+  void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
+
+  //Softmax Grad
+  DISPATCH_HALF_AND_BFLOAT(
+      output_grads_.scalar_type(),
+      "dispatch_scaled_upper_triang_masked_softmax_backward",
+      dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
+          reinterpret_cast<scalar_t*>(output_grads_ptr), 
+	  reinterpret_cast<scalar_t*>(output_grads_ptr), 
+	  reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
+	  scale_factor,
+	  seq_len,
+	  seq_len,
+	  attn_batches);
+      );
+  
+  //backward pass is completely in-place
+  return output_grads;
+}
+}
+}
+}

+ 49 - 0
csrc/fused_softmax/setup.py

@@ -0,0 +1,49 @@
+# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron
+# We add the case where seqlen = 4k and seqlen = 8k
+import os
+import subprocess
+
+import torch
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
+
+
+def get_cuda_bare_metal_version(cuda_dir):
+    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
+    output = raw_output.split()
+    release_idx = output.index("release") + 1
+    release = output[release_idx].split(".")
+    bare_metal_major = release[0]
+    bare_metal_minor = release[1][0]
+
+    return raw_output, bare_metal_major, bare_metal_minor
+
+
+def append_nvcc_threads(nvcc_extra_args):
+    _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+    if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
+        return nvcc_extra_args + ["--threads", "4"]
+    return nvcc_extra_args
+
+
+cc_flag = []
+cc_flag.append("-gencode")
+cc_flag.append("arch=compute_70,code=sm_70")
+cc_flag.append("-gencode")
+cc_flag.append("arch=compute_80,code=sm_80")
+
+setup(
+    name='fused_softmax_lib',
+    ext_modules=[
+        CUDAExtension(
+            name='fused_softmax_lib',
+            sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'],
+            extra_compile_args={
+                               'cxx': ['-O3',],
+                               'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag)
+                               }
+            )
+    ],
+    cmdclass={
+        'build_ext': BuildExtension
+})

+ 20 - 0
csrc/fused_softmax/type_shim.h

@@ -0,0 +1,20 @@
+#include <ATen/ATen.h>
+
+#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...)                   \
+switch(TYPE)                                                        \
+{                                                                   \
+case at::ScalarType::Half:                                          \
+    {                                                               \
+using scalar_t = at::Half;                                          \
+__VA_ARGS__;                                                        \
+break;                                                              \
+    }                                                               \
+case at::ScalarType::BFloat16:                                      \
+    {                                                               \
+using scalar_t = at::BFloat16;                                      \
+__VA_ARGS__;                                                        \
+break;                                                              \
+    }                                                               \
+default:                                                            \
+    AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");	\
+}

+ 205 - 0
flash_attn/fused_softmax.py

@@ -0,0 +1,205 @@
+# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py
+# for benchmarking.
+# We added support for seqlen=2k and seqlen=4k
+
+# coding=utf-8
+# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from apex._autocast_utils import _cast_if_autocast_enabled
+from apex.transformer.enums import AttnMaskType
+
+from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward
+from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block
+from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward
+
+
+class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
+    """
+    Fused operation which performs following three operations in sequence
+    1. Scale the tensor.
+    2. Apply upper triangular mask (typically used in gpt models).
+    3. Perform softmax.
+    """
+
+    @staticmethod
+    def forward(ctx, inputs, scale):
+        scale_t = torch.tensor([scale])
+        softmax_results = scaled_upper_triang_masked_softmax_forward(
+            inputs, scale_t[0]
+        )
+        ctx.save_for_backward(softmax_results, scale_t)
+        return softmax_results
+
+    @staticmethod
+    def backward(ctx, output_grads):
+        softmax_results, scale_t = ctx.saved_tensors
+        input_grads = scaled_upper_triang_masked_softmax_backward(
+            output_grads, softmax_results, scale_t[0]
+        )
+        return input_grads, None
+
+
+def scaled_upper_triang_masked_softmax(inputs, _, scale):
+    b, np, sq, sk = inputs.size()
+    assert sq == sk, "causal mask is only for self attention"
+    # Reshaping input to 3D tensor (attn_batches, sq, sk)
+    inputs = inputs.view(-1, sq, sk)
+    args = _cast_if_autocast_enabled(inputs, scale)
+    with torch.cuda.amp.autocast(enabled=False):
+        probs = ScaledUpperTriangMaskedSoftmax.apply(*args)
+    return probs.view(b, np, sq, sk)
+
+
+# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`.
+# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context.
+# So I needed to manually write two `torch.autograd.Function` inheritances.
+# Fused operation which performs following three operations in sequence
+# 1. Scale the tensor.
+# 2. Apply the mask.
+# 3. Perform softmax.
+class ScaledMaskedSoftmax(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, inputs, mask, scale):
+        scale_t = torch.tensor([scale])
+        softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0])
+        ctx.save_for_backward(softmax_results, scale_t)
+        return softmax_results
+
+    @staticmethod
+    def backward(ctx, output_grads):
+        softmax_results, scale_t = ctx.saved_tensors
+        input_grads = scaled_masked_softmax_backward(
+            output_grads, softmax_results, scale_t[0]
+        )
+        return input_grads, None, None
+
+
+def scaled_masked_softmax(inputs, mask, scale):
+    # input is 4D tensor (b, np, sq, sk)
+    args = _cast_if_autocast_enabled(inputs, mask, scale)
+    with torch.cuda.amp.autocast(enabled=False):
+        return ScaledMaskedSoftmax.apply(*args)
+
+
+class FusedScaleMaskSoftmax(torch.nn.Module):
+    """
+    fused operation: scaling + mask + softmax
+
+    Arguments:
+        input_in_fp16: flag to indicate if input in fp16 data format.
+        input_in_bf16: flag to indicate if input in bf16 data format.
+        attn_mask_type: attention mask type (pad or causal)
+        scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
+        mask_func: mask function to be applied.
+        softmax_in_fp32: if true, softmax in performed at fp32 precision.
+        scale: scaling factor used in input tensor scaling.
+    """
+
+    def __init__(
+        self,
+        input_in_fp16,
+        input_in_bf16,
+        attn_mask_type,
+        scaled_masked_softmax_fusion,
+        mask_func,
+        softmax_in_fp32,
+        scale,
+    ):
+        super().__init__()
+        self.input_in_fp16 = input_in_fp16
+        self.input_in_bf16 = input_in_bf16
+        if self.input_in_fp16 and self.input_in_bf16:
+            raise RuntimeError(
+                "both fp16 and bf16 flags cannot be active at the same time."
+            )
+        self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
+        self.attn_mask_type = attn_mask_type
+        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
+        self.mask_func = mask_func
+        self.softmax_in_fp32 = softmax_in_fp32
+        self.scale = scale
+
+        if not (self.scale is None or softmax_in_fp32):
+            raise RuntimeError("softmax should be in fp32 when scaled")
+
+        if self.scaled_masked_softmax_fusion:
+            if self.attn_mask_type == AttnMaskType.causal:
+                self.fused_softmax_func = scaled_upper_triang_masked_softmax
+            elif self.attn_mask_type == AttnMaskType.padding:
+                self.fused_softmax_func = scaled_masked_softmax
+            else:
+                raise ValueError("Invalid attn_mask_type.")
+
+    def forward(self, input, mask):
+        # [b, np, sq, sk]
+        assert input.dim() == 4
+
+        if self.is_kernel_available(mask, *input.size()):
+            return self.forward_fused_softmax(input, mask)
+        else:
+            return self.forward_torch_softmax(input, mask)
+
+    def is_kernel_available(self, mask, b, np, sq, sk):
+        attn_batches = b * np
+
+        if (
+            self.scaled_masked_softmax_fusion  # user want to fuse
+            and self.input_in_float16  # input must be fp16
+            and (
+                self.attn_mask_type == AttnMaskType.causal
+                or (self.attn_mask_type == AttnMaskType.padding and mask is not None)
+            )
+            and 16 < sk <= 8192  # sk must be 16 ~ 8192
+            and sq % 4 == 0  # sq must be divisor of 4
+            and sk % 4 == 0  # sk must be divisor of 4
+            and attn_batches % 4 == 0  # np * b must be divisor of 4
+        ):
+            if 0 <= sk <= 8192:
+                batch_per_block = self.get_batch_per_block(sq, sk, b, np)
+
+                if self.attn_mask_type == AttnMaskType.causal:
+                    if attn_batches % batch_per_block == 0:
+                        return True
+                else:
+                    if sq % batch_per_block == 0:
+                        return True
+        return False
+
+    def forward_fused_softmax(self, input, mask):
+        # input.shape = [b, np, sq, sk]
+        scale = self.scale if self.scale is not None else 1.0
+        return self.fused_softmax_func(input, mask, scale)
+
+    def forward_torch_softmax(self, input, mask):
+        if self.input_in_float16 and self.softmax_in_fp32:
+            input = input.float()
+
+        if self.scale is not None:
+            input = input * self.scale
+        mask_output = self.mask_func(input, mask) if mask is not None else input
+        probs = torch.nn.Softmax(dim=-1)(mask_output)
+
+        if self.input_in_float16 and self.softmax_in_fp32:
+            if self.input_in_fp16:
+                probs = probs.half()
+            else:
+                probs = probs.bfloat16()
+
+        return probs
+
+    @staticmethod
+    def get_batch_per_block(sq, sk, b, np):
+        return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np)