scaled_mm_c2x_sm80_dispatch.cuh 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #pragma once
  2. #include "scaled_mm_c2x.cuh"
  3. /**
  4. * This file defines Gemm kernel configurations for SM80 based on the Gemm
  5. * shape.
  6. */
  7. namespace aphrodite {
  8. template <typename InType, typename OutType,
  9. template <typename, typename> typename Epilogue>
  10. struct sm80_config_default {
  11. // This config is used in 2 cases,
  12. // - M in (128, inf)
  13. // - M in (64, 128] and N >= 8192
  14. // Shared Memory required by this Gemm - 81920 bytes
  15. static_assert(std::is_same<InType, int8_t>());
  16. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  17. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  18. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  19. using Cutlass2xGemm =
  20. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  21. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  22. };
  23. template <typename InType, typename OutType,
  24. template <typename, typename> typename Epilogue>
  25. struct sm80_config_M64 {
  26. // This config is used in 2 cases,
  27. // - M in (32, 64]
  28. // - M in (64, 128] and N < 8192
  29. // Shared Memory required by this Gemm - 122880 bytes
  30. static_assert(std::is_same<InType, int8_t>());
  31. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
  32. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  33. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  34. using Cutlass2xGemm =
  35. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  36. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  37. };
  38. template <typename InType, typename OutType,
  39. template <typename, typename> typename Epilogue>
  40. struct sm80_config_M32 {
  41. // M in (16, 32]
  42. // Shared Memory required by this Gemm - 61440 bytes
  43. static_assert(std::is_same<InType, int8_t>());
  44. using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
  45. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  46. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  47. using Cutlass2xGemm =
  48. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  49. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  50. };
  51. template <typename InType, typename OutType,
  52. template <typename, typename> typename Epilogue>
  53. struct sm80_config_M16 {
  54. // M in [1, 16]
  55. // Shared Memory required by this Gemm - 51200 bytes
  56. static_assert(std::is_same<InType, int8_t>());
  57. using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
  58. using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
  59. using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
  60. using Cutlass2xGemm =
  61. cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
  62. Epilogue, TileShape, WarpShape, InstructionShape, 5>;
  63. };
  64. template <typename InType, typename OutType,
  65. template <typename, typename> typename Epilogue,
  66. typename... EpilogueArgs>
  67. inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
  68. torch::Tensor const& a,
  69. torch::Tensor const& b,
  70. EpilogueArgs&&... args) {
  71. static_assert(std::is_same<InType, int8_t>());
  72. TORCH_CHECK(a.dtype() == torch::kInt8);
  73. TORCH_CHECK(b.dtype() == torch::kInt8);
  74. using Cutlass2xGemmDefault =
  75. typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
  76. using Cutlass2xGemmM128BigN =
  77. typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
  78. using Cutlass2xGemmM128SmallN =
  79. typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
  80. using Cutlass2xGemmM64 =
  81. typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
  82. using Cutlass2xGemmM32 =
  83. typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
  84. using Cutlass2xGemmM16 =
  85. typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
  86. // Due to shared memory requirements, some Gemms may fail to run on some
  87. // GPUs. As the name indicates, the Fallback Gemm is used as an alternative
  88. // in such cases.
  89. // sm80_config_M16 has the least shared-memory requirement. However,
  90. // based on some profiling, we select sm80_config_M32 as a better alternative
  91. // performance wise.
  92. using FallbackGemm =
  93. typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
  94. uint32_t const m = a.size(0);
  95. uint32_t const mp2 =
  96. std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
  97. if (mp2 <= 16) {
  98. // M in [1, 16]
  99. return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
  100. out, a, b, std::forward<EpilogueArgs>(args)...);
  101. } else if (mp2 <= 32) {
  102. // M in (16, 32]
  103. return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
  104. out, a, b, std::forward<EpilogueArgs>(args)...);
  105. } else if (mp2 <= 64) {
  106. // M in (32, 64]
  107. return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
  108. out, a, b, std::forward<EpilogueArgs>(args)...);
  109. } else if (mp2 <= 128) {
  110. // M in (64, 128]
  111. uint32_t const n = out.size(1);
  112. bool const small_n = n < 8192;
  113. if (small_n) {
  114. return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
  115. FallbackGemm>(
  116. out, a, b, std::forward<EpilogueArgs>(args)...);
  117. } else {
  118. return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
  119. out, a, b, std::forward<EpilogueArgs>(args)...);
  120. }
  121. } else {
  122. // M in (128, inf)
  123. return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
  124. out, a, b, std::forward<EpilogueArgs>(args)...);
  125. }
  126. }
  127. } // namespace aphrodite