enabled.h 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. /*
  2. * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #pragma once
  17. #include "cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
  18. #include "common.h"
  19. #include <cuda_runtime.h>
  20. inline int getSMVersion()
  21. {
  22. int device{-1};
  23. cudaGetDevice(&device);
  24. int sm_major = 0;
  25. int sm_minor = 0;
  26. cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device);
  27. cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device);
  28. return sm_major * 10 + sm_minor;
  29. }
  30. namespace tensorrt_llm
  31. {
  32. namespace kernels
  33. {
  34. template <typename TypeB, typename Layout>
  35. struct SupportedLayout
  36. {
  37. static constexpr bool value = false;
  38. };
  39. template <>
  40. struct SupportedLayout<uint8_t, cutlass::layout::ColumnMajorTileInterleave<64, 2>>
  41. {
  42. static constexpr bool value = true;
  43. };
  44. template <>
  45. struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajorTileInterleave<64, 4>>
  46. {
  47. static constexpr bool value = true;
  48. };
  49. template <typename TypeB, typename Arch>
  50. bool isEnabled()
  51. {
  52. using Layout = typename cutlass::gemm::kernel::LayoutDetailsB<TypeB, Arch>::Layout;
  53. return SupportedLayout<TypeB, Layout>::value;
  54. }
  55. template <typename TypeB>
  56. bool isEnabledForArch(int arch)
  57. {
  58. if (arch >= 70 && arch < 75)
  59. {
  60. return isEnabled<TypeB, cutlass::arch::Sm70>();
  61. }
  62. else if (arch >= 75 && arch < 80)
  63. {
  64. return isEnabled<TypeB, cutlass::arch::Sm75>();
  65. }
  66. else if (arch >= 80 && arch <= 90)
  67. {
  68. return isEnabled<TypeB, cutlass::arch::Sm80>();
  69. }
  70. else
  71. {
  72. // TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
  73. assert(0);
  74. return false;
  75. }
  76. }
  77. inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype)
  78. {
  79. const int arch = getSMVersion();
  80. if (qtype == WeightOnlyQuantType::Int4b)
  81. {
  82. return isEnabledForArch<cutlass::uint4b_t>(arch);
  83. }
  84. else if (qtype == WeightOnlyQuantType::Int8b)
  85. {
  86. return isEnabledForArch<uint8_t>(arch);
  87. }
  88. else
  89. {
  90. assert(0);
  91. // TLLM_CHECK_WITH_INFO(false, "Unsupported WeightOnlyQuantType");
  92. return false;
  93. }
  94. }
  95. } // namespace kernels
  96. } // namespace tensorrt_llm