punica_ops.cc 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. #include <cuda_bf16.h>
  2. #include <cuda_fp16.h>
  3. #include <torch/extension.h>
  4. #include <cstdint>
  5. #include "bgmv/bgmv_config.h"
  6. namespace {
  7. //====== utils ======
  8. inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
  9. const char *a_name, const char *b_name) {
  10. TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
  11. a.dim(), " vs ", b.dim());
  12. for (int i = 0; i < a.dim(); ++i) {
  13. TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
  14. ".size(", i, ")");
  15. }
  16. }
  17. inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
  18. return (uint32_t(a) << 16) | uint32_t(b);
  19. }
  20. #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  21. #define CHECK_CONTIGUOUS(x) \
  22. TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  23. #define CHECK_INPUT(x) \
  24. CHECK_CUDA(x); \
  25. CHECK_CONTIGUOUS(x)
  26. #define CHECK_DIM(d, x) \
  27. TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
  28. #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
  29. #define CHECK_EQ(a, b) \
  30. TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
  31. //====== bgmv ======
  32. template <typename in_T, typename out_T, typename W_T>
  33. inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
  34. const int64_t *lora_indices,
  35. uint16_t in_features, uint16_t out_features,
  36. int64_t y_offset, int64_t full_y_size,
  37. int64_t batch_size, int64_t num_layers,
  38. int64_t layer_idx, float scale) {
  39. switch (pack_u16(in_features, out_features)) {
  40. #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
  41. case pack_u16(feat_in, feat_out): \
  42. bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
  43. full_y_size, batch_size, num_layers, \
  44. layer_idx, scale); \
  45. break;
  46. #define CASE(_in_T, _out_T, _W_T, narrow, wide) \
  47. CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
  48. CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
  49. FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
  50. #undef CASE
  51. #undef CASE_ONESIDE
  52. default:
  53. return false;
  54. }
  55. return true;
  56. }
  57. void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
  58. torch::Tensor indicies, int64_t layer_idx, float scale) {
  59. CHECK_INPUT(y);
  60. CHECK_INPUT(x);
  61. CHECK_INPUT(w);
  62. CHECK_INPUT(indicies);
  63. CHECK_DIM(2, y);
  64. CHECK_DIM(2, x);
  65. CHECK_DIM(4, w);
  66. CHECK_DIM(1, indicies);
  67. int64_t B = x.size(0);
  68. int64_t h_in = x.size(1);
  69. int64_t h_out = y.size(1);
  70. int64_t num_layers = w.size(1);
  71. CHECK_EQ(w.size(3), h_in);
  72. CHECK_EQ(w.size(2), h_out);
  73. CHECK_EQ(indicies.size(0), x.size(0));
  74. CHECK_EQ(y.size(0), x.size(0));
  75. bool ok = false;
  76. if (h_in < 65536 && h_out < 65536) {
  77. // TODO: See if we can get rid of this massive nested switch
  78. switch (x.scalar_type()) {
  79. case at::ScalarType::Half:
  80. switch (y.scalar_type()) {
  81. case at::ScalarType::Half:
  82. switch (w.scalar_type()) {
  83. case at::ScalarType::Half:
  84. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  85. static_cast<nv_half *>(x.data_ptr()),
  86. static_cast<nv_half *>(w.data_ptr()),
  87. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  88. h_out, B, num_layers, layer_idx, scale);
  89. break;
  90. case at::ScalarType::BFloat16:
  91. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  92. static_cast<nv_half *>(x.data_ptr()),
  93. static_cast<nv_bfloat16 *>(w.data_ptr()),
  94. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  95. h_out, B, num_layers, layer_idx, scale);
  96. break;
  97. default:
  98. break;
  99. }
  100. break;
  101. case at::ScalarType::BFloat16:
  102. switch (w.scalar_type()) {
  103. case at::ScalarType::Half:
  104. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  105. static_cast<nv_half *>(x.data_ptr()),
  106. static_cast<nv_half *>(w.data_ptr()),
  107. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  108. h_out, B, num_layers, layer_idx, scale);
  109. break;
  110. case at::ScalarType::BFloat16:
  111. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  112. static_cast<nv_half *>(x.data_ptr()),
  113. static_cast<nv_bfloat16 *>(w.data_ptr()),
  114. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  115. h_out, B, num_layers, layer_idx, scale);
  116. break;
  117. default:
  118. break;
  119. }
  120. break;
  121. case at::ScalarType::Float:
  122. switch (w.scalar_type()) {
  123. case at::ScalarType::Half:
  124. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  125. static_cast<nv_half *>(x.data_ptr()),
  126. static_cast<nv_half *>(w.data_ptr()),
  127. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  128. h_out, B, num_layers, layer_idx, scale);
  129. break;
  130. case at::ScalarType::BFloat16:
  131. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  132. static_cast<nv_half *>(x.data_ptr()),
  133. static_cast<nv_bfloat16 *>(w.data_ptr()),
  134. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  135. h_out, B, num_layers, layer_idx, scale);
  136. break;
  137. default:
  138. break;
  139. }
  140. break;
  141. default:
  142. break;
  143. }
  144. break;
  145. case at::ScalarType::BFloat16:
  146. switch (y.scalar_type()) {
  147. case at::ScalarType::Half:
  148. switch (w.scalar_type()) {
  149. case at::ScalarType::Half:
  150. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  151. static_cast<nv_bfloat16 *>(x.data_ptr()),
  152. static_cast<nv_half *>(w.data_ptr()),
  153. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  154. h_out, B, num_layers, layer_idx, scale);
  155. break;
  156. case at::ScalarType::BFloat16:
  157. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  158. static_cast<nv_bfloat16 *>(x.data_ptr()),
  159. static_cast<nv_bfloat16 *>(w.data_ptr()),
  160. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  161. h_out, B, num_layers, layer_idx, scale);
  162. break;
  163. default:
  164. break;
  165. }
  166. break;
  167. case at::ScalarType::BFloat16:
  168. switch (w.scalar_type()) {
  169. case at::ScalarType::Half:
  170. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  171. static_cast<nv_bfloat16 *>(x.data_ptr()),
  172. static_cast<nv_half *>(w.data_ptr()),
  173. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  174. h_out, B, num_layers, layer_idx, scale);
  175. break;
  176. case at::ScalarType::BFloat16:
  177. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  178. static_cast<nv_bfloat16 *>(x.data_ptr()),
  179. static_cast<nv_bfloat16 *>(w.data_ptr()),
  180. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  181. h_out, B, num_layers, layer_idx, scale);
  182. break;
  183. default:
  184. break;
  185. }
  186. break;
  187. case at::ScalarType::Float:
  188. switch (w.scalar_type()) {
  189. case at::ScalarType::Half:
  190. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  191. static_cast<nv_bfloat16 *>(x.data_ptr()),
  192. static_cast<nv_half *>(w.data_ptr()),
  193. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  194. h_out, B, num_layers, layer_idx, scale);
  195. break;
  196. case at::ScalarType::BFloat16:
  197. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  198. static_cast<nv_bfloat16 *>(x.data_ptr()),
  199. static_cast<nv_bfloat16 *>(w.data_ptr()),
  200. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  201. h_out, B, num_layers, layer_idx, scale);
  202. break;
  203. default:
  204. break;
  205. }
  206. break;
  207. default:
  208. break;
  209. }
  210. break;
  211. case at::ScalarType::Float:
  212. switch (y.scalar_type()) {
  213. case at::ScalarType::Half:
  214. switch (w.scalar_type()) {
  215. case at::ScalarType::Half:
  216. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  217. static_cast<float *>(x.data_ptr()),
  218. static_cast<nv_half *>(w.data_ptr()),
  219. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  220. h_out, B, num_layers, layer_idx, scale);
  221. break;
  222. case at::ScalarType::BFloat16:
  223. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  224. static_cast<float *>(x.data_ptr()),
  225. static_cast<nv_bfloat16 *>(w.data_ptr()),
  226. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  227. h_out, B, num_layers, layer_idx, scale);
  228. break;
  229. default:
  230. break;
  231. }
  232. break;
  233. case at::ScalarType::BFloat16:
  234. switch (w.scalar_type()) {
  235. case at::ScalarType::Half:
  236. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  237. static_cast<float *>(x.data_ptr()),
  238. static_cast<nv_half *>(w.data_ptr()),
  239. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  240. h_out, B, num_layers, layer_idx, scale);
  241. break;
  242. case at::ScalarType::BFloat16:
  243. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  244. static_cast<float *>(x.data_ptr()),
  245. static_cast<nv_bfloat16 *>(w.data_ptr()),
  246. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  247. h_out, B, num_layers, layer_idx, scale);
  248. break;
  249. default:
  250. break;
  251. }
  252. break;
  253. case at::ScalarType::Float:
  254. switch (w.scalar_type()) {
  255. case at::ScalarType::Half:
  256. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  257. static_cast<float *>(x.data_ptr()),
  258. static_cast<nv_half *>(w.data_ptr()),
  259. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  260. h_out, B, num_layers, layer_idx, scale);
  261. break;
  262. case at::ScalarType::BFloat16:
  263. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  264. static_cast<float *>(x.data_ptr()),
  265. static_cast<nv_bfloat16 *>(w.data_ptr()),
  266. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  267. h_out, B, num_layers, layer_idx, scale);
  268. break;
  269. default:
  270. break;
  271. }
  272. break;
  273. default:
  274. break;
  275. }
  276. break;
  277. default:
  278. break;
  279. }
  280. }
  281. TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
  282. " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
  283. }
  284. void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
  285. torch::Tensor indicies, int64_t layer_idx,
  286. float scale, int64_t h_in, int64_t h_out,
  287. int64_t y_offset) {
  288. CHECK_INPUT(y);
  289. CHECK_INPUT(x);
  290. CHECK_INPUT(w);
  291. CHECK_INPUT(indicies);
  292. CHECK_DIM(2, y);
  293. CHECK_DIM(2, x);
  294. CHECK_DIM(4, w);
  295. CHECK_DIM(1, indicies);
  296. int64_t B = x.size(0);
  297. int64_t num_layers = w.size(1);
  298. int64_t full_y_size = y.size(1);
  299. CHECK_EQ(w.size(3), h_in);
  300. CHECK_EQ(w.size(2), h_out);
  301. CHECK_EQ(indicies.size(0), x.size(0));
  302. CHECK_EQ(y.size(0), x.size(0));
  303. bool ok = false;
  304. if (h_in < 65536 && h_out < 65536) {
  305. // TODO: See if we can get rid of this massive nested switch
  306. switch (x.scalar_type()) {
  307. case at::ScalarType::Half:
  308. switch (y.scalar_type()) {
  309. case at::ScalarType::Half:
  310. switch (w.scalar_type()) {
  311. case at::ScalarType::Half:
  312. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  313. static_cast<nv_half *>(x.data_ptr()),
  314. static_cast<nv_half *>(w.data_ptr()),
  315. indicies.data_ptr<int64_t>(), h_in, h_out,
  316. y_offset, full_y_size, B, num_layers,
  317. layer_idx, scale);
  318. break;
  319. case at::ScalarType::BFloat16:
  320. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  321. static_cast<nv_half *>(x.data_ptr()),
  322. static_cast<nv_bfloat16 *>(w.data_ptr()),
  323. indicies.data_ptr<int64_t>(), h_in, h_out,
  324. y_offset, full_y_size, B, num_layers,
  325. layer_idx, scale);
  326. break;
  327. default:
  328. break;
  329. }
  330. break;
  331. case at::ScalarType::BFloat16:
  332. switch (w.scalar_type()) {
  333. case at::ScalarType::Half:
  334. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  335. static_cast<nv_half *>(x.data_ptr()),
  336. static_cast<nv_half *>(w.data_ptr()),
  337. indicies.data_ptr<int64_t>(), h_in, h_out,
  338. y_offset, full_y_size, B, num_layers,
  339. layer_idx, scale);
  340. break;
  341. case at::ScalarType::BFloat16:
  342. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  343. static_cast<nv_half *>(x.data_ptr()),
  344. static_cast<nv_bfloat16 *>(w.data_ptr()),
  345. indicies.data_ptr<int64_t>(), h_in, h_out,
  346. y_offset, full_y_size, B, num_layers,
  347. layer_idx, scale);
  348. break;
  349. default:
  350. break;
  351. }
  352. break;
  353. case at::ScalarType::Float:
  354. switch (w.scalar_type()) {
  355. case at::ScalarType::Half:
  356. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  357. static_cast<nv_half *>(x.data_ptr()),
  358. static_cast<nv_half *>(w.data_ptr()),
  359. indicies.data_ptr<int64_t>(), h_in, h_out,
  360. y_offset, full_y_size, B, num_layers,
  361. layer_idx, scale);
  362. break;
  363. case at::ScalarType::BFloat16:
  364. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  365. static_cast<nv_half *>(x.data_ptr()),
  366. static_cast<nv_bfloat16 *>(w.data_ptr()),
  367. indicies.data_ptr<int64_t>(), h_in, h_out,
  368. y_offset, full_y_size, B, num_layers,
  369. layer_idx, scale);
  370. break;
  371. default:
  372. break;
  373. }
  374. break;
  375. default:
  376. break;
  377. }
  378. break;
  379. case at::ScalarType::BFloat16:
  380. switch (y.scalar_type()) {
  381. case at::ScalarType::Half:
  382. switch (w.scalar_type()) {
  383. case at::ScalarType::Half:
  384. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  385. static_cast<nv_bfloat16 *>(x.data_ptr()),
  386. static_cast<nv_half *>(w.data_ptr()),
  387. indicies.data_ptr<int64_t>(), h_in, h_out,
  388. y_offset, full_y_size, B, num_layers,
  389. layer_idx, scale);
  390. break;
  391. case at::ScalarType::BFloat16:
  392. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  393. static_cast<nv_bfloat16 *>(x.data_ptr()),
  394. static_cast<nv_bfloat16 *>(w.data_ptr()),
  395. indicies.data_ptr<int64_t>(), h_in, h_out,
  396. y_offset, full_y_size, B, num_layers,
  397. layer_idx, scale);
  398. break;
  399. default:
  400. break;
  401. }
  402. break;
  403. case at::ScalarType::BFloat16:
  404. switch (w.scalar_type()) {
  405. case at::ScalarType::Half:
  406. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  407. static_cast<nv_bfloat16 *>(x.data_ptr()),
  408. static_cast<nv_half *>(w.data_ptr()),
  409. indicies.data_ptr<int64_t>(), h_in, h_out,
  410. y_offset, full_y_size, B, num_layers,
  411. layer_idx, scale);
  412. break;
  413. case at::ScalarType::BFloat16:
  414. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  415. static_cast<nv_bfloat16 *>(x.data_ptr()),
  416. static_cast<nv_bfloat16 *>(w.data_ptr()),
  417. indicies.data_ptr<int64_t>(), h_in, h_out,
  418. y_offset, full_y_size, B, num_layers,
  419. layer_idx, scale);
  420. break;
  421. default:
  422. break;
  423. }
  424. break;
  425. case at::ScalarType::Float:
  426. switch (w.scalar_type()) {
  427. case at::ScalarType::Half:
  428. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  429. static_cast<nv_bfloat16 *>(x.data_ptr()),
  430. static_cast<nv_half *>(w.data_ptr()),
  431. indicies.data_ptr<int64_t>(), h_in, h_out,
  432. y_offset, full_y_size, B, num_layers,
  433. layer_idx, scale);
  434. break;
  435. case at::ScalarType::BFloat16:
  436. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  437. static_cast<nv_bfloat16 *>(x.data_ptr()),
  438. static_cast<nv_bfloat16 *>(w.data_ptr()),
  439. indicies.data_ptr<int64_t>(), h_in, h_out,
  440. y_offset, full_y_size, B, num_layers,
  441. layer_idx, scale);
  442. break;
  443. default:
  444. break;
  445. }
  446. break;
  447. default:
  448. break;
  449. }
  450. break;
  451. case at::ScalarType::Float:
  452. switch (y.scalar_type()) {
  453. case at::ScalarType::Half:
  454. switch (w.scalar_type()) {
  455. case at::ScalarType::Half:
  456. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  457. static_cast<float *>(x.data_ptr()),
  458. static_cast<nv_half *>(w.data_ptr()),
  459. indicies.data_ptr<int64_t>(), h_in, h_out,
  460. y_offset, full_y_size, B, num_layers,
  461. layer_idx, scale);
  462. break;
  463. case at::ScalarType::BFloat16:
  464. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  465. static_cast<float *>(x.data_ptr()),
  466. static_cast<nv_bfloat16 *>(w.data_ptr()),
  467. indicies.data_ptr<int64_t>(), h_in, h_out,
  468. y_offset, full_y_size, B, num_layers,
  469. layer_idx, scale);
  470. break;
  471. default:
  472. break;
  473. }
  474. break;
  475. case at::ScalarType::BFloat16:
  476. switch (w.scalar_type()) {
  477. case at::ScalarType::Half:
  478. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  479. static_cast<float *>(x.data_ptr()),
  480. static_cast<nv_half *>(w.data_ptr()),
  481. indicies.data_ptr<int64_t>(), h_in, h_out,
  482. y_offset, full_y_size, B, num_layers,
  483. layer_idx, scale);
  484. break;
  485. case at::ScalarType::BFloat16:
  486. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  487. static_cast<float *>(x.data_ptr()),
  488. static_cast<nv_bfloat16 *>(w.data_ptr()),
  489. indicies.data_ptr<int64_t>(), h_in, h_out,
  490. y_offset, full_y_size, B, num_layers,
  491. layer_idx, scale);
  492. break;
  493. default:
  494. break;
  495. }
  496. break;
  497. case at::ScalarType::Float:
  498. switch (w.scalar_type()) {
  499. case at::ScalarType::Half:
  500. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  501. static_cast<float *>(x.data_ptr()),
  502. static_cast<nv_half *>(w.data_ptr()),
  503. indicies.data_ptr<int64_t>(), h_in, h_out,
  504. y_offset, full_y_size, B, num_layers,
  505. layer_idx, scale);
  506. break;
  507. case at::ScalarType::BFloat16:
  508. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  509. static_cast<float *>(x.data_ptr()),
  510. static_cast<nv_bfloat16 *>(w.data_ptr()),
  511. indicies.data_ptr<int64_t>(), h_in, h_out,
  512. y_offset, full_y_size, B, num_layers,
  513. layer_idx, scale);
  514. break;
  515. default:
  516. break;
  517. }
  518. break;
  519. default:
  520. break;
  521. }
  522. break;
  523. default:
  524. break;
  525. }
  526. }
  527. TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
  528. " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
  529. }
  530. } // namespace
  531. //====== pybind ======
  532. #define DEFINE_pybind(name) m.def(#name, &name, #name);
  533. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  534. m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
  535. m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
  536. "dispatch_bgmv_low_level");
  537. }