scaled_mm_c2x_sm75_dispatch.cuh 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #pragma once
  2. #include "scaled_mm_c2x.cuh"
  3. /**
  4. * This file defines Gemm kernel configurations for SM75 based on the Gemm
  5. * shape.
  6. */
  7. namespace aphrodite {
  8. template <typename InType, typename OutType,
  9. template <typename, typename> typename Epilogue>
  10. struct sm75_config_default {
  11. // This config is used in 2 cases,
  12. // - M in (256, inf]
  13. // - M in (64, 128]
  14. // Shared memory required by this Gemm 32768
  15. static_assert(std::is_same<InType, int8_t>());
  16. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
  17. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  18. using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
  19. using Cutlass2xGemm =
  20. cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
  21. Epilogue, TileShape, WarpShape, InstructionShape, 2>;
  22. };
  23. template <typename InType, typename OutType,
  24. template <typename, typename> typename Epilogue>
  25. struct sm75_config_M256 {
  26. // M in (128, 256]
  27. // Shared memory required by this Gemm 65536
  28. static_assert(std::is_same<InType, int8_t>());
  29. using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
  30. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  31. using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
  32. using Cutlass2xGemm =
  33. cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
  34. Epilogue, TileShape, WarpShape, InstructionShape, 2>;
  35. };
  36. template <typename InType, typename OutType,
  37. template <typename, typename> typename Epilogue>
  38. struct sm75_config_M64 {
  39. // M in (32, 64]
  40. // Shared memory required by this Gemm 49152
  41. static_assert(std::is_same<InType, int8_t>());
  42. using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
  43. using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
  44. using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
  45. using Cutlass2xGemm =
  46. cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
  47. Epilogue, TileShape, WarpShape, InstructionShape, 2>;
  48. };
  49. template <typename InType, typename OutType,
  50. template <typename, typename> typename Epilogue>
  51. struct sm75_config_M32 {
  52. // M in [1, 32]
  53. // Shared memory required by this Gemm 49152
  54. static_assert(std::is_same<InType, int8_t>());
  55. using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
  56. using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
  57. using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
  58. using Cutlass2xGemm =
  59. cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
  60. Epilogue, TileShape, WarpShape, InstructionShape, 2>;
  61. };
  62. template <typename InType, typename OutType,
  63. template <typename, typename> typename Epilogue,
  64. typename... EpilogueArgs>
  65. inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
  66. torch::Tensor const& a,
  67. torch::Tensor const& b,
  68. EpilogueArgs&&... args) {
  69. static_assert(std::is_same<InType, int8_t>());
  70. TORCH_CHECK(a.dtype() == torch::kInt8);
  71. TORCH_CHECK(b.dtype() == torch::kInt8);
  72. using Cutlass2xGemmDefault =
  73. typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
  74. using Cutlass2xGemmM256 =
  75. typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
  76. using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
  77. using Cutlass2xGemmM64 =
  78. typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
  79. using Cutlass2xGemmM32 =
  80. typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
  81. // Due to shared memory requirements, some Gemms may fail to run on some
  82. // GPUs. As the name indicates, the Fallback Gemm is used as an alternative
  83. // in such cases.
  84. // sm75_config_default has the least shared-memory requirements.
  85. using FallbackGemm = Cutlass2xGemmDefault;
  86. uint32_t const m = a.size(0);
  87. uint32_t const mp2 =
  88. std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
  89. if (mp2 <= 32) {
  90. // M in [1, 32]
  91. return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
  92. out, a, b, std::forward<EpilogueArgs>(args)...);
  93. } else if (mp2 <= 64) {
  94. // M in (32, 64]
  95. return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
  96. out, a, b, std::forward<EpilogueArgs>(args)...);
  97. } else if (mp2 <= 128) {
  98. // M in (64, 128]
  99. return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
  100. out, a, b, std::forward<EpilogueArgs>(args)...);
  101. } else if (mp2 <= 256) {
  102. // M in (128, 256]
  103. return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
  104. out, a, b, std::forward<EpilogueArgs>(args)...);
  105. } else {
  106. // M in (256, inf)
  107. return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
  108. out, a, b, std::forward<EpilogueArgs>(args)...);
  109. }
  110. }
  111. } // namespace aphrodite