quant_cuda_kernel.cu 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. #include <torch/all.h>
  2. #include <torch/python.h>
  3. #include <cuda.h>
  4. #include <cuda_runtime.h>
  5. #include <cuda_fp16.h>
  6. // half-tensor
  7. #include <c10/cuda/CUDAStream.h>
  8. #include <ATen/cuda/CUDATensorMethods.cuh>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #define BLOCKWIDTH 128
  11. #define BLOCKHEIGHT4 16
  12. namespace aphrodite {
  13. namespace squeezellm {
  14. __device__ inline unsigned int as_unsigned(int i) {
  15. return *reinterpret_cast<unsigned int*>(&i);
  16. }
  17. // 4-bit matvec kernel (LUT-based)
  18. __global__ void NUQ4MatMulKernel(
  19. #ifndef USE_ROCM
  20. const half2* __restrict__ vec,
  21. #else
  22. const __half2* __restrict__ vec,
  23. #endif
  24. const int* __restrict__ mat,
  25. #ifndef USE_ROCM
  26. half2* __restrict__ mul,
  27. #else
  28. float2* __restrict__ mul,
  29. #endif
  30. const __half* __restrict__ lookup_table,
  31. int height,
  32. int width,
  33. int batch,
  34. int vec_height
  35. ) {
  36. const int blockwidth2 = BLOCKWIDTH / 2;
  37. int row = BLOCKHEIGHT4 * blockIdx.x;
  38. int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  39. #ifndef USE_ROCM
  40. __shared__ half2 blockvec[blockwidth2];
  41. #else
  42. __shared__ __half2 blockvec[blockwidth2];
  43. #endif
  44. __shared__ __half deq2[16][BLOCKWIDTH];
  45. int off = threadIdx.x;
  46. int column_offset = col * 16;
  47. for (int val = 0; val < 16; val += 1) {
  48. int lut_index = column_offset + val;
  49. deq2[val][off] = lookup_table[lut_index];
  50. }
  51. __half res;
  52. #ifndef USE_ROCM
  53. half2 res2;
  54. half2 tmp2;
  55. #else
  56. __half2 res2;
  57. __half2 tmp2;
  58. #endif
  59. int i;
  60. int k;
  61. unsigned int tmp1;
  62. unsigned int lut_index1, lut_index2;
  63. for (int b = 0; b < batch; ++b){
  64. i = width * row + col;
  65. res = __int2half_rd(0);
  66. k = 0;
  67. __syncthreads();
  68. if (threadIdx.x < blockwidth2)
  69. blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
  70. __syncthreads();
  71. while (k < blockwidth2) {
  72. tmp1 = as_unsigned(mat[i]);
  73. #ifndef USE_ROCM
  74. res2 = {};
  75. tmp2 = {};
  76. #else
  77. res2.x = __half_as_ushort(__float2half(0));
  78. res2.y = __half_as_ushort(__float2half(0));
  79. tmp2.x = __half_as_ushort(__float2half(0));
  80. tmp2.y = __half_as_ushort(__float2half(0));
  81. #endif
  82. lut_index1 = tmp1 & 0xF;
  83. lut_index2 = (tmp1 >> 4) & 0xF;
  84. #ifndef USE_ROCM
  85. tmp2.x = deq2[lut_index1][off];
  86. tmp2.y = deq2[lut_index2][off];
  87. #else
  88. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  89. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  90. #endif
  91. res2 = __hfma2(tmp2, blockvec[k + 0], res2);
  92. lut_index1 = (tmp1 >> 8) & 0xF;
  93. lut_index2 = (tmp1 >> 12) & 0xF;
  94. #ifndef USE_ROCM
  95. tmp2.x = deq2[lut_index1][off];
  96. tmp2.y = deq2[lut_index2][off];
  97. #else
  98. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  99. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  100. #endif
  101. res2 = __hfma2(tmp2, blockvec[k + 1], res2);
  102. lut_index1 = (tmp1 >> 16) & 0xF;
  103. lut_index2 = (tmp1 >> 20) & 0xF;
  104. #ifndef USE_ROCM
  105. tmp2.x = deq2[lut_index1][off];
  106. tmp2.y = deq2[lut_index2][off];
  107. #else
  108. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  109. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  110. #endif
  111. res2 = __hfma2(tmp2, blockvec[k + 2], res2);
  112. lut_index1 = (tmp1 >> 24) & 0xF;
  113. lut_index2 = (tmp1 >> 28) & 0xF;
  114. #ifndef USE_ROCM
  115. tmp2.x = deq2[lut_index1][off];
  116. tmp2.y = deq2[lut_index2][off];
  117. #else
  118. tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
  119. tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
  120. #endif
  121. res2 = __hfma2(tmp2, blockvec[k + 3], res2);
  122. #ifndef USE_ROCM
  123. res = __hadd(__hadd(res2.x, res2.y), res);
  124. #else
  125. res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
  126. #endif
  127. i += width;
  128. k += 4;
  129. }
  130. // col%2 -> only set one of the two values
  131. #ifndef USE_ROCM
  132. half2 res3 = {};
  133. if (col % 2 == 0) {
  134. res3.x = res;
  135. } else {
  136. res3.y = res;
  137. }
  138. #else
  139. __half2 res3;
  140. res3.x = __half_as_ushort(__float2half(0));
  141. res3.y = __half_as_ushort(__float2half(0));
  142. if (col % 2 == 0) {
  143. res3.x = __half_as_ushort(res);
  144. } else {
  145. res3.y = __half_as_ushort(res);
  146. }
  147. #endif
  148. #ifndef USE_ROCM
  149. atomicAdd(&mul[b * width / 2 + col / 2], res3);
  150. #else
  151. int tmp_addr = b * width / 2 + col / 2;
  152. atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
  153. atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
  154. #endif
  155. }
  156. }
  157. } // namespace squeezellm
  158. } // namespace aphrodite
  159. // 4-bit matvec kernel (LUT-based)
  160. void squeezellm_gemm(
  161. torch::Tensor vec,
  162. torch::Tensor mat,
  163. torch::Tensor mul,
  164. torch::Tensor lookup_table
  165. ) {
  166. int height = mat.size(0);
  167. int width = mat.size(1);
  168. int batch = vec.size(0);
  169. int vec_height = vec.size(1);
  170. dim3 blocks(
  171. (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
  172. (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  173. );
  174. dim3 threads(BLOCKWIDTH);
  175. const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
  176. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  177. aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
  178. #ifndef USE_ROCM
  179. (half2*) vec.data<at::Half>(),
  180. #else
  181. (__half2*) vec.data_ptr<at::Half>(),
  182. #endif
  183. mat.data_ptr<int>(),
  184. #ifndef USE_ROCM
  185. (half2*) mul.data<at::Half>(),
  186. (__half*) lookup_table.data<at::Half>(),
  187. #else
  188. (float2*) mul.data_ptr<float>(),
  189. (__half*) lookup_table.data_ptr<at::Half>(),
  190. #endif
  191. height, width, batch, vec_height
  192. );
  193. }
  194. #undef BLOCKWIDTH
  195. #undef BLOCKHEIGHT4