gemm_kernels.cu 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. /*
  2. * Modified by Neural Magic
  3. * Adapted from https://github.com/Vahe1994/AQLM
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <cuda.h>
  18. #include <cuda_fp16.h>
  19. #include <cuda_runtime.h>
  20. #include <torch/all.h>
  21. #include <c10/cuda/CUDAStream.h>
  22. #include <c10/cuda/CUDAGuard.h>
  23. #include <iostream>
  24. #include <cstdlib>
  25. namespace aphrodite {
  26. namespace aqlm {
  27. __global__ void Code1x16MatVec(
  28. const int4* __restrict__ A, const int4* __restrict__ B,
  29. int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
  30. const int prob_k,
  31. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
  32. // codebook, at most 3 long.
  33. const int codebook_stride // as int4.
  34. ) {
  35. int a_gl_stride = prob_k / 8 / 8;
  36. int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  37. bool pred = a_gl_rd < prob_m;
  38. if (pred) {
  39. // advance to the correct codebook, this easy because we only multiply one
  40. // column of the codebook.
  41. auto codebook_size = &codebook_a_sizes.x;
  42. while (a_gl_rd >= *codebook_size) {
  43. codebook += codebook_stride;
  44. ++codebook_size;
  45. }
  46. }
  47. int b_gl_rd = 0;
  48. int c_gl_wr = a_gl_rd;
  49. a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  50. int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  51. __shared__ int4 sh_b[32 * 9];
  52. float res = 0;
  53. int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  54. while (iters--) {
  55. // We pad shared memory to avoid bank conflicts during reads
  56. __syncthreads();
  57. for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
  58. if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
  59. }
  60. __syncthreads();
  61. b_gl_rd += 32 * 8;
  62. int b_sh_rd = 9 * (threadIdx.x % 32);
  63. if (pred && a_gl_rd < a_gl_end) {
  64. const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
  65. #pragma unroll
  66. for (int i = 0; i < 8; i++) {
  67. uint32_t dec[4];
  68. // We bypass the L1 cache to avoid massive amounts of memory streaming
  69. // that doesn't actually help us; this brings > 2x speedup.
  70. asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
  71. : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
  72. : "l"((void*)&codebook[enc[i]]));
  73. half2* a = reinterpret_cast<half2*>(&dec);
  74. half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
  75. half2 res2 = {};
  76. #pragma unroll
  77. for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
  78. res += __half2float(res2.x) + __half2float(res2.y);
  79. b_sh_rd++;
  80. }
  81. a_gl_rd += 32;
  82. }
  83. }
  84. if (pred) {
  85. #pragma unroll
  86. for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
  87. if (threadIdx.x % 32 == 0)
  88. reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  89. }
  90. }
  91. __global__ void Code2x8MatVec(
  92. const int4* __restrict__ A, const int4* __restrict__ B,
  93. int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
  94. int prob_k,
  95. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
  96. // codebook, at most 3 long.
  97. const int codebook_stride // as int4.
  98. ) {
  99. int a_gl_stride = prob_k / 8 / 8;
  100. int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  101. bool pred = a_gl_rd < prob_m;
  102. if (pred) {
  103. // advance to the correct codebook, this easy because we only multiply one
  104. // column of the codebook.
  105. auto codebook_size = &codebook_a_sizes.x;
  106. while (a_gl_rd >= *codebook_size) {
  107. codebook += codebook_stride;
  108. ++codebook_size;
  109. }
  110. }
  111. int b_gl_rd = 0;
  112. int c_gl_wr = a_gl_rd;
  113. a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  114. int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  115. int lane = threadIdx.x % 8;
  116. extern __shared__ int4 sh[];
  117. int4* sh_b = sh;
  118. int4* sh_code = sh_b + 32 * 9;
  119. int4* sh_code0 = sh_code;
  120. int4* sh_code1 = sh_code + 256 * 8;
  121. for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
  122. int4 dec = codebook[i];
  123. #pragma unroll
  124. for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
  125. }
  126. __syncthreads();
  127. float res = 0;
  128. int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  129. while (iters--) {
  130. // We pad shared memory to avoid bank conflicts during reads
  131. __syncthreads();
  132. for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
  133. if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
  134. }
  135. __syncthreads();
  136. b_gl_rd += 32 * 8;
  137. int b_sh_rd = 9 * (threadIdx.x % 32);
  138. if (pred && a_gl_rd < a_gl_end) {
  139. const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
  140. #pragma unroll
  141. for (int i = 0; i < 8; i++) {
  142. half2* a0 =
  143. reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
  144. half2* a1 =
  145. reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
  146. half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
  147. half2 res2 = {};
  148. #pragma unroll
  149. for (int j = 0; j < 4; j++)
  150. res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
  151. res += __half2float(res2.x) + __half2float(res2.y);
  152. b_sh_rd++;
  153. }
  154. a_gl_rd += 32;
  155. }
  156. }
  157. if (pred) {
  158. #pragma unroll
  159. for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
  160. if (threadIdx.x % 32 == 0)
  161. reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  162. }
  163. }
  164. __global__ void Code1x16Dequant(
  165. const int4* __restrict__ A, int4* __restrict__ C,
  166. const int4* __restrict__ codebook, int prob_m, int prob_k,
  167. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
  168. // codebook, at most 3 long, sums to m.
  169. const int codebook_stride // as int4
  170. ) {
  171. int a_gl_stride = prob_k / 8 / 8;
  172. int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  173. bool pred = a_gl_rd < prob_m;
  174. if (pred) {
  175. // advance to the correct codebook, this easy because we only multiply one
  176. // column of the codebook.
  177. auto codebook_size = &codebook_a_sizes.x;
  178. while (a_gl_rd >= *codebook_size) {
  179. codebook += codebook_stride;
  180. ++codebook_size;
  181. }
  182. }
  183. a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  184. int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  185. int c_gl_stride = prob_k / 8;
  186. int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  187. c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
  188. int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
  189. while (iters--) {
  190. if (pred && a_gl_rd < a_gl_end) {
  191. const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
  192. #pragma unroll
  193. for (int i = 0; i < 8; i++) {
  194. int4 chunk;
  195. auto dec = reinterpret_cast<uint32_t*>(&chunk);
  196. // We bypass the L1 cache to avoid massive amounts of memory streaming
  197. // that doesn't actually help us; this brings > 2x speedup.
  198. asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
  199. : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
  200. : "l"((void*)&codebook[enc[i]]));
  201. C[a_gl_rd * 8 + i] = chunk;
  202. }
  203. }
  204. a_gl_rd += 32;
  205. }
  206. }
  207. __global__ void Code2x8Dequant(
  208. const int4* __restrict__ A, int4* __restrict__ C,
  209. const int4* __restrict__ codebook, int prob_m, int prob_k,
  210. const int4
  211. codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
  212. // most 3 long, corresponds to cols.
  213. const int codebook_stride // as int4
  214. ) {
  215. int a_gl_stride = prob_k / 8 / 8;
  216. int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  217. bool pred = a_gl_rd < prob_m;
  218. if (pred) {
  219. // advance to the correct codebook, this easy because we only multiply one
  220. // column of the codebook.
  221. auto codebook_size = &codebook_a_sizes.x;
  222. while (a_gl_rd >= *codebook_size) {
  223. codebook += codebook_stride;
  224. ++codebook_size;
  225. }
  226. }
  227. a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  228. int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  229. int lane = threadIdx.x % 8;
  230. int c_gl_stride = prob_k / 8;
  231. int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  232. c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
  233. extern __shared__ int4 sh[];
  234. int4* sh_code = sh;
  235. int4* sh_code0 = sh_code;
  236. int4* sh_code1 = sh_code + 256 * 8;
  237. for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
  238. int4 dec = codebook[i];
  239. #pragma unroll
  240. for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
  241. }
  242. __syncthreads();
  243. int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
  244. while (iters--) {
  245. if (pred && a_gl_rd < a_gl_end) {
  246. const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
  247. #pragma unroll
  248. for (int i = 0; i < 8; i++) {
  249. int4 chunk;
  250. half2* a0 =
  251. reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
  252. half2* a1 =
  253. reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
  254. #pragma unroll
  255. for (int j = 0; j < 4; j++)
  256. reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
  257. C[a_gl_rd * 8 + i] = chunk;
  258. }
  259. }
  260. a_gl_rd += 32;
  261. }
  262. }
  263. inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
  264. const int THREAD_M = 16;
  265. void code1x16_matvec_cuda(const void* __restrict__ A,
  266. const void* __restrict__ B, void* __restrict__ C,
  267. const void* __restrict__ codebook, int prob_m,
  268. int prob_k, const int4 codebook_a_sizes,
  269. const int codebook_stride) {
  270. int sms;
  271. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  272. int waves = 0;
  273. int thread_m;
  274. do {
  275. waves++;
  276. thread_m = ceildiv(prob_m, waves * sms);
  277. } while (thread_m > THREAD_M);
  278. int blocks = ceildiv(prob_m, thread_m);
  279. int threads = 32 * thread_m;
  280. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  281. Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
  282. (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
  283. prob_k, codebook_a_sizes, codebook_stride);
  284. }
  285. void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
  286. void* __restrict__ C,
  287. const void* __restrict__ codebook, int prob_m,
  288. int prob_k, const int4 codebook_a_sizes,
  289. const int codebook_stride) {
  290. int sms;
  291. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  292. int waves = 0;
  293. int thread_m;
  294. do {
  295. waves++;
  296. thread_m = ceildiv(prob_m, waves * sms);
  297. } while (thread_m > THREAD_M);
  298. int blocks = ceildiv(prob_m, thread_m);
  299. int threads = 32 * thread_m;
  300. int shared = 16 * (2 * 256 * 8 + 32 * 9);
  301. cudaFuncSetAttribute(Code2x8MatVec,
  302. cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
  303. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  304. Code2x8MatVec<<<blocks, threads, shared, stream>>>(
  305. (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
  306. prob_k, codebook_a_sizes, codebook_stride);
  307. }
  308. void code1x16_dequant_cuda(
  309. const void* __restrict__ A, void* __restrict__ C,
  310. const void* __restrict__ codebook, int prob_m, int prob_k,
  311. const int4 codebook_a_sizes, // cumulative sizes of A spanning each
  312. // codebook, at most 3 long.
  313. const int codebook_stride // as int4.
  314. ) {
  315. int sms;
  316. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  317. int waves = 0;
  318. int thread_m;
  319. do {
  320. waves++;
  321. thread_m = ceildiv(prob_m, waves * sms);
  322. } while (thread_m > THREAD_M);
  323. int blocks = ceildiv(prob_m, thread_m);
  324. int threads = 32 * thread_m;
  325. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  326. Code1x16Dequant<<<blocks, threads, 0, stream>>>(
  327. (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
  328. codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
  329. // most 3 long.
  330. codebook_stride // as int4.
  331. );
  332. }
  333. // Dequantizes the code and codebook into weights.
  334. void code2x8_dequant_cuda(
  335. const void* __restrict__ A, void* __restrict__ C,
  336. const void* __restrict__ codebook, int prob_m, int prob_k,
  337. const int4
  338. codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
  339. // most 3 long, corresponds to cols.
  340. const int codebook_stride // as int4
  341. ) {
  342. int sms;
  343. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  344. int waves = 0;
  345. int thread_m;
  346. do {
  347. waves++;
  348. thread_m = ceildiv(prob_m, waves * sms);
  349. } while (thread_m > THREAD_M);
  350. int blocks = ceildiv(prob_m, thread_m);
  351. int threads = 32 * thread_m;
  352. int shared = 16 * (2 * 256 * 8 + 32 * 9);
  353. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  354. cudaFuncSetAttribute(Code2x8Dequant,
  355. cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
  356. Code2x8Dequant<<<blocks, threads, shared, stream>>>(
  357. (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
  358. codebook_a_sizes, codebook_stride);
  359. }
  360. int codebook_stride(const torch::Tensor& codebooks) {
  361. return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
  362. }
  363. void code1x16_matvec(
  364. const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
  365. const torch::Tensor& codebook,
  366. const int4 codebook_a_sizes // cumulative sizes of A spanning each
  367. // codebook, at most 3 long.
  368. ) {
  369. const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  370. int prob_m = C.size(0);
  371. int prob_k = B.size(0);
  372. code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
  373. codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
  374. codebook_stride(codebook));
  375. }
  376. torch::Tensor code1x16_matmat(const torch::Tensor& input,
  377. const torch::Tensor& codes,
  378. const torch::Tensor& codebooks,
  379. const torch::Tensor& scales,
  380. const int4 codebook_a_sizes,
  381. const std::optional<torch::Tensor>& bias) {
  382. auto input_sizes = input.sizes();
  383. auto out_features = codes.size(0) * codebooks.size(2);
  384. auto flat_input = input.reshape({-1, input.size(-1)});
  385. auto flat_output = torch::empty(
  386. {flat_input.size(0), out_features},
  387. torch::TensorOptions().dtype(input.dtype()).device(input.device()));
  388. for (int i = 0; i < flat_input.size(0); ++i) {
  389. auto input_vec = flat_input.index({i});
  390. auto output_vec = flat_output.index({i});
  391. code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
  392. codebook_a_sizes);
  393. }
  394. flat_output *= scales.flatten().unsqueeze(0);
  395. if (bias.has_value()) {
  396. flat_output += bias->unsqueeze(0);
  397. }
  398. auto output_sizes = input_sizes.vec();
  399. output_sizes.pop_back();
  400. output_sizes.push_back(-1);
  401. auto output = flat_output.reshape(output_sizes);
  402. return output;
  403. }
  404. void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
  405. torch::Tensor& C, const torch::Tensor& codebook,
  406. const int4 codebook_a_sizes) {
  407. const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  408. int prob_m = C.size(0);
  409. int prob_k = B.size(0);
  410. code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
  411. codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
  412. 2 * codebook_stride(codebook));
  413. }
  414. torch::Tensor code2x8_matmat(const torch::Tensor& input,
  415. const torch::Tensor& codes,
  416. const torch::Tensor& codebooks,
  417. const torch::Tensor& scales,
  418. const int4 codebook_a_sizes,
  419. const std::optional<torch::Tensor>& bias) {
  420. auto input_sizes = input.sizes();
  421. auto out_features = codes.size(0) * codebooks.size(2);
  422. auto flat_input = input.reshape({-1, input.size(-1)});
  423. auto flat_output = torch::empty(
  424. {flat_input.size(0), out_features},
  425. torch::TensorOptions().dtype(input.dtype()).device(input.device()));
  426. for (int i = 0; i < flat_input.size(0); ++i) {
  427. auto input_vec = flat_input.index({i});
  428. auto output_vec = flat_output.index({i});
  429. code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
  430. codebook_a_sizes);
  431. }
  432. flat_output *= scales.flatten().unsqueeze(0);
  433. if (bias.has_value()) {
  434. flat_output += bias->unsqueeze(0);
  435. }
  436. auto output_sizes = input_sizes.vec();
  437. output_sizes.pop_back();
  438. output_sizes.push_back(-1);
  439. auto output = flat_output.reshape(output_sizes);
  440. return output;
  441. }
  442. // Accumulate the partition sizes.
  443. int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
  444. int4 cumulative_sizes;
  445. auto cumulative_size = &cumulative_sizes.x;
  446. size_t i = 0;
  447. int last = 0;
  448. assert(codebook_partition_sizes.size() <= 4);
  449. for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
  450. *cumulative_size = codebook_partition_sizes[i] + last;
  451. last = *cumulative_size;
  452. }
  453. // fill in the rest with unreachable.
  454. for (; i < 4; ++i, ++cumulative_size) {
  455. *cumulative_size = last * 10;
  456. }
  457. return cumulative_sizes;
  458. }
  459. } // namespace aqlm
  460. } // namespace aphrodite
  461. torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
  462. const torch::Tensor& codebooks,
  463. const torch::Tensor& scales,
  464. const std::vector<int64_t>& codebook_partition_sizes,
  465. const std::optional<torch::Tensor>& bias) {
  466. int4 cumulative_sizes =
  467. aphrodite::aqlm::accumulate_sizes(codebook_partition_sizes);
  468. int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
  469. int const entries = codebooks.size(1);
  470. if (nbooks == 1 && entries == (1 << 16)) {
  471. return aphrodite::aqlm::code1x16_matmat(input, codes, codebooks, scales,
  472. cumulative_sizes, bias);
  473. }
  474. if (nbooks == 2 && entries == (1 << 8)) {
  475. return aphrodite::aqlm::code2x8_matmat(input, codes, codebooks, scales,
  476. cumulative_sizes, bias);
  477. }
  478. TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
  479. " entries is not currently supported.")
  480. return {};
  481. }
  482. torch::Tensor aqlm_dequant(
  483. const torch::Tensor& codes, const torch::Tensor& codebooks,
  484. const std::vector<int64_t>& codebook_partition_sizes) {
  485. int4 cumulative_sizes =
  486. aphrodite::aqlm::accumulate_sizes(codebook_partition_sizes);
  487. int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
  488. int const entries = codebooks.size(1);
  489. const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
  490. int rows = codes.size(1);
  491. int cols = codes.size(0);
  492. auto in_features = codes.size(1) * 8;
  493. auto out_features = codes.size(0);
  494. assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
  495. codebook_partition_sizes.end(), 0));
  496. auto weights = torch::empty({out_features, in_features},
  497. torch::TensorOptions()
  498. .dtype(codebooks.dtype())
  499. .device(codebooks.device()));
  500. if (nbooks == 1 && entries == (1 << 16)) {
  501. aphrodite::aqlm::code1x16_dequant_cuda(
  502. codes.data_ptr(), weights.data_ptr(), codebooks.data_ptr(),
  503. out_features, in_features, cumulative_sizes,
  504. aphrodite::aqlm::codebook_stride(codebooks));
  505. // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
  506. // and not consistent with gemv implementation.) weights *=
  507. // scales.index({"...", 0, 0});
  508. return weights;
  509. }
  510. if (nbooks == 2 && entries == (1 << 8)) {
  511. aphrodite::aqlm::code2x8_dequant_cuda(
  512. codes.data_ptr(), weights.data_ptr(), codebooks.data_ptr(),
  513. out_features, in_features, cumulative_sizes,
  514. aphrodite::aqlm::codebook_stride(codebooks));
  515. // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
  516. // and not consistent with gemv implementation) weights *=
  517. // scales.index({"...", 0, 0});
  518. return weights;
  519. }
  520. TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
  521. " entries is not currently supported.")
  522. return {};
  523. }