cpu_types.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. #ifndef CPU_TYPES_HPP
  2. #define CPU_TYPES_HPP
  3. #include <immintrin.h>
  4. #include <torch/all.h>
  5. #ifndef __AVX2__
  6. static_assert(false, "AVX2 must be supported for the current implementation.");
  7. #endif
  8. namespace vec_op {
  9. // FIXME: FP16 is not fully supported in Torch-CPU
  10. #define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...) \
  11. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  12. AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
  13. #define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  14. AT_DISPATCH_SWITCH(TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
  15. #ifndef CPU_OP_GUARD
  16. #define CPU_KERNEL_GUARD_IN(NAME)
  17. #define CPU_KERNEL_GUARD_OUT(NAME)
  18. #else
  19. #define CPU_KERNEL_GUARD_IN(NAME) \
  20. std::cout << #NAME << " invoked." << std::endl;
  21. #define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
  22. #endif
  23. #define FORCE_INLINE __attribute__((always_inline)) inline
  24. namespace {
  25. template <typename T, T... indexes, typename F>
  26. constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
  27. (f(std::integral_constant<T, indexes>{}), ...);
  28. }
  29. }; // namespace
  30. template <typename T, T count, typename F,
  31. typename = std::enable_if_t<std::is_invocable_v<F, T>>>
  32. constexpr void unroll_loop(F &&f) {
  33. unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
  34. }
  35. template <typename T> struct Vec {
  36. constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
  37. };
  38. struct FP32Vec8;
  39. struct FP32Vec16;
  40. #ifdef __AVX512FP16__
  41. struct FP16Vec8 : public Vec<FP16Vec8> {
  42. constexpr static int VEC_ELEM_NUM = 8;
  43. __m128h reg;
  44. explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
  45. explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
  46. explicit FP16Vec8(__m128h data) : reg(data) {}
  47. FP16Vec8 operator*(const FP16Vec8 &b) const {
  48. return FP16Vec8(_mm_mul_ph(reg, b.reg));
  49. }
  50. FP16Vec8 operator+(const FP16Vec8 &b) const {
  51. return FP16Vec8(_mm_add_ph(reg, b.reg));
  52. }
  53. FP16Vec8 operator-(const FP16Vec8 &b) const {
  54. return FP16Vec8(_mm_sub_ph(reg, b.reg));
  55. }
  56. FP16Vec8 operator/(const FP16Vec8 &b) const {
  57. return FP16Vec8(_mm_div_ph(reg, b.reg));
  58. }
  59. void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
  60. };
  61. #endif
  62. struct BF16Vec8 : public Vec<BF16Vec8> {
  63. constexpr static int VEC_ELEM_NUM = 8;
  64. __m128i reg;
  65. explicit BF16Vec8(const void *ptr)
  66. : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
  67. explicit BF16Vec8(const FP32Vec8 &);
  68. void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
  69. };
  70. struct BF16Vec16 : public Vec<BF16Vec16> {
  71. constexpr static int VEC_ELEM_NUM = 16;
  72. __m256i reg;
  73. explicit BF16Vec16(const void *ptr)
  74. : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
  75. explicit BF16Vec16(const FP32Vec16 &);
  76. void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
  77. };
  78. #ifdef __AVX512F__
  79. struct BF16Vec32 : public Vec<BF16Vec32> {
  80. constexpr static int VEC_ELEM_NUM = 32;
  81. __m512i reg;
  82. explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
  83. explicit BF16Vec32(__m512i data) : reg(data) {}
  84. explicit BF16Vec32(BF16Vec8 &vec8_data)
  85. : reg((__m512i)_mm512_inserti32x4(
  86. _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
  87. (__m128i)vec8_data.reg),
  88. (__m128i)vec8_data.reg, 1),
  89. (__m128i)vec8_data.reg, 2),
  90. (__m128i)vec8_data.reg, 3)) {}
  91. void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
  92. };
  93. #else
  94. struct BF16Vec32 : public Vec<BF16Vec32> {
  95. constexpr static int VEC_ELEM_NUM = 32;
  96. __m256i reg_low;
  97. __m256i reg_high;
  98. explicit BF16Vec32(const void *ptr)
  99. : reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
  100. reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
  101. explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
  102. reg_high(high) {}
  103. explicit BF16Vec32(BF16Vec8 &vec8_data)
  104. : reg_low((__m256i)_mm256_inserti32x4(
  105. _mm256_castsi128_si256((__m128i)vec8_data.reg),
  106. (__m128i)vec8_data.reg, 1)),
  107. reg_high((__m256i)_mm256_inserti32x4(
  108. _mm256_castsi128_si256((__m128i)vec8_data.reg),
  109. (__m128i)vec8_data.reg, 1)) {}
  110. void save(void *ptr) const {
  111. *reinterpret_cast<__m256i *>(ptr) = reg_low;
  112. *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
  113. }
  114. };
  115. #endif
  116. struct FP32Vec4 : public Vec<FP32Vec4> {
  117. constexpr static int VEC_ELEM_NUM = 4;
  118. union AliasReg {
  119. __m128 reg;
  120. float values[VEC_ELEM_NUM];
  121. };
  122. __m128 reg;
  123. explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
  124. explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
  125. explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
  126. explicit FP32Vec4(__m128 data) : reg(data) {}
  127. explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
  128. };
  129. struct FP32Vec8 : public Vec<FP32Vec8> {
  130. constexpr static int VEC_ELEM_NUM = 8;
  131. union AliasReg {
  132. __m256 reg;
  133. float values[VEC_ELEM_NUM];
  134. };
  135. __m256 reg;
  136. explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
  137. explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
  138. explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
  139. explicit FP32Vec8(__m256 data) : reg(data) {}
  140. explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
  141. #ifdef __AVX512FP16__
  142. explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
  143. #endif
  144. explicit FP32Vec8(const BF16Vec8 &v)
  145. : reg(_mm256_castsi256_ps(
  146. _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
  147. float reduce_sum() const {
  148. AliasReg ar;
  149. ar.reg = reg;
  150. float result = 0;
  151. unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
  152. return result;
  153. }
  154. FP32Vec8 exp() const {
  155. AliasReg ar;
  156. ar.reg = reg;
  157. return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
  158. expf(ar.values[5]), expf(ar.values[4]),
  159. expf(ar.values[3]), expf(ar.values[2]),
  160. expf(ar.values[1]), expf(ar.values[0])));
  161. }
  162. FP32Vec8 tanh() const {
  163. AliasReg ar;
  164. ar.reg = reg;
  165. return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
  166. tanhf(ar.values[5]), tanhf(ar.values[4]),
  167. tanhf(ar.values[3]), tanhf(ar.values[2]),
  168. tanhf(ar.values[1]), tanhf(ar.values[0])));
  169. }
  170. FP32Vec8 er() const {
  171. AliasReg ar;
  172. ar.reg = reg;
  173. return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
  174. erf(ar.values[5]), erf(ar.values[4]),
  175. erf(ar.values[3]), erf(ar.values[2]),
  176. erf(ar.values[1]), erf(ar.values[0])));
  177. }
  178. FP32Vec8 operator*(const FP32Vec8 &b) const {
  179. return FP32Vec8(_mm256_mul_ps(reg, b.reg));
  180. }
  181. FP32Vec8 operator+(const FP32Vec8 &b) const {
  182. return FP32Vec8(_mm256_add_ps(reg, b.reg));
  183. }
  184. FP32Vec8 operator-(const FP32Vec8 &b) const {
  185. return FP32Vec8(_mm256_sub_ps(reg, b.reg));
  186. }
  187. FP32Vec8 operator/(const FP32Vec8 &b) const {
  188. return FP32Vec8(_mm256_div_ps(reg, b.reg));
  189. }
  190. void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
  191. };
  192. #ifdef __AVX512F__
  193. struct FP32Vec16 : public Vec<FP32Vec16> {
  194. constexpr static int VEC_ELEM_NUM = 16;
  195. union AliasReg {
  196. __m512 reg;
  197. float values[VEC_ELEM_NUM];
  198. };
  199. __m512 reg;
  200. explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
  201. explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
  202. explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
  203. explicit FP32Vec16(__m512 data) : reg(data) {}
  204. explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
  205. explicit FP32Vec16(const FP32Vec4 &data)
  206. : reg((__m512)_mm512_inserti32x4(
  207. _mm512_inserti32x4(
  208. _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
  209. (__m128i)data.reg, 1),
  210. (__m128i)data.reg, 2),
  211. (__m128i)data.reg, 3)) {}
  212. explicit FP32Vec16(const FP32Vec8 &data)
  213. : reg((__m512)_mm512_inserti32x8(
  214. _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
  215. explicit FP32Vec16(const BF16Vec16 &v)
  216. : reg(_mm512_castsi512_ps(
  217. _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
  218. explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
  219. FP32Vec16 operator*(const FP32Vec16 &b) const {
  220. return FP32Vec16(_mm512_mul_ps(reg, b.reg));
  221. }
  222. FP32Vec16 operator+(const FP32Vec16 &b) const {
  223. return FP32Vec16(_mm512_add_ps(reg, b.reg));
  224. }
  225. FP32Vec16 operator-(const FP32Vec16 &b) const {
  226. return FP32Vec16(_mm512_sub_ps(reg, b.reg));
  227. }
  228. FP32Vec16 operator/(const FP32Vec16 &b) const {
  229. return FP32Vec16(_mm512_div_ps(reg, b.reg));
  230. }
  231. float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
  232. template <int group_size> float reduce_sub_sum(int idx) {
  233. static_assert(VEC_ELEM_NUM % group_size == 0);
  234. constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
  235. __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
  236. return _mm512_mask_reduce_add_ps(mask, reg);
  237. }
  238. void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
  239. };
  240. #else
  241. struct FP32Vec16 : public Vec<FP32Vec16> {
  242. constexpr static int VEC_ELEM_NUM = 16;
  243. union AliasReg {
  244. __m256 reg;
  245. float values[8];
  246. };
  247. __m256 reg_low;
  248. __m256 reg_high;
  249. explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
  250. reg_high(_mm256_set1_ps(v)) {}
  251. explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
  252. reg_high(_mm256_set1_ps(0.0)) {}
  253. explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
  254. reg_high(_mm256_loadu_ps(ptr + 8)) {}
  255. explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
  256. explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
  257. reg_high(data.reg_high) {}
  258. explicit FP32Vec16(const FP32Vec4 &data)
  259. : reg_low((__m256)_mm256_inserti128_si256(
  260. _mm256_castsi128_si256((__m128i)data.reg),
  261. (__m128i)data.reg, 1)),
  262. reg_high((__m256)_mm256_inserti128_si256(
  263. _mm256_castsi128_si256((__m128i)data.reg),
  264. (__m128i)data.reg, 1)) {}
  265. explicit FP32Vec16(const FP32Vec8 &data)
  266. : reg_low(data.reg), reg_high(data.reg) {}
  267. explicit FP32Vec16(const BF16Vec16 &v) {
  268. __m128i low = _mm256_extractf128_si256(v.reg, 0);
  269. __m128i high = _mm256_extractf128_si256(v.reg, 1);
  270. __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
  271. __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
  272. __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
  273. __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
  274. reg_low = _mm256_castsi256_ps(v_low_shifted);
  275. reg_high = _mm256_castsi256_ps(v_high_shifted);
  276. }
  277. explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
  278. FP32Vec16 operator*(const FP32Vec16 &b) const {
  279. return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
  280. _mm256_mul_ps(reg_high, b.reg_high));
  281. }
  282. FP32Vec16 operator+(const FP32Vec16 &b) const {
  283. return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
  284. _mm256_add_ps(reg_high, b.reg_high));
  285. }
  286. FP32Vec16 operator-(const FP32Vec16 &b) const {
  287. return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
  288. _mm256_sub_ps(reg_high, b.reg_high));
  289. }
  290. FP32Vec16 operator/(const FP32Vec16 &b) const {
  291. return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
  292. _mm256_div_ps(reg_high, b.reg_high));
  293. }
  294. float reduce_sum() const {
  295. FP32Vec8 low = FP32Vec8(reg_low);
  296. FP32Vec8 high = FP32Vec8(reg_high);
  297. return low.reduce_sum() + high.reduce_sum();
  298. }
  299. template <int group_size> float reduce_sub_sum(int idx) {
  300. float sum = 0.0;
  301. static_assert(VEC_ELEM_NUM % group_size == 0);
  302. constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
  303. uint32_t mask = base_mask << (idx * group_size);
  304. AliasReg ar;
  305. auto func = [&sum, &mask, &ar](int i) {
  306. int flag = mask & 0x1;
  307. mask = mask >> 1;
  308. if (flag != 0) sum += ar.values[i];
  309. };
  310. ar.reg = reg_low;
  311. unroll_loop<int, 8>(func);
  312. ar.reg = reg_high;
  313. unroll_loop<int, 8>(func);
  314. return sum;
  315. }
  316. void save(float *ptr) const {
  317. _mm256_storeu_ps(ptr, reg_low);
  318. _mm256_storeu_ps(ptr + 8, reg_high);
  319. }
  320. };
  321. #endif
  322. template <typename T> struct VecType { using vec_type = void; };
  323. template <typename T> using vec_t = typename VecType<T>::vec_type;
  324. template <> struct VecType<float> { using vec_type = FP32Vec8; };
  325. #ifdef __AVX512FP16__
  326. template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
  327. #endif
  328. template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
  329. template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
  330. #ifdef __AVX512FP16__
  331. template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
  332. *reinterpret_cast<_Float16 *>(ptr) = v;
  333. }
  334. #endif
  335. inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
  336. acc = acc + a * b;
  337. }
  338. #ifdef __AVX512BF16__
  339. template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
  340. *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
  341. }
  342. inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
  343. : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
  344. inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
  345. : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
  346. inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
  347. acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
  348. }
  349. #else
  350. template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
  351. c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
  352. reinterpret_cast<c10::BFloat16 *>(&v);
  353. *ptr = *(v_ptr + 1);
  354. }
  355. #ifdef __AVX512F__
  356. inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
  357. : reg(_mm256_cvtepi32_epi16(
  358. _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
  359. inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
  360. : reg(_mm512_cvtepi32_epi16(
  361. _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
  362. #else
  363. namespace{
  364. __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
  365. __m256i ai = _mm256_castps_si256(a);
  366. ai = _mm256_srli_epi32(ai, 16);
  367. ai = _mm256_packus_epi32(ai, ai);
  368. ai = _mm256_permute4x64_epi64(ai, 0b00111001);
  369. return _mm256_extracti128_si256(ai, 0);
  370. }
  371. }
  372. inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
  373. : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
  374. inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
  375. BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
  376. BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
  377. reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
  378. }
  379. #endif // __AVX512F__
  380. #endif // __AVX512BF16__
  381. inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
  382. }; // namespace vec_op
  383. #endif