int4_fp16_gemm_kernels.cu 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #include <torch/extension.h>
  2. #include <cuda_fp16.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <vector>
  5. #include "format.h"
  6. #include "gemm_s4_f16.h"
  7. // in_feats: M, IC [float16]
  8. // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
  9. // scaling_factors: IC // G, OC [float16]
  10. // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
  11. // assume that batch_size < 16 for now
  12. void autoquant_convert_s4_k_m8(
  13. torch::Tensor _weight_dest,
  14. torch::Tensor _quant_scales_zeros_dest,
  15. torch::Tensor _workspace,
  16. torch::Tensor _quant_weight_src,
  17. torch::Tensor _quant_scales,
  18. torch::Tensor _quant_zeros,
  19. int m,
  20. int k,
  21. int group_size){
  22. auto st_ = _quant_scales.scalar_type();
  23. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  24. if(st_ == at::ScalarType::Half){
  25. auto weight_dest = reinterpret_cast<uint32_t*>(_weight_dest.data_ptr<int32_t>());
  26. auto quant_scales_zeros_dest = reinterpret_cast<half2*>(_quant_scales_zeros_dest.data_ptr<int32_t>());
  27. auto workspace = reinterpret_cast<half*>(_workspace.data_ptr<at::Half>());
  28. auto quant_weight_src = reinterpret_cast<uint32_t*>(_quant_weight_src.data_ptr<int32_t>());
  29. auto quant_scales = reinterpret_cast<half*>(_quant_scales.data_ptr<at::Half>());
  30. auto quant_zeros = reinterpret_cast<uint32_t*>(_quant_zeros.data_ptr<int32_t>());
  31. aphrodite::autoquant::convert_s4_k_m8(weight_dest, quant_scales_zeros_dest, workspace, quant_weight_src, quant_scales, quant_zeros,
  32. m, k, group_size, stream);
  33. }
  34. else{
  35. auto weight_dest = reinterpret_cast<uint32_t*>(_weight_dest.data_ptr<int32_t>());
  36. auto quant_scales_zeros_dest = reinterpret_cast<__nv_bfloat162*>(_quant_scales_zeros_dest.data_ptr<int32_t>());
  37. auto workspace = reinterpret_cast<__nv_bfloat16*>(_workspace.data_ptr<at::BFloat16>());
  38. auto quant_weight_src = reinterpret_cast<uint32_t*>(_quant_weight_src.data_ptr<int32_t>());
  39. auto quant_scales = reinterpret_cast<__nv_bfloat16*>(_quant_scales.data_ptr<at::BFloat16>());
  40. auto quant_zeros = reinterpret_cast<uint32_t*>(_quant_zeros.data_ptr<int32_t>());
  41. aphrodite::autoquant::convert_s4_k_m8(weight_dest, quant_scales_zeros_dest, workspace, quant_weight_src, quant_scales, quant_zeros,
  42. m, k, group_size, stream);
  43. }
  44. }
  45. torch::Tensor autoquant_s4_f16_gemm(
  46. torch::Tensor _in_feats,
  47. torch::Tensor _kernel,
  48. torch::Tensor _scales_zeros)
  49. {
  50. int num_in_feats = _in_feats.size(0);
  51. int num_in_channels = _in_feats.size(1);
  52. const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
  53. auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
  54. at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(1) * 8}, options);
  55. int num_out_feats = _out_feats.size(-2);
  56. int num_out_channels = _out_feats.size(-1);
  57. auto st_ = _in_feats.scalar_type();
  58. if(st_ == at::ScalarType::Half){
  59. static aphrodite::autoquant::GemmS4F16<half, half2> gemm_s4_f16_;
  60. auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
  61. auto kernel = reinterpret_cast<const uint*>(_kernel.data_ptr<int32_t>());
  62. auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
  63. auto scales_zeros = reinterpret_cast<half2*>(_scales_zeros.data_ptr<int32_t>());
  64. int group_size = num_in_channels / _scales_zeros.size(0);
  65. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  66. gemm_s4_f16_.Run(out_feats,
  67. kernel,
  68. in_feats,
  69. scales_zeros,
  70. num_out_channels,
  71. num_in_feats,
  72. num_in_channels,
  73. group_size,
  74. aphrodite::autoquant::kGemm,
  75. -1,
  76. stream);
  77. return _out_feats;
  78. }
  79. else{
  80. static aphrodite::autoquant::GemmS4F16<__nv_bfloat16, __nv_bfloat162> gemm_s4_bf16_;
  81. auto in_feats = reinterpret_cast<__nv_bfloat16*>(_in_feats.data_ptr<at::BFloat16>());
  82. auto kernel = reinterpret_cast<const uint*>(_kernel.data_ptr<int32_t>());
  83. auto out_feats = reinterpret_cast<__nv_bfloat16*>(_out_feats.data_ptr<at::BFloat16>());
  84. auto scales_zeros = reinterpret_cast<__nv_bfloat162*>(_scales_zeros.data_ptr<int32_t>());
  85. int group_size = num_in_channels / _scales_zeros.size(0);
  86. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  87. gemm_s4_bf16_.Run(out_feats,
  88. kernel,
  89. in_feats,
  90. scales_zeros,
  91. num_out_channels,
  92. num_in_feats,
  93. num_in_channels,
  94. group_size,
  95. aphrodite::autoquant::kGemm,
  96. -1,
  97. stream);
  98. return _out_feats;
  99. }
  100. }