origin_order.cu 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. #include <cuda_bf16.h>
  2. #include <cuda_fp16.h>
  3. #include <cuda_runtime.h>
  4. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 700
  5. #include <mma.h>
  6. #endif
  7. #include <ATen/ATen.h>
  8. #include <ATen/core/Tensor.h>
  9. #include <ATen/cuda/CUDAContext.h>
  10. #include <ATen/DeviceGuard.h>
  11. #include <torch/extension.h>
  12. #include <c10/cuda/CUDAGuard.h>
  13. template <typename U, typename V>
  14. constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
  15. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  16. return (a / b);
  17. }
  18. template <typename U, typename V>
  19. constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
  20. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  21. // Overflow safe variant of (a + b - 1) / b
  22. const uint64_t blocks = a / b + (a % b != 0);
  23. return blocks;
  24. }
  25. template <typename U, typename V>
  26. constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
  27. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  28. return divDown(a, b) * b;
  29. }
  30. template <typename U, typename V>
  31. constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
  32. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  33. return divUp(a, b) * b;
  34. }
  35. constexpr int32_t kWarpSize = 32;
  36. constexpr int32_t KTilesPerWarp = 8;
  37. constexpr int32_t kMTileSize = 16;
  38. constexpr int32_t kNTileSize = 8;
  39. constexpr int32_t kKTileSize = 16;
  40. struct __align__(16) f16x2x4_u32 {
  41. uint32_t vals[4];
  42. };
  43. struct __align__(16) f16x2x2_u32 {
  44. uint32_t vals[2];
  45. };
  46. struct ALayout_RM {
  47. template <int KTilesToLoad>
  48. static __device__ void load(
  49. const half* A,
  50. int32_t m,
  51. int32_t k,
  52. int32_t mTiles,
  53. int32_t mTile,
  54. int32_t kTiles,
  55. int32_t kTileStart,
  56. int32_t laneId,
  57. f16x2x4_u32 out[KTilesToLoad]) {
  58. const auto mLane = mTile * kMTileSize + (laneId / 4);
  59. const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 4;
  60. // access
  61. // [mTile * kMTileSize + (laneId / 4)]
  62. // [kTileStart * kKTileSize + (laneId % 4) * 2]
  63. auto aPtr = A + mLane * k + kLane;
  64. auto aPtrPlus8Rows = aPtr + 8 * k;
  65. bool m0InBounds = mLane < m;
  66. bool m1InBounds = (mLane + 8) < m;
  67. #pragma unroll
  68. for (int i = 0; i < KTilesToLoad; ++i) {
  69. out[i].vals[0] = m0InBounds
  70. ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
  71. : uint32_t(0);
  72. out[i].vals[1] = m1InBounds
  73. ? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows + i * kKTileSize)
  74. : uint32_t(0);
  75. out[i].vals[2] = m0InBounds
  76. ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 2)
  77. : uint32_t(0);
  78. out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
  79. aPtrPlus8Rows + i * kKTileSize + 2)
  80. : uint32_t(0);
  81. }
  82. }
  83. static __device__ void store(
  84. half* C,
  85. int32_t m,
  86. int32_t n,
  87. int32_t mOutTiles,
  88. int32_t mTile,
  89. int32_t nOutTiles,
  90. int32_t nTile,
  91. int32_t laneId,
  92. const float4& out) {
  93. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  94. // sum.x / sum.y are written at
  95. // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
  96. // sum.z / sum.w are written at
  97. // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
  98. // i.e., same columns, different row.
  99. const int outRow = mTile * kMTileSize + (laneId / 4);
  100. const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
  101. // Pointer where sum.x / sum.y is written
  102. auto cPtr = C + outRow * n + outCol;
  103. auto v01 = __float22half2_rn(float2{out.x, out.y});
  104. auto v23 = __float22half2_rn(float2{out.z, out.w});
  105. if (outRow < m) {
  106. *reinterpret_cast<half2*>(cPtr) = v01;
  107. }
  108. // sum.z, sum.w at +8 rows from cPtr
  109. if (outRow + 8 < m) {
  110. *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
  111. }
  112. #endif
  113. }
  114. };
  115. struct BLayout_D4 {
  116. static constexpr bool use_codebook = true;
  117. template <int KTilesPerIteration>
  118. static __device__ void load(
  119. const void* __restrict__ B,
  120. const uint64_t* __restrict__ CB,
  121. int32_t n,
  122. int32_t k,
  123. int32_t nTiles,
  124. int32_t nTile,
  125. int32_t kTiles,
  126. int32_t kTileStart,
  127. int32_t laneId,
  128. f16x2x2_u32 b[KTilesPerIteration]) {
  129. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  130. auto Bptr = reinterpret_cast<const uint8_t*>(B);
  131. #pragma unroll
  132. for (int i = 0; i < KTilesPerIteration; ++i) {
  133. const int row = nTile * kNTileSize + laneId / 4;
  134. const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
  135. *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k/4 + col]];
  136. }
  137. #endif
  138. }
  139. };
  140. struct BLayout_HI {
  141. static constexpr bool use_codebook = false;
  142. template <int KTilesPerIteration>
  143. static __device__ void load(
  144. const void* __restrict__ B,
  145. const uint64_t* __restrict__ CB,
  146. int32_t n,
  147. int32_t k,
  148. int32_t nTiles,
  149. int32_t nTile,
  150. int32_t kTiles,
  151. int32_t kTileStart,
  152. int32_t laneId,
  153. f16x2x2_u32 b[KTilesPerIteration]) {
  154. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  155. auto Bptr = reinterpret_cast<const uint32_t*>(B);
  156. #pragma unroll
  157. for (int i = 0; i < KTilesPerIteration; ++i) {
  158. const int row = nTile * kNTileSize + laneId / 4;
  159. const int col = (kTileStart + i) * kKTileSize / 8 + (laneId % 4) / 2;
  160. // simply use code - 7.5 instead of reading codebook
  161. uint32_t code = Bptr[row * k/8 + col];
  162. const uint32_t c0 = 0x64086408;
  163. const half y16_ = __float2half_rn(1.0f / 16.0f);
  164. const half2 y16 = __halves2half2(y16_, y16_);
  165. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  166. const half2 z16 = __halves2half2(z16_, z16_);
  167. uint32_t qa = code >> ((laneId & 1) * 8);
  168. uint32_t q0 = (((qa & 0x000f000f) << 4)| c0);
  169. uint32_t q1 = ((qa & 0x00f000f0) | c0);
  170. *(half2*)(b[i].vals) = __hfma2(*((half2*)(&q0)), y16, z16);
  171. *(half2*)(b[i].vals+1) = __hfma2(*((half2*)(&q1)), y16, z16);
  172. }
  173. #endif
  174. }
  175. };
  176. struct BLayout_E8 {
  177. static constexpr bool use_codebook = true;
  178. __device__ static inline uint64_t decode8weights(
  179. uint16_t weight_compressed,
  180. const int64_t *__restrict__ codebook_abs
  181. ) {
  182. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  183. uint8_t bits_sign = weight_compressed & 0xff;
  184. uint8_t parity = __popc(bits_sign) & 1;
  185. uint8_t sign_vec = bits_sign ^ parity;
  186. uint8_t bits_abs = (weight_compressed >> 8);
  187. int64_t packed = codebook_abs[bits_abs];
  188. uint64_t decoded_sign = sign_vec * 0x8040201008040201ll;
  189. decoded_sign &= 0x8080808080808080;
  190. decoded_sign >>= 7;
  191. decoded_sign *= 255 - 3;
  192. packed ^= decoded_sign;
  193. packed |= 0x0101010101010101;
  194. packed -= parity * 0x0202020202020202;
  195. return packed;
  196. #endif
  197. }
  198. __device__ static inline uint32_t decode8weights(
  199. uint16_t weight_compressed,
  200. const int64_t *__restrict__ codebook_abs,
  201. int idx
  202. ) {
  203. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  204. uint8_t bits_sign = weight_compressed & 0xff; //__brev(weight_compressed) >> 24;
  205. const uint32_t magic_nums[2] = {0x08040201ll, 0x80402010ll};
  206. uint8_t parity = __popc(bits_sign) & 1;
  207. uint8_t sign_vec = bits_sign ^ parity; // (parity << 7);
  208. uint16_t bits_abs = (weight_compressed >> 8);
  209. uint32_t packed = ((uint32_t*)codebook_abs)[(bits_abs << 1) + idx];
  210. uint32_t magic_num = magic_nums[idx];
  211. uint32_t decoded_sign = sign_vec * magic_num;
  212. decoded_sign &= 0x80808080;
  213. decoded_sign >>= 7;
  214. decoded_sign *= 255 - 3;
  215. packed ^= decoded_sign;
  216. packed |= 0x01010101;
  217. packed -= parity * 0x02020202;
  218. return packed;
  219. #endif
  220. };
  221. template <int KTilesPerIteration>
  222. static __device__ void load(
  223. const void* __restrict__ B,
  224. const uint64_t* __restrict__ CB,
  225. int32_t n,
  226. int32_t k,
  227. int32_t nTiles,
  228. int32_t nTile,
  229. int32_t kTiles,
  230. int32_t kTileStart,
  231. int32_t laneId,
  232. f16x2x2_u32 b[KTilesPerIteration]) {
  233. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  234. auto Bptr = (const uint16_t*) B;
  235. #pragma unroll
  236. for (int i = 0; i < KTilesPerIteration; ++i) {
  237. const int row = nTile * kNTileSize + laneId / 4;
  238. const int col = (kTileStart + i) * kKTileSize / 8 + laneId % 4 / 2;
  239. uint32_t decoded = decode8weights(Bptr[row * k/8 + col], (const int64_t*)CB, laneId & 1);
  240. half2 unpacked[2];
  241. uint32_t lower_half = decoded & 0x00ff00ff;
  242. lower_half = (lower_half ^ 0x5c805c80);
  243. memcpy(unpacked, &lower_half, sizeof(uint32_t));
  244. uint32_t upper_half = (decoded & 0xff00ff00) >> 8;
  245. upper_half = (upper_half ^ 0x5c805c80);
  246. memcpy(unpacked + 1, &upper_half, sizeof(uint32_t));
  247. const half adjust_ = __float2half_rn(-288.0f);
  248. const half2 adjust = __halves2half2(adjust_, adjust_);
  249. unpacked[0] = __hadd2(unpacked[0], adjust);
  250. unpacked[1] = __hadd2(unpacked[1], adjust);
  251. *(reinterpret_cast<uint64_t*>(b[i].vals)) = *(reinterpret_cast<uint64_t*>(unpacked));
  252. //*((half*)(b[i].vals)) = unpacked[0];
  253. //*((half*)(b[i].vals) + 1) = unpacked[0].y;
  254. //*((half*)(b[i].vals) + 2) = unpacked[1].x;
  255. //*((half*)(b[i].vals) + 3) = unpacked[1].y;
  256. }
  257. #endif
  258. }
  259. };
  260. template <
  261. typename ALayout,
  262. typename BLayout,
  263. typename CLayout,
  264. int Warps,
  265. int KTilesPerIteration>
  266. __global__
  267. __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
  268. // Data for the A matrix, loaded as per ALayout
  269. const half* __restrict__ A,
  270. const void* __restrict__ B,
  271. const uint64_t* __restrict__ CB,
  272. // Output data for the C matrix, stored as per CLayout
  273. half* __restrict__ C,
  274. // The size of the matrix multiplication
  275. int32_t m,
  276. int32_t n,
  277. int32_t k,
  278. // The size of the matrix multiplication, in multiples of our TC tile size
  279. int32_t mTiles,
  280. int32_t nTiles,
  281. int32_t kTiles) {
  282. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  283. __shared__ uint64_t CB_[256];
  284. if (BLayout::use_codebook) {
  285. CB_[threadIdx.x + threadIdx.y * 32] = CB[threadIdx.x + threadIdx.y * 32];
  286. __syncthreads();
  287. }
  288. auto warpId = threadIdx.y;
  289. auto laneId = threadIdx.x;
  290. int32_t mTile = blockIdx.z;
  291. int32_t nTile = blockIdx.y;
  292. float4 c{0.0f, 0.0f, 0.0f, 0.0f};
  293. // First, handle whole multiples of KTilesPerIteration
  294. auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
  295. // Each warp handles a set of KTilesPerIteration under the above limit
  296. for (int32_t kTileBase = warpId * KTilesPerIteration; kTileBase < kTilesLimit; kTileBase += Warps * KTilesPerIteration) {
  297. //
  298. // Load data from A
  299. //
  300. f16x2x4_u32 a[KTilesPerIteration];
  301. ALayout::template load<KTilesPerIteration>(
  302. A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
  303. //
  304. // Load data from B and de-quantize as needed
  305. //
  306. f16x2x2_u32 b[KTilesPerIteration];
  307. BLayout::template load<KTilesPerIteration>(
  308. B, CB_, n, k, nTiles, nTile, kTiles, kTileBase, laneId, b);
  309. // Now, perform the matrix multiplication
  310. //
  311. #pragma unroll
  312. for (int i = 0; i < KTilesPerIteration / 2; ++i) {
  313. float4 cTmp[2];
  314. #pragma unroll
  315. for (int k = 0; k < 2; ++k) {
  316. cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
  317. }
  318. #pragma unroll
  319. for (int k = 0; k < 2; ++k) {
  320. asm volatile(
  321. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  322. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
  323. : "=f"(cTmp[k].x),
  324. "=f"(cTmp[k].y),
  325. "=f"(cTmp[k].z),
  326. "=f"(cTmp[k].w)
  327. : "r"(a[i * 2 + k].vals[0]),
  328. "r"(a[i * 2 + k].vals[1]),
  329. "r"(a[i * 2 + k].vals[2]),
  330. "r"(a[i * 2 + k].vals[3]),
  331. "r"(b[i * 2 + k].vals[0]),
  332. "r"(b[i * 2 + k].vals[1]),
  333. "f"(cTmp[k].x),
  334. "f"(cTmp[k].y),
  335. "f"(cTmp[k].z),
  336. "f"(cTmp[k].w));
  337. }
  338. #pragma unroll
  339. for (int k = 0; k < 2; ++k) {
  340. c.x += cTmp[k].x;
  341. c.y += cTmp[k].y;
  342. c.z += cTmp[k].z;
  343. c.w += cTmp[k].w;
  344. }
  345. }
  346. } // for all tiles under kTilesLimit
  347. auto kTileBaseRemaining = kTilesLimit + warpId;
  348. // If we have any remainder k-tiles, some warps will handle them, processing
  349. // kInnerKTiles k-tiles at a time
  350. if (kTileBaseRemaining < kTiles) {
  351. f16x2x4_u32 a;
  352. ALayout::template load<1>(
  353. A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, &a);
  354. f16x2x2_u32 b;
  355. BLayout::template load<1>(
  356. B, CB, n, k, nTiles, nTile, kTiles, kTileBaseRemaining, laneId, &b);
  357. asm volatile(
  358. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  359. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
  360. : "=f"(c.x),
  361. "=f"(c.y),
  362. "=f"(c.z),
  363. "=f"(c.w)
  364. : "r"(a.vals[0]),
  365. "r"(a.vals[1]),
  366. "r"(a.vals[2]),
  367. "r"(a.vals[3]),
  368. "r"(b.vals[0]),
  369. "r"(b.vals[1]),
  370. "f"(c.x),
  371. "f"(c.y),
  372. "f"(c.z),
  373. "f"(c.w));
  374. }
  375. // Reduce independent k-tiles (same m/n) across warps
  376. __shared__ float4 smem_sum[Warps][kWarpSize];
  377. smem_sum[warpId][laneId] = c;
  378. __syncthreads();
  379. if (warpId == 0) {
  380. float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
  381. // Reduce across the block in the first warp
  382. for (int i = 0; i < Warps; ++i) {
  383. float4 v = smem_sum[i][laneId];
  384. sum_f32.x += v.x;
  385. sum_f32.y += v.y;
  386. sum_f32.z += v.z;
  387. sum_f32.w += v.w;
  388. }
  389. // Write the reduced result (in the first warp) into the output
  390. CLayout::store(
  391. C,
  392. m,
  393. n,
  394. mTiles,
  395. mTile,
  396. // n for C output becomes k for A input, so for m16n8k16,
  397. // we need to halve the tiles
  398. nTiles / 2,
  399. nTile,
  400. laneId,
  401. sum_f32);
  402. }
  403. #endif
  404. }
  405. at::Tensor d4_mm_origorder(
  406. const at::Tensor& A,
  407. const at::Tensor& B,
  408. const at::Tensor& CB) {
  409. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  410. c10::cuda::CUDAGuard g(A.device());
  411. auto stream = at::cuda::getCurrentCUDAStream();
  412. constexpr int Warps = 8;
  413. // row major layout
  414. auto m = A.size(0);
  415. auto mTiles = divUp(m, kMTileSize);
  416. // tensor core layout
  417. auto n = B.size(0);
  418. auto nTiles = divUp(n, kNTileSize);
  419. // row major layout
  420. auto k = A.size(1);
  421. auto kTiles = divUp(k, kKTileSize);
  422. // Output is a standard row-major matrix
  423. auto C_final = at::empty(
  424. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  425. auto grid = dim3(1, nTiles, mTiles);
  426. auto block = dim3(kWarpSize, Warps);
  427. auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_D4, ALayout_RM, 8, 8>;
  428. kernel<<<grid, block, 0, stream>>>(
  429. (const half*)A.data_ptr(),
  430. (const void*)B.data_ptr(),
  431. (const uint64_t*)CB.data_ptr(),
  432. (half*)C_final.data_ptr(),
  433. m,
  434. n,
  435. k,
  436. mTiles,
  437. nTiles,
  438. kTiles);
  439. return C_final;
  440. #endif
  441. }
  442. at::Tensor e8p_mm_origorder(
  443. const at::Tensor& A,
  444. const at::Tensor& B,
  445. const at::Tensor& CB) {
  446. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  447. c10::cuda::CUDAGuard g(A.device());
  448. auto stream = at::cuda::getCurrentCUDAStream();
  449. constexpr int Warps = 8;
  450. // row major layout
  451. auto m = A.size(0);
  452. auto mTiles = divUp(m, kMTileSize);
  453. // tensor core layout
  454. auto n = B.size(0);
  455. auto nTiles = divUp(n, kNTileSize);
  456. // row major layout
  457. auto k = A.size(1);
  458. auto kTiles = divUp(k, kKTileSize);
  459. // Output is a standard row-major matrix
  460. auto C_final = at::empty(
  461. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  462. auto grid = dim3(1, nTiles, mTiles);
  463. auto block = dim3(kWarpSize, Warps);
  464. auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_E8, ALayout_RM, 8, 8>;
  465. kernel<<<grid, block, 0, stream>>>(
  466. (const half*)A.data_ptr(),
  467. (const void*)B.data_ptr(),
  468. (const uint64_t*)CB.data_ptr(),
  469. (half*)C_final.data_ptr(),
  470. m,
  471. n,
  472. k,
  473. mTiles,
  474. nTiles,
  475. kTiles);
  476. return C_final;
  477. #endif
  478. }
  479. at::Tensor hi_mm_origorder(
  480. const at::Tensor& A,
  481. const at::Tensor& B) {
  482. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  483. c10::cuda::CUDAGuard g(A.device());
  484. auto stream = at::cuda::getCurrentCUDAStream();
  485. constexpr int Warps = 8;
  486. // row major layout
  487. auto m = A.size(0);
  488. auto mTiles = divUp(m, kMTileSize);
  489. // tensor core layout
  490. auto n = B.size(0);
  491. auto nTiles = divUp(n, kNTileSize);
  492. // row major layout
  493. auto k = A.size(1);
  494. auto kTiles = divUp(k, kKTileSize);
  495. // Output is a standard row-major matrix
  496. auto C_final = at::empty(
  497. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  498. auto grid = dim3(1, nTiles, mTiles);
  499. auto block = dim3(kWarpSize, Warps);
  500. auto kernel = tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_HI, ALayout_RM, 8, 8>;
  501. kernel<<<grid, block, 0, stream>>>(
  502. (const half*)A.data_ptr(),
  503. (const void*)B.data_ptr(),
  504. nullptr,
  505. (half*)C_final.data_ptr(),
  506. m,
  507. n,
  508. k,
  509. mTiles,
  510. nTiles,
  511. kTiles);
  512. return C_final;
  513. #endif
  514. }
  515. #define DECOMPRESS_D4_BLOCK_SIZE 256
  516. __global__ void cuda_decompress_d4_origorder_kernel(
  517. const uint8_t* __restrict__ YIs, // m x (n/4)
  518. const c10::Half* __restrict__ CB, // 256 x 4
  519. c10::Half* __restrict__ Y // m x n
  520. ) {
  521. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  522. const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x;
  523. for(long r = 0; r < 4; r++) {
  524. uint8_t yidx = ((uint8_t*)YIs)[i*4 + r];
  525. ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255];
  526. }
  527. #endif
  528. }
  529. void decompress_d4_origorder(
  530. torch::Tensor YIs, // m x (n/4)
  531. torch::Tensor CB, // 256 x 4
  532. torch::Tensor Y // m x n
  533. ) {
  534. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  535. size_t m = Y.sizes()[0];
  536. size_t n = Y.sizes()[1];
  537. assert(YIs.is_contiguous());
  538. assert(CB.is_contiguous());
  539. assert(Y.is_contiguous());
  540. assert(YIs.sizes()[0] == m);
  541. assert(YIs.sizes()[1] * 4 == n);
  542. assert(CB.sizes()[0] == 256);
  543. const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE);
  544. const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE));
  545. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  546. cuda_decompress_d4_origorder_kernel<<<blocks, threads, 0, stream>>>(
  547. YIs.data_ptr<uint8_t>(),
  548. CB.data_ptr<c10::Half>(),
  549. Y.data_ptr<c10::Half>()
  550. );
  551. #endif
  552. }
  553. #define DECOMPRESS_E8P_BLOCK_SIZE 256
  554. __global__ void cuda_decompress_e8p_origorder_kernel(
  555. const int16_t* __restrict__ YIs, // m x (n/8)
  556. const int64_t* __restrict__ CB, // 256 x 8
  557. c10::Half* __restrict__ Y // m x n
  558. ) {
  559. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  560. const long i = threadIdx.x + DECOMPRESS_E8P_BLOCK_SIZE * blockIdx.x;
  561. uint16_t yidx = ((uint16_t*)YIs)[i];
  562. uint64_t decoded = BLayout_E8::decode8weights(yidx, CB);
  563. half2 unpacked[2][2];
  564. uint64_t lower_half = decoded & 0x00ff00ff00ff00ff;
  565. lower_half = (lower_half ^ 0x5c805c805c805c80);
  566. memcpy(unpacked[0], &lower_half, sizeof(uint64_t));
  567. uint64_t upper_half = (decoded & 0xff00ff00ff00ff00) >> 8;
  568. upper_half = (upper_half ^ 0x5c805c805c805c80);
  569. memcpy(unpacked[1], &upper_half, sizeof(uint64_t));
  570. const half adjust_ = __float2half_rn(-288.0f);
  571. const half2 adjust = __halves2half2(adjust_, adjust_);
  572. ((__half2*)Y)[i*4] = __hadd2(unpacked[0][0], adjust); // 01
  573. ((__half2*)Y)[i*4+2] = __hadd2(unpacked[0][1], adjust); // 45
  574. ((__half2*)Y)[i*4+1] = __hadd2(unpacked[1][0], adjust); // 23
  575. ((__half2*)Y)[i*4+3] = __hadd2(unpacked[1][1], adjust); // 67
  576. #endif
  577. }
  578. void decompress_e8p_origorder(
  579. torch::Tensor YIs, // m x (n/8)
  580. torch::Tensor CB, // 256 x 8
  581. torch::Tensor &Y // m x n
  582. ) {
  583. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  584. size_t m = Y.sizes()[0];
  585. size_t n = Y.sizes()[1];
  586. assert(YIs.is_contiguous());
  587. assert(CB.is_contiguous());
  588. assert(Y.is_contiguous());
  589. assert(YIs.sizes()[0] == m);
  590. assert(YIs.sizes()[1] * 8 == n);
  591. assert(CB.sizes()[0] == 256);
  592. const dim3 threads(DECOMPRESS_E8P_BLOCK_SIZE);
  593. const dim3 blocks(m*n/(8*DECOMPRESS_E8P_BLOCK_SIZE));
  594. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  595. cuda_decompress_e8p_origorder_kernel<<<blocks, threads, 0, stream>>>(
  596. YIs.data_ptr<int16_t>(),
  597. CB.data_ptr<int64_t>(),
  598. Y.data_ptr<c10::Half>()
  599. );
  600. #endif
  601. }
  602. #define DECOMPRESS_HI_BLOCK_SIZE 256
  603. __global__ void cuda_decompress_hi_origorder_kernel(
  604. const uint32_t* __restrict__ YIs, // m x (n/8)
  605. c10::Half* __restrict__ Y // m x n
  606. ) {
  607. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  608. const long i = threadIdx.x + DECOMPRESS_HI_BLOCK_SIZE * blockIdx.x;
  609. uint32_t qa = YIs[i];
  610. const uint32_t c0 = 0x64086408;
  611. const half y16_ = __float2half_rn(1.0f / 16.0f);
  612. const half2 y16 = __halves2half2(y16_, y16_);
  613. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  614. const half2 z16 = __halves2half2(z16_, z16_);
  615. uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
  616. uint32_t q1 = ((qa & 0x00f000f0)| c0);
  617. qa >>= 8;
  618. uint32_t q2 = (((qa & 0x000f000f) << 4) | c0);
  619. uint32_t q3 = ((qa & 0x00f000f0) | c0);
  620. ((__half2*)Y)[i*4] = __hfma2(*((half2*)(&q0)), y16, z16);
  621. ((__half2*)Y)[i*4+1] = __hfma2(*((half2*)(&q1)), y16, z16);
  622. ((__half2*)Y)[i*4+2] = __hfma2(*((half2*)(&q2)), y16, z16);
  623. ((__half2*)Y)[i*4+3] = __hfma2(*((half2*)(&q3)), y16, z16);
  624. #endif
  625. }
  626. void decompress_hi_origorder(
  627. torch::Tensor YIs, // m x (n/8)
  628. torch::Tensor Y // m x n
  629. ){
  630. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  631. size_t m = Y.sizes()[0];
  632. size_t n = Y.sizes()[1];
  633. assert(YIs.is_contiguous());
  634. assert(Y.is_contiguous());
  635. assert(YIs.sizes()[0] == m);
  636. assert(YIs.sizes()[1] * 8 == n);
  637. const dim3 threads(DECOMPRESS_HI_BLOCK_SIZE);
  638. const dim3 blocks(m*n/(8*DECOMPRESS_HI_BLOCK_SIZE));
  639. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  640. cuda_decompress_hi_origorder_kernel<<<blocks, threads, 0, stream>>>(
  641. (uint32_t*)YIs.data_ptr<int32_t>(),
  642. Y.data_ptr<c10::Half>()
  643. );
  644. #endif
  645. }