1
0

ipex_attn.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import Dict, List, Optional, Tuple
  2. import intel_extension_for_pytorch.llm.modules as ipex_modules
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. class PagedAttention:
  6. @staticmethod
  7. def get_supported_head_sizes() -> List[int]:
  8. return [64, 80, 96, 112, 128, 256]
  9. @staticmethod
  10. def get_kv_cache_shape(
  11. num_blocks: int,
  12. block_size: int,
  13. num_kv_heads: int,
  14. head_size: int,
  15. *args,
  16. ) -> Tuple[int, ...]:
  17. return (2, num_blocks, block_size * num_kv_heads * head_size)
  18. @staticmethod
  19. def split_kv_cache(
  20. kv_cache: torch.Tensor,
  21. num_kv_heads: int,
  22. head_size: int,
  23. *args,
  24. ) -> Tuple[torch.Tensor, torch.Tensor]:
  25. num_blocks = kv_cache.shape[1]
  26. key_cache = kv_cache[0]
  27. key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
  28. value_cache = kv_cache[1]
  29. value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
  30. return key_cache, value_cache
  31. @staticmethod
  32. def write_to_paged_cache(
  33. key: torch.Tensor,
  34. value: torch.Tensor,
  35. key_cache: torch.Tensor,
  36. value_cache: torch.Tensor,
  37. slot_mapping: torch.Tensor,
  38. kv_cache_dtype: str,
  39. k_scale: float,
  40. v_scale: float,
  41. *args,
  42. ) -> None:
  43. ipex_modules.PagedAttention.reshape_and_cache(
  44. key, value, key_cache, value_cache,
  45. slot_mapping.flatten().int())
  46. @staticmethod
  47. def forward_decode(
  48. query: torch.Tensor,
  49. key_cache: torch.Tensor,
  50. value_cache: torch.Tensor,
  51. block_tables: torch.Tensor,
  52. context_lens: torch.Tensor,
  53. max_context_len: int,
  54. kv_cache_dtype: str,
  55. num_kv_heads: int,
  56. scale: float,
  57. alibi_slopes: Optional[torch.Tensor],
  58. k_scale: float,
  59. v_scale: float,
  60. *args,
  61. ) -> torch.Tensor:
  62. output = torch.empty_like(query)
  63. block_size = value_cache.shape[2]
  64. head_mapping = torch.arange(
  65. 0,
  66. num_kv_heads,
  67. device="cpu",
  68. dtype=torch.int32,
  69. ).view(num_kv_heads,
  70. 1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
  71. ipex_modules.PagedAttention.single_query_cached_kv_attention(
  72. output, query.contiguous(), key_cache, value_cache, head_mapping,
  73. scale, block_tables, context_lens, block_size, max_context_len,
  74. alibi_slopes)
  75. return output
  76. @staticmethod
  77. def forward_prefix(
  78. query: torch.Tensor,
  79. key: torch.Tensor,
  80. value: torch.Tensor,
  81. kv_cache_dtype: str,
  82. key_cache: torch.Tensor,
  83. value_cache: torch.Tensor,
  84. block_tables: torch.Tensor,
  85. subquery_start_loc: torch.Tensor,
  86. prompt_lens_tensor: torch.Tensor,
  87. context_lens: torch.Tensor,
  88. max_subquery_len: int,
  89. alibi_slopes: Optional[torch.Tensor],
  90. *args,
  91. ) -> torch.Tensor:
  92. raise NotImplementedError
  93. @staticmethod
  94. def swap_blocks(
  95. src_kv_cache: torch.Tensor,
  96. dst_kv_cache: torch.Tensor,
  97. src_to_dst: Dict[int, int],
  98. *args,
  99. ) -> None:
  100. raise NotImplementedError
  101. @staticmethod
  102. def copy_blocks(
  103. kv_caches: List[torch.Tensor],
  104. src_to_dists: Dict[int, List[int]],
  105. *args,
  106. ) -> None:
  107. key_caches = [kv_cache[0] for kv_cache in kv_caches]
  108. value_caches = [kv_cache[1] for kv_cache in kv_caches]
  109. ops.copy_blocks(key_caches, value_caches, src_to_dists)