gemm_s4_f16_kernel.h 10 KB


  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <iostream>
  19. #include <memory>
  20. #include <sstream>
  21. #include "gemm_template.h"
  22. #include "metric.h"
  23. namespace aphrodite {
  24. namespace autoquant {
  25. template<typename T_BC, typename T_Q>
  26. struct IGemmKernel {
  27. virtual ~IGemmKernel() = default;
  28. virtual void GetMetric(Metric& metric, int m, int n, int k) = 0;
  29. virtual void Launch(T_BC* C,
  30. const uint* A,
  31. const T_BC* B,
  32. const T_Q* Q,
  33. int M,
  34. int N,
  35. int K,
  36. int output_op_idx,
  37. cudaStream_t) = 0;
  38. virtual void Dump(std::ostream& os) = 0;
  39. };
  40. template<typename CtaShape, typename WarpShape, int Stages, int GroupSize, typename OutputOps, typename T_BC, typename T_Q>
  41. struct GemmKernel: public IGemmKernel<T_BC, T_Q> {
  42. static constexpr CtaShape cta_shape{};
  43. static constexpr WarpShape warp_shape{};
  44. using GemmType = Gemm<cta_shape.m(),
  45. cta_shape.n(),
  46. cta_shape.k(),
  47. warp_shape.m(),
  48. warp_shape.n(),
  49. warp_shape.k(),
  50. Stages,
  51. GroupSize,
  52. OutputOps,
  53. T_BC,
  54. T_Q>;
  55. decltype(&gemm_s4_f16_nn<GemmType, T_BC, T_Q>) kernel_func_;
  56. std::shared_ptr<cudaDeviceProp> props_;
  57. int max_active_ctas_{};
  58. static constexpr int kSlices = GemmType::SLICES;
  59. static constexpr int kSmemSizeA = GemmType::IteratorA::kSmemByteSize * kSlices;
  60. static constexpr int kSmemSizeB = GemmType::IteratorB::kSmemByteSize * kSlices;
  61. static constexpr int kSmemSizeC = sizeof(float) * cta_shape.m() * cta_shape.n();
  62. static constexpr int kSmemByteSize = std::max(kSmemSizeA + kSmemSizeB, kSmemSizeC);
  63. // static shared memory size of Q
  64. static constexpr int kSmemSizeQ = sizeof(typename GemmType::IteratorQ::Storage);
  65. explicit GemmKernel(std::shared_ptr<cudaDeviceProp> props = {}): props_(std::move(props))
  66. {
  67. if (!props_) {
  68. props_ = std::make_shared<cudaDeviceProp>();
  69. int device_id = -1;
  70. cudaGetDevice(&device_id);
  71. cudaGetDeviceProperties(props_.get(), device_id);
  72. }
  73. kernel_func_ = gemm_s4_f16_nn<GemmType, T_BC, T_Q>;
  74. cudaFuncSetAttribute(kernel_func_, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
  75. cudaOccupancyMaxActiveBlocksPerMultiprocessor(
  76. &max_active_ctas_, kernel_func_, GemmType::kWarpCount * WARP_SIZE, kSmemByteSize);
  77. };
  78. bool is_feasible(int m, int n, int k)
  79. {
  80. return m % cta_shape.m() == 0 && k % cta_shape.k() == 0;
  81. }
  82. void GetMetric(Metric& metric, int m, int n, int k) override
  83. {
  84. metric.cta_shape = {cta_shape.m(), cta_shape.n(), cta_shape.k()};
  85. metric.warp_shape = {warp_shape.m(), warp_shape.n(), warp_shape.k()};
  86. metric.warps = GemmType::kWarpCount;
  87. metric.stages = Stages;
  88. metric.smem = (kSmemByteSize + kSmemSizeQ) / 1024.f;
  89. metric.feasible = is_feasible(m, n, k) && max_active_ctas_ > 0;
  90. metric.prefer = cta_shape.m() != 64 || m <= k;
  91. if (!metric.feasible) {
  92. return;
  93. }
  94. int grid_size = ((m + cta_shape.m() - 1) / cta_shape.m()) * ((n + cta_shape.n() - 1) / cta_shape.n());
  95. metric.grid_size = grid_size;
  96. metric.max_active_ctas = max_active_ctas_;
  97. metric.active_ctas =
  98. std::min(max_active_ctas_, (grid_size + props_->multiProcessorCount - 1) / props_->multiProcessorCount);
  99. metric.waves = (float)grid_size / (props_->multiProcessorCount * metric.active_ctas);
  100. metric.occupancy = (metric.active_ctas * GemmType::kWarpCount)
  101. / (float)(props_->maxThreadsPerMultiProcessor / props_->warpSize);
  102. metric.cta_cnt_m = (m + cta_shape.m() - 1) / cta_shape.m();
  103. metric.cta_cnt_n = (n + cta_shape.n() - 1) / cta_shape.n();
  104. metric.cta_iter_k = (k + cta_shape.k() - 1) / cta_shape.k();
  105. metric.tile_efficiency = (float)n / (metric.cta_cnt_n * cta_shape.n());
  106. metric.wave_efficiency = metric.waves / std::ceil(metric.waves);
  107. const int m_pad = (m + cta_shape.m() - 1) / cta_shape.m() * cta_shape.m();
  108. const int n_pad = (n + cta_shape.n() - 1) / cta_shape.n() * cta_shape.n();
  109. metric.grid_a0 = 0.25f * m * n_pad / cta_shape.n(); // Ta0 * M * [N / ctaN]
  110. metric.grid_b0 = 1.00f * n * m_pad / cta_shape.m(); // Tb0 * N * [M / ctaM]
  111. metric.grid_a1 = 0.65f * m_pad * n_pad / warp_shape.n(); // Ta1 * [M] * [N] / warpN
  112. metric.grid_b1 = 0.25f * m_pad * n_pad / warp_shape.m(); // Tb1 * [M] * [N] / warpM
  113. metric.grid_mm = 1.00f * m_pad * n_pad / 64; // Tm * [M] * [N]
  114. metric.grid_sum = metric.grid_a0 + metric.grid_b0 + metric.grid_a1 + metric.grid_b1 + metric.grid_mm;
  115. metric.cta_sum = metric.grid_sum / grid_size;
  116. metric.waves1 = (float)grid_size / (props_->multiProcessorCount * metric.active_ctas);
  117. metric.cta_wave = std::ceil(metric.waves1) * metric.active_ctas;
  118. metric.grid_norm = metric.cta_wave * metric.cta_sum;
  119. }
  120. void Launch(
  121. T_BC* C, const uint* A, const T_BC* B, const T_Q* Q, int M, int N, int K, int output_op_idx, cudaStream_t st)
  122. override
  123. {
  124. constexpr int block_size = GemmType::kWarpCount * WARP_SIZE;
  125. dim3 grid_size((M + cta_shape.m() - 1) / cta_shape.m(), (N + cta_shape.n() - 1) / cta_shape.n());
  126. kernel_func_<<<grid_size, block_size, kSmemByteSize, st>>>(C, A, B, Q, M, N, K, output_op_idx);
  127. }
  128. void Dump(std::ostream& os) override
  129. {
  130. {
  131. os << "[Gemm] CTA shape: " << cta_shape.m() << "x" << cta_shape.n() << "x" << cta_shape.k() << std::endl;
  132. os << "[Gemm] warp shape: " << warp_shape.m() << "x" << warp_shape.n() << "x" << warp_shape.k()
  133. << std::endl;
  134. os << "[Gemm] warp count: " << GemmType::kWarpCountM << "x" << GemmType::kWarpCountN << "x"
  135. << GemmType::kWarpCountK << " (" << GemmType::kWarpCount << ")" << std::endl;
  136. os << std::endl;
  137. }
  138. {
  139. using Iter = typename GemmType::IteratorA;
  140. os << "[A] shape: " << Iter::kShapeM << " " << Iter::kShapeK << std::endl;
  141. os << "[A] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
  142. os << "[A] warp shape per access: " << Iter::kWarpAccessM << " " << Iter::kWarpAccessK << std::endl;
  143. os << "[A] warp access iters: " << Iter::kWarpIterM << " " << Iter::kWarpIterK << std::endl;
  144. os << "[A] warp arrangement: " << Iter::kWarpM << " " << Iter::kWarpK << std::endl;
  145. os << "[A] iterations: " << Iter::kIterM << " " << Iter::kIterK << std::endl;
  146. os << "[A] iters per tile: " << Iter::kIterCount << std::endl;
  147. os << "[A] warp footprint: " << Iter::kWarpFootprintM << " " << Iter::kWarpFootprintK << std::endl;
  148. os << "[A] shared memory: " << Iter::kSmemByteSize << std::endl;
  149. os << std::endl;
  150. }
  151. {
  152. using Iter = typename GemmType::IteratorB;
  153. os << "[B] shape: " << Iter::kShapeK << " " << Iter::kShapeN << std::endl;
  154. os << "[B] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
  155. os << "[B] warp shape per access: " << Iter::kWarpAccessK << " " << Iter::kWarpAccessN << std::endl;
  156. os << "[B] warp access iters: " << Iter::kWarpIterK << " " << Iter::kWarpIterN << std::endl;
  157. os << "[B] warp arrangement: " << Iter::kWarpK << " " << Iter::kWarpN << std::endl;
  158. os << "[B] iterations: " << Iter::kIterK << " " << Iter::kIterN << std::endl;
  159. os << "[B] iters per tile: " << Iter::kIterCount << std::endl;
  160. os << "[B] warp footprint: " << Iter::kWarpFootprintK << " " << Iter::kWarpFootprintN << std::endl;
  161. os << "[B] shared memory: " << Iter::kSmemByteSize << std::endl;
  162. os << std::endl;
  163. }
  164. {
  165. using Iter = typename GemmType::IteratorQ;
  166. // os << "[Q] shape: " << CTA_M << " " << Iter::SLICE_K << std::endl;
  167. os << "[Q] warp thread arrangement: " << Iter::kWarpThreadC << " " << Iter::kWarpThreadS << std::endl;
  168. os << "[Q] warp shape per access: " << Iter::kWarpAccessM << " " << Iter::kWarpAccessK << std::endl;
  169. os << "[Q] warp access iters: " << Iter::kWarpIterM << " " << Iter::kWarpIterK << std::endl;
  170. os << "[Q] warp arrangement: " << Iter::kWarpM << " " << Iter::kWarpK << std::endl;
  171. os << "[Q] iterations: " << Iter::kIterM << " " << Iter::kIterK << std::endl;
  172. os << "[Q] iters per tile: " << Iter::kIterCount << std::endl;
  173. os << "[Q] warp footprint: " << Iter::kWarpFootprintM << " " << Iter::kWarpFootprintK << std::endl;
  174. os << "[Q] size per stage: " << Iter::kSizePerStage << std::endl;
  175. os << "[Q] shared memory: " << Iter::kSmemByteSize << std::endl;
  176. os << std::endl;
  177. }
  178. os << "Dynamic shared memory size: " << kSmemByteSize << std::endl;
  179. }
  180. };
  181. } // namespace autoquant
  182. } // namespace aphrodite