selective_scan_fwd.cu 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. // clang-format off
  2. // adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
  3. #include <torch/all.h>
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include "selective_scan.h"
  7. #include <c10/util/BFloat16.h>
  8. #include <c10/util/Half.h>
  9. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  10. #ifndef USE_ROCM
  11. #include <cub/block/block_load.cuh>
  12. #include <cub/block/block_store.cuh>
  13. #include <cub/block/block_scan.cuh>
  14. #else
  15. #include <hipcub/hipcub.hpp>
  16. namespace cub = hipcub;
  17. #endif
  18. #include "selective_scan.h"
  19. #include "static_switch.h"
  20. template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
  21. bool kIsVariableB_, bool kIsVariableC_,
  22. bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
  23. struct Selective_Scan_fwd_kernel_traits {
  24. static_assert(kNItems_ % 4 == 0);
  25. using input_t = input_t_;
  26. using weight_t = weight_t_;
  27. static constexpr int kNThreads = kNThreads_;
  28. // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
  29. static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
  30. static constexpr int kNItems = kNItems_;
  31. static constexpr int kNRows = kNRows_;
  32. static constexpr int kNBytes = sizeof(input_t);
  33. static_assert(kNBytes == 2 || kNBytes == 4);
  34. static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
  35. static_assert(kNItems % kNElts == 0);
  36. static constexpr int kNLoads = kNItems / kNElts;
  37. static constexpr bool kIsEvenLen = kIsEvenLen_;
  38. static constexpr bool kIsVariableB = kIsVariableB_;
  39. static constexpr bool kIsVariableC = kIsVariableC_;
  40. static constexpr bool kHasZ = kHasZ_;
  41. static constexpr bool kUseIndex = kUseIndex_;
  42. static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
  43. static constexpr int kNLoadsIndex = kNItems / 4;
  44. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  45. using scan_t = float2;
  46. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  47. using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
  48. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
  49. using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  50. using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
  51. !(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
  52. using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  53. using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
  54. !kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
  55. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
  56. using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
  57. !kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
  58. // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
  59. // using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
  60. using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
  61. static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
  62. sizeof(typename BlockLoadVecT::TempStorage),
  63. sizeof(typename BlockLoadIndexT::TempStorage),
  64. sizeof(typename BlockLoadIndexVecT::TempStorage),
  65. (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
  66. (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
  67. sizeof(typename BlockStoreT::TempStorage),
  68. sizeof(typename BlockStoreVecT::TempStorage)});
  69. static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
  70. };
  71. template<typename Ktraits>
  72. __global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
  73. void selective_scan_fwd_kernel(SSMParamsBase params) {
  74. constexpr bool kIsVariableB = Ktraits::kIsVariableB;
  75. constexpr bool kIsVariableC = Ktraits::kIsVariableC;
  76. constexpr bool kHasZ = Ktraits::kHasZ;
  77. constexpr bool kUseIndex = Ktraits::kUseIndex;
  78. constexpr int kNThreads = Ktraits::kNThreads;
  79. constexpr int kNItems = Ktraits::kNItems;
  80. constexpr int kNRows = Ktraits::kNRows;
  81. constexpr bool kDirectIO = Ktraits::kDirectIO;
  82. using input_t = typename Ktraits::input_t;
  83. using weight_t = typename Ktraits::weight_t;
  84. using scan_t = typename Ktraits::scan_t;
  85. // Shared memory.
  86. extern __shared__ char smem_[];
  87. // cast to lvalue reference of expected type
  88. // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
  89. // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
  90. // auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
  91. auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  92. auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
  93. auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
  94. auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
  95. auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  96. auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
  97. // weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
  98. // weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
  99. scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
  100. const int batch_id = blockIdx.x;
  101. const int dim_id = blockIdx.y;
  102. const int group_id = dim_id / (params.dim_ngroups_ratio);
  103. input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
  104. + dim_id * kNRows * params.u_d_stride;
  105. input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
  106. + dim_id * kNRows * params.delta_d_stride;
  107. weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
  108. weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
  109. input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
  110. weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
  111. input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
  112. scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
  113. int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
  114. float D_val[kNRows] = {0};
  115. if (params.D_ptr != nullptr) {
  116. #pragma unroll
  117. for (int r = 0; r < kNRows; ++r) {
  118. D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
  119. }
  120. }
  121. float delta_bias[kNRows] = {0};
  122. if (params.delta_bias_ptr != nullptr) {
  123. #pragma unroll
  124. for (int r = 0; r < kNRows; ++r) {
  125. delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
  126. }
  127. }
  128. // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
  129. // smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
  130. // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
  131. // }
  132. constexpr int kChunkSize = kNThreads * kNItems;
  133. for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
  134. input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
  135. int index_vals_load[kNRows][kNItems];
  136. __syncthreads();
  137. #pragma unroll
  138. for (int r = 0; r < kNRows; ++r) {
  139. if constexpr (!kDirectIO) {
  140. if (r > 0) { __syncthreads(); }
  141. }
  142. load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
  143. if constexpr (!kDirectIO) { __syncthreads(); }
  144. load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
  145. if constexpr (kUseIndex) {
  146. load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
  147. }
  148. }
  149. if constexpr (kUseIndex) {
  150. index += kChunkSize;
  151. }
  152. u += kChunkSize;
  153. delta += kChunkSize;
  154. float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
  155. #pragma unroll
  156. for (int r = 0; r < kNRows; ++r) {
  157. #pragma unroll
  158. for (int i = 0; i < kNItems; ++i) {
  159. float u_val = float(u_vals[r][i]);
  160. delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
  161. if (params.delta_softplus) {
  162. delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
  163. }
  164. delta_u_vals[r][i] = delta_vals[r][i] * u_val;
  165. out_vals[r][i] = D_val[r] * u_val;
  166. }
  167. }
  168. __syncthreads();
  169. for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
  170. weight_t A_val[kNRows];
  171. #pragma unroll
  172. for (int r = 0; r < kNRows; ++r) {
  173. A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
  174. // Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
  175. constexpr float kLog2e = M_LOG2E;
  176. A_val[r] *= kLog2e;
  177. }
  178. // This variable holds B * C if both B and C are constant across seqlen. If only B varies
  179. // across seqlen, this holds C. If only C varies across seqlen, this holds B.
  180. // If both B and C vary, this is unused.
  181. weight_t BC_val[kNRows];
  182. weight_t B_vals[kNItems], C_vals[kNItems];
  183. if constexpr (kIsVariableB) {
  184. load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
  185. smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
  186. if constexpr (!kIsVariableC) {
  187. #pragma unroll
  188. for (int r = 0; r < kNRows; ++r) {
  189. BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  190. }
  191. }
  192. }
  193. if constexpr (kIsVariableC) {
  194. auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
  195. load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
  196. smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
  197. if constexpr (!kIsVariableB) {
  198. #pragma unroll
  199. for (int r = 0; r < kNRows; ++r) {
  200. BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
  201. }
  202. }
  203. }
  204. if constexpr (!kIsVariableB && !kIsVariableC) {
  205. #pragma unroll
  206. for (int r = 0; r < kNRows; ++r) {
  207. BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
  208. }
  209. }
  210. #pragma unroll
  211. for (int r = 0; r < kNRows; ++r) {
  212. if (r > 0) { __syncthreads(); } // Scan could be using the same smem
  213. scan_t thread_data[kNItems];
  214. #pragma unroll
  215. for (int i = 0; i < kNItems; ++i) {
  216. thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
  217. !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
  218. // Reset A bar for cumulative sequences (Real)
  219. if constexpr (kUseIndex) {
  220. if (index_vals_load[r][i] == 0) {
  221. thread_data[i].x = 0.f;
  222. }
  223. }
  224. if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
  225. if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
  226. thread_data[i] = make_float2(1.f, 0.f);
  227. }
  228. }
  229. }
  230. // Initialize running total
  231. scan_t running_prefix;
  232. // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
  233. running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
  234. // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
  235. SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
  236. typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
  237. thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
  238. );
  239. // There's a syncthreads in the scan op, so we don't need to sync here.
  240. // Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
  241. if (threadIdx.x == 0) {
  242. smem_running_prefix[state_idx] = prefix_op.running_prefix;
  243. x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
  244. }
  245. #pragma unroll
  246. for (int i = 0; i < kNItems; ++i) {
  247. const weight_t C_val = !kIsVariableC
  248. ? BC_val[r]
  249. : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
  250. out_vals[r][i] += thread_data[i].y * C_val;
  251. }
  252. }
  253. }
  254. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  255. + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
  256. __syncthreads();
  257. #pragma unroll
  258. for (int r = 0; r < kNRows; ++r) {
  259. if constexpr (!kDirectIO) {
  260. if (r > 0) { __syncthreads(); }
  261. }
  262. store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
  263. }
  264. if constexpr (kHasZ) {
  265. input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
  266. + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
  267. input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
  268. + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
  269. #pragma unroll
  270. for (int r = 0; r < kNRows; ++r) {
  271. input_t z_vals[kNItems];
  272. __syncthreads();
  273. load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
  274. #pragma unroll
  275. for (int i = 0; i < kNItems; ++i) {
  276. float z_val = z_vals[i];
  277. out_vals[r][i] *= z_val / (1 + expf(-z_val));
  278. }
  279. __syncthreads();
  280. store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
  281. }
  282. }
  283. Bvar += kChunkSize * 1;
  284. Cvar += kChunkSize * 1;
  285. }
  286. }
  287. template<int kNThreads, int kNItems, typename input_t, typename weight_t>
  288. void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
  289. // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
  290. // processing 1 row.
  291. constexpr int kNRows = 1;
  292. // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
  293. constexpr bool kIsVariableB = true;
  294. constexpr bool kIsVariableC = true;
  295. constexpr bool kHasZ = true;
  296. BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
  297. BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
  298. using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
  299. constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
  300. dim3 grid(params.batch, params.dim / kNRows);
  301. auto kernel = &selective_scan_fwd_kernel<Ktraits>;
  302. if (kSmemSize >= 48 * 1024) {
  303. C10_CUDA_CHECK(cudaFuncSetAttribute(
  304. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  305. }
  306. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  307. C10_CUDA_KERNEL_LAUNCH_CHECK();
  308. });
  309. });
  310. }
  311. template<typename input_t, typename weight_t>
  312. void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream) {
  313. #ifndef USE_ROCM
  314. if (params.seqlen <= 128) {
  315. selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
  316. } else if (params.seqlen <= 256) {
  317. selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
  318. } else if (params.seqlen <= 512) {
  319. selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
  320. } else if (params.seqlen <= 1024) {
  321. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  322. } else {
  323. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  324. }
  325. #else
  326. if (params.seqlen <= 256) {
  327. selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
  328. } else if (params.seqlen <= 512) {
  329. selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
  330. } else if (params.seqlen <= 1024) {
  331. selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
  332. } else {
  333. selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
  334. }
  335. #endif
  336. }
  337. template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase &params, cudaStream_t stream);
  338. template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase &params, cudaStream_t stream);
  339. template void selective_scan_fwd_cuda<float, float>(SSMParamsBase &params, cudaStream_t stream);
  340. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  341. #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  342. if (ITYPE == at::ScalarType::Half) { \
  343. using input_t = at::Half; \
  344. using weight_t = float; \
  345. __VA_ARGS__(); \
  346. } else if (ITYPE == at::ScalarType::BFloat16) { \
  347. using input_t = at::BFloat16; \
  348. using weight_t = float; \
  349. __VA_ARGS__(); \
  350. } else if (ITYPE == at::ScalarType::Float) { \
  351. using input_t = float; \
  352. using weight_t = float; \
  353. __VA_ARGS__(); \
  354. } else { \
  355. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
  356. }
  357. template<typename input_t, typename weight_t>
  358. void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
  359. void set_ssm_params_fwd(SSMParamsBase &params,
  360. // sizes
  361. const size_t batch,
  362. const size_t dim,
  363. const size_t seqlen,
  364. const size_t dstate,
  365. const size_t n_groups,
  366. const size_t n_chunks,
  367. const bool is_variable_B,
  368. const bool is_variable_C,
  369. // device pointers
  370. const torch::Tensor u,
  371. const torch::Tensor delta,
  372. const torch::Tensor A,
  373. const torch::Tensor B,
  374. const torch::Tensor C,
  375. const torch::Tensor out,
  376. const torch::Tensor z,
  377. const torch::Tensor out_z,
  378. void* D_ptr,
  379. void* delta_bias_ptr,
  380. void* x_ptr,
  381. bool has_z,
  382. bool delta_softplus,
  383. void* index_ptr) {
  384. // Reset the parameters
  385. memset(&params, 0, sizeof(params));
  386. params.batch = batch;
  387. params.dim = dim;
  388. params.seqlen = seqlen;
  389. params.dstate = dstate;
  390. params.n_groups = n_groups;
  391. params.n_chunks = n_chunks;
  392. params.dim_ngroups_ratio = dim / n_groups;
  393. params.delta_softplus = delta_softplus;
  394. params.is_variable_B = is_variable_B;
  395. params.is_variable_C = is_variable_C;
  396. // Set the pointers and strides.
  397. params.u_ptr = u.data_ptr();
  398. params.delta_ptr = delta.data_ptr();
  399. params.A_ptr = A.data_ptr();
  400. params.B_ptr = B.data_ptr();
  401. params.C_ptr = C.data_ptr();
  402. params.D_ptr = D_ptr;
  403. params.delta_bias_ptr = delta_bias_ptr;
  404. params.out_ptr = out.data_ptr();
  405. params.x_ptr = x_ptr;
  406. params.z_ptr = has_z ? z.data_ptr() : nullptr;
  407. params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
  408. params.index_ptr = index_ptr;
  409. // All stride are in elements, not bytes.
  410. params.A_d_stride = A.stride(0);
  411. params.A_dstate_stride = A.stride(1);
  412. if (!is_variable_B) {
  413. params.B_d_stride = B.stride(0);
  414. } else {
  415. params.B_batch_stride = B.stride(0);
  416. params.B_group_stride = B.stride(1);
  417. }
  418. params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
  419. if (!is_variable_C) {
  420. params.C_d_stride = C.stride(0);
  421. } else {
  422. params.C_batch_stride = C.stride(0);
  423. params.C_group_stride = C.stride(1);
  424. }
  425. params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
  426. params.u_batch_stride = u.stride(0);
  427. params.u_d_stride = u.stride(1);
  428. params.delta_batch_stride = delta.stride(0);
  429. params.delta_d_stride = delta.stride(1);
  430. if (has_z) {
  431. params.z_batch_stride = z.stride(0);
  432. params.z_d_stride = z.stride(1);
  433. params.out_z_batch_stride = out_z.stride(0);
  434. params.out_z_d_stride = out_z.stride(1);
  435. }
  436. params.out_batch_stride = out.stride(0);
  437. params.out_d_stride = out.stride(1);
  438. }
  439. std::vector<torch::Tensor>
  440. selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
  441. const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
  442. const c10::optional<torch::Tensor> &D_,
  443. const c10::optional<torch::Tensor> &z_,
  444. const c10::optional<torch::Tensor> &delta_bias_,
  445. bool delta_softplus,
  446. const c10::optional<torch::Tensor> &index_,
  447. const c10::optional<torch::Tensor> &x) {
  448. auto input_type = u.scalar_type();
  449. auto weight_type = A.scalar_type();
  450. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  451. TORCH_CHECK(weight_type == at::ScalarType::Float);
  452. const bool is_variable_B = B.dim() >= 3;
  453. const bool is_variable_C = C.dim() >= 3;
  454. TORCH_CHECK(delta.scalar_type() == input_type);
  455. TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
  456. TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
  457. TORCH_CHECK(u.is_cuda());
  458. TORCH_CHECK(delta.is_cuda());
  459. TORCH_CHECK(A.is_cuda());
  460. TORCH_CHECK(B.is_cuda());
  461. TORCH_CHECK(C.is_cuda());
  462. TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
  463. TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
  464. const auto sizes = u.sizes();
  465. const int batch_size = sizes[0];
  466. const int dim = sizes[1];
  467. const int seqlen = sizes[2];
  468. const int dstate = A.size(1);
  469. const int n_groups = is_variable_B ? B.size(1) : 1;
  470. TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
  471. CHECK_SHAPE(u, batch_size, dim, seqlen);
  472. CHECK_SHAPE(delta, batch_size, dim, seqlen);
  473. CHECK_SHAPE(A, dim, dstate);
  474. TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
  475. CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
  476. TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
  477. TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
  478. CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
  479. TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
  480. if (D_.has_value()) {
  481. auto D = D_.value();
  482. TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
  483. TORCH_CHECK(D.is_cuda());
  484. TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
  485. CHECK_SHAPE(D, dim);
  486. }
  487. if (delta_bias_.has_value()) {
  488. auto delta_bias = delta_bias_.value();
  489. TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
  490. TORCH_CHECK(delta_bias.is_cuda());
  491. TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
  492. CHECK_SHAPE(delta_bias, dim);
  493. }
  494. if (index_.has_value()) {
  495. auto index = index_.value();
  496. TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
  497. TORCH_CHECK(index.is_cuda());
  498. CHECK_SHAPE(index, batch_size, seqlen);
  499. }
  500. at::Tensor z, out_z;
  501. const bool has_z = z_.has_value();
  502. TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
  503. z = z_.value();
  504. TORCH_CHECK(z.scalar_type() == input_type);
  505. TORCH_CHECK(z.is_cuda());
  506. TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
  507. CHECK_SHAPE(z, batch_size, dim, seqlen);
  508. out_z = torch::empty_like(z);
  509. const int n_chunks = (seqlen + 2048 - 1) / 2048;
  510. // const int n_chunks = (seqlen + 1024 - 1) / 1024;
  511. // at::Tensor out = torch::empty_like(u);
  512. // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
  513. at::Tensor out = torch::empty_like(delta);
  514. if (x.has_value()){
  515. auto _x = x.value();
  516. TORCH_CHECK(_x.scalar_type() == weight_type);
  517. TORCH_CHECK(_x.is_cuda());
  518. TORCH_CHECK(_x.stride(-1) == 1);
  519. CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
  520. }
  521. SSMParamsBase params;
  522. set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
  523. u, delta, A, B, C, out, z, out_z,
  524. D_.has_value() ? D_.value().data_ptr() : nullptr,
  525. delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
  526. x.value().data_ptr(),
  527. has_z,
  528. delta_softplus,
  529. index_.has_value() ? index_.value().data_ptr() : nullptr);
  530. // Otherwise the kernel will be launched from cuda:0 device
  531. // Cast to char to avoid compiler warning about narrowing
  532. at::cuda::CUDAGuard device_guard{(char)u.get_device()};
  533. auto stream = at::cuda::getCurrentCUDAStream().stream();
  534. DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
  535. selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
  536. });
  537. std::vector<at::Tensor> result = {out, x.value()};
  538. if (has_z) { result.push_back(out_z); }
  539. return result;
  540. }