1
0

punica_ops.cc 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. #include <cuda_bf16.h>
  2. #include <cuda_fp16.h>
  3. #include <torch/extension.h>
  4. #include <c10/cuda/CUDAGuard.h>
  5. #include <cstdint>
  6. #include "bgmv/bgmv_config.h"
  7. namespace {
  8. //====== utils ======
  9. inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
  10. const char *a_name, const char *b_name) {
  11. TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
  12. a.dim(), " vs ", b.dim());
  13. for (int i = 0; i < a.dim(); ++i) {
  14. TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
  15. ".size(", i, ")");
  16. }
  17. }
  18. inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
  19. return (uint64_t(a) << 32) | uint64_t(b);
  20. }
  21. #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  22. #define CHECK_CONTIGUOUS(x) \
  23. TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  24. #define CHECK_INPUT(x) \
  25. CHECK_CUDA(x); \
  26. CHECK_CONTIGUOUS(x)
  27. #define CHECK_DIM(d, x) \
  28. TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
  29. #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
  30. #define CHECK_EQ(a, b) \
  31. TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
  32. //====== bgmv ======
  33. template <typename in_T, typename out_T, typename W_T>
  34. inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
  35. const int64_t *lora_indices,
  36. uint32_t in_features, uint32_t out_features,
  37. int64_t y_offset, int64_t full_y_size,
  38. int64_t batch_size, int64_t num_layers,
  39. int64_t layer_idx, float scale) {
  40. switch (pack_u32(in_features, out_features)) {
  41. #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
  42. case pack_u32(feat_in, feat_out): \
  43. bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
  44. full_y_size, batch_size, num_layers, \
  45. layer_idx, scale); \
  46. break;
  47. #define CASE(_in_T, _out_T, _W_T, narrow, wide) \
  48. CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
  49. CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
  50. FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
  51. #undef CASE
  52. #undef CASE_ONESIDE
  53. default:
  54. return false;
  55. }
  56. return true;
  57. }
  58. void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
  59. torch::Tensor indicies, int64_t layer_idx, float scale) {
  60. CHECK_INPUT(y);
  61. CHECK_INPUT(x);
  62. CHECK_INPUT(w);
  63. CHECK_INPUT(indicies);
  64. CHECK_DIM(2, y);
  65. CHECK_DIM(2, x);
  66. CHECK_DIM(4, w);
  67. CHECK_DIM(1, indicies);
  68. int64_t B = x.size(0);
  69. int64_t h_in = x.size(1);
  70. int64_t h_out = y.size(1);
  71. int64_t num_layers = w.size(1);
  72. CHECK_EQ(w.size(3), h_in);
  73. CHECK_EQ(w.size(2), h_out);
  74. CHECK_EQ(indicies.size(0), x.size(0));
  75. CHECK_EQ(y.size(0), x.size(0));
  76. const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
  77. bool ok = false;
  78. if (h_in <= 128512 && h_out <= 128512) {
  79. // TODO: See if we can get rid of this massive nested switch
  80. switch (x.scalar_type()) {
  81. case at::ScalarType::Half:
  82. switch (y.scalar_type()) {
  83. case at::ScalarType::Half:
  84. switch (w.scalar_type()) {
  85. case at::ScalarType::Half:
  86. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  87. static_cast<nv_half *>(x.data_ptr()),
  88. static_cast<nv_half *>(w.data_ptr()),
  89. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  90. h_out, B, num_layers, layer_idx, scale);
  91. break;
  92. case at::ScalarType::BFloat16:
  93. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  94. static_cast<nv_half *>(x.data_ptr()),
  95. static_cast<nv_bfloat16 *>(w.data_ptr()),
  96. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  97. h_out, B, num_layers, layer_idx, scale);
  98. break;
  99. default:
  100. break;
  101. }
  102. break;
  103. case at::ScalarType::BFloat16:
  104. switch (w.scalar_type()) {
  105. case at::ScalarType::Half:
  106. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  107. static_cast<nv_half *>(x.data_ptr()),
  108. static_cast<nv_half *>(w.data_ptr()),
  109. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  110. h_out, B, num_layers, layer_idx, scale);
  111. break;
  112. case at::ScalarType::BFloat16:
  113. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  114. static_cast<nv_half *>(x.data_ptr()),
  115. static_cast<nv_bfloat16 *>(w.data_ptr()),
  116. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  117. h_out, B, num_layers, layer_idx, scale);
  118. break;
  119. default:
  120. break;
  121. }
  122. break;
  123. case at::ScalarType::Float:
  124. switch (w.scalar_type()) {
  125. case at::ScalarType::Half:
  126. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  127. static_cast<nv_half *>(x.data_ptr()),
  128. static_cast<nv_half *>(w.data_ptr()),
  129. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  130. h_out, B, num_layers, layer_idx, scale);
  131. break;
  132. case at::ScalarType::BFloat16:
  133. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  134. static_cast<nv_half *>(x.data_ptr()),
  135. static_cast<nv_bfloat16 *>(w.data_ptr()),
  136. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  137. h_out, B, num_layers, layer_idx, scale);
  138. break;
  139. default:
  140. break;
  141. }
  142. break;
  143. default:
  144. break;
  145. }
  146. break;
  147. case at::ScalarType::BFloat16:
  148. switch (y.scalar_type()) {
  149. case at::ScalarType::Half:
  150. switch (w.scalar_type()) {
  151. case at::ScalarType::Half:
  152. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  153. static_cast<nv_bfloat16 *>(x.data_ptr()),
  154. static_cast<nv_half *>(w.data_ptr()),
  155. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  156. h_out, B, num_layers, layer_idx, scale);
  157. break;
  158. case at::ScalarType::BFloat16:
  159. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  160. static_cast<nv_bfloat16 *>(x.data_ptr()),
  161. static_cast<nv_bfloat16 *>(w.data_ptr()),
  162. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  163. h_out, B, num_layers, layer_idx, scale);
  164. break;
  165. default:
  166. break;
  167. }
  168. break;
  169. case at::ScalarType::BFloat16:
  170. switch (w.scalar_type()) {
  171. case at::ScalarType::Half:
  172. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  173. static_cast<nv_bfloat16 *>(x.data_ptr()),
  174. static_cast<nv_half *>(w.data_ptr()),
  175. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  176. h_out, B, num_layers, layer_idx, scale);
  177. break;
  178. case at::ScalarType::BFloat16:
  179. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  180. static_cast<nv_bfloat16 *>(x.data_ptr()),
  181. static_cast<nv_bfloat16 *>(w.data_ptr()),
  182. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  183. h_out, B, num_layers, layer_idx, scale);
  184. break;
  185. default:
  186. break;
  187. }
  188. break;
  189. case at::ScalarType::Float:
  190. switch (w.scalar_type()) {
  191. case at::ScalarType::Half:
  192. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  193. static_cast<nv_bfloat16 *>(x.data_ptr()),
  194. static_cast<nv_half *>(w.data_ptr()),
  195. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  196. h_out, B, num_layers, layer_idx, scale);
  197. break;
  198. case at::ScalarType::BFloat16:
  199. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  200. static_cast<nv_bfloat16 *>(x.data_ptr()),
  201. static_cast<nv_bfloat16 *>(w.data_ptr()),
  202. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  203. h_out, B, num_layers, layer_idx, scale);
  204. break;
  205. default:
  206. break;
  207. }
  208. break;
  209. default:
  210. break;
  211. }
  212. break;
  213. case at::ScalarType::Float:
  214. switch (y.scalar_type()) {
  215. case at::ScalarType::Half:
  216. switch (w.scalar_type()) {
  217. case at::ScalarType::Half:
  218. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  219. static_cast<float *>(x.data_ptr()),
  220. static_cast<nv_half *>(w.data_ptr()),
  221. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  222. h_out, B, num_layers, layer_idx, scale);
  223. break;
  224. case at::ScalarType::BFloat16:
  225. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  226. static_cast<float *>(x.data_ptr()),
  227. static_cast<nv_bfloat16 *>(w.data_ptr()),
  228. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  229. h_out, B, num_layers, layer_idx, scale);
  230. break;
  231. default:
  232. break;
  233. }
  234. break;
  235. case at::ScalarType::BFloat16:
  236. switch (w.scalar_type()) {
  237. case at::ScalarType::Half:
  238. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  239. static_cast<float *>(x.data_ptr()),
  240. static_cast<nv_half *>(w.data_ptr()),
  241. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  242. h_out, B, num_layers, layer_idx, scale);
  243. break;
  244. case at::ScalarType::BFloat16:
  245. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  246. static_cast<float *>(x.data_ptr()),
  247. static_cast<nv_bfloat16 *>(w.data_ptr()),
  248. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  249. h_out, B, num_layers, layer_idx, scale);
  250. break;
  251. default:
  252. break;
  253. }
  254. break;
  255. case at::ScalarType::Float:
  256. switch (w.scalar_type()) {
  257. case at::ScalarType::Half:
  258. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  259. static_cast<float *>(x.data_ptr()),
  260. static_cast<nv_half *>(w.data_ptr()),
  261. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  262. h_out, B, num_layers, layer_idx, scale);
  263. break;
  264. case at::ScalarType::BFloat16:
  265. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  266. static_cast<float *>(x.data_ptr()),
  267. static_cast<nv_bfloat16 *>(w.data_ptr()),
  268. indicies.data_ptr<int64_t>(), h_in, h_out, 0,
  269. h_out, B, num_layers, layer_idx, scale);
  270. break;
  271. default:
  272. break;
  273. }
  274. break;
  275. default:
  276. break;
  277. }
  278. break;
  279. default:
  280. break;
  281. }
  282. }
  283. TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
  284. " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
  285. }
  286. void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
  287. torch::Tensor indicies, int64_t layer_idx,
  288. float scale, int64_t h_in, int64_t h_out,
  289. int64_t y_offset) {
  290. CHECK_INPUT(y);
  291. CHECK_INPUT(x);
  292. CHECK_INPUT(w);
  293. CHECK_INPUT(indicies);
  294. CHECK_DIM(2, y);
  295. CHECK_DIM(2, x);
  296. CHECK_DIM(4, w);
  297. CHECK_DIM(1, indicies);
  298. int64_t B = x.size(0);
  299. int64_t num_layers = w.size(1);
  300. int64_t full_y_size = y.size(1);
  301. CHECK_EQ(w.size(3), h_in);
  302. CHECK_EQ(w.size(2), h_out);
  303. CHECK_EQ(indicies.size(0), x.size(0));
  304. CHECK_EQ(y.size(0), x.size(0));
  305. const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
  306. bool ok = false;
  307. if (h_in <= 128512 && h_out <= 128512) {
  308. // TODO: See if we can get rid of this massive nested switch
  309. switch (x.scalar_type()) {
  310. case at::ScalarType::Half:
  311. switch (y.scalar_type()) {
  312. case at::ScalarType::Half:
  313. switch (w.scalar_type()) {
  314. case at::ScalarType::Half:
  315. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  316. static_cast<nv_half *>(x.data_ptr()),
  317. static_cast<nv_half *>(w.data_ptr()),
  318. indicies.data_ptr<int64_t>(), h_in, h_out,
  319. y_offset, full_y_size, B, num_layers,
  320. layer_idx, scale);
  321. break;
  322. case at::ScalarType::BFloat16:
  323. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  324. static_cast<nv_half *>(x.data_ptr()),
  325. static_cast<nv_bfloat16 *>(w.data_ptr()),
  326. indicies.data_ptr<int64_t>(), h_in, h_out,
  327. y_offset, full_y_size, B, num_layers,
  328. layer_idx, scale);
  329. break;
  330. default:
  331. break;
  332. }
  333. break;
  334. case at::ScalarType::BFloat16:
  335. switch (w.scalar_type()) {
  336. case at::ScalarType::Half:
  337. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  338. static_cast<nv_half *>(x.data_ptr()),
  339. static_cast<nv_half *>(w.data_ptr()),
  340. indicies.data_ptr<int64_t>(), h_in, h_out,
  341. y_offset, full_y_size, B, num_layers,
  342. layer_idx, scale);
  343. break;
  344. case at::ScalarType::BFloat16:
  345. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  346. static_cast<nv_half *>(x.data_ptr()),
  347. static_cast<nv_bfloat16 *>(w.data_ptr()),
  348. indicies.data_ptr<int64_t>(), h_in, h_out,
  349. y_offset, full_y_size, B, num_layers,
  350. layer_idx, scale);
  351. break;
  352. default:
  353. break;
  354. }
  355. break;
  356. case at::ScalarType::Float:
  357. switch (w.scalar_type()) {
  358. case at::ScalarType::Half:
  359. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  360. static_cast<nv_half *>(x.data_ptr()),
  361. static_cast<nv_half *>(w.data_ptr()),
  362. indicies.data_ptr<int64_t>(), h_in, h_out,
  363. y_offset, full_y_size, B, num_layers,
  364. layer_idx, scale);
  365. break;
  366. case at::ScalarType::BFloat16:
  367. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  368. static_cast<nv_half *>(x.data_ptr()),
  369. static_cast<nv_bfloat16 *>(w.data_ptr()),
  370. indicies.data_ptr<int64_t>(), h_in, h_out,
  371. y_offset, full_y_size, B, num_layers,
  372. layer_idx, scale);
  373. break;
  374. default:
  375. break;
  376. }
  377. break;
  378. default:
  379. break;
  380. }
  381. break;
  382. case at::ScalarType::BFloat16:
  383. switch (y.scalar_type()) {
  384. case at::ScalarType::Half:
  385. switch (w.scalar_type()) {
  386. case at::ScalarType::Half:
  387. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  388. static_cast<nv_bfloat16 *>(x.data_ptr()),
  389. static_cast<nv_half *>(w.data_ptr()),
  390. indicies.data_ptr<int64_t>(), h_in, h_out,
  391. y_offset, full_y_size, B, num_layers,
  392. layer_idx, scale);
  393. break;
  394. case at::ScalarType::BFloat16:
  395. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  396. static_cast<nv_bfloat16 *>(x.data_ptr()),
  397. static_cast<nv_bfloat16 *>(w.data_ptr()),
  398. indicies.data_ptr<int64_t>(), h_in, h_out,
  399. y_offset, full_y_size, B, num_layers,
  400. layer_idx, scale);
  401. break;
  402. default:
  403. break;
  404. }
  405. break;
  406. case at::ScalarType::BFloat16:
  407. switch (w.scalar_type()) {
  408. case at::ScalarType::Half:
  409. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  410. static_cast<nv_bfloat16 *>(x.data_ptr()),
  411. static_cast<nv_half *>(w.data_ptr()),
  412. indicies.data_ptr<int64_t>(), h_in, h_out,
  413. y_offset, full_y_size, B, num_layers,
  414. layer_idx, scale);
  415. break;
  416. case at::ScalarType::BFloat16:
  417. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  418. static_cast<nv_bfloat16 *>(x.data_ptr()),
  419. static_cast<nv_bfloat16 *>(w.data_ptr()),
  420. indicies.data_ptr<int64_t>(), h_in, h_out,
  421. y_offset, full_y_size, B, num_layers,
  422. layer_idx, scale);
  423. break;
  424. default:
  425. break;
  426. }
  427. break;
  428. case at::ScalarType::Float:
  429. switch (w.scalar_type()) {
  430. case at::ScalarType::Half:
  431. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  432. static_cast<nv_bfloat16 *>(x.data_ptr()),
  433. static_cast<nv_half *>(w.data_ptr()),
  434. indicies.data_ptr<int64_t>(), h_in, h_out,
  435. y_offset, full_y_size, B, num_layers,
  436. layer_idx, scale);
  437. break;
  438. case at::ScalarType::BFloat16:
  439. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  440. static_cast<nv_bfloat16 *>(x.data_ptr()),
  441. static_cast<nv_bfloat16 *>(w.data_ptr()),
  442. indicies.data_ptr<int64_t>(), h_in, h_out,
  443. y_offset, full_y_size, B, num_layers,
  444. layer_idx, scale);
  445. break;
  446. default:
  447. break;
  448. }
  449. break;
  450. default:
  451. break;
  452. }
  453. break;
  454. case at::ScalarType::Float:
  455. switch (y.scalar_type()) {
  456. case at::ScalarType::Half:
  457. switch (w.scalar_type()) {
  458. case at::ScalarType::Half:
  459. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  460. static_cast<float *>(x.data_ptr()),
  461. static_cast<nv_half *>(w.data_ptr()),
  462. indicies.data_ptr<int64_t>(), h_in, h_out,
  463. y_offset, full_y_size, B, num_layers,
  464. layer_idx, scale);
  465. break;
  466. case at::ScalarType::BFloat16:
  467. ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
  468. static_cast<float *>(x.data_ptr()),
  469. static_cast<nv_bfloat16 *>(w.data_ptr()),
  470. indicies.data_ptr<int64_t>(), h_in, h_out,
  471. y_offset, full_y_size, B, num_layers,
  472. layer_idx, scale);
  473. break;
  474. default:
  475. break;
  476. }
  477. break;
  478. case at::ScalarType::BFloat16:
  479. switch (w.scalar_type()) {
  480. case at::ScalarType::Half:
  481. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  482. static_cast<float *>(x.data_ptr()),
  483. static_cast<nv_half *>(w.data_ptr()),
  484. indicies.data_ptr<int64_t>(), h_in, h_out,
  485. y_offset, full_y_size, B, num_layers,
  486. layer_idx, scale);
  487. break;
  488. case at::ScalarType::BFloat16:
  489. ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
  490. static_cast<float *>(x.data_ptr()),
  491. static_cast<nv_bfloat16 *>(w.data_ptr()),
  492. indicies.data_ptr<int64_t>(), h_in, h_out,
  493. y_offset, full_y_size, B, num_layers,
  494. layer_idx, scale);
  495. break;
  496. default:
  497. break;
  498. }
  499. break;
  500. case at::ScalarType::Float:
  501. switch (w.scalar_type()) {
  502. case at::ScalarType::Half:
  503. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  504. static_cast<float *>(x.data_ptr()),
  505. static_cast<nv_half *>(w.data_ptr()),
  506. indicies.data_ptr<int64_t>(), h_in, h_out,
  507. y_offset, full_y_size, B, num_layers,
  508. layer_idx, scale);
  509. break;
  510. case at::ScalarType::BFloat16:
  511. ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
  512. static_cast<float *>(x.data_ptr()),
  513. static_cast<nv_bfloat16 *>(w.data_ptr()),
  514. indicies.data_ptr<int64_t>(), h_in, h_out,
  515. y_offset, full_y_size, B, num_layers,
  516. layer_idx, scale);
  517. break;
  518. default:
  519. break;
  520. }
  521. break;
  522. default:
  523. break;
  524. }
  525. break;
  526. default:
  527. break;
  528. }
  529. }
  530. TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
  531. " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
  532. }
  533. } // namespace
  534. //====== pybind ======
  535. #define DEFINE_pybind(name) m.def(#name, &name, #name);
  536. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  537. m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
  538. m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
  539. "dispatch_bgmv_low_level");
  540. }