cpu_types.hpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #ifndef CPU_TYPES_HPP
  2. #define CPU_TYPES_HPP
  3. #include <immintrin.h>
  4. #include <torch/extension.h>
  5. namespace vec_op {
  6. // FIXME: FP16 is not fully supported in Torch-CPU
  7. #define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...) \
  8. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  9. AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
  10. #define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  11. AT_DISPATCH_SWITCH(TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
  12. #ifndef CPU_OP_GUARD
  13. #define CPU_KERNEL_GUARD_IN(NAME)
  14. #define CPU_KERNEL_GUARD_OUT(NAME)
  15. #else
  16. #define CPU_KERNEL_GUARD_IN(NAME) \
  17. std::cout << #NAME << " invoked." << std::endl;
  18. #define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
  19. #endif
  20. #define FORCE_INLINE __attribute__((always_inline)) inline
  21. namespace {
  22. template <typename T, T... indexes, typename F>
  23. constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
  24. (f(std::integral_constant<T, indexes>{}), ...);
  25. }
  26. }; // namespace
  27. template <typename T, T count, typename F,
  28. typename = std::enable_if_t<std::is_invocable_v<F, T>>>
  29. constexpr void unroll_loop(F &&f) {
  30. unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
  31. }
  32. template <typename T> struct Vec {
  33. constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
  34. };
  35. struct FP32Vec8;
  36. struct FP32Vec16;
  37. #ifdef __AVX512FP16__
  38. struct FP16Vec8 : public Vec<FP16Vec8> {
  39. constexpr static int VEC_ELEM_NUM = 8;
  40. __m128h reg;
  41. explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
  42. explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
  43. explicit FP16Vec8(__m128h data) : reg(data) {}
  44. FP16Vec8 operator*(const FP16Vec8 &b) const {
  45. return FP16Vec8(_mm_mul_ph(reg, b.reg));
  46. }
  47. FP16Vec8 operator+(const FP16Vec8 &b) const {
  48. return FP16Vec8(_mm_add_ph(reg, b.reg));
  49. }
  50. FP16Vec8 operator-(const FP16Vec8 &b) const {
  51. return FP16Vec8(_mm_sub_ph(reg, b.reg));
  52. }
  53. FP16Vec8 operator/(const FP16Vec8 &b) const {
  54. return FP16Vec8(_mm_div_ph(reg, b.reg));
  55. }
  56. void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
  57. };
  58. #endif
  59. struct BF16Vec8 : public Vec<BF16Vec8> {
  60. constexpr static int VEC_ELEM_NUM = 8;
  61. __m128i reg;
  62. explicit BF16Vec8(const void *ptr)
  63. : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
  64. explicit BF16Vec8(const FP32Vec8 &);
  65. void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
  66. };
  67. struct BF16Vec16 : public Vec<BF16Vec16> {
  68. constexpr static int VEC_ELEM_NUM = 16;
  69. __m256i reg;
  70. explicit BF16Vec16(const void *ptr)
  71. : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
  72. explicit BF16Vec16(const FP32Vec16 &);
  73. void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
  74. };
  75. struct BF16Vec32 : public Vec<BF16Vec32> {
  76. constexpr static int VEC_ELEM_NUM = 32;
  77. __m512i reg;
  78. explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
  79. explicit BF16Vec32(__m512i data) : reg(data) {}
  80. explicit BF16Vec32(BF16Vec8 &vec8_data)
  81. : reg((__m512i)_mm512_inserti32x4(
  82. _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
  83. (__m128i)vec8_data.reg),
  84. (__m128i)vec8_data.reg, 1),
  85. (__m128i)vec8_data.reg, 2),
  86. (__m128i)vec8_data.reg, 3)) {}
  87. void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
  88. };
  89. struct FP32Vec4 : public Vec<FP32Vec4> {
  90. constexpr static int VEC_ELEM_NUM = 4;
  91. union AliasReg {
  92. __m128 reg;
  93. float values[VEC_ELEM_NUM];
  94. };
  95. __m128 reg;
  96. explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
  97. explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
  98. explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
  99. explicit FP32Vec4(__m128 data) : reg(data) {}
  100. explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
  101. };
  102. struct FP32Vec8 : public Vec<FP32Vec8> {
  103. constexpr static int VEC_ELEM_NUM = 8;
  104. union AliasReg {
  105. __m256 reg;
  106. float values[VEC_ELEM_NUM];
  107. };
  108. __m256 reg;
  109. explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
  110. explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
  111. explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
  112. explicit FP32Vec8(__m256 data) : reg(data) {}
  113. explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
  114. #ifdef __AVX512FP16__
  115. explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
  116. #endif
  117. explicit FP32Vec8(const BF16Vec8 &v)
  118. : reg(_mm256_castsi256_ps(
  119. _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
  120. float reduce_sum() const {
  121. AliasReg ar;
  122. ar.reg = reg;
  123. float result = 0;
  124. unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
  125. return result;
  126. }
  127. FP32Vec8 exp() const {
  128. AliasReg ar;
  129. ar.reg = reg;
  130. return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
  131. expf(ar.values[5]), expf(ar.values[4]),
  132. expf(ar.values[3]), expf(ar.values[2]),
  133. expf(ar.values[1]), expf(ar.values[0])));
  134. }
  135. FP32Vec8 tanh() const {
  136. AliasReg ar;
  137. ar.reg = reg;
  138. return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
  139. tanhf(ar.values[5]), tanhf(ar.values[4]),
  140. tanhf(ar.values[3]), tanhf(ar.values[2]),
  141. tanhf(ar.values[1]), tanhf(ar.values[0])));
  142. }
  143. FP32Vec8 er() const {
  144. AliasReg ar;
  145. ar.reg = reg;
  146. return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
  147. erf(ar.values[5]), erf(ar.values[4]),
  148. erf(ar.values[3]), erf(ar.values[2]),
  149. erf(ar.values[1]), erf(ar.values[0])));
  150. }
  151. FP32Vec8 operator*(const FP32Vec8 &b) const {
  152. return FP32Vec8(_mm256_mul_ps(reg, b.reg));
  153. }
  154. FP32Vec8 operator+(const FP32Vec8 &b) const {
  155. return FP32Vec8(_mm256_add_ps(reg, b.reg));
  156. }
  157. FP32Vec8 operator-(const FP32Vec8 &b) const {
  158. return FP32Vec8(_mm256_sub_ps(reg, b.reg));
  159. }
  160. FP32Vec8 operator/(const FP32Vec8 &b) const {
  161. return FP32Vec8(_mm256_div_ps(reg, b.reg));
  162. }
  163. void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
  164. };
  165. struct FP32Vec16 : public Vec<FP32Vec16> {
  166. constexpr static int VEC_ELEM_NUM = 16;
  167. union AliasReg {
  168. __m512 reg;
  169. float values[VEC_ELEM_NUM];
  170. };
  171. __m512 reg;
  172. explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
  173. explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
  174. explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
  175. explicit FP32Vec16(__m512 data) : reg(data) {}
  176. explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
  177. explicit FP32Vec16(const FP32Vec4 &data)
  178. : reg((__m512)_mm512_inserti32x4(
  179. _mm512_inserti32x4(
  180. _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
  181. (__m128i)data.reg, 1),
  182. (__m128i)data.reg, 2),
  183. (__m128i)data.reg, 3)) {}
  184. explicit FP32Vec16(const FP32Vec8 &data)
  185. : reg((__m512)_mm512_inserti32x8(
  186. _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
  187. explicit FP32Vec16(const BF16Vec16 &v)
  188. : reg(_mm512_castsi512_ps(
  189. _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
  190. explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
  191. FP32Vec16 operator*(const FP32Vec16 &b) const {
  192. return FP32Vec16(_mm512_mul_ps(reg, b.reg));
  193. }
  194. FP32Vec16 operator+(const FP32Vec16 &b) const {
  195. return FP32Vec16(_mm512_add_ps(reg, b.reg));
  196. }
  197. FP32Vec16 operator-(const FP32Vec16 &b) const {
  198. return FP32Vec16(_mm512_sub_ps(reg, b.reg));
  199. }
  200. FP32Vec16 operator/(const FP32Vec16 &b) const {
  201. return FP32Vec16(_mm512_div_ps(reg, b.reg));
  202. }
  203. float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
  204. template <int group_size> float reduce_sub_sum(int idx) {
  205. static_assert(VEC_ELEM_NUM % group_size == 0);
  206. constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
  207. __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
  208. return _mm512_mask_reduce_add_ps(mask, reg);
  209. }
  210. void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
  211. };
  212. template <typename T> struct VecType { using vec_type = void; };
  213. template <typename T> using vec_t = typename VecType<T>::vec_type;
  214. template <> struct VecType<float> { using vec_type = FP32Vec8; };
  215. #ifdef __AVX512FP16__
  216. template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
  217. #endif
  218. template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
  219. template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
  220. #ifdef __AVX512FP16__
  221. template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
  222. *reinterpret_cast<_Float16 *>(ptr) = v;
  223. }
  224. #endif
  225. inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
  226. acc = acc + a * b;
  227. }
  228. #ifdef __AVX512BF16__
  229. template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
  230. *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
  231. }
  232. inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
  233. : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
  234. inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
  235. : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
  236. inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
  237. acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
  238. }
  239. #else
  240. template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
  241. c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
  242. reinterpret_cast<c10::BFloat16 *>(&v);
  243. *ptr = *(v_ptr + 1);
  244. }
  245. inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
  246. : reg(_mm256_cvtepi32_epi16(
  247. _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
  248. inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
  249. : reg(_mm512_cvtepi32_epi16(
  250. _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
  251. #endif
  252. inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
  253. }; // namespace vec_op
  254. #endif