punica_ops.h 457 B

1234567891011
  1. #pragma once
  2. #include <torch/all.h>
  3. void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
  4. torch::Tensor indicies, int64_t layer_idx, double scale);
  5. void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
  6. torch::Tensor indicies, int64_t layer_idx,
  7. double scale, int64_t h_in, int64_t h_out,
  8. int64_t y_offset);