selective_scan.h 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #ifndef USE_ROCM
  6. #include <cuda_bf16.h>
  7. #else
  8. #include <hip/hip_bf16.h>
  9. #endif
  10. #include <cuda_fp16.h>
  11. ////////////////////////////////////////////////////////////////////////////////////////////////////
  12. struct SSMParamsBase {
  13. using index_t = uint32_t;
  14. int batch, dim, seqlen, dstate, n_groups, n_chunks;
  15. int dim_ngroups_ratio;
  16. bool is_variable_B;
  17. bool is_variable_C;
  18. bool delta_softplus;
  19. index_t A_d_stride;
  20. index_t A_dstate_stride;
  21. index_t B_batch_stride;
  22. index_t B_d_stride;
  23. index_t B_dstate_stride;
  24. index_t B_group_stride;
  25. index_t C_batch_stride;
  26. index_t C_d_stride;
  27. index_t C_dstate_stride;
  28. index_t C_group_stride;
  29. index_t u_batch_stride;
  30. index_t u_d_stride;
  31. index_t delta_batch_stride;
  32. index_t delta_d_stride;
  33. index_t z_batch_stride;
  34. index_t z_d_stride;
  35. index_t out_batch_stride;
  36. index_t out_d_stride;
  37. index_t out_z_batch_stride;
  38. index_t out_z_d_stride;
  39. // Common data pointers.
  40. void* __restrict__ A_ptr;
  41. void* __restrict__ B_ptr;
  42. void* __restrict__ C_ptr;
  43. void* __restrict__ D_ptr;
  44. void* __restrict__ u_ptr;
  45. void* __restrict__ delta_ptr;
  46. void* __restrict__ delta_bias_ptr;
  47. void* __restrict__ out_ptr;
  48. void* __restrict__ x_ptr;
  49. void* __restrict__ z_ptr;
  50. void* __restrict__ out_z_ptr;
  51. void* __restrict__ index_ptr;
  52. };
  53. #ifndef USE_ROCM
  54. constexpr size_t custom_max(std::initializer_list<size_t> ilist) {
  55. return std::max(ilist);
  56. }
  57. template <typename T>
  58. constexpr T constexpr_min(T a, T b) {
  59. return std::min(a, b);
  60. }
  61. #else
  62. constexpr size_t custom_max(std::initializer_list<size_t> ilist) {
  63. return *std::max_element(ilist.begin(), ilist.end());
  64. }
  65. template <typename T>
  66. constexpr T constexpr_min(T a, T b) {
  67. return a < b ? a : b;
  68. }
  69. #endif
  70. #define MAX_DSTATE 256
  71. inline __device__ float2 operator+(const float2& a, const float2& b) {
  72. return {a.x + b.x, a.y + b.y};
  73. }
  74. inline __device__ float3 operator+(const float3& a, const float3& b) {
  75. return {a.x + b.x, a.y + b.y, a.z + b.z};
  76. }
  77. inline __device__ float4 operator+(const float4& a, const float4& b) {
  78. return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
  79. }
  80. ////////////////////////////////////////////////////////////////////////////////////////////////////
  81. template <int BYTES>
  82. struct BytesToType {};
  83. template <>
  84. struct BytesToType<16> {
  85. using Type = uint4;
  86. static_assert(sizeof(Type) == 16);
  87. };
  88. template <>
  89. struct BytesToType<8> {
  90. using Type = uint64_t;
  91. static_assert(sizeof(Type) == 8);
  92. };
  93. template <>
  94. struct BytesToType<4> {
  95. using Type = uint32_t;
  96. static_assert(sizeof(Type) == 4);
  97. };
  98. template <>
  99. struct BytesToType<2> {
  100. using Type = uint16_t;
  101. static_assert(sizeof(Type) == 2);
  102. };
  103. template <>
  104. struct BytesToType<1> {
  105. using Type = uint8_t;
  106. static_assert(sizeof(Type) == 1);
  107. };
  108. ////////////////////////////////////////////////////////////////////////////////////////////////////
  109. template <typename scalar_t, int N>
  110. struct Converter {
  111. static inline __device__ void to_float(const scalar_t (&src)[N],
  112. float (&dst)[N]) {
  113. #pragma unroll
  114. for (int i = 0; i < N; ++i) {
  115. dst[i] = src[i];
  116. }
  117. }
  118. };
  119. template <int N>
  120. struct Converter<at::Half, N> {
  121. static inline __device__ void to_float(const at::Half (&src)[N],
  122. float (&dst)[N]) {
  123. static_assert(N % 2 == 0);
  124. auto& src2 = reinterpret_cast<const half2(&)[N / 2]>(src);
  125. auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
  126. #pragma unroll
  127. for (int i = 0; i < N / 2; ++i) {
  128. dst2[i] = __half22float2(src2[i]);
  129. }
  130. }
  131. };
  132. #if __CUDA_ARCH__ >= 800
  133. template <int N>
  134. struct Converter<at::BFloat16, N> {
  135. static inline __device__ void to_float(const at::BFloat16 (&src)[N],
  136. float (&dst)[N]) {
  137. static_assert(N % 2 == 0);
  138. auto& src2 = reinterpret_cast<const nv_bfloat162(&)[N / 2]>(src);
  139. auto& dst2 = reinterpret_cast<float2(&)[N / 2]>(dst);
  140. #pragma unroll
  141. for (int i = 0; i < N / 2; ++i) {
  142. dst2[i] = __bfloat1622float2(src2[i]);
  143. }
  144. }
  145. };
  146. #endif
  147. ////////////////////////////////////////////////////////////////////////////////////////////////////
  148. template <typename scalar_t>
  149. struct SSMScanOp;
  150. template <>
  151. struct SSMScanOp<float> {
  152. __device__ __forceinline__ float2 operator()(const float2& ab0,
  153. const float2& ab1) const {
  154. return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
  155. }
  156. };
  157. // A stateful callback functor that maintains a running prefix to be applied
  158. // during consecutive scan operations.
  159. template <typename scalar_t>
  160. struct SSMScanPrefixCallbackOp {
  161. using scan_t =
  162. std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
  163. scan_t running_prefix;
  164. // Constructor
  165. __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_)
  166. : running_prefix(running_prefix_) {}
  167. // Callback operator to be entered by the first warp of threads in the block.
  168. // Thread-0 is responsible for returning a value for seeding the block-wide
  169. // scan.
  170. __device__ scan_t operator()(scan_t block_aggregate) {
  171. scan_t old_prefix = running_prefix;
  172. running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
  173. return old_prefix;
  174. }
  175. };
  176. ////////////////////////////////////////////////////////////////////////////////////////////////////
  177. template <typename Ktraits>
  178. inline __device__ void load_input(
  179. typename Ktraits::input_t* u,
  180. typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
  181. typename Ktraits::BlockLoadT::TempStorage& smem_load, int seqlen) {
  182. if constexpr (Ktraits::kIsEvenLen) {
  183. auto& smem_load_vec =
  184. reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(
  185. smem_load);
  186. using vec_t = typename Ktraits::vec_t;
  187. typename Ktraits::BlockLoadVecT(smem_load_vec)
  188. .Load(reinterpret_cast<vec_t*>(u),
  189. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
  190. #ifdef USE_ROCM
  191. ,
  192. Ktraits::kNThreads * Ktraits::kNLoads
  193. #endif
  194. );
  195. } else {
  196. typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
  197. }
  198. }
  199. template <typename Ktraits>
  200. inline __device__ void load_index(
  201. int* u, int (&u_vals)[Ktraits::kNItems],
  202. typename Ktraits::BlockLoadIndexT::TempStorage& smem_load_index,
  203. int seqlen) {
  204. if constexpr (Ktraits::kIsEvenLen) {
  205. auto& smem_load_index_vec =
  206. reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(
  207. smem_load_index);
  208. Ktraits::BlockLoadIndexVecT(smem_load_index_vec)
  209. .Load(reinterpret_cast<uint4*>(u),
  210. reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals));
  211. } else {
  212. Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
  213. }
  214. }
  215. template <typename Ktraits>
  216. inline __device__ void load_weight(
  217. typename Ktraits::input_t* Bvar,
  218. typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
  219. typename Ktraits::BlockLoadWeightT::TempStorage& smem_load_weight,
  220. int seqlen) {
  221. constexpr int kNItems = Ktraits::kNItems;
  222. typename Ktraits::input_t B_vals_load[kNItems];
  223. if constexpr (Ktraits::kIsEvenLen) {
  224. auto& smem_load_weight_vec =
  225. reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(
  226. smem_load_weight);
  227. using vec_t = typename Ktraits::vec_t;
  228. typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec)
  229. .Load(reinterpret_cast<vec_t*>(Bvar),
  230. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load));
  231. } else {
  232. typename Ktraits::BlockLoadWeightT(smem_load_weight)
  233. .Load(Bvar, B_vals_load, seqlen, 0.f);
  234. }
  235. // #pragma unroll
  236. // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
  237. Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
  238. }
  239. template <typename Ktraits>
  240. inline __device__ void store_output(
  241. typename Ktraits::input_t* out, const float (&out_vals)[Ktraits::kNItems],
  242. typename Ktraits::BlockStoreT::TempStorage& smem_store, int seqlen) {
  243. typename Ktraits::input_t write_vals[Ktraits::kNItems];
  244. #pragma unroll
  245. for (int i = 0; i < Ktraits::kNItems; ++i) {
  246. write_vals[i] = out_vals[i];
  247. }
  248. if constexpr (Ktraits::kIsEvenLen) {
  249. auto& smem_store_vec =
  250. reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(
  251. smem_store);
  252. using vec_t = typename Ktraits::vec_t;
  253. typename Ktraits::BlockStoreVecT(smem_store_vec)
  254. .Store(reinterpret_cast<vec_t*>(out),
  255. reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals));
  256. } else {
  257. typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
  258. }
  259. }