1
0

selective_scan_fwd.cu 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "selective_scan.h"
  5. #include <c10/util/BFloat16.h>
  6. #include <c10/util/Half.h>
  7. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  8. #ifndef USE_ROCM
  9. #include <cub/block/block_load.cuh>
  10. #include <cub/block/block_store.cuh>
  11. #include <cub/block/block_scan.cuh>
  12. #else
  13. #include <hipcub/hipcub.hpp>
  14. namespace cub = hipcub;
  15. #endif
  16. #include "selective_scan.h"
  17. #include "static_switch.h"
  18. template <int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
  19. bool kIsVariableB_, bool kIsVariableC_, bool kHasZ_, bool kUseIndex_,
  20. typename input_t_, typename weight_t_>
  21. struct Selective_Scan_fwd_kernel_traits {
  22. static_assert(kNItems_ % 4 == 0);
  23. using input_t = input_t_;
  24. using weight_t = weight_t_;
  25. static constexpr int kNThreads = kNThreads_;
  26. // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves
  27. // occupancy.
  28. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
  29. static constexpr int kNItems = kNItems_;
  30. static constexpr int kNRows = kNRows_;
  31. static constexpr int kNBytes = sizeof(input_t);
  32. static_assert(kNBytes == 2 || kNBytes == 4);
  33. static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
  34. static_assert(kNItems % kNElts == 0);
  35. static constexpr int kNLoads = kNItems / kNElts;
  36. static constexpr bool kIsEvenLen = kIsEvenLen_;
  37. static constexpr bool kIsVariableB = kIsVariableB_;
  38. static constexpr bool kIsVariableC = kIsVariableC_;
  39. static constexpr bool kHasZ = kHasZ_;
  40. static constexpr bool kUseIndex = kUseIndex_;
  41. static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
  42. static constexpr int kNLoadsIndex = kNItems / 4;
  43. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  44. using scan_t = float2;
  45. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems,
  46. cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  47. using BlockLoadVecT =
  48. cub::BlockLoad<vec_t, kNThreads, kNLoads,
  49. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  50. : cub::BLOCK_LOAD_DIRECT>;
  51. using BlockLoadIndexT =
  52. cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  53. using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
  54. !(kIsEvenLen && kNLoadsIndex == 1)
  55. ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  56. : cub::BLOCK_LOAD_DIRECT>;
  57. using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems,
  58. cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  59. using BlockLoadWeightVecT =
  60. cub::BlockLoad<vec_t, kNThreads, kNLoads,
  61. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  62. : cub::BLOCK_LOAD_DIRECT>;
  63. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems,
  64. cub::BLOCK_STORE_WARP_TRANSPOSE>;
  65. using BlockStoreVecT =
  66. cub::BlockStore<vec_t, kNThreads, kNLoads,
  67. !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE
  68. : cub::BLOCK_STORE_DIRECT>;
  69. // using BlockScanT = cub::BlockScan<scan_t, kNThreads,
  70. // cub::BLOCK_SCAN_RAKING_MEMOIZE>; using BlockScanT = cub::BlockScan<scan_t,
  71. // kNThreads, cub::BLOCK_SCAN_RAKING>;
  72. using BlockScanT =
  73. cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
  74. static constexpr int kSmemIOSize =
  75. custom_max({sizeof(typename BlockLoadT::TempStorage),
  76. sizeof(typename BlockLoadVecT::TempStorage),
  77. sizeof(typename BlockLoadIndexT::TempStorage),
  78. sizeof(typename BlockLoadIndexVecT::TempStorage),
  79. (int(kIsVariableB) + int(kIsVariableC)) *
  80. sizeof(typename BlockLoadWeightT::TempStorage),
  81. (int(kIsVariableB) + int(kIsVariableC)) *
  82. sizeof(typename BlockLoadWeightVecT::TempStorage),
  83. sizeof(typename BlockStoreT::TempStorage),
  84. sizeof(typename BlockStoreVecT::TempStorage)});
  85. static constexpr int kSmemSize =
  86. kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
  87. };
  88. template <typename Ktraits>
  89. __global__ __launch_bounds__(
  90. Ktraits::kNThreads,
  91. Ktraits::kMinBlocks) void selective_scan_fwd_kernel(SSMParamsBase params) {
  92. constexpr bool kIsVariableB = Ktraits::kIsVariableB;
  93. constexpr bool kIsVariableC = Ktraits::kIsVariableC;
  94. constexpr bool kHasZ = Ktraits::kHasZ;
  95. constexpr bool kUseIndex = Ktraits::kUseIndex;
  96. constexpr int kNThreads = Ktraits::kNThreads;
  97. constexpr int kNItems = Ktraits::kNItems;
  98. constexpr int kNRows = Ktraits::kNRows;
  99. constexpr bool kDirectIO = Ktraits::kDirectIO;
  100. using input_t = typename Ktraits::input_t;
  101. using weight_t = typename Ktraits::weight_t;
  102. using scan_t = typename Ktraits::scan_t;
  103. // Shared memory.
  104. extern __shared__ char smem_[];
  105. // cast to lvalue reference of expected type
  106. // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
  107. // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_
  108. // + 2 * MAX_DSTATE * sizeof(weight_t)); auto& smem_load =
  109. // reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
  110. auto& smem_load =
  111. reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  112. [[maybe_unused]] auto& smem_load_weight =
  113. reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
  114. [[maybe_unused]] auto& smem_load_index =
  115. reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
  116. [[maybe_unused]] auto& smem_load_weight1 =
  117. *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(
  118. smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
  119. auto& smem_store =
  120. reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  121. auto& smem_scan =
  122. *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(
  123. smem_ + Ktraits::kSmemIOSize);
  124. // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ +
  125. // smem_loadstorescan_size); weight_t *smem_bc = reinterpret_cast<weight_t
  126. // *>(smem_a + MAX_DSTATE);
  127. scan_t* smem_running_prefix =
  128. reinterpret_cast<scan_t*>(smem_ + Ktraits::kSmemSize);
  129. const int batch_id = blockIdx.x;
  130. const int dim_id = blockIdx.y;
  131. const int group_id = dim_id / (params.dim_ngroups_ratio);
  132. input_t* u = reinterpret_cast<input_t*>(params.u_ptr) +
  133. batch_id * params.u_batch_stride +
  134. dim_id * kNRows * params.u_d_stride;
  135. input_t* delta = reinterpret_cast<input_t*>(params.delta_ptr) +
  136. batch_id * params.delta_batch_stride +
  137. dim_id * kNRows * params.delta_d_stride;
  138. weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr) +
  139. dim_id * kNRows * params.A_d_stride;
  140. [[maybe_unused]] weight_t* B = reinterpret_cast<weight_t*>(params.B_ptr) +
  141. dim_id * kNRows * params.B_d_stride;
  142. input_t* Bvar = reinterpret_cast<input_t*>(params.B_ptr) +
  143. batch_id * params.B_batch_stride +
  144. group_id * params.B_group_stride;
  145. [[maybe_unused]] weight_t* C = reinterpret_cast<weight_t*>(params.C_ptr) +
  146. dim_id * kNRows * params.C_d_stride;
  147. input_t* Cvar = reinterpret_cast<input_t*>(params.C_ptr) +
  148. batch_id * params.C_batch_stride +
  149. group_id * params.C_group_stride;
  150. scan_t* x = reinterpret_cast<scan_t*>(params.x_ptr) +
  151. (batch_id * params.dim + dim_id * kNRows) * params.n_chunks *
  152. params.dstate;
  153. [[maybe_unused]] int* index =
  154. !kUseIndex
  155. ? nullptr
  156. : reinterpret_cast<int*>(params.index_ptr) + batch_id * params.seqlen;
  157. float D_val[kNRows] = {0};
  158. if (params.D_ptr != nullptr) {
  159. #pragma unroll
  160. for (int r = 0; r < kNRows; ++r) {
  161. D_val[r] = reinterpret_cast<float*>(params.D_ptr)[dim_id * kNRows + r];
  162. }
  163. }
  164. float delta_bias[kNRows] = {0};
  165. if (params.delta_bias_ptr != nullptr) {
  166. #pragma unroll
  167. for (int r = 0; r < kNRows; ++r) {
  168. delta_bias[r] =
  169. reinterpret_cast<float*>(params.delta_bias_ptr)[dim_id * kNRows + r];
  170. }
  171. }
  172. // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx +=
  173. // blockDim.x) {
  174. // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
  175. // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] *
  176. // C[state_idx * params.C_dstate_stride];
  177. // }
  178. constexpr int kChunkSize = kNThreads * kNItems;
  179. for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
  180. input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
  181. [[maybe_unused]] int index_vals_load[kNRows][kNItems];
  182. __syncthreads();
  183. #pragma unroll
  184. for (int r = 0; r < kNRows; ++r) {
  185. if constexpr (!kDirectIO) {
  186. if (r > 0) {
  187. __syncthreads();
  188. }
  189. }
  190. load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load,
  191. params.seqlen - chunk * kChunkSize);
  192. if constexpr (!kDirectIO) {
  193. __syncthreads();
  194. }
  195. load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r],
  196. smem_load, params.seqlen - chunk * kChunkSize);
  197. if constexpr (kUseIndex) {
  198. load_index<Ktraits>(index + r * params.delta_d_stride,
  199. index_vals_load[r], smem_load_index,
  200. params.seqlen - chunk * kChunkSize);
  201. }
  202. }
  203. if constexpr (kUseIndex) {
  204. index += kChunkSize;
  205. }
  206. u += kChunkSize;
  207. delta += kChunkSize;
  208. float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems],
  209. out_vals[kNRows][kNItems];
  210. #pragma unroll
  211. for (int r = 0; r < kNRows; ++r) {
  212. #pragma unroll
  213. for (int i = 0; i < kNItems; ++i) {
  214. float u_val = float(u_vals[r][i]);
  215. delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
  216. if (params.delta_softplus) {
  217. delta_vals[r][i] = delta_vals[r][i] <= 20.f
  218. ? log1pf(expf(delta_vals[r][i]))
  219. : delta_vals[r][i];
  220. }
  221. delta_u_vals[r][i] = delta_vals[r][i] * u_val;
  222. out_vals[r][i] = D_val[r] * u_val;
  223. }
  224. }
  225. __syncthreads();
  226. for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
  227. weight_t A_val[kNRows];
  228. #pragma unroll
  229. for (int r = 0; r < kNRows; ++r) {
  230. A_val[r] =
  231. A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
  232. // Multiply the real part of A with LOG2E so we can use exp2f instead of
  233. // expf.
  234. constexpr float kLog2e = M_LOG2E;
  235. A_val[r] *= kLog2e;
  236. }
  237. // This variable holds B * C if both B and C are constant across seqlen.
  238. // If only B varies across seqlen, this holds C. If only C varies across
  239. // seqlen, this holds B. If both B and C vary, this is unused.
  240. weight_t BC_val[kNRows];
  241. weight_t B_vals[kNItems], C_vals[kNItems];
  242. if constexpr (kIsVariableB) {
  243. load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
  244. smem_load_weight,
  245. (params.seqlen - chunk * kChunkSize) * (1));
  246. if constexpr (!kIsVariableC) {
  247. #pragma unroll
  248. for (int r = 0; r < kNRows; ++r) {
  249. BC_val[r] =
  250. C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  251. }
  252. }
  253. }
  254. if constexpr (kIsVariableC) {
  255. auto& smem_load_weight_C =
  256. !kIsVariableB ? smem_load_weight : smem_load_weight1;
  257. load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
  258. smem_load_weight_C,
  259. (params.seqlen - chunk * kChunkSize) * (1));
  260. if constexpr (!kIsVariableB) {
  261. #pragma unroll
  262. for (int r = 0; r < kNRows; ++r) {
  263. BC_val[r] =
  264. B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
  265. }
  266. }
  267. }
  268. if constexpr (!kIsVariableB && !kIsVariableC) {
  269. #pragma unroll
  270. for (int r = 0; r < kNRows; ++r) {
  271. BC_val[r] =
  272. B[state_idx * params.B_dstate_stride + r * params.B_d_stride] *
  273. C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  274. }
  275. }
  276. #pragma unroll
  277. for (int r = 0; r < kNRows; ++r) {
  278. if (r > 0) {
  279. __syncthreads();
  280. } // Scan could be using the same smem
  281. scan_t thread_data[kNItems];
  282. #pragma unroll
  283. for (int i = 0; i < kNItems; ++i) {
  284. thread_data[i] =
  285. make_float2(exp2f(delta_vals[r][i] * A_val[r]),
  286. !kIsVariableB ? delta_u_vals[r][i]
  287. : B_vals[i] * delta_u_vals[r][i]);
  288. // Reset A bar for cumulative sequences (Real)
  289. if constexpr (kUseIndex) {
  290. if (index_vals_load[r][i] == 0) {
  291. thread_data[i].x = 0.f;
  292. }
  293. }
  294. if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is
  295. // correct
  296. if (threadIdx.x * kNItems + i >=
  297. params.seqlen - chunk * kChunkSize) {
  298. thread_data[i] = make_float2(1.f, 0.f);
  299. }
  300. }
  301. }
  302. // Initialize running total
  303. scan_t running_prefix;
  304. // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0)
  305. // needs to read
  306. running_prefix =
  307. chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx]
  308. : (threadIdx.x % 32 == 0
  309. ? smem_running_prefix[state_idx + r * MAX_DSTATE]
  310. : make_float2(1.f, 0.f));
  311. // running_prefix = chunk > 0 && threadIdx.x == 0 ?
  312. // smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
  313. SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
  314. typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
  315. thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op);
  316. // There's a syncthreads in the scan op, so we don't need to sync here.
  317. // Unless there's only 1 warp, but then it's the same thread (0) reading
  318. // and writing.
  319. if (threadIdx.x == 0) {
  320. smem_running_prefix[state_idx] = prefix_op.running_prefix;
  321. x[(r * params.n_chunks + chunk) * params.dstate + state_idx] =
  322. prefix_op.running_prefix;
  323. }
  324. #pragma unroll
  325. for (int i = 0; i < kNItems; ++i) {
  326. const weight_t C_val =
  327. !kIsVariableC
  328. ? BC_val[r]
  329. : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
  330. out_vals[r][i] += thread_data[i].y * C_val;
  331. }
  332. }
  333. }
  334. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
  335. batch_id * params.out_batch_stride +
  336. dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
  337. __syncthreads();
  338. #pragma unroll
  339. for (int r = 0; r < kNRows; ++r) {
  340. if constexpr (!kDirectIO) {
  341. if (r > 0) {
  342. __syncthreads();
  343. }
  344. }
  345. store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r],
  346. smem_store, params.seqlen - chunk * kChunkSize);
  347. }
  348. if constexpr (kHasZ) {
  349. input_t* z = reinterpret_cast<input_t*>(params.z_ptr) +
  350. batch_id * params.z_batch_stride +
  351. dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
  352. input_t* out_z = reinterpret_cast<input_t*>(params.out_z_ptr) +
  353. batch_id * params.out_z_batch_stride +
  354. dim_id * kNRows * params.out_z_d_stride +
  355. chunk * kChunkSize;
  356. #pragma unroll
  357. for (int r = 0; r < kNRows; ++r) {
  358. input_t z_vals[kNItems];
  359. __syncthreads();
  360. load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load,
  361. params.seqlen - chunk * kChunkSize);
  362. #pragma unroll
  363. for (int i = 0; i < kNItems; ++i) {
  364. float z_val = z_vals[i];
  365. out_vals[r][i] *= z_val / (1 + expf(-z_val));
  366. }
  367. __syncthreads();
  368. store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r],
  369. smem_store, params.seqlen - chunk * kChunkSize);
  370. }
  371. }
  372. Bvar += kChunkSize * 1;
  373. Cvar += kChunkSize * 1;
  374. }
  375. }
  376. template <int kNThreads, int kNItems, typename input_t, typename weight_t>
  377. void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) {
  378. // Only kNRows == 1 is tested for now, which ofc doesn't differ from
  379. // previously when we had each block processing 1 row.
  380. static constexpr int kNRows = 1;
  381. BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
  382. BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
  383. BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
  384. BOOL_SWITCH(params.z_ptr != nullptr, kHasZ, [&] {
  385. BOOL_SWITCH(params.index_ptr != nullptr, kUseIndex, [&] {
  386. using Ktraits = Selective_Scan_fwd_kernel_traits<
  387. kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB,
  388. kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
  389. // constexpr int kSmemSize = Ktraits::kSmemSize;
  390. constexpr int kSmemSize =
  391. Ktraits::kSmemSize +
  392. kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
  393. dim3 grid(params.batch, params.dim / kNRows);
  394. auto kernel = &selective_scan_fwd_kernel<Ktraits>;
  395. if (kSmemSize >= 48 * 1024) {
  396. C10_CUDA_CHECK(cudaFuncSetAttribute(
  397. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
  398. kSmemSize));
  399. }
  400. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  401. C10_CUDA_KERNEL_LAUNCH_CHECK();
  402. });
  403. });
  404. });
  405. });
  406. });
  407. }
  408. template <typename input_t, typename weight_t>
  409. void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream) {
  410. #ifndef USE_ROCM
  411. if (params.seqlen <= 128) {
  412. selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
  413. } else if (params.seqlen <= 256) {
  414. selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
  415. } else if (params.seqlen <= 512) {
  416. selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
  417. } else if (params.seqlen <= 1024) {
  418. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  419. } else {
  420. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  421. }
  422. #else
  423. if (params.seqlen <= 256) {
  424. selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
  425. } else if (params.seqlen <= 512) {
  426. selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
  427. } else if (params.seqlen <= 1024) {
  428. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  429. } else {
  430. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  431. }
  432. #endif
  433. }
  434. template void selective_scan_fwd_cuda<at::BFloat16, float>(
  435. SSMParamsBase& params, cudaStream_t stream);
  436. template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase& params,
  437. cudaStream_t stream);
  438. template void selective_scan_fwd_cuda<float, float>(SSMParamsBase& params,
  439. cudaStream_t stream);
  440. #define CHECK_SHAPE(x, ...) \
  441. TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
  442. #x " must have shape (" #__VA_ARGS__ ")")
  443. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  444. if (ITYPE == at::ScalarType::Half) { \
  445. using input_t = at::Half; \
  446. __VA_ARGS__(); \
  447. } else if (ITYPE == at::ScalarType::BFloat16) { \
  448. using input_t = at::BFloat16; \
  449. __VA_ARGS__(); \
  450. } else if (ITYPE == at::ScalarType::Float) { \
  451. using input_t = float; \
  452. __VA_ARGS__(); \
  453. } else { \
  454. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \
  455. "'"); \
  456. }
  457. #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
  458. if (WTYPE == at::ScalarType::Half) { \
  459. using weight_t = at::Half; \
  460. __VA_ARGS__(); \
  461. } else if (WTYPE == at::ScalarType::BFloat16) { \
  462. using weight_t = at::BFloat16; \
  463. __VA_ARGS__(); \
  464. } else if (WTYPE == at::ScalarType::Float) { \
  465. using weight_t = float; \
  466. __VA_ARGS__(); \
  467. } else { \
  468. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \
  469. "'"); \
  470. }
  471. #define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \
  472. if (WTYPE == at::ScalarType::Float) { \
  473. using weight_t = float; \
  474. __VA_ARGS__(); \
  475. } else { \
  476. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \
  477. "'"); \
  478. }
  479. template <typename input_t, typename weight_t>
  480. void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream);
  481. void set_ssm_params_fwd(SSMParamsBase& params,
  482. // sizes
  483. const size_t batch, const size_t dim,
  484. const size_t seqlen, const size_t dstate,
  485. const size_t n_groups, const size_t n_chunks,
  486. const bool is_variable_B, const bool is_variable_C,
  487. // device pointers
  488. const torch::Tensor u, const torch::Tensor delta,
  489. const torch::Tensor A, const torch::Tensor B,
  490. const torch::Tensor C, const torch::Tensor out,
  491. const torch::Tensor z, const torch::Tensor out_z,
  492. void* D_ptr, void* delta_bias_ptr, void* x_ptr,
  493. bool has_z, bool delta_softplus, void* index_ptr) {
  494. // Reset the parameters
  495. memset(&params, 0, sizeof(params));
  496. params.batch = batch;
  497. params.dim = dim;
  498. params.seqlen = seqlen;
  499. params.dstate = dstate;
  500. params.n_groups = n_groups;
  501. params.n_chunks = n_chunks;
  502. params.dim_ngroups_ratio = dim / n_groups;
  503. params.delta_softplus = delta_softplus;
  504. params.is_variable_B = is_variable_B;
  505. params.is_variable_C = is_variable_C;
  506. // Set the pointers and strides.
  507. params.u_ptr = u.data_ptr();
  508. params.delta_ptr = delta.data_ptr();
  509. params.A_ptr = A.data_ptr();
  510. params.B_ptr = B.data_ptr();
  511. params.C_ptr = C.data_ptr();
  512. params.D_ptr = D_ptr;
  513. params.delta_bias_ptr = delta_bias_ptr;
  514. params.out_ptr = out.data_ptr();
  515. params.x_ptr = x_ptr;
  516. params.z_ptr = has_z ? z.data_ptr() : nullptr;
  517. params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
  518. params.index_ptr = index_ptr;
  519. // All stride are in elements, not bytes.
  520. params.A_d_stride = A.stride(0);
  521. params.A_dstate_stride = A.stride(1);
  522. if (!is_variable_B) {
  523. params.B_d_stride = B.stride(0);
  524. } else {
  525. params.B_batch_stride = B.stride(0);
  526. params.B_group_stride = B.stride(1);
  527. }
  528. params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
  529. if (!is_variable_C) {
  530. params.C_d_stride = C.stride(0);
  531. } else {
  532. params.C_batch_stride = C.stride(0);
  533. params.C_group_stride = C.stride(1);
  534. }
  535. params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
  536. params.u_batch_stride = u.stride(0);
  537. params.u_d_stride = u.stride(1);
  538. params.delta_batch_stride = delta.stride(0);
  539. params.delta_d_stride = delta.stride(1);
  540. if (has_z) {
  541. params.z_batch_stride = z.stride(0);
  542. params.z_d_stride = z.stride(1);
  543. params.out_z_batch_stride = out_z.stride(0);
  544. params.out_z_d_stride = out_z.stride(1);
  545. }
  546. params.out_batch_stride = out.stride(0);
  547. params.out_d_stride = out.stride(1);
  548. }
  549. std::vector<torch::Tensor> selective_scan_fwd(
  550. const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
  551. const torch::Tensor& B, const torch::Tensor& C,
  552. const c10::optional<torch::Tensor>& D_,
  553. const c10::optional<torch::Tensor>& z_,
  554. const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
  555. const c10::optional<torch::Tensor>& index_,
  556. const c10::optional<torch::Tensor>& x) {
  557. auto input_type = u.scalar_type();
  558. auto weight_type = A.scalar_type();
  559. TORCH_CHECK(input_type == at::ScalarType::Float ||
  560. input_type == at::ScalarType::Half ||
  561. input_type == at::ScalarType::BFloat16);
  562. TORCH_CHECK(weight_type == at::ScalarType::Float ||
  563. weight_type == at::ScalarType::ComplexFloat);
  564. const bool is_variable_B = B.dim() >= 3;
  565. const bool is_variable_C = C.dim() >= 3;
  566. const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
  567. TORCH_CHECK(delta.scalar_type() == input_type);
  568. TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
  569. TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
  570. TORCH_CHECK(u.is_cuda());
  571. TORCH_CHECK(delta.is_cuda());
  572. TORCH_CHECK(A.is_cuda());
  573. TORCH_CHECK(B.is_cuda());
  574. TORCH_CHECK(C.is_cuda());
  575. TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
  576. TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
  577. const auto sizes = u.sizes();
  578. const int batch_size = sizes[0];
  579. const int dim = sizes[1];
  580. const int seqlen = sizes[2];
  581. const int dstate = A.size(1);
  582. const int n_groups = is_variable_B ? B.size(1) : 1;
  583. TORCH_CHECK(dstate <= 256,
  584. "selective_scan only supports state dimension <= 256");
  585. CHECK_SHAPE(u, batch_size, dim, seqlen);
  586. CHECK_SHAPE(delta, batch_size, dim, seqlen);
  587. CHECK_SHAPE(A, dim, dstate);
  588. if (!is_variable_B) {
  589. CHECK_SHAPE(B, dim, dstate);
  590. } else {
  591. CHECK_SHAPE(B, batch_size, n_groups, dstate,
  592. !is_complex ? seqlen : seqlen * 2);
  593. TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
  594. }
  595. if (!is_variable_C) {
  596. CHECK_SHAPE(C, dim, dstate);
  597. } else {
  598. CHECK_SHAPE(C, batch_size, n_groups, dstate,
  599. !is_complex ? seqlen : seqlen * 2);
  600. TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
  601. }
  602. if (D_.has_value()) {
  603. auto D = D_.value();
  604. TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
  605. TORCH_CHECK(D.is_cuda());
  606. TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
  607. CHECK_SHAPE(D, dim);
  608. }
  609. if (delta_bias_.has_value()) {
  610. auto delta_bias = delta_bias_.value();
  611. TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
  612. TORCH_CHECK(delta_bias.is_cuda());
  613. TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
  614. CHECK_SHAPE(delta_bias, dim);
  615. }
  616. if (index_.has_value()) {
  617. auto index = index_.value();
  618. TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
  619. TORCH_CHECK(index.is_cuda());
  620. CHECK_SHAPE(index, batch_size, seqlen);
  621. }
  622. at::Tensor z, out_z;
  623. const bool has_z = z_.has_value();
  624. if (has_z) {
  625. z = z_.value();
  626. TORCH_CHECK(z.scalar_type() == input_type);
  627. TORCH_CHECK(z.is_cuda());
  628. TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
  629. CHECK_SHAPE(z, batch_size, dim, seqlen);
  630. out_z = torch::empty_like(z);
  631. }
  632. const int n_chunks = (seqlen + 2048 - 1) / 2048;
  633. // const int n_chunks = (seqlen + 1024 - 1) / 1024;
  634. // at::Tensor out = torch::empty_like(u);
  635. // Right now u has BHL layout and delta has HBL layout, and we want out to
  636. // have HBL layout
  637. at::Tensor out = torch::empty_like(delta);
  638. if (x.has_value()) {
  639. auto _x = x.value();
  640. TORCH_CHECK(_x.scalar_type() == weight_type);
  641. TORCH_CHECK(_x.is_cuda());
  642. TORCH_CHECK(_x.stride(-1) == 1);
  643. CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
  644. }
  645. SSMParamsBase params;
  646. set_ssm_params_fwd(
  647. params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
  648. is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z,
  649. D_.has_value() ? D_.value().data_ptr() : nullptr,
  650. delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
  651. x.value().data_ptr(), has_z, delta_softplus,
  652. index_.has_value() ? index_.value().data_ptr() : nullptr);
  653. // Otherwise the kernel will be launched from cuda:0 device
  654. // Cast to char to avoid compiler warning about narrowing
  655. at::cuda::CUDAGuard device_guard{(char)u.get_device()};
  656. auto stream = at::cuda::getCurrentCUDAStream().stream();
  657. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  658. u.scalar_type(), "selective_scan_fwd", [&] {
  659. DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] {
  660. selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
  661. });
  662. });
  663. std::vector<at::Tensor> result = {out, x.value()};
  664. if (has_z) {
  665. result.push_back(out_z);
  666. }
  667. return result;
  668. }