static_switch.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. //
  16. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  17. [&] { \
  18. if (COND) { \
  19. constexpr static bool CONST_NAME = true; \
  20. return __VA_ARGS__(); \
  21. } else { \
  22. constexpr static bool CONST_NAME = false; \
  23. return __VA_ARGS__(); \
  24. } \
  25. }()
  26. #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
  27. [&] { \
  28. if (CAUSAL_COND) { \
  29. constexpr static bool CAUSAL_CONST_NAME = true; \
  30. constexpr static bool LOCAL_CONST_NAME = false; \
  31. return __VA_ARGS__(); \
  32. } else if (LOCAL_COND) { \
  33. constexpr static bool CAUSAL_CONST_NAME = false; \
  34. constexpr static bool LOCAL_CONST_NAME = true; \
  35. return __VA_ARGS__(); \
  36. } else { \
  37. constexpr static bool CAUSAL_CONST_NAME = false; \
  38. constexpr static bool LOCAL_CONST_NAME = false; \
  39. return __VA_ARGS__(); \
  40. } \
  41. }()
  42. #define PREC_SWITCH(PRECTYPE, ...) \
  43. [&] { \
  44. if (PRECTYPE == 1) { \
  45. using kPrecType = cutlass::half_t; \
  46. constexpr static bool kSoftFp16 = false; \
  47. constexpr static bool kHybrid = false; \
  48. return __VA_ARGS__(); \
  49. } else if (PRECTYPE == 2) { \
  50. using kPrecType = cutlass::float_e4m3_t; \
  51. constexpr static bool kSoftFp16 = false; \
  52. constexpr static bool kHybrid = false; \
  53. return __VA_ARGS__(); \
  54. } else if (PRECTYPE == 3) { \
  55. using kPrecType = cutlass::float_e4m3_t; \
  56. constexpr static bool kSoftFp16 = false; \
  57. constexpr static bool kHybrid = true; \
  58. return __VA_ARGS__(); \
  59. } else if (PRECTYPE == 4) { \
  60. using kPrecType = cutlass::float_e4m3_t; \
  61. constexpr static bool kSoftFp16 = true; \
  62. constexpr static bool kHybrid = false; \
  63. return __VA_ARGS__(); \
  64. } \
  65. }()
  66. #define HEADDIM_SWITCH(HEADDIM, ...) \
  67. [&] { \
  68. if (HEADDIM == 64) { \
  69. constexpr static int kHeadSize = 64; \
  70. return __VA_ARGS__(); \
  71. } else if (HEADDIM == 96) { \
  72. constexpr static int kHeadSize = 96; \
  73. return __VA_ARGS__(); \
  74. } else if (HEADDIM == 128) { \
  75. constexpr static int kHeadSize = 128; \
  76. return __VA_ARGS__(); \
  77. } else if (HEADDIM == 96) { \
  78. constexpr static int kHeadSize = 96; \
  79. return __VA_ARGS__(); \
  80. } else if (HEADDIM == 256) { \
  81. constexpr static int kHeadSize = 256; \
  82. return __VA_ARGS__(); \
  83. } \
  84. }()
  85. #define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, SEQ_LEN_OUT_OF_BOUND_CHECK, ...) \
  86. [&] { \
  87. if (!USE_VAR_SEQ_LEN) { \
  88. if (SEQ_LEN_OUT_OF_BOUND_CHECK) { \
  89. using kSeqLenTraitsType = FixedSeqLenTraits<true>; \
  90. return __VA_ARGS__(); \
  91. } else { \
  92. using kSeqLenTraitsType = FixedSeqLenTraits<false>; \
  93. return __VA_ARGS__(); \
  94. } \
  95. } else { \
  96. using kSeqLenTraitsType = VarSeqLenTraits; \
  97. return __VA_ARGS__(); \
  98. } \
  99. }()