static_switch.h 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. //
  16. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  17. [&] { \
  18. if (COND) { \
  19. constexpr static bool CONST_NAME = true; \
  20. return __VA_ARGS__(); \
  21. } else { \
  22. constexpr static bool CONST_NAME = false; \
  23. return __VA_ARGS__(); \
  24. } \
  25. }()
  26. #ifdef FLASHATTENTION_DISABLE_LOCAL
  27. #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
  28. [&] { \
  29. constexpr static bool LOCAL_CONST_NAME = false; \
  30. if (CAUSAL_COND) { \
  31. constexpr static bool CAUSAL_CONST_NAME = true; \
  32. return __VA_ARGS__(); \
  33. } else { \
  34. constexpr static bool CAUSAL_CONST_NAME = false; \
  35. return __VA_ARGS__(); \
  36. } \
  37. }()
  38. #else
  39. #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
  40. [&] { \
  41. if (CAUSAL_COND) { \
  42. constexpr static bool CAUSAL_CONST_NAME = true; \
  43. constexpr static bool LOCAL_CONST_NAME = false; \
  44. return __VA_ARGS__(); \
  45. } else if (LOCAL_COND) { \
  46. constexpr static bool CAUSAL_CONST_NAME = false; \
  47. constexpr static bool LOCAL_CONST_NAME = true; \
  48. return __VA_ARGS__(); \
  49. } else { \
  50. constexpr static bool CAUSAL_CONST_NAME = false; \
  51. constexpr static bool LOCAL_CONST_NAME = false; \
  52. return __VA_ARGS__(); \
  53. } \
  54. }()
  55. #endif
  56. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  57. #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
  58. [&] { \
  59. constexpr static bool CONST_NAME = false; \
  60. return __VA_ARGS__(); \
  61. }()
  62. #else
  63. #define SOFTCAP_SWITCH BOOL_SWITCH
  64. #endif
  65. #ifdef FLASHATTENTION_DISABLE_PAGEDKV
  66. #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \
  67. [&] { \
  68. constexpr static bool CONST_NAME = false; \
  69. return __VA_ARGS__(); \
  70. }()
  71. #else
  72. #define PAGEDKV_SWITCH BOOL_SWITCH
  73. #endif
  74. #ifdef FLASHATTENTION_DISABLE_SPLIT
  75. #define SPLIT_SWITCH(COND, CONST_NAME, ...) \
  76. [&] { \
  77. constexpr static bool CONST_NAME = false; \
  78. return __VA_ARGS__(); \
  79. }()
  80. #else
  81. #define SPLIT_SWITCH BOOL_SWITCH
  82. #endif
  83. #ifdef FLASHATTENTION_DISABLE_APPENDKV
  84. #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \
  85. [&] { \
  86. constexpr static bool CONST_NAME = false; \
  87. return __VA_ARGS__(); \
  88. }()
  89. #else
  90. #define APPENDKV_SWITCH BOOL_SWITCH
  91. #endif
  92. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  93. #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \
  94. [&] { \
  95. constexpr static bool CONST_NAME = false; \
  96. return __VA_ARGS__(); \
  97. }()
  98. #else
  99. #define PACKGQA_SWITCH BOOL_SWITCH
  100. #endif
  101. #ifdef FLASHATTENTION_DISABLE_VARLEN
  102. #define VARLEN_SWITCH(COND, CONST_NAME, ...) \
  103. [&] { \
  104. constexpr static bool CONST_NAME = false; \
  105. return __VA_ARGS__(); \
  106. }()
  107. #else
  108. #define VARLEN_SWITCH BOOL_SWITCH
  109. #endif
  110. #ifdef FLASHATTENTION_DISABLE_CLUSTER
  111. #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \
  112. [&] { \
  113. constexpr static bool CONST_NAME = false; \
  114. return __VA_ARGS__(); \
  115. }()
  116. #else
  117. #define CLUSTER_SWITCH BOOL_SWITCH
  118. #endif
  119. #ifndef FLASHATTENTION_ENABLE_VCOLMAJOR
  120. #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \
  121. [&] { \
  122. constexpr static bool CONST_NAME = false; \
  123. return __VA_ARGS__(); \
  124. }()
  125. #else
  126. #define VCOLMAJOR_SWITCH BOOL_SWITCH
  127. #endif
  128. #define HEADDIM_SWITCH(HEADDIM, ...) \
  129. [&] { \
  130. if (HEADDIM == 64) { \
  131. constexpr static int kHeadSize = 64; \
  132. return __VA_ARGS__(); \
  133. } else if (HEADDIM == 96) { \
  134. constexpr static int kHeadSize = 96; \
  135. return __VA_ARGS__(); \
  136. } else if (HEADDIM == 128) { \
  137. constexpr static int kHeadSize = 128; \
  138. return __VA_ARGS__(); \
  139. } else if (HEADDIM == 96) { \
  140. constexpr static int kHeadSize = 96; \
  141. return __VA_ARGS__(); \
  142. } else if (HEADDIM == 256) { \
  143. constexpr static int kHeadSize = 256; \
  144. return __VA_ARGS__(); \
  145. } \
  146. }()