_custom_ops.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. import contextlib
  2. from typing import List, Optional, Tuple, Type
  3. import torch
  4. try:
  5. import aphrodite._C
  6. except ImportError as e:
  7. from loguru import logger
  8. logger.warning("Failed to import from vllm._C with %r", e)
  9. with contextlib.suppress(ImportError):
  10. import aphrodite._moe_C
  11. with contextlib.suppress(ImportError):
  12. # ruff: noqa: F401
  13. import aphrodite._punica_C
  14. def is_custom_op_supported(op_name: str) -> bool:
  15. op, overloads = torch._C._jit_get_operation(op_name)
  16. return op is not None
  17. # activation ops
  18. def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  19. torch.ops._C.silu_and_mul(out, x)
  20. def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  21. torch.ops._C.gelu_and_mul(out, x)
  22. def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  23. torch.ops._C.gelu_tanh_and_mul(out, x)
  24. def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
  25. torch.ops._C.gelu_fast(out, x)
  26. def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
  27. torch.ops._C.gelu_new(out, x)
  28. # page attention ops
  29. def paged_attention_v1(
  30. out: torch.Tensor,
  31. query: torch.Tensor,
  32. key_cache: torch.Tensor,
  33. value_cache: torch.Tensor,
  34. num_kv_heads: int,
  35. scale: float,
  36. block_tables: torch.Tensor,
  37. seq_lens: torch.Tensor,
  38. block_size: int,
  39. max_seq_len: int,
  40. alibi_slopes: Optional[torch.Tensor],
  41. kv_cache_dtype: str,
  42. kv_scale: float,
  43. tp_rank: int = 0,
  44. blocksparse_local_blocks: int = 0,
  45. blocksparse_vert_stride: int = 0,
  46. blocksparse_block_size: int = 64,
  47. blocksparse_head_sliding_step: int = 0,
  48. ) -> None:
  49. torch.ops._C.paged_attention_v1(
  50. out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
  51. seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
  52. kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
  53. blocksparse_block_size, blocksparse_head_sliding_step)
  54. def paged_attention_v2(
  55. out: torch.Tensor,
  56. exp_sum: torch.Tensor,
  57. max_logits: torch.Tensor,
  58. tmp_out: torch.Tensor,
  59. query: torch.Tensor,
  60. key_cache: torch.Tensor,
  61. value_cache: torch.Tensor,
  62. num_kv_heads: int,
  63. scale: float,
  64. block_tables: torch.Tensor,
  65. seq_lens: torch.Tensor,
  66. block_size: int,
  67. max_seq_len: int,
  68. alibi_slopes: Optional[torch.Tensor],
  69. kv_cache_dtype: str,
  70. kv_scale: float,
  71. tp_rank: int = 0,
  72. blocksparse_local_blocks: int = 0,
  73. blocksparse_vert_stride: int = 0,
  74. blocksparse_block_size: int = 64,
  75. blocksparse_head_sliding_step: int = 0,
  76. ) -> None:
  77. torch.ops._C.paged_attention_v2(
  78. out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
  79. num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
  80. alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
  81. blocksparse_local_blocks, blocksparse_vert_stride,
  82. blocksparse_block_size, blocksparse_head_sliding_step)
  83. # pos encoding ops
  84. def rotary_embedding(
  85. positions: torch.Tensor,
  86. query: torch.Tensor,
  87. key: torch.Tensor,
  88. head_size: int,
  89. cos_sin_cache: torch.Tensor,
  90. is_neox: bool,
  91. ) -> None:
  92. torch.ops._C.rotary_embedding(positions, query, key, head_size,
  93. cos_sin_cache, is_neox)
  94. def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
  95. key: torch.Tensor, head_size: int,
  96. cos_sin_cache: torch.Tensor, is_neox: bool,
  97. rot_dim: int,
  98. cos_sin_cache_offsets: torch.Tensor) -> None:
  99. torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
  100. cos_sin_cache, is_neox, rot_dim,
  101. cos_sin_cache_offsets)
  102. # layer norm ops
  103. def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
  104. epsilon: float) -> None:
  105. torch.ops._C.rms_norm(out, input, weight, epsilon)
  106. def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
  107. weight: torch.Tensor, epsilon: float) -> None:
  108. torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
  109. # quantization ops
  110. # awq
  111. def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
  112. zeros: torch.Tensor, split_k_iters: int, thx: int,
  113. thy: int) -> torch.Tensor:
  114. return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
  115. thx, thy)
  116. def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
  117. scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
  118. return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
  119. # gptq
  120. def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  121. b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
  122. b_g_idx: torch.Tensor, use_exllama: bool,
  123. bit: int) -> torch.Tensor:
  124. return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
  125. b_g_idx, use_exllama, bit)
  126. def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
  127. bit: int) -> None:
  128. torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
  129. # squeezellm
  130. def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
  131. lookup_table: torch.Tensor) -> None:
  132. torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table)
  133. # marlin
  134. def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  135. b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
  136. size_n: int, size_k: int) -> torch.Tensor:
  137. return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
  138. size_n, size_k)
  139. # marlin_24
  140. def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  141. b_meta: torch.Tensor, b_scales: torch.Tensor,
  142. workspace: torch.Tensor, num_bits: int, size_m: int,
  143. size_n: int, size_k: int) -> torch.Tensor:
  144. return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
  145. workspace, num_bits, size_m,
  146. size_n, size_k)
  147. # cutlass
  148. def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
  149. scale_a: torch.Tensor, scale_b: torch.Tensor,
  150. out_dtype: Type[torch.dtype]) -> torch.Tensor:
  151. assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
  152. assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
  153. m = a.shape[0]
  154. n = b.shape[1]
  155. out = torch.empty((m, n), dtype=out_dtype, device=a.device)
  156. torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b)
  157. return out
  158. # aqlm
  159. def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
  160. codebooks: torch.Tensor, scales: torch.Tensor,
  161. codebook_partition_sizes: torch.Tensor,
  162. bias: Optional[torch.Tensor]) -> torch.Tensor:
  163. return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
  164. codebook_partition_sizes, bias)
  165. def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
  166. codebook_partition_sizes: torch.Tensor) -> torch.Tensor:
  167. return torch.ops._C.aqlm_dequant(codes, codebooks,
  168. codebook_partition_sizes)
  169. # gptq_marlin
  170. def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
  171. size_k: int, size_n: int,
  172. num_bits: int) -> torch.Tensor:
  173. return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
  174. num_bits)
  175. def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
  176. b_scales: torch.Tensor, g_idx: torch.Tensor,
  177. perm: torch.Tensor, workspace: torch.Tensor,
  178. num_bits: int, size_m: int, size_n: int, size_k: int,
  179. is_k_full: bool) -> torch.Tensor:
  180. return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
  181. workspace, num_bits, size_m, size_n,
  182. size_k, is_k_full)
  183. # fp8
  184. def scaled_fp8_quant(
  185. input: torch.Tensor,
  186. scale: Optional[torch.Tensor] = None,
  187. batch_dim_padding: Optional[int] = None,
  188. ) -> Tuple[torch.Tensor, torch.Tensor]:
  189. """
  190. Quantize input tensor to FP8 and return quantized tensor and scale.
  191. This function supports both static and dynamic quantization: If you
  192. provide the scale, it will use static scaling and if you omit it,
  193. the scale will be determined dynamically. The function also allows
  194. optional padding of the output tensor for downstream kernels that
  195. will benefit from padding.
  196. Args:
  197. input: The input tensor to be quantized to FP8
  198. scale: Optional scaling factor for the FP8 quantization
  199. batch_dim_padding: If specified, pad the first dimension
  200. of the output to at least this value.
  201. Returns:
  202. Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
  203. scaling factor.
  204. """
  205. if batch_dim_padding:
  206. shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
  207. output = torch.empty(shape,
  208. device=input.device,
  209. dtype=torch.float8_e4m3fn)
  210. else:
  211. output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
  212. if scale is None:
  213. scale = torch.zeros(1, device=input.device, dtype=torch.float32)
  214. torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
  215. else:
  216. torch.ops._C.static_scaled_fp8_quant(output, input, scale)
  217. return output, scale
  218. # int8
  219. def scaled_int8_quant(
  220. input: torch.Tensor,
  221. scale: Optional[torch.Tensor] = None
  222. ) -> Tuple[torch.Tensor, torch.Tensor]:
  223. """
  224. Quantize the input tensor to int8 and return the quantized tensor and scale.
  225. Args:
  226. input: The input tensor to be quantized to int8.
  227. scale: Optional scaling factor for the int8 quantization.
  228. When not provided, we invoke dynamic-per-token quantization.
  229. Returns:
  230. Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
  231. """
  232. output = torch.empty_like(input, dtype=torch.int8)
  233. if scale is not None:
  234. # static-per-tensor quantization.
  235. torch.ops._C.static_scaled_int8_quant(output, input, scale)
  236. return output, scale
  237. # dynamic-per-token quantization.
  238. input_scales = torch.empty((input.numel() // input.shape[-1], 1),
  239. device=input.device,
  240. dtype=torch.float32)
  241. torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
  242. return output, input_scales
  243. # quip#
  244. def quip_gemv(
  245. A: torch.Tensor,
  246. B: torch.Tensor,
  247. CB: torch.Tensor,
  248. ) -> torch.Tensor:
  249. return torch.ops._C.quip_gemv(A, B, CB)
  250. def quip_decompress(
  251. YIs: torch.Tensor,
  252. CB: torch.Tensor,
  253. Y: torch.Tensor,
  254. ) -> torch.Tensor:
  255. return torch.ops._C.quip_decompress(YIs, CB, Y)
  256. # moe
  257. def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
  258. block_size: int, sorted_token_ids: torch.Tensor,
  259. experts_ids: torch.Tensor,
  260. num_tokens_post_pad: torch.Tensor) -> None:
  261. torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
  262. sorted_token_ids, experts_ids,
  263. num_tokens_post_pad)
  264. def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
  265. token_expert_indicies: torch.Tensor,
  266. gating_output: float) -> None:
  267. torch.ops._moe_C.topk_softmax(topk_weights, topk_ids,
  268. token_expert_indicies, gating_output)
  269. def reshape_and_cache(
  270. key: torch.Tensor,
  271. value: torch.Tensor,
  272. key_cache: torch.Tensor,
  273. value_cache: torch.Tensor,
  274. slot_mapping: torch.Tensor,
  275. kv_cache_dtype: str,
  276. kv_scale: float,
  277. ) -> None:
  278. torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
  279. value_cache, slot_mapping,
  280. kv_cache_dtype, kv_scale)
  281. def reshape_and_cache_flash(
  282. key: torch.Tensor,
  283. value: torch.Tensor,
  284. key_cache: torch.Tensor,
  285. value_cache: torch.Tensor,
  286. slot_mapping: torch.Tensor,
  287. kv_cache_dtype: str,
  288. ) -> None:
  289. torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
  290. value_cache, slot_mapping,
  291. kv_cache_dtype)
  292. def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
  293. block_mapping: torch.Tensor) -> None:
  294. torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
  295. def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
  296. block_mapping: torch.Tensor) -> None:
  297. torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
  298. def convert_fp8(output: torch.Tensor,
  299. input: torch.Tensor,
  300. scale: float = 1.0,
  301. kv_dtype: str = "fp8") -> None:
  302. torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype)
  303. def get_device_attribute(attribute: int, device: int) -> int:
  304. return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
  305. def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
  306. # ruff: noqa: E501
  307. return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
  308. device)
  309. # custom ar
  310. def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
  311. handles: List[str], offsets: List[int], rank: int,
  312. full_nvlink: bool) -> int:
  313. return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
  314. offsets, rank, full_nvlink)
  315. def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
  316. full_nvlink: bool) -> bool:
  317. return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
  318. full_nvlink)
  319. def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
  320. torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)
  321. def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor,
  322. out: torch.Tensor) -> None:
  323. torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out)
  324. def dispose(fa: int) -> None:
  325. torch.ops._C_custom_ar.dispose(fa)
  326. def meta_size() -> int:
  327. return torch.ops._C_custom_ar.meta_size()
  328. def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
  329. offsets: List[int]) -> None:
  330. return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
  331. def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
  332. return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
  333. def register_graph_buffers(fa: int, handles: List[str],
  334. offsets: List[List[int]]) -> None:
  335. torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
  336. # punica
  337. def dispatch_bgmv(
  338. y: torch.Tensor,
  339. x: torch.Tensor,
  340. w_t_all: torch.Tensor,
  341. indicies: torch.Tensor,
  342. layer_idx: int,
  343. scale: float,
  344. ) -> None:
  345. torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
  346. scale)
  347. def dispatch_bgmv_low_level(
  348. y: torch.Tensor,
  349. x: torch.Tensor,
  350. w_t_all: torch.Tensor,
  351. indicies: torch.Tensor,
  352. layer_idx: int,
  353. scale: float,
  354. h_in: int,
  355. h_out: int,
  356. y_offset: int,
  357. ) -> None:
  358. torch.ops._punica_C.dispatch_bgmv_low_level(
  359. y,
  360. x,
  361. w_t_all,
  362. indicies,
  363. layer_idx,
  364. scale,
  365. h_in,
  366. h_out,
  367. y_offset,
  368. )