torch_bindings.cpp 585 B

123456789101112131415161718
  1. #include "registration.h"
  2. #include "punica_ops.h"
  3. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
  4. m.def(
  5. "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
  6. "layer_idx, float scale) -> ()");
  7. m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
  8. m.def(
  9. "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
  10. "Tensor indicies, int layer_idx,"
  11. "float scale, int h_in, int h_out,"
  12. "int y_offset) -> ()");
  13. m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
  14. }
  15. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)