aphrodite_numeric_conversion.cuh 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. #pragma once
  2. #include "cutlass/numeric_conversion.h"
  3. #include "cutlass_extensions/aphrodite_custom_types.cuh"
  4. #include "cutlass_extensions/cute_utils.cuh"
  5. // this file extends:
  6. // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
  7. // with aphrodite specific type conversions, namely: aphrodite_uint4b8_t,
  8. // aphrodite_uint8b128_t as well as adds interleaved numeric array converters
  9. // for specific types. (interleaved numeric array converters can be more
  10. // efficient for subbyte types)
  11. namespace cutlass {
  12. // InterleavedNumericArrayConverter is like NumericArrayConverter but also
  13. // deinterleaves converted elements based on IlvBlkLayout, interleaving can
  14. // make subbyte converts more efficient by allowing for efficient extraction
  15. // of subbyte elements from a 32bit register.
  16. template <typename IlvBlkLayout, typename T, typename S, int N,
  17. FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
  18. class Enable = void>
  19. struct InterleavedNumericArrayConverter {
  20. using Converter = NumericArrayConverter<T, S, N, Round>;
  21. using result_type = typename Converter::result_type;
  22. using source_type = typename Converter::source_type;
  23. CUTLASS_DEVICE
  24. static result_type convert(source_type const& source) {
  25. CUTE_INVALID_CONTROL_PATH(
  26. "InterleavedNumericArrayConverter not implemented\n");
  27. return {};
  28. }
  29. CUTLASS_DEVICE
  30. result_type operator()(source_type const& s) const { return convert(s); }
  31. };
  32. template <typename IlvBlkLayout, typename T, typename S, int N,
  33. FloatRoundStyle Round>
  34. struct InterleavedNumericArrayConverter<
  35. IlvBlkLayout, T, S, N, Round,
  36. std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
  37. using Converter = NumericArrayConverter<T, S, N, Round>;
  38. using result_type = typename Converter::result_type;
  39. using source_type = typename Converter::source_type;
  40. CUTLASS_DEVICE
  41. static result_type convert(source_type const& source) {
  42. return Converter::convert(source);
  43. }
  44. CUTLASS_DEVICE
  45. result_type operator()(source_type const& s) const { return convert(s); }
  46. };
  47. // TODO (LucasWilkinson): Implement
  48. // for Array<cutlass::float8_e4m3fn, N> <= Array<aphrodite_uint4b8_t, N>
  49. // ....
  50. template <typename RegConvert32bit, typename T, typename S, int N>
  51. struct ArrayConverterPacked32Bit {
  52. using result_type = Array<T, N>;
  53. using source_type = Array<S, N>;
  54. using result_packed_8_t = Array<T, 8>;
  55. using result_packed_4_t = Array<T, 4>;
  56. using result_packed_2_t = Array<T, 2>;
  57. using src_packed_8_t = Array<S, 8>;
  58. using src_packed_4_t = Array<S, 4>;
  59. using src_packed_2_t = Array<S, 2>;
  60. static_assert(N % 2 == 0, "N must be a multiple of 2");
  61. static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
  62. static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
  63. static constexpr auto src_elems_per_32bit_reg =
  64. 32 / cutlass::sizeof_bits_v<S>;
  65. // Maybe not Valid. ScalarConverter will not actually work unless
  66. // NumericConverter<T, S, Round> is implemented. However it won't be used
  67. // anyways since we assert N % 2 == 0, just here for compliance with
  68. // VectorizedConverter.
  69. using ScalarConverter = NumericConverter<T, S>;
  70. template <typename PackedSrc>
  71. CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
  72. if constexpr (sizeof(PackedSrc) == 1) {
  73. return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
  74. } else if constexpr (sizeof(PackedSrc) == 2) {
  75. return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
  76. } else {
  77. static_assert(sizeof(PackedSrc) == 4);
  78. return reinterpret_cast<const uint32_t&>(source);
  79. }
  80. }
  81. // The core converter uses bit tricks to construct a known FP16 number, then
  82. // does a subtraction in FP16 for the final result.
  83. template <typename PackedResultType, typename PackedSrcType>
  84. CUTLASS_DEVICE static PackedResultType packed_convert(
  85. PackedSrcType const& source) {
  86. static_assert(PackedSrcType::kElements == PackedResultType::kElements);
  87. static_assert(PackedResultType::kElements == 2 ||
  88. PackedResultType::kElements == 4 ||
  89. PackedResultType::kElements == 8,
  90. "Invalid PackedResultType must be 2, 4 or 8.");
  91. static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
  92. static_assert(std::is_same_v<typename PackedResultType::Element, T>);
  93. return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
  94. }
  95. friend class detail::VectorizedConverter;
  96. public:
  97. CUTLASS_DEVICE static result_type convert(source_type const& source) {
  98. result_type result;
  99. using ConverterType =
  100. ArrayConverterPacked32Bit<RegConvert32bit,
  101. typename result_type::Element,
  102. typename source_type::Element, N>;
  103. if constexpr (src_elems_per_32bit_reg >= 8) {
  104. detail::VectorizedConverter::convert<
  105. ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
  106. src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
  107. } else if constexpr (src_elems_per_32bit_reg >= 4) {
  108. detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
  109. src_packed_4_t, result_packed_2_t,
  110. src_packed_2_t>(result, source);
  111. } else {
  112. detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
  113. src_packed_2_t>(result, source);
  114. }
  115. return result;
  116. }
  117. };
  118. // for Array<cutlass::half_t, N> <= Array<aphrodite_uint4b8_t, N>
  119. template <FloatRoundStyle Round, int N>
  120. struct NumericArrayConverter<cutlass::half_t, aphrodite_uint4b8_t, N, Round> {
  121. using result_type = Array<cutlass::half_t, N>;
  122. using source_type = Array<aphrodite_uint4b8_t, N>;
  123. struct RegConvert {
  124. template <typename PackedResultType>
  125. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  126. using RegArray =
  127. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  128. sizeof(PackedResultType)>;
  129. RegArray r;
  130. // Below constructs the following temporary:
  131. // fp16s_01 = {0x00, i4_01, 0x00, i4_01}
  132. // fp16s_23 = {0x00, i4_23, 0x00, i4_23}
  133. // fp16s_45 = {0x00, i4_45, 0x00, i4_45}
  134. // fp16s_67 = {0x00, i4_67, 0x00, i4_67}
  135. // We use inline asm instead of __byte_perm intrinsic since we don't want
  136. // the documented (& 0x7) on the index. NVCC might be able to optimize it
  137. // out since the index is a constexpr, but we choose to be safe about it
  138. // here.
  139. uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
  140. static_assert(RegArray::kElements <= 4,
  141. "Too many inputs for F16 -> I4 vector converter");
  142. CUTLASS_PRAGMA_UNROLL
  143. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  144. asm volatile(
  145. "{\n"
  146. " prmt.b32 %0, %1, %2, %3;\n"
  147. "}\n"
  148. : "=r"(r[ii])
  149. : "r"(src), "n"(0), "r"(prmt_indices[ii]));
  150. }
  151. // Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
  152. // we are trying to construct x and a fp16 value
  153. // The below XOR does the following:
  154. // 1) Sets the exponent bits of the FP16 to the correct value for the
  155. // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
  156. // where x1 in the high nibble and x0 is the low nibble then using hfma
  157. // to subtract 1032 from that
  158. // The AND does the following:
  159. // 1) Clear the set bits for the int4 we will ignore.
  160. // We use lop3 so that we can use 1 instruction for AND and XOR.
  161. static constexpr uint32_t xor_mask = 0x64006400;
  162. static constexpr uint32_t and_mask = 0xFFF0FF0F;
  163. static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
  164. // For each operand, computes:
  165. // r[i] = (r[i] & and_mask) ^ xor_mask
  166. CUTLASS_PRAGMA_UNROLL
  167. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  168. asm volatile(
  169. "{\n"
  170. " lop3.b32 %0, %0, %1, %2, %3;\n"
  171. "}\n"
  172. : "+r"(r[ii])
  173. : "n"(and_mask), "n"(xor_mask), "n"(immLut));
  174. }
  175. // We will issue 2 hfmas that do the following:
  176. // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
  177. // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
  178. static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
  179. static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
  180. const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
  181. const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
  182. CUTLASS_PRAGMA_UNROLL
  183. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  184. half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
  185. fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
  186. }
  187. return reinterpret_cast<PackedResultType&>(r);
  188. };
  189. };
  190. public:
  191. CUTLASS_DEVICE
  192. static result_type convert(source_type const& source) {
  193. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  194. typename source_type::Element,
  195. N>::convert(source);
  196. }
  197. CUTLASS_DEVICE
  198. result_type operator()(source_type const& s) const { return convert(s); }
  199. };
  200. // for Array<cutlass::half_t, N> <= Array<aphrodite_uint4b8_t, N>
  201. // for IlvdLayout: (2, 4):(4, 1)
  202. template <FloatRoundStyle Round, int N>
  203. struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
  204. cutlass::half_t, aphrodite_uint4b8_t, N,
  205. Round, void> {
  206. using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
  207. static_assert(N % size(IlvdLayout{}) == 0);
  208. using result_type = Array<cutlass::half_t, N>;
  209. using source_type = Array<aphrodite_uint4b8_t, N>;
  210. static FloatRoundStyle const round_style = Round;
  211. private:
  212. struct RegConvert {
  213. template <typename PackedResultType>
  214. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  215. using RegArray =
  216. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  217. sizeof(PackedResultType)>;
  218. RegArray r;
  219. static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
  220. static constexpr uint32_t xor_mask = 0x64006400;
  221. for (int ii = 0; ii < RegArray::kElements; ii += 2) {
  222. auto src_ = src >> (4 * (ii));
  223. r[ii + 0] = src_;
  224. r[ii + 1] = src_;
  225. static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
  226. static constexpr uint32_t low_nib_mask = 0x000F000F;
  227. static constexpr uint32_t high_nib_mask = 0x00F000F0;
  228. asm volatile(
  229. "{\n"
  230. " lop3.b32 %0, %0, %1, %2, %3;\n"
  231. "}\n"
  232. : "+r"(r[ii + 0])
  233. : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
  234. asm volatile(
  235. "{\n"
  236. " lop3.b32 %0, %0, %1, %2, %3;\n"
  237. "}\n"
  238. : "+r"(r[ii + 1])
  239. : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
  240. // For low nibble:
  241. // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
  242. // For high nibble:
  243. // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
  244. // - {72, 72}
  245. static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
  246. static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
  247. static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
  248. {
  249. half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
  250. fp16x2_val =
  251. __hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
  252. }
  253. {
  254. half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
  255. fp16x2_val = __hfma2(fp16x2_val,
  256. reinterpret_cast<const half2&>(high_nib_scale),
  257. reinterpret_cast<const half2&>(high_nib_bias));
  258. }
  259. }
  260. return reinterpret_cast<PackedResultType&>(r);
  261. };
  262. };
  263. public:
  264. CUTLASS_DEVICE
  265. static result_type convert(source_type const& source) {
  266. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  267. typename source_type::Element,
  268. N>::convert(source);
  269. }
  270. CUTLASS_DEVICE
  271. result_type operator()(source_type const& s) const { return convert(s); }
  272. };
  273. // for Array<cutlass::half_t, N> <= Array<uint4_t, N>
  274. // for IlvdLayout: (2, 4):(4, 1)
  275. template <FloatRoundStyle Round, int N>
  276. struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
  277. cutlass::half_t, uint4_t, N, Round,
  278. void> {
  279. using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
  280. static_assert(N % size(IlvdLayout{}) == 0);
  281. using result_type = Array<cutlass::half_t, N>;
  282. using source_type = Array<uint4_t, N>;
  283. static FloatRoundStyle const round_style = Round;
  284. private:
  285. struct RegConvert {
  286. template <typename PackedResultType>
  287. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  288. using RegArray =
  289. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  290. sizeof(PackedResultType)>;
  291. RegArray r;
  292. static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
  293. static constexpr uint32_t xor_mask = 0x64006400;
  294. for (int ii = 0; ii < RegArray::kElements; ii += 2) {
  295. auto src_ = src >> (4 * (ii));
  296. r[ii + 0] = src_;
  297. r[ii + 1] = src_;
  298. static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
  299. static constexpr uint32_t low_nib_mask = 0x000F000F;
  300. static constexpr uint32_t high_nib_mask = 0x00F000F0;
  301. asm volatile(
  302. "{\n"
  303. " lop3.b32 %0, %0, %1, %2, %3;\n"
  304. "}\n"
  305. : "+r"(r[ii + 0])
  306. : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
  307. asm volatile(
  308. "{\n"
  309. " lop3.b32 %0, %0, %1, %2, %3;\n"
  310. "}\n"
  311. : "+r"(r[ii + 1])
  312. : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
  313. // For low nibble:
  314. // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
  315. // For high nibble:
  316. // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
  317. static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
  318. static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
  319. static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
  320. {
  321. half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
  322. fp16x2_val =
  323. __hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
  324. }
  325. {
  326. half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
  327. fp16x2_val = __hfma2(fp16x2_val,
  328. reinterpret_cast<const half2&>(high_nib_scale),
  329. reinterpret_cast<const half2&>(high_nib_bias));
  330. }
  331. }
  332. return reinterpret_cast<PackedResultType&>(r);
  333. };
  334. };
  335. public:
  336. CUTLASS_DEVICE
  337. static result_type convert(source_type const& source) {
  338. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  339. typename source_type::Element,
  340. N>::convert(source);
  341. }
  342. CUTLASS_DEVICE
  343. result_type operator()(source_type const& s) const { return convert(s); }
  344. };
  345. // for Array<cutlass::half_t, N> <= Array<aphrodite_uint8b128_t, N>
  346. template <FloatRoundStyle Round, int N>
  347. struct NumericArrayConverter<cutlass::half_t, aphrodite_uint8b128_t, N, Round> {
  348. using result_type = Array<cutlass::half_t, N>;
  349. using source_type = Array<aphrodite_uint8b128_t, N>;
  350. struct RegConvert {
  351. template <typename PackedResultType>
  352. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  353. // Hold output FP16s in reg. We need 1 reg for every 2 elements
  354. using RegArray =
  355. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  356. sizeof(PackedResultType)>;
  357. RegArray r;
  358. uint32_t const prmt_indices[2] = {0x5150, 0x5352};
  359. static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
  360. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  361. asm volatile("prmt.b32 %0,%1,%2,%3;\n"
  362. : "=r"(r[ii])
  363. : "r"(src), "n"(start_byte_for_fp16),
  364. "r"(prmt_indices[ii]));
  365. }
  366. // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
  367. static constexpr uint32_t bias_rep = 0x64806480;
  368. const half2& bias = reinterpret_cast<const half2&>(bias_rep);
  369. CUTLASS_PRAGMA_UNROLL
  370. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  371. half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
  372. fp16x2_val = __hsub2(fp16x2_val, bias);
  373. }
  374. return reinterpret_cast<PackedResultType&>(r);
  375. };
  376. };
  377. public:
  378. CUTLASS_DEVICE
  379. static result_type convert(source_type const& source) {
  380. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  381. typename source_type::Element,
  382. N>::convert(source);
  383. }
  384. CUTLASS_DEVICE
  385. result_type operator()(source_type const& s) const { return convert(s); }
  386. };
  387. // for Array<cutlass::float, N> <= Array<aphrodite_uint8b128_t, N>
  388. template <FloatRoundStyle Round, int N>
  389. struct NumericArrayConverter<float, aphrodite_uint8b128_t, N, Round> {
  390. using result_type = Array<float, N>;
  391. using source_type = Array<aphrodite_uint8b128_t, N>;
  392. static FloatRoundStyle const round_style = Round;
  393. private:
  394. struct RegConvert {
  395. template <typename PackedResultType>
  396. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  397. PackedResultType r;
  398. // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
  399. // u8x4 source and stores the result in r (without introducing extra
  400. // cvt.u32.u8 instruction)
  401. uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
  402. uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
  403. for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
  404. result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
  405. // Subtract the magic number 0x4B000000 from tmp in floating-point
  406. // arithmetic to obtain final result
  407. r[ii] -= (8388608.f + 128.f); // fold in -128 bias
  408. }
  409. return r;
  410. };
  411. };
  412. public:
  413. CUTLASS_DEVICE
  414. static result_type convert(source_type const& source) {
  415. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  416. typename source_type::Element,
  417. N>::convert(source);
  418. }
  419. CUTLASS_DEVICE
  420. result_type operator()(source_type const& s) const { return convert(s); }
  421. };
  422. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
  423. // for Array<cutlass::bfloat16_t, N> <= Array<aphrodite_uint4b8_t, N>
  424. template <FloatRoundStyle Round, int N>
  425. struct NumericArrayConverter<cutlass::bfloat16_t, aphrodite_uint4b8_t, N,
  426. Round> {
  427. using result_type = Array<cutlass::bfloat16_t, N>;
  428. using source_type = Array<aphrodite_uint4b8_t, N>;
  429. static FloatRoundStyle const round_style = Round;
  430. private:
  431. struct RegConvert {
  432. template <typename PackedResultType>
  433. CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
  434. // Hold output BF16s in reg. We need 1 reg for every 2 elements
  435. using RegArray =
  436. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  437. sizeof(PackedResultType)>;
  438. RegArray r;
  439. uint32_t src_reg_shifted = src_reg >> 4;
  440. // Below constructs the following temporary:
  441. uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
  442. static_assert(RegArray::kElements <= 4,
  443. "Too many inputs for uint4b8_t -> BF16 vector converter");
  444. CUTLASS_PRAGMA_UNROLL
  445. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  446. asm volatile(
  447. "{\n"
  448. " prmt.b32 %0, %1, %2, %3;\n"
  449. "}\n"
  450. : "=r"(r[ii])
  451. : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
  452. }
  453. // Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
  454. // we are trying to construct x and a BF16 value
  455. // The below XOR does the following:
  456. // 1) Sets the exponent bits of the BF16 to the correct value for the
  457. // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
  458. // and subtracting 136 to get {x1, x0}
  459. static constexpr uint32_t xor_mask = 0x43004300;
  460. static constexpr uint32_t and_mask = 0x000F000F;
  461. static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
  462. // For each operand, computes:
  463. // r[i] = (r[i] & and_mask) ^ xor_mask
  464. CUTLASS_PRAGMA_UNROLL
  465. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  466. asm volatile(
  467. "{\n"
  468. " lop3.b32 %0, %0, %1, %2, %3;\n"
  469. "}\n"
  470. : "+r"(r[ii])
  471. : "n"(and_mask), "n"(xor_mask), "n"(immLut));
  472. }
  473. // We will issue 2 bfmas that do the following:
  474. // high BF16:
  475. // hi_bf16 - 136, lo_bf16 - 136
  476. // This is the BF16 {136, 136} represented as an integer.
  477. static constexpr uint32_t bias_rep = 0x43084308;
  478. const __nv_bfloat162& bias =
  479. reinterpret_cast<const __nv_bfloat162&>(bias_rep);
  480. CUTLASS_PRAGMA_UNROLL
  481. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  482. __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
  483. bf16x2_val = __hsub2(bf16x2_val, bias);
  484. }
  485. return reinterpret_cast<PackedResultType&>(r);
  486. }
  487. };
  488. public:
  489. CUTLASS_DEVICE
  490. static result_type convert(source_type const& source) {
  491. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  492. typename source_type::Element,
  493. N>::convert(source);
  494. }
  495. CUTLASS_DEVICE
  496. result_type operator()(source_type const& s) const { return convert(s); }
  497. };
  498. // for Array<cutlass::bfloat16_t, N> <= Array<aphrodite_uint4b8_t, N>
  499. // for IlvdLayout: (2, 4):(4, 1)
  500. template <FloatRoundStyle Round, int N>
  501. struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
  502. cutlass::bfloat16_t,
  503. aphrodite_uint4b8_t, N, Round, void> {
  504. using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
  505. static_assert(N % size(IlvdLayout{}) == 0);
  506. using result_type = Array<cutlass::bfloat16_t, N>;
  507. using source_type = Array<aphrodite_uint4b8_t, N>;
  508. private:
  509. struct RegConvert {
  510. template <typename PackedResultType>
  511. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  512. using RegArray =
  513. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  514. sizeof(PackedResultType)>;
  515. RegArray r;
  516. static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
  517. static constexpr uint32_t or_mask = 0x43004300;
  518. // Unlike float16 where the mantissa is large enough to contain 2
  519. // nibbles, bfloat16 can only fit one, so we can only convert one
  520. // nibble at a time
  521. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  522. r[ii] = src >> (4 * ii);
  523. static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
  524. static constexpr uint32_t low_nib_mask = 0x000F000F;
  525. asm volatile(
  526. "{\n"
  527. " lop3.b32 %0, %0, %1, %2, %3;\n"
  528. "}\n"
  529. : "+r"(r[ii + 0])
  530. : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
  531. // For low nibble:
  532. // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
  533. static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
  534. {
  535. __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
  536. fp16x2_val =
  537. __hsub2(fp16x2_val,
  538. reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
  539. }
  540. }
  541. return reinterpret_cast<PackedResultType&>(r);
  542. };
  543. };
  544. public:
  545. CUTLASS_DEVICE
  546. static result_type convert(source_type const& source) {
  547. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  548. typename source_type::Element,
  549. N>::convert(source);
  550. }
  551. CUTLASS_DEVICE
  552. result_type operator()(source_type const& s) const { return convert(s); }
  553. };
  554. // for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
  555. // for IlvdLayout: (2, 4):(4, 1)
  556. template <FloatRoundStyle Round, int N>
  557. struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
  558. cutlass::bfloat16_t, uint4_t, N, Round,
  559. void> {
  560. using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
  561. static_assert(N % size(IlvdLayout{}) == 0);
  562. using result_type = Array<cutlass::bfloat16_t, N>;
  563. using source_type = Array<uint4_t, N>;
  564. private:
  565. struct RegConvert {
  566. template <typename PackedResultType>
  567. CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
  568. using RegArray =
  569. cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
  570. sizeof(PackedResultType)>;
  571. RegArray r;
  572. static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
  573. static constexpr uint32_t or_mask = 0x43004300;
  574. // Unlike float16 where the mantissa is large enough to contain 2
  575. // nibbles, bfloat16 can only fit one, so we can only convert one
  576. // nibble at a time
  577. for (int ii = 0; ii < RegArray::kElements; ++ii) {
  578. r[ii] = src >> (4 * ii);
  579. static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
  580. static constexpr uint32_t low_nib_mask = 0x000F000F;
  581. asm volatile(
  582. "{\n"
  583. " lop3.b32 %0, %0, %1, %2, %3;\n"
  584. "}\n"
  585. : "+r"(r[ii])
  586. : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
  587. // For low nibble:
  588. // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
  589. static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
  590. {
  591. __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
  592. fp16x2_val =
  593. __hsub2(fp16x2_val,
  594. reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
  595. }
  596. }
  597. return reinterpret_cast<PackedResultType&>(r);
  598. };
  599. };
  600. public:
  601. CUTLASS_DEVICE
  602. static result_type convert(source_type const& source) {
  603. return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
  604. typename source_type::Element,
  605. N>::convert(source);
  606. }
  607. CUTLASS_DEVICE
  608. result_type operator()(source_type const& s) const { return convert(s); }
  609. };
  610. // for Array<cutlass::bfloat16_t, N> <= Array<aphrodite_uint8b128_t, N>
  611. template <FloatRoundStyle Round, int N>
  612. struct NumericArrayConverter<cutlass::bfloat16_t, aphrodite_uint8b128_t, N,
  613. Round> {
  614. using result_type = Array<cutlass::bfloat16_t, N>;
  615. using source_type = Array<aphrodite_uint8b128_t, N>;
  616. static FloatRoundStyle const round_style = Round;
  617. private:
  618. using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
  619. using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
  620. using src_packed_4_t = Array<aphrodite_uint8b128_t, 4>;
  621. using src_packed_2_t = Array<aphrodite_uint8b128_t, 2>;
  622. // Not Valid, not supported, only here to satisfy the interface and to avoid
  623. // a compile error. ScalarConverter will not actually work until
  624. // NumericConverter<cutlass::bfloat16_t, aphrodite_uint8b128_t, Round> is
  625. // implemented
  626. using ScalarConverter =
  627. NumericConverter<cutlass::bfloat16_t, aphrodite_uint8b128_t, Round>;
  628. template <typename PackedResultType, typename PackedSrcType>
  629. CUTLASS_DEVICE static PackedResultType packed_convert(
  630. PackedSrcType const& source) {
  631. static_assert(
  632. (platform::is_same<PackedSrcType, src_packed_2_t>::value &&
  633. platform::is_same<PackedResultType, result_packed_2_t>::value) ||
  634. (platform::is_same<PackedSrcType, src_packed_4_t>::value &&
  635. platform::is_same<PackedResultType, result_packed_4_t>::value),
  636. "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
  637. "convert dispatch.");
  638. NumericArrayConverter<float, aphrodite_uint8b128_t,
  639. PackedResultType::kElements, Round>
  640. convert_uint8_to_f32;
  641. Array<float, PackedResultType::kElements> tmp =
  642. convert_uint8_to_f32(source);
  643. NumericArrayConverter<cutlass::bfloat16_t, float,
  644. PackedResultType::kElements, Round>
  645. convert_f32_to_bf16_;
  646. return convert_f32_to_bf16_(tmp);
  647. }
  648. friend class detail::VectorizedConverter;
  649. public:
  650. CUTLASS_DEVICE
  651. static result_type convert(source_type const& source) {
  652. result_type result;
  653. using ConverterType =
  654. NumericArrayConverter<typename result_type::Element,
  655. typename source_type::Element, N, Round>;
  656. detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
  657. src_packed_4_t, result_packed_2_t,
  658. src_packed_2_t>(result, source);
  659. return result;
  660. }
  661. CUTLASS_DEVICE
  662. result_type operator()(source_type const& s) const { return convert(s); }
  663. };
  664. #endif
  665. /////////////////////////////////////////////////////////////////////////////////////////////////
  666. } // namespace cutlass
  667. /////////////////////////////////////////////////////////////////////////////////////////////////