123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- from typing import List, Optional, Tuple
- import torch
- from loguru import logger
- try:
- import intel_extension_for_pytorch as ipex
- except ImportError as e:
- logger.warning(f"Import error msg: {e.msg}")
- class ipex_ops:
- @staticmethod
- def _reshape_activation_tensor(
- x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- num = x.size(0)
- d = x.size(1) // 2
- x = x.reshape(num, 2, d)
- x1, x2 = torch.chunk(x, chunks=2, dim=1)
- x1 = x1.reshape(num, d)
- x2 = x2.reshape(num, d)
- return x1, x2
- @staticmethod
- def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
- ipex.llm.functional.silu_and_mul(x, out)
- @staticmethod
- def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
- ipex.llm.functional.gelu_and_mul(x, out)
- @staticmethod
- def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
- ipex.llm.functional.gelu_and_mul(x, out)
- @staticmethod
- def gelu_fast(x: torch.Tensor) -> torch.Tensor:
- return torch.nn.functional.gelu(x)
- @staticmethod
- def gelu_new(x: torch.Tensor) -> torch.Tensor:
- return torch.nn.functional.gelu(x)
- @staticmethod
- def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
- ipex.llm.functional.gelu_quick(x, out)
- def paged_attention_v1(
- out: torch.Tensor,
- query: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- num_kv_heads: int,
- scale: float,
- block_tables: torch.Tensor,
- context_lens: torch.Tensor,
- block_size: int,
- max_context_len: int,
- alibi_slopes: Optional[torch.Tensor],
- kv_cache_dtype: str,
- k_scale: float,
- v_scale: float,
- tp_rank: int = 0,
- blocksparse_local_blocks: int = 0,
- blocksparse_vert_stride: int = 0,
- blocksparse_block_size: int = 64,
- blocksparse_head_sliding_step: int = 0,
- ) -> None:
- assert kv_cache_dtype == "auto"
- num_heads = out.size(1)
- num_queries_per_tokens = num_heads // num_kv_heads
- head_mapping = torch.arange(
- 0,
- num_kv_heads,
- device=query.device,
- dtype=torch.int32,
- ).view(num_kv_heads,
- 1).repeat_interleave(num_queries_per_tokens).flatten()
- # todo: ipex will refactor namespace
- torch.xpu.paged_attention_v1(out, query.contiguous(),
- key_cache.view_as(value_cache),
- value_cache, head_mapping, scale,
- block_tables, context_lens, block_size,
- max_context_len, alibi_slopes)
- def paged_attention_v2(
- out: torch.Tensor,
- exp_sum: torch.Tensor,
- max_logits: torch.Tensor,
- tmp_out: torch.Tensor,
- query: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- num_kv_heads: int,
- scale: float,
- block_tables: torch.Tensor,
- context_lens: torch.Tensor,
- block_size: int,
- max_context_len: int,
- alibi_slopes: Optional[torch.Tensor],
- kv_cache_dtype: str,
- k_scale: float,
- v_scale: float,
- tp_rank: int = 0,
- blocksparse_local_blocks: int = 0,
- blocksparse_vert_stride: int = 0,
- blocksparse_block_size: int = 64,
- blocksparse_head_sliding_step: int = 0,
- ) -> None:
- assert kv_cache_dtype == "auto"
- num_heads = out.size(1)
- num_queries_per_tokens = num_heads // num_kv_heads
- head_mapping = torch.arange(
- 0,
- num_kv_heads,
- dtype=torch.int32,
- device=query.device,
- ).view(num_kv_heads,
- 1).repeat_interleave(num_queries_per_tokens).flatten()
- # todo: ipex will refactor namespace
- torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
- query.contiguous(),
- key_cache.view_as(value_cache),
- value_cache, head_mapping, block_tables,
- context_lens, scale, block_size,
- max_context_len, alibi_slopes)
- def rotary_embedding(
- positions: torch.Tensor, # [batch_size, seq_len]
- query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
- key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
- head_size: int,
- cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
- is_neox: bool,
- ) -> None:
- rot_dim = cos_sin_cache.size(1)
- ipex.llm.functional.rotary_embedding_batched(positions, query, key,
- head_size, cos_sin_cache,
- is_neox, rot_dim)
- def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
- key: torch.Tensor, head_size: int,
- cos_sin_cache: torch.Tensor, is_neox: bool,
- rot_dim: int,
- cos_sin_cache_offsets: torch.Tensor) -> None:
- ipex.llm.functional.rotary_embedding_batched(positions, query, key,
- head_size, cos_sin_cache,
- is_neox, rot_dim,
- cos_sin_cache_offsets)
- @staticmethod
- def rms_norm(input: torch.Tensor, weight: torch.Tensor,
- epsilon: float) -> torch.Tensor:
- return ipex.llm.functional.rms_norm(input, weight, epsilon)
- def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
- weight: torch.Tensor, epsilon: float) -> None:
- tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
- epsilon, True)
- input.copy_(tmp)
- def varlen_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- out: torch.Tensor,
- seqlen_q: torch.Tensor,
- seqlen_k: torch.Tensor,
- max_seqlen_q: int,
- max_seqlen_k: int,
- pdropout: float,
- softmax_scale: float,
- zero_tensors: bool,
- is_causal: bool,
- return_softmax: bool,
- gen_: torch.Generator,
- ) -> None:
- ipex.llm.functional.varlen_attention(query.contiguous(),
- key.contiguous(),
- value.contiguous(), out,
- seqlen_q.int(), seqlen_k.int(),
- max_seqlen_q, max_seqlen_k,
- pdropout, softmax_scale,
- zero_tensors, is_causal,
- return_softmax, gen_)
- def reshape_and_cache(
- key: torch.Tensor,
- value: torch.Tensor,
- key_cache: torch.Tensor,
- value_cache: torch.Tensor,
- slot_mapping: torch.Tensor,
- kv_cache_dtype: str,
- k_scale: float,
- v_scale: float,
- ) -> None:
- assert kv_cache_dtype == "auto"
- ipex.llm.modules.PagedAttention.reshape_and_cache(
- key, value, key_cache, value_cache, slot_mapping)
- @staticmethod
- def copy_blocks(key_caches: List[torch.Tensor],
- value_caches: List[torch.Tensor],
- block_mapping: torch.Tensor) -> None:
- torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
- def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
- block_mapping: torch.Tensor) -> None:
- torch.xpu.swap_blocks(src, dst, block_mapping)
|