123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- // Inspired by
- // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
- // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
- #pragma once
- /// @param COND - a boolean expression to switch by
- /// @param CONST_NAME - a name given for the constexpr bool variable.
- /// @param ... - code to execute for true and false
- ///
- /// Usage:
- /// ```
- /// BOOL_SWITCH(flag, BoolConst, [&] {
- /// some_function<BoolConst>(...);
- /// });
- /// ```
- //
- #define BOOL_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- if (COND) { \
- constexpr static bool CONST_NAME = true; \
- return __VA_ARGS__(); \
- } else { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- } \
- }()
- #ifdef FLASHATTENTION_DISABLE_LOCAL
- #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
- [&] { \
- constexpr static bool LOCAL_CONST_NAME = false; \
- if (CAUSAL_COND) { \
- constexpr static bool CAUSAL_CONST_NAME = true; \
- return __VA_ARGS__(); \
- } else { \
- constexpr static bool CAUSAL_CONST_NAME = false; \
- return __VA_ARGS__(); \
- } \
- }()
- #else
- #define CAUSAL_LOCAL_SWITCH(CAUSAL_COND, LOCAL_COND, CAUSAL_CONST_NAME, LOCAL_CONST_NAME, ...) \
- [&] { \
- if (CAUSAL_COND) { \
- constexpr static bool CAUSAL_CONST_NAME = true; \
- constexpr static bool LOCAL_CONST_NAME = false; \
- return __VA_ARGS__(); \
- } else if (LOCAL_COND) { \
- constexpr static bool CAUSAL_CONST_NAME = false; \
- constexpr static bool LOCAL_CONST_NAME = true; \
- return __VA_ARGS__(); \
- } else { \
- constexpr static bool CAUSAL_CONST_NAME = false; \
- constexpr static bool LOCAL_CONST_NAME = false; \
- return __VA_ARGS__(); \
- } \
- }()
- #endif
- #ifdef FLASHATTENTION_DISABLE_SOFTCAP
- #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define SOFTCAP_SWITCH BOOL_SWITCH
- #endif
- #ifdef FLASHATTENTION_DISABLE_PAGEDKV
- #define PAGEDKV_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define PAGEDKV_SWITCH BOOL_SWITCH
- #endif
- #ifdef FLASHATTENTION_DISABLE_SPLIT
- #define SPLIT_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define SPLIT_SWITCH BOOL_SWITCH
- #endif
- #ifdef FLASHATTENTION_DISABLE_APPENDKV
- #define APPENDKV_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define APPENDKV_SWITCH BOOL_SWITCH
- #endif
- #ifdef FLASHATTENTION_DISABLE_PACKGQA
- #define PACKGQA_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define PACKGQA_SWITCH BOOL_SWITCH
- #endif
- #ifdef FLASHATTENTION_DISABLE_VARLEN
- #define VARLEN_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define VARLEN_SWITCH BOOL_SWITCH
- #endif
- #ifdef FLASHATTENTION_DISABLE_CLUSTER
- #define CLUSTER_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define CLUSTER_SWITCH BOOL_SWITCH
- #endif
- #ifndef FLASHATTENTION_ENABLE_VCOLMAJOR
- #define VCOLMAJOR_SWITCH(COND, CONST_NAME, ...) \
- [&] { \
- constexpr static bool CONST_NAME = false; \
- return __VA_ARGS__(); \
- }()
- #else
- #define VCOLMAJOR_SWITCH BOOL_SWITCH
- #endif
- #define HEADDIM_SWITCH(HEADDIM, ...) \
- [&] { \
- if (HEADDIM == 64) { \
- constexpr static int kHeadSize = 64; \
- return __VA_ARGS__(); \
- } else if (HEADDIM == 96) { \
- constexpr static int kHeadSize = 96; \
- return __VA_ARGS__(); \
- } else if (HEADDIM == 128) { \
- constexpr static int kHeadSize = 128; \
- return __VA_ARGS__(); \
- } else if (HEADDIM == 96) { \
- constexpr static int kHeadSize = 96; \
- return __VA_ARGS__(); \
- } else if (HEADDIM == 256) { \
- constexpr static int kHeadSize = 256; \
- return __VA_ARGS__(); \
- } \
- }()
|