selective_scan.h 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. // clang-format off
  5. // adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
  6. #pragma once
  7. #ifndef USE_ROCM
  8. #include <cuda_bf16.h>
  9. #else
  10. #include <hip/hip_bf16.h>
  11. #endif
  12. #include <cuda_fp16.h>
  13. ////////////////////////////////////////////////////////////////////////////////////////////////////
  14. struct SSMParamsBase {
  15. using index_t = uint32_t;
  16. int batch, dim, seqlen, dstate, n_groups, n_chunks;
  17. int dim_ngroups_ratio;
  18. bool is_variable_B;
  19. bool is_variable_C;
  20. bool delta_softplus;
  21. index_t A_d_stride;
  22. index_t A_dstate_stride;
  23. index_t B_batch_stride;
  24. index_t B_d_stride;
  25. index_t B_dstate_stride;
  26. index_t B_group_stride;
  27. index_t C_batch_stride;
  28. index_t C_d_stride;
  29. index_t C_dstate_stride;
  30. index_t C_group_stride;
  31. index_t u_batch_stride;
  32. index_t u_d_stride;
  33. index_t delta_batch_stride;
  34. index_t delta_d_stride;
  35. index_t z_batch_stride;
  36. index_t z_d_stride;
  37. index_t out_batch_stride;
  38. index_t out_d_stride;
  39. index_t out_z_batch_stride;
  40. index_t out_z_d_stride;
  41. // Common data pointers.
  42. void *__restrict__ A_ptr;
  43. void *__restrict__ B_ptr;
  44. void *__restrict__ C_ptr;
  45. void *__restrict__ D_ptr;
  46. void *__restrict__ u_ptr;
  47. void *__restrict__ delta_ptr;
  48. void *__restrict__ delta_bias_ptr;
  49. void *__restrict__ out_ptr;
  50. void *__restrict__ x_ptr;
  51. void *__restrict__ z_ptr;
  52. void *__restrict__ out_z_ptr;
  53. void *__restrict__ index_ptr;
  54. };
  55. #ifndef USE_ROCM
  56. constexpr size_t custom_max(std::initializer_list<size_t> ilist)
  57. {
  58. return std::max(ilist);
  59. }
  60. template<typename T>
  61. constexpr T constexpr_min(T a, T b) {
  62. return std::min(a, b);
  63. }
  64. #else
  65. constexpr size_t custom_max(std::initializer_list<size_t> ilist)
  66. {
  67. return *std::max_element(ilist.begin(), ilist.end());
  68. }
  69. template<typename T>
  70. constexpr T constexpr_min(T a, T b) {
  71. return a < b ? a : b;
  72. }
  73. #endif
  74. #define MAX_DSTATE 256
  75. inline __device__ float2 operator+(const float2 & a, const float2 & b){
  76. return {a.x + b.x, a.y + b.y};
  77. }
  78. inline __device__ float3 operator+(const float3 &a, const float3 &b) {
  79. return {a.x + b.x, a.y + b.y, a.z + b.z};
  80. }
  81. inline __device__ float4 operator+(const float4 & a, const float4 & b){
  82. return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
  83. }
  84. ////////////////////////////////////////////////////////////////////////////////////////////////////
  85. template<int BYTES> struct BytesToType {};
  86. template<> struct BytesToType<16> {
  87. using Type = uint4;
  88. static_assert(sizeof(Type) == 16);
  89. };
  90. template<> struct BytesToType<8> {
  91. using Type = uint64_t;
  92. static_assert(sizeof(Type) == 8);
  93. };
  94. template<> struct BytesToType<4> {
  95. using Type = uint32_t;
  96. static_assert(sizeof(Type) == 4);
  97. };
  98. template<> struct BytesToType<2> {
  99. using Type = uint16_t;
  100. static_assert(sizeof(Type) == 2);
  101. };
  102. template<> struct BytesToType<1> {
  103. using Type = uint8_t;
  104. static_assert(sizeof(Type) == 1);
  105. };
  106. ////////////////////////////////////////////////////////////////////////////////////////////////////
  107. template<typename scalar_t, int N>
  108. struct Converter{
  109. static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
  110. #pragma unroll
  111. for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
  112. }
  113. };
  114. template<int N>
  115. struct Converter<at::Half, N>{
  116. static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
  117. static_assert(N % 2 == 0);
  118. auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
  119. auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
  120. #pragma unroll
  121. for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
  122. }
  123. };
  124. #if __CUDA_ARCH__ >= 800
  125. template<int N>
  126. struct Converter<at::BFloat16, N>{
  127. static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
  128. static_assert(N % 2 == 0);
  129. auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
  130. auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
  131. #pragma unroll
  132. for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
  133. }
  134. };
  135. #endif
  136. ////////////////////////////////////////////////////////////////////////////////////////////////////
  137. template<typename scalar_t> struct SSMScanOp;
  138. template<>
  139. struct SSMScanOp<float> {
  140. __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
  141. return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
  142. }
  143. };
  144. // A stateful callback functor that maintains a running prefix to be applied
  145. // during consecutive scan operations.
  146. template <typename scalar_t> struct SSMScanPrefixCallbackOp {
  147. using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
  148. scan_t running_prefix;
  149. // Constructor
  150. __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
  151. // Callback operator to be entered by the first warp of threads in the block.
  152. // Thread-0 is responsible for returning a value for seeding the block-wide scan.
  153. __device__ scan_t operator()(scan_t block_aggregate) {
  154. scan_t old_prefix = running_prefix;
  155. running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
  156. return old_prefix;
  157. }
  158. };
  159. ////////////////////////////////////////////////////////////////////////////////////////////////////
  160. template<typename Ktraits>
  161. inline __device__ void load_input(typename Ktraits::input_t *u,
  162. typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
  163. typename Ktraits::BlockLoadT::TempStorage &smem_load,
  164. int seqlen) {
  165. if constexpr (Ktraits::kIsEvenLen) {
  166. auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
  167. using vec_t = typename Ktraits::vec_t;
  168. typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
  169. reinterpret_cast<vec_t*>(u),
  170. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
  171. #ifdef USE_ROCM
  172. , Ktraits::kNThreads * Ktraits::kNLoads
  173. #endif
  174. );
  175. } else {
  176. typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
  177. }
  178. }
  179. template<typename Ktraits>
  180. inline __device__ void load_index(int *u,
  181. int (&u_vals)[Ktraits::kNItems],
  182. typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
  183. int seqlen) {
  184. if constexpr (Ktraits::kIsEvenLen) {
  185. auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
  186. Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
  187. reinterpret_cast<uint4*>(u),
  188. reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
  189. );
  190. } else {
  191. Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
  192. }
  193. }
  194. template<typename Ktraits>
  195. inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
  196. typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
  197. typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
  198. int seqlen) {
  199. constexpr int kNItems = Ktraits::kNItems;
  200. typename Ktraits::input_t B_vals_load[kNItems];
  201. if constexpr (Ktraits::kIsEvenLen) {
  202. auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
  203. using vec_t = typename Ktraits::vec_t;
  204. typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
  205. reinterpret_cast<vec_t*>(Bvar),
  206. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
  207. );
  208. } else {
  209. typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
  210. }
  211. // #pragma unroll
  212. // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
  213. Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
  214. }
  215. template<typename Ktraits>
  216. inline __device__ void store_output(typename Ktraits::input_t *out,
  217. const float (&out_vals)[Ktraits::kNItems],
  218. typename Ktraits::BlockStoreT::TempStorage &smem_store,
  219. int seqlen) {
  220. typename Ktraits::input_t write_vals[Ktraits::kNItems];
  221. #pragma unroll
  222. for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
  223. if constexpr (Ktraits::kIsEvenLen) {
  224. auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
  225. using vec_t = typename Ktraits::vec_t;
  226. typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
  227. reinterpret_cast<vec_t*>(out),
  228. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
  229. );
  230. } else {
  231. typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
  232. }
  233. }