selective_scan_fwd_kernel.cuh 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #pragma once
  5. #include <c10/util/BFloat16.h>
  6. #include <c10/util/Half.h>
  7. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  8. #include <cub/block/block_load.cuh>
  9. #include <cub/block/block_store.cuh>
  10. #include <cub/block/block_scan.cuh>
  11. #include "selective_scan.h"
  12. #include "selective_scan_common.h"
  13. #include "static_switch.h"
  14. template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
  15. bool kIsVariableB_, bool kIsVariableC_,
  16. bool kHasZ_, typename input_t_, typename weight_t_>
  17. struct Selective_Scan_fwd_kernel_traits {
  18. static_assert(kNItems_ % 4 == 0);
  19. using input_t = input_t_;
  20. using weight_t = weight_t_;
  21. static constexpr int kNThreads = kNThreads_;
  22. // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
  23. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
  24. static constexpr int kNItems = kNItems_;
  25. static constexpr int kNRows = kNRows_;
  26. static constexpr int kNBytes = sizeof(input_t);
  27. static_assert(kNBytes == 2 || kNBytes == 4);
  28. static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems);
  29. static_assert(kNItems % kNElts == 0);
  30. static constexpr int kNLoads = kNItems / kNElts;
  31. static constexpr bool kIsComplex = std::is_same_v<weight_t, complex_t>;
  32. static constexpr bool kIsEvenLen = kIsEvenLen_;
  33. static constexpr bool kIsVariableB = kIsVariableB_;
  34. static constexpr bool kIsVariableC = kIsVariableC_;
  35. static constexpr bool kHasZ = kHasZ_;
  36. static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
  37. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  38. using scan_t = std::conditional_t<!kIsComplex, float2, float4>;
  39. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  40. using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
  41. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
  42. using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, !kIsComplex ? kNItems : kNItems * 2, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  43. using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, !kIsComplex ? kNLoads : kNLoads * 2,
  44. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
  45. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
  46. using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
  47. !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
  48. // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
  49. // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
  50. using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
  51. static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage),
  52. sizeof(typename BlockLoadVecT::TempStorage),
  53. (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
  54. (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
  55. sizeof(typename BlockStoreT::TempStorage),
  56. sizeof(typename BlockStoreVecT::TempStorage)});
  57. static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
  58. };
  59. template<typename Ktraits>
  60. __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
  61. void selective_scan_fwd_kernel(SSMParamsBase params) {
  62. constexpr bool kIsComplex = Ktraits::kIsComplex;
  63. constexpr bool kIsVariableB = Ktraits::kIsVariableB;
  64. constexpr bool kIsVariableC = Ktraits::kIsVariableC;
  65. constexpr bool kHasZ = Ktraits::kHasZ;
  66. constexpr int kNThreads = Ktraits::kNThreads;
  67. constexpr int kNItems = Ktraits::kNItems;
  68. constexpr int kNRows = Ktraits::kNRows;
  69. constexpr bool kDirectIO = Ktraits::kDirectIO;
  70. using input_t = typename Ktraits::input_t;
  71. using weight_t = typename Ktraits::weight_t;
  72. using scan_t = typename Ktraits::scan_t;
  73. // Shared memory.
  74. extern __shared__ char smem_[];
  75. // cast to lvalue reference of expected type
  76. // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
  77. // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
  78. // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
  79. auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  80. auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
  81. auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
  82. auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  83. auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
  84. // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
  85. // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
  86. scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
  87. const int batch_id = blockIdx.x;
  88. const int dim_id = blockIdx.y;
  89. const int group_id = dim_id / (params.dim_ngroups_ratio);
  90. input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
  91. + dim_id * kNRows * params.u_d_stride;
  92. input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
  93. + dim_id * kNRows * params.delta_d_stride;
  94. weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
  95. weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
  96. input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
  97. weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
  98. input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
  99. scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
  100. float D_val[kNRows] = {0};
  101. if (params.D_ptr != nullptr) {
  102. #pragma unroll
  103. for (int r = 0; r < kNRows; ++r) {
  104. D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
  105. }
  106. }
  107. float delta_bias[kNRows] = {0};
  108. if (params.delta_bias_ptr != nullptr) {
  109. #pragma unroll
  110. for (int r = 0; r < kNRows; ++r) {
  111. delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
  112. }
  113. }
  114. // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
  115. // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
  116. // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
  117. // }
  118. constexpr int kChunkSize = kNThreads * kNItems;
  119. for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
  120. input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
  121. __syncthreads();
  122. #pragma unroll
  123. for (int r = 0; r < kNRows; ++r) {
  124. if constexpr (!kDirectIO) {
  125. if (r > 0) { __syncthreads(); }
  126. }
  127. load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
  128. if constexpr (!kDirectIO) { __syncthreads(); }
  129. load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
  130. }
  131. u += kChunkSize;
  132. delta += kChunkSize;
  133. float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
  134. #pragma unroll
  135. for (int r = 0; r < kNRows; ++r) {
  136. #pragma unroll
  137. for (int i = 0; i < kNItems; ++i) {
  138. float u_val = float(u_vals[r][i]);
  139. delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
  140. if (params.delta_softplus) {
  141. delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
  142. }
  143. delta_u_vals[r][i] = delta_vals[r][i] * u_val;
  144. out_vals[r][i] = D_val[r] * u_val;
  145. }
  146. }
  147. __syncthreads();
  148. for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
  149. weight_t A_val[kNRows];
  150. #pragma unroll
  151. for (int r = 0; r < kNRows; ++r) {
  152. A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
  153. // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
  154. constexpr float kLog2e = M_LOG2E;
  155. if constexpr (!kIsComplex) {
  156. A_val[r] *= kLog2e;
  157. } else {
  158. A_val[r].real_ *= kLog2e;
  159. }
  160. }
  161. // This variable holds B * C if both B and C are constant across seqlen. If only B varies
  162. // across seqlen, this holds C. If only C varies across seqlen, this holds B.
  163. // If both B and C vary, this is unused.
  164. weight_t BC_val[kNRows];
  165. weight_t B_vals[kNItems], C_vals[kNItems];
  166. if constexpr (kIsVariableB) {
  167. load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
  168. smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
  169. if constexpr (!kIsVariableC) {
  170. #pragma unroll
  171. for (int r = 0; r < kNRows; ++r) {
  172. BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  173. }
  174. }
  175. }
  176. if constexpr (kIsVariableC) {
  177. auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
  178. load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
  179. smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
  180. if constexpr (!kIsVariableB) {
  181. #pragma unroll
  182. for (int r = 0; r < kNRows; ++r) {
  183. BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
  184. }
  185. }
  186. }
  187. if constexpr (!kIsVariableB && !kIsVariableC) {
  188. #pragma unroll
  189. for (int r = 0; r < kNRows; ++r) {
  190. BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  191. }
  192. }
  193. #pragma unroll
  194. for (int r = 0; r < kNRows; ++r) {
  195. if (r > 0) { __syncthreads(); } // Scan could be using the same smem
  196. scan_t thread_data[kNItems];
  197. #pragma unroll
  198. for (int i = 0; i < kNItems; ++i) {
  199. if constexpr (!kIsComplex) {
  200. thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
  201. !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
  202. if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
  203. if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
  204. thread_data[i] = make_float2(1.f, 0.f);
  205. }
  206. }
  207. } else {
  208. // Pytorch's implementation of complex exp (which calls thrust) is very slow
  209. complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]);
  210. weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i];
  211. thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_);
  212. if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
  213. if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
  214. thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f);
  215. }
  216. }
  217. }
  218. }
  219. // Initialize running total
  220. scan_t running_prefix;
  221. if constexpr (!kIsComplex) {
  222. // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
  223. running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f);
  224. // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
  225. } else {
  226. running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f);
  227. // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f);
  228. }
  229. SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
  230. Ktraits::BlockScanT(smem_scan).InclusiveScan(
  231. thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
  232. );
  233. // There's a syncthreads in the scan op, so we don't need to sync here.
  234. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
  235. if (threadIdx.x == 0) {
  236. smem_running_prefix[state_idx] = prefix_op.running_prefix;
  237. x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
  238. }
  239. #pragma unroll
  240. for (int i = 0; i < kNItems; ++i) {
  241. const weight_t C_val = !kIsVariableC
  242. ? BC_val[r]
  243. : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
  244. if constexpr (!kIsComplex) {
  245. out_vals[r][i] += thread_data[i].y * C_val;
  246. } else {
  247. out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2;
  248. }
  249. }
  250. }
  251. }
  252. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  253. + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
  254. __syncthreads();
  255. #pragma unroll
  256. for (int r = 0; r < kNRows; ++r) {
  257. if constexpr (!kDirectIO) {
  258. if (r > 0) { __syncthreads(); }
  259. }
  260. store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
  261. }
  262. if constexpr (kHasZ) {
  263. input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
  264. + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
  265. input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
  266. + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
  267. #pragma unroll
  268. for (int r = 0; r < kNRows; ++r) {
  269. input_t z_vals[kNItems];
  270. __syncthreads();
  271. load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
  272. #pragma unroll
  273. for (int i = 0; i < kNItems; ++i) {
  274. float z_val = z_vals[i];
  275. out_vals[r][i] *= z_val / (1 + expf(-z_val));
  276. }
  277. __syncthreads();
  278. store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
  279. }
  280. }
  281. Bvar += kChunkSize * (!kIsComplex ? 1 : 2);
  282. Cvar += kChunkSize * (!kIsComplex ? 1 : 2);
  283. }
  284. }
  285. template<int kNThreads, int kNItems, typename input_t, typename weight_t>
  286. void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
  287. // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
  288. // processing 1 row.
  289. constexpr int kNRows = 1;
  290. BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
  291. BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
  292. BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
  293. BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
  294. using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
  295. // constexpr int kSmemSize = Ktraits::kSmemSize;
  296. constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
  297. // printf("smem_size = %d\n", kSmemSize);
  298. dim3 grid(params.batch, params.dim / kNRows);
  299. auto kernel = &selective_scan_fwd_kernel<Ktraits>;
  300. if (kSmemSize >= 48 * 1024) {
  301. C10_CUDA_CHECK(cudaFuncSetAttribute(
  302. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  303. }
  304. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  305. C10_CUDA_KERNEL_LAUNCH_CHECK();
  306. });
  307. });
  308. });
  309. });
  310. }
  311. template<typename input_t, typename weight_t>
  312. void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
  313. if (params.seqlen <= 128) {
  314. selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
  315. } else if (params.seqlen <= 256) {
  316. selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
  317. } else if (params.seqlen <= 512) {
  318. selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
  319. } else if (params.seqlen <= 1024) {
  320. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  321. } else {
  322. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  323. }
  324. }