_custom_ops.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078
  1. import contextlib
  2. import functools
  3. from typing import List, Optional, Tuple, Type, Union
  4. import torch
  5. from loguru import logger
  6. import aphrodite.common.envs as envs
  7. from aphrodite._core_ext import ScalarType
  8. from aphrodite.common.utils import is_hip
  9. from aphrodite.platforms import current_platform
  10. if not current_platform.is_tpu():
  11. try:
  12. import aphrodite._C
  13. except ImportError as e:
  14. logger.warning(f"Failed to import from aphrodite._C with {e}")
  15. if current_platform.is_rocm():
  16. import aphrodite._rocm_C # noqa: F401
  17. with contextlib.suppress(ImportError):
  18. # ruff: noqa: F401
  19. import aphrodite._moe_C
  20. def hint_on_error(fn):
  21. @functools.wraps(fn)
  22. def wrapper(*args, **kwargs):
  23. try:
  24. return fn(*args, **kwargs)
  25. except AttributeError as e:
  26. msg = (
  27. f"Error in calling custom op {fn.__name__}: {e}\n"
  28. f"Possibly you have built or installed an obsolete version of aphrodite.\n"
  29. f"Please try a clean build and install of aphrodite,"
  30. f"or remove old built files such as aphrodite/*.so and build/ ."
  31. )
  32. logger.error(msg)
  33. raise e
  34. return wrapper
  35. # activation ops
  36. def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  37. torch.ops._C.silu_and_mul(out, x)
  38. def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  39. torch.ops._C.gelu_and_mul(out, x)
  40. def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  41. torch.ops._C.gelu_tanh_and_mul(out, x)
  42. def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
  43. torch.ops._C.gelu_fast(out, x)
  44. def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
  45. torch.ops._C.gelu_new(out, x)
  46. def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
  47. torch.ops._C.gelu_quick(out, x)
  48. # page attention ops
  49. def paged_attention_v1(
  50. out: torch.Tensor,
  51. query: torch.Tensor,
  52. key_cache: torch.Tensor,
  53. value_cache: torch.Tensor,
  54. num_kv_heads: int,
  55. scale: float,
  56. block_tables: torch.Tensor,
  57. seq_lens: torch.Tensor,
  58. block_size: int,
  59. max_seq_len: int,
  60. alibi_slopes: Optional[torch.Tensor],
  61. kv_cache_dtype: str,
  62. k_scale: float,
  63. v_scale: float,
  64. tp_rank: int = 0,
  65. blocksparse_local_blocks: int = 0,
  66. blocksparse_vert_stride: int = 0,
  67. blocksparse_block_size: int = 64,
  68. blocksparse_head_sliding_step: int = 0,
  69. ) -> None:
  70. torch.ops._C.paged_attention_v1(
  71. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
  72. seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
  73. k_scale, v_scale, tp_rank, blocksparse_local_blocks,
  74. blocksparse_vert_stride, blocksparse_block_size,
  75. blocksparse_head_sliding_step)
  76. def paged_attention_v2(
  77. out: torch.Tensor,
  78. exp_sum: torch.Tensor,
  79. max_logits: torch.Tensor,
  80. tmp_out: torch.Tensor,
  81. query: torch.Tensor,
  82. key_cache: torch.Tensor,
  83. value_cache: torch.Tensor,
  84. num_kv_heads: int,
  85. scale: float,
  86. block_tables: torch.Tensor,
  87. seq_lens: torch.Tensor,
  88. block_size: int,
  89. max_seq_len: int,
  90. alibi_slopes: Optional[torch.Tensor],
  91. kv_cache_dtype: str,
  92. k_scale: float,
  93. v_scale: float,
  94. tp_rank: int = 0,
  95. blocksparse_local_blocks: int = 0,
  96. blocksparse_vert_stride: int = 0,
  97. blocksparse_block_size: int = 64,
  98. blocksparse_head_sliding_step: int = 0,
  99. ) -> None:
  100. torch.ops._C.paged_attention_v2(
  101. out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
  102. num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
  103. alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
  104. blocksparse_local_blocks, blocksparse_vert_stride,
  105. blocksparse_block_size, blocksparse_head_sliding_step)
  106. def paged_attention_rocm(
  107. out: torch.Tensor,
  108. exp_sum: torch.Tensor,
  109. max_logits: torch.Tensor,
  110. tmp_out: torch.Tensor,
  111. query: torch.Tensor,
  112. key_cache: torch.Tensor,
  113. value_cache: torch.Tensor,
  114. num_kv_heads: int,
  115. scale: float,
  116. block_tables: torch.Tensor,
  117. seq_lens: torch.Tensor,
  118. block_size: int,
  119. max_seq_len: int,
  120. alibi_slopes: Optional[torch.Tensor],
  121. kv_cache_dtype: str,
  122. k_scale: float,
  123. v_scale: float,
  124. ) -> None:
  125. torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
  126. key_cache, value_cache, num_kv_heads,
  127. scale, block_tables, seq_lens,
  128. block_size, max_seq_len, alibi_slopes,
  129. kv_cache_dtype, k_scale, v_scale)
  130. # pos encoding ops
  131. def rotary_embedding(
  132. positions: torch.Tensor,
  133. query: torch.Tensor,
  134. key: torch.Tensor,
  135. head_size: int,
  136. cos_sin_cache: torch.Tensor,
  137. is_neox: bool,
  138. ) -> None:
  139. torch.ops._C.rotary_embedding(positions, query, key, head_size,
  140. cos_sin_cache, is_neox)
  141. def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
  142. key: torch.Tensor, head_size: int,
  143. cos_sin_cache: torch.Tensor, is_neox: bool,
  144. rot_dim: int,
  145. cos_sin_cache_offsets: torch.Tensor) -> None:
  146. torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
  147. cos_sin_cache, is_neox, rot_dim,
  148. cos_sin_cache_offsets)
  149. # layer norm ops
  150. def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
  151. epsilon: float) -> None:
  152. torch.ops._C.rms_norm(out, input, weight, epsilon)
  153. def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
  154. weight: torch.Tensor, epsilon: float) -> None:
  155. torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
  156. def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
  157. input_tokens: torch.Tensor,
  158. sampled_token_ids: torch.Tensor,
  159. input_positions: torch.Tensor,
  160. seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
  161. block_tables: torch.Tensor) -> None:
  162. """Advance a step on GPU for existing inputs for a multi-step runner"""
  163. return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
  164. block_size, input_tokens,
  165. sampled_token_ids,
  166. input_positions, seq_lens,
  167. slot_mapping, block_tables)
  168. def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
  169. input_tokens: torch.Tensor,
  170. sampled_token_ids: torch.Tensor,
  171. input_positions: torch.Tensor,
  172. seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
  173. block_tables: torch.Tensor,
  174. paged_kv_indices: torch.Tensor,
  175. paged_kv_indptr: torch.Tensor,
  176. paged_kv_last_page_len: torch.Tensor,
  177. block_table_bound: torch.Tensor) -> None:
  178. return torch.ops._C.advance_step_flashinfer(
  179. num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
  180. input_positions, seq_lens, slot_mapping, block_tables,
  181. paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
  182. block_table_bound)
  183. # quantization ops
  184. # awq
  185. def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
  186. zeros: torch.Tensor, split_k_iters: int, thx: int,
  187. thy: int) -> torch.Tensor:
  188. if envs.APHRODITE_USE_TRITON_AWQ:
  189. from aphrodite.quantization.awq_triton import awq_dequantize_triton
  190. return awq_dequantize_triton(qweight, scales, zeros)
  191. return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
  192. thx, thy)
  193. def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
  194. scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
  195. if envs.APHRODITE_USE_TRITON_AWQ:
  196. from aphrodite.quantization.awq_triton import awq_gemm_triton
  197. return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters)
  198. return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
  199. # gptq
  200. def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  201. b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
  202. b_g_idx: torch.Tensor, use_exllama: bool,
  203. bit: int) -> torch.Tensor:
  204. return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
  205. b_g_idx, use_exllama, bit)
  206. # TODO: has to be a better way to do this
  207. try:
  208. torch.ops._C.gptq_gemm # noqa B018
  209. @torch.library.register_fake("_C::gptq_gemm")
  210. def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
  211. b_gptq_qzeros: torch.Tensor,
  212. b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
  213. use_exllama: bool, bit: int) -> torch.Tensor:
  214. return torch.empty((a.size(0), b_q_weight.size(1)),
  215. dtype=a.dtype,
  216. device=a.device)
  217. except Exception:
  218. pass
  219. def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
  220. bit: int) -> None:
  221. torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
  222. # squeezellm
  223. def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
  224. lookup_table: torch.Tensor) -> None:
  225. torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
  226. # marlin
  227. def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  228. b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
  229. size_n: int, size_k: int) -> torch.Tensor:
  230. return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
  231. size_n, size_k)
  232. # marlin_24
  233. def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  234. b_meta: torch.Tensor, b_scales: torch.Tensor,
  235. workspace: torch.Tensor, b_q_type: ScalarType,
  236. size_m: int, size_n: int, size_k: int) -> torch.Tensor:
  237. return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
  238. workspace, b_q_type, size_m,
  239. size_n, size_k)
  240. # TODO: has to be a better way to do this
  241. try:
  242. torch.ops._C.gptq_marlin_24_gemm # noqa B018
  243. @torch.library.register_fake("_C::gptq_marlin_24_gemm")
  244. def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
  245. b_meta: torch.Tensor, b_scales: torch.Tensor,
  246. workspace: torch.Tensor,
  247. b_q_type: ScalarType, size_m: int,
  248. size_n: int, size_k: int) -> torch.Tensor:
  249. return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
  250. @torch.library.register_fake("_C::gptq_marlin_gemm")
  251. def _gptq_marlin_gemm_fake(a: torch.Tensor,
  252. b_q_weight: torch.Tensor,
  253. b_scales: torch.Tensor,
  254. b_zeros: torch.Tensor,
  255. g_idx: torch.Tensor,
  256. perm: torch.Tensor,
  257. workspace: torch.Tensor,
  258. b_q_type: ScalarType,
  259. size_m: int,
  260. size_n: int,
  261. size_k: int,
  262. is_k_full: bool,
  263. has_zp: bool = False,
  264. use_fp32_reduce: bool = False) -> torch.Tensor:
  265. return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
  266. @torch.library.register_fake("_C::ggml_dequantize")
  267. def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
  268. n: int) -> torch.Tensor:
  269. return torch.empty((m, n), dtype=torch.float16, device=W.device)
  270. @torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
  271. def _ggml_mul_mat_vec_a8_fake(
  272. W: torch.Tensor,
  273. X: torch.Tensor,
  274. quant_type: int,
  275. row: int,
  276. ) -> torch.Tensor:
  277. return torch.empty((1, row), dtype=torch.float16, device=W.device)
  278. @torch.library.register_fake("_C::ggml_mul_mat_a8")
  279. def _ggml_mul_mat_a8_fake(
  280. W: torch.Tensor,
  281. X: torch.Tensor,
  282. quant_type: int,
  283. row: int,
  284. ) -> torch.Tensor:
  285. batch = X.size(0)
  286. return torch.empty((batch, row), dtype=torch.float16, device=W.device)
  287. @torch.library.register_fake("_C::marlin_qqq_gemm")
  288. def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
  289. s_tok: torch.Tensor, s_ch: torch.Tensor,
  290. s_group: torch.Tensor, workspace: torch.Tensor,
  291. size_m: int, size_n: int,
  292. size_k: int) -> torch.Tensor:
  293. return torch.empty((size_m, size_n),
  294. dtype=torch.float16,
  295. device=a.device)
  296. @torch.library.register_fake("_C::marlin_gemm")
  297. def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
  298. b_scales: torch.Tensor, workspace: torch.Tensor,
  299. size_m: int, size_n: int,
  300. size_k: int) -> torch.Tensor:
  301. return torch.empty((size_m, size_n),
  302. dtype=torch.float16,
  303. device=a.device)
  304. @torch.library.register_fake("_C::awq_dequantize")
  305. def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
  306. zeros: torch.Tensor, split_k_iters: int, thx: int,
  307. thy: int) -> torch.Tensor:
  308. in_c = qweight.size(0)
  309. qout_c = qweight.size(1)
  310. out_c = qout_c * 8
  311. return torch.empty((in_c, out_c),
  312. dtype=scales.dtype,
  313. device=scales.device)
  314. @torch.library.register_fake("_C::awq_gemm")
  315. def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
  316. qzeros: torch.Tensor, scales: torch.Tensor,
  317. split_k_iters: int) -> torch.Tensor:
  318. num_in_feats = input.size(0)
  319. return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
  320. dtype=input.dtype,
  321. device=input.device).sum(0)
  322. @torch.library.register_fake("_C::aqlm_gemm")
  323. def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
  324. codebooks: torch.Tensor, scales: torch.Tensor,
  325. codebook_partition_sizes: List[int],
  326. bias: Optional[torch.Tensor]) -> torch.Tensor:
  327. out_features = codes.size(0) * codebooks.size(2)
  328. flat_input = input.reshape((-1, input.size(-1)))
  329. flat_output = torch.empty((flat_input.size(0), out_features),
  330. dtype=input.dtype,
  331. device=input.device)
  332. output_sizes = list(input.shape)
  333. output_sizes.pop()
  334. output_sizes.append(-1)
  335. return flat_output.reshape(tuple(output_sizes))
  336. @torch.library.register_fake("_C::aqlm_dequant")
  337. def _aqlm_dequant_fake(
  338. codes: torch.Tensor, codebooks: torch.Tensor,
  339. codebook_partition_sizes: List[int]) -> torch.Tensor:
  340. in_features = codes.size(1) * 8
  341. out_features = codes.size(0)
  342. return torch.empty((out_features, in_features),
  343. dtype=codebooks.dtype,
  344. device=codebooks.device)
  345. @torch.library.register_fake("_C::fp8_marlin_gemm")
  346. def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
  347. b_scales: torch.Tensor, workspace: torch.Tensor,
  348. num_bits: int, size_m: int, size_n: int,
  349. size_k: int) -> torch.Tensor:
  350. return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
  351. @torch.library.register_fake("_C::machete_gemm")
  352. def machete_gemm_fake(
  353. a: torch.Tensor,
  354. b_q: torch.
  355. Tensor, # Should be the tensor returned by machete_prepack_B
  356. b_type: ScalarType,
  357. b_scales: Optional[torch.Tensor] = None,
  358. b_zeros: Optional[torch.Tensor] = None,
  359. b_group_size: Optional[int] = None,
  360. c: Optional[torch.Tensor] = None,
  361. alpha: Optional[float] = None,
  362. beta: Optional[float] = None,
  363. schedule: Optional[str] = None,
  364. ) -> torch.Tensor:
  365. m = a.size(0)
  366. n = b_q.size(1)
  367. return torch.empty((m, n), device=a.device, dtype=a.dtype)
  368. @torch.library.register_fake("_C::machete_prepack_B")
  369. def machete_prepack_B_fake(b_q_weight: torch.Tensor,
  370. b_type: ScalarType) -> torch.Tensor:
  371. return torch.empty_like(b_q_weight)
  372. @torch.library.register_fake("_C::causal_conv1d_fwd")
  373. def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
  374. bias_: Optional[torch.Tensor],
  375. seq_idx_: Optional[torch.Tensor],
  376. initial_states_: Optional[torch.Tensor],
  377. final_states_out_: Optional[torch.Tensor],
  378. silu_activation: bool) -> torch.Tensor:
  379. return torch.empty_like(x)
  380. @torch.library.register_fake("_C::causal_conv1d_update")
  381. def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
  382. weight: torch.Tensor,
  383. bias_: Optional[torch.Tensor],
  384. silu_activation: bool) -> torch.Tensor:
  385. return torch.empty_like(x)
  386. @torch.library.register_fake("_C::selective_scan_fwd")
  387. def selective_scan_fwd_fake(
  388. u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
  389. B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
  390. z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
  391. delta_softplus: bool, index_: Optional[torch.Tensor],
  392. x: Optional[torch.Tensor]) -> List[torch.Tensor]:
  393. a = torch.empty_like(u)
  394. if x is not None:
  395. b = x
  396. else:
  397. b = torch.empty((u.size(0), u.size(1), A.size(1)),
  398. dtype=u.dtype,
  399. device=u.device)
  400. if z_ is not None:
  401. c = torch.empty_like(z_)
  402. return [a, b, c]
  403. else:
  404. return [a, b]
  405. except Exception:
  406. pass
  407. # fp8 marlin
  408. def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  409. b_scales: torch.Tensor, workspace: torch.Tensor,
  410. num_bits: int, size_m: int, size_n: int,
  411. size_k: int) -> torch.Tensor:
  412. return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
  413. num_bits, size_m, size_n, size_k)
  414. # cutlass
  415. def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
  416. return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
  417. def cutlass_scaled_mm(a: torch.Tensor,
  418. b: torch.Tensor,
  419. scale_a: torch.Tensor,
  420. scale_b: torch.Tensor,
  421. out_dtype: Type[torch.dtype],
  422. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  423. assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
  424. assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
  425. assert bias is None or bias.shape[0] == b.shape[
  426. 1] and bias.dtype == out_dtype
  427. m = a.shape[0]
  428. n = b.shape[1]
  429. out = torch.empty((m, n), dtype=out_dtype, device=a.device)
  430. torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
  431. return out
  432. def cutlass_scaled_mm_azp(a: torch.Tensor,
  433. b: torch.Tensor,
  434. scale_a: torch.Tensor,
  435. scale_b: torch.Tensor,
  436. out_dtype: torch.dtype,
  437. azp_adj: torch.Tensor,
  438. azp: Optional[torch.Tensor] = None,
  439. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  440. assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
  441. assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
  442. assert bias is None or bias.numel(
  443. ) == b.shape[1] and bias.dtype == out_dtype
  444. m = a.shape[0]
  445. n = b.shape[1]
  446. out = torch.empty((m, n), dtype=out_dtype, device=a.device)
  447. torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
  448. azp, bias)
  449. return out
  450. # aqlm
  451. def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
  452. codebooks: torch.Tensor, scales: torch.Tensor,
  453. codebook_partition_sizes: List[int],
  454. bias: Optional[torch.Tensor]) -> torch.Tensor:
  455. return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
  456. codebook_partition_sizes, bias)
  457. def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
  458. codebook_partition_sizes: List[int]) -> torch.Tensor:
  459. return torch.ops._C.aqlm_dequant(codes, codebooks,
  460. codebook_partition_sizes)
  461. # gptq_marlin
  462. def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
  463. size_k: int, size_n: int,
  464. num_bits: int) -> torch.Tensor:
  465. return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
  466. num_bits)
  467. def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int,
  468. num_bits: int) -> torch.Tensor:
  469. return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
  470. def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
  471. size_k: int, size_n: int,
  472. num_bits: int) -> torch.Tensor:
  473. num_experts = b_q_weight.shape[0]
  474. assert size_k % 16 == 0
  475. output = torch.empty((num_experts, size_k // 16, size_n * 2),
  476. device=b_q_weight.device,
  477. dtype=b_q_weight.dtype)
  478. for e in range(num_experts):
  479. output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e],
  480. size_k, size_n, num_bits)
  481. return output
  482. def gptq_marlin_gemm(a: torch.Tensor,
  483. b_q_weight: torch.Tensor,
  484. b_scales: torch.Tensor,
  485. b_zeros: torch.Tensor,
  486. g_idx: torch.Tensor,
  487. perm: torch.Tensor,
  488. workspace: torch.Tensor,
  489. b_q_type: ScalarType,
  490. size_m: int,
  491. size_n: int,
  492. size_k: int,
  493. is_k_full: bool,
  494. has_zp: bool = False,
  495. use_fp32_reduce: bool = False,
  496. is_zp_float: bool = False) -> torch.Tensor:
  497. return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
  498. g_idx, perm, workspace, b_q_type,
  499. size_m, size_n, size_k, is_k_full,
  500. has_zp, use_fp32_reduce,
  501. is_zp_float)
  502. # machete
  503. def machete_supported_schedules(b_type: ScalarType) -> List[str]:
  504. return torch.ops._C.machete_supported_schedules(b_type)
  505. def machete_gemm(
  506. a: torch.Tensor,
  507. b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
  508. b_type: ScalarType,
  509. b_scales: Optional[torch.Tensor] = None,
  510. b_zeros: Optional[torch.Tensor] = None,
  511. b_group_size: Optional[int] = None,
  512. c: Optional[torch.Tensor] = None,
  513. alpha: Optional[float] = None,
  514. beta: Optional[float] = None,
  515. schedule: Optional[str] = None,
  516. ) -> torch.Tensor:
  517. return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
  518. b_group_size, c, alpha, beta, schedule)
  519. def machete_prepack_B(b_q_weight: torch.Tensor,
  520. b_type: ScalarType) -> torch.Tensor:
  521. return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
  522. def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
  523. return torch.ops._C.permute_cols(a, perm)
  524. # fp8
  525. def scaled_fp8_quant(
  526. input: torch.Tensor,
  527. scale: Optional[torch.Tensor] = None,
  528. num_token_padding: Optional[int] = None,
  529. scale_ub: Optional[torch.Tensor] = None,
  530. use_per_token_if_dynamic: bool = False,
  531. ) -> Tuple[torch.Tensor, torch.Tensor]:
  532. """
  533. Quantize input tensor to FP8 and return quantized tensor and scale.
  534. This function supports both static and dynamic quantization: If you
  535. provide the scale, it will use static scaling and if you omit it,
  536. the scale will be determined dynamically. The function also allows
  537. optional padding of the output tensors for downstream kernels that
  538. will benefit from padding.
  539. Args:
  540. input: The input tensor to be quantized to FP8
  541. scale: Optional scaling factor for the FP8 quantization
  542. num_token_padding: If specified, pad the first dimension
  543. of the output to at least this value.
  544. use_per_token_if_dynamic: Whether to do per_tensor or per_token
  545. in the dynamic quantization case.
  546. Returns:
  547. Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
  548. scaling factor.
  549. """
  550. # This code assumes batch_dim and num_tokens are flattened
  551. assert (input.ndim == 2)
  552. shape = input.shape
  553. # For rocm, the output fp8 dtype is torch.float_e3m3fnuz
  554. out_dtype: torch.dtype = torch.float8_e4m3fnuz if \
  555. is_hip() else torch.float8_e4m3fn
  556. if num_token_padding:
  557. shape = (max(num_token_padding, input.shape[0]), shape[1])
  558. output = torch.empty(shape, device=input.device, dtype=out_dtype)
  559. if scale is None:
  560. if use_per_token_if_dynamic:
  561. scale = torch.empty((shape[0], 1),
  562. device=input.device,
  563. dtype=torch.float32)
  564. torch.ops._C.dynamic_per_token_scaled_fp8_quant(
  565. output, input, scale, scale_ub)
  566. else:
  567. scale = torch.zeros(1, device=input.device, dtype=torch.float32)
  568. torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
  569. else:
  570. # num_token_padding not implemented for this case
  571. assert (scale.numel() == 1 or num_token_padding is None)
  572. torch.ops._C.static_scaled_fp8_quant(output, input, scale)
  573. return output, scale
  574. # int8
  575. def scaled_int8_quant(
  576. input: torch.Tensor,
  577. scale: Optional[torch.Tensor] = None,
  578. azp: Optional[torch.Tensor] = None,
  579. symmetric: bool = True
  580. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  581. """
  582. Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
  583. Args:
  584. input: The input tensor to be quantized to int8.
  585. scale: Optional scaling factor for the int8 quantization.
  586. When not provided, we invoke dynamic-per-token quantization.
  587. azp: Optional zero-point for the int8 quantization.
  588. Must be provided for asymmetric quantization if `scale` is provided.
  589. symmetric: Whether to use symmetric quantization (scale only, azp ignored).
  590. Returns:
  591. Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
  592. """
  593. output = torch.empty_like(input, dtype=torch.int8)
  594. if scale is not None:
  595. # static-per-tensor quantization.
  596. assert symmetric == (
  597. azp is
  598. None), "azp must only be provided for asymmetric quantization."
  599. torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
  600. return output, scale, None
  601. # dynamic-per-token quantization.
  602. input_scales = torch.empty((input.numel() // input.shape[-1], 1),
  603. device=input.device,
  604. dtype=torch.float32)
  605. input_azp = None if symmetric else torch.empty_like(input_scales,
  606. dtype=torch.int32)
  607. torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
  608. input_azp)
  609. return output, input_scales, input_azp
  610. # quip#
  611. def quip_gemv(
  612. A: torch.Tensor,
  613. B: torch.Tensor,
  614. CB: torch.Tensor,
  615. ) -> torch.Tensor:
  616. return torch.ops._C.quip_gemv(A, B, CB)
  617. def quip_decompress(
  618. YIs: torch.Tensor,
  619. CB: torch.Tensor,
  620. Y: torch.Tensor,
  621. ) -> torch.Tensor:
  622. return torch.ops._C.quip_decompress(YIs, CB, Y)
  623. # qqq ops
  624. def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  625. s_tok: torch.Tensor, s_ch: torch.Tensor,
  626. s_group: torch.Tensor, workspace: torch.Tensor,
  627. size_m: int, size_n: int, size_k: int) -> torch.Tensor:
  628. return torch.ops._C.marlin_qqq_gemm(a, b_q_weight, s_tok, s_ch, s_group,
  629. workspace, size_m, size_n, size_k)
  630. # gguf
  631. def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
  632. n: int) -> torch.Tensor:
  633. return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
  634. def ggml_mul_mat_vec_a8(
  635. W: torch.Tensor,
  636. X: torch.Tensor,
  637. quant_type: int,
  638. row: int,
  639. ) -> torch.Tensor:
  640. return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
  641. def ggml_mul_mat_a8(
  642. W: torch.Tensor,
  643. X: torch.Tensor,
  644. quant_type: int,
  645. row: int,
  646. ) -> torch.Tensor:
  647. return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)
  648. # fp6
  649. def fp_eXmY_linear_forward_cuda(
  650. EXPONENT: int,
  651. MANTISSA: int,
  652. _in_feats: torch.Tensor,
  653. _weights: torch.Tensor,
  654. _scales: torch.Tensor,
  655. splitK: int = 1,
  656. ) -> torch.Tensor:
  657. return torch.ops._C.fp_eXmY_linear_forward_cuda(EXPONENT, MANTISSA,
  658. _in_feats, _weights,
  659. _scales, splitK)
  660. # mamba
  661. def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
  662. bias_: Optional[torch.Tensor],
  663. seq_idx_: Optional[torch.Tensor],
  664. initial_states_: Optional[torch.Tensor],
  665. final_states_out_: Optional[torch.Tensor],
  666. silu_activation: bool) -> torch.Tensor:
  667. return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
  668. initial_states_, final_states_out_,
  669. silu_activation)
  670. def causal_conv1d_update(
  671. x: torch.Tensor,
  672. conv_state: torch.Tensor,
  673. weight: torch.Tensor,
  674. bias_: Optional[torch.Tensor],
  675. silu_activation: bool,
  676. conv_state_indices: Optional[torch.Tensor],
  677. ) -> torch.Tensor:
  678. return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
  679. silu_activation,
  680. conv_state_indices)
  681. def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
  682. B: torch.Tensor, C: torch.Tensor,
  683. D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
  684. delta_bias_: Optional[torch.Tensor],
  685. delta_softplus: bool, index_: Optional[torch.Tensor],
  686. x: Optional[torch.Tensor]) -> List[torch.Tensor]:
  687. return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_,
  688. delta_bias_, delta_softplus, index_,
  689. x)
  690. # moe
  691. def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
  692. block_size: int, sorted_token_ids: torch.Tensor,
  693. experts_ids: torch.Tensor,
  694. num_tokens_post_pad: torch.Tensor) -> None:
  695. torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
  696. sorted_token_ids, experts_ids,
  697. num_tokens_post_pad)
  698. def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  699. token_expert_indicies: torch.Tensor,
  700. gating_output: float) -> None:
  701. torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
  702. token_expert_indicies, gating_output)
  703. def reshape_and_cache(
  704. key: torch.Tensor,
  705. value: torch.Tensor,
  706. key_cache: torch.Tensor,
  707. value_cache: torch.Tensor,
  708. slot_mapping: torch.Tensor,
  709. kv_cache_dtype: str,
  710. k_scale: float,
  711. v_scale: float,
  712. ) -> None:
  713. torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
  714. value_cache, slot_mapping,
  715. kv_cache_dtype, k_scale, v_scale)
  716. def reshape_and_cache_flash(
  717. key: torch.Tensor,
  718. value: torch.Tensor,
  719. key_cache: torch.Tensor,
  720. value_cache: torch.Tensor,
  721. slot_mapping: torch.Tensor,
  722. kv_cache_dtype: str,
  723. k_scale: float,
  724. v_scale: float,
  725. ) -> None:
  726. torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
  727. value_cache, slot_mapping,
  728. kv_cache_dtype, k_scale,
  729. v_scale)
  730. def copy_blocks(key_caches: List[torch.Tensor],
  731. value_caches: List[torch.Tensor],
  732. block_mapping: torch.Tensor) -> None:
  733. torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
  734. def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
  735. block_mapping: torch.Tensor) -> None:
  736. torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
  737. def convert_fp8(output: torch.Tensor,
  738. input: torch.Tensor,
  739. scale: float = 1.0,
  740. kv_dtype: str = "fp8") -> None:
  741. torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
  742. def get_device_attribute(attribute: int, device: int) -> int:
  743. return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
  744. def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
  745. # ruff: noqa: E501
  746. return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
  747. device)
  748. # custom ar
  749. def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
  750. handles: List[str], offsets: List[int], rank: int,
  751. full_nvlink: bool) -> int:
  752. return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
  753. offsets, rank, full_nvlink)
  754. def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
  755. torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
  756. def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
  757. out: torch.Tensor) -> None:
  758. torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
  759. def dispose(fa: int) -> None:
  760. torch.ops._C_custom_ar.dispose(fa)
  761. def meta_size() -> int:
  762. return torch.ops._C_custom_ar.meta_size()
  763. def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
  764. offsets: List[int]) -> None:
  765. return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
  766. def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
  767. return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
  768. def register_graph_buffers(fa: int, handles: List[str],
  769. offsets: List[List[int]]) -> None:
  770. torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
  771. # Sampling Kernels
  772. def sampling_from_probs(probs: torch.Tensor,
  773. uniform_samplers: torch.Tensor,
  774. deterministic: bool = True,
  775. check_nan: bool = False) -> torch.Tensor:
  776. if check_nan and torch.any(torch.isnan(probs)):
  777. raise ValueError("NaN detected in probs")
  778. return torch.ops._C.sampling_from_probs(probs, uniform_samplers,
  779. deterministic)
  780. def _to_tensor_scalar_tuple(x):
  781. if isinstance(x, torch.Tensor):
  782. return (x, 0)
  783. else:
  784. return (None, x)
  785. def top_p_sampling_from_probs(
  786. probs: torch.Tensor,
  787. uniform_samples: torch.Tensor,
  788. top_p: Union[torch.Tensor, float],
  789. deterministic: bool = True,
  790. check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
  791. if check_nan and torch.any(torch.isnan(probs)):
  792. raise ValueError("NaN detected in probs")
  793. return torch.ops._C.top_p_sampling_from_probs(
  794. probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic)
  795. def top_k_sampling_from_probs(
  796. probs: torch.Tensor,
  797. uniform_samples: torch.Tensor,
  798. top_k: Union[torch.Tensor, int],
  799. deterministic: bool = True,
  800. check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
  801. if check_nan and torch.any(torch.isnan(probs)):
  802. raise ValueError("NaN detected in probs")
  803. return torch.ops._C.top_k_sampling_from_probs(
  804. probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic)
  805. def min_p_sampling_from_probs(
  806. probs: torch.Tensor,
  807. uniform_samples: torch.Tensor,
  808. min_p: Union[torch.Tensor, float],
  809. deterministic: bool = True,
  810. check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
  811. if check_nan and torch.any(torch.isnan(probs)):
  812. raise ValueError("NaN detected in probs")
  813. return torch.ops._C.min_p_sampling_from_probs(
  814. probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic)
  815. def top_k_mask_logits(
  816. logits: torch.Tensor,
  817. top_k: Union[torch.Tensor, int],
  818. ) -> torch.Tensor:
  819. return torch.ops._C.top_k_mask_logits(logits,
  820. *_to_tensor_scalar_tuple(top_k))
  821. def top_p_renorm_prob(
  822. probs: torch.Tensor,
  823. top_p: Union[torch.Tensor, float],
  824. ) -> torch.Tensor:
  825. return torch.ops._C.top_p_renorm_prob(probs,
  826. *_to_tensor_scalar_tuple(top_p))
  827. def top_k_renorm_prob(
  828. probs: torch.Tensor,
  829. top_k: Union[torch.Tensor, int],
  830. ) -> torch.Tensor:
  831. return torch.ops._C.top_k_renorm_prob(probs,
  832. *_to_tensor_scalar_tuple(top_k))
  833. def top_k_top_p_sampling_from_logits(
  834. probs: torch.Tensor,
  835. uniform_samples: torch.Tensor,
  836. top_k: Union[torch.Tensor, int],
  837. top_p: Union[torch.Tensor, float],
  838. filter_apply_order: str = "top_k_first",
  839. deterministic: bool = True,
  840. check_nan: bool = False,
  841. ) -> Tuple[torch.Tensor, torch.Tensor]:
  842. if filter_apply_order == "top_k_first":
  843. masked_logits = top_k_mask_logits(probs, top_k)
  844. probs = torch.softmax(masked_logits, dim=-1)
  845. return top_p_sampling_from_probs(probs, uniform_samples, top_p,
  846. deterministic, check_nan)
  847. elif filter_apply_order == "joint":
  848. probs = torch.softmax(probs, dim=-1)
  849. if check_nan and torch.any(torch.isnan(probs)):
  850. raise ValueError("NaN detected in probs")
  851. return torch.ops._C.top_k_top_p_sampling_from_logits(
  852. probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
  853. *_to_tensor_scalar_tuple(top_p), deterministic)
  854. else:
  855. raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
  856. def top_k_top_p_sampling_from_probs(
  857. probs: torch.Tensor,
  858. uniform_samples: torch.Tensor,
  859. top_k: Union[torch.Tensor, int],
  860. top_p: Union[torch.Tensor, float],
  861. filter_apply_order: str = "top_k_first",
  862. deterministic: bool = True,
  863. check_nan: bool = False,
  864. ) -> Tuple[torch.Tensor, torch.Tensor]:
  865. if filter_apply_order == "top_k_first":
  866. renorm_probs = top_k_renorm_prob(probs, top_k)
  867. return top_p_sampling_from_probs(renorm_probs, uniform_samples, top_p,
  868. deterministic, check_nan)
  869. elif filter_apply_order == "joint":
  870. if check_nan and torch.any(torch.isnan(probs)):
  871. raise ValueError("NaN detected in probs")
  872. return torch.ops._C.top_k_top_p_sampling_from_probs(
  873. probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
  874. *_to_tensor_scalar_tuple(top_p), deterministic)
  875. else:
  876. raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
  877. # TODO: remove this later
  878. names_and_values = globals()
  879. names_and_values_to_update = {}
  880. # prepare variables to avoid dict size change during iteration
  881. k, v, arg = None, None, None
  882. fn_type = type(lambda x: x)
  883. for k, v in names_and_values.items():
  884. # find functions that are defined in this file and have torch.Tensor
  885. # in their annotations. `arg == "torch.Tensor"` is used to handle
  886. # the case when users use `import __annotations__` to turn type
  887. # hints into strings.
  888. if isinstance(v, fn_type) \
  889. and v.__code__.co_filename == __file__ \
  890. and any(arg is torch.Tensor or arg == "torch.Tensor"
  891. for arg in v.__annotations__.values()):
  892. names_and_values_to_update[k] = hint_on_error(v)
  893. names_and_values.update(names_and_values_to_update)
  894. del names_and_values_to_update, names_and_values, v, k, fn_type