selective_scan.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include <torch/extension.h>
  7. #include <vector>
  8. #include "selective_scan.h"
  9. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  10. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  11. if (ITYPE == at::ScalarType::Half) { \
  12. using input_t = at::Half; \
  13. __VA_ARGS__(); \
  14. } else if (ITYPE == at::ScalarType::BFloat16) { \
  15. using input_t = at::BFloat16; \
  16. __VA_ARGS__(); \
  17. } else if (ITYPE == at::ScalarType::Float) { \
  18. using input_t = float; \
  19. __VA_ARGS__(); \
  20. } else { \
  21. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
  22. }
  23. #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
  24. if (WTYPE == at::ScalarType::Half) { \
  25. using weight_t = at::Half; \
  26. __VA_ARGS__(); \
  27. } else if (WTYPE == at::ScalarType::BFloat16) { \
  28. using weight_t = at::BFloat16; \
  29. __VA_ARGS__(); \
  30. } else if (WTYPE == at::ScalarType::Float) { \
  31. using weight_t = float; \
  32. __VA_ARGS__(); \
  33. } else { \
  34. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
  35. }
  36. #define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \
  37. if (WTYPE == at::ScalarType::Float) { \
  38. using weight_t = float; \
  39. __VA_ARGS__(); \
  40. } else if (WTYPE == at::ScalarType::ComplexFloat) { \
  41. using weight_t = c10::complex<float>; \
  42. __VA_ARGS__(); \
  43. } else { \
  44. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
  45. }
  46. template<typename input_t, typename weight_t>
  47. void selective_scan_fwd_cuda(SSMParamsBase &params, cudaStream_t stream);
  48. void set_ssm_params_fwd(SSMParamsBase &params,
  49. // sizes
  50. const size_t batch,
  51. const size_t dim,
  52. const size_t seqlen,
  53. const size_t dstate,
  54. const size_t n_groups,
  55. const size_t n_chunks,
  56. const bool is_variable_B,
  57. const bool is_variable_C,
  58. // device pointers
  59. const at::Tensor u,
  60. const at::Tensor delta,
  61. const at::Tensor A,
  62. const at::Tensor B,
  63. const at::Tensor C,
  64. const at::Tensor out,
  65. const at::Tensor z,
  66. const at::Tensor out_z,
  67. void* D_ptr,
  68. void* delta_bias_ptr,
  69. void* x_ptr,
  70. bool has_z,
  71. bool delta_softplus) {
  72. // Reset the parameters
  73. memset(&params, 0, sizeof(params));
  74. params.batch = batch;
  75. params.dim = dim;
  76. params.seqlen = seqlen;
  77. params.dstate = dstate;
  78. params.n_groups = n_groups;
  79. params.n_chunks = n_chunks;
  80. params.dim_ngroups_ratio = dim / n_groups;
  81. params.delta_softplus = delta_softplus;
  82. params.is_variable_B = is_variable_B;
  83. params.is_variable_C = is_variable_C;
  84. // Set the pointers and strides.
  85. params.u_ptr = u.data_ptr();
  86. params.delta_ptr = delta.data_ptr();
  87. params.A_ptr = A.data_ptr();
  88. params.B_ptr = B.data_ptr();
  89. params.C_ptr = C.data_ptr();
  90. params.D_ptr = D_ptr;
  91. params.delta_bias_ptr = delta_bias_ptr;
  92. params.out_ptr = out.data_ptr();
  93. params.x_ptr = x_ptr;
  94. params.z_ptr = has_z ? z.data_ptr() : nullptr;
  95. params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
  96. // All stride are in elements, not bytes.
  97. params.A_d_stride = A.stride(0);
  98. params.A_dstate_stride = A.stride(1);
  99. if (!is_variable_B) {
  100. params.B_d_stride = B.stride(0);
  101. } else {
  102. params.B_batch_stride = B.stride(0);
  103. params.B_group_stride = B.stride(1);
  104. }
  105. params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
  106. if (!is_variable_C) {
  107. params.C_d_stride = C.stride(0);
  108. } else {
  109. params.C_batch_stride = C.stride(0);
  110. params.C_group_stride = C.stride(1);
  111. }
  112. params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
  113. params.u_batch_stride = u.stride(0);
  114. params.u_d_stride = u.stride(1);
  115. params.delta_batch_stride = delta.stride(0);
  116. params.delta_d_stride = delta.stride(1);
  117. if (has_z) {
  118. params.z_batch_stride = z.stride(0);
  119. params.z_d_stride = z.stride(1);
  120. params.out_z_batch_stride = out_z.stride(0);
  121. params.out_z_d_stride = out_z.stride(1);
  122. }
  123. params.out_batch_stride = out.stride(0);
  124. params.out_d_stride = out.stride(1);
  125. }
  126. std::vector<at::Tensor>
  127. selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
  128. const at::Tensor &A, const at::Tensor &B, const at::Tensor &C,
  129. const c10::optional<at::Tensor> &D_,
  130. const c10::optional<at::Tensor> &z_,
  131. const c10::optional<at::Tensor> &delta_bias_,
  132. bool delta_softplus) {
  133. auto input_type = u.scalar_type();
  134. auto weight_type = A.scalar_type();
  135. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  136. TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat);
  137. const bool is_variable_B = B.dim() >= 3;
  138. const bool is_variable_C = C.dim() >= 3;
  139. const bool is_complex = weight_type == at::ScalarType::ComplexFloat;
  140. TORCH_CHECK(delta.scalar_type() == input_type);
  141. TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
  142. TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
  143. TORCH_CHECK(u.is_cuda());
  144. TORCH_CHECK(delta.is_cuda());
  145. TORCH_CHECK(A.is_cuda());
  146. TORCH_CHECK(B.is_cuda());
  147. TORCH_CHECK(C.is_cuda());
  148. TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
  149. TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
  150. const auto sizes = u.sizes();
  151. const int batch_size = sizes[0];
  152. const int dim = sizes[1];
  153. const int seqlen = sizes[2];
  154. const int dstate = A.size(1);
  155. const int n_groups = is_variable_B ? B.size(1) : 1;
  156. TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
  157. CHECK_SHAPE(u, batch_size, dim, seqlen);
  158. CHECK_SHAPE(delta, batch_size, dim, seqlen);
  159. CHECK_SHAPE(A, dim, dstate);
  160. if (!is_variable_B) {
  161. CHECK_SHAPE(B, dim, dstate);
  162. } else {
  163. CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
  164. TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
  165. }
  166. if (!is_variable_C) {
  167. CHECK_SHAPE(C, dim, dstate);
  168. } else {
  169. CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
  170. TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
  171. }
  172. if (D_.has_value()) {
  173. auto D = D_.value();
  174. TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
  175. TORCH_CHECK(D.is_cuda());
  176. TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
  177. CHECK_SHAPE(D, dim);
  178. }
  179. if (delta_bias_.has_value()) {
  180. auto delta_bias = delta_bias_.value();
  181. TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
  182. TORCH_CHECK(delta_bias.is_cuda());
  183. TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
  184. CHECK_SHAPE(delta_bias, dim);
  185. }
  186. at::Tensor z, out_z;
  187. const bool has_z = z_.has_value();
  188. if (has_z) {
  189. z = z_.value();
  190. TORCH_CHECK(z.scalar_type() == input_type);
  191. TORCH_CHECK(z.is_cuda());
  192. TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
  193. CHECK_SHAPE(z, batch_size, dim, seqlen);
  194. out_z = torch::empty_like(z);
  195. }
  196. const int n_chunks = (seqlen + 2048 - 1) / 2048;
  197. // const int n_chunks = (seqlen + 1024 - 1) / 1024;
  198. // at::Tensor out = torch::empty_like(u);
  199. // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
  200. at::Tensor out = torch::empty_like(delta);
  201. at::Tensor x;
  202. x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type));
  203. SSMParamsBase params;
  204. set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
  205. u, delta, A, B, C, out, z, out_z,
  206. D_.has_value() ? D_.value().data_ptr() : nullptr,
  207. delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
  208. x.data_ptr(),
  209. has_z,
  210. delta_softplus);
  211. // Otherwise the kernel will be launched from cuda:0 device
  212. // Cast to char to avoid compiler warning about narrowing
  213. at::cuda::CUDAGuard device_guard{(char)u.get_device()};
  214. auto stream = at::cuda::getCurrentCUDAStream().stream();
  215. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
  216. DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] {
  217. selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
  218. });
  219. });
  220. std::vector<at::Tensor> result = {out, x};
  221. if (has_z) { result.push_back(out_z); }
  222. return result;
  223. }
  224. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  225. m.def("fwd", &selective_scan_fwd, "Selective scan forward");
  226. }