paged_attn.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from typing import List, Optional
  2. import torch
  3. from aphrodite._C import cache_ops
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.metadata import InputMetadata
  6. from aphrodite.modeling.layers.attention.ops.prefix_prefill import (
  7. context_attention_fwd)
  8. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
  9. _PARTITION_SIZE = 512
  10. class PagedAttentionImpl:
  11. @staticmethod
  12. def get_supported_head_sizes() -> List[int]:
  13. return [64, 80, 96, 112, 128, 256]
  14. @staticmethod
  15. def reshape_and_cache(
  16. key: torch.Tensor,
  17. value: torch.Tensor,
  18. key_cache: torch.Tensor,
  19. value_cache: torch.Tensor,
  20. input_metadata: InputMetadata,
  21. ) -> None:
  22. cache_ops.reshape_and_cache(
  23. key,
  24. value,
  25. key_cache,
  26. value_cache,
  27. input_metadata.slot_mapping.flatten(),
  28. input_metadata.kv_cache_dtype,
  29. )
  30. @staticmethod
  31. def forward_decode(
  32. query: torch.Tensor,
  33. key_cache: torch.Tensor,
  34. value_cache: torch.Tensor,
  35. input_metadata: InputMetadata,
  36. num_kv_heads: int,
  37. scale: float,
  38. alibi_slopes: Optional[torch.Tensor],
  39. ) -> torch.Tensor:
  40. output = torch.empty_like(query)
  41. block_size = value_cache.shape[3]
  42. num_seqs, num_heads, head_size = query.shape
  43. max_num_partitions = (
  44. (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
  45. _PARTITION_SIZE)
  46. # NOTE: We use a simple heuristic to decide whether to use
  47. # PagedAttention V1 or V2. If the number of partitions is 1, we use
  48. # V1 to avoid the overhead of reduction. Also, if the number of
  49. # sequences or heads is large, we use V1 since there is enough work
  50. # to parallelize.
  51. # TODO: Tune this heuristic.
  52. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
  53. use_v1 = input_metadata.max_context_len <= 8192 and (
  54. max_num_partitions == 1 or num_seqs * num_heads > 512)
  55. if use_v1:
  56. # Run PagedAttention V1.
  57. ops.paged_attention_v1(
  58. output,
  59. query,
  60. key_cache,
  61. value_cache,
  62. num_kv_heads,
  63. scale,
  64. input_metadata.block_tables,
  65. input_metadata.context_lens,
  66. block_size,
  67. input_metadata.max_context_len,
  68. alibi_slopes,
  69. input_metadata.kv_cache_dtype,
  70. )
  71. else:
  72. # Run PagedAttention V2.
  73. assert _PARTITION_SIZE % block_size == 0
  74. tmp_output = torch.empty(
  75. size=(num_seqs, num_heads, max_num_partitions, head_size),
  76. dtype=output.dtype,
  77. device=output.device,
  78. )
  79. exp_sums = torch.empty(
  80. size=(num_seqs, num_heads, max_num_partitions),
  81. dtype=torch.float32,
  82. device=output.device,
  83. )
  84. max_logits = torch.empty_like(exp_sums)
  85. ops.paged_attention_v2(
  86. output,
  87. exp_sums,
  88. max_logits,
  89. tmp_output,
  90. query,
  91. key_cache,
  92. value_cache,
  93. num_kv_heads,
  94. scale,
  95. input_metadata.block_tables,
  96. input_metadata.context_lens,
  97. block_size,
  98. input_metadata.max_context_len,
  99. alibi_slopes,
  100. input_metadata.kv_cache_dtype,
  101. )
  102. return output
  103. @staticmethod
  104. def forward_prefix(
  105. query: torch.Tensor,
  106. key: torch.Tensor,
  107. value: torch.Tensor,
  108. key_cache: torch.Tensor,
  109. value_cache: torch.Tensor,
  110. input_metadata: InputMetadata,
  111. alibi_slopes: Optional[torch.Tensor],
  112. ) -> torch.Tensor:
  113. output = torch.empty_like(query)
  114. context_attention_fwd(
  115. query,
  116. key,
  117. value,
  118. output,
  119. key_cache,
  120. value_cache,
  121. input_metadata.block_tables,
  122. # subquery_start_loc is (batch_size + 1,)
  123. input_metadata.subquery_start_loc[:-1],
  124. input_metadata.prompt_lens_tensor,
  125. input_metadata.context_lens,
  126. input_metadata.max_subquery_len,
  127. alibi_slopes,
  128. )
  129. return output