_custom_ops.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. import contextlib
  2. import functools
  3. from typing import List, Optional, Tuple, Type
  4. import torch
  5. from loguru import logger
  6. from aphrodite._core_ext import ScalarType
  7. from aphrodite.common.utils import is_hip
  8. from aphrodite.platforms import current_platform
  9. if not current_platform.is_tpu():
  10. try:
  11. import aphrodite._C
  12. except ImportError as e:
  13. logger.warning(f"Failed to import from aphrodite._C with {e}")
  14. with contextlib.suppress(ImportError):
  15. # ruff: noqa: F401
  16. import aphrodite._moe_C
  17. def hint_on_error(fn):
  18. @functools.wraps(fn)
  19. def wrapper(*args, **kwargs):
  20. try:
  21. return fn(*args, **kwargs)
  22. except AttributeError as e:
  23. msg = (
  24. f"Error in calling custom op {fn.__name__}: {e}\n"
  25. f"Possibly you have built or installed an obsolete version of aphrodite.\n"
  26. f"Please try a clean build and install of aphrodite,"
  27. f"or remove old built files such as aphrodite/*.so and build/ ."
  28. )
  29. logger.error(msg)
  30. raise e
  31. return wrapper
  32. # activation ops
  33. def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  34. torch.ops._C.silu_and_mul(out, x)
  35. def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  36. torch.ops._C.gelu_and_mul(out, x)
  37. def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  38. torch.ops._C.gelu_tanh_and_mul(out, x)
  39. def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
  40. torch.ops._C.gelu_fast(out, x)
  41. def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
  42. torch.ops._C.gelu_new(out, x)
  43. def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
  44. torch.ops._C.gelu_quick(out, x)
  45. # page attention ops
  46. def paged_attention_v1(
  47. out: torch.Tensor,
  48. query: torch.Tensor,
  49. key_cache: torch.Tensor,
  50. value_cache: torch.Tensor,
  51. num_kv_heads: int,
  52. scale: float,
  53. block_tables: torch.Tensor,
  54. seq_lens: torch.Tensor,
  55. block_size: int,
  56. max_seq_len: int,
  57. alibi_slopes: Optional[torch.Tensor],
  58. kv_cache_dtype: str,
  59. k_scale: float,
  60. v_scale: float,
  61. tp_rank: int = 0,
  62. blocksparse_local_blocks: int = 0,
  63. blocksparse_vert_stride: int = 0,
  64. blocksparse_block_size: int = 64,
  65. blocksparse_head_sliding_step: int = 0,
  66. ) -> None:
  67. torch.ops._C.paged_attention_v1(
  68. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
  69. seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
  70. k_scale, v_scale, tp_rank, blocksparse_local_blocks,
  71. blocksparse_vert_stride, blocksparse_block_size,
  72. blocksparse_head_sliding_step)
  73. def paged_attention_v2(
  74. out: torch.Tensor,
  75. exp_sum: torch.Tensor,
  76. max_logits: torch.Tensor,
  77. tmp_out: torch.Tensor,
  78. query: torch.Tensor,
  79. key_cache: torch.Tensor,
  80. value_cache: torch.Tensor,
  81. num_kv_heads: int,
  82. scale: float,
  83. block_tables: torch.Tensor,
  84. seq_lens: torch.Tensor,
  85. block_size: int,
  86. max_seq_len: int,
  87. alibi_slopes: Optional[torch.Tensor],
  88. kv_cache_dtype: str,
  89. k_scale: float,
  90. v_scale: float,
  91. tp_rank: int = 0,
  92. blocksparse_local_blocks: int = 0,
  93. blocksparse_vert_stride: int = 0,
  94. blocksparse_block_size: int = 64,
  95. blocksparse_head_sliding_step: int = 0,
  96. ) -> None:
  97. torch.ops._C.paged_attention_v2(
  98. out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
  99. num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
  100. alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
  101. blocksparse_local_blocks, blocksparse_vert_stride,
  102. blocksparse_block_size, blocksparse_head_sliding_step)
  103. # pos encoding ops
  104. def rotary_embedding(
  105. positions: torch.Tensor,
  106. query: torch.Tensor,
  107. key: torch.Tensor,
  108. head_size: int,
  109. cos_sin_cache: torch.Tensor,
  110. is_neox: bool,
  111. ) -> None:
  112. torch.ops._C.rotary_embedding(positions, query, key, head_size,
  113. cos_sin_cache, is_neox)
  114. def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
  115. key: torch.Tensor, head_size: int,
  116. cos_sin_cache: torch.Tensor, is_neox: bool,
  117. rot_dim: int,
  118. cos_sin_cache_offsets: torch.Tensor) -> None:
  119. torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
  120. cos_sin_cache, is_neox, rot_dim,
  121. cos_sin_cache_offsets)
  122. # layer norm ops
  123. def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
  124. epsilon: float) -> None:
  125. torch.ops._C.rms_norm(out, input, weight, epsilon)
  126. def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
  127. weight: torch.Tensor, epsilon: float) -> None:
  128. torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
  129. def advance_step(num_seqs: int, num_queries: int, block_size: int,
  130. input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
  131. input_positions: torch.Tensor, seq_lens: torch.Tensor,
  132. slot_mapping: torch.Tensor,
  133. block_tables: torch.Tensor) -> None:
  134. """Advance a step on GPU for existing inputs for a multi-step runner"""
  135. return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
  136. input_tokens, sampled_token_ids,
  137. input_positions, seq_lens, slot_mapping,
  138. block_tables)
  139. # quantization ops
  140. # awq
  141. def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
  142. zeros: torch.Tensor, split_k_iters: int, thx: int,
  143. thy: int) -> torch.Tensor:
  144. return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
  145. thx, thy)
  146. def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
  147. scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
  148. return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
  149. # gptq
  150. def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  151. b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
  152. b_g_idx: torch.Tensor, use_exllama: bool,
  153. bit: int) -> torch.Tensor:
  154. return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
  155. b_g_idx, use_exllama, bit)
  156. def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
  157. bit: int) -> None:
  158. torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
  159. # squeezellm
  160. def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
  161. lookup_table: torch.Tensor) -> None:
  162. torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
  163. # marlin
  164. def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  165. b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
  166. size_n: int, size_k: int) -> torch.Tensor:
  167. return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
  168. size_n, size_k)
  169. # marlin_24
  170. def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  171. b_meta: torch.Tensor, b_scales: torch.Tensor,
  172. workspace: torch.Tensor, b_q_type: ScalarType,
  173. size_m: int, size_n: int, size_k: int) -> torch.Tensor:
  174. return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
  175. workspace, b_q_type, size_m,
  176. size_n, size_k)
  177. # fp8 marlin
  178. def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  179. b_scales: torch.Tensor, workspace: torch.Tensor,
  180. num_bits: int, size_m: int, size_n: int,
  181. size_k: int) -> torch.Tensor:
  182. return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
  183. num_bits, size_m, size_n, size_k)
  184. # cutlass
  185. def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
  186. return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
  187. def cutlass_scaled_mm(a: torch.Tensor,
  188. b: torch.Tensor,
  189. scale_a: torch.Tensor,
  190. scale_b: torch.Tensor,
  191. out_dtype: Type[torch.dtype],
  192. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  193. assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
  194. assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
  195. assert bias is None or bias.shape[0] == b.shape[
  196. 1] and bias.dtype == out_dtype
  197. m = a.shape[0]
  198. n = b.shape[1]
  199. out = torch.empty((m, n), dtype=out_dtype, device=a.device)
  200. torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
  201. return out
  202. def cutlass_scaled_mm_azp(a: torch.Tensor,
  203. b: torch.Tensor,
  204. scale_a: torch.Tensor,
  205. scale_b: torch.Tensor,
  206. out_dtype: torch.dtype,
  207. azp_adj: torch.Tensor,
  208. azp: Optional[torch.Tensor] = None,
  209. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  210. assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
  211. assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
  212. assert bias is None or bias.numel(
  213. ) == b.shape[1] and bias.dtype == out_dtype
  214. m = a.shape[0]
  215. n = b.shape[1]
  216. out = torch.empty((m, n), dtype=out_dtype, device=a.device)
  217. torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
  218. azp, bias)
  219. return out
  220. # aqlm
  221. def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
  222. codebooks: torch.Tensor, scales: torch.Tensor,
  223. codebook_partition_sizes: List[int],
  224. bias: Optional[torch.Tensor]) -> torch.Tensor:
  225. return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
  226. codebook_partition_sizes, bias)
  227. def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
  228. codebook_partition_sizes: List[int]) -> torch.Tensor:
  229. return torch.ops._C.aqlm_dequant(codes, codebooks,
  230. codebook_partition_sizes)
  231. # gptq_marlin
  232. def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
  233. size_k: int, size_n: int,
  234. num_bits: int) -> torch.Tensor:
  235. return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
  236. num_bits)
  237. def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
  238. num_bits: int) -> torch.Tensor:
  239. return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
  240. def gptq_marlin_gemm(a: torch.Tensor,
  241. b_q_weight: torch.Tensor,
  242. b_scales: torch.Tensor,
  243. b_zeros: torch.Tensor,
  244. g_idx: torch.Tensor,
  245. perm: torch.Tensor,
  246. workspace: torch.Tensor,
  247. b_q_type: ScalarType,
  248. size_m: int,
  249. size_n: int,
  250. size_k: int,
  251. is_k_full: bool,
  252. has_zp: bool = False,
  253. use_fp32_reduce: bool = False) -> torch.Tensor:
  254. return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
  255. g_idx, perm, workspace, b_q_type,
  256. size_m, size_n, size_k, is_k_full,
  257. has_zp, use_fp32_reduce)
  258. # fp8
  259. def scaled_fp8_quant(
  260. input: torch.Tensor,
  261. scale: Optional[torch.Tensor] = None,
  262. num_token_padding: Optional[int] = None,
  263. scale_ub: Optional[torch.Tensor] = None,
  264. use_per_token_if_dynamic: bool = False,
  265. ) -> Tuple[torch.Tensor, torch.Tensor]:
  266. """
  267. Quantize input tensor to FP8 and return quantized tensor and scale.
  268. This function supports both static and dynamic quantization: If you
  269. provide the scale, it will use static scaling and if you omit it,
  270. the scale will be determined dynamically. The function also allows
  271. optional padding of the output tensors for downstream kernels that
  272. will benefit from padding.
  273. Args:
  274. input: The input tensor to be quantized to FP8
  275. scale: Optional scaling factor for the FP8 quantization
  276. num_token_padding: If specified, pad the first dimension
  277. of the output to at least this value.
  278. use_per_token_if_dynamic: Whether to do per_tensor or per_token
  279. in the dynamic quantization case.
  280. Returns:
  281. Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
  282. scaling factor.
  283. """
  284. # This code assumes batch_dim and num_tokens are flattened
  285. assert (input.ndim == 2)
  286. shape = input.shape
  287. # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
  288. out_dtype: torch.dtype = torch.float8_e4m3fnuz if \
  289. is_hip() else torch.float8_e4m3fn
  290. if num_token_padding:
  291. shape = (max(num_token_padding, input.shape[0]), shape[1])
  292. output = torch.empty(shape, device=input.device, dtype=out_dtype)
  293. if scale is None:
  294. if use_per_token_if_dynamic:
  295. scale = torch.empty((shape[0], 1),
  296. device=input.device,
  297. dtype=torch.float32)
  298. torch.ops._C.dynamic_per_token_scaled_fp8_quant(
  299. output, input, scale, scale_ub)
  300. else:
  301. scale = torch.zeros(1, device=input.device, dtype=torch.float32)
  302. torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
  303. else:
  304. # num_token_padding not implemented for this case
  305. assert (scale.numel() == 1 or num_token_padding is None)
  306. torch.ops._C.static_scaled_fp8_quant(output, input, scale)
  307. return output, scale
  308. # int8
  309. def scaled_int8_quant(
  310. input: torch.Tensor,
  311. scale: Optional[torch.Tensor] = None
  312. ) -> Tuple[torch.Tensor, torch.Tensor]:
  313. """
  314. Quantize the input tensor to int8 and return the quantized tensor and scale.
  315. Args:
  316. input: The input tensor to be quantized to int8.
  317. scale: Optional scaling factor for the int8 quantization.
  318. When not provided, we invoke dynamic-per-token quantization.
  319. Returns:
  320. Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
  321. """
  322. output = torch.empty_like(input, dtype=torch.int8)
  323. if scale is not None:
  324. # static-per-tensor quantization.
  325. torch.ops._C.static_scaled_int8_quant(output, input, scale)
  326. return output, scale
  327. # dynamic-per-token quantization.
  328. input_scales = torch.empty((input.numel() // input.shape[-1], 1),
  329. device=input.device,
  330. dtype=torch.float32)
  331. torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
  332. return output, input_scales
  333. # quip#
  334. def quip_gemv(
  335. A: torch.Tensor,
  336. B: torch.Tensor,
  337. CB: torch.Tensor,
  338. ) -> torch.Tensor:
  339. return torch.ops._C.quip_gemv(A, B, CB)
  340. def quip_decompress(
  341. YIs: torch.Tensor,
  342. CB: torch.Tensor,
  343. Y: torch.Tensor,
  344. ) -> torch.Tensor:
  345. return torch.ops._C.quip_decompress(YIs, CB, Y)
  346. # qqq ops
  347. def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  348. s_tok: torch.Tensor, s_ch: torch.Tensor,
  349. s_group: torch.Tensor, workspace: torch.Tensor,
  350. size_m: int, size_n: int, size_k: int) -> torch.Tensor:
  351. return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
  352. workspace, size_m, size_n, size_k)
  353. # gguf
  354. def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
  355. n: int) -> torch.Tensor:
  356. return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
  357. def ggml_mul_mat_vec_a8(
  358. W: torch.Tensor,
  359. X: torch.Tensor,
  360. quant_type: int,
  361. row: int,
  362. ) -> torch.Tensor:
  363. return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
  364. def ggml_mul_mat_a8(
  365. W: torch.Tensor,
  366. X: torch.Tensor,
  367. quant_type: int,
  368. row: int,
  369. ) -> torch.Tensor:
  370. return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
  371. # fp6
  372. def fp_eXmY_linear_forward_cuda(
  373. EXPONENT: int,
  374. MANTISSA: int,
  375. _in_feats: torch.Tensor,
  376. _weights: torch.Tensor,
  377. _scales: torch.Tensor,
  378. splitK: int = 1,
  379. ) -> torch.Tensor:
  380. return torch.ops._C.fp_eXmY_linear_forward_cuda(EXPONENT, MANTISSA,
  381. _in_feats, _weights,
  382. _scales, splitK)
  383. # qgemm ops
  384. @hint_on_error
  385. def qgemm_simple_80(
  386. input: torch.Tensor,
  387. weight: torch.Tensor,
  388. scales: torch.Tensor,
  389. table: torch.Tensor,
  390. table2: torch.Tensor,
  391. workspace: torch.Tensor,
  392. num_bits: int,
  393. group_size: int,
  394. ) -> torch.Tensor:
  395. return torch.ops._C.qgemm_simple_80(
  396. input, weight, scales, table, table2, workspace, num_bits, group_size
  397. )
  398. @hint_on_error
  399. def qgemm_simple_86(
  400. input: torch.Tensor,
  401. weight: torch.Tensor,
  402. scales: torch.Tensor,
  403. table: torch.Tensor,
  404. table2: torch.Tensor,
  405. workspace: torch.Tensor,
  406. num_bits: int,
  407. group_size: int,
  408. ) -> torch.Tensor:
  409. return torch.ops._C.qgemm_simple_86(
  410. input, weight, scales, table, table2, workspace, num_bits, group_size
  411. )
  412. @hint_on_error
  413. def qgemm_simple_89(
  414. input: torch.Tensor,
  415. weight: torch.Tensor,
  416. scales: torch.Tensor,
  417. table: torch.Tensor,
  418. table2: torch.Tensor,
  419. workspace: torch.Tensor,
  420. num_bits: int,
  421. group_size: int,
  422. ) -> torch.Tensor:
  423. return torch.ops._C.qgemm_simple_89(
  424. input, weight, scales, table, table2, workspace, num_bits, group_size
  425. )
  426. @hint_on_error
  427. def qgemm_raw_simple_80(
  428. input: torch.Tensor,
  429. weight: torch.Tensor,
  430. output: torch.Tensor,
  431. scales: torch.Tensor,
  432. table: torch.Tensor,
  433. table2: torch.Tensor,
  434. workspace: torch.Tensor,
  435. num_bits: int,
  436. group_size: int,
  437. template_id: int,
  438. ) -> None:
  439. torch.ops._C.qgemm_raw_simple_80(
  440. input,
  441. weight,
  442. output,
  443. scales,
  444. table,
  445. table2,
  446. workspace,
  447. num_bits,
  448. group_size,
  449. template_id,
  450. )
  451. @hint_on_error
  452. def qgemm_raw_simple_86(
  453. input: torch.Tensor,
  454. weight: torch.Tensor,
  455. output: torch.Tensor,
  456. scales: torch.Tensor,
  457. table: torch.Tensor,
  458. table2: torch.Tensor,
  459. workspace: torch.Tensor,
  460. num_bits: int,
  461. group_size: int,
  462. template_id: int,
  463. ) -> None:
  464. torch.ops._C.qgemm_raw_simple_86(
  465. input,
  466. weight,
  467. output,
  468. scales,
  469. table,
  470. table2,
  471. workspace,
  472. num_bits,
  473. group_size,
  474. template_id,
  475. )
  476. @hint_on_error
  477. def qgemm_raw_simple_89(
  478. input: torch.Tensor,
  479. weight: torch.Tensor,
  480. output: torch.Tensor,
  481. scales: torch.Tensor,
  482. table: torch.Tensor,
  483. table2: torch.Tensor,
  484. workspace: torch.Tensor,
  485. num_bits: int,
  486. group_size: int,
  487. template_id: int,
  488. ) -> None:
  489. torch.ops._C.qgemm_raw_simple_89(
  490. input,
  491. weight,
  492. output,
  493. scales,
  494. table,
  495. table2,
  496. workspace,
  497. num_bits,
  498. group_size,
  499. template_id,
  500. )
  501. # mamba
  502. def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
  503. bias_: Optional[torch.Tensor],
  504. seq_idx_: Optional[torch.Tensor],
  505. initial_states_: Optional[torch.Tensor],
  506. final_states_out_: Optional[torch.Tensor],
  507. silu_activation: bool) -> torch.Tensor:
  508. return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, None,
  509. initial_states_, final_states_out_,
  510. silu_activation)
  511. def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
  512. weight: torch.Tensor, bias_: Optional[torch.Tensor],
  513. silu_activation: bool) -> torch.Tensor:
  514. return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
  515. silu_activation)
  516. def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
  517. B: torch.Tensor, C: torch.Tensor,
  518. D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
  519. delta_bias_: Optional[torch.Tensor],
  520. delta_softplus: bool, index_: Optional[torch.Tensor],
  521. x: Optional[torch.Tensor]) -> List[torch.Tensor]:
  522. return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
  523. delta_bias_, delta_softplus, index_,
  524. x)
  525. # moe
  526. def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
  527. block_size: int, sorted_token_ids: torch.Tensor,
  528. experts_ids: torch.Tensor,
  529. num_tokens_post_pad: torch.Tensor) -> None:
  530. torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
  531. sorted_token_ids, experts_ids,
  532. num_tokens_post_pad)
  533. def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  534. token_expert_indicies: torch.Tensor,
  535. gating_output: float) -> None:
  536. torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
  537. token_expert_indicies, gating_output)
  538. def reshape_and_cache(
  539. key: torch.Tensor,
  540. value: torch.Tensor,
  541. key_cache: torch.Tensor,
  542. value_cache: torch.Tensor,
  543. slot_mapping: torch.Tensor,
  544. kv_cache_dtype: str,
  545. k_scale: float,
  546. v_scale: float,
  547. ) -> None:
  548. torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
  549. value_cache, slot_mapping,
  550. kv_cache_dtype, k_scale, v_scale)
  551. def reshape_and_cache_flash(
  552. key: torch.Tensor,
  553. value: torch.Tensor,
  554. key_cache: torch.Tensor,
  555. value_cache: torch.Tensor,
  556. slot_mapping: torch.Tensor,
  557. kv_cache_dtype: str,
  558. k_scale: float,
  559. v_scale: float,
  560. ) -> None:
  561. torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
  562. value_cache, slot_mapping,
  563. kv_cache_dtype, k_scale,
  564. v_scale)
  565. def copy_blocks(key_caches: List[torch.Tensor],
  566. value_caches: List[torch.Tensor],
  567. block_mapping: torch.Tensor) -> None:
  568. torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
  569. def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
  570. block_mapping: torch.Tensor) -> None:
  571. torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
  572. def convert_fp8(output: torch.Tensor,
  573. input: torch.Tensor,
  574. scale: float = 1.0,
  575. kv_dtype: str = "fp8") -> None:
  576. torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
  577. def get_device_attribute(attribute: int, device: int) -> int:
  578. return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
  579. def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
  580. # ruff: noqa: E501
  581. return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
  582. device)
  583. # custom ar
  584. def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
  585. handles: List[str], offsets: List[int], rank: int,
  586. full_nvlink: bool) -> int:
  587. return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
  588. offsets, rank, full_nvlink)
  589. def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
  590. full_nvlink: bool) -> bool:
  591. return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
  592. full_nvlink)
  593. def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
  594. torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
  595. def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
  596. out: torch.Tensor) -> None:
  597. torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
  598. def dispose(fa: int) -> None:
  599. torch.ops._C_custom_ar.dispose(fa)
  600. def meta_size() -> int:
  601. return torch.ops._C_custom_ar.meta_size()
  602. def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
  603. offsets: List[int]) -> None:
  604. return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
  605. def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
  606. return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
  607. def register_graph_buffers(fa: int, handles: List[str],
  608. offsets: List[List[int]]) -> None:
  609. torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
  610. # TODO: remove this later
  611. names_and_values = globals()
  612. names_and_values_to_update = {}
  613. # prepare variables to avoid dict size change during iteration
  614. k, v, arg = None, None, None
  615. fn_type = type(lambda x: x)
  616. for k, v in names_and_values.items():
  617. # find functions that are defined in this file and have torch.Tensor
  618. # in their annotations. `arg == "torch.Tensor"` is used to handle
  619. # the case when users use `import __annotations__` to turn type
  620. # hints into strings.
  621. if isinstance(v, fn_type) \
  622. and v.__code__.co_filename == __file__ \
  623. and any(arg is torch.Tensor or arg == "torch.Tensor"
  624. for arg in v.__annotations__.values()):
  625. names_and_values_to_update[k] = hint_on_error(v)
  626. names_and_values.update(names_and_values_to_update)
  627. del names_and_values_to_update, names_and_values, v, k, fn_type