static_switch.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. //
  16. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  17. [&] { \
  18. if (COND) { \
  19. constexpr static bool CONST_NAME = true; \
  20. return __VA_ARGS__(); \
  21. } else { \
  22. constexpr static bool CONST_NAME = false; \
  23. return __VA_ARGS__(); \
  24. } \
  25. }()
  26. #define PREC_SWITCH(PRECTYPE, NAME, ...) \
  27. [&] { \
  28. if (PRECTYPE == 3) { \
  29. using NAME = cutlass::float_e4m3_t; \
  30. return __VA_ARGS__(); \
  31. } else if (PRECTYPE == 2) { \
  32. using NAME = cutlass::bfloat16_t; \
  33. return __VA_ARGS__(); \
  34. } else { \
  35. using NAME = cutlass::half_t; \
  36. return __VA_ARGS__(); \
  37. } \
  38. }()
  39. #define HEADDIM_SWITCH(HEADDIM, CONST_NAME, ...) \
  40. [&] { \
  41. if (HEADDIM == 64) { \
  42. constexpr static int CONST_NAME = 64; \
  43. return __VA_ARGS__(); \
  44. } else if (HEADDIM == 128) { \
  45. constexpr static int CONST_NAME = 128; \
  46. return __VA_ARGS__(); \
  47. } else { \
  48. constexpr static int CONST_NAME = 256; \
  49. return __VA_ARGS__(); \
  50. } \
  51. }()
  52. #define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, NAME, ...) \
  53. [&] { \
  54. bool useSeqLen = USE_VAR_SEQ_LEN; \
  55. if (useSeqLen) { \
  56. using NAME = flash::VarSeqLenTraits; \
  57. return __VA_ARGS__(); \
  58. } else { \
  59. using NAME = flash::FixedSeqLenTraits; \
  60. return __VA_ARGS__(); \
  61. } \
  62. }()
  63. #define SEQLEN_SWITCH_FWD(VAR_SEQ_LEN_Q, SEQ_USED_K, NAME_Q, NAME_K, ...) \
  64. [&] { \
  65. bool useVarSeqLenQ = VAR_SEQ_LEN_Q; \
  66. bool useSeqUsedK = SEQ_USED_K; \
  67. if (useVarSeqLenQ) { \
  68. using NAME_Q = flash::VarSeqLenTraits; \
  69. using NAME_K = flash::VarSeqLenTraits; \
  70. return __VA_ARGS__(); \
  71. } else if (useSeqUsedK) { \
  72. using NAME_Q = flash::FixedSeqLenTraits; \
  73. using NAME_K = flash::FixedSeqLenTraitsDynamic; \
  74. return __VA_ARGS__(); \
  75. } else { \
  76. using NAME_Q = flash::FixedSeqLenTraits; \
  77. using NAME_K = flash::FixedSeqLenTraits; \
  78. return __VA_ARGS__(); \
  79. } \
  80. }()
  81. #define QUERYHEAD_SWITCH(QUERYHEADS, CONST_NAME, ...) \
  82. [&] { \
  83. if (QUERYHEADS <= 2) { \
  84. constexpr static int CONST_NAME = 2; \
  85. return __VA_ARGS__(); \
  86. } else if (QUERYHEADS <= 4) { \
  87. constexpr static int CONST_NAME = 4; \
  88. return __VA_ARGS__(); \
  89. } else if (QUERYHEADS <= 8) { \
  90. constexpr static int CONST_NAME = 8; \
  91. return __VA_ARGS__(); \
  92. } else if (QUERYHEADS <= 16) { \
  93. constexpr static int CONST_NAME = 16; \
  94. return __VA_ARGS__(); \
  95. } else { \
  96. constexpr static int CONST_NAME = 32; \
  97. return __VA_ARGS__(); \
  98. } \
  99. }()
  100. #define MMA_3WG_SWITCH(QLEN, CONST_NAME, ...) \
  101. [&] { \
  102. if (QLEN <= 64) { \
  103. constexpr static int CONST_NAME = 1; \
  104. return __VA_ARGS__(); \
  105. } else if (QLEN <= 128) { \
  106. constexpr static int CONST_NAME = 2; \
  107. return __VA_ARGS__(); \
  108. } else { \
  109. constexpr static int CONST_NAME = 3; \
  110. return __VA_ARGS__(); \
  111. } \
  112. }()
  113. #define MMA_2WG_SWITCH(QLEN, CONST_NAME, ...) \
  114. [&] { \
  115. if (QLEN <= 64) { \
  116. constexpr static int CONST_NAME = 1; \
  117. return __VA_ARGS__(); \
  118. } else { \
  119. constexpr static int CONST_NAME = 2; \
  120. return __VA_ARGS__(); \
  121. } \
  122. }()
  123. #define NUM_SPLITS_SWITCH(NUM_SPLITS, LOG_MAX_SPLITS, ...) \
  124. [&] { \
  125. if (NUM_SPLITS <= 2) { \
  126. constexpr static int LOG_MAX_SPLITS = 1; \
  127. return __VA_ARGS__(); \
  128. } else if (NUM_SPLITS <= 4) { \
  129. constexpr static int LOG_MAX_SPLITS = 2; \
  130. return __VA_ARGS__(); \
  131. } else if (NUM_SPLITS <= 8) { \
  132. constexpr static int LOG_MAX_SPLITS = 3; \
  133. return __VA_ARGS__(); \
  134. } else if (NUM_SPLITS <= 16) { \
  135. constexpr static int LOG_MAX_SPLITS = 4; \
  136. return __VA_ARGS__(); \
  137. } else if (NUM_SPLITS <= 32) { \
  138. constexpr static int LOG_MAX_SPLITS = 5; \
  139. return __VA_ARGS__(); \
  140. } else if (NUM_SPLITS <= 64) { \
  141. constexpr static int LOG_MAX_SPLITS = 6; \
  142. return __VA_ARGS__(); \
  143. } else { \
  144. constexpr static int LOG_MAX_SPLITS = 7; \
  145. return __VA_ARGS__(); \
  146. } \
  147. }()