static_switch.h 4.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. // Inspired by
  2. // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
  3. // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
  4. #pragma once
  5. /// @param COND - a boolean expression to switch by
  6. /// @param CONST_NAME - a name given for the constexpr bool variable.
  7. /// @param ... - code to execute for true and false
  8. ///
  9. /// Usage:
  10. /// ```
  11. /// BOOL_SWITCH(flag, BoolConst, [&] {
  12. /// some_function<BoolConst>(...);
  13. /// });
  14. /// ```
  15. //
  16. #define BOOL_SWITCH(COND, CONST_NAME, ...) \
  17. [&] { \
  18. if (COND) { \
  19. constexpr static bool CONST_NAME = true; \
  20. return __VA_ARGS__(); \
  21. } else { \
  22. constexpr static bool CONST_NAME = false; \
  23. return __VA_ARGS__(); \
  24. } \
  25. }()
  26. #define PREC_SWITCH(PRECTYPE, ...) \
  27. [&] { \
  28. if (PRECTYPE == 1) { \
  29. using kPrecType = cutlass::half_t; \
  30. constexpr static bool kSoftFp16 = false; \
  31. constexpr static bool kHybrid = false; \
  32. return __VA_ARGS__(); \
  33. } else if (PRECTYPE == 2) { \
  34. using kPrecType = cutlass::float_e4m3_t; \
  35. constexpr static bool kSoftFp16 = false; \
  36. constexpr static bool kHybrid = false; \
  37. return __VA_ARGS__(); \
  38. } else if (PRECTYPE == 3) { \
  39. using kPrecType = cutlass::float_e4m3_t; \
  40. constexpr static bool kSoftFp16 = false; \
  41. constexpr static bool kHybrid = true; \
  42. return __VA_ARGS__(); \
  43. } else if (PRECTYPE == 4) { \
  44. using kPrecType = cutlass::float_e4m3_t; \
  45. constexpr static bool kSoftFp16 = true; \
  46. constexpr static bool kHybrid = false; \
  47. return __VA_ARGS__(); \
  48. } \
  49. }()
  50. #define HEADDIM_SWITCH(HEADDIM, ...) \
  51. [&] { \
  52. if (HEADDIM == 64) { \
  53. constexpr static int kHeadSize = 64; \
  54. return __VA_ARGS__(); \
  55. } else if (HEADDIM == 128) { \
  56. constexpr static int kHeadSize = 128; \
  57. return __VA_ARGS__(); \
  58. } else if (HEADDIM == 256) { \
  59. constexpr static int kHeadSize = 256; \
  60. return __VA_ARGS__(); \
  61. } \
  62. }()
  63. #define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, NAME, ...) \
  64. [&] { \
  65. bool useSeqLen = USE_VAR_SEQ_LEN; \
  66. if (useSeqLen) { \
  67. using NAME = flash::VarSeqLenTraits; \
  68. return __VA_ARGS__(); \
  69. } else { \
  70. using NAME = flash::FixedSeqLenTraits; \
  71. return __VA_ARGS__(); \
  72. } \
  73. }()