123456789101112131415161718 |
- #include "registration.h"
- #include "punica_ops.h"
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
- m.def(
- "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
- "layer_idx, float scale) -> ()");
- m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
- m.def(
- "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
- "Tensor indicies, int layer_idx,"
- "float scale, int h_in, int h_out,"
- "int y_offset) -> ()");
- m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
- }
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|