causal_conv1d_common.h 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda_bf16.h>
  6. #include <cuda_fp16.h>
  7. ////////////////////////////////////////////////////////////////////////////////////////////////////
  8. template<int BYTES> struct BytesToType {};
  9. template<> struct BytesToType<16> {
  10. using Type = uint4;
  11. static_assert(sizeof(Type) == 16);
  12. };
  13. template<> struct BytesToType<8> {
  14. using Type = uint64_t;
  15. static_assert(sizeof(Type) == 8);
  16. };
  17. template<> struct BytesToType<4> {
  18. using Type = uint32_t;
  19. static_assert(sizeof(Type) == 4);
  20. };
  21. template<> struct BytesToType<2> {
  22. using Type = uint16_t;
  23. static_assert(sizeof(Type) == 2);
  24. };
  25. template<> struct BytesToType<1> {
  26. using Type = uint8_t;
  27. static_assert(sizeof(Type) == 1);
  28. };
  29. ////////////////////////////////////////////////////////////////////////////////////////////////////
  30. template<typename T>
  31. struct SumOp {
  32. __device__ inline T operator()(T const & x, T const & y) { return x + y; }
  33. };
  34. template<int THREADS>
  35. struct Allreduce {
  36. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  37. template<typename T, typename Operator>
  38. static __device__ inline T run(T x, Operator &op) {
  39. constexpr int OFFSET = THREADS / 2;
  40. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  41. return Allreduce<OFFSET>::run(x, op);
  42. }
  43. };
  44. template<>
  45. struct Allreduce<2> {
  46. template<typename T, typename Operator>
  47. static __device__ inline T run(T x, Operator &op) {
  48. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  49. return x;
  50. }
  51. };