#ifndef DNNL_HELPER_HPP #define DNNL_HELPER_HPP #include #include "oneapi/dnnl/dnnl.hpp" namespace { template struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::undef; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; }; template <> struct DNNLType { static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; }; template constexpr inline dnnl::memory::data_type get_dnnl_type() { return DNNLType>::type; } }; // namespace template class DNNLPrimitiveHelper { public: // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) // A: [M, K], row-major // B: [K, N], column-major // C: [M, N], row-major // bias: [N], row-major, optional // a_scales: [MS] // b_scales: [NS] // Note: Due to the limitation of oneDNN // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is // not supported. template static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, const float* a_scales, const float* b_scales, dnnl_dim_t MS, dnnl_dim_t NS) { auto&& OutputType = get_dnnl_type(); auto&& BiasType = get_dnnl_type(); dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); dnnl::primitive_attr attr; if constexpr (!InputNoScale) { if (MS == 1) { // per-tensor attr.set_scales_mask(DNNL_ARG_SRC, 0); } else { // per-token TORCH_CHECK(false, "per-token quantization is unsupported."); } } if (NS == 1) { // per-tensor attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); } else { // per-channel attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); } dnnl::matmul::primitive_desc matmul_pd; if (bias) { dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, c_md, attr); } else { matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, attr); } dnnl::matmul matmul(matmul_pd); auto& engine = default_engine(); dnnl::memory a_m(a_md, engine, (void*)a); dnnl::memory b_m(b_md, engine, (void*)b); dnnl::memory c_m(c_md, engine, (void*)c); dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, (void*)a_scales); dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, (void*)b_scales); auto& stream = default_stream(); if constexpr (InputNoScale) { if (bias) { dnnl::memory::desc bias_md({N}, BiasType, {1}); dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { {DNNL_ARG_SRC, a_m}, {DNNL_ARG_WEIGHTS, b_m}, {DNNL_ARG_BIAS, bias_m}, {DNNL_ARG_DST, c_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { {DNNL_ARG_SRC, a_m}, {DNNL_ARG_WEIGHTS, b_m}, {DNNL_ARG_DST, c_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } } else { if (bias) { dnnl::memory::desc bias_md({N}, BiasType, {1}); dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { {DNNL_ARG_SRC, a_m}, {DNNL_ARG_WEIGHTS, b_m}, {DNNL_ARG_BIAS, bias_m}, {DNNL_ARG_DST, c_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { {DNNL_ARG_SRC, a_m}, {DNNL_ARG_WEIGHTS, b_m}, {DNNL_ARG_DST, c_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } } stream.wait(); } private: static dnnl::engine& default_engine() { static dnnl::engine engine(dnnl::engine::kind::cpu, 0); return engine; } static dnnl::stream& default_stream() { static dnnl::stream stream(default_engine()); return stream; } }; #endif