|
@@ -13,7 +13,7 @@ from aphrodite import cache_ops
|
|
|
from aphrodite import pos_encoding_ops
|
|
|
from aphrodite.modeling.metadata import InputMetadata
|
|
|
|
|
|
-_SUPPORTED_HEAD_SIZES =
|
|
|
+_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]
|
|
|
|
|
|
class PagedAttention(nn.Module):
|
|
|
"""GPT-style multi-head PagedAttention.
|
|
@@ -150,8 +150,8 @@ class PagedAttention(nn.Module):
|
|
|
)
|
|
|
|
|
|
self.single_query_cached_kv_attention(
|
|
|
- output[num_prompt_tokens:num_valid_tokens]
|
|
|
- query[num_prompt_tokens:num_valid_tokens]
|
|
|
+ output[num_prompt_tokens:num_valid_tokens],
|
|
|
+ query[num_prompt_tokens:num_valid_tokens],
|
|
|
key_cache,
|
|
|
value_cache,
|
|
|
input_metadata)
|