origin_order.cu 22 KB


  1. #include <cuda_bf16.h>
  2. #include <cuda_fp16.h>
  3. #include <cuda_runtime.h>
  4. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 700
  5. #include <mma.h>
  6. #endif
  7. #include <ATen/ATen.h>
  8. #include <ATen/core/Tensor.h>
  9. #include <ATen/cuda/CUDAContext.h>
  10. #include <ATen/DeviceGuard.h>
  11. #include <torch/all.h>
  12. #include <c10/cuda/CUDAGuard.h>
  13. template <typename U, typename V>
  14. constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
  15. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  16. return (a / b);
  17. }
  18. template <typename U, typename V>
  19. constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
  20. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  21. // Overflow safe variant of (a + b - 1) / b
  22. const uint64_t blocks = a / b + (a % b != 0);
  23. return blocks;
  24. }
  25. template <typename U, typename V>
  26. constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
  27. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  28. return divDown(a, b) * b;
  29. }
  30. template <typename U, typename V>
  31. constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
  32. static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
  33. return divUp(a, b) * b;
  34. }
  35. constexpr int32_t kWarpSize = 32;
  36. constexpr int32_t KTilesPerWarp = 8;
  37. constexpr int32_t kMTileSize = 16;
  38. constexpr int32_t kNTileSize = 8;
  39. constexpr int32_t kKTileSize = 16;
  40. struct __align__(16) f16x2x4_u32 { uint32_t vals[4]; };
  41. struct __align__(16) f16x2x2_u32 { uint32_t vals[2]; };
  42. struct ALayout_RM {
  43. template <int KTilesToLoad>
  44. static __device__ void load(const half* A, int32_t m, int32_t k,
  45. int32_t mTiles, int32_t mTile, int32_t kTiles,
  46. int32_t kTileStart, int32_t laneId,
  47. f16x2x4_u32 out[KTilesToLoad]) {
  48. const auto mLane = mTile * kMTileSize + (laneId / 4);
  49. const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 4;
  50. // access
  51. // [mTile * kMTileSize + (laneId / 4)]
  52. // [kTileStart * kKTileSize + (laneId % 4) * 2]
  53. auto aPtr = A + mLane * k + kLane;
  54. auto aPtrPlus8Rows = aPtr + 8 * k;
  55. bool m0InBounds = mLane < m;
  56. bool m1InBounds = (mLane + 8) < m;
  57. #pragma unroll
  58. for (int i = 0; i < KTilesToLoad; ++i) {
  59. out[i].vals[0] =
  60. m0InBounds ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
  61. : uint32_t(0);
  62. out[i].vals[1] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
  63. aPtrPlus8Rows + i * kKTileSize)
  64. : uint32_t(0);
  65. out[i].vals[2] =
  66. m0InBounds
  67. ? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 2)
  68. : uint32_t(0);
  69. out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
  70. aPtrPlus8Rows + i * kKTileSize + 2)
  71. : uint32_t(0);
  72. }
  73. }
  74. static __device__ void store(half* C, int32_t m, int32_t n, int32_t mOutTiles,
  75. int32_t mTile, int32_t nOutTiles, int32_t nTile,
  76. int32_t laneId, const float4& out) {
  77. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  78. // sum.x / sum.y are written at
  79. // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
  80. // sum.z / sum.w are written at
  81. // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
  82. // i.e., same columns, different row.
  83. const int outRow = mTile * kMTileSize + (laneId / 4);
  84. const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
  85. // Pointer where sum.x / sum.y is written
  86. auto cPtr = C + outRow * n + outCol;
  87. auto v01 = __float22half2_rn(float2{out.x, out.y});
  88. auto v23 = __float22half2_rn(float2{out.z, out.w});
  89. if (outRow < m) {
  90. *reinterpret_cast<half2*>(cPtr) = v01;
  91. }
  92. // sum.z, sum.w at +8 rows from cPtr
  93. if (outRow + 8 < m) {
  94. *reinterpret_cast<half2*>(cPtr + 8 * n) = v23;
  95. }
  96. #endif
  97. }
  98. };
  99. struct BLayout_D4 {
  100. static constexpr bool use_codebook = true;
  101. template <int KTilesPerIteration>
  102. static __device__ void load(const void* __restrict__ B,
  103. const uint64_t* __restrict__ CB, int32_t n,
  104. int32_t k, int32_t nTiles, int32_t nTile,
  105. int32_t kTiles, int32_t kTileStart,
  106. int32_t laneId,
  107. f16x2x2_u32 b[KTilesPerIteration]) {
  108. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  109. auto Bptr = reinterpret_cast<const uint8_t*>(B);
  110. #pragma unroll
  111. for (int i = 0; i < KTilesPerIteration; ++i) {
  112. const int row = nTile * kNTileSize + laneId / 4;
  113. const int col = (kTileStart + i) * kKTileSize / 4 + laneId % 4;
  114. *(reinterpret_cast<uint64_t*>(b[i].vals)) = CB[Bptr[row * k / 4 + col]];
  115. }
  116. #endif
  117. }
  118. };
  119. struct BLayout_HI {
  120. static constexpr bool use_codebook = false;
  121. template <int KTilesPerIteration>
  122. static __device__ void load(const void* __restrict__ B,
  123. const uint64_t* __restrict__ CB, int32_t n,
  124. int32_t k, int32_t nTiles, int32_t nTile,
  125. int32_t kTiles, int32_t kTileStart,
  126. int32_t laneId,
  127. f16x2x2_u32 b[KTilesPerIteration]) {
  128. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  129. auto Bptr = reinterpret_cast<const uint32_t*>(B);
  130. #pragma unroll
  131. for (int i = 0; i < KTilesPerIteration; ++i) {
  132. const int row = nTile * kNTileSize + laneId / 4;
  133. const int col = (kTileStart + i) * kKTileSize / 8 + (laneId % 4) / 2;
  134. // simply use code - 7.5 instead of reading codebook
  135. uint32_t code = Bptr[row * k / 8 + col];
  136. const uint32_t c0 = 0x64086408;
  137. const half y16_ = __float2half_rn(1.0f / 16.0f);
  138. const half2 y16 = __halves2half2(y16_, y16_);
  139. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  140. const half2 z16 = __halves2half2(z16_, z16_);
  141. uint32_t qa = code >> ((laneId & 1) * 8);
  142. uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
  143. uint32_t q1 = ((qa & 0x00f000f0) | c0);
  144. *(half2*)(b[i].vals) = __hfma2(*((half2*)(&q0)), y16, z16);
  145. *(half2*)(b[i].vals + 1) = __hfma2(*((half2*)(&q1)), y16, z16);
  146. }
  147. #endif
  148. }
  149. };
  150. struct BLayout_E8 {
  151. static constexpr bool use_codebook = true;
  152. __device__ static inline uint64_t decode8weights(
  153. uint16_t weight_compressed, const int64_t* __restrict__ codebook_abs) {
  154. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  155. uint8_t bits_sign = weight_compressed & 0xff;
  156. uint8_t parity = __popc(bits_sign) & 1;
  157. uint8_t sign_vec = bits_sign ^ parity;
  158. uint8_t bits_abs = (weight_compressed >> 8);
  159. int64_t packed = codebook_abs[bits_abs];
  160. uint64_t decoded_sign = sign_vec * 0x8040201008040201ll;
  161. decoded_sign &= 0x8080808080808080;
  162. decoded_sign >>= 7;
  163. decoded_sign *= 255 - 3;
  164. packed ^= decoded_sign;
  165. packed |= 0x0101010101010101;
  166. packed -= parity * 0x0202020202020202;
  167. return packed;
  168. #endif
  169. }
  170. __device__ static inline uint32_t decode8weights(
  171. uint16_t weight_compressed, const int64_t* __restrict__ codebook_abs,
  172. int idx) {
  173. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  174. uint8_t bits_sign =
  175. weight_compressed & 0xff; //__brev(weight_compressed) >> 24;
  176. const uint32_t magic_nums[2] = {0x08040201ll, 0x80402010ll};
  177. uint8_t parity = __popc(bits_sign) & 1;
  178. uint8_t sign_vec = bits_sign ^ parity; // (parity << 7);
  179. uint16_t bits_abs = (weight_compressed >> 8);
  180. uint32_t packed = ((uint32_t*)codebook_abs)[(bits_abs << 1) + idx];
  181. uint32_t magic_num = magic_nums[idx];
  182. uint32_t decoded_sign = sign_vec * magic_num;
  183. decoded_sign &= 0x80808080;
  184. decoded_sign >>= 7;
  185. decoded_sign *= 255 - 3;
  186. packed ^= decoded_sign;
  187. packed |= 0x01010101;
  188. packed -= parity * 0x02020202;
  189. return packed;
  190. #endif
  191. };
  192. template <int KTilesPerIteration>
  193. static __device__ void load(const void* __restrict__ B,
  194. const uint64_t* __restrict__ CB, int32_t n,
  195. int32_t k, int32_t nTiles, int32_t nTile,
  196. int32_t kTiles, int32_t kTileStart,
  197. int32_t laneId,
  198. f16x2x2_u32 b[KTilesPerIteration]) {
  199. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  200. auto Bptr = (const uint16_t*)B;
  201. #pragma unroll
  202. for (int i = 0; i < KTilesPerIteration; ++i) {
  203. const int row = nTile * kNTileSize + laneId / 4;
  204. const int col = (kTileStart + i) * kKTileSize / 8 + laneId % 4 / 2;
  205. uint32_t decoded = decode8weights(Bptr[row * k / 8 + col],
  206. (const int64_t*)CB, laneId & 1);
  207. half2 unpacked[2];
  208. uint32_t lower_half = decoded & 0x00ff00ff;
  209. lower_half = (lower_half ^ 0x5c805c80);
  210. memcpy(unpacked, &lower_half, sizeof(uint32_t));
  211. uint32_t upper_half = (decoded & 0xff00ff00) >> 8;
  212. upper_half = (upper_half ^ 0x5c805c80);
  213. memcpy(unpacked + 1, &upper_half, sizeof(uint32_t));
  214. const half adjust_ = __float2half_rn(-288.0f);
  215. const half2 adjust = __halves2half2(adjust_, adjust_);
  216. unpacked[0] = __hadd2(unpacked[0], adjust);
  217. unpacked[1] = __hadd2(unpacked[1], adjust);
  218. *(reinterpret_cast<uint64_t*>(b[i].vals)) =
  219. *(reinterpret_cast<uint64_t*>(unpacked));
  220. //*((half*)(b[i].vals)) = unpacked[0];
  221. //*((half*)(b[i].vals) + 1) = unpacked[0].y;
  222. //*((half*)(b[i].vals) + 2) = unpacked[1].x;
  223. //*((half*)(b[i].vals) + 3) = unpacked[1].y;
  224. }
  225. #endif
  226. }
  227. };
  228. template <typename ALayout, typename BLayout, typename CLayout, int Warps,
  229. int KTilesPerIteration>
  230. __global__ __launch_bounds__(256) void tinygemm_m16n8k16_chunk_kernel(
  231. // Data for the A matrix, loaded as per ALayout
  232. const half* __restrict__ A, const void* __restrict__ B,
  233. const uint64_t* __restrict__ CB,
  234. // Output data for the C matrix, stored as per CLayout
  235. half* __restrict__ C,
  236. // The size of the matrix multiplication
  237. int32_t m, int32_t n, int32_t k,
  238. // The size of the matrix multiplication, in multiples of our TC tile size
  239. int32_t mTiles, int32_t nTiles, int32_t kTiles) {
  240. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  241. __shared__ uint64_t CB_[256];
  242. if (BLayout::use_codebook) {
  243. CB_[threadIdx.x + threadIdx.y * 32] = CB[threadIdx.x + threadIdx.y * 32];
  244. __syncthreads();
  245. }
  246. auto warpId = threadIdx.y;
  247. auto laneId = threadIdx.x;
  248. int32_t mTile = blockIdx.z;
  249. int32_t nTile = blockIdx.y;
  250. float4 c{0.0f, 0.0f, 0.0f, 0.0f};
  251. // First, handle whole multiples of KTilesPerIteration
  252. auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
  253. // Each warp handles a set of KTilesPerIteration under the above limit
  254. for (int32_t kTileBase = warpId * KTilesPerIteration; kTileBase < kTilesLimit;
  255. kTileBase += Warps * KTilesPerIteration) {
  256. //
  257. // Load data from A
  258. //
  259. f16x2x4_u32 a[KTilesPerIteration];
  260. ALayout::template load<KTilesPerIteration>(A, m, k, mTiles, mTile, kTiles,
  261. kTileBase, laneId, a);
  262. //
  263. // Load data from B and de-quantize as needed
  264. //
  265. f16x2x2_u32 b[KTilesPerIteration];
  266. BLayout::template load<KTilesPerIteration>(B, CB_, n, k, nTiles, nTile,
  267. kTiles, kTileBase, laneId, b);
  268. // Now, perform the matrix multiplication
  269. //
  270. #pragma unroll
  271. for (int i = 0; i < KTilesPerIteration / 2; ++i) {
  272. float4 cTmp[2];
  273. #pragma unroll
  274. for (int k = 0; k < 2; ++k) {
  275. cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
  276. }
  277. #pragma unroll
  278. for (int k = 0; k < 2; ++k) {
  279. asm volatile(
  280. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  281. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
  282. : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
  283. : "r"(a[i * 2 + k].vals[0]), "r"(a[i * 2 + k].vals[1]),
  284. "r"(a[i * 2 + k].vals[2]), "r"(a[i * 2 + k].vals[3]),
  285. "r"(b[i * 2 + k].vals[0]), "r"(b[i * 2 + k].vals[1]),
  286. "f"(cTmp[k].x), "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w));
  287. }
  288. #pragma unroll
  289. for (int k = 0; k < 2; ++k) {
  290. c.x += cTmp[k].x;
  291. c.y += cTmp[k].y;
  292. c.z += cTmp[k].z;
  293. c.w += cTmp[k].w;
  294. }
  295. }
  296. } // for all tiles under kTilesLimit
  297. auto kTileBaseRemaining = kTilesLimit + warpId;
  298. // If we have any remainder k-tiles, some warps will handle them, processing
  299. // kInnerKTiles k-tiles at a time
  300. if (kTileBaseRemaining < kTiles) {
  301. f16x2x4_u32 a;
  302. ALayout::template load<1>(A, m, k, mTiles, mTile, kTiles,
  303. kTileBaseRemaining, laneId, &a);
  304. f16x2x2_u32 b;
  305. BLayout::template load<1>(B, CB, n, k, nTiles, nTile, kTiles,
  306. kTileBaseRemaining, laneId, &b);
  307. asm volatile(
  308. "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
  309. "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
  310. : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
  311. : "r"(a.vals[0]), "r"(a.vals[1]), "r"(a.vals[2]), "r"(a.vals[3]),
  312. "r"(b.vals[0]), "r"(b.vals[1]), "f"(c.x), "f"(c.y), "f"(c.z),
  313. "f"(c.w));
  314. }
  315. // Reduce independent k-tiles (same m/n) across warps
  316. __shared__ float4 smem_sum[Warps][kWarpSize];
  317. smem_sum[warpId][laneId] = c;
  318. __syncthreads();
  319. if (warpId == 0) {
  320. float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
  321. // Reduce across the block in the first warp
  322. for (int i = 0; i < Warps; ++i) {
  323. float4 v = smem_sum[i][laneId];
  324. sum_f32.x += v.x;
  325. sum_f32.y += v.y;
  326. sum_f32.z += v.z;
  327. sum_f32.w += v.w;
  328. }
  329. // Write the reduced result (in the first warp) into the output
  330. CLayout::store(C, m, n, mTiles, mTile,
  331. // n for C output becomes k for A input, so for m16n8k16,
  332. // we need to halve the tiles
  333. nTiles / 2, nTile, laneId, sum_f32);
  334. }
  335. #endif
  336. }
  337. at::Tensor d4_mm_origorder(const at::Tensor& A, const at::Tensor& B,
  338. const at::Tensor& CB) {
  339. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  340. c10::cuda::CUDAGuard g(A.device());
  341. auto stream = at::cuda::getCurrentCUDAStream();
  342. constexpr int Warps = 8;
  343. // row major layout
  344. auto m = A.size(0);
  345. auto mTiles = divUp(m, kMTileSize);
  346. // tensor core layout
  347. auto n = B.size(0);
  348. auto nTiles = divUp(n, kNTileSize);
  349. // row major layout
  350. auto k = A.size(1);
  351. auto kTiles = divUp(k, kKTileSize);
  352. // Output is a standard row-major matrix
  353. auto C_final = at::empty(
  354. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  355. auto grid = dim3(1, nTiles, mTiles);
  356. auto block = dim3(kWarpSize, Warps);
  357. auto kernel =
  358. tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_D4, ALayout_RM, 8, 8>;
  359. kernel<<<grid, block, 0, stream>>>(
  360. (const half*)A.data_ptr(), (const void*)B.data_ptr(),
  361. (const uint64_t*)CB.data_ptr(), (half*)C_final.data_ptr(), m, n, k,
  362. mTiles, nTiles, kTiles);
  363. return C_final;
  364. #endif
  365. }
  366. at::Tensor e8p_mm_origorder(const at::Tensor& A, const at::Tensor& B,
  367. const at::Tensor& CB) {
  368. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  369. c10::cuda::CUDAGuard g(A.device());
  370. auto stream = at::cuda::getCurrentCUDAStream();
  371. constexpr int Warps = 8;
  372. // row major layout
  373. auto m = A.size(0);
  374. auto mTiles = divUp(m, kMTileSize);
  375. // tensor core layout
  376. auto n = B.size(0);
  377. auto nTiles = divUp(n, kNTileSize);
  378. // row major layout
  379. auto k = A.size(1);
  380. auto kTiles = divUp(k, kKTileSize);
  381. // Output is a standard row-major matrix
  382. auto C_final = at::empty(
  383. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  384. auto grid = dim3(1, nTiles, mTiles);
  385. auto block = dim3(kWarpSize, Warps);
  386. auto kernel =
  387. tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_E8, ALayout_RM, 8, 8>;
  388. kernel<<<grid, block, 0, stream>>>(
  389. (const half*)A.data_ptr(), (const void*)B.data_ptr(),
  390. (const uint64_t*)CB.data_ptr(), (half*)C_final.data_ptr(), m, n, k,
  391. mTiles, nTiles, kTiles);
  392. return C_final;
  393. #endif
  394. }
  395. at::Tensor hi_mm_origorder(const at::Tensor& A, const at::Tensor& B) {
  396. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  397. c10::cuda::CUDAGuard g(A.device());
  398. auto stream = at::cuda::getCurrentCUDAStream();
  399. constexpr int Warps = 8;
  400. // row major layout
  401. auto m = A.size(0);
  402. auto mTiles = divUp(m, kMTileSize);
  403. // tensor core layout
  404. auto n = B.size(0);
  405. auto nTiles = divUp(n, kNTileSize);
  406. // row major layout
  407. auto k = A.size(1);
  408. auto kTiles = divUp(k, kKTileSize);
  409. // Output is a standard row-major matrix
  410. auto C_final = at::empty(
  411. {m, n}, at::TensorOptions().dtype(A.dtype()).device(A.device()));
  412. auto grid = dim3(1, nTiles, mTiles);
  413. auto block = dim3(kWarpSize, Warps);
  414. auto kernel =
  415. tinygemm_m16n8k16_chunk_kernel<ALayout_RM, BLayout_HI, ALayout_RM, 8, 8>;
  416. kernel<<<grid, block, 0, stream>>>(
  417. (const half*)A.data_ptr(), (const void*)B.data_ptr(), nullptr,
  418. (half*)C_final.data_ptr(), m, n, k, mTiles, nTiles, kTiles);
  419. return C_final;
  420. #endif
  421. }
  422. #define DECOMPRESS_D4_BLOCK_SIZE 256
  423. __global__ void cuda_decompress_d4_origorder_kernel(
  424. const uint8_t* __restrict__ YIs, // m x (n/4)
  425. const c10::Half* __restrict__ CB, // 256 x 4
  426. c10::Half* __restrict__ Y // m x n
  427. ) {
  428. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  429. const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x;
  430. for (long r = 0; r < 4; r++) {
  431. uint8_t yidx = ((uint8_t*)YIs)[i * 4 + r];
  432. ((uint64_t*)Y)[i * 4 + r] = ((uint64_t*)CB)[yidx & 255];
  433. }
  434. #endif
  435. }
  436. void decompress_d4_origorder(torch::Tensor YIs, // m x (n/4)
  437. torch::Tensor CB, // 256 x 4
  438. torch::Tensor Y // m x n
  439. ) {
  440. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  441. size_t m = Y.sizes()[0];
  442. size_t n = Y.sizes()[1];
  443. assert(YIs.is_contiguous());
  444. assert(CB.is_contiguous());
  445. assert(Y.is_contiguous());
  446. assert(YIs.sizes()[0] == m);
  447. assert(YIs.sizes()[1] * 4 == n);
  448. assert(CB.sizes()[0] == 256);
  449. const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE);
  450. const dim3 blocks(m * n / (16 * DECOMPRESS_D4_BLOCK_SIZE));
  451. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  452. cuda_decompress_d4_origorder_kernel<<<blocks, threads, 0, stream>>>(
  453. YIs.data_ptr<uint8_t>(), CB.data_ptr<c10::Half>(),
  454. Y.data_ptr<c10::Half>());
  455. #endif
  456. }
  457. #define DECOMPRESS_E8P_BLOCK_SIZE 256
  458. __global__ void cuda_decompress_e8p_origorder_kernel(
  459. const int16_t* __restrict__ YIs, // m x (n/8)
  460. const int64_t* __restrict__ CB, // 256 x 8
  461. c10::Half* __restrict__ Y // m x n
  462. ) {
  463. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  464. const long i = threadIdx.x + DECOMPRESS_E8P_BLOCK_SIZE * blockIdx.x;
  465. uint16_t yidx = ((uint16_t*)YIs)[i];
  466. uint64_t decoded = BLayout_E8::decode8weights(yidx, CB);
  467. half2 unpacked[2][2];
  468. uint64_t lower_half = decoded & 0x00ff00ff00ff00ff;
  469. lower_half = (lower_half ^ 0x5c805c805c805c80);
  470. memcpy(unpacked[0], &lower_half, sizeof(uint64_t));
  471. uint64_t upper_half = (decoded & 0xff00ff00ff00ff00) >> 8;
  472. upper_half = (upper_half ^ 0x5c805c805c805c80);
  473. memcpy(unpacked[1], &upper_half, sizeof(uint64_t));
  474. const half adjust_ = __float2half_rn(-288.0f);
  475. const half2 adjust = __halves2half2(adjust_, adjust_);
  476. ((__half2*)Y)[i * 4] = __hadd2(unpacked[0][0], adjust); // 01
  477. ((__half2*)Y)[i * 4 + 2] = __hadd2(unpacked[0][1], adjust); // 45
  478. ((__half2*)Y)[i * 4 + 1] = __hadd2(unpacked[1][0], adjust); // 23
  479. ((__half2*)Y)[i * 4 + 3] = __hadd2(unpacked[1][1], adjust); // 67
  480. #endif
  481. }
  482. void decompress_e8p_origorder(torch::Tensor YIs, // m x (n/8)
  483. torch::Tensor CB, // 256 x 8
  484. torch::Tensor& Y // m x n
  485. ) {
  486. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  487. size_t m = Y.sizes()[0];
  488. size_t n = Y.sizes()[1];
  489. assert(YIs.is_contiguous());
  490. assert(CB.is_contiguous());
  491. assert(Y.is_contiguous());
  492. assert(YIs.sizes()[0] == m);
  493. assert(YIs.sizes()[1] * 8 == n);
  494. assert(CB.sizes()[0] == 256);
  495. const dim3 threads(DECOMPRESS_E8P_BLOCK_SIZE);
  496. const dim3 blocks(m * n / (8 * DECOMPRESS_E8P_BLOCK_SIZE));
  497. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  498. cuda_decompress_e8p_origorder_kernel<<<blocks, threads, 0, stream>>>(
  499. YIs.data_ptr<int16_t>(), CB.data_ptr<int64_t>(), Y.data_ptr<c10::Half>());
  500. #endif
  501. }
  502. #define DECOMPRESS_HI_BLOCK_SIZE 256
  503. __global__ void cuda_decompress_hi_origorder_kernel(
  504. const uint32_t* __restrict__ YIs, // m x (n/8)
  505. c10::Half* __restrict__ Y // m x n
  506. ) {
  507. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  508. const long i = threadIdx.x + DECOMPRESS_HI_BLOCK_SIZE * blockIdx.x;
  509. uint32_t qa = YIs[i];
  510. const uint32_t c0 = 0x64086408;
  511. const half y16_ = __float2half_rn(1.0f / 16.0f);
  512. const half2 y16 = __halves2half2(y16_, y16_);
  513. const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
  514. const half2 z16 = __halves2half2(z16_, z16_);
  515. uint32_t q0 = (((qa & 0x000f000f) << 4) | c0);
  516. uint32_t q1 = ((qa & 0x00f000f0) | c0);
  517. qa >>= 8;
  518. uint32_t q2 = (((qa & 0x000f000f) << 4) | c0);
  519. uint32_t q3 = ((qa & 0x00f000f0) | c0);
  520. ((__half2*)Y)[i * 4] = __hfma2(*((half2*)(&q0)), y16, z16);
  521. ((__half2*)Y)[i * 4 + 1] = __hfma2(*((half2*)(&q1)), y16, z16);
  522. ((__half2*)Y)[i * 4 + 2] = __hfma2(*((half2*)(&q2)), y16, z16);
  523. ((__half2*)Y)[i * 4 + 3] = __hfma2(*((half2*)(&q3)), y16, z16);
  524. #endif
  525. }
  526. void decompress_hi_origorder(torch::Tensor YIs, // m x (n/8)
  527. torch::Tensor Y // m x n
  528. ) {
  529. #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800
  530. size_t m = Y.sizes()[0];
  531. size_t n = Y.sizes()[1];
  532. assert(YIs.is_contiguous());
  533. assert(Y.is_contiguous());
  534. assert(YIs.sizes()[0] == m);
  535. assert(YIs.sizes()[1] * 8 == n);
  536. const dim3 threads(DECOMPRESS_HI_BLOCK_SIZE);
  537. const dim3 blocks(m * n / (8 * DECOMPRESS_HI_BLOCK_SIZE));
  538. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  539. cuda_decompress_hi_origorder_kernel<<<blocks, threads, 0, stream>>>(
  540. (uint32_t*)YIs.data_ptr<int32_t>(), Y.data_ptr<c10::Half>());
  541. #endif
  542. }