static_switch.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. //
  16. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  17. [&] { \
  18. if (COND) { \
  19. constexpr static bool CONST_NAME = true; \
  20. return __VA_ARGS__(); \
  21. } else { \
  22. constexpr static bool CONST_NAME = false; \
  23. return __VA_ARGS__(); \
  24. } \
  25. }()
  26. #define PREC_SWITCH(PRECTYPE, NAME, ...) \
  27. [&] { \
  28. if (PRECTYPE == 3) { \
  29. using NAME = cutlass::float_e4m3_t; \
  30. return __VA_ARGS__(); \
  31. } else if (PRECTYPE == 2) { \
  32. using NAME = cutlass::bfloat16_t; \
  33. return __VA_ARGS__(); \
  34. } else { \
  35. using NAME = cutlass::half_t; \
  36. return __VA_ARGS__(); \
  37. } \
  38. }()
  39. #define HEADDIM_SWITCH(HEADDIM, CONST_NAME, ...) \
  40. [&] { \
  41. if (HEADDIM == 64) { \
  42. constexpr static int CONST_NAME = 64; \
  43. return __VA_ARGS__(); \
  44. } else if (HEADDIM == 128) { \
  45. constexpr static int CONST_NAME = 128; \
  46. return __VA_ARGS__(); \
  47. } else { \
  48. constexpr static int CONST_NAME = 256; \
  49. return __VA_ARGS__(); \
  50. } \
  51. }()
  52. #define SEQLEN_SWITCH(PARAMS, NAME, NAME_Q, ...) \
  53. [&] { \
  54. const bool useSeqLen = PARAMS.cu_seqlens_q; \
  55. const bool usePagedKV = PARAMS.page_block_size>0; \
  56. if (useSeqLen) { \
  57. if (usePagedKV) { \
  58. using NAME = flash::PagedSeqLenTraits; \
  59. using NAME_Q = flash::VarSeqLenTraits; \
  60. return __VA_ARGS__(); \
  61. } else { \
  62. using NAME = flash::VarSeqLenTraits; \
  63. using NAME_Q = flash::VarSeqLenTraits; \
  64. return __VA_ARGS__(); \
  65. } \
  66. } else { \
  67. using NAME = flash::FixedSeqLenTraits; \
  68. using NAME_Q = flash::FixedSeqLenTraits; \
  69. return __VA_ARGS__(); \
  70. } \
  71. }()
  72. #define SEQLEN_SWITCH_FWD(VAR_SEQ_LEN_Q, SEQ_USED_K, NAME_Q, NAME_K, ...) \
  73. [&] { \
  74. bool useVarSeqLenQ = VAR_SEQ_LEN_Q; \
  75. bool useSeqUsedK = SEQ_USED_K; \
  76. if (useVarSeqLenQ) { \
  77. using NAME_Q = flash::VarSeqLenTraits; \
  78. using NAME_K = flash::VarSeqLenTraits; \
  79. return __VA_ARGS__(); \
  80. } else if (useSeqUsedK) { \
  81. using NAME_Q = flash::FixedSeqLenTraits; \
  82. using NAME_K = flash::FixedSeqLenTraitsDynamic; \
  83. return __VA_ARGS__(); \
  84. } else { \
  85. using NAME_Q = flash::FixedSeqLenTraits; \
  86. using NAME_K = flash::FixedSeqLenTraits; \
  87. return __VA_ARGS__(); \
  88. } \
  89. }()
  90. #define QUERYHEAD_SWITCH(QUERYHEADS, CONST_NAME, ...) \
  91. [&] { \
  92. if (QUERYHEADS <= 2) { \
  93. constexpr static int CONST_NAME = 2; \
  94. return __VA_ARGS__(); \
  95. } else if (QUERYHEADS <= 4) { \
  96. constexpr static int CONST_NAME = 4; \
  97. return __VA_ARGS__(); \
  98. } else if (QUERYHEADS <= 8) { \
  99. constexpr static int CONST_NAME = 8; \
  100. return __VA_ARGS__(); \
  101. } else if (QUERYHEADS <= 16) { \
  102. constexpr static int CONST_NAME = 16; \
  103. return __VA_ARGS__(); \
  104. } else { \
  105. constexpr static int CONST_NAME = 32; \
  106. return __VA_ARGS__(); \
  107. } \
  108. }()
  109. #define MMA_3WG_SWITCH(QLEN, CONST_NAME, ...) \
  110. [&] { \
  111. if (QLEN <= 64) { \
  112. constexpr static int CONST_NAME = 1; \
  113. return __VA_ARGS__(); \
  114. } else if (QLEN <= 128) { \
  115. constexpr static int CONST_NAME = 2; \
  116. return __VA_ARGS__(); \
  117. } else { \
  118. constexpr static int CONST_NAME = 3; \
  119. return __VA_ARGS__(); \
  120. } \
  121. }()
  122. #define MMA_2WG_SWITCH(QLEN, CONST_NAME, ...) \
  123. [&] { \
  124. if (QLEN <= 64) { \
  125. constexpr static int CONST_NAME = 1; \
  126. return __VA_ARGS__(); \
  127. } else { \
  128. constexpr static int CONST_NAME = 2; \
  129. return __VA_ARGS__(); \
  130. } \
  131. }()
  132. #define NUM_SPLITS_SWITCH(NUM_SPLITS, LOG_MAX_SPLITS, ...) \
  133. [&] { \
  134. if (NUM_SPLITS <= 2) { \
  135. constexpr static int LOG_MAX_SPLITS = 1; \
  136. return __VA_ARGS__(); \
  137. } else if (NUM_SPLITS <= 4) { \
  138. constexpr static int LOG_MAX_SPLITS = 2; \
  139. return __VA_ARGS__(); \
  140. } else if (NUM_SPLITS <= 8) { \
  141. constexpr static int LOG_MAX_SPLITS = 3; \
  142. return __VA_ARGS__(); \
  143. } else if (NUM_SPLITS <= 16) { \
  144. constexpr static int LOG_MAX_SPLITS = 4; \
  145. return __VA_ARGS__(); \
  146. } else if (NUM_SPLITS <= 32) { \
  147. constexpr static int LOG_MAX_SPLITS = 5; \
  148. return __VA_ARGS__(); \
  149. } else if (NUM_SPLITS <= 64) { \
  150. constexpr static int LOG_MAX_SPLITS = 6; \
  151. return __VA_ARGS__(); \
  152. } else { \
  153. constexpr static int LOG_MAX_SPLITS = 7; \
  154. return __VA_ARGS__(); \
  155. } \
  156. }()