causal_conv1d.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include <torch/extension.h>
  7. #include <vector>
  8. #include "causal_conv1d.h"
  9. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  10. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  11. if (ITYPE == at::ScalarType::Half) { \
  12. using input_t = at::Half; \
  13. __VA_ARGS__(); \
  14. } else if (ITYPE == at::ScalarType::BFloat16) { \
  15. using input_t = at::BFloat16; \
  16. __VA_ARGS__(); \
  17. } else if (ITYPE == at::ScalarType::Float) { \
  18. using input_t = float; \
  19. __VA_ARGS__(); \
  20. } else { \
  21. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
  22. }
  23. #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
  24. if (WTYPE == at::ScalarType::Half) { \
  25. using weight_t = at::Half; \
  26. __VA_ARGS__(); \
  27. } else if (WTYPE == at::ScalarType::BFloat16) { \
  28. using weight_t = at::BFloat16; \
  29. __VA_ARGS__(); \
  30. } else if (WTYPE == at::ScalarType::Float) { \
  31. using weight_t = float; \
  32. __VA_ARGS__(); \
  33. } else { \
  34. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
  35. }
  36. template<typename input_t, typename weight_t>
  37. void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
  38. template <typename input_t, typename weight_t>
  39. void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
  40. template<typename input_t, typename weight_t>
  41. void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
  42. void set_conv_params_fwd(ConvParamsBase &params,
  43. // sizes
  44. const size_t batch,
  45. const size_t dim,
  46. const size_t seqlen,
  47. const size_t width,
  48. // device pointers
  49. const at::Tensor x,
  50. const at::Tensor weight,
  51. const at::Tensor out,
  52. void* bias_ptr,
  53. bool silu_activation) {
  54. // Reset the parameters
  55. memset(&params, 0, sizeof(params));
  56. params.batch = batch;
  57. params.dim = dim;
  58. params.seqlen = seqlen;
  59. params.width = width;
  60. params.silu_activation = silu_activation;
  61. // Set the pointers and strides.
  62. params.x_ptr = x.data_ptr();
  63. params.weight_ptr = weight.data_ptr();
  64. params.bias_ptr = bias_ptr;
  65. params.out_ptr = out.data_ptr();
  66. // All stride are in elements, not bytes.
  67. params.x_batch_stride = x.stride(0);
  68. params.x_c_stride = x.stride(1);
  69. params.x_l_stride = x.stride(-1);
  70. params.weight_c_stride = weight.stride(0);
  71. params.weight_width_stride = weight.stride(1);
  72. params.out_batch_stride = out.stride(0);
  73. params.out_c_stride = out.stride(1);
  74. params.out_l_stride = out.stride(-1);
  75. }
  76. at::Tensor
  77. causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
  78. const c10::optional<at::Tensor> &bias_,
  79. const c10::optional<at::Tensor> &seq_idx_,
  80. const c10::optional<at::Tensor> &initial_states_,
  81. c10::optional<at::Tensor> &final_states_out_,
  82. bool silu_activation) {
  83. auto input_type = x.scalar_type();
  84. auto weight_type = weight.scalar_type();
  85. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  86. TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
  87. TORCH_CHECK(x.is_cuda());
  88. TORCH_CHECK(weight.is_cuda());
  89. const auto sizes = x.sizes();
  90. const int batch_size = sizes[0];
  91. const int dim = sizes[1];
  92. const int seqlen = sizes[2];
  93. const int width = weight.size(-1);
  94. CHECK_SHAPE(x, batch_size, dim, seqlen);
  95. CHECK_SHAPE(weight, dim, width);
  96. TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
  97. const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
  98. if (is_channel_last) {
  99. TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
  100. TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
  101. }
  102. TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
  103. if (bias_.has_value()) {
  104. auto bias = bias_.value();
  105. TORCH_CHECK(bias.scalar_type() == weight_type);
  106. TORCH_CHECK(bias.is_cuda());
  107. TORCH_CHECK(bias.stride(-1) == 1);
  108. CHECK_SHAPE(bias, dim);
  109. }
  110. if (seq_idx_.has_value()) {
  111. TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
  112. auto seq_idx = seq_idx_.value();
  113. TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
  114. TORCH_CHECK(seq_idx.is_cuda());
  115. TORCH_CHECK(seq_idx.is_contiguous());
  116. CHECK_SHAPE(seq_idx, batch_size, seqlen);
  117. }
  118. at::Tensor out = torch::empty_like(x);
  119. ConvParamsBase params;
  120. set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
  121. bias_.has_value() ? bias_.value().data_ptr() : nullptr,
  122. silu_activation);
  123. if (seq_idx_.has_value()) {
  124. params.seq_idx_ptr = seq_idx_.value().data_ptr();
  125. } else {
  126. params.seq_idx_ptr = nullptr;
  127. }
  128. if (initial_states_.has_value()) {
  129. TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
  130. auto initial_states = initial_states_.value();
  131. TORCH_CHECK(initial_states.scalar_type() == input_type);
  132. TORCH_CHECK(initial_states.is_cuda());
  133. CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
  134. TORCH_CHECK(initial_states.stride(1) == 1);
  135. params.initial_states_ptr = initial_states.data_ptr();
  136. params.initial_states_batch_stride = initial_states.stride(0);
  137. params.initial_states_c_stride = initial_states.stride(1);
  138. params.initial_states_l_stride = initial_states.stride(2);
  139. } else {
  140. params.initial_states_ptr = nullptr;
  141. }
  142. if (final_states_out_.has_value()) {
  143. TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
  144. auto final_states = final_states_out_.value();
  145. TORCH_CHECK(final_states.scalar_type() == input_type);
  146. TORCH_CHECK(final_states.is_cuda());
  147. CHECK_SHAPE(final_states, batch_size, dim, width - 1);
  148. TORCH_CHECK(final_states.stride(1) == 1);
  149. params.final_states_ptr = final_states.data_ptr();
  150. params.final_states_batch_stride = final_states.stride(0);
  151. params.final_states_c_stride = final_states.stride(1);
  152. params.final_states_l_stride = final_states.stride(2);
  153. } else {
  154. params.final_states_ptr = nullptr;
  155. }
  156. // Otherwise the kernel will be launched from cuda:0 device
  157. // Cast to char to avoid compiler warning about narrowing
  158. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  159. auto stream = at::cuda::getCurrentCUDAStream().stream();
  160. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
  161. DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
  162. if (!is_channel_last) {
  163. causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
  164. } else {
  165. causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
  166. }
  167. });
  168. });
  169. return out;
  170. }
  171. at::Tensor
  172. causal_conv1d_update(const at::Tensor &x,
  173. const at::Tensor &conv_state,
  174. const at::Tensor &weight,
  175. const c10::optional<at::Tensor> &bias_,
  176. bool silu_activation) {
  177. auto input_type = x.scalar_type();
  178. auto weight_type = weight.scalar_type();
  179. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  180. TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
  181. TORCH_CHECK(conv_state.scalar_type() == input_type);
  182. TORCH_CHECK(x.is_cuda());
  183. TORCH_CHECK(conv_state.is_cuda());
  184. TORCH_CHECK(weight.is_cuda());
  185. const auto sizes = x.sizes();
  186. const int batch_size = sizes[0];
  187. const int dim = sizes[1];
  188. const int width = weight.size(-1);
  189. CHECK_SHAPE(x, batch_size, dim);
  190. CHECK_SHAPE(conv_state, batch_size, dim, width);
  191. CHECK_SHAPE(weight, dim, width);
  192. TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
  193. if (bias_.has_value()) {
  194. auto bias = bias_.value();
  195. TORCH_CHECK(bias.scalar_type() == weight_type);
  196. TORCH_CHECK(bias.is_cuda());
  197. TORCH_CHECK(bias.stride(-1) == 1);
  198. CHECK_SHAPE(bias, dim);
  199. }
  200. at::Tensor out = torch::empty_like(x);
  201. ConvParamsBase params;
  202. set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
  203. bias_.has_value() ? bias_.value().data_ptr() : nullptr,
  204. silu_activation);
  205. params.conv_state_ptr = conv_state.data_ptr();
  206. // All stride are in elements, not bytes.
  207. params.conv_state_batch_stride = conv_state.stride(0);
  208. params.conv_state_c_stride = conv_state.stride(1);
  209. params.conv_state_l_stride = conv_state.stride(2);
  210. // Otherwise the kernel will be launched from cuda:0 device
  211. // Cast to char to avoid compiler warning about narrowing
  212. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  213. auto stream = at::cuda::getCurrentCUDAStream().stream();
  214. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
  215. DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
  216. causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
  217. });
  218. });
  219. return out;
  220. }
  221. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  222. m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
  223. m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
  224. }