123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- // Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
- // We make it work for bfloat16
- #include <torch/extension.h>
- #include <torch/torch.h>
- #include <vector>
- #include <stdio.h>
- // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
- // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
- #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
- switch (TYPE) { \
- case at::ScalarType::Half: { \
- using scalar_t = at::Half; \
- __VA_ARGS__(); \
- break; \
- } \
- case at::ScalarType::BFloat16: { \
- using scalar_t = at::BFloat16; \
- __VA_ARGS__(); \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
- template <typename T>
- int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
- template <typename T>
- int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace);
- template <typename T>
- int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
- template <typename T>
- int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
- template <typename T>
- int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
- at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int out_features = weight.size(0);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
- scalar_t* w_ptr = weight.data_ptr<scalar_t>();
- auto result = linear_bias_forward_cuda<scalar_t>(
- input,
- w_ptr,
- bias,
- in_features,
- batch_size,
- out_features,
- out,
- //out.data_ptr<scalar_t>(),
- // reserved_space.data_ptr<scalar_t>(),
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_bias_forward failed.")
- });
- return {out};
- }
- std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int out_features = weight.size(0);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto opts = input.options();
- auto d_weight = at::empty({out_features, in_features}, opts);
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
- auto d_bias = d_output.view({-1, out_features}).sum(0, false);
- #else
- auto d_bias = at::empty({out_features}, opts);
- #endif
- auto d_input = at::empty({batch_size, in_features}, opts);
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, opts);
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
- scalar_t* w_ptr = weight.data_ptr<scalar_t>();
- auto result = linear_bias_backward_cuda<scalar_t>(
- input.data_ptr<scalar_t>(),
- w_ptr,
- d_output.data_ptr<scalar_t>(),
- in_features,
- batch_size,
- out_features,
- d_weight.data_ptr<scalar_t>(),
- d_bias.data_ptr<scalar_t>(),
- d_input.data_ptr<scalar_t>(),
- // reserved_space.data_ptr<scalar_t>(),
- /*residual=*/false,
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_bias_backward failed.")
- });
- return {d_input, d_weight, d_bias};
- }
- std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int out_features = d_output.size(1);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto opts = input.options();
- auto d_weight = at::empty({out_features, in_features}, opts);
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
- auto d_bias = d_output.view({-1, out_features}).sum(0, false);
- #else
- auto d_bias = at::empty({out_features}, opts);
- #endif
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, opts);
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
- auto result = linear_bias_wgrad_cuda<scalar_t>(
- input.data_ptr<scalar_t>(),
- d_output.data_ptr<scalar_t>(),
- in_features,
- batch_size,
- out_features,
- d_weight.data_ptr<scalar_t>(),
- d_bias.data_ptr<scalar_t>(),
- // reserved_space.data_ptr<scalar_t>(),
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
- });
- return {d_weight, d_bias};
- }
- std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int out_features = weight.size(0);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto opts = input.options();
- auto d_weight = at::empty({out_features, in_features}, opts);
- #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
- auto d_bias = d_output.view({-1, out_features}).sum(0, false);
- #else
- auto d_bias = at::empty({out_features}, opts);
- #endif
- CHECK_SHAPE(d_input, batch_size, in_features);
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, opts);
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
- scalar_t* w_ptr = weight.data_ptr<scalar_t>();
- auto result = linear_bias_backward_cuda<scalar_t>(
- input.data_ptr<scalar_t>(),
- w_ptr,
- d_output.data_ptr<scalar_t>(),
- in_features,
- batch_size,
- out_features,
- d_weight.data_ptr<scalar_t>(),
- d_bias.data_ptr<scalar_t>(),
- d_input.data_ptr<scalar_t>(),
- // reserved_space.data_ptr<scalar_t>(),
- /*residual=*/true,
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
- });
- return {d_input, d_weight, d_bias};
- }
- std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
- bool save_gelu_in, int heuristic) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int out_features = weight.size(0);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto opts = input.options();
- auto output = at::empty({batch_size, out_features}, opts);
- at::Tensor gelu_in;
- if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, opts);
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
- scalar_t* w_ptr = weight.data_ptr<scalar_t>();
- scalar_t* b_ptr = bias.data_ptr<scalar_t>();
- auto result = linear_gelu_forward_cuda<scalar_t>(
- input.data_ptr<scalar_t>(),
- w_ptr,
- b_ptr,
- in_features,
- batch_size,
- out_features,
- heuristic,
- output.data_ptr<scalar_t>(),
- save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
- // reserved_space.data_ptr<scalar_t>(),
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
- });
- std::vector<at::Tensor> result = {output};
- if (save_gelu_in) { result.push_back(gelu_in); };
- return result;
- }
- std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int hidden_features = weight1.size(0);
- int out_features = weight2.size(0);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto opts = input.options();
- auto d_weight1 = at::empty({hidden_features, in_features}, opts);
- auto d_weight2 = at::empty({out_features, hidden_features}, opts);
- auto d_bias1 = at::empty({hidden_features}, opts);
- auto d_bias2 = at::empty({out_features}, opts);
- auto d_input = at::empty({batch_size, in_features}, opts);
- auto d_output1 = at::empty({batch_size, hidden_features}, opts);
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, opts);
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
- //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
- //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
- auto result = linear_gelu_linear_backward_cuda<scalar_t>(
- input.data_ptr<scalar_t>(),
- gelu_in.data_ptr<scalar_t>(),
- output1.data_ptr<scalar_t>(),
- weight1.data_ptr<scalar_t>(),
- weight2.data_ptr<scalar_t>(),
- d_output1.data_ptr<scalar_t>(),
- d_output2.data_ptr<scalar_t>(),
- in_features,
- batch_size,
- hidden_features,
- out_features,
- heuristic,
- d_weight1.data_ptr<scalar_t>(),
- d_weight2.data_ptr<scalar_t>(),
- d_bias1.data_ptr<scalar_t>(),
- d_bias2.data_ptr<scalar_t>(),
- d_input.data_ptr<scalar_t>(),
- // reserved_space.data_ptr<scalar_t>(),
- /*residual=*/false,
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
- });
- return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
- }
- std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
- auto batch_size = input.size(0);
- auto in_features = input.size(1);
- int hidden_features = weight1.size(0);
- int out_features = weight2.size(0);
- //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
- // create output/workspace tensor
- auto opts = input.options();
- auto d_weight1 = at::empty({hidden_features, in_features}, opts);
- auto d_weight2 = at::empty({out_features, hidden_features}, opts);
- auto d_bias1 = at::empty({hidden_features}, opts);
- auto d_bias2 = at::empty({out_features}, opts);
- CHECK_SHAPE(d_input, batch_size, in_features);
- auto d_output1 = at::empty({batch_size, hidden_features}, opts);
- //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
- // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
- auto lt_workspace = at::empty({1 << 22}, opts);
- DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
- //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
- //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
- auto result = linear_gelu_linear_backward_cuda<scalar_t>(
- input.data_ptr<scalar_t>(),
- gelu_in.data_ptr<scalar_t>(),
- output1.data_ptr<scalar_t>(),
- weight1.data_ptr<scalar_t>(),
- weight2.data_ptr<scalar_t>(),
- d_output1.data_ptr<scalar_t>(),
- d_output2.data_ptr<scalar_t>(),
- in_features,
- batch_size,
- hidden_features,
- out_features,
- heuristic,
- d_weight1.data_ptr<scalar_t>(),
- d_weight2.data_ptr<scalar_t>(),
- d_bias1.data_ptr<scalar_t>(),
- d_bias2.data_ptr<scalar_t>(),
- d_input.data_ptr<scalar_t>(),
- // reserved_space.data_ptr<scalar_t>(),
- /*residual=*/true,
- (void*) (lt_workspace.data_ptr<scalar_t>()));
- TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
- });
- return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
- }
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
- m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
- m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
- m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
- m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
- m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
- m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
- }
|