dispatch_utils.h 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. /*
  2. * Adapted from
  3. * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
  4. */
  5. #pragma once
  6. #include <torch/all.h>
  7. #define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...) \
  8. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  9. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
  10. AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
  11. #define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  12. AT_DISPATCH_SWITCH(TYPE, NAME, \
  13. APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
  14. #define APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
  15. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  16. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
  17. AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
  18. AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
  19. #define APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
  20. AT_DISPATCH_SWITCH( \
  21. TYPE, NAME, \
  22. APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
  23. #define APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(...) \
  24. AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
  25. AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
  26. AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
  27. AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
  28. AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
  29. #define APHRODITE_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
  30. AT_DISPATCH_SWITCH(TYPE, NAME, \
  31. APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))