1
0

_ipex_ops.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from loguru import logger
  4. try:
  5. import intel_extension_for_pytorch as ipex
  6. except ImportError as e:
  7. logger.warning(f"Import error msg: {e.msg}")
  8. class ipex_ops:
  9. @staticmethod
  10. def _reshape_activation_tensor(
  11. x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  12. num = x.size(0)
  13. d = x.size(1) // 2
  14. x = x.reshape(num, 2, d)
  15. x1, x2 = torch.chunk(x, chunks=2, dim=1)
  16. x1 = x1.reshape(num, d)
  17. x2 = x2.reshape(num, d)
  18. return x1, x2
  19. def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  20. x1, x2 = ipex_ops._reshape_activation_tensor(x)
  21. ipex.llm.functional.silu_mul(x1, x2, out)
  22. def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  23. x1, x2 = ipex_ops._reshape_activation_tensor(x)
  24. ipex.llm.functional.gelu_mul(x1, x2, out, "none")
  25. def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  26. x1, x2 = ipex_ops._reshape_activation_tensor(x)
  27. ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
  28. def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
  29. out.copy_(torch.nn.functional.gelu(x))
  30. def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
  31. out.copy_(torch.nn.functional.gelu(x))
  32. def paged_attention_v1(
  33. out: torch.Tensor,
  34. query: torch.Tensor,
  35. key_cache: torch.Tensor,
  36. value_cache: torch.Tensor,
  37. num_kv_heads: int,
  38. scale: float,
  39. block_tables: torch.Tensor,
  40. context_lens: torch.Tensor,
  41. block_size: int,
  42. max_context_len: int,
  43. alibi_slopes: Optional[torch.Tensor],
  44. kv_cache_dtype: str,
  45. k_scale: float,
  46. v_scale: float,
  47. tp_rank: int = 0,
  48. blocksparse_local_blocks: int = 0,
  49. blocksparse_vert_stride: int = 0,
  50. blocksparse_block_size: int = 64,
  51. blocksparse_head_sliding_step: int = 0,
  52. ) -> None:
  53. assert kv_cache_dtype == "auto"
  54. num_heads = out.size(1)
  55. num_queries_per_tokens = num_heads // num_kv_heads
  56. head_mapping = torch.arange(
  57. 0,
  58. num_kv_heads,
  59. device=query.device,
  60. dtype=torch.int32,
  61. ).view(num_kv_heads,
  62. 1).repeat_interleave(num_queries_per_tokens).flatten()
  63. # todo: ipex will refactor namespace
  64. torch.xpu.paged_attention_v1(out, query.contiguous(),
  65. key_cache.view_as(value_cache),
  66. value_cache, head_mapping, scale,
  67. block_tables, context_lens, block_size,
  68. max_context_len, alibi_slopes)
  69. def paged_attention_v2(
  70. out: torch.Tensor,
  71. exp_sum: torch.Tensor,
  72. max_logits: torch.Tensor,
  73. tmp_out: torch.Tensor,
  74. query: torch.Tensor,
  75. key_cache: torch.Tensor,
  76. value_cache: torch.Tensor,
  77. num_kv_heads: int,
  78. scale: float,
  79. block_tables: torch.Tensor,
  80. context_lens: torch.Tensor,
  81. block_size: int,
  82. max_context_len: int,
  83. alibi_slopes: Optional[torch.Tensor],
  84. kv_cache_dtype: str,
  85. k_scale: float,
  86. v_scale: float,
  87. tp_rank: int = 0,
  88. blocksparse_local_blocks: int = 0,
  89. blocksparse_vert_stride: int = 0,
  90. blocksparse_block_size: int = 64,
  91. blocksparse_head_sliding_step: int = 0,
  92. ) -> None:
  93. assert kv_cache_dtype == "auto"
  94. num_heads = out.size(1)
  95. num_queries_per_tokens = num_heads // num_kv_heads
  96. head_mapping = torch.arange(
  97. 0,
  98. num_kv_heads,
  99. dtype=torch.int32,
  100. device=query.device,
  101. ).view(num_kv_heads,
  102. 1).repeat_interleave(num_queries_per_tokens).flatten()
  103. # todo: ipex will refactor namespace
  104. torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
  105. query.contiguous(),
  106. key_cache.view_as(value_cache),
  107. value_cache, head_mapping, block_tables,
  108. context_lens, scale, block_size,
  109. max_context_len, alibi_slopes)
  110. def rotary_embedding(
  111. positions: torch.Tensor, # [batch_size, seq_len]
  112. query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
  113. key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
  114. head_size: int,
  115. cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
  116. is_neox: bool,
  117. ) -> None:
  118. if positions.dim() == 1:
  119. positions = positions.unsqueeze(0)
  120. query = query.unsqueeze(0)
  121. key = key.unsqueeze(0)
  122. rotary_dim = cos_sin_cache.size(1)
  123. query = query.view(*query.shape[:-1], -1, head_size)
  124. key = key.view(*key.shape[:-1], -1, head_size)
  125. query_rot = query[..., :rotary_dim]
  126. key_rot = key[..., :rotary_dim]
  127. cos_sin = cos_sin_cache[positions.long()]
  128. cos, sin = cos_sin.chunk(2, dim=-1)
  129. if is_neox:
  130. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  131. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  132. else:
  133. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  134. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  135. ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
  136. rotary_dim, is_neox, positions)
  137. def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
  138. key: torch.Tensor, head_size: int,
  139. cos_sin_cache: torch.Tensor, is_neox: bool,
  140. rot_dim: int,
  141. cos_sin_cache_offsets: torch.Tensor) -> None:
  142. if positions.dim() == 1:
  143. positions = positions.unsqueeze(0)
  144. query = query.unsqueeze(0)
  145. key = key.unsqueeze(0)
  146. cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
  147. rotary_dim = cos_sin_cache.size(1)
  148. query = query.view(*query.shape[:-1], -1, head_size)
  149. key = key.view(*key.shape[:-1], -1, head_size)
  150. query_rot = query[..., :rotary_dim]
  151. key_rot = key[..., :rotary_dim]
  152. cos_sin = cos_sin_cache[torch.add(positions,
  153. cos_sin_cache_offsets).long()]
  154. cos, sin = cos_sin.chunk(2, dim=-1)
  155. if is_neox:
  156. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  157. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  158. else:
  159. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  160. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  161. ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
  162. rotary_dim, is_neox, positions)
  163. def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
  164. epsilon: float) -> None:
  165. tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
  166. out.copy_(tmp)
  167. def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
  168. weight: torch.Tensor, epsilon: float) -> None:
  169. tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
  170. epsilon, True)
  171. input.copy_(tmp)
  172. def varlen_attention(
  173. query: torch.Tensor,
  174. key: torch.Tensor,
  175. value: torch.Tensor,
  176. out: torch.Tensor,
  177. seqlen_q: torch.Tensor,
  178. seqlen_k: torch.Tensor,
  179. max_seqlen_q: int,
  180. max_seqlen_k: int,
  181. pdropout: float,
  182. softmax_scale: float,
  183. zero_tensors: bool,
  184. is_causal: bool,
  185. return_softmax: bool,
  186. gen_: torch.Generator,
  187. ) -> None:
  188. ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
  189. seqlen_k, max_seqlen_q,
  190. max_seqlen_k, pdropout,
  191. softmax_scale, zero_tensors,
  192. is_causal, return_softmax, gen_)
  193. def reshape_and_cache(
  194. key: torch.Tensor,
  195. value: torch.Tensor,
  196. key_cache: torch.Tensor,
  197. value_cache: torch.Tensor,
  198. slot_mapping: torch.Tensor,
  199. kv_cache_dtype: str,
  200. k_scale: float,
  201. v_scale: float,
  202. ) -> None:
  203. assert kv_cache_dtype == "auto"
  204. ipex.llm.modules.PagedAttention.reshape_and_cache(
  205. key, value, key_cache, value_cache, slot_mapping)
  206. @staticmethod
  207. def copy_blocks(key_caches: List[torch.Tensor],
  208. value_caches: List[torch.Tensor],
  209. block_mapping: torch.Tensor) -> None:
  210. torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
  211. def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
  212. block_mapping: torch.Tensor) -> None:
  213. torch.xpu.swap_blocks(src, dst, block_mapping)