1
0

quant_cuda_kernel.cu 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. #include <torch/all.h>
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <cuda_fp16.h>
  5. // half-tensor
  6. #include <c10/cuda/CUDAStream.h>
  7. #include <ATen/cuda/CUDATensorMethods.cuh>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #define BLOCKWIDTH 128
  10. #define BLOCKHEIGHT4 16
  11. namespace aphrodite {
  12. namespace squeezellm {
  13. __device__ inline unsigned int as_unsigned(int i) {
  14. return *reinterpret_cast<unsigned int*>(&i);
  15. }
  16. // 4-bit matvec kernel (LUT-based)
  17. __global__ void NUQ4MatMulKernel(
  18. #ifndef USE_ROCM
  19. const half2* __restrict__ vec,
  20. #else
  21. const __half2* __restrict__ vec,
  22. #endif
  23. const int* __restrict__ mat,
  24. #ifndef USE_ROCM
  25. half2* __restrict__ mul,
  26. #else
  27. float2* __restrict__ mul,
  28. #endif
  29. const __half* __restrict__ lookup_table, int height, int width, int batch,
  30. int vec_height) {
  31. const int blockwidth2 = BLOCKWIDTH / 2;
  32. int row = BLOCKHEIGHT4 * blockIdx.x;
  33. int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  34. #ifndef USE_ROCM
  35. __shared__ half2 blockvec[blockwidth2];
  36. #else
  37. __shared__ __half2 blockvec[blockwidth2];
  38. #endif
  39. __shared__ __half deq2[16][BLOCKWIDTH];
  40. int off = threadIdx.x;
  41. int column_offset = col * 16;
  42. for (int val = 0; val < 16; val += 1) {
  43. int lut_index = column_offset + val;
  44. deq2[val][off] = lookup_table[lut_index];
  45. }
  46. __half res;
  47. #ifndef USE_ROCM
  48. half2 res2;
  49. half2 tmp2;
  50. #else
  51. __half2 res2;
  52. __half2 tmp2;
  53. #endif
  54. int i;
  55. int k;
  56. unsigned int tmp1;
  57. unsigned int lut_index1, lut_index2;
  58. for (int b = 0; b < batch; ++b) {
  59. i = width * row + col;
  60. res = __int2half_rd(0);
  61. k = 0;
  62. __syncthreads();
  63. if (threadIdx.x < blockwidth2)
  64. blockvec[threadIdx.x] =
  65. vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
  66. threadIdx.x];
  67. __syncthreads();
  68. while (k < blockwidth2) {
  69. tmp1 = as_unsigned(mat[i]);
  70. #ifndef USE_ROCM
  71. res2 = {};
  72. tmp2 = {};
  73. #else
  74. res2.x = __half_as_ushort(__float2half(0));
  75. res2.y = __half_as_ushort(__float2half(0));
  76. tmp2.x = __half_as_ushort(__float2half(0));
  77. tmp2.y = __half_as_ushort(__float2half(0));
  78. #endif
  79. lut_index1 = tmp1 & 0xF;
  80. lut_index2 = (tmp1 >> 4) & 0xF;
  81. #ifndef USE_ROCM
  82. tmp2.x = deq2[lut_index1][off];
  83. tmp2.y = deq2[lut_index2][off];
  84. #else
  85. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  86. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  87. #endif
  88. res2 = __hfma2(tmp2, blockvec[k + 0], res2);
  89. lut_index1 = (tmp1 >> 8) & 0xF;
  90. lut_index2 = (tmp1 >> 12) & 0xF;
  91. #ifndef USE_ROCM
  92. tmp2.x = deq2[lut_index1][off];
  93. tmp2.y = deq2[lut_index2][off];
  94. #else
  95. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  96. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  97. #endif
  98. res2 = __hfma2(tmp2, blockvec[k + 1], res2);
  99. lut_index1 = (tmp1 >> 16) & 0xF;
  100. lut_index2 = (tmp1 >> 20) & 0xF;
  101. #ifndef USE_ROCM
  102. tmp2.x = deq2[lut_index1][off];
  103. tmp2.y = deq2[lut_index2][off];
  104. #else
  105. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  106. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  107. #endif
  108. res2 = __hfma2(tmp2, blockvec[k + 2], res2);
  109. lut_index1 = (tmp1 >> 24) & 0xF;
  110. lut_index2 = (tmp1 >> 28) & 0xF;
  111. #ifndef USE_ROCM
  112. tmp2.x = deq2[lut_index1][off];
  113. tmp2.y = deq2[lut_index2][off];
  114. #else
  115. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  116. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  117. #endif
  118. res2 = __hfma2(tmp2, blockvec[k + 3], res2);
  119. #ifndef USE_ROCM
  120. res = __hadd(__hadd(res2.x, res2.y), res);
  121. #else
  122. res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
  123. res);
  124. #endif
  125. i += width;
  126. k += 4;
  127. }
  128. // col%2 -> only set one of the two values
  129. #ifndef USE_ROCM
  130. half2 res3 = {};
  131. if (col % 2 == 0) {
  132. res3.x = res;
  133. } else {
  134. res3.y = res;
  135. }
  136. #else
  137. __half2 res3;
  138. res3.x = __half_as_ushort(__float2half(0));
  139. res3.y = __half_as_ushort(__float2half(0));
  140. if (col % 2 == 0) {
  141. res3.x = __half_as_ushort(res);
  142. } else {
  143. res3.y = __half_as_ushort(res);
  144. }
  145. #endif
  146. #ifndef USE_ROCM
  147. atomicAdd(&mul[b * width / 2 + col / 2], res3);
  148. #else
  149. int tmp_addr = b * width / 2 + col / 2;
  150. atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
  151. atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
  152. #endif
  153. }
  154. }
  155. } // namespace squeezellm
  156. } // namespace aphrodite
  157. // 4-bit matvec kernel (LUT-based)
  158. void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
  159. torch::Tensor lookup_table) {
  160. int height = mat.size(0);
  161. int width = mat.size(1);
  162. int batch = vec.size(0);
  163. int vec_height = vec.size(1);
  164. dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  165. (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
  166. dim3 threads(BLOCKWIDTH);
  167. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  168. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  169. aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
  170. #ifndef USE_ROCM
  171. (half2*)vec.data_ptr<at::Half>(),
  172. #else
  173. (__half2*)vec.data_ptr<at::Half>(),
  174. #endif
  175. mat.data_ptr<int>(),
  176. #ifndef USE_ROCM
  177. (half2*)mul.data_ptr<at::Half>(),
  178. (__half*)lookup_table.data_ptr<at::Half>(),
  179. #else
  180. (float2*)mul.data_ptr<float>(),
  181. (__half*)lookup_table.data_ptr<at::Half>(),
  182. #endif
  183. height, width, batch, vec_height);
  184. }
  185. #undef BLOCKWIDTH
  186. #undef BLOCKHEIGHT4