1
0

gemm_s4_f16.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <algorithm>
  18. #include <iomanip>
  19. #include <ios>
  20. #include <iostream>
  21. #include <limits>
  22. #include <numeric>
  23. #include <stdexcept>
  24. #include <tuple>
  25. #include <vector>
  26. #include "gemm_s4_f16.h"
  27. #include "gemm_s4_f16_kernel.h"
  28. #include "metric.h"
  29. #include "common.h"
  30. namespace aphrodite {
  31. namespace autoquant {
  32. bool g_dump_kernel_info_once = false;
  33. namespace ops {
  34. struct Identity {
  35. static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
  36. {
  37. if (n < N) {
  38. (uint&)C[n * M + m] = (uint&)data;
  39. }
  40. }
  41. static __inline__ __device__ void apply(uint data, int m, int n, __nv_bfloat16* C, int M, int N)
  42. {
  43. if (n < N) {
  44. (uint&)C[n * M + m] = (uint&)data;
  45. }
  46. }
  47. };
  48. struct SiluActivation {
  49. static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
  50. {
  51. auto u = __half22float2((half2&)data);
  52. float silu = u.x / (1.f + __expf(-u.x));
  53. half val = __float2half_rn(silu * u.y);
  54. if (n < N) {
  55. C[n * (M / 2) + m / 2] = val;
  56. }
  57. }
  58. static __inline__ __device__ void apply(uint data, int m, int n, __nv_bfloat16* C, int M, int N)
  59. {
  60. auto u = bfloat1622float2((__nv_bfloat162&)data);
  61. float silu = u.x / (1.f + __expf(-u.x));
  62. __nv_bfloat16 val = __float2bfloat16_rn(silu * u.y);
  63. if (n < N) {
  64. C[n * (M / 2) + m / 2] = val;
  65. }
  66. }
  67. };
  68. } // namespace ops
  69. template<typename... Ts>
  70. struct OutputOps {
  71. template<int index>
  72. static __inline__ __device__ void apply(uint data, int m, int n, half* C, int M, int N)
  73. {
  74. std::tuple_element_t<index, std::tuple<Ts...>>::apply(data, m, n, C, M, N);
  75. }
  76. template<int index>
  77. static __inline__ __device__ void apply(uint data, int m, int n, __nv_bfloat16* C, int M, int N)
  78. {
  79. std::tuple_element_t<index, std::tuple<Ts...>>::apply(data, m, n, C, M, N);
  80. }
  81. };
  82. template<typename T_BC, typename T_Q>
  83. void Impl<T_BC, T_Q>::Generate(std::vector<Kernels>& kernels)
  84. {
  85. // smem size (KB):
  86. // sm75: 64
  87. // sm80: 163
  88. // sm86: 99
  89. // sm89: 99
  90. // sm90: 227
  91. using Op = OutputOps<ops::Identity, ops::SiluActivation>;
  92. const int GS = 128;
  93. Kernels k;
  94. // 256
  95. k.emplace_back(new GemmKernel<Shape<256, 128, 32>, Shape<32, 128, 32>, 3, GS, Op, T_BC, T_Q>{});
  96. k.emplace_back(new GemmKernel<Shape<256, 64, 64>, Shape<64, 64, 32>, 3, GS, Op, T_BC, T_Q>{});
  97. k.emplace_back(new GemmKernel<Shape<256, 64, 32>, Shape<64, 64, 32>, 3, GS, Op, T_BC, T_Q>{});
  98. k.emplace_back(new GemmKernel<Shape<256, 32, 64>, Shape<64, 32, 32>, 3, GS, Op, T_BC, T_Q>{});
  99. k.emplace_back(new GemmKernel<Shape<256, 16, 256>, Shape<32, 16, 128>, 3, GS, Op, T_BC, T_Q>{});
  100. k.emplace_back(new GemmKernel<Shape<256, 8, 256>, Shape<32, 8, 128>, 3, GS, Op, T_BC, T_Q>{});
  101. // 128
  102. k.emplace_back(new GemmKernel<Shape<128, 128, 64>, Shape<32, 128, 32>, 3, GS, Op, T_BC, T_Q>{});
  103. k.emplace_back(new GemmKernel<Shape<128, 128, 32>, Shape<32, 128, 32>, 3, GS, Op, T_BC, T_Q>{});
  104. k.emplace_back(new GemmKernel<Shape<128, 96, 64>, Shape<32, 96, 32>, 3, GS, Op, T_BC, T_Q>{});
  105. k.emplace_back(new GemmKernel<Shape<128, 64, 64>, Shape<32, 64, 32>, 3, GS, Op, T_BC, T_Q>{});
  106. k.emplace_back(new GemmKernel<Shape<128, 64, 32>, Shape<32, 64, 32>, 3, GS, Op, T_BC, T_Q>{});
  107. k.emplace_back(new GemmKernel<Shape<128, 32, 128>, Shape<32, 32, 64>, 3, GS, Op, T_BC, T_Q>{});
  108. k.emplace_back(new GemmKernel<Shape<128, 16, 256>, Shape<32, 16, 64>, 3, GS, Op, T_BC, T_Q>{});
  109. k.emplace_back(new GemmKernel<Shape<128, 8, 512>, Shape<32, 8, 128>, 3, GS, Op, T_BC, T_Q>{});
  110. k.emplace_back(new GemmKernel<Shape<128, 8, 512>, Shape<32, 8, 128>, 2, GS, Op, T_BC, T_Q>{}); // for 86/89
  111. // 64
  112. k.emplace_back(new GemmKernel<Shape<64, 16, 256>, Shape<32, 16, 32>, 3, GS, Op, T_BC, T_Q>{});
  113. k.emplace_back(new GemmKernel<Shape<64, 8, 256>, Shape<32, 8, 32>, 3, GS, Op, T_BC, T_Q>{});
  114. kernels.push_back(std::move(k));
  115. }
  116. template<typename T_BC, typename T_Q>
  117. void Impl<T_BC, T_Q>::Measure(T_BC* C,
  118. const uint* A,
  119. const T_BC* B,
  120. const T_Q* Q,
  121. int m,
  122. int n,
  123. int k,
  124. int group_size,
  125. Type type,
  126. std::vector<Metric>& metrics,
  127. cudaStream_t st,
  128. std::vector<Kernels>& _kernels)
  129. {
  130. int gid = -1;
  131. for (size_t i = 0; i < group_sizes_.size(); ++i) {
  132. if (group_sizes_[i] == group_size) {
  133. gid = i;
  134. break;
  135. }
  136. }
  137. if (gid < 0) {
  138. throw std::runtime_error("unsupported group size");
  139. }
  140. const auto& kernels = _kernels[gid];
  141. metrics = std::vector<Metric>(kernels.size());
  142. int best = 0;
  143. for (size_t i = 0; i < kernels.size(); ++i) {
  144. metrics[i].id = i;
  145. kernels[i]->GetMetric(metrics[i], m, n, k);
  146. if (!metrics[i].feasible) {
  147. metrics[i].time = std::numeric_limits<float>::infinity();
  148. metrics[i].count = 1;
  149. continue;
  150. }
  151. if (Compare(metrics[i], metrics[best])) {
  152. best = i;
  153. }
  154. for (size_t j = 0; j < kWarmup + kMeasure; ++j) {
  155. if (j == kWarmup) {
  156. cudaEventRecord(ev_start_, st);
  157. }
  158. kernels[i]->Launch(C, A, B, Q, m, n, k, type, st);
  159. }
  160. cudaEventRecord(ev_end_, st);
  161. cudaEventSynchronize(ev_end_);
  162. float ms{};
  163. cudaEventElapsedTime(&ms, ev_start_, ev_end_);
  164. metrics[i].time = ms;
  165. metrics[i].count = kMeasure;
  166. }
  167. metrics[best].best = 1;
  168. // sort metrics
  169. std::vector<int> indices(kernels.size());
  170. std::iota(indices.begin(), indices.end(), 0);
  171. std::stable_sort(
  172. indices.begin(), indices.end(), [&](int i, int j) { return metrics[i].time < metrics[j].time; });
  173. if (g_dump_kernel_info_once) {
  174. DumpMetrics(std::cerr, metrics, indices);
  175. g_dump_kernel_info_once = 0;
  176. }
  177. std::vector<Metric> tmp;
  178. for (size_t i = 0; i < indices.size(); ++i) {
  179. tmp.push_back(metrics[indices[i]]);
  180. }
  181. metrics.swap(tmp);
  182. }
  183. static bool Compare(const Metric& a, const Metric& b)
  184. {
  185. if (a.feasible != b.feasible) {
  186. return a.feasible > b.feasible;
  187. }
  188. if (a.prefer != b.prefer) {
  189. return a.prefer > b.prefer;
  190. }
  191. return a.grid_norm < b.grid_norm;
  192. }
  193. template<typename T_BC, typename T_Q>
  194. int Impl<T_BC, T_Q>::Estimate(int m, int n, int k, Kernels& kernels)
  195. {
  196. int best = 0;
  197. std::vector<Metric> metrics(kernels.size());
  198. for (size_t i = 0; i < kernels.size(); ++i) {
  199. metrics[i].id = i;
  200. kernels[i]->GetMetric(metrics[i], m, n, k);
  201. if (Compare(metrics[i], metrics[best])) {
  202. best = i;
  203. }
  204. }
  205. if (g_dump_kernel_info_once) {
  206. std::vector<int> indices(kernels.size());
  207. std::iota(indices.begin(), indices.end(), 0);
  208. std::stable_sort(
  209. indices.begin(), indices.end(), [&](int i, int j) { return Compare(metrics[i], metrics[j]); });
  210. DumpMetrics(std::cerr, metrics, indices);
  211. g_dump_kernel_info_once = 0;
  212. }
  213. return best;
  214. }
  215. template<typename T_BC, typename T_Q>
  216. void Impl<T_BC, T_Q>::Run(T_BC* C,
  217. const uint* A,
  218. const T_BC* B,
  219. const T_Q* Q,
  220. int m,
  221. int n,
  222. int k,
  223. int group_size,
  224. Type type,
  225. int algo_id,
  226. cudaStream_t st,
  227. std::vector<Kernels>& kernels)
  228. {
  229. for (size_t i = 0; i < group_sizes_.size(); ++i) {
  230. if (group_sizes_[i] == group_size) {
  231. if (algo_id < 0) {
  232. algo_id = Estimate(m, n, k, kernels[i]);
  233. //printf("**** m: %d, n: %d, k: %d, Run algo_id: %d \n", m, n, k, algo_id);
  234. }
  235. if (algo_id < 0) {
  236. throw std::runtime_error("no feasible kernel found");
  237. }
  238. kernels[i].at(algo_id)->Launch(C, A, B, Q, m, n, k, type, st);
  239. return;
  240. }
  241. }
  242. throw std::runtime_error("unsupported group size");
  243. }
  244. template<typename T_BC, typename T_Q>
  245. Impl<T_BC, T_Q>::Impl()
  246. {
  247. cudaEventCreate(&ev_start_);
  248. cudaEventCreate(&ev_end_);
  249. using Ops = OutputOps<ops::Identity, ops::SiluActivation>;
  250. /// TODO: add more group sizes
  251. //Generate<128, Ops>(kernels_);
  252. Generate(kernels_);
  253. group_sizes_.push_back(128);
  254. }
  255. template<typename T_BC, typename T_Q>
  256. Impl<T_BC, T_Q>::~Impl()
  257. {
  258. cudaEventDestroy(ev_end_);
  259. cudaEventDestroy(ev_start_);
  260. }
  261. template struct Impl<half, half2>;
  262. template struct Impl<__nv_bfloat16, __nv_bfloat162>;
  263. template<typename T_BC, typename T_Q>
  264. GemmS4F16<T_BC, T_Q>::GemmS4F16(): impl_(std::make_unique<Impl<T_BC, T_Q>>()) {}
  265. template<typename T_BC, typename T_Q>
  266. GemmS4F16<T_BC, T_Q>::~GemmS4F16() = default;
  267. template<typename T_BC, typename T_Q>
  268. void GemmS4F16<T_BC, T_Q>::Measure(T_BC* C,
  269. const uint* A,
  270. const T_BC* B,
  271. const T_Q* Q,
  272. int m,
  273. int n,
  274. int k,
  275. int group_size,
  276. Type type,
  277. std::vector<Metric>& metrics,
  278. cudaStream_t st)
  279. {
  280. impl_->Measure(C, A, B, Q, m, n, k, group_size, type, metrics, st, impl_->kernels_);
  281. }
  282. template<typename T_BC, typename T_Q>
  283. void GemmS4F16<T_BC, T_Q>::Run(T_BC* C,
  284. const uint* A,
  285. const T_BC* B,
  286. const T_Q* Q,
  287. int m,
  288. int n,
  289. int k,
  290. int group_size,
  291. Type type,
  292. int algo_id,
  293. cudaStream_t st)
  294. {
  295. impl_->Run(C, A, B, Q, m, n, k, group_size, type, algo_id, st, impl_->kernels_);
  296. }
  297. template class GemmS4F16<half, half2>;
  298. template class GemmS4F16<__nv_bfloat16, __nv_bfloat162>;
  299. } // namespace autoquant
  300. } // namespace aphrodite