|
@@ -58,21 +58,63 @@ struct enable_sm90_or_later : Kernel {
|
|
|
};
|
|
|
|
|
|
/*
|
|
|
- * This class provides the common ScaleA and ScaleB descriptors for the
|
|
|
- * ScaledEpilogue and ScaledEpilogueBias classes.
|
|
|
+ * This class provides the common load descriptors for the
|
|
|
+ * ScaledEpilogue[...] classes
|
|
|
*/
|
|
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
|
struct ScaledEpilogueBase {
|
|
|
protected:
|
|
|
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
|
|
|
|
|
- using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
|
|
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
|
|
+ template <typename T>
|
|
|
+ using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
|
|
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
|
Stride<Int<1>, Int<0>, Int<0>>>;
|
|
|
|
|
|
- using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
|
|
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
|
|
+ template <typename T>
|
|
|
+ using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
|
|
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
|
Stride<Int<0>, Int<1>, Int<0>>>;
|
|
|
+
|
|
|
+ // Don't want to support nullptr by default
|
|
|
+ template <typename T, bool EnableNullPtr = false>
|
|
|
+ using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
|
|
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
|
+ Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
|
|
+
|
|
|
+ // Don't want to support nullptr by default
|
|
|
+ template <typename T, bool EnableNullPtr = false>
|
|
|
+ using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
|
|
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
|
|
+ Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
|
|
+
|
|
|
+ // This utility function constructs the arguments for the load descriptors
|
|
|
+ // from a tensor. It can handle both row and column, as well as row/column or
|
|
|
+ // scalar cases.
|
|
|
+ template <typename Descriptor, typename T>
|
|
|
+ static auto args_from_tensor(torch::Tensor const& tensor) {
|
|
|
+ using Arguments = typename Descriptor::Arguments;
|
|
|
+ auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
|
|
+ if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
|
|
+ std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
|
|
+ return Arguments{data_ptr, tensor.numel() != 1};
|
|
|
+ } else {
|
|
|
+ static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
|
|
+ !std::is_same_v<Descriptor, RowLoad<T, true>>);
|
|
|
+ return Arguments{data_ptr};
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // This overload handles the case where there might not be a tensor, in which
|
|
|
+ // case a nullptr is passed and a constant (0) is used.
|
|
|
+ template <typename Descriptor, typename T>
|
|
|
+ static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
|
|
+ using Arguments = typename Descriptor::Arguments;
|
|
|
+ auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
|
|
+ static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
|
|
+ std::is_same_v<Descriptor, RowLoad<T, true>>);
|
|
|
+ return Arguments{data_ptr};
|
|
|
+ }
|
|
|
};
|
|
|
|
|
|
/*
|
|
@@ -97,8 +139,8 @@ struct ScaledEpilogue
|
|
|
private:
|
|
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
|
using Accum = typename SUPER::Accum;
|
|
|
- using ScaleA = typename SUPER::ScaleA;
|
|
|
- using ScaleB = typename SUPER::ScaleB;
|
|
|
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
|
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
|
|
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
cutlass::multiplies, float, float,
|
|
@@ -118,24 +160,32 @@ struct ScaledEpilogue
|
|
|
|
|
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
|
torch::Tensor const& b_scales) {
|
|
|
- using ScaleA_Args = typename ScaleA::Arguments;
|
|
|
- using ScaleB_Args = typename ScaleB::Arguments;
|
|
|
-
|
|
|
- ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
|
|
- ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
|
|
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
|
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
|
|
|
|
- return ArgumentType{a_args, {b_args}};
|
|
|
+ typename EVTCompute0::Arguments evt0_args{b_args};
|
|
|
+ return ArgumentType{a_args, evt0_args};
|
|
|
}
|
|
|
};
|
|
|
|
|
|
+/*
|
|
|
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
|
|
+ * This bias can also be used in the per-tensor azp case, where the activation
|
|
|
+ * zero point (azp) is used to compute an azp correction term,
|
|
|
+ * which is folded into the bias.
|
|
|
+ *
|
|
|
+ * The bias tensor must be per-output channel.
|
|
|
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
|
|
+ */
|
|
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
|
struct ScaledEpilogueBias
|
|
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
|
private:
|
|
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
|
using Accum = typename SUPER::Accum;
|
|
|
- using ScaleA = typename SUPER::ScaleA;
|
|
|
- using ScaleB = typename SUPER::ScaleB;
|
|
|
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
|
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
+ using Bias = typename SUPER::template RowLoad<ElementD>;
|
|
|
|
|
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
cutlass::multiplies, float, float,
|
|
@@ -148,27 +198,160 @@ struct ScaledEpilogueBias
|
|
|
cutlass::multiply_add, ElementD, float,
|
|
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
|
|
|
- using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
|
|
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
|
|
|
- Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
|
|
|
-
|
|
|
public:
|
|
|
using EVTCompute =
|
|
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
|
|
- using ArgumentType = typename EVTCompute::Arguments;
|
|
|
|
|
|
+ using ArgumentType = typename EVTCompute::Arguments;
|
|
|
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
|
torch::Tensor const& b_scales,
|
|
|
torch::Tensor const& bias) {
|
|
|
- using ScaleA_Args = typename ScaleA::Arguments;
|
|
|
- using ScaleB_Args = typename ScaleB::Arguments;
|
|
|
- using Bias_Args = typename Bias::Arguments;
|
|
|
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
|
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
|
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
+
|
|
|
+ typename EVTCompute0::Arguments evt0_args{b_args};
|
|
|
+ return ArgumentType{a_args, evt0_args, bias_args};
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+/*
|
|
|
+ * This epilogue directly supports per-tensor azp in int32 form.
|
|
|
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
|
|
+ * term, which should already be multiplied with the scalar azp.
|
|
|
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
|
|
+ *
|
|
|
+ * This epilogue also supports bias, which remains per-channel.
|
|
|
+ */
|
|
|
+template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
|
+struct ScaledEpilogueBiasAzp
|
|
|
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
|
+ private:
|
|
|
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
|
+ using Accum = typename SUPER::Accum;
|
|
|
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
|
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
+ using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
|
|
+
|
|
|
+ // This is the full AZP term, azp * J @ B, shape (1,n)
|
|
|
+ using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
|
|
+
|
|
|
+ // Compute float(accum - azp_adj), both operands are int32_t
|
|
|
+ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::minus, float, int32_t,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
+
|
|
|
+ using EVTComputeAzp =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
|
|
+
|
|
|
+ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::multiplies, float, float,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
+
|
|
|
+ using EVTComputeScaleB =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
|
|
+
|
|
|
+ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::multiply_add, ElementD, float,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
+
|
|
|
+ public:
|
|
|
+ using EVTCompute =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
|
|
+ EVTComputeScaleB, Bias>;
|
|
|
+ using ArgumentType = typename EVTCompute::Arguments;
|
|
|
+
|
|
|
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
|
+ torch::Tensor const& b_scales,
|
|
|
+ torch::Tensor const& azp_adj,
|
|
|
+ c10::optional<torch::Tensor> const& bias) {
|
|
|
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
|
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
|
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
+ auto azp_adj_args =
|
|
|
+ SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
|
|
+
|
|
|
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
|
|
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
|
|
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+/*
|
|
|
+ * This epilogue supports per-token azp by computing and applying
|
|
|
+ * the correction term using a rank-1 update. If the term were materialized,
|
|
|
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
|
|
|
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
|
|
+ * point for each row of A.
|
|
|
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
|
|
+ *
|
|
|
+ * This epilogue also supports bias, which remains per-channel.
|
|
|
+ */
|
|
|
+template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
|
|
+struct ScaledEpilogueBiasAzpToken
|
|
|
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
|
|
+ private:
|
|
|
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
|
|
+ using Accum = typename SUPER::Accum;
|
|
|
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
|
|
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
|
|
+ using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
|
|
+
|
|
|
+ // Per-token azp term, shape (m,1)
|
|
|
+ using Azp = typename SUPER::template ColLoad<int32_t>;
|
|
|
+
|
|
|
+ // This is the AZP adjustment term, J @ B, shape (1,n)
|
|
|
+ using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
|
|
+
|
|
|
+ // Compute azp * azp_adj
|
|
|
+ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::multiplies, int32_t, int32_t,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
+
|
|
|
+ using EVTComputeAzp =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
|
|
|
|
|
- ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
|
|
- ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
|
|
- Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
|
|
|
+ // Compute float(accum - azp*azp_adj), all operands are int32_t
|
|
|
+ using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::minus, float, int32_t,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
|
|
|
- return ArgumentType{a_args, {b_args}, bias_args};
|
|
|
+ using EVTComputeAcc =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
|
|
+
|
|
|
+ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::multiplies, float, float,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
+
|
|
|
+ using EVTComputeScaleB =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
|
|
+
|
|
|
+ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
|
|
+ cutlass::multiply_add, ElementD, float,
|
|
|
+ cutlass::FloatRoundStyle::round_to_nearest>;
|
|
|
+
|
|
|
+ public:
|
|
|
+ using EVTCompute =
|
|
|
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
|
|
+ EVTComputeScaleB, Bias>;
|
|
|
+ using ArgumentType = typename EVTCompute::Arguments;
|
|
|
+
|
|
|
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
|
|
+ torch::Tensor const& b_scales,
|
|
|
+ torch::Tensor const& azp_adj,
|
|
|
+ torch::Tensor const& azp,
|
|
|
+ c10::optional<torch::Tensor> const& bias) {
|
|
|
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
|
|
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
|
|
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
+ auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
|
|
+ auto azp_adj_args =
|
|
|
+ SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
|
|
+
|
|
|
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
|
|
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
|
|
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
|
|
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
|
|
}
|
|
|
};
|
|
|
|
|
@@ -546,4 +729,23 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
|
+ torch::Tensor const& b,
|
|
|
+ torch::Tensor const& a_scales,
|
|
|
+ torch::Tensor const& b_scales,
|
|
|
+ torch::Tensor const& azp_adj,
|
|
|
+ c10::optional<torch::Tensor> const& azp,
|
|
|
+ c10::optional<torch::Tensor> const& bias) {
|
|
|
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
|
|
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
|
|
+
|
|
|
+ if (azp) {
|
|
|
+ return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzpToken>(
|
|
|
+ out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
|
|
+ } else {
|
|
|
+ return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBiasAzp>(
|
|
|
+ out, a, b, a_scales, b_scales, azp_adj, bias);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
#endif
|