fpA_intB_gemm.h 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. #pragma once
  2. #include <string>
  3. #include <optional>
  4. #include <cuda_runtime.h>
  5. #include "cutlass/numeric_types.h"
  6. #include "cutlass/half.h"
  7. #include "cutlass/integer_subbyte.h"
  8. namespace fastertransformer {
  9. using half = cutlass::half_t;
  10. using uint4b_t = cutlass::uint4b_t;
  11. // TODO: Support more general bias shape
  12. // base gemm
  13. void gemm_fp16_int(const half *A, const uint8_t * B, const half *weight_scales,
  14. half *C, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
  15. template <typename WeightType>
  16. void gemm_fp16_int_bias_act(const half *A, const WeightType *B,
  17. const half *weight_scales, const half *bias,
  18. half *C, std::optional<std::string> activation, int m,
  19. int n, int k, int bias_stride, char *workspace_ptr,
  20. size_t workspace_bytes, cudaStream_t stream);
  21. template <typename WeightType>
  22. void gemm_fp16_int_bias_act_residual(
  23. const half *A, const WeightType *B, const half *weight_scales,
  24. const half *bias, const half *residual, half *C, const std::string& activation, const std::string& binary_op,
  25. const std::string& unary_op, int m, int n, int k, char *workspace_ptr, size_t workspace_bytes, cudaStream_t stream);
  26. } // namespace fastertransformer