|
@@ -22,23 +22,29 @@ class ipex_ops:
|
|
|
x2 = x2.reshape(num, d)
|
|
|
return x1, x2
|
|
|
|
|
|
+ @staticmethod
|
|
|
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
- x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
|
|
- ipex.llm.functional.silu_mul(x1, x2, out)
|
|
|
+ ipex.llm.functional.silu_and_mul(x, out)
|
|
|
|
|
|
+ @staticmethod
|
|
|
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
- x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
|
|
- ipex.llm.functional.gelu_mul(x1, x2, out, "none")
|
|
|
+ ipex.llm.functional.gelu_and_mul(x, out)
|
|
|
|
|
|
+ @staticmethod
|
|
|
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
- x1, x2 = ipex_ops._reshape_activation_tensor(x)
|
|
|
- ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
|
|
|
+ ipex.llm.functional.gelu_and_mul(x, out)
|
|
|
|
|
|
- def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
- out.copy_(torch.nn.functional.gelu(x))
|
|
|
+ @staticmethod
|
|
|
+ def gelu_fast(x: torch.Tensor) -> torch.Tensor:
|
|
|
+ return torch.nn.functional.gelu(x)
|
|
|
|
|
|
- def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
- out.copy_(torch.nn.functional.gelu(x))
|
|
|
+ @staticmethod
|
|
|
+ def gelu_new(x: torch.Tensor) -> torch.Tensor:
|
|
|
+ return torch.nn.functional.gelu(x)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
|
|
|
+ ipex.llm.functional.gelu_quick(x, out)
|
|
|
|
|
|
def paged_attention_v1(
|
|
|
out: torch.Tensor,
|
|
@@ -128,65 +134,25 @@ class ipex_ops:
|
|
|
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
|
|
|
is_neox: bool,
|
|
|
) -> None:
|
|
|
- if positions.dim() == 1:
|
|
|
- positions = positions.unsqueeze(0)
|
|
|
- query = query.unsqueeze(0)
|
|
|
- key = key.unsqueeze(0)
|
|
|
-
|
|
|
- rotary_dim = cos_sin_cache.size(1)
|
|
|
- query = query.view(*query.shape[:-1], -1, head_size)
|
|
|
- key = key.view(*key.shape[:-1], -1, head_size)
|
|
|
-
|
|
|
- query_rot = query[..., :rotary_dim]
|
|
|
- key_rot = key[..., :rotary_dim]
|
|
|
-
|
|
|
- cos_sin = cos_sin_cache[positions.long()]
|
|
|
- cos, sin = cos_sin.chunk(2, dim=-1)
|
|
|
-
|
|
|
- if is_neox:
|
|
|
- cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
|
|
- sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
|
|
- else:
|
|
|
- cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
|
- sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
|
- ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
|
|
- rotary_dim, is_neox, positions)
|
|
|
+ rot_dim = cos_sin_cache.size(1)
|
|
|
+ ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
|
|
+ head_size, cos_sin_cache,
|
|
|
+ is_neox, rot_dim)
|
|
|
|
|
|
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|
|
key: torch.Tensor, head_size: int,
|
|
|
cos_sin_cache: torch.Tensor, is_neox: bool,
|
|
|
rot_dim: int,
|
|
|
cos_sin_cache_offsets: torch.Tensor) -> None:
|
|
|
- if positions.dim() == 1:
|
|
|
- positions = positions.unsqueeze(0)
|
|
|
- query = query.unsqueeze(0)
|
|
|
- key = key.unsqueeze(0)
|
|
|
- cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
|
|
|
- rotary_dim = cos_sin_cache.size(1)
|
|
|
- query = query.view(*query.shape[:-1], -1, head_size)
|
|
|
- key = key.view(*key.shape[:-1], -1, head_size)
|
|
|
-
|
|
|
- query_rot = query[..., :rotary_dim]
|
|
|
- key_rot = key[..., :rotary_dim]
|
|
|
-
|
|
|
- cos_sin = cos_sin_cache[torch.add(positions,
|
|
|
- cos_sin_cache_offsets).long()]
|
|
|
- cos, sin = cos_sin.chunk(2, dim=-1)
|
|
|
-
|
|
|
- if is_neox:
|
|
|
- cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
|
|
- sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
|
|
- else:
|
|
|
- cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
|
- sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
|
|
-
|
|
|
- ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
|
|
|
- rotary_dim, is_neox, positions)
|
|
|
-
|
|
|
- def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
|
|
- epsilon: float) -> None:
|
|
|
- tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
|
|
|
- out.copy_(tmp)
|
|
|
+ ipex.llm.functional.rotary_embedding_batched(positions, query, key,
|
|
|
+ head_size, cos_sin_cache,
|
|
|
+ is_neox, rot_dim,
|
|
|
+ cos_sin_cache_offsets)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def rms_norm(input: torch.Tensor, weight: torch.Tensor,
|
|
|
+ epsilon: float) -> torch.Tensor:
|
|
|
+ return ipex.llm.functional.rms_norm(input, weight, epsilon)
|
|
|
|
|
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
|
|
weight: torch.Tensor, epsilon: float) -> None:
|
|
@@ -210,11 +176,14 @@ class ipex_ops:
|
|
|
return_softmax: bool,
|
|
|
gen_: torch.Generator,
|
|
|
) -> None:
|
|
|
- ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
|
|
|
- seqlen_k, max_seqlen_q,
|
|
|
- max_seqlen_k, pdropout,
|
|
|
- softmax_scale, zero_tensors,
|
|
|
- is_causal, return_softmax, gen_)
|
|
|
+ ipex.llm.functional.varlen_attention(query.contiguous(),
|
|
|
+ key.contiguous(),
|
|
|
+ value.contiguous(), out,
|
|
|
+ seqlen_q.int(), seqlen_k.int(),
|
|
|
+ max_seqlen_q, max_seqlen_k,
|
|
|
+ pdropout, softmax_scale,
|
|
|
+ zero_tensors, is_causal,
|
|
|
+ return_softmax, gen_)
|
|
|
|
|
|
def reshape_and_cache(
|
|
|
key: torch.Tensor,
|