causal_conv1d.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. // clang-format off
  5. // adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
  6. #pragma once
  7. #include <cuda_bf16.h>
  8. #include <cuda_fp16.h>
  9. ////////////////////////////////////////////////////////////////////////////////////////////////////
  10. struct ConvParamsBase {
  11. using index_t = uint32_t;
  12. int batch, dim, seqlen, width;
  13. bool silu_activation;
  14. index_t x_batch_stride;
  15. index_t x_c_stride;
  16. index_t x_l_stride;
  17. index_t weight_c_stride;
  18. index_t weight_width_stride;
  19. index_t out_batch_stride;
  20. index_t out_c_stride;
  21. index_t out_l_stride;
  22. index_t conv_state_batch_stride;
  23. index_t conv_state_c_stride;
  24. index_t conv_state_l_stride;
  25. // Common data pointers.
  26. void *__restrict__ x_ptr;
  27. void *__restrict__ weight_ptr;
  28. void *__restrict__ bias_ptr;
  29. void *__restrict__ out_ptr;
  30. void *__restrict__ conv_state_ptr;
  31. // For the continuous batching case. Makes it so that the mamba state for
  32. // the current batch doesn't need to be a contiguous tensor.
  33. int32_t *__restrict__ conv_state_indices_ptr;
  34. void *__restrict__ seq_idx_ptr;
  35. // No __restrict__ since initial_states could be the same as final_states.
  36. void * initial_states_ptr;
  37. index_t initial_states_batch_stride;
  38. index_t initial_states_l_stride;
  39. index_t initial_states_c_stride;
  40. void * final_states_ptr;
  41. index_t final_states_batch_stride;
  42. index_t final_states_l_stride;
  43. index_t final_states_c_stride;
  44. };
  45. #ifndef USE_ROCM
  46. #include <cuda_bf16.h>
  47. template<typename T>
  48. __device__ inline T shuffle_xor(T val, int offset) {
  49. return __shfl_xor_sync(uint32_t(-1), val, offset);
  50. }
  51. constexpr size_t custom_max(std::initializer_list<size_t> ilist)
  52. {
  53. return std::max(ilist);
  54. }
  55. template<typename T>
  56. constexpr T constexpr_min(T a, T b) {
  57. return std::min(a, b);
  58. }
  59. #else
  60. #include <hip/hip_bf16.h>
  61. template<typename T>
  62. __device__ inline T shuffle_xor(T val, int offset) {
  63. return __shfl_xor(val, offset);
  64. }
  65. constexpr size_t custom_max(std::initializer_list<size_t> ilist)
  66. {
  67. return *std::max_element(ilist.begin(), ilist.end());
  68. }
  69. template<typename T>
  70. constexpr T constexpr_min(T a, T b) {
  71. return a < b ? a : b;
  72. }
  73. #endif
  74. ////////////////////////////////////////////////////////////////////////////////////////////////////
  75. template<int BYTES> struct BytesToType {};
  76. template<> struct BytesToType<16> {
  77. using Type = uint4;
  78. static_assert(sizeof(Type) == 16);
  79. };
  80. template<> struct BytesToType<8> {
  81. using Type = uint64_t;
  82. static_assert(sizeof(Type) == 8);
  83. };
  84. template<> struct BytesToType<4> {
  85. using Type = uint32_t;
  86. static_assert(sizeof(Type) == 4);
  87. };
  88. template<> struct BytesToType<2> {
  89. using Type = uint16_t;
  90. static_assert(sizeof(Type) == 2);
  91. };
  92. template<> struct BytesToType<1> {
  93. using Type = uint8_t;
  94. static_assert(sizeof(Type) == 1);
  95. };
  96. ////////////////////////////////////////////////////////////////////////////////////////////////////
  97. template<typename T>
  98. struct SumOp {
  99. __device__ inline T operator()(T const & x, T const & y) { return x + y; }
  100. };
  101. template<int THREADS>
  102. struct Allreduce {
  103. static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
  104. template<typename T, typename Operator>
  105. static __device__ inline T run(T x, Operator &op) {
  106. constexpr int OFFSET = THREADS / 2;
  107. x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
  108. return Allreduce<OFFSET>::run(x, op);
  109. }
  110. };
  111. template<>
  112. struct Allreduce<2> {
  113. template<typename T, typename Operator>
  114. static __device__ inline T run(T x, Operator &op) {
  115. x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
  116. return x;
  117. }
  118. };