1
0

openvino.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from dataclasses import dataclass
  2. from typing import List, Tuple
  3. import openvino as ov
  4. import torch
  5. from aphrodite.attention.backends.abstract import (AttentionBackend,
  6. AttentionMetadata)
  7. class OpenVINOAttentionBackend(AttentionBackend):
  8. @staticmethod
  9. def get_name() -> str:
  10. return "openvino"
  11. @staticmethod
  12. def get_impl_cls():
  13. # OpenVINO implements PagedAttention as part of the Optimum
  14. # exported model
  15. raise NotImplementedError
  16. @staticmethod
  17. def make_metadata(*args, **kwargs) -> "AttentionMetadata":
  18. raise NotImplementedError
  19. @staticmethod
  20. def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
  21. return OpenVINOAttentionMetadata(*args, **kwargs)
  22. @staticmethod
  23. def get_kv_cache_shape(
  24. num_blocks: int,
  25. block_size: int,
  26. num_kv_heads: int,
  27. head_size: int,
  28. ) -> Tuple[int, ...]:
  29. return (2, num_blocks, num_kv_heads, block_size, head_size)
  30. @staticmethod
  31. def swap_blocks(
  32. src_kv_cache: ov.Tensor,
  33. dst_kv_cache: ov.Tensor,
  34. src_to_dst: torch.Tensor,
  35. ) -> None:
  36. # OpenVINO currently supports only CPU, which does not require
  37. # swap of KV cache blocks
  38. raise NotImplementedError
  39. @staticmethod
  40. def copy_blocks(
  41. kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
  42. src_to_dists: List[Tuple[int, int]],
  43. ) -> None:
  44. for src, dst in src_to_dists:
  45. for key_cache, value_cache in kv_caches:
  46. key_cache.data[dst, :] = key_cache.data[src, :]
  47. value_cache.data[dst, :] = value_cache.data[src, :]
  48. @dataclass
  49. class OpenVINOAttentionMetadata:
  50. """Metadata for OpenVINOAttentionBackend.
  51. Basic terms used below:
  52. - batch_size_in_sequences - total number of sequences to execute​
  53. - prompt_lens – per sequence size number of scheduled tokens​
  54. - batch_size_in_tokens = sum(prompt_lens)​
  55. - max_context_len = max(context_lens)​
  56. - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
  57. - num_blocks – total number of blocks in block_indices​
  58. """
  59. # Describes past KV cache size for each sequence within a batch
  60. # Shape: [batch_size_in_sequences]
  61. # Type: i32​
  62. past_lens: torch.Tensor
  63. # Describes start indices of input / speculative tokens from
  64. # current sequences within a batch sequence​
  65. # Shape: [batch_size_in_sequences + 1]​
  66. # Type: i32
  67. subsequence_begins: torch.Tensor
  68. # Describes block tables for each sequence within a batch​ -
  69. # indices along 0th dimension in key_cache and value_cache inputs​
  70. # Shape: [num_blocks]
  71. # Type: i32​
  72. block_indices: torch.Tensor
  73. # Describes block tables for each sequence within a batch​ -
  74. # for i-th element, it is an index in block_indices with the
  75. # first block belonging to i-th sequence​
  76. # Shape: [batch_size_in_sequences + 1]
  77. # Type: i32​
  78. block_indices_begins: torch.Tensor
  79. # Describes max context length
  80. # Shape: scalar
  81. # Type: i32
  82. max_context_len: torch.Tensor