static_switch.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. //
  16. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  17. [&] { \
  18. if (COND) { \
  19. constexpr static bool CONST_NAME = true; \
  20. return __VA_ARGS__(); \
  21. } else { \
  22. constexpr static bool CONST_NAME = false; \
  23. return __VA_ARGS__(); \
  24. } \
  25. }()
  26. #ifdef FLASHATTENTION_DISABLE_LOCAL
  27. #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
  28. [&] { \
  29. constexpr static bool LOCAL_CONST_NAME = false; \
  30. if (CAUSAL_COND) { \
  31. constexpr static bool CAUSAL_CONST_NAME = true; \
  32. return __VA_ARGS__(); \
  33. } else { \
  34. constexpr static bool CAUSAL_CONST_NAME = false; \
  35. return __VA_ARGS__(); \
  36. } \
  37. }()
  38. #else
  39. #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
  40. [&] { \
  41. if (CAUSAL_COND) { \
  42. constexpr static bool CAUSAL_CONST_NAME = true; \
  43. constexpr static bool LOCAL_CONST_NAME = false; \
  44. return __VA_ARGS__(); \
  45. } else if (LOCAL_COND) { \
  46. constexpr static bool CAUSAL_CONST_NAME = false; \
  47. constexpr static bool LOCAL_CONST_NAME = true; \
  48. return __VA_ARGS__(); \
  49. } else { \
  50. constexpr static bool CAUSAL_CONST_NAME = false; \
  51. constexpr static bool LOCAL_CONST_NAME = false; \
  52. return __VA_ARGS__(); \
  53. } \
  54. }()
  55. #endif
  56. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  57. #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
  58. [&] { \
  59. constexpr static bool CONST_NAME = false; \
  60. return __VA_ARGS__(); \
  61. }()
  62. #else
  63. #define SOFTCAP_SWITCH BOOL_SWITCH
  64. #endif
  65. #ifdef FLASHATTENTION_DISABLE_PAGEDKV
  66. #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \
  67. [&] { \
  68. constexpr static bool CONST_NAME = false; \
  69. return __VA_ARGS__(); \
  70. }()
  71. #else
  72. #define PAGEDKV_SWITCH BOOL_SWITCH
  73. #endif
  74. #ifdef FLASHATTENTION_DISABLE_SPLIT
  75. #define SPLIT_SWITCH(COND, CONST_NAME, ...) \
  76. [&] { \
  77. constexpr static bool CONST_NAME = false; \
  78. return __VA_ARGS__(); \
  79. }()
  80. #else
  81. #define SPLIT_SWITCH BOOL_SWITCH
  82. #endif
  83. #ifdef FLASHATTENTION_DISABLE_APPENDKV
  84. #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \
  85. [&] { \
  86. constexpr static bool CONST_NAME = false; \
  87. return __VA_ARGS__(); \
  88. }()
  89. #else
  90. #define APPENDKV_SWITCH BOOL_SWITCH
  91. #endif
  92. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  93. #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \
  94. [&] { \
  95. constexpr static bool CONST_NAME = false; \
  96. return __VA_ARGS__(); \
  97. }()
  98. #else
  99. #define PACKGQA_SWITCH BOOL_SWITCH
  100. #endif
  101. #ifdef FLASHATTENTION_DISABLE_VARLEN
  102. #define VARLEN_SWITCH(COND, CONST_NAME, ...) \
  103. [&] { \
  104. constexpr static bool CONST_NAME = false; \
  105. return __VA_ARGS__(); \
  106. }()
  107. #else
  108. #define VARLEN_SWITCH BOOL_SWITCH
  109. #endif
  110. #ifdef FLASHATTENTION_DISABLE_CLUSTER
  111. #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \
  112. [&] { \
  113. constexpr static bool CONST_NAME = false; \
  114. return __VA_ARGS__(); \
  115. }()
  116. #else
  117. #define CLUSTER_SWITCH BOOL_SWITCH
  118. #endif
  119. #ifdef FLASHATTENTION_DISABLE_SM8x
  120. #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \
  121. [&] { \
  122. constexpr static int ARCH_NAME = 90; \
  123. return __VA_ARGS__(); \
  124. }()
  125. #else
  126. #define ARCH_SWITCH(ARCH, ARCH_NAME, ...) \
  127. [&] { \
  128. if (ARCH == 86 || ARCH == 89) { \
  129. constexpr static int ARCH_NAME = 86; \
  130. return __VA_ARGS__(); \
  131. } else if (ARCH < 90) { \
  132. constexpr static int ARCH_NAME = 80; \
  133. return __VA_ARGS__(); \
  134. } else { \
  135. constexpr static int ARCH_NAME = 90; \
  136. return __VA_ARGS__(); \
  137. } \
  138. }()
  139. #endif
  140. #ifndef FLASHATTENTION_ENABLE_VCOLMAJOR
  141. #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \
  142. [&] { \
  143. constexpr static bool CONST_NAME = false; \
  144. return __VA_ARGS__(); \
  145. }()
  146. #else
  147. #define VCOLMAJOR_SWITCH BOOL_SWITCH
  148. #endif
  149. #define HEADDIM_SWITCH(HEADDIM, ...) \
  150. [&] { \
  151. if (HEADDIM == 64) { \
  152. constexpr static int kHeadSize = 64; \
  153. return __VA_ARGS__(); \
  154. } else if (HEADDIM == 96) { \
  155. constexpr static int kHeadSize = 96; \
  156. return __VA_ARGS__(); \
  157. } else if (HEADDIM == 128) { \
  158. constexpr static int kHeadSize = 128; \
  159. return __VA_ARGS__(); \
  160. } else if (HEADDIM == 96) { \
  161. constexpr static int kHeadSize = 96; \
  162. return __VA_ARGS__(); \
  163. } else if (HEADDIM == 256) { \
  164. constexpr static int kHeadSize = 256; \
  165. return __VA_ARGS__(); \
  166. } \
  167. }()