q_gemm_exl2.cu 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #include <torch/extension.h>
  24. #include <c10/cuda/CUDAGuard.h>
  25. #include <ATen/cuda/CUDAContext.h>
  26. #include <cuda_runtime.h>
  27. #include "q_matrix.cuh"
  28. #include "matrix_view.cuh"
  29. #include "quant/qdq_2.cuh"
  30. #include "quant/qdq_3.cuh"
  31. #include "quant/qdq_4.cuh"
  32. #include "quant/qdq_5.cuh"
  33. #include "quant/qdq_6.cuh"
  34. #include "quant/qdq_8.cuh"
  35. #include "q_gemm_kernel.cuh"
  36. namespace aphrodite {
  37. namespace exl2 {
  38. #define MAX_Q_GEMM_ROWS 32
  39. #define EXL2_BLOCK_KN_SIZE 64
  40. #define EXL2_BLOCK_M_SIZE_MAX 8
  41. #define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
  42. #if defined(USE_ROCM)
  43. __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
  44. hipblasOperation_t transA,
  45. hipblasOperation_t transB,
  46. int m,
  47. int n,
  48. int k,
  49. const half* alpha,
  50. const half* AP,
  51. int lda,
  52. const half* BP,
  53. int ldb,
  54. const half* beta,
  55. half* CP,
  56. int ldc) {
  57. return hipblasHgemm(handle, transA, transB, m, n, k,
  58. reinterpret_cast<const hipblasHalf *>(alpha),
  59. reinterpret_cast<const hipblasHalf *>(AP), lda,
  60. reinterpret_cast<const hipblasHalf *>(BP), ldb,
  61. reinterpret_cast<const hipblasHalf *>(beta),
  62. reinterpret_cast<hipblasHalf *>(CP), ldc);
  63. }
  64. #define hipblasHgemm __compat_hipblasHgemm
  65. #endif
  66. #define DIVIDE(x, size) (((x) + (size) - 1) / (size))
  67. void gemm_half_q_half_cuda_part
  68. (
  69. const half* a,
  70. QMatrix* b,
  71. half* c,
  72. int size_m,
  73. int size_n,
  74. int size_k,
  75. int m_count,
  76. bool clear
  77. )
  78. {
  79. {
  80. dim3 blockDim, gridDim;
  81. blockDim.x = EXL2_BLOCK_KN_SIZE;
  82. blockDim.y = 1;
  83. blockDim.z = 1;
  84. gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
  85. gridDim.y = DIVIDE(size_m, m_count);
  86. gridDim.z = DIVIDE(b->height, EXL2_BLOCK_KN_SIZE);
  87. fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count);
  88. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  89. kernel<<<gridDim, blockDim, 0, stream>>>
  90. (
  91. a,
  92. b->cuda_q_weight,
  93. b->cuda_q_scale,
  94. b->cuda_q_scale_max,
  95. c,
  96. size_m,
  97. size_n,
  98. size_k,
  99. b->height,
  100. b->groups,
  101. b->cuda_q_group_map,
  102. b->cuda_q_perm,
  103. b->rows_8,
  104. b->rows_6,
  105. b->rows_5,
  106. b->rows_4,
  107. b->rows_3,
  108. b->rows_2,
  109. clear
  110. );
  111. }
  112. }
  113. void gemm_half_q_half_cuda
  114. (
  115. cublasHandle_t cublas_handle,
  116. const half* a,
  117. QMatrix* b,
  118. half* c,
  119. int size_m,
  120. int size_n,
  121. int size_k,
  122. bool clear,
  123. half* temp_dq
  124. )
  125. {
  126. if (size_m > MAX_Q_GEMM_ROWS)
  127. {
  128. // Reconstruct FP16 matrix, then cuBLAS
  129. b->reconstruct(temp_dq);
  130. //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
  131. const half alpha = __float2half(1.0f);
  132. const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
  133. cublasHgemm(cublas_handle,
  134. CUBLAS_OP_N,
  135. CUBLAS_OP_N,
  136. size_n, size_m, size_k,
  137. &alpha, temp_dq, size_n,
  138. a, size_k,
  139. &beta, c, size_n);
  140. }
  141. else
  142. {
  143. // Quantized matmul
  144. int block_m_size_max = EXL2_BLOCK_M_SIZE_MAX;
  145. int max_chunks = size_m / block_m_size_max;
  146. int last_chunk = max_chunks * block_m_size_max;
  147. int last_chunk_size = size_m - last_chunk;
  148. if (max_chunks)
  149. {
  150. gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear);
  151. }
  152. if (last_chunk_size)
  153. {
  154. gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
  155. }
  156. }
  157. }
  158. } // namespace exl2
  159. } // namespace aphrodite
  160. torch::Tensor exl2_gemm
  161. (
  162. torch::Tensor a,
  163. uintptr_t b
  164. )
  165. {
  166. const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
  167. aphrodite::exl2::QMatrix* qm = reinterpret_cast<aphrodite::exl2::QMatrix*> (b);
  168. auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
  169. at::Tensor c = torch::empty({a.size(0), qm->width}, options);
  170. at::Tensor temp_dq;
  171. if (c.size(0) > MAX_Q_GEMM_ROWS) {
  172. temp_dq = torch::zeros({a.size(1), qm->width}, options);
  173. }
  174. aphrodite::exl2::gemm_half_q_half_cuda
  175. (
  176. at::cuda::getCurrentCUDABlasHandle(),
  177. (const half*) a.data_ptr(),
  178. qm,
  179. (half*) c.data_ptr(),
  180. c.size(0), // m
  181. c.size(1), // n
  182. a.size(1), // k
  183. true,
  184. c.size(0) > MAX_Q_GEMM_ROWS? (half*)temp_dq.data_ptr() : NULL
  185. );
  186. return c;
  187. }
  188. uintptr_t make_q_matrix
  189. (
  190. torch::Tensor q_weight,
  191. torch::Tensor q_perm,
  192. torch::Tensor q_invperm,
  193. torch::Tensor q_scale,
  194. torch::Tensor q_scale_max,
  195. torch::Tensor q_groups,
  196. torch::Tensor q_group_map
  197. )
  198. {
  199. const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
  200. int device = q_weight.device().index();
  201. int width = q_weight.size(1);
  202. int groups = q_scale.size(0);
  203. int height = q_perm.size(0);
  204. aphrodite::exl2::QMatrix* m = new aphrodite::exl2::QMatrix
  205. (
  206. device,
  207. height,
  208. width,
  209. groups,
  210. (uint32_t*) q_weight.data_ptr(),
  211. (uint16_t*) q_perm.data_ptr(),
  212. (uint16_t*) q_invperm.data_ptr(),
  213. (uint32_t*) q_scale.data_ptr(),
  214. (half*) q_scale_max.data_ptr(),
  215. (uint16_t*) q_groups.data_ptr(),
  216. (uint16_t*) q_group_map.data_ptr()
  217. );
  218. return reinterpret_cast<uintptr_t>(m);
  219. }