1
0

static_switch.h 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  16. [&] { \
  17. if (COND) { \
  18. constexpr static bool CONST_NAME = true; \
  19. return __VA_ARGS__(); \
  20. } else { \
  21. constexpr static bool CONST_NAME = false; \
  22. return __VA_ARGS__(); \
  23. } \
  24. }()
  25. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  26. #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
  27. [&] { \
  28. constexpr static bool CONST_NAME = false; \
  29. return __VA_ARGS__(); \
  30. }()
  31. #else
  32. #define DROPOUT_SWITCH BOOL_SWITCH
  33. #endif
  34. #ifdef FLASHATTENTION_DISABLE_ALIBI
  35. #define ALIBI_SWITCH(COND, CONST_NAME, ...) \
  36. [&] { \
  37. constexpr static bool CONST_NAME = false; \
  38. return __VA_ARGS__(); \
  39. }()
  40. #else
  41. #define ALIBI_SWITCH BOOL_SWITCH
  42. #endif
  43. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  44. #define EVENK_SWITCH(COND, CONST_NAME, ...) \
  45. [&] { \
  46. constexpr static bool CONST_NAME = true; \
  47. return __VA_ARGS__(); \
  48. }()
  49. #else
  50. #define EVENK_SWITCH BOOL_SWITCH
  51. #endif
  52. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  53. #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
  54. [&] { \
  55. constexpr static bool CONST_NAME = false; \
  56. return __VA_ARGS__(); \
  57. }()
  58. #else
  59. #define SOFTCAP_SWITCH BOOL_SWITCH
  60. #endif
  61. #ifdef FLASHATTENTION_DISABLE_LOCAL
  62. #define LOCAL_SWITCH(COND, CONST_NAME, ...) \
  63. [&] { \
  64. constexpr static bool CONST_NAME = false; \
  65. return __VA_ARGS__(); \
  66. }()
  67. #else
  68. #define LOCAL_SWITCH BOOL_SWITCH
  69. #endif
  70. #define FP16_SWITCH(COND, ...) \
  71. [&] { \
  72. if (COND) { \
  73. using elem_type = cutlass::half_t; \
  74. return __VA_ARGS__(); \
  75. } else { \
  76. using elem_type = cutlass::bfloat16_t; \
  77. return __VA_ARGS__(); \
  78. } \
  79. }()
  80. #define HEADDIM_SWITCH(HEADDIM, ...) \
  81. [&] { \
  82. if (HEADDIM <= 32) { \
  83. constexpr static int kHeadDim = 32; \
  84. return __VA_ARGS__(); \
  85. } else if (HEADDIM <= 64) { \
  86. constexpr static int kHeadDim = 64; \
  87. return __VA_ARGS__(); \
  88. } else if (HEADDIM <= 96) { \
  89. constexpr static int kHeadDim = 96; \
  90. return __VA_ARGS__(); \
  91. } else if (HEADDIM <= 128) { \
  92. constexpr static int kHeadDim = 128; \
  93. return __VA_ARGS__(); \
  94. } else if (HEADDIM <= 160) { \
  95. constexpr static int kHeadDim = 160; \
  96. return __VA_ARGS__(); \
  97. } else if (HEADDIM <= 192) { \
  98. constexpr static int kHeadDim = 192; \
  99. return __VA_ARGS__(); \
  100. } else if (HEADDIM <= 224) { \
  101. constexpr static int kHeadDim = 224; \
  102. return __VA_ARGS__(); \
  103. } else if (HEADDIM <= 256) { \
  104. constexpr static int kHeadDim = 256; \
  105. return __VA_ARGS__(); \
  106. } \
  107. }()