openvino.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from dataclasses import dataclass
  2. from typing import List, Tuple, Type
  3. import openvino as ov
  4. import torch
  5. from aphrodite.attention.backends.abstract import (AttentionBackend,
  6. AttentionMetadata)
  7. from aphrodite.attention.backends.utils import CommonAttentionState
  8. class OpenVINOAttentionBackend(AttentionBackend):
  9. @staticmethod
  10. def get_name() -> str:
  11. return "openvino"
  12. @staticmethod
  13. def get_impl_cls():
  14. # OpenVINO implements PagedAttention as part of the Optimum
  15. # exported model
  16. raise NotImplementedError
  17. @staticmethod
  18. def make_metadata(*args, **kwargs) -> "AttentionMetadata":
  19. raise NotImplementedError
  20. @staticmethod
  21. def get_state_cls() -> Type["CommonAttentionState"]:
  22. return CommonAttentionState
  23. @staticmethod
  24. def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
  25. return OpenVINOAttentionMetadata(*args, **kwargs)
  26. @staticmethod
  27. def get_kv_cache_shape(
  28. num_blocks: int,
  29. block_size: int,
  30. num_kv_heads: int,
  31. head_size: int,
  32. ) -> Tuple[int, ...]:
  33. return (2, num_blocks, num_kv_heads, block_size, head_size)
  34. @staticmethod
  35. def swap_blocks(
  36. src_kv_cache: ov.Tensor,
  37. dst_kv_cache: ov.Tensor,
  38. src_to_dst: torch.Tensor,
  39. ) -> None:
  40. # OpenVINO currently supports only CPU, which does not require
  41. # swap of KV cache blocks
  42. raise NotImplementedError
  43. @staticmethod
  44. def copy_blocks(
  45. kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
  46. src_to_dists: List[Tuple[int, int]],
  47. ) -> None:
  48. for src, dst in src_to_dists:
  49. for key_cache, value_cache in kv_caches:
  50. key_cache.data[dst, :] = key_cache.data[src, :]
  51. value_cache.data[dst, :] = value_cache.data[src, :]
  52. @dataclass
  53. class OpenVINOAttentionMetadata:
  54. """Metadata for OpenVINOAttentionBackend.
  55. Basic terms used below:
  56. - batch_size_in_sequences - total number of sequences to execute​
  57. - prompt_lens – per sequence size number of scheduled tokens​
  58. - batch_size_in_tokens = sum(prompt_lens)​
  59. - max_context_len = max(context_lens)​
  60. - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
  61. - num_blocks – total number of blocks in block_indices​
  62. """
  63. # Describes past KV cache size for each sequence within a batch
  64. # Shape: [batch_size_in_sequences]
  65. # Type: i32​
  66. past_lens: torch.Tensor
  67. # Describes start indices of input / speculative tokens from
  68. # current sequences within a batch sequence​
  69. # Shape: [batch_size_in_sequences + 1]​
  70. # Type: i32
  71. subsequence_begins: torch.Tensor
  72. # Describes block tables for each sequence within a batch​ -
  73. # indices along 0th dimension in key_cache and value_cache inputs​
  74. # Shape: [num_blocks]
  75. # Type: i32​
  76. block_indices: torch.Tensor
  77. # Describes block tables for each sequence within a batch​ -
  78. # for i-th element, it is an index in block_indices with the
  79. # first block belonging to i-th sequence​
  80. # Shape: [batch_size_in_sequences + 1]
  81. # Type: i32​
  82. block_indices_begins: torch.Tensor
  83. # Describes max context length
  84. # Shape: scalar
  85. # Type: i32
  86. max_context_len: torch.Tensor