_ipex_ops.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from loguru import logger
  4. try:
  5. import intel_extension_for_pytorch as ipex
  6. except ImportError as e:
  7. logger.warning(f"Import error msg: {e.msg}")
  8. class ipex_ops:
  9. @staticmethod
  10. def _reshape_activation_tensor(
  11. x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  12. num = x.size(0)
  13. d = x.size(1) // 2
  14. x = x.reshape(num, 2, d)
  15. x1, x2 = torch.chunk(x, chunks=2, dim=1)
  16. x1 = x1.reshape(num, d)
  17. x2 = x2.reshape(num, d)
  18. return x1, x2
  19. def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  20. x1, x2 = ipex_ops._reshape_activation_tensor(x)
  21. ipex.llm.functional.silu_mul(x1, x2, out)
  22. def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  23. x1, x2 = ipex_ops._reshape_activation_tensor(x)
  24. ipex.llm.functional.gelu_mul(x1, x2, out, "none")
  25. def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
  26. x1, x2 = ipex_ops._reshape_activation_tensor(x)
  27. ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
  28. def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
  29. out.copy_(torch.nn.functional.gelu(x))
  30. def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
  31. out.copy_(torch.nn.functional.gelu(x))
  32. def paged_attention_v1(
  33. out: torch.Tensor,
  34. query: torch.Tensor,
  35. key_cache: torch.Tensor,
  36. value_cache: torch.Tensor,
  37. num_kv_heads: int,
  38. scale: float,
  39. block_tables: torch.Tensor,
  40. context_lens: torch.Tensor,
  41. block_size: int,
  42. max_context_len: int,
  43. alibi_slopes: Optional[torch.Tensor],
  44. kv_cache_dtype: str,
  45. kv_scale: float,
  46. tp_rank: int = 0,
  47. blocksparse_local_blocks: int = 0,
  48. blocksparse_vert_stride: int = 0,
  49. blocksparse_block_size: int = 64,
  50. blocksparse_head_sliding_step: int = 0,
  51. ) -> None:
  52. assert kv_cache_dtype == "auto"
  53. num_heads = out.size(1)
  54. num_queries_per_tokens = num_heads // num_kv_heads
  55. head_mapping = torch.arange(
  56. 0,
  57. num_kv_heads,
  58. device=query.device,
  59. dtype=torch.int32,
  60. ).view(num_kv_heads,
  61. 1).repeat_interleave(num_queries_per_tokens).flatten()
  62. # todo: ipex will refactor namespace
  63. torch.xpu.paged_attention_v1(out, query.contiguous(),
  64. key_cache.view_as(value_cache),
  65. value_cache, head_mapping, scale,
  66. block_tables, context_lens, block_size,
  67. max_context_len, alibi_slopes)
  68. def paged_attention_v2(
  69. out: torch.Tensor,
  70. exp_sum: torch.Tensor,
  71. max_logits: torch.Tensor,
  72. tmp_out: torch.Tensor,
  73. query: torch.Tensor,
  74. key_cache: torch.Tensor,
  75. value_cache: torch.Tensor,
  76. num_kv_heads: int,
  77. scale: float,
  78. block_tables: torch.Tensor,
  79. context_lens: torch.Tensor,
  80. block_size: int,
  81. max_context_len: int,
  82. alibi_slopes: Optional[torch.Tensor],
  83. kv_cache_dtype: str,
  84. kv_scale: float,
  85. tp_rank: int = 0,
  86. blocksparse_local_blocks: int = 0,
  87. blocksparse_vert_stride: int = 0,
  88. blocksparse_block_size: int = 64,
  89. blocksparse_head_sliding_step: int = 0,
  90. ) -> None:
  91. assert kv_cache_dtype == "auto"
  92. num_heads = out.size(1)
  93. num_queries_per_tokens = num_heads // num_kv_heads
  94. head_mapping = torch.arange(
  95. 0,
  96. num_kv_heads,
  97. dtype=torch.int32,
  98. device=query.device,
  99. ).view(num_kv_heads,
  100. 1).repeat_interleave(num_queries_per_tokens).flatten()
  101. # todo: ipex will refactor namespace
  102. torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
  103. query.contiguous(),
  104. key_cache.view_as(value_cache),
  105. value_cache, head_mapping, block_tables,
  106. context_lens, scale, block_size,
  107. max_context_len, alibi_slopes)
  108. def rotary_embedding(
  109. positions: torch.Tensor, # [batch_size, seq_len]
  110. query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
  111. key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size]
  112. head_size: int,
  113. cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
  114. is_neox: bool,
  115. ) -> None:
  116. if positions.dim() == 1:
  117. positions = positions.unsqueeze(0)
  118. query = query.unsqueeze(0)
  119. key = key.unsqueeze(0)
  120. rotary_dim = cos_sin_cache.size(1)
  121. query = query.view(*query.shape[:-1], -1, head_size)
  122. key = key.view(*key.shape[:-1], -1, head_size)
  123. query_rot = query[..., :rotary_dim]
  124. key_rot = key[..., :rotary_dim]
  125. cos_sin = cos_sin_cache[positions.long()]
  126. cos, sin = cos_sin.chunk(2, dim=-1)
  127. if is_neox:
  128. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  129. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  130. else:
  131. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  132. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  133. ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
  134. rotary_dim, is_neox, positions)
  135. def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
  136. key: torch.Tensor, head_size: int,
  137. cos_sin_cache: torch.Tensor, is_neox: bool,
  138. rot_dim: int,
  139. cos_sin_cache_offsets: torch.Tensor) -> None:
  140. if positions.dim() == 1:
  141. positions = positions.unsqueeze(0)
  142. query = query.unsqueeze(0)
  143. key = key.unsqueeze(0)
  144. cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
  145. rotary_dim = cos_sin_cache.size(1)
  146. query = query.view(*query.shape[:-1], -1, head_size)
  147. key = key.view(*key.shape[:-1], -1, head_size)
  148. query_rot = query[..., :rotary_dim]
  149. key_rot = key[..., :rotary_dim]
  150. cos_sin = cos_sin_cache[torch.add(positions,
  151. cos_sin_cache_offsets).long()]
  152. cos, sin = cos_sin.chunk(2, dim=-1)
  153. if is_neox:
  154. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  155. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  156. else:
  157. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  158. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  159. ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
  160. rotary_dim, is_neox, positions)
  161. def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
  162. epsilon: float) -> None:
  163. tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
  164. out.copy_(tmp)
  165. def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
  166. weight: torch.Tensor, epsilon: float) -> None:
  167. tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
  168. epsilon, True)
  169. input.copy_(tmp)
  170. def varlen_attention(
  171. query: torch.Tensor,
  172. key: torch.Tensor,
  173. value: torch.Tensor,
  174. out: torch.Tensor,
  175. seqlen_q: torch.Tensor,
  176. seqlen_k: torch.Tensor,
  177. max_seqlen_q: int,
  178. max_seqlen_k: int,
  179. pdropout: float,
  180. softmax_scale: float,
  181. zero_tensors: bool,
  182. is_causal: bool,
  183. return_softmax: bool,
  184. gen_: torch.Generator,
  185. ) -> None:
  186. ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q,
  187. seqlen_k, max_seqlen_q,
  188. max_seqlen_k, pdropout,
  189. softmax_scale, zero_tensors,
  190. is_causal, return_softmax, gen_)
  191. def reshape_and_cache(
  192. key: torch.Tensor,
  193. value: torch.Tensor,
  194. key_cache: torch.Tensor,
  195. value_cache: torch.Tensor,
  196. slot_mapping: torch.Tensor,
  197. kv_cache_dtype: str,
  198. kv_scale: float,
  199. ) -> None:
  200. assert kv_cache_dtype == "auto"
  201. ipex.llm.modules.PagedAttention.reshape_and_cache(
  202. key, value, key_cache, value_cache, slot_mapping)
  203. @staticmethod
  204. def copy_blocks(key_caches: List[torch.Tensor],
  205. value_caches: List[torch.Tensor],
  206. block_mapping: torch.Tensor) -> None:
  207. torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
  208. def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
  209. block_mapping: torch.Tensor) -> None:
  210. torch.xpu.swap_blocks(src, dst, block_mapping)