fused_dense.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
  2. // We make it work for bfloat16
  3. #include <torch/extension.h>
  4. #include <torch/torch.h>
  5. #include <vector>
  6. #include <stdio.h>
  7. // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
  8. // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  9. #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
  10. switch (TYPE) { \
  11. case at::ScalarType::Half: { \
  12. using scalar_t = at::Half; \
  13. __VA_ARGS__(); \
  14. break; \
  15. } \
  16. case at::ScalarType::BFloat16: { \
  17. using scalar_t = at::BFloat16; \
  18. __VA_ARGS__(); \
  19. break; \
  20. } \
  21. default: \
  22. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  23. }
  24. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  25. template <typename T>
  26. int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
  27. template <typename T>
  28. int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace);
  29. template <typename T>
  30. int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
  31. template <typename T>
  32. int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
  33. template <typename T>
  34. int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
  35. at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
  36. auto batch_size = input.size(0);
  37. auto in_features = input.size(1);
  38. int out_features = weight.size(0);
  39. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  40. // create output/workspace tensor
  41. auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
  42. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  43. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  44. auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
  45. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
  46. scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  47. auto result = linear_bias_forward_cuda<scalar_t>(
  48. input,
  49. w_ptr,
  50. bias,
  51. in_features,
  52. batch_size,
  53. out_features,
  54. out,
  55. //out.data_ptr<scalar_t>(),
  56. // reserved_space.data_ptr<scalar_t>(),
  57. (void*) (lt_workspace.data_ptr<scalar_t>()));
  58. TORCH_CHECK(result == 0, "linear_bias_forward failed.")
  59. });
  60. return {out};
  61. }
  62. std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
  63. auto batch_size = input.size(0);
  64. auto in_features = input.size(1);
  65. int out_features = weight.size(0);
  66. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  67. // create output/workspace tensor
  68. auto opts = input.options();
  69. auto d_weight = at::empty({out_features, in_features}, opts);
  70. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
  71. auto d_bias = d_output.view({-1, out_features}).sum(0, false);
  72. #else
  73. auto d_bias = at::empty({out_features}, opts);
  74. #endif
  75. auto d_input = at::empty({batch_size, in_features}, opts);
  76. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  77. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  78. auto lt_workspace = at::empty({1 << 22}, opts);
  79. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
  80. scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  81. auto result = linear_bias_backward_cuda<scalar_t>(
  82. input.data_ptr<scalar_t>(),
  83. w_ptr,
  84. d_output.data_ptr<scalar_t>(),
  85. in_features,
  86. batch_size,
  87. out_features,
  88. d_weight.data_ptr<scalar_t>(),
  89. d_bias.data_ptr<scalar_t>(),
  90. d_input.data_ptr<scalar_t>(),
  91. // reserved_space.data_ptr<scalar_t>(),
  92. /*residual=*/false,
  93. (void*) (lt_workspace.data_ptr<scalar_t>()));
  94. TORCH_CHECK(result == 0, "linear_bias_backward failed.")
  95. });
  96. return {d_input, d_weight, d_bias};
  97. }
  98. std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
  99. auto batch_size = input.size(0);
  100. auto in_features = input.size(1);
  101. int out_features = d_output.size(1);
  102. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  103. // create output/workspace tensor
  104. auto opts = input.options();
  105. auto d_weight = at::empty({out_features, in_features}, opts);
  106. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
  107. auto d_bias = d_output.view({-1, out_features}).sum(0, false);
  108. #else
  109. auto d_bias = at::empty({out_features}, opts);
  110. #endif
  111. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  112. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  113. auto lt_workspace = at::empty({1 << 22}, opts);
  114. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
  115. auto result = linear_bias_wgrad_cuda<scalar_t>(
  116. input.data_ptr<scalar_t>(),
  117. d_output.data_ptr<scalar_t>(),
  118. in_features,
  119. batch_size,
  120. out_features,
  121. d_weight.data_ptr<scalar_t>(),
  122. d_bias.data_ptr<scalar_t>(),
  123. // reserved_space.data_ptr<scalar_t>(),
  124. (void*) (lt_workspace.data_ptr<scalar_t>()));
  125. TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
  126. });
  127. return {d_weight, d_bias};
  128. }
  129. std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
  130. auto batch_size = input.size(0);
  131. auto in_features = input.size(1);
  132. int out_features = weight.size(0);
  133. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  134. // create output/workspace tensor
  135. auto opts = input.options();
  136. auto d_weight = at::empty({out_features, in_features}, opts);
  137. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
  138. auto d_bias = d_output.view({-1, out_features}).sum(0, false);
  139. #else
  140. auto d_bias = at::empty({out_features}, opts);
  141. #endif
  142. CHECK_SHAPE(d_input, batch_size, in_features);
  143. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  144. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  145. auto lt_workspace = at::empty({1 << 22}, opts);
  146. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
  147. scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  148. auto result = linear_bias_backward_cuda<scalar_t>(
  149. input.data_ptr<scalar_t>(),
  150. w_ptr,
  151. d_output.data_ptr<scalar_t>(),
  152. in_features,
  153. batch_size,
  154. out_features,
  155. d_weight.data_ptr<scalar_t>(),
  156. d_bias.data_ptr<scalar_t>(),
  157. d_input.data_ptr<scalar_t>(),
  158. // reserved_space.data_ptr<scalar_t>(),
  159. /*residual=*/true,
  160. (void*) (lt_workspace.data_ptr<scalar_t>()));
  161. TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
  162. });
  163. return {d_input, d_weight, d_bias};
  164. }
  165. std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
  166. bool save_gelu_in, int heuristic) {
  167. auto batch_size = input.size(0);
  168. auto in_features = input.size(1);
  169. int out_features = weight.size(0);
  170. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  171. // create output/workspace tensor
  172. auto opts = input.options();
  173. auto output = at::empty({batch_size, out_features}, opts);
  174. at::Tensor gelu_in;
  175. if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
  176. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  177. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  178. auto lt_workspace = at::empty({1 << 22}, opts);
  179. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
  180. scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  181. scalar_t* b_ptr = bias.data_ptr<scalar_t>();
  182. auto result = linear_gelu_forward_cuda<scalar_t>(
  183. input.data_ptr<scalar_t>(),
  184. w_ptr,
  185. b_ptr,
  186. in_features,
  187. batch_size,
  188. out_features,
  189. heuristic,
  190. output.data_ptr<scalar_t>(),
  191. save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
  192. // reserved_space.data_ptr<scalar_t>(),
  193. (void*) (lt_workspace.data_ptr<scalar_t>()));
  194. TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
  195. });
  196. std::vector<at::Tensor> result = {output};
  197. if (save_gelu_in) { result.push_back(gelu_in); };
  198. return result;
  199. }
  200. std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
  201. auto batch_size = input.size(0);
  202. auto in_features = input.size(1);
  203. int hidden_features = weight1.size(0);
  204. int out_features = weight2.size(0);
  205. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  206. // create output/workspace tensor
  207. auto opts = input.options();
  208. auto d_weight1 = at::empty({hidden_features, in_features}, opts);
  209. auto d_weight2 = at::empty({out_features, hidden_features}, opts);
  210. auto d_bias1 = at::empty({hidden_features}, opts);
  211. auto d_bias2 = at::empty({out_features}, opts);
  212. auto d_input = at::empty({batch_size, in_features}, opts);
  213. auto d_output1 = at::empty({batch_size, hidden_features}, opts);
  214. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  215. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  216. auto lt_workspace = at::empty({1 << 22}, opts);
  217. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
  218. //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  219. //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
  220. auto result = linear_gelu_linear_backward_cuda<scalar_t>(
  221. input.data_ptr<scalar_t>(),
  222. gelu_in.data_ptr<scalar_t>(),
  223. output1.data_ptr<scalar_t>(),
  224. weight1.data_ptr<scalar_t>(),
  225. weight2.data_ptr<scalar_t>(),
  226. d_output1.data_ptr<scalar_t>(),
  227. d_output2.data_ptr<scalar_t>(),
  228. in_features,
  229. batch_size,
  230. hidden_features,
  231. out_features,
  232. heuristic,
  233. d_weight1.data_ptr<scalar_t>(),
  234. d_weight2.data_ptr<scalar_t>(),
  235. d_bias1.data_ptr<scalar_t>(),
  236. d_bias2.data_ptr<scalar_t>(),
  237. d_input.data_ptr<scalar_t>(),
  238. // reserved_space.data_ptr<scalar_t>(),
  239. /*residual=*/false,
  240. (void*) (lt_workspace.data_ptr<scalar_t>()));
  241. TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
  242. });
  243. return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
  244. }
  245. std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
  246. auto batch_size = input.size(0);
  247. auto in_features = input.size(1);
  248. int hidden_features = weight1.size(0);
  249. int out_features = weight2.size(0);
  250. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  251. // create output/workspace tensor
  252. auto opts = input.options();
  253. auto d_weight1 = at::empty({hidden_features, in_features}, opts);
  254. auto d_weight2 = at::empty({out_features, hidden_features}, opts);
  255. auto d_bias1 = at::empty({hidden_features}, opts);
  256. auto d_bias2 = at::empty({out_features}, opts);
  257. CHECK_SHAPE(d_input, batch_size, in_features);
  258. auto d_output1 = at::empty({batch_size, hidden_features}, opts);
  259. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  260. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  261. auto lt_workspace = at::empty({1 << 22}, opts);
  262. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
  263. //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  264. //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
  265. auto result = linear_gelu_linear_backward_cuda<scalar_t>(
  266. input.data_ptr<scalar_t>(),
  267. gelu_in.data_ptr<scalar_t>(),
  268. output1.data_ptr<scalar_t>(),
  269. weight1.data_ptr<scalar_t>(),
  270. weight2.data_ptr<scalar_t>(),
  271. d_output1.data_ptr<scalar_t>(),
  272. d_output2.data_ptr<scalar_t>(),
  273. in_features,
  274. batch_size,
  275. hidden_features,
  276. out_features,
  277. heuristic,
  278. d_weight1.data_ptr<scalar_t>(),
  279. d_weight2.data_ptr<scalar_t>(),
  280. d_bias1.data_ptr<scalar_t>(),
  281. d_bias2.data_ptr<scalar_t>(),
  282. d_input.data_ptr<scalar_t>(),
  283. // reserved_space.data_ptr<scalar_t>(),
  284. /*residual=*/true,
  285. (void*) (lt_workspace.data_ptr<scalar_t>()));
  286. TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
  287. });
  288. return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
  289. }
  290. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  291. m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
  292. m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
  293. m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
  294. m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
  295. m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
  296. m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
  297. m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
  298. }