fpA_intB_gemm.cu 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #include "fpA_intB_gemm.h"
  2. #include "fpA_intB_gemm/fpA_intB_gemm_template.h"
  3. namespace fastertransformer
  4. {
  5. ActivationType get_activation(const std::string &activation_name)
  6. {
  7. if (activation_name == "identity")
  8. return ActivationType::Identity;
  9. if (activation_name == "relu")
  10. return ActivationType::Relu;
  11. if (activation_name == "silu")
  12. return ActivationType::Silu;
  13. if (activation_name == "gelu")
  14. return ActivationType::Gelu;
  15. // todo: more
  16. return ActivationType::InvalidType;
  17. }
  18. void gemm_fp16_int(const half *A,
  19. const uint8_t *B,
  20. const half *weight_scales,
  21. half *C,
  22. int m, int n, int k,
  23. char *workspace_ptr,
  24. size_t workspace_bytes,
  25. cudaStream_t stream)
  26. {
  27. CutlassFpAIntBGemmRunner<half, uint8_t> runner;
  28. runner.gemm(A, B, weight_scales,
  29. C, m, n, k, workspace_ptr, workspace_bytes, stream);
  30. }
  31. template <typename WeightType>
  32. void gemm_fp16_int_bias_act(const half *A,
  33. const WeightType *B,
  34. const half *weight_scales,
  35. const half *bias,
  36. half *C,
  37. std::optional<std::string> activation,
  38. int m, int n, int k, int bias_stride, char *workspace_ptr,
  39. size_t workspace_bytes, cudaStream_t stream)
  40. {
  41. CutlassFpAIntBGemmRunner<half, WeightType> runner;
  42. if (!activation && bias == nullptr)
  43. {
  44. runner.gemm(A, B, weight_scales,
  45. C, m, n, k, workspace_ptr, workspace_bytes, stream);
  46. }
  47. else if (!activation)
  48. {
  49. runner.gemm_bias_act(A, B, weight_scales, bias,
  50. C, m, n, k, bias_stride, ActivationType::Identity, workspace_ptr, workspace_bytes, stream);
  51. }
  52. else
  53. {
  54. runner.gemm_bias_act(A, B, weight_scales, bias,
  55. C, m, n, k, bias_stride, get_activation(*activation), workspace_ptr, workspace_bytes, stream);
  56. }
  57. }
  58. template <typename WeightType>
  59. void gemm_fp16_int_bias_act_residual(
  60. const half *A, const WeightType *B, const half *weight_scales,
  61. const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
  62. const std::string &unary_op, int m, int n,
  63. int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream)
  64. {
  65. CutlassFpAIntBGemmRunner<half, WeightType> runner;
  66. runner.gemm_bias_act_residual(A, B, weight_scales, bias, residual,
  67. C, m, n, k, activation, binary_op, unary_op, workspace_ptr, workspace_bytes, stream);
  68. }
  69. template void gemm_fp16_int_bias_act<uint4b_t>(const half *A, const uint4b_t *B,
  70. const half *weight_scales, const half *bias,
  71. half *C, std::optional<std::string> activation, int m,
  72. int n, int k, int bias_stride, char *workspace_ptr,
  73. size_t workspace_bytes, cudaStream_t stream);
  74. template void gemm_fp16_int_bias_act_residual<uint4b_t>(
  75. const half *A, const uint4b_t *B, const half *weight_scales,
  76. const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
  77. const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
  78. template void gemm_fp16_int_bias_act<uint8_t>(const half *A, const uint8_t *B,
  79. const half *weight_scales, const half *bias,
  80. half *C, std::optional<std::string> activation, int m,
  81. int n, int k, int bias_stride, char *workspace_ptr,
  82. size_t workspace_bytes, cudaStream_t stream);
  83. template void gemm_fp16_int_bias_act_residual<uint8_t>(
  84. const half *A, const uint8_t *B, const half *weight_scales,
  85. const half *bias, const half *residual, half *C, const std::string &activation, const std::string &binary_op,
  86. const std::string &unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
  87. } // namespace fastertransformer