|
@@ -38,12 +38,23 @@ class PagedAttention(nn.Module):
|
|
|
5. Output a flattened 1D tensor.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
|
|
+ def __init__(self,
|
|
|
+ num_heads: int,
|
|
|
+ head_size: int,
|
|
|
+ scale: float,
|
|
|
+ num_kv_heads: Optional[int] = None) -> None:
|
|
|
super().__init__()
|
|
|
self.num_heads = num_heads
|
|
|
self.head_size = head_size
|
|
|
self.scale = float(scale)
|
|
|
self.attn_op = xops.fmha.cutlass.FwOp()
|
|
|
+ self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
|
|
+
|
|
|
+ assert self.num_heads % self.num_kv_heads == 0
|
|
|
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
|
+ self.head_mapping = torch.repeat_interleave(
|
|
|
+ torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
|
|
|
+ self.num_queries_per_kv)
|
|
|
|
|
|
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
|
|
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
|
@@ -70,11 +81,19 @@ class PagedAttention(nn.Module):
|
|
|
Args:
|
|
|
output: shape = [num_prompt_tokens, num_heads, head_size]
|
|
|
query: shape = [num_prompt_tokens, num_heads, head_size]
|
|
|
- key: shape = [num_prompt_tokens, num_heads, head_size]
|
|
|
- value: shape = [num_prompt_tokens, num_heads, head_size]
|
|
|
+ key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
|
|
+ value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
|
|
input_metadata: metadata for paged attention.
|
|
|
"""
|
|
|
- #TODO: The unsqueeze op may incur some CPU overhead. See if optimzation is possible.
|
|
|
+
|
|
|
+ if self.num_kv_heads != self.num_heads:
|
|
|
+ # Project the key and value tensors to the desired number of heads.
|
|
|
+ key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
|
|
+ value = torch.repeat_interleave(value,
|
|
|
+ self.num_queries_per_kv,
|
|
|
+ dim=1)
|
|
|
+
|
|
|
+ # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
|
|
out = xops.memory_efficient_attention_forward(
|
|
|
query.unsqueeze(0),
|
|
|
key.unsqueeze(0),
|
|
@@ -84,7 +103,7 @@ class PagedAttention(nn.Module):
|
|
|
scale=self.scale,
|
|
|
op=self.attn_op,
|
|
|
)
|
|
|
- # TODO: Unnecessary copy. Optimize.
|
|
|
+ # TODO(woosuk): Unnecessary copy. Optimize.
|
|
|
output.copy_(out.squeeze(0))
|
|
|
return output
|
|
|
|
|
@@ -101,9 +120,9 @@ class PagedAttention(nn.Module):
|
|
|
Args:
|
|
|
output: shape = [num_generation_tokens, num_heads, head_size]
|
|
|
query: shape = [num_generation_tokens, num_heads, head_size]
|
|
|
- key_cache: shape = [num_blocks, num_heads, head_size/x,
|
|
|
+ key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
|
|
block_size, x]
|
|
|
- value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
|
|
+ value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
|
|
|
input_metadata: metadata for paged attention.
|
|
|
"""
|
|
|
block_size = value_cache.shape[3]
|
|
@@ -112,6 +131,7 @@ class PagedAttention(nn.Module):
|
|
|
query,
|
|
|
key_cache,
|
|
|
value_cache,
|
|
|
+ self.head_mapping,
|
|
|
self.scale,
|
|
|
input_metadata.block_tables,
|
|
|
input_metadata.context_lens,
|
|
@@ -137,11 +157,12 @@ class PagedAttention(nn.Module):
|
|
|
|
|
|
Args:
|
|
|
query: shape = [num_tokens, num_heads * head_size]
|
|
|
- key: shape = [num_tokens, num_heads * head_size]
|
|
|
- value: shape = [num_tokens, num_heads * head_size]
|
|
|
- key_cache: shape = [num_blocks, num_heads, head_size/x,
|
|
|
+ key: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
+ value: shape = [num_tokens, num_kv_heads * head_size]
|
|
|
+ key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
|
|
block_size, x]
|
|
|
- value_cache: shape = [num_blocks, num_heads, head_size, block_size]
|
|
|
+ value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
|
|
+ block_size]
|
|
|
input_metadata: metadata for paged attention.
|
|
|
cache_event: event to wait for the cache operations to finish.
|
|
|
|
|
@@ -151,8 +172,8 @@ class PagedAttention(nn.Module):
|
|
|
|
|
|
# Reshape the query, key, and value tensors.
|
|
|
query = query.view(-1, self.num_heads, self.head_size)
|
|
|
- key = key.view(-1, self.num_heads, self.head_size)
|
|
|
- value = value.view(-1, self.num_heads, self.head_size)
|
|
|
+ key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
|
+ value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
|
|
|
|
# Pre-allocate the output tensor.
|
|
|
output = torch.empty_like(query)
|
|
@@ -308,8 +329,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|
|
bias = bias.to(self.alibi_slopes.device)
|
|
|
|
|
|
# When using custom attention bias, xformers requires the bias to
|
|
|
- # be sliced from a tensor whose length is a multiple of 8. Likely
|
|
|
- # due to a CUDA Tensor code limitation.
|
|
|
+ # be sliced from a tensor whose length is a multiple of 8.
|
|
|
padded_len = (prompt_len + 7) // 8 * 8
|
|
|
bias = torch.empty(
|
|
|
self.num_heads,
|
|
@@ -338,7 +358,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|
|
value: shape = [num_prompt_tokens, num_heads, head_size]
|
|
|
input_metadata: metadata for paged attention.
|
|
|
"""
|
|
|
- # FIXME: Because xformers does not support dynamic sequence
|
|
|
+ # FIXME(woosuk): Because xformers does not support dynamic sequence
|
|
|
# lengths with custom attention bias, we process each prompt one by
|
|
|
# one. This is inefficient, especially when we have many short prompts.
|
|
|
start = 0
|
|
@@ -353,7 +373,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
|
|
scale=self.scale,
|
|
|
op=self.attn_op,
|
|
|
)
|
|
|
- # TODO: Unnecessary copy. Optimize.
|
|
|
+ # TODO(woosuk): Unnecessary copy. Optimize.
|
|
|
output[start:end].copy_(out.squeeze(0))
|
|
|
start += prompt_len
|
|
|
return output
|