vec_dtypes.cuh 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501
  1. /*
  2. * Copyright (c) 2024 by PygmalionAI team.
  3. * Copyright (c) 2023 by FlashInfer team.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #ifndef VEC_DTYPES_CUH_
  18. #define VEC_DTYPES_CUH_
  19. #include <cuda_bf16.h>
  20. #include <cuda_fp16.h>
  21. #include <cuda_fp8.h>
  22. #include <cuda_runtime.h>
  23. #include <type_traits>
  24. namespace aphrodite {
  25. #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
  26. #define APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  27. #endif
  28. #define APHRODITE_INLINE inline __attribute__((always_inline)) __device__
  29. /******************* vec_t type cast *******************/
  30. template <typename dst_t, typename src_t>
  31. struct vec_cast {
  32. template <size_t vec_size>
  33. APHRODITE_INLINE static void cast(dst_t* dst, const src_t* src) {
  34. #pragma unroll
  35. for (size_t i = 0; i < vec_size; ++i) {
  36. dst[i] = (dst_t)src[i];
  37. }
  38. }
  39. };
  40. template <>
  41. struct vec_cast<float, half> {
  42. template <size_t vec_size>
  43. APHRODITE_INLINE static void cast(float* dst, const half* src) {
  44. if constexpr (vec_size == 1) {
  45. dst[0] = (float)src[0];
  46. } else {
  47. #pragma unroll
  48. for (size_t i = 0; i < vec_size / 2; ++i) {
  49. ((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
  50. }
  51. }
  52. }
  53. };
  54. template <>
  55. struct vec_cast<half, float> {
  56. template <size_t vec_size>
  57. APHRODITE_INLINE static void cast(half* dst, const float* src) {
  58. if constexpr (vec_size == 1) {
  59. dst[0] = __float2half(src[0]);
  60. } else {
  61. #pragma unroll
  62. for (size_t i = 0; i < vec_size / 2; ++i) {
  63. ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
  64. }
  65. }
  66. }
  67. };
  68. template <typename T>
  69. constexpr APHRODITE_INLINE int get_exponent_bits() {
  70. if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
  71. return 4;
  72. } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
  73. return 5;
  74. } else if constexpr (std::is_same<T, half>::value) {
  75. return 5;
  76. } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
  77. return 8;
  78. }
  79. }
  80. template <typename T>
  81. constexpr APHRODITE_INLINE int get_mantissa_bits() {
  82. if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
  83. return 3;
  84. } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
  85. return 2;
  86. } else if constexpr (std::is_same<T, half>::value) {
  87. return 11;
  88. } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
  89. return 7;
  90. }
  91. }
  92. /*!
  93. * \brief Fallback to software fast dequant implementation if hardware
  94. * dequantization is not available. \note Inspired by Marlin's fast
  95. * dequantization, but here we don't have to permute weights order. \ref
  96. * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120
  97. */
  98. template <typename fp8_dtype, typename fp16_dtype>
  99. __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
  100. uint32_t q = *input;
  101. if constexpr (std::is_same<fp8_dtype, __nv_fp8_e5m2>::value &&
  102. std::is_same<fp16_dtype, half>::value) {
  103. output->x = __byte_perm(0U, q, 0x5140);
  104. output->y = __byte_perm(0U, q, 0x7362);
  105. } else {
  106. constexpr int FP8_EXPONENT = get_exponent_bits<fp8_dtype>();
  107. constexpr int FP8_MANTISSA = get_mantissa_bits<fp8_dtype>();
  108. constexpr int FP16_EXPONENT = get_exponent_bits<fp16_dtype>();
  109. constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
  110. // Calculate MASK for extracting mantissa and exponent
  111. constexpr int MASK1 = 0x80000000;
  112. constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
  113. constexpr int MASK3 = MASK2 & 0x7fffffff;
  114. constexpr int MASK = MASK3 | (MASK3 >> 16);
  115. q = __byte_perm(q, q, 0x1302);
  116. // Extract and shift FP8 values to FP16 format
  117. uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
  118. uint32_t Out2 =
  119. ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
  120. constexpr int BIAS_OFFSET =
  121. (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
  122. // Construct and apply exponent bias
  123. if constexpr (std::is_same<fp16_dtype, half>::value) {
  124. const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
  125. // Convert to half2 and apply bias
  126. *(half2*)&(output->x) =
  127. __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
  128. *(half2*)&(output->y) =
  129. __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
  130. } else {
  131. constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
  132. const nv_bfloat162 bias_reg =
  133. __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
  134. // Convert to bfloat162 and apply bias
  135. *(nv_bfloat162*)&(output->x) =
  136. __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
  137. *(nv_bfloat162*)&(output->y) =
  138. __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
  139. }
  140. }
  141. }
  142. template <>
  143. struct vec_cast<nv_bfloat16, __nv_fp8_e4m3> {
  144. template <size_t vec_size>
  145. APHRODITE_INLINE static void cast(nv_bfloat16* dst,
  146. const __nv_fp8_e4m3* src) {
  147. if constexpr (vec_size == 1) {
  148. dst[0] = nv_bfloat16(src[0]);
  149. } else if constexpr (vec_size == 2) {
  150. dst[0] = nv_bfloat16(src[0]);
  151. dst[1] = nv_bfloat16(src[1]);
  152. } else {
  153. static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
  154. #pragma unroll
  155. for (uint32_t i = 0; i < vec_size / 4; ++i) {
  156. fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4],
  157. (uint2*)&dst[i * 4]);
  158. }
  159. }
  160. }
  161. };
  162. template <>
  163. struct vec_cast<nv_bfloat16, __nv_fp8_e5m2> {
  164. template <size_t vec_size>
  165. APHRODITE_INLINE static void cast(nv_bfloat16* dst,
  166. const __nv_fp8_e5m2* src) {
  167. if constexpr (vec_size == 1) {
  168. dst[0] = nv_bfloat16(src[0]);
  169. } else if constexpr (vec_size == 2) {
  170. dst[0] = nv_bfloat16(src[0]);
  171. dst[1] = nv_bfloat16(src[1]);
  172. } else {
  173. static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
  174. #pragma unroll
  175. for (uint32_t i = 0; i < vec_size / 4; ++i) {
  176. fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4],
  177. (uint2*)&dst[i * 4]);
  178. }
  179. }
  180. }
  181. };
  182. template <>
  183. struct vec_cast<__nv_fp8_e4m3, half> {
  184. template <size_t vec_size>
  185. APHRODITE_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) {
  186. #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  187. if constexpr (vec_size == 1) {
  188. dst[0] = __nv_fp8_e4m3(src[0]);
  189. } else {
  190. #pragma unroll
  191. for (size_t i = 0; i < vec_size / 2; ++i) {
  192. uint16_t y;
  193. uint32_t x = *(uint32_t*)&src[i * 2];
  194. asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;"
  195. : "=h"(y)
  196. : "r"(x));
  197. *(uint16_t*)&dst[i * 2] = y;
  198. }
  199. }
  200. #else
  201. #pragma unroll
  202. for (size_t i = 0; i < vec_size; ++i) {
  203. dst[i] = __nv_fp8_e4m3(src[i]);
  204. }
  205. #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  206. }
  207. };
  208. template <>
  209. struct vec_cast<__nv_fp8_e5m2, half> {
  210. template <size_t vec_size>
  211. APHRODITE_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) {
  212. #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  213. if constexpr (vec_size == 1) {
  214. dst[0] = __nv_fp8_e5m2(src[0]);
  215. } else {
  216. #pragma unroll
  217. for (size_t i = 0; i < vec_size / 2; ++i) {
  218. uint16_t y;
  219. uint32_t x = *(uint32_t*)&src[i * 2];
  220. asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;"
  221. : "=h"(y)
  222. : "r"(x));
  223. *(uint16_t*)&dst[i * 2] = y;
  224. }
  225. }
  226. #else
  227. #pragma unroll
  228. for (size_t i = 0; i < vec_size; ++i) {
  229. dst[i] = __nv_fp8_e5m2(src[i]);
  230. }
  231. #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  232. }
  233. };
  234. template <>
  235. struct vec_cast<half, __nv_fp8_e4m3> {
  236. template <size_t vec_size>
  237. APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) {
  238. #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  239. if constexpr (vec_size == 1) {
  240. dst[0] = half(src[0]);
  241. } else {
  242. #pragma unroll
  243. for (size_t i = 0; i < vec_size / 2; ++i) {
  244. uint32_t y;
  245. uint16_t x = *(uint16_t*)&src[i * 2];
  246. asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x));
  247. *(uint32_t*)&dst[i * 2] = y;
  248. }
  249. }
  250. #else
  251. if constexpr (vec_size == 1) {
  252. dst[0] = half(src[0]);
  253. } else if constexpr (vec_size == 2) {
  254. dst[0] = half(src[0]);
  255. dst[1] = half(src[1]);
  256. } else {
  257. static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
  258. #pragma unroll
  259. for (uint32_t i = 0; i < vec_size / 4; ++i) {
  260. fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4],
  261. (uint2*)&dst[i * 4]);
  262. }
  263. }
  264. #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  265. }
  266. };
  267. template <>
  268. struct vec_cast<half, __nv_fp8_e5m2> {
  269. template <size_t vec_size>
  270. APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) {
  271. #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  272. if constexpr (vec_size == 1) {
  273. dst[0] = half(src[0]);
  274. } else {
  275. #pragma unroll
  276. for (size_t i = 0; i < vec_size / 2; ++i) {
  277. uint32_t y;
  278. uint16_t x = *(uint16_t*)&src[i * 2];
  279. asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x));
  280. *(uint32_t*)&dst[i * 2] = y;
  281. }
  282. }
  283. #else
  284. if constexpr (vec_size == 1) {
  285. dst[0] = half(src[0]);
  286. } else if constexpr (vec_size == 2) {
  287. dst[0] = half(src[0]);
  288. dst[1] = half(src[1]);
  289. } else {
  290. static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
  291. #pragma unroll
  292. for (uint32_t i = 0; i < vec_size / 4; ++i) {
  293. fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4],
  294. (uint2*)&dst[i * 4]);
  295. }
  296. }
  297. #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
  298. }
  299. };
  300. template <>
  301. struct vec_cast<float, nv_bfloat16> {
  302. template <size_t vec_size>
  303. APHRODITE_INLINE static void cast(float* dst, const nv_bfloat16* src) {
  304. if constexpr (vec_size == 1) {
  305. dst[0] = (float)src[0];
  306. } else {
  307. #pragma unroll
  308. for (size_t i = 0; i < vec_size / 2; ++i) {
  309. ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
  310. }
  311. }
  312. }
  313. };
  314. template <>
  315. struct vec_cast<nv_bfloat16, float> {
  316. template <size_t vec_size>
  317. APHRODITE_INLINE static void cast(nv_bfloat16* dst, const float* src) {
  318. if constexpr (vec_size == 1) {
  319. dst[0] = nv_bfloat16(src[0]);
  320. } else {
  321. #pragma unroll
  322. for (size_t i = 0; i < vec_size / 2; ++i) {
  323. ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
  324. }
  325. }
  326. }
  327. };
  328. template <typename float_t, size_t vec_size>
  329. struct vec_t {
  330. APHRODITE_INLINE float_t& operator[](size_t i);
  331. APHRODITE_INLINE const float_t& operator[](size_t i) const;
  332. APHRODITE_INLINE void fill(float_t val);
  333. APHRODITE_INLINE void load(const float_t* ptr);
  334. APHRODITE_INLINE void store(float_t* ptr) const;
  335. template <typename T>
  336. APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src);
  337. template <typename T>
  338. APHRODITE_INLINE void cast_load(const T* ptr);
  339. template <typename T>
  340. APHRODITE_INLINE void cast_store(T* ptr) const;
  341. APHRODITE_INLINE static void memcpy(float_t* dst, const float_t* src);
  342. APHRODITE_INLINE float_t* ptr();
  343. };
  344. template <typename src_float_t, typename tgt_float_t, size_t vec_size>
  345. APHRODITE_INLINE void cast_from_impl(vec_t<tgt_float_t, vec_size>& dst,
  346. const vec_t<src_float_t, vec_size>& src) {
  347. vec_cast<tgt_float_t, src_float_t>::cast<vec_size>(
  348. dst.ptr(), const_cast<vec_t<src_float_t, vec_size>*>(&src)->ptr());
  349. }
  350. template <typename src_float_t, typename tgt_float_t, size_t vec_size>
  351. APHRODITE_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
  352. const src_float_t* src_ptr) {
  353. if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
  354. dst.load(src_ptr);
  355. } else {
  356. vec_t<src_float_t, vec_size> tmp;
  357. tmp.load(src_ptr);
  358. dst.cast_from(tmp);
  359. }
  360. }
  361. template <typename src_float_t, typename tgt_float_t, size_t vec_size>
  362. APHRODITE_INLINE void cast_store_impl(tgt_float_t* dst_ptr,
  363. const vec_t<src_float_t, vec_size>& src) {
  364. if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
  365. src.store(dst_ptr);
  366. } else {
  367. vec_t<tgt_float_t, vec_size> tmp;
  368. tmp.cast_from(src);
  369. tmp.store(dst_ptr);
  370. }
  371. }
  372. /******************* vec_t<__nv_fp8_e4m3> *******************/
  373. // __nv_fp8_e4m3 x 1
  374. template <>
  375. struct vec_t<__nv_fp8_e4m3, 1> {
  376. __nv_fp8_e4m3 data;
  377. APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
  378. return ((__nv_fp8_e4m3*)(&data))[i];
  379. }
  380. APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
  381. return ((const __nv_fp8_e4m3*)(&data))[i];
  382. }
  383. APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
  384. return reinterpret_cast<__nv_fp8_e4m3*>(&data);
  385. }
  386. APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
  387. APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
  388. APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
  389. template <typename T>
  390. APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
  391. cast_from_impl(*this, src);
  392. }
  393. template <typename T>
  394. APHRODITE_INLINE void cast_load(const T* ptr) {
  395. cast_load_impl(*this, ptr);
  396. }
  397. template <typename T>
  398. APHRODITE_INLINE void cast_store(T* ptr) const {
  399. cast_store_impl(ptr, *this);
  400. }
  401. APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
  402. const __nv_fp8_e4m3* src);
  403. };
  404. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
  405. data = val;
  406. }
  407. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) {
  408. data = *ptr;
  409. }
  410. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const {
  411. *ptr = data;
  412. }
  413. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
  414. __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
  415. *dst = *src;
  416. }
  417. // __nv_fp8_e4m3 x 2
  418. template <>
  419. struct vec_t<__nv_fp8_e4m3, 2> {
  420. __nv_fp8x2_e4m3 data;
  421. APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
  422. return ((__nv_fp8_e4m3*)(&data))[i];
  423. }
  424. APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
  425. return ((const __nv_fp8_e4m3*)(&data))[i];
  426. }
  427. APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
  428. return reinterpret_cast<__nv_fp8_e4m3*>(&data);
  429. }
  430. APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
  431. APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
  432. APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
  433. template <typename T>
  434. APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
  435. cast_from_impl(*this, src);
  436. }
  437. template <typename T>
  438. APHRODITE_INLINE void cast_load(const T* ptr) {
  439. cast_load_impl(*this, ptr);
  440. }
  441. template <typename T>
  442. APHRODITE_INLINE void cast_store(T* ptr) const {
  443. cast_store_impl(ptr, *this);
  444. }
  445. APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
  446. const __nv_fp8_e4m3* src);
  447. };
  448. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
  449. data.__x =
  450. (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
  451. }
  452. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) {
  453. data = *((__nv_fp8x2_e4m3*)ptr);
  454. }
  455. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const {
  456. *((__nv_fp8x2_e4m3*)ptr) = data;
  457. }
  458. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
  459. __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
  460. *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src);
  461. }
  462. // __nv_fp8_e4m3 x 4
  463. template <>
  464. struct vec_t<__nv_fp8_e4m3, 4> {
  465. __nv_fp8x4_e4m3 data;
  466. APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
  467. return ((__nv_fp8_e4m3*)(&data))[i];
  468. }
  469. APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
  470. return ((const __nv_fp8_e4m3*)(&data))[i];
  471. }
  472. APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
  473. return reinterpret_cast<__nv_fp8_e4m3*>(&data);
  474. }
  475. APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
  476. APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
  477. APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
  478. template <typename T>
  479. APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
  480. cast_from_impl(*this, src);
  481. }
  482. template <typename T>
  483. APHRODITE_INLINE void cast_load(const T* ptr) {
  484. cast_load_impl(*this, ptr);
  485. }
  486. template <typename T>
  487. APHRODITE_INLINE void cast_store(T* ptr) const {
  488. cast_store_impl(ptr, *this);
  489. }
  490. APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
  491. const __nv_fp8_e4m3* src);
  492. };
  493. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
  494. data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
  495. (__nv_fp8x4_storage_t(val.__x) << 16) |
  496. (__nv_fp8x4_storage_t(val.__x) << 8) |
  497. __nv_fp8x4_storage_t(val.__x);
  498. }
  499. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) {
  500. data = *((__nv_fp8x4_e4m3*)ptr);
  501. }
  502. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const {
  503. *((__nv_fp8x4_e4m3*)ptr) = data;
  504. }
  505. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
  506. __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
  507. *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src);
  508. }
  509. // __nv_fp8_e4m3 x 8
  510. template <>
  511. struct vec_t<__nv_fp8_e4m3, 8> {
  512. uint2 data;
  513. APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
  514. return ((__nv_fp8_e4m3*)(&data))[i];
  515. }
  516. APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
  517. return ((const __nv_fp8_e4m3*)(&data))[i];
  518. }
  519. APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
  520. return reinterpret_cast<__nv_fp8_e4m3*>(&data);
  521. }
  522. APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
  523. APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
  524. APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
  525. template <typename T>
  526. APHRODITE_INLINE void cast_from(const vec_t<T, 8>& src) {
  527. cast_from_impl(*this, src);
  528. }
  529. template <typename T>
  530. APHRODITE_INLINE void cast_load(const T* ptr) {
  531. cast_load_impl(*this, ptr);
  532. }
  533. template <typename T>
  534. APHRODITE_INLINE void cast_store(T* ptr) const {
  535. cast_store_impl(ptr, *this);
  536. }
  537. APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
  538. const __nv_fp8_e4m3* src);
  539. };
  540. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
  541. ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
  542. (__nv_fp8x4_storage_t(val.__x) << 16) |
  543. (__nv_fp8x4_storage_t(val.__x) << 8) |
  544. __nv_fp8x4_storage_t(val.__x);
  545. ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
  546. (__nv_fp8x4_storage_t(val.__x) << 16) |
  547. (__nv_fp8x4_storage_t(val.__x) << 8) |
  548. __nv_fp8x4_storage_t(val.__x);
  549. }
  550. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) {
  551. data = *((uint2*)ptr);
  552. }
  553. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const {
  554. *((uint2*)ptr) = data;
  555. }
  556. APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
  557. __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
  558. *((uint2*)dst) = *((uint2*)src);
  559. }
  560. // __nv_fp8_e4m3 x 16 or more
  561. template <size_t vec_size>
  562. struct vec_t<__nv_fp8_e4m3, vec_size> {
  563. uint4 data[vec_size / 16];
  564. APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
  565. return ((__nv_fp8_e4m3*)data)[i];
  566. }
  567. APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
  568. return ((const __nv_fp8_e4m3*)data)[i];
  569. }
  570. APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
  571. return reinterpret_cast<__nv_fp8_e4m3*>(&data);
  572. }
  573. APHRODITE_INLINE void fill(__nv_fp8_e4m3 val) {
  574. #pragma unroll
  575. for (size_t i = 0; i < vec_size / 16; ++i) {
  576. ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x =
  577. (__nv_fp8x4_storage_t(val.__x) << 24) |
  578. (__nv_fp8x4_storage_t(val.__x) << 16) |
  579. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  580. ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x =
  581. (__nv_fp8x4_storage_t(val.__x) << 24) |
  582. (__nv_fp8x4_storage_t(val.__x) << 16) |
  583. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  584. ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x =
  585. (__nv_fp8x4_storage_t(val.__x) << 24) |
  586. (__nv_fp8x4_storage_t(val.__x) << 16) |
  587. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  588. ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x =
  589. (__nv_fp8x4_storage_t(val.__x) << 24) |
  590. (__nv_fp8x4_storage_t(val.__x) << 16) |
  591. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  592. }
  593. }
  594. APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr) {
  595. #pragma unroll
  596. for (size_t i = 0; i < vec_size / 16; ++i) {
  597. data[i] = ((uint4*)ptr)[i];
  598. }
  599. }
  600. APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const {
  601. #pragma unroll
  602. for (size_t i = 0; i < vec_size / 16; ++i) {
  603. ((uint4*)ptr)[i] = data[i];
  604. }
  605. }
  606. template <typename T>
  607. APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
  608. cast_from_impl(*this, src);
  609. }
  610. template <typename T>
  611. APHRODITE_INLINE void cast_load(const T* ptr) {
  612. cast_load_impl(*this, ptr);
  613. }
  614. template <typename T>
  615. APHRODITE_INLINE void cast_store(T* ptr) const {
  616. cast_store_impl(ptr, *this);
  617. }
  618. APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
  619. const __nv_fp8_e4m3* src) {
  620. #pragma unroll
  621. for (size_t i = 0; i < vec_size / 16; ++i) {
  622. ((uint4*)dst)[i] = ((uint4*)src)[i];
  623. }
  624. }
  625. };
  626. /******************* vec_t<__nv_fp8_e5m2> *******************/
  627. // __nv_fp8_e5m2 x 1
  628. template <>
  629. struct vec_t<__nv_fp8_e5m2, 1> {
  630. __nv_fp8_e5m2 data;
  631. APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
  632. return ((__nv_fp8_e5m2*)(&data))[i];
  633. }
  634. APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
  635. return ((const __nv_fp8_e5m2*)(&data))[i];
  636. }
  637. APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
  638. return reinterpret_cast<__nv_fp8_e5m2*>(&data);
  639. }
  640. APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
  641. APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
  642. APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
  643. template <typename T>
  644. APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
  645. cast_from_impl(*this, src);
  646. }
  647. template <typename T>
  648. APHRODITE_INLINE void cast_load(const T* ptr) {
  649. cast_load_impl(*this, ptr);
  650. }
  651. template <typename T>
  652. APHRODITE_INLINE void cast_store(T* ptr) const {
  653. cast_store_impl(ptr, *this);
  654. }
  655. APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
  656. const __nv_fp8_e5m2* src);
  657. };
  658. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
  659. data = val;
  660. }
  661. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) {
  662. data = *ptr;
  663. }
  664. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const {
  665. *ptr = data;
  666. }
  667. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
  668. __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
  669. *dst = *src;
  670. }
  671. // __nv_fp8_e5m2 x 2
  672. template <>
  673. struct vec_t<__nv_fp8_e5m2, 2> {
  674. __nv_fp8x2_e5m2 data;
  675. APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
  676. return ((__nv_fp8_e5m2*)(&data))[i];
  677. }
  678. APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
  679. return ((const __nv_fp8_e5m2*)(&data))[i];
  680. }
  681. APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
  682. return reinterpret_cast<__nv_fp8_e5m2*>(&data);
  683. }
  684. APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
  685. APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
  686. APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
  687. template <typename T>
  688. APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
  689. cast_from_impl(*this, src);
  690. }
  691. template <typename T>
  692. APHRODITE_INLINE void cast_load(const T* ptr) {
  693. cast_load_impl(*this, ptr);
  694. }
  695. template <typename T>
  696. APHRODITE_INLINE void cast_store(T* ptr) const {
  697. cast_store_impl(ptr, *this);
  698. }
  699. APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
  700. const __nv_fp8_e5m2* src);
  701. };
  702. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
  703. data.__x =
  704. (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
  705. }
  706. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) {
  707. data = *((__nv_fp8x2_e5m2*)ptr);
  708. }
  709. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const {
  710. *((__nv_fp8x2_e5m2*)ptr) = data;
  711. }
  712. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
  713. __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
  714. *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src);
  715. }
  716. // __nv_fp8_e5m2 x 4
  717. template <>
  718. struct vec_t<__nv_fp8_e5m2, 4> {
  719. __nv_fp8x4_e5m2 data;
  720. APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
  721. return ((__nv_fp8_e5m2*)(&data))[i];
  722. }
  723. APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
  724. return ((const __nv_fp8_e5m2*)(&data))[i];
  725. }
  726. APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
  727. return reinterpret_cast<__nv_fp8_e5m2*>(&data);
  728. }
  729. APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
  730. APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
  731. APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
  732. template <typename T>
  733. APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
  734. cast_from_impl(*this, src);
  735. }
  736. template <typename T>
  737. APHRODITE_INLINE void cast_load(const T* ptr) {
  738. cast_load_impl(*this, ptr);
  739. }
  740. template <typename T>
  741. APHRODITE_INLINE void cast_store(T* ptr) const {
  742. cast_store_impl(ptr, *this);
  743. }
  744. APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
  745. const __nv_fp8_e5m2* src);
  746. };
  747. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
  748. data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
  749. (__nv_fp8x4_storage_t(val.__x) << 16) |
  750. (__nv_fp8x4_storage_t(val.__x) << 8) |
  751. __nv_fp8x4_storage_t(val.__x);
  752. }
  753. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) {
  754. data = *((__nv_fp8x4_e5m2*)ptr);
  755. }
  756. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const {
  757. *((__nv_fp8x4_e5m2*)ptr) = data;
  758. }
  759. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
  760. __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
  761. *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src);
  762. }
  763. // __nv_fp8_e5m2 x 8
  764. template <>
  765. struct vec_t<__nv_fp8_e5m2, 8> {
  766. uint2 data;
  767. APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
  768. return ((__nv_fp8_e5m2*)(&data))[i];
  769. }
  770. APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
  771. return ((const __nv_fp8_e5m2*)(&data))[i];
  772. }
  773. APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
  774. return reinterpret_cast<__nv_fp8_e5m2*>(&data);
  775. }
  776. APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
  777. APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
  778. APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
  779. template <typename T>
  780. APHRODITE_INLINE void cast_from(const vec_t<T, 8>& src) {
  781. cast_from_impl(*this, src);
  782. }
  783. template <typename T>
  784. APHRODITE_INLINE void cast_load(const T* ptr) {
  785. cast_load_impl(*this, ptr);
  786. }
  787. template <typename T>
  788. APHRODITE_INLINE void cast_store(T* ptr) const {
  789. cast_store_impl(ptr, *this);
  790. }
  791. APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
  792. const __nv_fp8_e5m2* src);
  793. };
  794. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
  795. ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
  796. (__nv_fp8x4_storage_t(val.__x) << 16) |
  797. (__nv_fp8x4_storage_t(val.__x) << 8) |
  798. __nv_fp8x4_storage_t(val.__x);
  799. ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
  800. (__nv_fp8x4_storage_t(val.__x) << 16) |
  801. (__nv_fp8x4_storage_t(val.__x) << 8) |
  802. __nv_fp8x4_storage_t(val.__x);
  803. }
  804. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) {
  805. data = *((uint2*)ptr);
  806. }
  807. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const {
  808. *((uint2*)ptr) = data;
  809. }
  810. APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
  811. __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
  812. *((uint2*)dst) = *((uint2*)src);
  813. }
  814. // __nv_fp8_e5m2 x 16 or more
  815. template <size_t vec_size>
  816. struct vec_t<__nv_fp8_e5m2, vec_size> {
  817. uint4 data[vec_size / 16];
  818. APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
  819. return ((__nv_fp8_e5m2*)data)[i];
  820. }
  821. APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
  822. return ((const __nv_fp8_e5m2*)data)[i];
  823. }
  824. APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
  825. return reinterpret_cast<__nv_fp8_e5m2*>(&data);
  826. }
  827. APHRODITE_INLINE void fill(__nv_fp8_e5m2 val) {
  828. #pragma unroll
  829. for (size_t i = 0; i < vec_size / 16; ++i) {
  830. ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x =
  831. (__nv_fp8x4_storage_t(val.__x) << 24) |
  832. (__nv_fp8x4_storage_t(val.__x) << 16) |
  833. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  834. ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x =
  835. (__nv_fp8x4_storage_t(val.__x) << 24) |
  836. (__nv_fp8x4_storage_t(val.__x) << 16) |
  837. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  838. ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x =
  839. (__nv_fp8x4_storage_t(val.__x) << 24) |
  840. (__nv_fp8x4_storage_t(val.__x) << 16) |
  841. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  842. ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x =
  843. (__nv_fp8x4_storage_t(val.__x) << 24) |
  844. (__nv_fp8x4_storage_t(val.__x) << 16) |
  845. (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
  846. }
  847. }
  848. APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr) {
  849. #pragma unroll
  850. for (size_t i = 0; i < vec_size / 16; ++i) {
  851. data[i] = ((uint4*)ptr)[i];
  852. }
  853. }
  854. APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const {
  855. #pragma unroll
  856. for (size_t i = 0; i < vec_size / 16; ++i) {
  857. ((uint4*)ptr)[i] = data[i];
  858. }
  859. }
  860. template <typename T>
  861. APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
  862. cast_from_impl(*this, src);
  863. }
  864. template <typename T>
  865. APHRODITE_INLINE void cast_load(const T* ptr) {
  866. cast_load_impl(*this, ptr);
  867. }
  868. template <typename T>
  869. APHRODITE_INLINE void cast_store(T* ptr) const {
  870. cast_store_impl(ptr, *this);
  871. }
  872. APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
  873. const __nv_fp8_e5m2* src) {
  874. #pragma unroll
  875. for (size_t i = 0; i < vec_size / 16; ++i) {
  876. ((uint4*)dst)[i] = ((uint4*)src)[i];
  877. }
  878. }
  879. };
  880. /******************* vec_t<half> *******************/
  881. // half x 1
  882. template <>
  883. struct vec_t<half, 1> {
  884. half data;
  885. APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
  886. APHRODITE_INLINE const half& operator[](size_t i) const {
  887. return ((const half*)(&data))[i];
  888. }
  889. APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
  890. APHRODITE_INLINE void fill(half val);
  891. APHRODITE_INLINE void load(const half* ptr);
  892. APHRODITE_INLINE void store(half* ptr) const;
  893. template <typename T>
  894. APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
  895. cast_from_impl(*this, src);
  896. }
  897. template <typename T>
  898. APHRODITE_INLINE void cast_load(const T* ptr) {
  899. cast_load_impl(*this, ptr);
  900. }
  901. template <typename T>
  902. APHRODITE_INLINE void cast_store(T* ptr) const {
  903. cast_store_impl(ptr, *this);
  904. }
  905. APHRODITE_INLINE static void memcpy(half* dst, const half* src);
  906. };
  907. APHRODITE_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
  908. APHRODITE_INLINE void vec_t<half, 1>::load(const half* ptr) { data = *ptr; }
  909. APHRODITE_INLINE void vec_t<half, 1>::store(half* ptr) const { *ptr = data; }
  910. APHRODITE_INLINE void vec_t<half, 1>::memcpy(half* dst, const half* src) {
  911. *dst = *src;
  912. }
  913. // half x 2
  914. template <>
  915. struct vec_t<half, 2> {
  916. half2 data;
  917. APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
  918. APHRODITE_INLINE const half& operator[](size_t i) const {
  919. return ((const half*)(&data))[i];
  920. }
  921. APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
  922. APHRODITE_INLINE void fill(half val);
  923. APHRODITE_INLINE void load(const half* ptr);
  924. APHRODITE_INLINE void store(half* ptr) const;
  925. template <typename T>
  926. APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
  927. cast_from_impl(*this, src);
  928. }
  929. template <typename T>
  930. APHRODITE_INLINE void cast_load(const T* ptr) {
  931. cast_load_impl(*this, ptr);
  932. }
  933. template <typename T>
  934. APHRODITE_INLINE void cast_store(T* ptr) const {
  935. cast_store_impl(ptr, *this);
  936. }
  937. APHRODITE_INLINE static void memcpy(half* dst, const half* src);
  938. };
  939. APHRODITE_INLINE void vec_t<half, 2>::fill(half val) {
  940. data = make_half2(val, val);
  941. }
  942. APHRODITE_INLINE void vec_t<half, 2>::load(const half* ptr) {
  943. data = *((half2*)ptr);
  944. }
  945. APHRODITE_INLINE void vec_t<half, 2>::store(half* ptr) const {
  946. *((half2*)ptr) = data;
  947. }
  948. APHRODITE_INLINE void vec_t<half, 2>::memcpy(half* dst, const half* src) {
  949. *((half2*)dst) = *((half2*)src);
  950. }
  951. // half x 4
  952. template <>
  953. struct vec_t<half, 4> {
  954. uint2 data;
  955. APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
  956. APHRODITE_INLINE const half& operator[](size_t i) const {
  957. return ((const half*)(&data))[i];
  958. }
  959. APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
  960. APHRODITE_INLINE void fill(half val);
  961. APHRODITE_INLINE void load(const half* ptr);
  962. APHRODITE_INLINE void store(half* ptr) const;
  963. template <typename T>
  964. APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
  965. cast_from_impl(*this, src);
  966. }
  967. template <typename T>
  968. APHRODITE_INLINE void cast_load(const T* ptr) {
  969. cast_load_impl(*this, ptr);
  970. }
  971. template <typename T>
  972. APHRODITE_INLINE void cast_store(T* ptr) const {
  973. cast_store_impl(ptr, *this);
  974. }
  975. APHRODITE_INLINE static void memcpy(half* dst, const half* src);
  976. };
  977. APHRODITE_INLINE void vec_t<half, 4>::fill(half val) {
  978. *(half2*)(&data.x) = make_half2(val, val);
  979. *(half2*)(&data.y) = make_half2(val, val);
  980. }
  981. APHRODITE_INLINE void vec_t<half, 4>::load(const half* ptr) {
  982. data = *((uint2*)ptr);
  983. }
  984. APHRODITE_INLINE void vec_t<half, 4>::store(half* ptr) const {
  985. *((uint2*)ptr) = data;
  986. }
  987. APHRODITE_INLINE void vec_t<half, 4>::memcpy(half* dst, const half* src) {
  988. *((uint2*)dst) = *((uint2*)src);
  989. }
  990. // half x 8 or more
  991. template <size_t vec_size>
  992. struct vec_t<half, vec_size> {
  993. uint4 data[vec_size / 8];
  994. APHRODITE_INLINE half& operator[](size_t i) { return ((half*)data)[i]; }
  995. APHRODITE_INLINE const half& operator[](size_t i) const {
  996. return ((const half*)data)[i];
  997. }
  998. APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
  999. APHRODITE_INLINE void fill(half val) {
  1000. #pragma unroll
  1001. for (size_t i = 0; i < vec_size / 8; ++i) {
  1002. *(half2*)(&(data[i].x)) = make_half2(val, val);
  1003. *(half2*)(&(data[i].y)) = make_half2(val, val);
  1004. *(half2*)(&(data[i].z)) = make_half2(val, val);
  1005. *(half2*)(&(data[i].w)) = make_half2(val, val);
  1006. }
  1007. }
  1008. APHRODITE_INLINE void load(const half* ptr) {
  1009. #pragma unroll
  1010. for (size_t i = 0; i < vec_size / 8; ++i) {
  1011. data[i] = ((uint4*)ptr)[i];
  1012. }
  1013. }
  1014. APHRODITE_INLINE void store(half* ptr) const {
  1015. #pragma unroll
  1016. for (size_t i = 0; i < vec_size / 8; ++i) {
  1017. ((uint4*)ptr)[i] = data[i];
  1018. }
  1019. }
  1020. template <typename T>
  1021. APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
  1022. cast_from_impl(*this, src);
  1023. }
  1024. template <typename T>
  1025. APHRODITE_INLINE void cast_load(const T* ptr) {
  1026. cast_load_impl(*this, ptr);
  1027. }
  1028. template <typename T>
  1029. APHRODITE_INLINE void cast_store(T* ptr) const {
  1030. cast_store_impl(ptr, *this);
  1031. }
  1032. APHRODITE_INLINE static void memcpy(half* dst, const half* src) {
  1033. #pragma unroll
  1034. for (size_t i = 0; i < vec_size / 8; ++i) {
  1035. ((uint4*)dst)[i] = ((uint4*)src)[i];
  1036. }
  1037. }
  1038. };
  1039. /******************* vec_t<nv_bfloat16> *******************/
  1040. // nv_bfloat16 x 1
  1041. template <>
  1042. struct vec_t<nv_bfloat16, 1> {
  1043. nv_bfloat16 data;
  1044. APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
  1045. return ((nv_bfloat16*)(&data))[i];
  1046. }
  1047. APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
  1048. return ((const nv_bfloat16*)(&data))[i];
  1049. }
  1050. APHRODITE_INLINE nv_bfloat16* ptr() {
  1051. return reinterpret_cast<nv_bfloat16*>(&data);
  1052. }
  1053. APHRODITE_INLINE void fill(nv_bfloat16 val);
  1054. APHRODITE_INLINE void load(const nv_bfloat16* ptr);
  1055. APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
  1056. template <typename T>
  1057. APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
  1058. cast_from_impl(*this, src);
  1059. }
  1060. template <typename T>
  1061. APHRODITE_INLINE void cast_load(const T* ptr) {
  1062. cast_load_impl(*this, ptr);
  1063. }
  1064. template <typename T>
  1065. APHRODITE_INLINE void cast_store(T* ptr) const {
  1066. cast_store_impl(ptr, *this);
  1067. }
  1068. APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
  1069. };
  1070. APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) {
  1071. data = val;
  1072. }
  1073. APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
  1074. data = *ptr;
  1075. }
  1076. APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
  1077. *ptr = data;
  1078. }
  1079. APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16* dst,
  1080. const nv_bfloat16* src) {
  1081. *dst = *src;
  1082. }
  1083. // nv_bfloat16 x 2
  1084. template <>
  1085. struct vec_t<nv_bfloat16, 2> {
  1086. nv_bfloat162 data;
  1087. APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
  1088. return ((nv_bfloat16*)(&data))[i];
  1089. }
  1090. APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
  1091. return ((const nv_bfloat16*)(&data))[i];
  1092. }
  1093. APHRODITE_INLINE nv_bfloat16* ptr() {
  1094. return reinterpret_cast<nv_bfloat16*>(&data);
  1095. }
  1096. APHRODITE_INLINE void fill(nv_bfloat16 val);
  1097. APHRODITE_INLINE void load(const nv_bfloat16* ptr);
  1098. APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
  1099. template <typename T>
  1100. APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
  1101. cast_from_impl(*this, src);
  1102. }
  1103. template <typename T>
  1104. APHRODITE_INLINE void cast_load(const T* ptr) {
  1105. cast_load_impl(*this, ptr);
  1106. }
  1107. template <typename T>
  1108. APHRODITE_INLINE void cast_store(T* ptr) const {
  1109. cast_store_impl(ptr, *this);
  1110. }
  1111. APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
  1112. };
  1113. APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
  1114. data = make_bfloat162(val, val);
  1115. }
  1116. APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
  1117. data = *((nv_bfloat162*)ptr);
  1118. }
  1119. APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
  1120. *((nv_bfloat162*)ptr) = data;
  1121. }
  1122. APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16* dst,
  1123. const nv_bfloat16* src) {
  1124. *((nv_bfloat162*)dst) = *((nv_bfloat162*)src);
  1125. }
  1126. // nv_bfloat16 x 4
  1127. template <>
  1128. struct vec_t<nv_bfloat16, 4> {
  1129. uint2 data;
  1130. APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
  1131. return ((nv_bfloat16*)(&data))[i];
  1132. }
  1133. APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
  1134. return ((const nv_bfloat16*)(&data))[i];
  1135. }
  1136. APHRODITE_INLINE nv_bfloat16* ptr() {
  1137. return reinterpret_cast<nv_bfloat16*>(&data);
  1138. }
  1139. APHRODITE_INLINE void fill(nv_bfloat16 val);
  1140. APHRODITE_INLINE void load(const nv_bfloat16* ptr);
  1141. APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
  1142. template <typename T>
  1143. APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
  1144. cast_from_impl(*this, src);
  1145. }
  1146. template <typename T>
  1147. APHRODITE_INLINE void cast_load(const T* ptr) {
  1148. cast_load_impl(*this, ptr);
  1149. }
  1150. template <typename T>
  1151. APHRODITE_INLINE void cast_store(T* ptr) const {
  1152. cast_store_impl(ptr, *this);
  1153. }
  1154. APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
  1155. };
  1156. APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
  1157. *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val);
  1158. *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val);
  1159. }
  1160. APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
  1161. data = *((uint2*)ptr);
  1162. }
  1163. APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
  1164. *((uint2*)ptr) = data;
  1165. }
  1166. APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16* dst,
  1167. const nv_bfloat16* src) {
  1168. *((uint2*)dst) = *((uint2*)src);
  1169. }
  1170. // nv_bfloat16 x 8 or more
  1171. template <size_t vec_size>
  1172. struct vec_t<nv_bfloat16, vec_size> {
  1173. uint4 data[vec_size / 8];
  1174. APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
  1175. return ((nv_bfloat16*)data)[i];
  1176. }
  1177. APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
  1178. return ((const nv_bfloat16*)data)[i];
  1179. }
  1180. APHRODITE_INLINE nv_bfloat16* ptr() {
  1181. return reinterpret_cast<nv_bfloat16*>(&data);
  1182. }
  1183. APHRODITE_INLINE void fill(nv_bfloat16 val) {
  1184. #pragma unoll
  1185. for (size_t i = 0; i < vec_size / 8; ++i) {
  1186. *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val);
  1187. *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val);
  1188. *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val);
  1189. *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val);
  1190. }
  1191. }
  1192. APHRODITE_INLINE void load(const nv_bfloat16* ptr) {
  1193. #pragma unoll
  1194. for (size_t i = 0; i < vec_size / 8; ++i) {
  1195. data[i] = ((uint4*)ptr)[i];
  1196. }
  1197. }
  1198. APHRODITE_INLINE void store(nv_bfloat16* ptr) const {
  1199. #pragma unoll
  1200. for (size_t i = 0; i < vec_size / 8; ++i) {
  1201. ((uint4*)ptr)[i] = data[i];
  1202. }
  1203. }
  1204. template <typename T>
  1205. APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
  1206. cast_from_impl(*this, src);
  1207. }
  1208. template <typename T>
  1209. APHRODITE_INLINE void cast_load(const T* ptr) {
  1210. cast_load_impl(*this, ptr);
  1211. }
  1212. template <typename T>
  1213. APHRODITE_INLINE void cast_store(T* ptr) const {
  1214. cast_store_impl(ptr, *this);
  1215. }
  1216. APHRODITE_INLINE static void memcpy(nv_bfloat16* dst,
  1217. const nv_bfloat16* src) {
  1218. #pragma unoll
  1219. for (size_t i = 0; i < vec_size / 8; ++i) {
  1220. ((uint4*)dst)[i] = ((uint4*)src)[i];
  1221. }
  1222. }
  1223. };
  1224. /******************* vec_t<float> *******************/
  1225. // float x 1
  1226. template <>
  1227. struct vec_t<float, 1> {
  1228. float data;
  1229. APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
  1230. APHRODITE_INLINE const float& operator[](size_t i) const {
  1231. return ((const float*)(&data))[i];
  1232. }
  1233. APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
  1234. APHRODITE_INLINE void fill(float val);
  1235. APHRODITE_INLINE void load(const float* ptr);
  1236. APHRODITE_INLINE void store(float* ptr) const;
  1237. template <typename T>
  1238. APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
  1239. cast_from_impl(*this, src);
  1240. }
  1241. template <typename T>
  1242. APHRODITE_INLINE void cast_load(const T* ptr) {
  1243. cast_load_impl(*this, ptr);
  1244. }
  1245. template <typename T>
  1246. APHRODITE_INLINE void cast_store(T* ptr) const {
  1247. cast_store_impl(ptr, *this);
  1248. }
  1249. APHRODITE_INLINE static void memcpy(float* dst, const float* src);
  1250. };
  1251. APHRODITE_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
  1252. APHRODITE_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
  1253. APHRODITE_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
  1254. APHRODITE_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) {
  1255. *dst = *src;
  1256. }
  1257. // float x 2
  1258. template <>
  1259. struct vec_t<float, 2> {
  1260. float2 data;
  1261. APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
  1262. APHRODITE_INLINE const float& operator[](size_t i) const {
  1263. return ((const float*)(&data))[i];
  1264. }
  1265. APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
  1266. APHRODITE_INLINE void fill(float val);
  1267. APHRODITE_INLINE void load(const float* ptr);
  1268. APHRODITE_INLINE void store(float* ptr) const;
  1269. template <typename T>
  1270. APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
  1271. cast_from_impl(*this, src);
  1272. }
  1273. template <typename T>
  1274. APHRODITE_INLINE void cast_load(const T* ptr) {
  1275. cast_load_impl(*this, ptr);
  1276. }
  1277. template <typename T>
  1278. APHRODITE_INLINE void cast_store(T* ptr) const {
  1279. cast_store_impl(ptr, *this);
  1280. }
  1281. APHRODITE_INLINE static void memcpy(float* dst, const float* src);
  1282. };
  1283. APHRODITE_INLINE void vec_t<float, 2>::fill(float val) {
  1284. data = make_float2(val, val);
  1285. }
  1286. APHRODITE_INLINE void vec_t<float, 2>::load(const float* ptr) {
  1287. data = *((float2*)ptr);
  1288. }
  1289. APHRODITE_INLINE void vec_t<float, 2>::store(float* ptr) const {
  1290. *((float2*)ptr) = data;
  1291. }
  1292. APHRODITE_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
  1293. *((float2*)dst) = *((float2*)src);
  1294. }
  1295. // float x 4 or more
  1296. template <size_t vec_size>
  1297. struct vec_t<float, vec_size> {
  1298. float4 data[vec_size / 4];
  1299. APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
  1300. APHRODITE_INLINE const float& operator[](size_t i) const {
  1301. return ((const float*)(data))[i];
  1302. }
  1303. APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
  1304. APHRODITE_INLINE void fill(float val) {
  1305. #pragma unroll
  1306. for (size_t i = 0; i < vec_size / 4; ++i) {
  1307. data[i] = make_float4(val, val, val, val);
  1308. }
  1309. }
  1310. APHRODITE_INLINE void load(const float* ptr) {
  1311. #pragma unroll
  1312. for (size_t i = 0; i < vec_size / 4; ++i) {
  1313. data[i] = ((float4*)ptr)[i];
  1314. }
  1315. }
  1316. APHRODITE_INLINE void store(float* ptr) const {
  1317. #pragma unroll
  1318. for (size_t i = 0; i < vec_size / 4; ++i) {
  1319. ((float4*)ptr)[i] = data[i];
  1320. }
  1321. }
  1322. template <typename T>
  1323. APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
  1324. cast_from_impl(*this, src);
  1325. }
  1326. template <typename T>
  1327. APHRODITE_INLINE void cast_load(const T* ptr) {
  1328. cast_load_impl(*this, ptr);
  1329. }
  1330. template <typename T>
  1331. APHRODITE_INLINE void cast_store(T* ptr) const {
  1332. cast_store_impl(ptr, *this);
  1333. }
  1334. APHRODITE_INLINE static void memcpy(float* dst, const float* src) {
  1335. #pragma unroll
  1336. for (size_t i = 0; i < vec_size / 4; ++i) {
  1337. ((float4*)dst)[i] = ((float4*)src)[i];
  1338. }
  1339. }
  1340. };
  1341. } // namespace aphrodite
  1342. #endif // VEC_DTYPES_CUH_