1
0

causal_conv1d.h 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <cuda_bf16.h>
  6. #include <cuda_fp16.h>
  7. ////////////////////////////////////////////////////////////////////////////////////////////////////
  8. struct ConvParamsBase {
  9. using index_t = uint32_t;
  10. int batch, dim, seqlen, width;
  11. bool silu_activation;
  12. index_t x_batch_stride;
  13. index_t x_c_stride;
  14. index_t x_l_stride;
  15. index_t weight_c_stride;
  16. index_t weight_width_stride;
  17. index_t out_batch_stride;
  18. index_t out_c_stride;
  19. index_t out_l_stride;
  20. index_t conv_state_batch_stride;
  21. index_t conv_state_c_stride;
  22. index_t conv_state_l_stride;
  23. // Common data pointers.
  24. void* __restrict__ x_ptr;
  25. void* __restrict__ weight_ptr;
  26. void* __restrict__ bias_ptr;
  27. void* __restrict__ out_ptr;
  28. void* __restrict__ conv_state_ptr;
  29. void* __restrict__ seq_idx_ptr;
  30. void* __restrict__ seq_pos_idx_ptr;
  31. // No __restrict__ since initial_states could be the same as final_states.
  32. void* initial_states_ptr;
  33. index_t initial_states_batch_stride;
  34. index_t initial_states_l_stride;
  35. index_t initial_states_c_stride;
  36. void* final_states_ptr;
  37. index_t final_states_batch_stride;
  38. index_t final_states_l_stride;
  39. index_t final_states_c_stride;
  40. };
  41. template <int BYTES>
  42. struct BytesToType {};
  43. template <>
  44. struct BytesToType<16> {
  45. using Type = uint4;
  46. static_assert(sizeof(Type) == 16);
  47. };
  48. template <>
  49. struct BytesToType<8> {
  50. using Type = uint64_t;
  51. static_assert(sizeof(Type) == 8);
  52. };
  53. template <>
  54. struct BytesToType<4> {
  55. using Type = uint32_t;
  56. static_assert(sizeof(Type) == 4);
  57. };
  58. template <>
  59. struct BytesToType<2> {
  60. using Type = uint16_t;
  61. static_assert(sizeof(Type) == 2);
  62. };
  63. template <>
  64. struct BytesToType<1> {
  65. using Type = uint8_t;
  66. static_assert(sizeof(Type) == 1);
  67. };
  68. ////////////////////////////////////////////////////////////////////////////////////////////////////
  69. template <typename T>
  70. struct SumOp {
  71. __device__ inline T operator()(T const& x, T const& y) { return x + y; }
  72. };
  73. template <int THREADS>
  74. struct Allreduce {
  75. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  76. template <typename T, typename Operator>
  77. static __device__ inline T run(T x, Operator& op) {
  78. constexpr int OFFSET = THREADS / 2;
  79. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  80. return Allreduce<OFFSET>::run(x, op);
  81. }
  82. };
  83. template <>
  84. struct Allreduce<2> {
  85. template <typename T, typename Operator>
  86. static __device__ inline T run(T x, Operator& op) {
  87. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  88. return x;
  89. }
  90. };