#pragma once #include "scaled_mm_c2x.cuh" /** * This file defines Gemm kernel configurations for SM80 based on the Gemm * shape. */ namespace aphrodite { template typename Epilogue> struct sm80_config_default { // This config is used in 2 cases, // - M in (128, inf) // - M in (64, 128] and N >= 8192 // Shared Memory required by this Gemm - 81920 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using Cutlass2xGemm = cutlass_2x_gemm; }; template typename Epilogue> struct sm80_config_M64 { // This config is used in 2 cases, // - M in (32, 64] // - M in (64, 128] and N < 8192 // Shared Memory required by this Gemm - 122880 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>; using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using Cutlass2xGemm = cutlass_2x_gemm; }; template typename Epilogue> struct sm80_config_M32 { // M in (16, 32] // Shared Memory required by this Gemm - 61440 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>; using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using Cutlass2xGemm = cutlass_2x_gemm; }; template typename Epilogue> struct sm80_config_M16 { // M in [1, 16] // Shared Memory required by this Gemm - 51200 bytes static_assert(std::is_same()); using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>; using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>; using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>; using Cutlass2xGemm = cutlass_2x_gemm; }; template typename Epilogue, typename... EpilogueArgs> inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); TORCH_CHECK(a.dtype() == torch::kInt8); TORCH_CHECK(b.dtype() == torch::kInt8); using Cutlass2xGemmDefault = typename sm80_config_default::Cutlass2xGemm; using Cutlass2xGemmM128BigN = typename sm80_config_default::Cutlass2xGemm; using Cutlass2xGemmM128SmallN = typename sm80_config_M64::Cutlass2xGemm; using Cutlass2xGemmM64 = typename sm80_config_M64::Cutlass2xGemm; using Cutlass2xGemmM32 = typename sm80_config_M32::Cutlass2xGemm; using Cutlass2xGemmM16 = typename sm80_config_M16::Cutlass2xGemm; // Due to shared memory requirements, some Gemms may fail to run on some // GPUs. As the name indicates, the Fallback Gemm is used as an alternative // in such cases. // sm80_config_M16 has the least shared-memory requirement. However, // based on some profiling, we select sm80_config_M32 as a better alternative // performance wise. using FallbackGemm = typename sm80_config_M32::Cutlass2xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); // next power of 2 if (mp2 <= 16) { // M in [1, 16] return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 32) { // M in (16, 32] return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 64) { // M in (32, 64] return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else if (mp2 <= 128) { // M in (64, 128] uint32_t const n = out.size(1); bool const small_n = n < 8192; if (small_n) { return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } } else { // M in (128, inf) return fallback_cutlass_gemm_caller( out, a, b, std::forward(args)...); } } } // namespace aphrodite