aqlm_cuda_kernel.cu 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. /*
  2. * Modified by Neural Magic
  3. * Adapted from https://github.com/Vahe1994/AQLM
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <cuda.h>
  18. #include <cuda_fp16.h>
  19. #include <cuda_runtime.h>
  20. #include <c10/cuda/CUDAStream.h>
  21. #include <iostream>
  22. __global__ void Code1x16MatVec(
  23. const int4* __restrict__ A,
  24. const int4* __restrict__ B,
  25. int4* __restrict__ C,
  26. const int4* __restrict__ codebook,
  27. const int prob_m,
  28. const int prob_k,
  29. const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
  30. const int codebook_stride // as int4.
  31. ) {
  32. int a_gl_stride = prob_k / 8 / 8;
  33. int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  34. bool pred = a_gl_rd < prob_m;
  35. if (pred)
  36. {
  37. // advance to the correct codebook, this easy because we only multiply one column of the codebook.
  38. auto codebook_size = &codebook_a_sizes.x;
  39. while (a_gl_rd >= *codebook_size)
  40. {
  41. codebook += codebook_stride;
  42. ++codebook_size;
  43. }
  44. }
  45. int b_gl_rd = 0;
  46. int c_gl_wr = a_gl_rd;
  47. a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  48. int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  49. __shared__ int4 sh_b[32 * 9];
  50. float res = 0;
  51. int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  52. while (iters--) {
  53. // We pad shared memory to avoid bank conflicts during reads
  54. __syncthreads();
  55. for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
  56. if (b_gl_rd + i < prob_k / 8)
  57. sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
  58. }
  59. __syncthreads();
  60. b_gl_rd += 32 * 8;
  61. int b_sh_rd = 9 * (threadIdx.x % 32);
  62. if (pred && a_gl_rd < a_gl_end) {
  63. const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
  64. #pragma unroll
  65. for (int i = 0; i < 8; i++) {
  66. uint32_t dec[4];
  67. // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
  68. // actually help us; this brings > 2x speedup.
  69. asm volatile (
  70. "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
  71. : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
  72. : "l"((void*) &codebook[enc[i]])
  73. );
  74. half2* a = reinterpret_cast<half2*>(&dec);
  75. half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
  76. half2 res2 = {};
  77. #pragma unroll
  78. for (int j = 0; j < 4; j++)
  79. res2 = __hfma2(a[j], b[j], res2);
  80. res += __half2float(res2.x) + __half2float(res2.y);
  81. b_sh_rd++;
  82. }
  83. a_gl_rd += 32;
  84. }
  85. }
  86. if (pred) {
  87. #pragma unroll
  88. for (int i = 16; i > 0; i /= 2)
  89. res += __shfl_down_sync(0xffffffff, res, i);
  90. if (threadIdx.x % 32 == 0)
  91. reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  92. }
  93. }
  94. __global__ void Code2x8MatVec(
  95. const int4* __restrict__ A,
  96. const int4* __restrict__ B,
  97. int4* __restrict__ C,
  98. const int4* __restrict__ codebook,
  99. int prob_m,
  100. int prob_k,
  101. const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
  102. const int codebook_stride // as int4.
  103. ) {
  104. int a_gl_stride = prob_k / 8 / 8;
  105. int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  106. bool pred = a_gl_rd < prob_m;
  107. if (pred)
  108. {
  109. // advance to the correct codebook, this easy because we only multiply one column of the codebook.
  110. auto codebook_size = &codebook_a_sizes.x;
  111. while (a_gl_rd >= *codebook_size)
  112. {
  113. codebook += codebook_stride;
  114. ++codebook_size;
  115. }
  116. }
  117. int b_gl_rd = 0;
  118. int c_gl_wr = a_gl_rd;
  119. a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  120. int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  121. int lane = threadIdx.x % 8;
  122. extern __shared__ int4 sh[];
  123. int4* sh_b = sh;
  124. int4* sh_code = sh_b + 32 * 9;
  125. int4* sh_code0 = sh_code;
  126. int4* sh_code1 = sh_code + 256 * 8;
  127. for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
  128. int4 dec = codebook[i];
  129. #pragma unroll
  130. for (int j = 0; j < 8; j++)
  131. sh_code[8 * i + (j + lane) % 8] = dec;
  132. }
  133. __syncthreads();
  134. float res = 0;
  135. int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  136. while (iters--) {
  137. // We pad shared memory to avoid bank conflicts during reads
  138. __syncthreads();
  139. for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
  140. if (b_gl_rd + i < prob_k / 8)
  141. sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
  142. }
  143. __syncthreads();
  144. b_gl_rd += 32 * 8;
  145. int b_sh_rd = 9 * (threadIdx.x % 32);
  146. if (pred && a_gl_rd < a_gl_end) {
  147. const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
  148. #pragma unroll
  149. for (int i = 0; i < 8; i++) {
  150. half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
  151. half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
  152. half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
  153. half2 res2 = {};
  154. #pragma unroll
  155. for (int j = 0; j < 4; j++)
  156. res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
  157. res += __half2float(res2.x) + __half2float(res2.y);
  158. b_sh_rd++;
  159. }
  160. a_gl_rd += 32;
  161. }
  162. }
  163. if (pred) {
  164. #pragma unroll
  165. for (int i = 16; i > 0; i /= 2)
  166. res += __shfl_down_sync(0xffffffff, res, i);
  167. if (threadIdx.x % 32 == 0)
  168. reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  169. }
  170. }
  171. inline int ceildiv(int a, int b) {
  172. return (a + b - 1) / b;
  173. }
  174. const int THREAD_M = 16;
  175. void code1x16_matvec_cuda(
  176. const void* __restrict__ A,
  177. const void* __restrict__ B,
  178. void* __restrict__ C,
  179. const void* __restrict__ codebook,
  180. int prob_m,
  181. int prob_k,
  182. const int4 codebook_a_sizes,
  183. const int codebook_stride
  184. ) {
  185. int sms;
  186. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  187. int waves = 0;
  188. int thread_m;
  189. do {
  190. waves++;
  191. thread_m = ceildiv(prob_m, waves * sms);
  192. } while (thread_m > THREAD_M);
  193. int blocks = ceildiv(prob_m, thread_m);
  194. int threads = 32 * thread_m;
  195. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  196. Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
  197. (const int4*) A,
  198. (const int4*) B,
  199. (int4*) C,
  200. (const int4*) codebook,
  201. prob_m,
  202. prob_k,
  203. codebook_a_sizes,
  204. codebook_stride
  205. );
  206. }
  207. void code2x8_matvec_cuda(
  208. const void* __restrict__ A,
  209. const void* __restrict__ B,
  210. void* __restrict__ C,
  211. const void* __restrict__ codebook,
  212. int prob_m,
  213. int prob_k,
  214. const int4 codebook_a_sizes,
  215. const int codebook_stride
  216. ) {
  217. int sms;
  218. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  219. int waves = 0;
  220. int thread_m;
  221. do {
  222. waves++;
  223. thread_m = ceildiv(prob_m, waves * sms);
  224. } while (thread_m > THREAD_M);
  225. int blocks = ceildiv(prob_m, thread_m);
  226. int threads = 32 * thread_m;
  227. int shared = 16 * (2 * 256 * 8 + 32 * 9);
  228. cudaFuncSetAttribute(
  229. Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  230. );
  231. cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  232. Code2x8MatVec<<<blocks, threads, shared, stream>>>(
  233. (const int4*) A,
  234. (const int4*) B,
  235. (int4*) C,
  236. (const int4*) codebook,
  237. prob_m,
  238. prob_k,
  239. codebook_a_sizes,
  240. codebook_stride
  241. );
  242. }