fused_dense.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
  2. // We make it work for bfloat16
  3. #include <torch/extension.h>
  4. #include <torch/torch.h>
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include <c10/cuda/CUDAGuard.h>
  7. #include <vector>
  8. #include <stdio.h>
  9. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  10. // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
  11. // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  12. #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
  13. switch (TYPE) { \
  14. case at::ScalarType::Half: { \
  15. using scalar_t = at::Half; \
  16. __VA_ARGS__(); \
  17. break; \
  18. } \
  19. case at::ScalarType::BFloat16: { \
  20. using scalar_t = at::BFloat16; \
  21. __VA_ARGS__(); \
  22. break; \
  23. } \
  24. default: \
  25. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  26. }
  27. template <typename T>
  28. int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize);
  29. template <typename T>
  30. int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
  31. template <typename T>
  32. int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize);
  33. std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
  34. int64_t batch_size = input.size(0);
  35. int64_t in_features = input.size(1);
  36. int64_t out_features = d_output.size(1);
  37. TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
  38. TORCH_CHECK(input.dtype() == d_output.dtype());
  39. TORCH_CHECK(input.is_cuda());
  40. TORCH_CHECK(d_output.is_cuda());
  41. TORCH_CHECK(input.is_contiguous());
  42. TORCH_CHECK(d_output.is_contiguous());
  43. CHECK_SHAPE(input, batch_size, in_features);
  44. CHECK_SHAPE(d_output, batch_size, out_features);
  45. // Otherwise the kernel will be launched from cuda:0 device
  46. at::cuda::CUDAGuard device_guard{input.device()};
  47. // create output/workspace tensor
  48. auto opts = input.options();
  49. auto d_weight = at::empty({out_features, in_features}, opts);
  50. at::Tensor d_bias;
  51. if (has_d_bias) {
  52. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
  53. d_bias = d_output.view({-1, out_features}).sum(0, false);
  54. #else
  55. d_bias = at::empty({out_features}, opts);
  56. #endif
  57. }
  58. // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
  59. // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
  60. // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
  61. size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
  62. auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
  63. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
  64. auto result = linear_bias_wgrad_cuda<scalar_t>(
  65. input.data_ptr<scalar_t>(),
  66. d_output.data_ptr<scalar_t>(),
  67. in_features,
  68. batch_size,
  69. out_features,
  70. d_weight.data_ptr<scalar_t>(),
  71. has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
  72. (void*) (lt_workspace.data_ptr()),
  73. workspaceSize);
  74. TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
  75. });
  76. return {d_weight, d_bias};
  77. }
  78. std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
  79. c10::optional<at::Tensor> bias_,
  80. bool is_gelu, bool save_pre_act, int heuristic) {
  81. int64_t batch_size = input.size(0);
  82. int64_t in_features = input.size(1);
  83. int64_t out_features = weight.size(0);
  84. TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
  85. TORCH_CHECK(input.dtype() == weight.dtype());
  86. TORCH_CHECK(input.is_cuda());
  87. TORCH_CHECK(weight.is_cuda());
  88. TORCH_CHECK(input.is_contiguous());
  89. TORCH_CHECK(weight.is_contiguous());
  90. CHECK_SHAPE(input, batch_size, in_features);
  91. CHECK_SHAPE(weight, out_features, in_features);
  92. if (bias_.has_value()) {
  93. auto bias = bias_.value();
  94. TORCH_CHECK(bias.dtype() == input.dtype());
  95. TORCH_CHECK(bias.is_cuda());
  96. TORCH_CHECK(bias.is_contiguous());
  97. CHECK_SHAPE(bias, out_features);
  98. }
  99. // Otherwise the kernel will be launched from cuda:0 device
  100. at::cuda::CUDAGuard device_guard{input.device()};
  101. // create output/workspace tensor
  102. auto opts = input.options();
  103. auto output = at::empty({batch_size, out_features}, opts);
  104. at::Tensor pre_act;
  105. // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
  106. if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
  107. is_gelu ? opts : opts.dtype(torch::kUInt8)); }
  108. // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
  109. // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
  110. // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
  111. size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
  112. auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
  113. DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
  114. auto result = linear_act_forward_cuda<scalar_t>(
  115. input.data_ptr<scalar_t>(),
  116. weight.data_ptr<scalar_t>(),
  117. bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
  118. in_features,
  119. batch_size,
  120. out_features,
  121. is_gelu,
  122. heuristic,
  123. output.data_ptr<scalar_t>(),
  124. save_pre_act ? pre_act.data_ptr() : nullptr,
  125. (void*) (lt_workspace.data_ptr()),
  126. workspaceSize);
  127. TORCH_CHECK(result == 0, "linear_act_forward failed.");
  128. });
  129. std::vector<at::Tensor> result = {output};
  130. if (save_pre_act) { result.push_back(pre_act); };
  131. return result;
  132. }
  133. std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
  134. at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic
  135. ) {
  136. int64_t batch_size = d_output.size(0);
  137. int64_t out_features = d_output.size(1);
  138. int64_t in_features = weight.size(1);
  139. TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
  140. TORCH_CHECK(weight.dtype() == d_output.dtype());
  141. TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));
  142. TORCH_CHECK(weight.is_cuda());
  143. TORCH_CHECK(d_output.is_cuda());
  144. TORCH_CHECK(pre_act.is_cuda());
  145. TORCH_CHECK(weight.is_contiguous());
  146. TORCH_CHECK(d_output.is_contiguous());
  147. TORCH_CHECK(pre_act.is_contiguous());
  148. CHECK_SHAPE(weight, out_features, in_features);
  149. CHECK_SHAPE(d_output, batch_size, out_features);
  150. // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
  151. CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);
  152. // Otherwise the kernel will be launched from cuda:0 device
  153. at::cuda::CUDAGuard device_guard{weight.device()};
  154. // create output/workspace tensor
  155. auto opts = weight.options();
  156. auto d_bias = at::empty({in_features}, opts);
  157. auto d_input = at::empty({batch_size, in_features}, opts);
  158. // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
  159. // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
  160. // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
  161. size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
  162. auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
  163. DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
  164. auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
  165. weight.data_ptr<scalar_t>(),
  166. d_output.data_ptr<scalar_t>(),
  167. pre_act.data_ptr(),
  168. in_features,
  169. batch_size,
  170. out_features,
  171. is_gelu,
  172. heuristic,
  173. d_input.data_ptr<scalar_t>(),
  174. d_bias.data_ptr<scalar_t>(),
  175. (void*) (lt_workspace.data_ptr()),
  176. workspaceSize);
  177. TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
  178. });
  179. return {d_input, d_bias};
  180. }
  181. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  182. m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
  183. m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward");
  184. m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad");
  185. }