123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554 |
- #include <torch/extension.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <cstdint>
- #include "type_convert.h"
- #include "../cuda_compat.h"
- #include "bgmv/bgmv_config.h"
- //====== utils ======
- inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
- const char *a_name, const char *b_name) {
- TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
- a.dim(), " vs ", b.dim());
- for (int i = 0; i < a.dim(); ++i) {
- TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
- ".size(", i, ")");
- }
- }
- inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
- return (uint64_t(a) << 32) | uint64_t(b);
- }
- #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
- #define CHECK_CONTIGUOUS(x) \
- TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
- #define CHECK_INPUT(x) \
- CHECK_CUDA(x); \
- CHECK_CONTIGUOUS(x)
- #define CHECK_DIM(d, x) \
- TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
- #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
- #define CHECK_EQ(a, b) \
- TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
- //====== bgmv ======
- template <typename in_T, typename out_T, typename W_T>
- inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
- const int64_t *lora_indices,
- uint32_t in_features, uint32_t out_features,
- int64_t y_offset, int64_t full_y_size,
- int64_t batch_size, int64_t num_layers,
- int64_t layer_idx, float scale) {
- switch (pack_u32(in_features, out_features)) {
- #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
- case pack_u32(feat_in, feat_out): \
- bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
- full_y_size, batch_size, num_layers, \
- layer_idx, scale); \
- break;
- #define CASE(_in_T, _out_T, _W_T, narrow, wide) \
- CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
- CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
- FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
- FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
- #undef CASE
- #undef CASE_ONESIDE
- default:
- return false;
- }
- return true;
- }
- void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
- torch::Tensor indicies, int64_t layer_idx, float scale) {
- CHECK_INPUT(y);
- CHECK_INPUT(x);
- CHECK_INPUT(w);
- CHECK_INPUT(indicies);
- CHECK_DIM(2, y);
- CHECK_DIM(2, x);
- CHECK_DIM(4, w);
- CHECK_DIM(1, indicies);
- int64_t B = x.size(0);
- int64_t h_in = x.size(1);
- int64_t h_out = y.size(1);
- int64_t num_layers = w.size(1);
- CHECK_EQ(w.size(3), h_in);
- CHECK_EQ(w.size(2), h_out);
- CHECK_EQ(indicies.size(0), x.size(0));
- CHECK_EQ(y.size(0), x.size(0));
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
- bool ok = false;
- if (h_in <= 128512 && h_out <= 128512) {
- // TODO: See if we can get rid of this massive nested switch
- switch (x.scalar_type()) {
- case at::ScalarType::Half:
- switch (y.scalar_type()) {
- case at::ScalarType::Half:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (y.scalar_type()) {
- case at::ScalarType::Half:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (y.scalar_type()) {
- case at::ScalarType::Half:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out, 0,
- h_out, B, num_layers, layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- }
- TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
- " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
- }
- void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
- torch::Tensor indicies, int64_t layer_idx,
- float scale, int64_t h_in, int64_t h_out,
- int64_t y_offset) {
- CHECK_INPUT(y);
- CHECK_INPUT(x);
- CHECK_INPUT(w);
- CHECK_INPUT(indicies);
- CHECK_DIM(2, y);
- CHECK_DIM(2, x);
- CHECK_DIM(4, w);
- CHECK_DIM(1, indicies);
- int64_t B = x.size(0);
- int64_t num_layers = w.size(1);
- int64_t full_y_size = y.size(1);
- CHECK_EQ(w.size(3), h_in);
- CHECK_EQ(w.size(2), h_out);
- CHECK_EQ(indicies.size(0), x.size(0));
- CHECK_EQ(y.size(0), x.size(0));
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
- bool ok = false;
- if (h_in <= 128512 && h_out <= 128512) {
- // TODO: See if we can get rid of this massive nested switch
- switch (x.scalar_type()) {
- case at::ScalarType::Half:
- switch (y.scalar_type()) {
- case at::ScalarType::Half:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_half *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (y.scalar_type()) {
- case at::ScalarType::Half:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<nv_bfloat16 *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (y.scalar_type()) {
- case at::ScalarType::Half:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::BFloat16:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- case at::ScalarType::Float:
- switch (w.scalar_type()) {
- case at::ScalarType::Half:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_half *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- case at::ScalarType::BFloat16:
- ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
- static_cast<float *>(x.data_ptr()),
- static_cast<nv_bfloat16 *>(w.data_ptr()),
- indicies.data_ptr<int64_t>(), h_in, h_out,
- y_offset, full_y_size, B, num_layers,
- layer_idx, scale);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- }
- TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
- " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
- }
|