cutlass_heuristic.cc 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. /*
  2. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "cutlass_heuristic.h"
  17. #include "cutlass/gemm/gemm.h"
  18. #include <cuda_runtime_api.h>
  19. #include <vector>
  20. #include <stdexcept>
  21. namespace fastertransformer {
  22. struct TileShape {
  23. int m;
  24. int n;
  25. };
  26. TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
  27. {
  28. switch (tile_config) {
  29. case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
  30. return TileShape{32, 128};
  31. case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
  32. case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
  33. return TileShape{64, 128};
  34. case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
  35. case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
  36. case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
  37. return TileShape{128, 128};
  38. default:
  39. throw std::runtime_error("[FT Error][get_grid_shape_for_config] Invalid config");
  40. }
  41. }
  42. bool is_valid_split_k_factor(const int64_t m,
  43. const int64_t n,
  44. const int64_t k,
  45. const TileShape tile_shape,
  46. const int split_k_factor,
  47. const size_t workspace_bytes,
  48. const bool is_weight_only)
  49. {
  50. // All tile sizes have a k_tile of 64.
  51. static constexpr int k_tile = 64;
  52. // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k
  53. if (is_weight_only) {
  54. if ((k % k_tile) != 0) {
  55. return false;
  56. }
  57. if ((k % split_k_factor) != 0) {
  58. return false;
  59. }
  60. const int k_elements_per_split = k / split_k_factor;
  61. if ((k_elements_per_split % k_tile) != 0) {
  62. return false;
  63. }
  64. }
  65. // Check that the workspace has sufficient space for this split-k factor
  66. const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
  67. const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
  68. const size_t required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
  69. if (required_ws_bytes > workspace_bytes) {
  70. return false;
  71. }
  72. return true;
  73. }
  74. std::vector<CutlassTileConfig> get_candidate_tiles(const bool is_weight_only, const bool simt_configs_only)
  75. {
  76. std::vector<CutlassTileConfig> simt_configs{CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
  77. std::vector<CutlassTileConfig> square_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
  78. CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
  79. CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64};
  80. std::vector<CutlassTileConfig> quant_B_configs{CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
  81. CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
  82. CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64};
  83. const std::vector<CutlassTileConfig> allowed_configs = is_weight_only ? quant_B_configs : square_configs;
  84. return simt_configs_only ? simt_configs : allowed_configs;
  85. }
  86. std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only)
  87. {
  88. std::vector<CutlassTileConfig> tiles = get_candidate_tiles(is_weight_only, simt_configs_only);
  89. std::vector<CutlassGemmConfig> candidate_configs;
  90. const int min_stages = 2;
  91. const int max_stages = sm >= 80 ? 4 : 2;
  92. for (const auto& tile_config : tiles) {
  93. for (int stages = min_stages; stages <= max_stages; ++stages) {
  94. CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
  95. candidate_configs.push_back(config);
  96. }
  97. }
  98. return candidate_configs;
  99. }
  100. CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
  101. const std::vector<int>& occupancies,
  102. const int64_t m,
  103. const int64_t n,
  104. const int64_t k,
  105. const int64_t num_experts,
  106. const int split_k_limit,
  107. const size_t workspace_bytes,
  108. const int multi_processor_count,
  109. const int is_weight_only)
  110. {
  111. if (occupancies.size() != candidate_configs.size()) {
  112. throw std::runtime_error("[FT Error][estimate_best_config_from_occupancies] occpancies and "
  113. "candidate configs vectors must have equal length.");
  114. }
  115. CutlassGemmConfig best_config;
  116. // Score will be [0, 1]. The objective is to minimize this score.
  117. // It represents the fraction of SM resources unused in the last wave.
  118. float config_score = 1.0f;
  119. int config_waves = INT_MAX;
  120. int current_m_tile = 0;
  121. const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
  122. for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
  123. CutlassGemmConfig candidate_config = candidate_configs[ii];
  124. TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config);
  125. int occupancy = occupancies[ii];
  126. if (occupancy == 0) {
  127. continue;
  128. }
  129. // Keep small tile sizes when possible.
  130. if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile
  131. && current_m_tile < tile_shape.m) {
  132. continue;
  133. }
  134. const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
  135. const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
  136. for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) {
  137. if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) {
  138. const int ctas_per_wave = occupancy * multi_processor_count;
  139. const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
  140. const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
  141. const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
  142. const float current_score = float(num_waves_total) - num_waves_fractional;
  143. const float score_slack = 0.1f;
  144. if (current_score < config_score
  145. || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) {
  146. config_score = current_score;
  147. config_waves = num_waves_total;
  148. SplitKStyle split_style =
  149. split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
  150. best_config = CutlassGemmConfig{
  151. candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
  152. current_m_tile = tile_shape.m;
  153. }
  154. else if (current_score == config_score
  155. && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor
  156. || current_m_tile < tile_shape.m)) {
  157. // Prefer deeper pipeline or smaller split-k
  158. SplitKStyle split_style =
  159. split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
  160. best_config = CutlassGemmConfig{
  161. candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
  162. current_m_tile = tile_shape.m;
  163. config_waves = num_waves_total;
  164. }
  165. }
  166. }
  167. }
  168. if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
  169. throw std::runtime_error("[FT Error] Heurisitc failed to find a valid config.");
  170. }
  171. return best_config;
  172. }
  173. } // namespace fastertransformer