origin_order.cu 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. #include <cuda_bf16.h>
  2. #include <cuda_fp16.h>
  3. #include <cuda_runtime.h>
  4. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 700
  5. #include <mma.h>
  6. #endif
  7. #include <ATen/ATen.h>
  8. #include <ATen/core/Tensor.h>
  9. #include <ATen/cuda/CUDAContext.h>
  10. #include <ATen/DeviceGuard.h>
  11. #include <torch/all.h>
  12. #include <c10/cuda/CUDAGuard.h>
  13. template <typename U, typename V>
  14. constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
  15. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  16. return (a / b);
  17. }
  18. template <typename U, typename V>
  19. constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
  20. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  21. // Overflow safe variant of (a + b - 1) / b
  22. const uint64_t blocks = a / b + (a % b != 0);
  23. return blocks;
  24. }
  25. template <typename U, typename V>
  26. constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
  27. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  28. return divDown(a, b) * b;
  29. }
  30. template <typename U, typename V>
  31. constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
  32. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  33. return divUp(a, b) * b;
  34. }
  35. constexpr int32_t kWarpSize = 32;
  36. [[maybe_unused]] constexpr int32_t KTilesPerWarp = 8;
  37. constexpr int32_t kMTileSize = 16;
  38. constexpr int32_t kNTileSize = 8;
  39. constexpr int32_t kKTileSize = 16;
  40. struct __align__(16) f16x2x4_u32 { uint32_t vals[4]; };
  41. struct __align__(16) f16x2x2_u32 { uint32_t vals[2]; };
  42. struct ALayout_RM {
  43. template <int KTilesToLoad>
  44. static __device__ void load(const half* A, int32_t m, int32_t k,
  45. int32_t mTiles, int32_t mTile, int32_t kTiles,
  46. int32_t kTileStart, int32_t laneId,
  47. f16x2x4_u32 out[KTilesToLoad]) {
  48. const auto mLane = mTile * kMTileSize + (laneId / 4);
  49. const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 4;
  50. // access
  51. // [mTile * kMTileSize + (laneId / 4)]
  52. // [kTileStart * kKTileSize + (laneId % 4) * 2]
  53. auto aPtr = A + mLane * k + kLane;
  54. auto aPtrPlus8Rows = aPtr + 8 * k;
  55. bool m0InBounds = mLane < m;
  56. bool m1InBounds = (mLane + 8) < m;
  57. #pragma unroll
  58. for (int i = 0; i < KTilesToLoad; ++i) {
  59. out[i].vals[0] =
  60. m0InBounds ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
  61. : uint32_t(0);
  62. out[i].vals[1] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
  63. aPtrPlus8Rows + i * kKTileSize)
  64. : uint32_t(0);
  65. out[i].vals[2] =
  66. m0InBounds
  67. ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 2)
  68. : uint32_t(0);
  69. out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
  70. aPtrPlus8Rows + i * kKTileSize + 2)
  71. : uint32_t(0);
  72. }
  73. }
  74. static __device__ void store(half* C, int32_t m, int32_t n, int32_t mOutTiles,
  75. int32_t mTile, int32_t nOutTiles, int32_t nTile,
  76. int32_t laneId, const float4& out) {
  77. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  78. // sum.x / sum.y are written at
  79. // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
  80. // sum.z / sum.w are written at
  81. // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
  82. // i.e., same columns, different row.
  83. const int outRow = mTile * kMTileSize + (laneId / 4);
  84. const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
  85. // Pointer where sum.x / sum.y is written
  86. auto cPtr = C + outRow * n + outCol;
  87. auto v01 = __float22half2_rn(float2{out.x, out.y});
  88. auto v23 = __float22half2_rn(float2{out.z, out.w});
  89. if (outRow < m) {
  90. *reinterpret_cast<half2*>(cPtr) = v01;
  91. }
  92. // sum.z, sum.w at +8 rows from cPtr
  93. if (outRow + 8 < m) {
  94. *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
  95. }
  96. #endif
  97. }
  98. };
  99. struct BLayout_D4 {
  100. static constexpr bool use_codebook = true;
  101. template <int KTilesPerIteration>
  102. static __device__ void load(const void* __restrict__ B,
  103. const uint64_t* __restrict__ CB, int32_t n,
  104. int32_t k, int32_t nTiles, int32_t nTile,
  105. int32_t kTiles, int32_t kTileStart,
  106. int32_t laneId,
  107. f16x2x2_u32 b[KTilesPerIteration]) {
  108. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  109. auto Bptr = reinterpret_cast<const uint8_t*>(B);
  110. #pragma unroll
  111. for (int i = 0; i < KTilesPerIteration; ++i) {
  112. const int row = nTile * kNTileSize + laneId / 4;
  113. const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
  114. *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k / 4 + col]];
  115. }
  116. #endif
  117. }
  118. };
  119. struct BLayout_HI {
  120. static constexpr bool use_codebook = false;
  121. template <int KTilesPerIteration>
  122. static __device__ void load(const void* __restrict__ B,
  123. const uint64_t* __restrict__ CB, int32_t n,
  124. int32_t k, int32_t nTiles, int32_t nTile,
  125. int32_t kTiles, int32_t kTileStart,
  126. int32_t laneId,
  127. f16x2x2_u32 b[KTilesPerIteration]) {
  128. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  129. auto Bptr = reinterpret_cast<const uint32_t*>(B);
  130. #pragma unroll
  131. for (int i = 0; i < KTilesPerIteration; ++i) {
  132. const int row = nTile * kNTileSize + laneId / 4;
  133. const int col = (kTileStart + i) * kKTileSize / 8 + (laneId % 4) / 2;
  134. // simply use code - 7.5 instead of reading codebook
  135. uint32_t code = Bptr[row * k / 8 + col];
  136. const uint32_t c0 = 0x64086408;
  137. const half y16_ = __float2half_rn(1.0f / 16.0f);
  138. const half2 y16 = __halves2half2(y16_, y16_);
  139. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  140. const half2 z16 = __halves2half2(z16_, z16_);
  141. uint32_t qa = code >> ((laneId & 1) * 8);
  142. uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
  143. uint32_t q1 = ((qa & 0x00f000f0) | c0);
  144. *(half2*)(b[i].vals) = __hfma2(*((half2*)(&q0)), y16, z16);
  145. *(half2*)(b[i].vals + 1) = __hfma2(*((half2*)(&q1)), y16, z16);
  146. }
  147. #endif
  148. }
  149. };
  150. struct BLayout_E8 {
  151. static constexpr bool use_codebook = true;
  152. __device__ static inline uint64_t decode8weights(
  153. uint16_t weight_compressed, const int64_t* __restrict__ codebook_abs) {
  154. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  155. uint8_t bits_sign = weight_compressed & 0xff;
  156. uint8_t parity = __popc(bits_sign) & 1;
  157. uint8_t sign_vec = bits_sign ^ parity;
  158. uint8_t bits_abs = (weight_compressed >> 8);
  159. int64_t packed = codebook_abs[bits_abs];
  160. uint64_t decoded_sign = sign_vec * 0x8040201008040201ll;
  161. decoded_sign &= 0x8080808080808080;
  162. decoded_sign >>= 7;
  163. decoded_sign *= 255 - 3;
  164. packed ^= decoded_sign;
  165. packed |= 0x0101010101010101;
  166. packed -= parity * 0x0202020202020202;
  167. return packed;
  168. #endif
  169. }
  170. __device__ static inline uint32_t decode8weights(
  171. uint16_t weight_compressed, const int64_t* __restrict__ codebook_abs,
  172. int idx) {
  173. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  174. uint8_t bits_sign =
  175. weight_compressed & 0xff; //__brev(weight_compressed) >> 24;
  176. const uint32_t magic_nums[2] = {0x08040201ll, 0x80402010ll};
  177. uint8_t parity = __popc(bits_sign) & 1;
  178. uint8_t sign_vec = bits_sign ^ parity; // (parity << 7);
  179. uint16_t bits_abs = (weight_compressed >> 8);
  180. uint32_t packed = ((uint32_t*)codebook_abs)[(bits_abs << 1) + idx];
  181. uint32_t magic_num = magic_nums[idx];
  182. uint32_t decoded_sign = sign_vec * magic_num;
  183. decoded_sign &= 0x80808080;
  184. decoded_sign >>= 7;
  185. decoded_sign *= 255 - 3;
  186. packed ^= decoded_sign;
  187. packed |= 0x01010101;
  188. packed -= parity * 0x02020202;
  189. return packed;
  190. #endif
  191. };
  192. template <int KTilesPerIteration>
  193. static __device__ void load(const void* __restrict__ B,
  194. const uint64_t* __restrict__ CB, int32_t n,
  195. int32_t k, int32_t nTiles, int32_t nTile,
  196. int32_t kTiles, int32_t kTileStart,
  197. int32_t laneId,
  198. f16x2x2_u32 b[KTilesPerIteration]) {
  199. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  200. auto Bptr = (const uint16_t*)B;
  201. #pragma unroll
  202. for (int i = 0; i < KTilesPerIteration; ++i) {
  203. const int row = nTile * kNTileSize + laneId / 4;
  204. const int col = (kTileStart + i) * kKTileSize / 8 + laneId % 4 / 2;
  205. uint32_t decoded = decode8weights(Bptr[row * k / 8 + col],
  206. (const int64_t*)CB, laneId & 1);
  207. half2 unpacked[2];
  208. uint32_t lower_half = decoded & 0x00ff00ff;
  209. lower_half = (lower_half ^ 0x5c805c80);
  210. memcpy(unpacked, &lower_half, sizeof(uint32_t));
  211. uint32_t upper_half = (decoded & 0xff00ff00) >> 8;
  212. upper_half = (upper_half ^ 0x5c805c80);
  213. memcpy(unpacked + 1, &upper_half, sizeof(uint32_t));
  214. const half adjust_ = __float2half_rn(-288.0f);
  215. const half2 adjust = __halves2half2(adjust_, adjust_);
  216. unpacked[0] = __hadd2(unpacked[0], adjust);
  217. unpacked[1] = __hadd2(unpacked[1], adjust);
  218. *(reinterpret_cast<uint64_t*>(b[i].vals)) =
  219. *(reinterpret_cast<uint64_t*>(unpacked));
  220. //*((half*)(b[i].vals)) = unpacked[0];
  221. //*((half*)(b[i].vals) + 1) = unpacked[0].y;
  222. //*((half*)(b[i].vals) + 2) = unpacked[1].x;
  223. //*((half*)(b[i].vals) + 3) = unpacked[1].y;
  224. }
  225. #endif
  226. }
  227. };
  228. template <typename ALayout, typename BLayout, typename CLayout, int Warps,
  229. int KTilesPerIteration>
  230. __global__ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
  231. // Data for the A matrix, loaded as per ALayout
  232. const half* __restrict__ A, const void* __restrict__ B,
  233. const uint64_t* __restrict__ CB,
  234. // Output data for the C matrix, stored as per CLayout
  235. half* __restrict__ C,
  236. // The size of the matrix multiplication
  237. int32_t m, int32_t n, int32_t k,
  238. // The size of the matrix multiplication, in multiples of our TC tile size
  239. int32_t mTiles, int32_t nTiles, int32_t kTiles) {
  240. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  241. __shared__ uint64_t CB_[256];
  242. if (BLayout::use_codebook) {
  243. CB_[threadIdx.x + threadIdx.y * 32] = CB[threadIdx.x + threadIdx.y * 32];
  244. __syncthreads();
  245. }
  246. auto warpId = threadIdx.y;
  247. auto laneId = threadIdx.x;
  248. int32_t mTile = blockIdx.z;
  249. int32_t nTile = blockIdx.y;
  250. float4 c{0.0f, 0.0f, 0.0f, 0.0f};
  251. // First, handle whole multiples of KTilesPerIteration
  252. auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
  253. // Each warp handles a set of KTilesPerIteration under the above limit
  254. for (int32_t kTileBase = warpId * KTilesPerIteration; kTileBase < kTilesLimit;
  255. kTileBase += Warps * KTilesPerIteration) {
  256. //
  257. // Load data from A
  258. //
  259. f16x2x4_u32 a[KTilesPerIteration];
  260. ALayout::template load<KTilesPerIteration>(A, m, k, mTiles, mTile, kTiles,
  261. kTileBase, laneId, a);
  262. //
  263. // Load data from B and de-quantize as needed
  264. //
  265. f16x2x2_u32 b[KTilesPerIteration];
  266. BLayout::template load<KTilesPerIteration>(B, CB_, n, k, nTiles, nTile,
  267. kTiles, kTileBase, laneId, b);
  268. // Now, perform the matrix multiplication
  269. //
  270. #pragma unroll
  271. for (int i = 0; i < KTilesPerIteration / 2; ++i) {
  272. float4 cTmp[2];
  273. #pragma unroll
  274. for (int k = 0; k < 2; ++k) {
  275. cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
  276. }
  277. #pragma unroll
  278. for (int k = 0; k < 2; ++k) {
  279. asm volatile(
  280. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  281. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
  282. : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
  283. : "r"(a[i * 2 + k].vals[0]), "r"(a[i * 2 + k].vals[1]),
  284. "r"(a[i * 2 + k].vals[2]), "r"(a[i * 2 + k].vals[3]),
  285. "r"(b[i * 2 + k].vals[0]), "r"(b[i * 2 + k].vals[1]),
  286. "f"(cTmp[k].x), "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w));
  287. }
  288. #pragma unroll
  289. for (int k = 0; k < 2; ++k) {
  290. c.x += cTmp[k].x;
  291. c.y += cTmp[k].y;
  292. c.z += cTmp[k].z;
  293. c.w += cTmp[k].w;
  294. }
  295. }
  296. } // for all tiles under kTilesLimit
  297. auto kTileBaseRemaining = kTilesLimit + warpId;
  298. // If we have any remainder k-tiles, some warps will handle them, processing
  299. // kInnerKTiles k-tiles at a time
  300. if (kTileBaseRemaining < kTiles) {
  301. f16x2x4_u32 a;
  302. ALayout::template load<1>(A, m, k, mTiles, mTile, kTiles,
  303. kTileBaseRemaining, laneId, &a);
  304. f16x2x2_u32 b;
  305. BLayout::template load<1>(B, CB, n, k, nTiles, nTile, kTiles,
  306. kTileBaseRemaining, laneId, &b);
  307. asm volatile(
  308. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  309. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
  310. : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
  311. : "r"(a.vals[0]), "r"(a.vals[1]), "r"(a.vals[2]), "r"(a.vals[3]),
  312. "r"(b.vals[0]), "r"(b.vals[1]), "f"(c.x), "f"(c.y), "f"(c.z),
  313. "f"(c.w));
  314. }
  315. // Reduce independent k-tiles (same m/n) across warps
  316. __shared__ float4 smem_sum[Warps][kWarpSize];
  317. smem_sum[warpId][laneId] = c;
  318. __syncthreads();
  319. if (warpId == 0) {
  320. float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
  321. // Reduce across the block in the first warp
  322. for (int i = 0; i < Warps; ++i) {
  323. float4 v = smem_sum[i][laneId];
  324. sum_f32.x += v.x;
  325. sum_f32.y += v.y;
  326. sum_f32.z += v.z;
  327. sum_f32.w += v.w;
  328. }
  329. // Write the reduced result (in the first warp) into the output
  330. CLayout::store(C, m, n, mTiles, mTile,
  331. // n for C output becomes k for A input, so for m16n8k16,
  332. // we need to halve the tiles
  333. nTiles / 2, nTile, laneId, sum_f32);
  334. }
  335. #endif
  336. }
  337. at::Tensor d4_mm_origorder(const at::Tensor& A, const at::Tensor& B,
  338. const at::Tensor& CB) {
  339. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  340. c10::cuda::CUDAGuard g(A.device());
  341. auto stream = at::cuda::getCurrentCUDAStream();
  342. constexpr int Warps = 8;
  343. // row major layout
  344. auto m = A.size(0);
  345. auto mTiles = divUp(m, kMTileSize);
  346. // tensor core layout
  347. auto n = B.size(0);
  348. auto nTiles = divUp(n, kNTileSize);
  349. // row major layout
  350. auto k = A.size(1);
  351. auto kTiles = divUp(k, kKTileSize);
  352. // Output is a standard row-major matrix
  353. auto C_final = at::empty(
  354. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  355. auto grid = dim3(1, nTiles, mTiles);
  356. auto block = dim3(kWarpSize, Warps);
  357. auto kernel =
  358. tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_D4, ALayout_RM, 8, 8>;
  359. kernel<<<grid, block, 0, stream>>>(
  360. (const half*)A.data_ptr(), (const void*)B.data_ptr(),
  361. (const uint64_t*)CB.data_ptr(), (half*)C_final.data_ptr(), m, n, k,
  362. mTiles, nTiles, kTiles);
  363. return C_final;
  364. #endif
  365. return {}; // Squash missing return statement warning
  366. }
  367. at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
  368. const at::Tensor& CB) {
  369. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  370. c10::cuda::CUDAGuard g(A.device());
  371. auto stream = at::cuda::getCurrentCUDAStream();
  372. constexpr int Warps = 8;
  373. // row major layout
  374. auto m = A.size(0);
  375. auto mTiles = divUp(m, kMTileSize);
  376. // tensor core layout
  377. auto n = B.size(0);
  378. auto nTiles = divUp(n, kNTileSize);
  379. // row major layout
  380. auto k = A.size(1);
  381. auto kTiles = divUp(k, kKTileSize);
  382. // Output is a standard row-major matrix
  383. auto C_final = at::empty(
  384. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  385. auto grid = dim3(1, nTiles, mTiles);
  386. auto block = dim3(kWarpSize, Warps);
  387. auto kernel =
  388. tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_E8, ALayout_RM, 8, 8>;
  389. kernel<<<grid, block, 0, stream>>>(
  390. (const half*)A.data_ptr(), (const void*)B.data_ptr(),
  391. (const uint64_t*)CB.data_ptr(), (half*)C_final.data_ptr(), m, n, k,
  392. mTiles, nTiles, kTiles);
  393. return C_final;
  394. #endif
  395. return {}; // Squash missing return statement warning
  396. }
  397. at::Tensor hi_mm_origorder(const at::Tensor& A, const at::Tensor& B) {
  398. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  399. c10::cuda::CUDAGuard g(A.device());
  400. auto stream = at::cuda::getCurrentCUDAStream();
  401. constexpr int Warps = 8;
  402. // row major layout
  403. auto m = A.size(0);
  404. auto mTiles = divUp(m, kMTileSize);
  405. // tensor core layout
  406. auto n = B.size(0);
  407. auto nTiles = divUp(n, kNTileSize);
  408. // row major layout
  409. auto k = A.size(1);
  410. auto kTiles = divUp(k, kKTileSize);
  411. // Output is a standard row-major matrix
  412. auto C_final = at::empty(
  413. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  414. auto grid = dim3(1, nTiles, mTiles);
  415. auto block = dim3(kWarpSize, Warps);
  416. auto kernel =
  417. tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_HI, ALayout_RM, 8, 8>;
  418. kernel<<<grid, block, 0, stream>>>(
  419. (const half*)A.data_ptr(), (const void*)B.data_ptr(), nullptr,
  420. (half*)C_final.data_ptr(), m, n, k, mTiles, nTiles, kTiles);
  421. return C_final;
  422. #endif
  423. return {}; // Squash missing return statement warning
  424. }
  425. #define DECOMPRESS_D4_BLOCK_SIZE 256
  426. __global__ void cuda_decompress_d4_origorder_kernel(
  427. const uint8_t* __restrict__ YIs, // m x (n/4)
  428. const c10::Half* __restrict__ CB, // 256 x 4
  429. c10::Half* __restrict__ Y // m x n
  430. ) {
  431. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  432. const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x;
  433. for (long r = 0; r < 4; r++) {
  434. uint8_t yidx = ((uint8_t*)YIs)[i * 4 + r];
  435. ((uint64_t*)Y)[i * 4 + r] = ((uint64_t*)CB)[yidx & 255];
  436. }
  437. #endif
  438. }
  439. void decompress_d4_origorder(torch::Tensor YIs, // m x (n/4)
  440. torch::Tensor CB, // 256 x 4
  441. torch::Tensor Y // m x n
  442. ) {
  443. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  444. size_t m = Y.sizes()[0];
  445. size_t n = Y.sizes()[1];
  446. assert(YIs.is_contiguous());
  447. assert(CB.is_contiguous());
  448. assert(Y.is_contiguous());
  449. assert(YIs.sizes()[0] == m);
  450. assert(YIs.sizes()[1] * 4 == n);
  451. assert(CB.sizes()[0] == 256);
  452. const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE);
  453. const dim3 blocks(m * n / (16 * DECOMPRESS_D4_BLOCK_SIZE));
  454. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  455. cuda_decompress_d4_origorder_kernel<<<blocks, threads, 0, stream>>>(
  456. YIs.data_ptr<uint8_t>(), CB.data_ptr<c10::Half>(),
  457. Y.data_ptr<c10::Half>());
  458. #endif
  459. }
  460. #define DECOMPRESS_E8P_BLOCK_SIZE 256
  461. __global__ void cuda_decompress_e8p_origorder_kernel(
  462. const int16_t* __restrict__ YIs, // m x (n/8)
  463. const int64_t* __restrict__ CB, // 256 x 8
  464. c10::Half* __restrict__ Y // m x n
  465. ) {
  466. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  467. const long i = threadIdx.x + DECOMPRESS_E8P_BLOCK_SIZE * blockIdx.x;
  468. uint16_t yidx = ((uint16_t*)YIs)[i];
  469. uint64_t decoded = BLayout_E8::decode8weights(yidx, CB);
  470. half2 unpacked[2][2];
  471. uint64_t lower_half = decoded & 0x00ff00ff00ff00ff;
  472. lower_half = (lower_half ^ 0x5c805c805c805c80);
  473. memcpy(unpacked[0], &lower_half, sizeof(uint64_t));
  474. uint64_t upper_half = (decoded & 0xff00ff00ff00ff00) >> 8;
  475. upper_half = (upper_half ^ 0x5c805c805c805c80);
  476. memcpy(unpacked[1], &upper_half, sizeof(uint64_t));
  477. const half adjust_ = __float2half_rn(-288.0f);
  478. const half2 adjust = __halves2half2(adjust_, adjust_);
  479. ((__half2*)Y)[i * 4] = __hadd2(unpacked[0][0], adjust); // 01
  480. ((__half2*)Y)[i * 4 + 2] = __hadd2(unpacked[0][1], adjust); // 45
  481. ((__half2*)Y)[i * 4 + 1] = __hadd2(unpacked[1][0], adjust); // 23
  482. ((__half2*)Y)[i * 4 + 3] = __hadd2(unpacked[1][1], adjust); // 67
  483. #endif
  484. }
  485. void decompress_e8p_origorder(torch::Tensor YIs, // m x (n/8)
  486. torch::Tensor CB, // 256 x 8
  487. torch::Tensor& Y // m x n
  488. ) {
  489. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  490. size_t m = Y.sizes()[0];
  491. size_t n = Y.sizes()[1];
  492. assert(YIs.is_contiguous());
  493. assert(CB.is_contiguous());
  494. assert(Y.is_contiguous());
  495. assert(YIs.sizes()[0] == m);
  496. assert(YIs.sizes()[1] * 8 == n);
  497. assert(CB.sizes()[0] == 256);
  498. const dim3 threads(DECOMPRESS_E8P_BLOCK_SIZE);
  499. const dim3 blocks(m * n / (8 * DECOMPRESS_E8P_BLOCK_SIZE));
  500. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  501. cuda_decompress_e8p_origorder_kernel<<<blocks, threads, 0, stream>>>(
  502. YIs.data_ptr<int16_t>(), CB.data_ptr<int64_t>(), Y.data_ptr<c10::Half>());
  503. #endif
  504. }
  505. #define DECOMPRESS_HI_BLOCK_SIZE 256
  506. __global__ void cuda_decompress_hi_origorder_kernel(
  507. const uint32_t* __restrict__ YIs, // m x (n/8)
  508. c10::Half* __restrict__ Y // m x n
  509. ) {
  510. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  511. const long i = threadIdx.x + DECOMPRESS_HI_BLOCK_SIZE * blockIdx.x;
  512. uint32_t qa = YIs[i];
  513. const uint32_t c0 = 0x64086408;
  514. const half y16_ = __float2half_rn(1.0f / 16.0f);
  515. const half2 y16 = __halves2half2(y16_, y16_);
  516. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  517. const half2 z16 = __halves2half2(z16_, z16_);
  518. uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
  519. uint32_t q1 = ((qa & 0x00f000f0) | c0);
  520. qa >>= 8;
  521. uint32_t q2 = (((qa & 0x000f000f) << 4) | c0);
  522. uint32_t q3 = ((qa & 0x00f000f0) | c0);
  523. ((__half2*)Y)[i * 4] = __hfma2(*((half2*)(&q0)), y16, z16);
  524. ((__half2*)Y)[i * 4 + 1] = __hfma2(*((half2*)(&q1)), y16, z16);
  525. ((__half2*)Y)[i * 4 + 2] = __hfma2(*((half2*)(&q2)), y16, z16);
  526. ((__half2*)Y)[i * 4 + 3] = __hfma2(*((half2*)(&q3)), y16, z16);
  527. #endif
  528. }
  529. void decompress_hi_origorder(torch::Tensor YIs, // m x (n/8)
  530. torch::Tensor Y // m x n
  531. ) {
  532. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  533. size_t m = Y.sizes()[0];
  534. size_t n = Y.sizes()[1];
  535. assert(YIs.is_contiguous());
  536. assert(Y.is_contiguous());
  537. assert(YIs.sizes()[0] == m);
  538. assert(YIs.sizes()[1] * 8 == n);
  539. const dim3 threads(DECOMPRESS_HI_BLOCK_SIZE);
  540. const dim3 blocks(m * n / (8 * DECOMPRESS_HI_BLOCK_SIZE));
  541. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  542. cuda_decompress_hi_origorder_kernel<<<blocks, threads, 0, stream>>>(
  543. (uint32_t*)YIs.data_ptr<int32_t>(), Y.data_ptr<c10::Half>());
  544. #endif
  545. }