浏览代码

feat: per-tensor token epilogue kernels (#630)

* add new kernels

* feat: add to _custom_ops

* fix naming issues

* add documentation
AlpinDale 6 月之前
父节点
当前提交
a401f8e05d

+ 25 - 1
aphrodite/_custom_ops.py

@@ -248,6 +248,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:
     assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
     assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
+    assert bias is None or bias.shape[0] == b.shape[
+        1] and bias.dtype == out_dtype
 
     m = a.shape[0]
     n = b.shape[1]
@@ -258,6 +260,28 @@ def cutlass_scaled_mm(a: torch.Tensor,
     return out
 
 
+def cutlass_scaled_mm_azp(a: torch.Tensor,
+                          b: torch.Tensor,
+                          scale_a: torch.Tensor,
+                          scale_b: torch.Tensor,
+                          out_dtype: torch.dtype,
+                          azp_adj: torch.Tensor,
+                          azp: Optional[torch.Tensor] = None,
+                          bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
+    assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
+    assert bias is None or bias.numel(
+    ) == b.shape[1] and bias.dtype == out_dtype
+
+    m = a.shape[0]
+    n = b.shape[1]
+    out = torch.empty((m, n), dtype=out_dtype, device=a.device)
+
+    torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
+                                       azp, bias)
+    return out
+
+
 # aqlm
 def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
               codebooks: torch.Tensor, scales: torch.Tensor,
@@ -613,7 +637,7 @@ for k, v in names_and_values.items():
     if isinstance(v, fn_type) \
         and v.__code__.co_filename == __file__ \
         and any(arg is torch.Tensor or arg == "torch.Tensor"
-                   for arg in v.__annotations__.values()):
+                for arg in v.__annotations__.values()):
         names_and_values_to_update[k] = hint_on_error(v)
 
 names_and_values.update(names_and_values_to_update)

+ 4 - 0
docs/.vitepress/config.mts

@@ -161,6 +161,10 @@ export default defineConfig({
 						text: "Paged Attention",
 						link: "/pages/developer/paged-attention",
 					},
+					{
+						text: "NVIDIA CUTLASS Epilogues",
+						link: "/pages/developer/cutlass-epilogue",
+					},
 				],
 			},
 		],

+ 124 - 0
docs/pages/developer/cutlass-epilogue.md

@@ -0,0 +1,124 @@
+---
+outline: deep
+---
+
+# NVIDIA CUTLASS Epilogues
+
+## Introduction
+
+Epilogues are the final stages in a sequence of GPU-accelerated matrix multiplications and/or tensor operations. We typically handle these using the [NVIDIA CUTLASS](https://github.com/NVIDIA/CUTLASS) library of Linear Algebra Subroutines and Solvers. The epilogue phase comes after the core matmul (GEMM) has been performed on the input, and these are simply the additional operations applied to the output.
+
+In other words, the epilogue rearranges the result of a matrix product through shared memory to match canonical tensor layouts in global memory. Epilogues also support conversion and reduction operation. Note that the shared memory resource is time-sliced across warps.
+
+Currently in Aphrodite, we only support symmetric quantization for weights, but symmetric and asymmetric for activations. Both can be quantized per-tensor/per-channel (weights) or per-token (activations).
+
+In total, we use 4 epilogues:
+
+1. `ScaledEpilogue`: symmetric  quantization for activations, no bias.
+2. `ScaledEpilogueBias`: symmetric quantization for activations, supports bias.
+3. `ScaledEpilogueAzp`: asymmetric per-tensor quantization for activations, supports bias.
+4. `ScaledEpilogueAzpPerToken`: asymmetric per-token quantization for activations, supports bias.
+
+We don't have epilogues for asymmetric quantization of activations without bias in order to reduce the final binary size. Instead, if no bias is passed, the epilogue will use 0 as the bias!
+That induces a redundant addition operation (and runtime check), but the performance impact seems to be relatively minor.
+
+## Underlying Linear Algebra
+
+If $\widehat X$ is the quantized $X$, our matrices become the following:
+***
+$$A = s_a (\widehat A - J_a z_a)$$
+
+$$B = s_b \widehat B$$
+
+$$D = A B + C$$
+
+$$D = s_a s_b \widehat D + C$$
+***
+Here, $D$ is the output of the GEMM, and $C$ is the bias. $A$ is the activations and supports asymmetric quantization, and $B$ is the weights and only supports symmetric quantization. $s_a$ and $s_b$ are the scales for activations and weights, respectively. $z_a$ is the zero-point for activations, and $J_a$ is the matrix of all ones with dimensions of $A$. Additional epilogues would be required to support asymmetric quantization for weights.
+
+Expanding further, we can calculate $\widehat D$ as follows:
+
+***
+$$A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B$$
+
+$$A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)$$
+
+$$\widehat D = \widehat A \widehat B - z_a J_a \widehat B$$
+
+***
+Now that $\widehat A \widehat B$ is the raw output of the GEMM, and $J_a \widehat B$ is known ahead of time. Each row of it is equal to $\mathbf 1 \widehat B$, which is a row-vector of column sums of $\widehat B$.
+
+
+## Epilogues
+
+### ScaledEpilogue
+This epilogue computes the symmetric quantization for activations without bias, meaning $C = 0$ and  $z_a = 0$. The output of the GEMM is:
+
+***
+$$\widehat D = \widehat A \widehat B$$
+$$D = s_a s_b \widehat D$$
+$$D = s_a s_b \widehat A \widehat B$$
+***
+
+Epilogue parameters:
+- `scale_a`: the scale for activations, can be per-tensor (scalar) or per-token (column-vector)
+- `scale_b`: the scale for weights, can be per-tensor (scalar) or per-channel (row-vector)
+
+### ScaledEpilogueBias
+This epilogue computes the symmetric quantization for activations with bias, meaning $z_a = 0$.
+The output of the GEMM is:
+
+***
+$$\widehat D = \widehat A \widehat B$$
+$$D = s_a s_b \widehat D + C $$
+$$D = s_a s_b \widehat A \widehat B + C$$
+***
+
+Epilogue parameters:
+- `scale_a`: the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
+- `scale_b`: the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
+- `bias`: the bias, is always per-channel (row-vector).
+
+### ScaledEpilogueAzp
+This epilogue computes the asymmetric per-tensor quantization for activations with bias.
+The output of the GEMM is:
+
+***
+$$\widehat D = \widehat A \widehat B - z_a J_a \widehat B$$
+$$D = s_a s_b \widehat D + C $$
+$$D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C$$
+***
+
+Because $z_a$ is a scalar, the zero-point term $z_a J_a \widehat B$ has every row equal to $z_a \mathbf 1 B$. 
+That is precomputed and stored in `azp_with_adj` as a row-vector.
+
+Epilogue parameters:
+- `scale_a`: the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
+  - Generally this will be per-tensor as the zero-points are per-tensor.
+- `scale_b`: the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
+- `azp_with_adj`: the precomputed zero-point term ($z_a J_a \widehat B$), is per-channel (row-vector).
+- `bias`: the bias, is always per-channel (row-vector).
+
+To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
+
+
+### ScaledEpilogueAzpPerToken
+This epilogue computes the asymmetric per-token quantization for activations with bias.
+
+The output of the GEMM is the same as above, but the $z_a$ is a column-vector.
+That means the zero-point term $z_a J_a \widehat B$ becomes an outer product of $z_a$ and $\mathbf 1 \widehat B$.
+
+Epilogue parameters:
+- `scale_a`: the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
+  - Generally this will be per-token as the zero-points are per-token.
+- `scale_b`: the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
+- `azp_adj`: the precomputed zero-point adjustment term ($\mathbf 1 \widehat B$), is per-channel (row-vector).
+- `azp`: the zero-point (`z_a`), is per-token (column-vector).
+- `bias`: the bias, is always per-channel (row-vector).
+
+To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
+
+The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
+```
+out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
+```

+ 151 - 1
kernels/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp

@@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast {
 
 };
 
+/////////////////////////////////////////////////////////////////////////////////////////////////
+
+// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
+template<
+  class ThreadMap,
+  class Element,
+  class StrideMNL
+>
+struct VisitorRowOrZeroBroadcast {
+
+  // This struct has been modified to remove null_default (because it's always 0)
+  struct Arguments {
+    Element const* ptr_row = nullptr;
+    StrideMNL dRow = {};
+  };
+
+  using Params = Arguments;
+
+  template <class ProblemShape>
+  static constexpr Params
+  to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
+    return args;
+  }
+
+  template <class ProblemShape>
+  static size_t
+  get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
+    return 0;
+  }
+
+  struct SharedStorage {};
+
+  // Global load type
+  static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
+  using VecType = uint_bit_t<cute::min(128, vec_bits)>;
+  static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
+
+  CUTLASS_HOST_DEVICE
+  VisitorRowOrZeroBroadcast() { }
+
+  CUTLASS_HOST_DEVICE
+  VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
+    : params_ptr(&params) { }
+
+  Params const* params_ptr;
+
+  template <class GTensor, class RTensor, class CTensor, class ProblemShape>
+  struct Callbacks : EmptyCallbacks {
+    CUTLASS_DEVICE
+    Callbacks(
+      GTensor&& tC_gRow,
+      RTensor&& tC_rRow,
+      CTensor&& tC_cRow,
+      ProblemShape problem_shape,
+      Params const* params_ptr
+    ):
+      tC_gRow(cute::forward<GTensor>(tC_gRow)),
+      tC_rRow(cute::forward<RTensor>(tC_rRow)),
+      tC_cRow(cute::forward<CTensor>(tC_cRow)),
+      n(get<1>(problem_shape)),
+      params_ptr(params_ptr) { }
+
+    GTensor tC_gRow;
+    RTensor tC_rRow;
+    CTensor tC_cRow;
+    Params const* params_ptr;
+    int n;
+
+    // This function is modified from VisitorRowBroadcast
+    CUTLASS_DEVICE void
+    begin_epilogue() {
+      clear(tC_rRow);
+      auto src_v = filter(tC_gRow);
+      auto coord_v = filter(tC_cRow);
+      auto dst_v = filter(tC_rRow);
+
+      if (params_ptr->ptr_row != nullptr) {
+        // In this case we are loading from a row vector and broadcasting
+        CUTLASS_PRAGMA_UNROLL
+        for (int i = 0; i < size(src_v); ++i) {
+          bool guard = get<1>(coord_v(i)) < n;
+          cutlass::arch::global_load<VecType, sizeof(VecType)>(
+              dst_v(i), (void const*)&src_v(i), guard);
+        }
+      } else {
+        // In this case we are broadcasting 0
+        VecType filled_vec;
+        CUTLASS_PRAGMA_UNROLL
+        for (int i = 0; i < VecLength; i++) {
+          reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
+        }
+
+        CUTLASS_PRAGMA_UNROLL
+        for (int i = 0; i < size(src_v); ++i) {
+          if (get<1>(coord_v(i)) < n) {
+            dst_v(i) = filled_vec;
+          }
+        }
+      }
+    }
+
+    template <class ElementAccumulator, int FragmentSize>
+    CUTLASS_DEVICE auto // returns an Array
+    visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
+          Array<ElementAccumulator, FragmentSize> const& frg_acc) {
+      Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
+      return rRow_frg(column_idx);
+    }
+  };
+
+  template <class ProblemShape>
+  CUTLASS_DEVICE auto
+  get_callbacks(
+    gemm::GemmCoord threadblock_tile_offset,
+    int thread_idx,
+    ProblemShape problem_shape
+  ) {
+    Tensor mRow = make_tensor(
+      make_gmem_ptr(params_ptr->ptr_row),
+      problem_shape,
+      params_ptr->dRow);
+
+    // VECTOR, FRAGMENT_COLUMN
+    Tensor tC_gRow = recast<VecType>(
+      ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
+    )(_,_,_0{},_0{},_0{},_0{});
+    Tensor tC_rRow = make_tensor_like(tC_gRow);
+
+    // Generate the pred tensor
+    Tensor cRow = make_identity_tensor(mRow.shape());
+    Tensor tC_cRow = outer_partition(
+      ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
+      Shape<Int<VecLength>>{},
+      (_0{})
+    );
+
+    return Callbacks<
+      decltype(tC_gRow), decltype(tC_rRow),
+      decltype(tC_cRow), ProblemShape>(
+      cute::move(tC_gRow),
+      cute::move(tC_rRow),
+      cute::move(tC_cRow),
+      problem_shape,
+      params_ptr
+    );
+  }
+
+};
+
+
 /////////////////////////////////////////////////////////////////////////////////////////////////
 
 // Column vector broadcast
@@ -217,7 +367,7 @@ template<
 >
 struct VisitorColOrScalarBroadcast {
 
-  // This struct has been modified to have a bool indicating that ptr_col is a 
+  // This struct has been modified to have a bool indicating that ptr_col is a
   // scalar that must be broadcast.
   struct Arguments {
     Element const* ptr_col = nullptr;

+ 60 - 0
kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cu

@@ -51,6 +51,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
   }
 }
 
+void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias) {
+  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
+  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
+
+  if (azp) {
+    return cutlass_scaled_mm_sm75_epilogue<
+        aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
+                                               azp_adj, *azp, bias);
+  } else {
+    return cutlass_scaled_mm_sm75_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
+        out, a, b, a_scales, b_scales, azp_adj, bias);
+  }
+}
+
 template <template <typename, typename> typename Epilogue,
           typename... EpilogueArgs>
 void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
@@ -89,6 +109,26 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
   }
 }
 
+void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias) {
+  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
+  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
+
+  if (azp) {
+    return cutlass_scaled_mm_sm80_epilogue<
+        aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
+                                               azp_adj, *azp, bias);
+  } else {
+    return cutlass_scaled_mm_sm80_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
+        out, a, b, a_scales, b_scales, azp_adj, bias);
+  }
+}
+
 template <template <typename, typename> typename Epilogue,
           typename... EpilogueArgs>
 void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
@@ -140,4 +180,24 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
     return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogue>(
         out, a, b, a_scales, b_scales);
   }
+}
+
+void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias) {
+  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
+  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
+
+  if (azp) {
+    return cutlass_scaled_mm_sm89_epilogue<
+        aphrodite::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales,
+                                               azp_adj, *azp, bias);
+  } else {
+    return cutlass_scaled_mm_sm89_epilogue<aphrodite::ScaledEpilogueBiasAzp>(
+        out, a, b, a_scales, b_scales, azp_adj, bias);
+  }
 }

+ 217 - 36
kernels/quantization/cutlass_w8a8/scaled_mm_c2x.cuh

@@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
 };
 
 /*
- * This class provides the common ScaleA and ScaleB descriptors for the
- * ScaledEpilogue and ScaledEpilogueBias classes.
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
  */
 template <typename ElementD, typename OutputTileThreadMap>
 struct ScaledEpilogueBase {
  protected:
   using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
 
-  using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
-      OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
+  template <typename T>
+  using ColOrScalarLoad =
+      cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
+          OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
+
+  template <typename T>
+  using RowOrScalarLoad =
+      cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
+          OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
+
+  template <typename T>
+  using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
+      OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
+
+  template <typename T>
+  using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
+      OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
+
+  template <typename T>
+  using RowOrZeroLoad =
+      cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
+          OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
+
+  // This utility function constructs the arguments for the load descriptors
+  // from a tensor. It can handle both row and column, as well as row/column or
+  // scalar cases.
+  template <typename Descriptor, typename T>
+  static auto args_from_tensor(torch::Tensor const& tensor) {
+    using Arguments = typename Descriptor::Arguments;
+    auto* data_ptr = static_cast<T*>(tensor.data_ptr());
+    if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
+                  std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
+      return Arguments{data_ptr, tensor.numel() != 1};
+    } else {
+      // it would technically work but no use case as data_ptr is never nullptr
+      static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
+      return Arguments{data_ptr};
+    }
+  }
 
-  using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
-      OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
+  // This overload handles the case where there might not be a tensor, in which
+  // case a nullptr is passed and a constant (0) is used.
+  template <typename Descriptor, typename T>
+  static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
+    static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
+    using Arguments = typename Descriptor::Arguments;
+    auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
+    return Arguments{data_ptr};
+  }
 };
 
 /*
@@ -110,8 +154,8 @@ struct ScaledEpilogue
  private:
   using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
   using Accum = typename SUPER::Accum;
-  using ScaleA = typename SUPER::ScaleA;
-  using ScaleB = typename SUPER::ScaleB;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
 
   using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
       cutlass::multiplies, float, float,
@@ -131,28 +175,32 @@ struct ScaledEpilogue
 
   static ArgumentType prepare_args(torch::Tensor const& a_scales,
                                    torch::Tensor const& b_scales) {
-    using ScaleAArgs = typename ScaleA::Arguments;
-    using ScaleBArgs = typename ScaleB::Arguments;
-
-    ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
-    ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
 
-    typename EVTCompute0::Arguments evt0_compute_args{b_args};
-
-    typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
-    return evt_compute_args;
+    typename EVTCompute0::Arguments evt0_args{b_args};
+    return ArgumentType{a_args, evt0_args};
   }
 };
 
+/*
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
+ * This bias can also be used in the per-tensor azp case, where the activation
+ * zero point (azp) is used to compute an azp correction term,
+ * which is folded into the bias.
+ *
+ * The bias tensor must be per-output channel.
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
+ */
 template <typename ElementD, typename OutputTileThreadMap>
 struct ScaledEpilogueBias
-    : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
- private:
+    : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
+ protected:
   using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
   using Accum = typename SUPER::Accum;
-  using ScaleA = typename SUPER::ScaleA;
-  using ScaleB = typename SUPER::ScaleB;
-
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
+  using Bias = typename SUPER::template RowLoad<ElementD>;
   using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
       cutlass::multiplies, float, float,
       cutlass::FloatRoundStyle::round_to_nearest>;
@@ -164,30 +212,163 @@ struct ScaledEpilogueBias
       cutlass::multiply_add, ElementD, float,
       cutlass::FloatRoundStyle::round_to_nearest>;
 
-  using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
-      OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
-
  public:
   using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
                                                              EVTCompute0, Bias>;
   using ArgumentType = typename EVTCompute::Arguments;
-
   static ArgumentType prepare_args(torch::Tensor const& a_scales,
                                    torch::Tensor const& b_scales,
                                    torch::Tensor const& bias) {
-    using ScaleAArgs = typename ScaleA::Arguments;
-    using ScaleBArgs = typename ScaleB::Arguments;
-    using BiasArgs = typename Bias::Arguments;
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
+    auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
+
+    typename EVTCompute0::Arguments evt0_args{b_args};
+    return ArgumentType{a_args, evt0_args, bias_args};
+  }
+};
+
+/*
+ * This epilogue directly supports per-tensor azp in int32 form.
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
+ * term, which should already be multiplied with the scalar azp.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template <typename ElementD, typename OutputTileThreadMap>
+struct ScaledEpilogueBiasAzp
+    : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
+ private:
+  using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
+  using Accum = typename SUPER::Accum;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
+  using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
+
+  // This is the full AZP term, azp * J @ B, shape (1,n)
+  using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
+
+  // Compute float(accum - azp_adj), both operands are int32_t
+  using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::minus, float, int32_t,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeAzp =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
+
+  using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, float, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeScaleB =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
+                                              EVTComputeAzp>;
+
+  using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiply_add, ElementD, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+  using EVTCompute =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
+                                              EVTComputeScaleB, Bias>;
+
+  using ArgumentType = typename EVTCompute::Arguments;
+
+  static ArgumentType prepare_args(torch::Tensor const& a_scales,
+                                   torch::Tensor const& b_scales,
+                                   torch::Tensor const& azp_adj,
+                                   c10::optional<torch::Tensor> const& bias) {
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
+    auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
+    auto azp_adj_args =
+        SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
+
+    typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
+    typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
+    return ArgumentType{a_args, evt_scale_b_args, bias_args};
+  }
+};
+
+/*
+ * This epilogue supports per-token azp by computing and applying
+ * the correction term using a rank-1 update. If the term were materialized,
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
+ * point for each row of A.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template <typename ElementD, typename OutputTileThreadMap>
+struct ScaledEpilogueBiasAzpToken
+    : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
+ private:
+  using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
+  using Accum = typename SUPER::Accum;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
+  using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
+
+  // Per-token azp term, shape (m,1)
+  using Azp = typename SUPER::template ColLoad<int32_t>;
 
-    ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
-    ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
-    BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
+  // This is the AZP adjustment term, J @ B, shape (1,n)
+  using AzpAdj = typename SUPER::template RowLoad<int32_t>;
+
+  // Compute azp * azp_adj
+  using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, int32_t, int32_t,
+      cutlass::FloatRoundStyle::round_to_nearest>;
 
-    typename EVTCompute0::Arguments evt0_compute_args{b_args};
+  using EVTComputeAzp =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
 
-    typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
-                                                    bias_args};
-    return evt_compute_args;
+  // Compute float(accum - azp*azp_adj), all operands are int32_t
+  using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::minus, float, int32_t,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeAcc =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
+
+  using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiplies, float, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeScaleB =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
+                                              EVTComputeAcc>;
+
+  using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
+      cutlass::multiply_add, ElementD, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+  using EVTCompute =
+      cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
+                                              EVTComputeScaleB, Bias>;
+
+  using ArgumentType = typename EVTCompute::Arguments;
+
+  static ArgumentType prepare_args(torch::Tensor const& a_scales,
+                                   torch::Tensor const& b_scales,
+                                   torch::Tensor const& azp_adj,
+                                   torch::Tensor const& azp,
+                                   c10::optional<torch::Tensor> const& bias) {
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
+    auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
+    auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
+    auto azp_adj_args =
+        SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
+
+    typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
+    typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
+    typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
+    return ArgumentType{a_args, evt_scale_b_args, bias_args};
   }
 };
 

+ 230 - 28
kernels/quantization/cutlass_w8a8/scaled_mm_c3x.cu

@@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
 };
 
 /*
- * This class provides the common ScaleA and ScaleB descriptors for the
- * ScaledEpilogue and ScaledEpilogueBias classes.
+ * This class provides the common load descriptors for the
+ * ScaledEpilogue[...] classes
  */
 template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
 struct ScaledEpilogueBase {
  protected:
   using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
 
-  using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
-      0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
+  template <typename T>
+  using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
+      0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
       Stride<Int<1>, Int<0>, Int<0>>>;
 
-  using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
-      0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
+  template <typename T>
+  using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
+      0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
       Stride<Int<0>, Int<1>, Int<0>>>;
+
+  // Don't want to support nullptr by default
+  template <typename T, bool EnableNullPtr = false>
+  using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
+      0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+      Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
+
+  // Don't want to support nullptr by default
+  template <typename T, bool EnableNullPtr = false>
+  using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
+      0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+      Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
+
+  // This utility function constructs the arguments for the load descriptors
+  // from a tensor. It can handle both row and column, as well as row/column or
+  // scalar cases.
+  template <typename Descriptor, typename T>
+  static auto args_from_tensor(torch::Tensor const& tensor) {
+    using Arguments = typename Descriptor::Arguments;
+    auto* data_ptr = static_cast<T*>(tensor.data_ptr());
+    if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
+                  std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
+      return Arguments{data_ptr, tensor.numel() != 1};
+    } else {
+      static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
+                    !std::is_same_v<Descriptor, RowLoad<T, true>>);
+      return Arguments{data_ptr};
+    }
+  }
+
+  // This overload handles the case where there might not be a tensor, in which
+  // case a nullptr is passed and a constant (0) is used.
+  template <typename Descriptor, typename T>
+  static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
+    using Arguments = typename Descriptor::Arguments;
+    auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
+    static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
+                  std::is_same_v<Descriptor, RowLoad<T, true>>);
+    return Arguments{data_ptr};
+  }
 };
 
 /*
@@ -97,8 +139,8 @@ struct ScaledEpilogue
  private:
   using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
   using Accum = typename SUPER::Accum;
-  using ScaleA = typename SUPER::ScaleA;
-  using ScaleB = typename SUPER::ScaleB;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
 
   using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
       cutlass::multiplies, float, float,
@@ -118,24 +160,32 @@ struct ScaledEpilogue
 
   static ArgumentType prepare_args(torch::Tensor const& a_scales,
                                    torch::Tensor const& b_scales) {
-    using ScaleA_Args = typename ScaleA::Arguments;
-    using ScaleB_Args = typename ScaleB::Arguments;
-
-    ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
-    ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
 
-    return ArgumentType{a_args, {b_args}};
+    typename EVTCompute0::Arguments evt0_args{b_args};
+    return ArgumentType{a_args, evt0_args};
   }
 };
 
+/*
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
+ * This bias can also be used in the per-tensor azp case, where the activation
+ * zero point (azp) is used to compute an azp correction term,
+ * which is folded into the bias.
+ *
+ * The bias tensor must be per-output channel.
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
+ */
 template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
 struct ScaledEpilogueBias
     : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
  private:
   using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
   using Accum = typename SUPER::Accum;
-  using ScaleA = typename SUPER::ScaleA;
-  using ScaleB = typename SUPER::ScaleB;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
+  using Bias = typename SUPER::template RowLoad<ElementD>;
 
   using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
       cutlass::multiplies, float, float,
@@ -148,27 +198,160 @@ struct ScaledEpilogueBias
       cutlass::multiply_add, ElementD, float,
       cutlass::FloatRoundStyle::round_to_nearest>;
 
-  using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
-      0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
-      Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
-
  public:
   using EVTCompute =
       cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
-  using ArgumentType = typename EVTCompute::Arguments;
 
+  using ArgumentType = typename EVTCompute::Arguments;
   static ArgumentType prepare_args(torch::Tensor const& a_scales,
                                    torch::Tensor const& b_scales,
                                    torch::Tensor const& bias) {
-    using ScaleA_Args = typename ScaleA::Arguments;
-    using ScaleB_Args = typename ScaleB::Arguments;
-    using Bias_Args = typename Bias::Arguments;
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
+    auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
+
+    typename EVTCompute0::Arguments evt0_args{b_args};
+    return ArgumentType{a_args, evt0_args, bias_args};
+  }
+};
+
+/*
+ * This epilogue directly supports per-tensor azp in int32 form.
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
+ * term, which should already be multiplied with the scalar azp.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
+struct ScaledEpilogueBiasAzp
+    : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
+ private:
+  using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
+  using Accum = typename SUPER::Accum;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
+  using Bias = typename SUPER::template RowLoad<ElementD, true>;
+
+  // This is the full AZP term, azp * J @ B, shape (1,n)
+  using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
+
+  // Compute float(accum - azp_adj), both operands are int32_t
+  using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::minus, float, int32_t,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeAzp =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
+
+  using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::multiplies, float, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeScaleB =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
+
+  using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::multiply_add, ElementD, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+  using EVTCompute =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
+                                         EVTComputeScaleB, Bias>;
+  using ArgumentType = typename EVTCompute::Arguments;
+
+  static ArgumentType prepare_args(torch::Tensor const& a_scales,
+                                   torch::Tensor const& b_scales,
+                                   torch::Tensor const& azp_adj,
+                                   c10::optional<torch::Tensor> const& bias) {
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
+    auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
+    auto azp_adj_args =
+        SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
+
+    typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
+    typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
+    return ArgumentType{a_args, evt_scale_b_args, bias_args};
+  }
+};
+
+/*
+ * This epilogue supports per-token azp by computing and applying
+ * the correction term using a rank-1 update. If the term were materialized,
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
+ * point for each row of A.
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
+ *
+ * This epilogue also supports bias, which remains per-channel.
+ */
+template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
+struct ScaledEpilogueBiasAzpToken
+    : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
+ private:
+  using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
+  using Accum = typename SUPER::Accum;
+  using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
+  using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
+  using Bias = typename SUPER::template RowLoad<ElementD, true>;
+
+  // Per-token azp term, shape (m,1)
+  using Azp = typename SUPER::template ColLoad<int32_t>;
+
+  // This is the AZP adjustment term, J @ B, shape (1,n)
+  using AzpAdj = typename SUPER::template RowLoad<int32_t>;
+
+  // Compute azp * azp_adj
+  using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::multiplies, int32_t, int32_t,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeAzp =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
 
-    ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
-    ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
-    Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
+  // Compute float(accum - azp*azp_adj), all operands are int32_t
+  using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::minus, float, int32_t,
+      cutlass::FloatRoundStyle::round_to_nearest>;
 
-    return ArgumentType{a_args, {b_args}, bias_args};
+  using EVTComputeAcc =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
+
+  using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::multiplies, float, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+  using EVTComputeScaleB =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
+
+  using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
+      cutlass::multiply_add, ElementD, float,
+      cutlass::FloatRoundStyle::round_to_nearest>;
+
+ public:
+  using EVTCompute =
+      cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
+                                         EVTComputeScaleB, Bias>;
+  using ArgumentType = typename EVTCompute::Arguments;
+
+  static ArgumentType prepare_args(torch::Tensor const& a_scales,
+                                   torch::Tensor const& b_scales,
+                                   torch::Tensor const& azp_adj,
+                                   torch::Tensor const& azp,
+                                   c10::optional<torch::Tensor> const& bias) {
+    auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
+    auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
+    auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
+    auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
+    auto azp_adj_args =
+        SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
+
+    typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
+    typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
+    typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
+    return ArgumentType{a_args, evt_scale_b_args, bias_args};
   }
 };
 
@@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
   }
 }
 
+void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias) {
+  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
+  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
+
+  if (azp) {
+    return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
+        out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
+  } else {
+    return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
+        out, a, b, a_scales, b_scales, azp_adj, bias);
+  }
+}
+
 #endif

+ 104 - 7
kernels/quantization/cutlass_w8a8/scaled_mm_entry.cu

@@ -29,6 +29,40 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
                             c10::optional<torch::Tensor> const& bias);
 #endif
 
+void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias);
+
+void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias);
+
+void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias);
+
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
+void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
+                                torch::Tensor const& b,
+                                torch::Tensor const& a_scales,
+                                torch::Tensor const& b_scales,
+                                torch::Tensor const& azp_adj,
+                                c10::optional<torch::Tensor> const& azp,
+                                c10::optional<torch::Tensor> const& bias);
+#endif
+
 bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
   // CUTLASS FP8 kernels need at least
   //   CUDA 12.0 on SM90 systems (Hopper)
@@ -45,18 +79,20 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
   return false;
 }
 
-void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
-                       torch::Tensor const& b, torch::Tensor const& a_scales,
-                       torch::Tensor const& b_scales,
-                       c10::optional<torch::Tensor> const& bias) {
-  int32_t major_capability;
-  int32_t minor_capability;
+int32_t get_sm_version_num() {
+  int32_t major_capability, minor_capability;
   cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
                          0);
   cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
                          0);
   int32_t version_num = major_capability * 10 + minor_capability;
+  return version_num;
+}
 
+void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
+                       torch::Tensor const& b, torch::Tensor const& a_scales,
+                       torch::Tensor const& b_scales,
+                       c10::optional<torch::Tensor> const& bias) {
   // Checks for conformality
   TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
   TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
@@ -77,7 +113,7 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
   }
 
   at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
-
+  int32_t version_num = get_sm_version_num();
   if (version_num >= 90) {
     // Hopper
 
@@ -98,4 +134,65 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
     TORCH_CHECK(version_num >= 75);
     cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
   }
+}
+
+void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
+                           torch::Tensor const& b,
+                           torch::Tensor const& a_scales,
+                           torch::Tensor const& b_scales,
+                           torch::Tensor const& azp_adj,
+                           c10::optional<torch::Tensor> const& azp,
+                           c10::optional<torch::Tensor> const& bias) {
+  // Checks for conformality
+  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
+  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
+              b.size(1) == c.size(1));
+  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
+  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
+
+  // Check for strides and alignment
+  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
+  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
+  TORCH_CHECK(c.stride(0) % 16 == 0 &&
+              b.stride(1) % 16 == 0);  // 16 Byte Alignment
+  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
+
+  // bias, azp, azp_adj are all 1d
+  // bias and azp_adj have n elements, azp has m elements
+  if (bias) {
+    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
+  }
+  if (azp) {
+    TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
+  }
+  TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
+
+  // azp & bias types
+  TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
+  TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
+  TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
+              "currently bias dtype must match output dtype ", c.dtype());
+
+  at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
+  int32_t version_num = get_sm_version_num();
+  if (version_num >= 90) {
+    // Hopper
+
+    // Guard against compilation issues for sm90 kernels
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
+    cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
+#else
+    cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
+#endif
+  } else if (version_num == 89) {
+    // Ada Lovelace
+    cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
+  } else if (version_num >= 80) {
+    // Ampere
+    cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
+  } else {
+    // Turing
+    TORCH_CHECK(version_num >= 75);
+    cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
+  }
 }

+ 8 - 0
kernels/quantization/quant_ops.h

@@ -116,6 +116,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
                        torch::Tensor const& b_scales,
                        c10::optional<torch::Tensor> const& bias);
 
+void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
+                           torch::Tensor const& b,
+                           torch::Tensor const& a_scales,
+                           torch::Tensor const& b_scales,
+                           torch::Tensor const& azp_adj,
+                           c10::optional<torch::Tensor> const& azp,
+                           c10::optional<torch::Tensor> const& bias);
+
 torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
                               torch::Tensor const& b_q_weight,
                               torch::Tensor const& s_tok,

+ 9 - 0
kernels/torch_bindings.cpp

@@ -180,6 +180,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
            &cutlass_scaled_mm_supports_fp8);
 
+  // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
+  // quantization.
+  ops.def(
+      "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
+      "                  Tensor b, Tensor a_scales,"
+      "                  Tensor b_scales, Tensor azp_adj,"
+      "                  Tensor? azp, Tensor? bias) -> ()");
+  ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
+
   // QuIP# GEMV
   ops.def("quip_gemv", &e8p_mm_origorder);
   ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);