1234567891011121314151617181920212223242526272829303132333435363738 |
- /*
- * Adapted from
- * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
- */
- #pragma once
- #include <torch/all.h>
- #define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
- #define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH(TYPE, NAME, \
- APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
- #define APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
- #define APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH( \
- TYPE, NAME, \
- APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
- #define APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(...) \
- AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
- AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
- #define APHRODITE_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
- AT_DISPATCH_SWITCH(TYPE, NAME, \
- APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|