|
@@ -203,7 +203,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
|
|
) -> None:
|
|
|
assert blocksparse_params is None, ValueError(
|
|
|
- "ROCm FlashAttention does not support block-sparse attention.")
|
|
|
+ "ROCmFlashAttention does not support blocksparse attention.")
|
|
|
self.num_heads = num_heads
|
|
|
self.head_size = head_size
|
|
|
self.scale = float(scale)
|
|
@@ -248,7 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|
|
self.use_naive_attn = True
|
|
|
|
|
|
if self.use_naive_attn:
|
|
|
- self.attn_func = _naive_attention
|
|
|
+ self.attn_func = _sdpa_attention
|
|
|
logger.debug("Using naive attention in ROCmBackend")
|
|
|
|
|
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
@@ -343,11 +343,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|
|
# Interleave for MQA workaround.
|
|
|
key = self.repeat_kv(key, self.num_queries_per_kv)
|
|
|
value = self.repeat_kv(value, self.num_queries_per_kv)
|
|
|
+ query = query.movedim(0, query.dim() - 2)
|
|
|
+ key = key.movedim(0, key.dim() - 2)
|
|
|
+ value = value.movedim(0, value.dim() - 2)
|
|
|
+ # sdpa math backend attention
|
|
|
out = self.attn_func(
|
|
|
query,
|
|
|
key,
|
|
|
value,
|
|
|
prefill_meta.seq_lens,
|
|
|
+ num_tokens,
|
|
|
+ self.num_heads,
|
|
|
+ self.head_size,
|
|
|
self.scale,
|
|
|
)
|
|
|
else:
|
|
@@ -403,45 +410,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|
|
return output.view(num_tokens, hidden_size)
|
|
|
|
|
|
|
|
|
-def _naive_attention(
|
|
|
+def _sdpa_attention(
|
|
|
query: torch.Tensor,
|
|
|
key: torch.Tensor,
|
|
|
value: torch.Tensor,
|
|
|
seq_lens: List[int],
|
|
|
+ num_tokens: int,
|
|
|
+ num_heads: int,
|
|
|
+ head_size: int,
|
|
|
scale: float,
|
|
|
) -> torch.Tensor:
|
|
|
- output = torch.empty_like(query)
|
|
|
start = 0
|
|
|
- for _, seq_len in enumerate(seq_lens):
|
|
|
+ output = torch.empty((num_tokens, num_heads, head_size),
|
|
|
+ dtype=query.dtype,
|
|
|
+ device=query.device)
|
|
|
+
|
|
|
+ for seq_len in seq_lens:
|
|
|
end = start + seq_len
|
|
|
- out = _naive_masked_attention(
|
|
|
- query[start:end],
|
|
|
- key[start:end],
|
|
|
- value[start:end],
|
|
|
- scale,
|
|
|
- )
|
|
|
- # TODO: Unnecessary copy. Optimize.
|
|
|
- output[start:end].copy_(out)
|
|
|
- start += seq_len
|
|
|
+ with torch.backends.cuda.sdp_kernel(enable_math=True,
|
|
|
+ enable_flash=False,
|
|
|
+ enable_mem_efficient=False):
|
|
|
+ sub_out = torch.nn.functional.scaled_dot_product_attention(
|
|
|
+ query[:, start:end, :],
|
|
|
+ key[:, start:end, :],
|
|
|
+ value[:, start:end, :],
|
|
|
+ dropout_p=0.0,
|
|
|
+ is_causal=True,
|
|
|
+ scale=scale).movedim(query.dim() - 2, 0)
|
|
|
+ output[start:end, :, :] = sub_out
|
|
|
+ start = end
|
|
|
|
|
|
return output
|
|
|
-
|
|
|
-
|
|
|
-def _naive_masked_attention(
|
|
|
- query: torch.Tensor,
|
|
|
- key: torch.Tensor,
|
|
|
- value: torch.Tensor,
|
|
|
- scale: float,
|
|
|
-) -> torch.Tensor:
|
|
|
- seq_len, head_size, head_dim = query.shape
|
|
|
- attn_mask = torch.triu(torch.ones(seq_len,
|
|
|
- seq_len,
|
|
|
- dtype=query.dtype,
|
|
|
- device=query.device),
|
|
|
- diagonal=1)
|
|
|
- attn_mask = attn_mask * torch.finfo(query.dtype).min
|
|
|
- attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
|
|
- attn_weights = attn_weights + attn_mask.float()
|
|
|
- attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
|
- out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
|
|
- return out
|