123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- #ifndef DNNL_HELPER_HPP
- #define DNNL_HELPER_HPP
- #include <c10/util/BFloat16.h>
- #include "oneapi/dnnl/dnnl.hpp"
- namespace {
- template <typename T>
- struct DNNLType {
- static constexpr dnnl::memory::data_type type =
- dnnl::memory::data_type::undef;
- };
- template <>
- struct DNNLType<int8_t> {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
- };
- template <>
- struct DNNLType<int32_t> {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
- };
- template <>
- struct DNNLType<float> {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
- };
- template <>
- struct DNNLType<c10::BFloat16> {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
- };
- template <typename T>
- constexpr inline dnnl::memory::data_type get_dnnl_type() {
- return DNNLType<std::decay_t<T>>::type;
- }
- }; // namespace
- template <bool InputNoScale>
- class DNNLPrimitiveHelper {
- public:
- // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
- // A: [M, K], row-major
- // B: [K, N], column-major
- // C: [M, N], row-major
- // bias: [N], row-major, optional
- // a_scales: [MS]
- // b_scales: [NS]
- // Note: Due to the limitation of oneDNN
- // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
- // not supported.
- template <typename OutputT, typename BiasT>
- static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
- const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
- dnnl_dim_t K, const float* a_scales,
- const float* b_scales, dnnl_dim_t MS,
- dnnl_dim_t NS) {
- auto&& OutputType = get_dnnl_type<OutputT>();
- auto&& BiasType = get_dnnl_type<BiasT>();
- dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
- dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
- dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
- dnnl::primitive_attr attr;
- if constexpr (!InputNoScale) {
- if (MS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_SRC, 0);
- } else {
- // per-token
- TORCH_CHECK(false, "per-token quantization is unsupported.");
- }
- }
- if (NS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
- } else {
- // per-channel
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
- }
- dnnl::matmul::primitive_desc matmul_pd;
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- bias_md, c_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- c_md, attr);
- }
- dnnl::matmul matmul(matmul_pd);
- auto& engine = default_engine();
- dnnl::memory a_m(a_md, engine, (void*)a);
- dnnl::memory b_m(b_md, engine, (void*)b);
- dnnl::memory c_m(c_md, engine, (void*)c);
- dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)a_scales);
- dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)b_scales);
- auto& stream = default_stream();
- if constexpr (InputNoScale) {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, a_m},
- {DNNL_ARG_WEIGHTS, b_m},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, c_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, a_m},
- {DNNL_ARG_WEIGHTS, b_m},
- {DNNL_ARG_DST, c_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- } else {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, a_m},
- {DNNL_ARG_WEIGHTS, b_m},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, c_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, a_m},
- {DNNL_ARG_WEIGHTS, b_m},
- {DNNL_ARG_DST, c_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- }
- stream.wait();
- }
- private:
- static dnnl::engine& default_engine() {
- static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
- return engine;
- }
- static dnnl::stream& default_stream() {
- static dnnl::stream stream(default_engine());
- return stream;
- }
- };
- #endif
|