_ipex_ops.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from loguru import logger
  4. try:
  5. import intel_extension_for_pytorch as ipex
  6. except ImportError as e:
  7. logger.warning(f"Import error msg: {e.msg}")
  8. class ipex_ops:
  9. @staticmethod
  10. def _reshape_activation_tensor(
  11. x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  12. num = x.size(0)
  13. d = x.size(1) // 2
  14. x = x.reshape(num, 2, d)
  15. x1, x2 = torch.chunk(x, chunks=2, dim=1)
  16. x1 = x1.reshape(num, d)
  17. x2 = x2.reshape(num, d)
  18. return x1, x2
  19. @staticmethod
  20. def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  21. ipex.llm.functional.silu_and_mul(x, out)
  22. @staticmethod
  23. def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  24. ipex.llm.functional.gelu_and_mul(x, out)
  25. @staticmethod
  26. def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  27. ipex.llm.functional.gelu_and_mul(x, out)
  28. @staticmethod
  29. def gelu_fast(x: torch.Tensor) -> torch.Tensor:
  30. return torch.nn.functional.gelu(x)
  31. @staticmethod
  32. def gelu_new(x: torch.Tensor) -> torch.Tensor:
  33. return torch.nn.functional.gelu(x)
  34. @staticmethod
  35. def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
  36. ipex.llm.functional.gelu_quick(x, out)
  37. def paged_attention_v1(
  38. out: torch.Tensor,
  39. query: torch.Tensor,
  40. key_cache: torch.Tensor,
  41. value_cache: torch.Tensor,
  42. num_kv_heads: int,
  43. scale: float,
  44. block_tables: torch.Tensor,
  45. context_lens: torch.Tensor,
  46. block_size: int,
  47. max_context_len: int,
  48. alibi_slopes: Optional[torch.Tensor],
  49. kv_cache_dtype: str,
  50. k_scale: float,
  51. v_scale: float,
  52. tp_rank: int = 0,
  53. blocksparse_local_blocks: int = 0,
  54. blocksparse_vert_stride: int = 0,
  55. blocksparse_block_size: int = 64,
  56. blocksparse_head_sliding_step: int = 0,
  57. ) -> None:
  58. assert kv_cache_dtype == "auto"
  59. num_heads = out.size(1)
  60. num_queries_per_tokens = num_heads // num_kv_heads
  61. head_mapping = torch.arange(
  62. 0,
  63. num_kv_heads,
  64. device=query.device,
  65. dtype=torch.int32,
  66. ).view(num_kv_heads,
  67. 1).repeat_interleave(num_queries_per_tokens).flatten()
  68. # todo: ipex will refactor namespace
  69. torch.xpu.paged_attention_v1(out, query.contiguous(),
  70. key_cache.view_as(value_cache),
  71. value_cache, head_mapping, scale,
  72. block_tables, context_lens, block_size,
  73. max_context_len, alibi_slopes)
  74. def paged_attention_v2(
  75. out: torch.Tensor,
  76. exp_sum: torch.Tensor,
  77. max_logits: torch.Tensor,
  78. tmp_out: torch.Tensor,
  79. query: torch.Tensor,
  80. key_cache: torch.Tensor,
  81. value_cache: torch.Tensor,
  82. num_kv_heads: int,
  83. scale: float,
  84. block_tables: torch.Tensor,
  85. context_lens: torch.Tensor,
  86. block_size: int,
  87. max_context_len: int,
  88. alibi_slopes: Optional[torch.Tensor],
  89. kv_cache_dtype: str,
  90. k_scale: float,
  91. v_scale: float,
  92. tp_rank: int = 0,
  93. blocksparse_local_blocks: int = 0,
  94. blocksparse_vert_stride: int = 0,
  95. blocksparse_block_size: int = 64,
  96. blocksparse_head_sliding_step: int = 0,
  97. ) -> None:
  98. assert kv_cache_dtype == "auto"
  99. num_heads = out.size(1)
  100. num_queries_per_tokens = num_heads // num_kv_heads
  101. head_mapping = torch.arange(
  102. 0,
  103. num_kv_heads,
  104. dtype=torch.int32,
  105. device=query.device,
  106. ).view(num_kv_heads,
  107. 1).repeat_interleave(num_queries_per_tokens).flatten()
  108. # todo: ipex will refactor namespace
  109. torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
  110. query.contiguous(),
  111. key_cache.view_as(value_cache),
  112. value_cache, head_mapping, block_tables,
  113. context_lens, scale, block_size,
  114. max_context_len, alibi_slopes)
  115. def rotary_embedding(
  116. positions: torch.Tensor, # [batch_size, seq_len]
  117. query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
  118. key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
  119. head_size: int,
  120. cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
  121. is_neox: bool,
  122. ) -> None:
  123. rot_dim = cos_sin_cache.size(1)
  124. ipex.llm.functional.rotary_embedding_batched(positions, query, key,
  125. head_size, cos_sin_cache,
  126. is_neox, rot_dim)
  127. def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
  128. key: torch.Tensor, head_size: int,
  129. cos_sin_cache: torch.Tensor, is_neox: bool,
  130. rot_dim: int,
  131. cos_sin_cache_offsets: torch.Tensor) -> None:
  132. ipex.llm.functional.rotary_embedding_batched(positions, query, key,
  133. head_size, cos_sin_cache,
  134. is_neox, rot_dim,
  135. cos_sin_cache_offsets)
  136. @staticmethod
  137. def rms_norm(input: torch.Tensor, weight: torch.Tensor,
  138. epsilon: float) -> torch.Tensor:
  139. return ipex.llm.functional.rms_norm(input, weight, epsilon)
  140. def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
  141. weight: torch.Tensor, epsilon: float) -> None:
  142. tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
  143. epsilon, True)
  144. input.copy_(tmp)
  145. def varlen_attention(
  146. query: torch.Tensor,
  147. key: torch.Tensor,
  148. value: torch.Tensor,
  149. out: torch.Tensor,
  150. seqlen_q: torch.Tensor,
  151. seqlen_k: torch.Tensor,
  152. max_seqlen_q: int,
  153. max_seqlen_k: int,
  154. pdropout: float,
  155. softmax_scale: float,
  156. zero_tensors: bool,
  157. is_causal: bool,
  158. return_softmax: bool,
  159. gen_: torch.Generator,
  160. ) -> None:
  161. ipex.llm.functional.varlen_attention(query.contiguous(),
  162. key.contiguous(),
  163. value.contiguous(), out,
  164. seqlen_q.int(), seqlen_k.int(),
  165. max_seqlen_q, max_seqlen_k,
  166. pdropout, softmax_scale,
  167. zero_tensors, is_causal,
  168. return_softmax, gen_)
  169. def reshape_and_cache(
  170. key: torch.Tensor,
  171. value: torch.Tensor,
  172. key_cache: torch.Tensor,
  173. value_cache: torch.Tensor,
  174. slot_mapping: torch.Tensor,
  175. kv_cache_dtype: str,
  176. k_scale: float,
  177. v_scale: float,
  178. ) -> None:
  179. assert kv_cache_dtype == "auto"
  180. ipex.llm.modules.PagedAttention.reshape_and_cache(
  181. key, value, key_cache, value_cache, slot_mapping)
  182. @staticmethod
  183. def copy_blocks(key_caches: List[torch.Tensor],
  184. value_caches: List[torch.Tensor],
  185. block_mapping: torch.Tensor) -> None:
  186. torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
  187. def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
  188. block_mapping: torch.Tensor) -> None:
  189. torch.xpu.swap_blocks(src, dst, block_mapping)