fpA_intB_gemm_wrapper.cu 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #include <torch/extension.h>
  2. #include "cub/cub.cuh"
  3. #include <cuda_runtime.h>
  4. #include <cuda_fp16.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include "fpA_intB_gemm_wrapper.h"
  7. #include "fpA_intB_gemm.h"
  8. #include "cuda_utils.h"
  9. #include "weightOnlyBatchedGemv/enabled.h"
  10. #include "weightOnlyBatchedGemv/kernelLauncher.h"
  11. #include "torch_utils.h"
  12. #include <vector>
  13. namespace ft = fastertransformer;
  14. int getWorkspaceSize(const int m, const int n, const int k)
  15. {
  16. // These are the min tile sizes for each config, which would launch the maximum number of blocks
  17. const int max_grid_m = (m + 31) / 32;
  18. const int max_grid_n = (n + 127) / 128;
  19. const int split_k_limit = 7;
  20. // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
  21. return max_grid_m * max_grid_n * split_k_limit * 4;
  22. }
  23. std::vector<torch::Tensor>
  24. torch::Tensor w8_a16_gemm_forward_cuda(torch::Tensor &input,
  25. torch::Tensor &weight,
  26. torch::Tensor &scale)
  27. {
  28. c10::cuda::CUDAGuard device_guard(input.device());
  29. // TORCH_CHECK(input.dim() == 3 || input.dim() == 2, "Invalid input dim: ", input.dim());
  30. const int m = input.dim() == 2 ? input.size(0) : input.size(0) * input.size(1);
  31. const int k = input.size(-1);
  32. const int n = weight.size(-1);
  33. auto options = torch::TensorOptions().dtype(input.dtype()).device(input.device());
  34. torch::Tensor output = input.dim() == 2 ? torch::empty({m, n}, options) : torch::empty({input.size(0), input.size(1), n}, options);
  35. const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
  36. const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
  37. const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
  38. ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
  39. // const int max_size = std::max(n, k);
  40. // size_t workspace_size = getWorkspaceSize(m, max_size, max_size);
  41. // void *ptr = nullptr;
  42. // char *workspace_ptr = workspace_size > 0 ? (char *)cudaMalloc((void **)&ptr, workspace_size) : nullptr;
  43. const bool use_cuda_kernel = m <= SMALL_M_FAST_PATH;
  44. // const bool use_cuda_kernel = false;
  45. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  46. if(use_cuda_kernel){
  47. tensorrt_llm::kernels::WeightOnlyActivationType weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::FP16;
  48. tensorrt_llm::kernels::WeightOnlyQuantType weight_only_quant_type = tensorrt_llm::kernels::WeightOnlyQuantType::Int8b;
  49. tensorrt_llm::kernels::WeightOnlyParams params{weight_ptr, reinterpret_cast<const uint8_t *>(scale.data_ptr()), nullptr,
  50. reinterpret_cast<half *>(input.data_ptr()), nullptr, nullptr, reinterpret_cast<half *>(output.data_ptr()), m, n, k, 0, weight_only_quant_type,
  51. tensorrt_llm::kernels::WeightOnlyType::PerChannel,
  52. tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
  53. tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
  54. }
  55. else
  56. ft::gemm_fp16_int(
  57. input_ptr,
  58. weight_ptr,
  59. scale_ptr,
  60. output_ptr,
  61. m, n, k,
  62. nullptr,
  63. 0,
  64. stream);
  65. return output;
  66. }
  67. torch::Tensor w8_a16_gemm_forward_cuda_(torch::Tensor &input,
  68. torch::Tensor &weight,
  69. torch::Tensor &scale,
  70. torch::Tensor &output,
  71. const int m,
  72. const int n,
  73. const int k)
  74. {
  75. c10::cuda::CUDAGuard device_guard(input.device());
  76. const ft::half *input_ptr = reinterpret_cast<ft::half *>(input.data_ptr());
  77. const uint8_t *weight_ptr = reinterpret_cast<const uint8_t *>(weight.data_ptr());
  78. const ft::half *scale_ptr = reinterpret_cast<ft::half *>(scale.data_ptr());
  79. ft::half *output_ptr = reinterpret_cast<ft::half *>(output.data_ptr());
  80. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  81. ft::gemm_fp16_int(
  82. input_ptr,
  83. weight_ptr,
  84. scale_ptr,
  85. output_ptr,
  86. m, n, k,
  87. nullptr,
  88. 0,
  89. stream);
  90. return output;
  91. }