alt_matmul_kernel.cu 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #include <torch/all.h>
  2. #include <torch/python.h>
  3. #include <cuda.h>
  4. #include <cuda_runtime.h>
  5. #include <cuda_fp16.h>
  6. #include "cu_compat.cuh"
  7. const int BLOCKWIDTH = 256;
  8. const int BLOCKHEIGHT = 32;
  9. __device__ inline unsigned int as_unsigned(int i) {
  10. return *reinterpret_cast<unsigned int*>(&i);
  11. }
  12. __device__ inline int as_int(int i) {
  13. return *reinterpret_cast<int*>(&i);
  14. }
  15. template <typename scalar_t>
  16. __global__ void VecQuant4MatMulKernel(
  17. const scalar_t* __restrict__ vec,
  18. const int* __restrict__ mat,
  19. scalar_t* __restrict__ mul,
  20. const scalar_t* __restrict__ scales,
  21. const int* __restrict__ zeros,
  22. const int* __restrict__ g_idx,
  23. int batch,
  24. int vec_height,
  25. int height,
  26. int width,
  27. int zero_width
  28. ) {
  29. int h = BLOCKHEIGHT * blockIdx.x;
  30. int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
  31. int h_end = min(h + BLOCKHEIGHT, height);
  32. __shared__ scalar_t blockvec[BLOCKWIDTH];
  33. int i = width * h + w;
  34. int g_h = h * 8;
  35. int h_range = (h_end - h) * 8;
  36. int k;
  37. unsigned int g;
  38. scalar_t w_tmp;
  39. int z_w = w / 8;
  40. int z_mod = (w % 8) * 4;
  41. float weight[BLOCKWIDTH];
  42. if (w < width) {
  43. for (k = 0; k < h_range; ++k) {
  44. int k_w = (k / 8);
  45. int k_bit = (k % 8) * 4;
  46. g = as_int(g_idx[g_h + k]);
  47. scalar_t scale = scales[g * width + w];
  48. scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
  49. w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
  50. weight[k] = scale * (w_tmp - zero);
  51. }
  52. }
  53. scalar_t res;
  54. for (int b = 0; b < batch; ++b) {
  55. res = 0;
  56. if (threadIdx.x < h_range) {
  57. blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
  58. }
  59. __syncthreads();
  60. if (w < width) {
  61. for (k = 0; k < h_range; ++k){
  62. res += weight[k] * blockvec[k];
  63. }
  64. atomicAdd(&mul[b * width + w], res);
  65. }
  66. __syncthreads();
  67. }
  68. }
  69. void vecquant4matmul_cuda(
  70. torch::Tensor vec,
  71. torch::Tensor mat,
  72. torch::Tensor mul,
  73. torch::Tensor scales,
  74. torch::Tensor zeros,
  75. torch::Tensor g_idx
  76. ) {
  77. int batch = vec.size(0);
  78. int vec_height = vec.size(1);
  79. int height = mat.size(0);
  80. int width = mat.size(1);
  81. int zero_width = zeros.size(1);
  82. dim3 blocks(
  83. (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT,
  84. (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  85. );
  86. dim3 threads(BLOCKWIDTH);
  87. AT_DISPATCH_FLOATING_TYPES(
  88. vec.type(), "vecquant4matmul_cuda", ([&] {
  89. VecQuant4MatMulKernel<<<blocks, threads>>>(
  90. vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
  91. scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
  92. batch, vec_height, height, width, zero_width
  93. );
  94. })
  95. );
  96. }