dnnl_helper.hpp 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. #ifndef DNNL_HELPER_HPP
  2. #define DNNL_HELPER_HPP
  3. #include <c10/util/BFloat16.h>
  4. #include "oneapi/dnnl/dnnl.hpp"
  5. namespace {
  6. template <typename T>
  7. struct DNNLType {
  8. static constexpr dnnl::memory::data_type type =
  9. dnnl::memory::data_type::undef;
  10. };
  11. template <>
  12. struct DNNLType<int8_t> {
  13. static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
  14. };
  15. template <>
  16. struct DNNLType<int32_t> {
  17. static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
  18. };
  19. template <>
  20. struct DNNLType<float> {
  21. static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
  22. };
  23. template <>
  24. struct DNNLType<c10::BFloat16> {
  25. static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
  26. };
  27. template <typename T>
  28. constexpr inline dnnl::memory::data_type get_dnnl_type() {
  29. return DNNLType<std::decay_t<T>>::type;
  30. }
  31. }; // namespace
  32. template <bool InputNoScale>
  33. class DNNLPrimitiveHelper {
  34. public:
  35. // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
  36. // A: [M, K], row-major
  37. // B: [K, N], column-major
  38. // C: [M, N], row-major
  39. // bias: [N], row-major, optional
  40. // a_scales: [MS]
  41. // b_scales: [NS]
  42. // Note: Due to the limitation of oneDNN
  43. // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
  44. // not supported.
  45. template <typename OutputT, typename BiasT>
  46. static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
  47. const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
  48. dnnl_dim_t K, const float* a_scales,
  49. const float* b_scales, dnnl_dim_t MS,
  50. dnnl_dim_t NS) {
  51. auto&& OutputType = get_dnnl_type<OutputT>();
  52. auto&& BiasType = get_dnnl_type<BiasT>();
  53. dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
  54. dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
  55. dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
  56. dnnl::primitive_attr attr;
  57. if constexpr (!InputNoScale) {
  58. if (MS == 1) {
  59. // per-tensor
  60. attr.set_scales_mask(DNNL_ARG_SRC, 0);
  61. } else {
  62. // per-token
  63. TORCH_CHECK(false, "per-token quantization is unsupported.");
  64. }
  65. }
  66. if (NS == 1) {
  67. // per-tensor
  68. attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
  69. } else {
  70. // per-channel
  71. attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
  72. }
  73. dnnl::matmul::primitive_desc matmul_pd;
  74. if (bias) {
  75. dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
  76. matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
  77. bias_md, c_md, attr);
  78. } else {
  79. matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
  80. c_md, attr);
  81. }
  82. dnnl::matmul matmul(matmul_pd);
  83. auto& engine = default_engine();
  84. dnnl::memory a_m(a_md, engine, (void*)a);
  85. dnnl::memory b_m(b_md, engine, (void*)b);
  86. dnnl::memory c_m(c_md, engine, (void*)c);
  87. dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
  88. (void*)a_scales);
  89. dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
  90. (void*)b_scales);
  91. auto& stream = default_stream();
  92. if constexpr (InputNoScale) {
  93. if (bias) {
  94. dnnl::memory::desc bias_md({N}, BiasType, {1});
  95. dnnl::memory bias_m(bias_md, engine, (void*)bias);
  96. matmul.execute(
  97. stream, {
  98. {DNNL_ARG_SRC, a_m},
  99. {DNNL_ARG_WEIGHTS, b_m},
  100. {DNNL_ARG_BIAS, bias_m},
  101. {DNNL_ARG_DST, c_m},
  102. {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
  103. });
  104. } else {
  105. matmul.execute(
  106. stream, {
  107. {DNNL_ARG_SRC, a_m},
  108. {DNNL_ARG_WEIGHTS, b_m},
  109. {DNNL_ARG_DST, c_m},
  110. {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
  111. });
  112. }
  113. } else {
  114. if (bias) {
  115. dnnl::memory::desc bias_md({N}, BiasType, {1});
  116. dnnl::memory bias_m(bias_md, engine, (void*)bias);
  117. matmul.execute(
  118. stream, {
  119. {DNNL_ARG_SRC, a_m},
  120. {DNNL_ARG_WEIGHTS, b_m},
  121. {DNNL_ARG_BIAS, bias_m},
  122. {DNNL_ARG_DST, c_m},
  123. {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
  124. {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
  125. });
  126. } else {
  127. matmul.execute(
  128. stream, {
  129. {DNNL_ARG_SRC, a_m},
  130. {DNNL_ARG_WEIGHTS, b_m},
  131. {DNNL_ARG_DST, c_m},
  132. {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
  133. {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
  134. });
  135. }
  136. }
  137. stream.wait();
  138. }
  139. private:
  140. static dnnl::engine& default_engine() {
  141. static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
  142. return engine;
  143. }
  144. static dnnl::stream& default_stream() {
  145. static dnnl::stream stream(default_engine());
  146. return stream;
  147. }
  148. };
  149. #endif