selective_scan_fwd.cu 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "selective_scan.h"
  5. #include <c10/util/BFloat16.h>
  6. #include <c10/util/Half.h>
  7. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  8. #ifndef USE_ROCM
  9. #include <cub/block/block_load.cuh>
  10. #include <cub/block/block_store.cuh>
  11. #include <cub/block/block_scan.cuh>
  12. #else
  13. #include <hipcub/hipcub.hpp>
  14. namespace cub = hipcub;
  15. #endif
  16. #include "selective_scan.h"
  17. #include "static_switch.h"
  18. template <int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
  19. bool kIsVariableB_, bool kIsVariableC_, bool kHasZ_, bool kUseIndex_,
  20. typename input_t_, typename weight_t_>
  21. struct Selective_Scan_fwd_kernel_traits {
  22. static_assert(kNItems_ % 4 == 0);
  23. using input_t = input_t_;
  24. using weight_t = weight_t_;
  25. static constexpr int kNThreads = kNThreads_;
  26. // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves
  27. // occupancy.
  28. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
  29. static constexpr int kNItems = kNItems_;
  30. static constexpr int kNRows = kNRows_;
  31. static constexpr int kNBytes = sizeof(input_t);
  32. static_assert(kNBytes == 2 || kNBytes == 4);
  33. static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
  34. static_assert(kNItems % kNElts == 0);
  35. static constexpr int kNLoads = kNItems / kNElts;
  36. static constexpr bool kIsEvenLen = kIsEvenLen_;
  37. static constexpr bool kIsVariableB = kIsVariableB_;
  38. static constexpr bool kIsVariableC = kIsVariableC_;
  39. static constexpr bool kHasZ = kHasZ_;
  40. static constexpr bool kUseIndex = kUseIndex_;
  41. static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
  42. static constexpr int kNLoadsIndex = kNItems / 4;
  43. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  44. using scan_t = float2;
  45. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems,
  46. cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  47. using BlockLoadVecT =
  48. cub::BlockLoad<vec_t, kNThreads, kNLoads,
  49. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  50. : cub::BLOCK_LOAD_DIRECT>;
  51. using BlockLoadIndexT =
  52. cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  53. using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
  54. !(kIsEvenLen && kNLoadsIndex == 1)
  55. ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  56. : cub::BLOCK_LOAD_DIRECT>;
  57. using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems,
  58. cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  59. using BlockLoadWeightVecT =
  60. cub::BlockLoad<vec_t, kNThreads, kNLoads,
  61. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  62. : cub::BLOCK_LOAD_DIRECT>;
  63. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems,
  64. cub::BLOCK_STORE_WARP_TRANSPOSE>;
  65. using BlockStoreVecT =
  66. cub::BlockStore<vec_t, kNThreads, kNLoads,
  67. !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE
  68. : cub::BLOCK_STORE_DIRECT>;
  69. // using BlockScanT = cub::BlockScan<scan_t, kNThreads,
  70. // cub::BLOCK_SCAN_RAKING_MEMOIZE>; using BlockScanT = cub::BlockScan<scan_t,
  71. // kNThreads, cub::BLOCK_SCAN_RAKING>;
  72. using BlockScanT =
  73. cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
  74. static constexpr int kSmemIOSize =
  75. custom_max({sizeof(typename BlockLoadT::TempStorage),
  76. sizeof(typename BlockLoadVecT::TempStorage),
  77. sizeof(typename BlockLoadIndexT::TempStorage),
  78. sizeof(typename BlockLoadIndexVecT::TempStorage),
  79. (int(kIsVariableB) + int(kIsVariableC)) *
  80. sizeof(typename BlockLoadWeightT::TempStorage),
  81. (int(kIsVariableB) + int(kIsVariableC)) *
  82. sizeof(typename BlockLoadWeightVecT::TempStorage),
  83. sizeof(typename BlockStoreT::TempStorage),
  84. sizeof(typename BlockStoreVecT::TempStorage)});
  85. static constexpr int kSmemSize =
  86. kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
  87. };
  88. template <typename Ktraits>
  89. __global__ __launch_bounds__(
  90. Ktraits::kNThreads,
  91. Ktraits::kMinBlocks) void selective_scan_fwd_kernel(SSMParamsBase params) {
  92. constexpr bool kIsVariableB = Ktraits::kIsVariableB;
  93. constexpr bool kIsVariableC = Ktraits::kIsVariableC;
  94. constexpr bool kHasZ = Ktraits::kHasZ;
  95. constexpr bool kUseIndex = Ktraits::kUseIndex;
  96. constexpr int kNThreads = Ktraits::kNThreads;
  97. constexpr int kNItems = Ktraits::kNItems;
  98. constexpr int kNRows = Ktraits::kNRows;
  99. constexpr bool kDirectIO = Ktraits::kDirectIO;
  100. using input_t = typename Ktraits::input_t;
  101. using weight_t = typename Ktraits::weight_t;
  102. using scan_t = typename Ktraits::scan_t;
  103. // Shared memory.
  104. extern __shared__ char smem_[];
  105. // cast to lvalue reference of expected type
  106. // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
  107. // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_
  108. // + 2 * MAX_DSTATE * sizeof(weight_t)); auto& smem_load =
  109. // reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
  110. auto& smem_load =
  111. reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  112. auto& smem_load_weight =
  113. reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
  114. auto& smem_load_index =
  115. reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
  116. auto& smem_load_weight1 =
  117. *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(
  118. smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
  119. auto& smem_store =
  120. reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  121. auto& smem_scan =
  122. *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(
  123. smem_ + Ktraits::kSmemIOSize);
  124. // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ +
  125. // smem_loadstorescan_size); weight_t *smem_bc = reinterpret_cast<weight_t
  126. // *>(smem_a + MAX_DSTATE);
  127. scan_t* smem_running_prefix =
  128. reinterpret_cast<scan_t*>(smem_ + Ktraits::kSmemSize);
  129. const int batch_id = blockIdx.x;
  130. const int dim_id = blockIdx.y;
  131. const int group_id = dim_id / (params.dim_ngroups_ratio);
  132. input_t* u = reinterpret_cast<input_t*>(params.u_ptr) +
  133. batch_id * params.u_batch_stride +
  134. dim_id * kNRows * params.u_d_stride;
  135. input_t* delta = reinterpret_cast<input_t*>(params.delta_ptr) +
  136. batch_id * params.delta_batch_stride +
  137. dim_id * kNRows * params.delta_d_stride;
  138. weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr) +
  139. dim_id * kNRows * params.A_d_stride;
  140. weight_t* B = reinterpret_cast<weight_t*>(params.B_ptr) +
  141. dim_id * kNRows * params.B_d_stride;
  142. input_t* Bvar = reinterpret_cast<input_t*>(params.B_ptr) +
  143. batch_id * params.B_batch_stride +
  144. group_id * params.B_group_stride;
  145. weight_t* C = reinterpret_cast<weight_t*>(params.C_ptr) +
  146. dim_id * kNRows * params.C_d_stride;
  147. input_t* Cvar = reinterpret_cast<input_t*>(params.C_ptr) +
  148. batch_id * params.C_batch_stride +
  149. group_id * params.C_group_stride;
  150. scan_t* x = reinterpret_cast<scan_t*>(params.x_ptr) +
  151. (batch_id * params.dim + dim_id * kNRows) * params.n_chunks *
  152. params.dstate;
  153. int* index = !kUseIndex ? nullptr
  154. : reinterpret_cast<int*>(params.index_ptr) +
  155. batch_id * params.seqlen;
  156. float D_val[kNRows] = {0};
  157. if (params.D_ptr != nullptr) {
  158. #pragma unroll
  159. for (int r = 0; r < kNRows; ++r) {
  160. D_val[r] = reinterpret_cast<float*>(params.D_ptr)[dim_id * kNRows + r];
  161. }
  162. }
  163. float delta_bias[kNRows] = {0};
  164. if (params.delta_bias_ptr != nullptr) {
  165. #pragma unroll
  166. for (int r = 0; r < kNRows; ++r) {
  167. delta_bias[r] =
  168. reinterpret_cast<float*>(params.delta_bias_ptr)[dim_id * kNRows + r];
  169. }
  170. }
  171. // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx +=
  172. // blockDim.x) {
  173. // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
  174. // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] *
  175. // C[state_idx * params.C_dstate_stride];
  176. // }
  177. constexpr int kChunkSize = kNThreads * kNItems;
  178. for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
  179. input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
  180. int index_vals_load[kNRows][kNItems];
  181. __syncthreads();
  182. #pragma unroll
  183. for (int r = 0; r < kNRows; ++r) {
  184. if constexpr (!kDirectIO) {
  185. if (r > 0) {
  186. __syncthreads();
  187. }
  188. }
  189. load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load,
  190. params.seqlen - chunk * kChunkSize);
  191. if constexpr (!kDirectIO) {
  192. __syncthreads();
  193. }
  194. load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r],
  195. smem_load, params.seqlen - chunk * kChunkSize);
  196. if constexpr (kUseIndex) {
  197. load_index<Ktraits>(index + r * params.delta_d_stride,
  198. index_vals_load[r], smem_load_index,
  199. params.seqlen - chunk * kChunkSize);
  200. }
  201. }
  202. if constexpr (kUseIndex) {
  203. index += kChunkSize;
  204. }
  205. u += kChunkSize;
  206. delta += kChunkSize;
  207. float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems],
  208. out_vals[kNRows][kNItems];
  209. #pragma unroll
  210. for (int r = 0; r < kNRows; ++r) {
  211. #pragma unroll
  212. for (int i = 0; i < kNItems; ++i) {
  213. float u_val = float(u_vals[r][i]);
  214. delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
  215. if (params.delta_softplus) {
  216. delta_vals[r][i] = delta_vals[r][i] <= 20.f
  217. ? log1pf(expf(delta_vals[r][i]))
  218. : delta_vals[r][i];
  219. }
  220. delta_u_vals[r][i] = delta_vals[r][i] * u_val;
  221. out_vals[r][i] = D_val[r] * u_val;
  222. }
  223. }
  224. __syncthreads();
  225. for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
  226. weight_t A_val[kNRows];
  227. #pragma unroll
  228. for (int r = 0; r < kNRows; ++r) {
  229. A_val[r] =
  230. A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
  231. // Multiply the real part of A with LOG2E so we can use exp2f instead of
  232. // expf.
  233. constexpr float kLog2e = M_LOG2E;
  234. A_val[r] *= kLog2e;
  235. }
  236. // This variable holds B * C if both B and C are constant across seqlen.
  237. // If only B varies across seqlen, this holds C. If only C varies across
  238. // seqlen, this holds B. If both B and C vary, this is unused.
  239. weight_t BC_val[kNRows];
  240. weight_t B_vals[kNItems], C_vals[kNItems];
  241. if constexpr (kIsVariableB) {
  242. load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
  243. smem_load_weight,
  244. (params.seqlen - chunk * kChunkSize) * (1));
  245. if constexpr (!kIsVariableC) {
  246. #pragma unroll
  247. for (int r = 0; r < kNRows; ++r) {
  248. BC_val[r] =
  249. C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  250. }
  251. }
  252. }
  253. if constexpr (kIsVariableC) {
  254. auto& smem_load_weight_C =
  255. !kIsVariableB ? smem_load_weight : smem_load_weight1;
  256. load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
  257. smem_load_weight_C,
  258. (params.seqlen - chunk * kChunkSize) * (1));
  259. if constexpr (!kIsVariableB) {
  260. #pragma unroll
  261. for (int r = 0; r < kNRows; ++r) {
  262. BC_val[r] =
  263. B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
  264. }
  265. }
  266. }
  267. if constexpr (!kIsVariableB && !kIsVariableC) {
  268. #pragma unroll
  269. for (int r = 0; r < kNRows; ++r) {
  270. BC_val[r] =
  271. B[state_idx * params.B_dstate_stride + r * params.B_d_stride] *
  272. C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  273. }
  274. }
  275. #pragma unroll
  276. for (int r = 0; r < kNRows; ++r) {
  277. if (r > 0) {
  278. __syncthreads();
  279. } // Scan could be using the same smem
  280. scan_t thread_data[kNItems];
  281. #pragma unroll
  282. for (int i = 0; i < kNItems; ++i) {
  283. thread_data[i] =
  284. make_float2(exp2f(delta_vals[r][i] * A_val[r]),
  285. !kIsVariableB ? delta_u_vals[r][i]
  286. : B_vals[i] * delta_u_vals[r][i]);
  287. // Reset A bar for cumulative sequences (Real)
  288. if constexpr (kUseIndex) {
  289. if (index_vals_load[r][i] == 0) {
  290. thread_data[i].x = 0.f;
  291. }
  292. }
  293. if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is
  294. // correct
  295. if (threadIdx.x * kNItems + i >=
  296. params.seqlen - chunk * kChunkSize) {
  297. thread_data[i] = make_float2(1.f, 0.f);
  298. }
  299. }
  300. }
  301. // Initialize running total
  302. scan_t running_prefix;
  303. // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0)
  304. // needs to read
  305. running_prefix =
  306. chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx]
  307. : (threadIdx.x % 32 == 0
  308. ? smem_running_prefix[state_idx + r * MAX_DSTATE]
  309. : make_float2(1.f, 0.f));
  310. // running_prefix = chunk > 0 && threadIdx.x == 0 ?
  311. // smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
  312. SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
  313. typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
  314. thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op);
  315. // There's a syncthreads in the scan op, so we don't need to sync here.
  316. // Unless there's only 1 warp, but then it's the same thread (0) reading
  317. // and writing.
  318. if (threadIdx.x == 0) {
  319. smem_running_prefix[state_idx] = prefix_op.running_prefix;
  320. x[(r * params.n_chunks + chunk) * params.dstate + state_idx] =
  321. prefix_op.running_prefix;
  322. }
  323. #pragma unroll
  324. for (int i = 0; i < kNItems; ++i) {
  325. const weight_t C_val =
  326. !kIsVariableC
  327. ? BC_val[r]
  328. : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
  329. out_vals[r][i] += thread_data[i].y * C_val;
  330. }
  331. }
  332. }
  333. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
  334. batch_id * params.out_batch_stride +
  335. dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
  336. __syncthreads();
  337. #pragma unroll
  338. for (int r = 0; r < kNRows; ++r) {
  339. if constexpr (!kDirectIO) {
  340. if (r > 0) {
  341. __syncthreads();
  342. }
  343. }
  344. store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r],
  345. smem_store, params.seqlen - chunk * kChunkSize);
  346. }
  347. if constexpr (kHasZ) {
  348. input_t* z = reinterpret_cast<input_t*>(params.z_ptr) +
  349. batch_id * params.z_batch_stride +
  350. dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
  351. input_t* out_z = reinterpret_cast<input_t*>(params.out_z_ptr) +
  352. batch_id * params.out_z_batch_stride +
  353. dim_id * kNRows * params.out_z_d_stride +
  354. chunk * kChunkSize;
  355. #pragma unroll
  356. for (int r = 0; r < kNRows; ++r) {
  357. input_t z_vals[kNItems];
  358. __syncthreads();
  359. load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load,
  360. params.seqlen - chunk * kChunkSize);
  361. #pragma unroll
  362. for (int i = 0; i < kNItems; ++i) {
  363. float z_val = z_vals[i];
  364. out_vals[r][i] *= z_val / (1 + expf(-z_val));
  365. }
  366. __syncthreads();
  367. store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r],
  368. smem_store, params.seqlen - chunk * kChunkSize);
  369. }
  370. }
  371. Bvar += kChunkSize * 1;
  372. Cvar += kChunkSize * 1;
  373. }
  374. }
  375. template <int kNThreads, int kNItems, typename input_t, typename weight_t>
  376. void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) {
  377. // Only kNRows == 1 is tested for now, which ofc doesn't differ from
  378. // previously when we had each block processing 1 row.
  379. constexpr int kNRows = 1;
  380. BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
  381. BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
  382. BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
  383. BOOL_SWITCH(params.z_ptr != nullptr, kHasZ, [&] {
  384. BOOL_SWITCH(params.index_ptr != nullptr, kUseIndex, [&] {
  385. using Ktraits = Selective_Scan_fwd_kernel_traits<
  386. kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB,
  387. kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
  388. // constexpr int kSmemSize = Ktraits::kSmemSize;
  389. constexpr int kSmemSize =
  390. Ktraits::kSmemSize +
  391. kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
  392. dim3 grid(params.batch, params.dim / kNRows);
  393. auto kernel = &selective_scan_fwd_kernel<Ktraits>;
  394. if (kSmemSize >= 48 * 1024) {
  395. C10_CUDA_CHECK(cudaFuncSetAttribute(
  396. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
  397. kSmemSize));
  398. }
  399. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  400. C10_CUDA_KERNEL_LAUNCH_CHECK();
  401. });
  402. });
  403. });
  404. });
  405. });
  406. }
  407. template <typename input_t, typename weight_t>
  408. void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream) {
  409. #ifndef USE_ROCM
  410. if (params.seqlen <= 128) {
  411. selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
  412. } else if (params.seqlen <= 256) {
  413. selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
  414. } else if (params.seqlen <= 512) {
  415. selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
  416. } else if (params.seqlen <= 1024) {
  417. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  418. } else {
  419. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  420. }
  421. #else
  422. if (params.seqlen <= 256) {
  423. selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
  424. } else if (params.seqlen <= 512) {
  425. selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
  426. } else if (params.seqlen <= 1024) {
  427. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  428. } else {
  429. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  430. }
  431. #endif
  432. }
  433. template void selective_scan_fwd_cuda<at::BFloat16, float>(
  434. SSMParamsBase& params, cudaStream_t stream);
  435. template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase& params,
  436. cudaStream_t stream);
  437. template void selective_scan_fwd_cuda<float, float>(SSMParamsBase& params,
  438. cudaStream_t stream);
  439. #define CHECK_SHAPE(x, ...) \
  440. TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
  441. #x " must have shape (" #__VA_ARGS__ ")")
  442. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  443. if (ITYPE == at::ScalarType::Half) { \
  444. using input_t = at::Half; \
  445. __VA_ARGS__(); \
  446. } else if (ITYPE == at::ScalarType::BFloat16) { \
  447. using input_t = at::BFloat16; \
  448. __VA_ARGS__(); \
  449. } else if (ITYPE == at::ScalarType::Float) { \
  450. using input_t = float; \
  451. __VA_ARGS__(); \
  452. } else { \
  453. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \
  454. "'"); \
  455. }
  456. #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
  457. if (WTYPE == at::ScalarType::Half) { \
  458. using weight_t = at::Half; \
  459. __VA_ARGS__(); \
  460. } else if (WTYPE == at::ScalarType::BFloat16) { \
  461. using weight_t = at::BFloat16; \
  462. __VA_ARGS__(); \
  463. } else if (WTYPE == at::ScalarType::Float) { \
  464. using weight_t = float; \
  465. __VA_ARGS__(); \
  466. } else { \
  467. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \
  468. "'"); \
  469. }
  470. #define DISPATCH_WTYPE_FLOAT(WTYPE, NAME, ...) \
  471. if (WTYPE == at::ScalarType::Float) { \
  472. using weight_t = float; \
  473. __VA_ARGS__(); \
  474. } else { \
  475. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \
  476. "'"); \
  477. }
  478. template <typename input_t, typename weight_t>
  479. void selective_scan_fwd_cuda(SSMParamsBase& params, cudaStream_t stream);
  480. void set_ssm_params_fwd(SSMParamsBase& params,
  481. // sizes
  482. const size_t batch, const size_t dim,
  483. const size_t seqlen, const size_t dstate,
  484. const size_t n_groups, const size_t n_chunks,
  485. const bool is_variable_B, const bool is_variable_C,
  486. // device pointers
  487. const torch::Tensor u, const torch::Tensor delta,
  488. const torch::Tensor A, const torch::Tensor B,
  489. const torch::Tensor C, const torch::Tensor out,
  490. const torch::Tensor z, const torch::Tensor out_z,
  491. void* D_ptr, void* delta_bias_ptr, void* x_ptr,
  492. bool has_z, bool delta_softplus, void* index_ptr) {
  493. // Reset the parameters
  494. memset(&params, 0, sizeof(params));
  495. params.batch = batch;
  496. params.dim = dim;
  497. params.seqlen = seqlen;
  498. params.dstate = dstate;
  499. params.n_groups = n_groups;
  500. params.n_chunks = n_chunks;
  501. params.dim_ngroups_ratio = dim / n_groups;
  502. params.delta_softplus = delta_softplus;
  503. params.is_variable_B = is_variable_B;
  504. params.is_variable_C = is_variable_C;
  505. // Set the pointers and strides.
  506. params.u_ptr = u.data_ptr();
  507. params.delta_ptr = delta.data_ptr();
  508. params.A_ptr = A.data_ptr();
  509. params.B_ptr = B.data_ptr();
  510. params.C_ptr = C.data_ptr();
  511. params.D_ptr = D_ptr;
  512. params.delta_bias_ptr = delta_bias_ptr;
  513. params.out_ptr = out.data_ptr();
  514. params.x_ptr = x_ptr;
  515. params.z_ptr = has_z ? z.data_ptr() : nullptr;
  516. params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
  517. params.index_ptr = index_ptr;
  518. // All stride are in elements, not bytes.
  519. params.A_d_stride = A.stride(0);
  520. params.A_dstate_stride = A.stride(1);
  521. if (!is_variable_B) {
  522. params.B_d_stride = B.stride(0);
  523. } else {
  524. params.B_batch_stride = B.stride(0);
  525. params.B_group_stride = B.stride(1);
  526. }
  527. params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
  528. if (!is_variable_C) {
  529. params.C_d_stride = C.stride(0);
  530. } else {
  531. params.C_batch_stride = C.stride(0);
  532. params.C_group_stride = C.stride(1);
  533. }
  534. params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
  535. params.u_batch_stride = u.stride(0);
  536. params.u_d_stride = u.stride(1);
  537. params.delta_batch_stride = delta.stride(0);
  538. params.delta_d_stride = delta.stride(1);
  539. if (has_z) {
  540. params.z_batch_stride = z.stride(0);
  541. params.z_d_stride = z.stride(1);
  542. params.out_z_batch_stride = out_z.stride(0);
  543. params.out_z_d_stride = out_z.stride(1);
  544. }
  545. params.out_batch_stride = out.stride(0);
  546. params.out_d_stride = out.stride(1);
  547. }
  548. std::vector<torch::Tensor> selective_scan_fwd(
  549. const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
  550. const torch::Tensor& B, const torch::Tensor& C,
  551. const c10::optional<torch::Tensor>& D_,
  552. const c10::optional<torch::Tensor>& z_,
  553. const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
  554. const c10::optional<torch::Tensor>& index_,
  555. const c10::optional<torch::Tensor>& x) {
  556. auto input_type = u.scalar_type();
  557. auto weight_type = A.scalar_type();
  558. TORCH_CHECK(input_type == at::ScalarType::Float ||
  559. input_type == at::ScalarType::Half ||
  560. input_type == at::ScalarType::BFloat16);
  561. TORCH_CHECK(weight_type == at::ScalarType::Float ||
  562. weight_type == at::ScalarType::ComplexFloat);
  563. const bool is_variable_B = B.dim() >= 3;
  564. const bool is_variable_C = C.dim() >= 3;
  565. const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
  566. TORCH_CHECK(delta.scalar_type() == input_type);
  567. TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
  568. TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
  569. TORCH_CHECK(u.is_cuda());
  570. TORCH_CHECK(delta.is_cuda());
  571. TORCH_CHECK(A.is_cuda());
  572. TORCH_CHECK(B.is_cuda());
  573. TORCH_CHECK(C.is_cuda());
  574. TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
  575. TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
  576. const auto sizes = u.sizes();
  577. const int batch_size = sizes[0];
  578. const int dim = sizes[1];
  579. const int seqlen = sizes[2];
  580. const int dstate = A.size(1);
  581. const int n_groups = is_variable_B ? B.size(1) : 1;
  582. TORCH_CHECK(dstate <= 256,
  583. "selective_scan only supports state dimension <= 256");
  584. CHECK_SHAPE(u, batch_size, dim, seqlen);
  585. CHECK_SHAPE(delta, batch_size, dim, seqlen);
  586. CHECK_SHAPE(A, dim, dstate);
  587. if (!is_variable_B) {
  588. CHECK_SHAPE(B, dim, dstate);
  589. } else {
  590. CHECK_SHAPE(B, batch_size, n_groups, dstate,
  591. !is_complex ? seqlen : seqlen * 2);
  592. TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
  593. }
  594. if (!is_variable_C) {
  595. CHECK_SHAPE(C, dim, dstate);
  596. } else {
  597. CHECK_SHAPE(C, batch_size, n_groups, dstate,
  598. !is_complex ? seqlen : seqlen * 2);
  599. TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
  600. }
  601. if (D_.has_value()) {
  602. auto D = D_.value();
  603. TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
  604. TORCH_CHECK(D.is_cuda());
  605. TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
  606. CHECK_SHAPE(D, dim);
  607. }
  608. if (delta_bias_.has_value()) {
  609. auto delta_bias = delta_bias_.value();
  610. TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
  611. TORCH_CHECK(delta_bias.is_cuda());
  612. TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
  613. CHECK_SHAPE(delta_bias, dim);
  614. }
  615. if (index_.has_value()) {
  616. auto index = index_.value();
  617. TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
  618. TORCH_CHECK(index.is_cuda());
  619. CHECK_SHAPE(index, batch_size, seqlen);
  620. }
  621. at::Tensor z, out_z;
  622. const bool has_z = z_.has_value();
  623. if (has_z) {
  624. z = z_.value();
  625. TORCH_CHECK(z.scalar_type() == input_type);
  626. TORCH_CHECK(z.is_cuda());
  627. TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
  628. CHECK_SHAPE(z, batch_size, dim, seqlen);
  629. out_z = torch::empty_like(z);
  630. }
  631. const int n_chunks = (seqlen + 2048 - 1) / 2048;
  632. // const int n_chunks = (seqlen + 1024 - 1) / 1024;
  633. // at::Tensor out = torch::empty_like(u);
  634. // Right now u has BHL layout and delta has HBL layout, and we want out to
  635. // have HBL layout
  636. at::Tensor out = torch::empty_like(delta);
  637. if (x.has_value()) {
  638. auto _x = x.value();
  639. TORCH_CHECK(_x.scalar_type() == weight_type);
  640. TORCH_CHECK(_x.is_cuda());
  641. TORCH_CHECK(_x.stride(-1) == 1);
  642. CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
  643. }
  644. SSMParamsBase params;
  645. set_ssm_params_fwd(
  646. params, batch_size, dim, seqlen, dstate, n_groups, n_chunks,
  647. is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z,
  648. D_.has_value() ? D_.value().data_ptr() : nullptr,
  649. delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
  650. x.value().data_ptr(), has_z, delta_softplus,
  651. index_.has_value() ? index_.value().data_ptr() : nullptr);
  652. // Otherwise the kernel will be launched from cuda:0 device
  653. // Cast to char to avoid compiler warning about narrowing
  654. at::cuda::CUDAGuard device_guard{(char)u.get_device()};
  655. auto stream = at::cuda::getCurrentCUDAStream().stream();
  656. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  657. u.scalar_type(), "selective_scan_fwd", [&] {
  658. DISPATCH_WTYPE_FLOAT(A.scalar_type(), "selective_scan_fwd", [&] {
  659. selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
  660. });
  661. });
  662. std::vector<at::Tensor> result = {out, x.value()};
  663. if (has_z) {
  664. result.push_back(out_z);
  665. }
  666. return result;
  667. }