gemm_s4_f16.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <cuda_fp16.h>
  19. #include <cuda_bf16.h>
  20. #include <cuda_runtime.h>
  21. #include <memory>
  22. #include <vector>
  23. #include "gemm_s4_f16_kernel.h"
  24. #include "metric.h"
  25. namespace aphrodite {
  26. namespace autoquant {
  27. extern bool g_dump_kernel_info_once;
  28. enum Type
  29. {
  30. kGemm,
  31. kFusedSiluFfn
  32. };
  33. template <typename T_BC, typename T_Q>
  34. struct Impl{
  35. using Kernels = std::vector<std::unique_ptr<IGemmKernel<T_BC, T_Q>>>;
  36. void Generate(std::vector<Kernels>& kernels);
  37. void Measure(T_BC* C,
  38. const uint* A,
  39. const T_BC* B,
  40. const T_Q* Q,
  41. int m,
  42. int n,
  43. int k,
  44. int group_size,
  45. Type type,
  46. std::vector<Metric>& metrics,
  47. cudaStream_t st,
  48. std::vector<Kernels>& _kernels);
  49. static bool Compare(const Metric& a, const Metric& b)
  50. {
  51. if (a.feasible != b.feasible) {
  52. return a.feasible > b.feasible;
  53. }
  54. if (a.prefer != b.prefer) {
  55. return a.prefer > b.prefer;
  56. }
  57. return a.grid_norm < b.grid_norm;
  58. }
  59. int Estimate(int m, int n, int k, Kernels& kernels);
  60. void Run(T_BC* C,
  61. const uint* A,
  62. const T_BC* B,
  63. const T_Q* Q,
  64. int m,
  65. int n,
  66. int k,
  67. int group_size,
  68. Type type,
  69. int algo_id,
  70. cudaStream_t st,
  71. std::vector<Kernels>& kernels);
  72. Impl();
  73. ~Impl();
  74. std::vector<Kernels> kernels_;
  75. std::vector<int> group_sizes_;
  76. static constexpr int kWarmup = 10;
  77. static constexpr int kMeasure = 100;
  78. cudaEvent_t ev_start_{};
  79. cudaEvent_t ev_end_{};
  80. };
  81. template <typename T_BC, typename T_Q>
  82. class GemmS4F16 {
  83. public:
  84. GemmS4F16();
  85. ~GemmS4F16();
  86. void Measure(T_BC* C,
  87. const uint* A,
  88. const T_BC* B,
  89. const T_Q* Q,
  90. int m,
  91. int n,
  92. int k,
  93. int group_size,
  94. Type type,
  95. std::vector<Metric>& metrics,
  96. cudaStream_t st);
  97. void Run(T_BC* C,
  98. const uint* A,
  99. const T_BC* B,
  100. const T_Q* Q,
  101. int m,
  102. int n,
  103. int k,
  104. int group_size,
  105. Type type,
  106. int algo_id,
  107. cudaStream_t st);
  108. private:
  109. //struct Impl<T_BC, T_Q>;
  110. std::unique_ptr<Impl<T_BC, T_Q>> impl_;
  111. };
  112. } // namespace autoquant
  113. } // namespace aphrodite