|
@@ -1,6 +1,6 @@
|
|
|
# Copyright (c) 2023, Tri Dao.
|
|
|
|
|
|
-from typing import Optional, Union
|
|
|
+from typing import Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
@@ -14,6 +14,7 @@ import flash_attn_2_cuda as flash_attn_cuda
|
|
|
def maybe_contiguous(x):
|
|
|
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
|
|
|
|
|
+
|
|
|
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
|
|
# This should match the block sizes in the CUDA kernel
|
|
|
assert head_dim <= 256
|
|
@@ -45,11 +46,49 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
|
|
return 64
|
|
|
|
|
|
|
|
|
+def round_multiple(x, m):
|
|
|
+ return (x + m - 1) // m * m
|
|
|
+
|
|
|
+
|
|
|
+# torch.compile() support is only enabled for pytorch >= 2.4
|
|
|
+# The reason for this is that we are using the new custom_op and register_fake
|
|
|
+# APIs, which support inplace modification of inputs in the function itself
|
|
|
+if torch.__version__ >= "2.4.0":
|
|
|
+ _torch_custom_op_wrapper = torch.library.custom_op
|
|
|
+ _torch_register_fake_wrapper = torch.library.register_fake
|
|
|
+else:
|
|
|
+ def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
|
|
+ def wrap(func):
|
|
|
+ return func
|
|
|
+ if fn is None:
|
|
|
+ return wrap
|
|
|
+ return fn
|
|
|
+ def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
|
|
|
+ def wrap(func):
|
|
|
+ return func
|
|
|
+ if fn is None:
|
|
|
+ return wrap
|
|
|
+ return fn
|
|
|
+ _torch_custom_op_wrapper = noop_custom_op_wrapper
|
|
|
+ _torch_register_fake_wrapper = noop_register_fake_wrapper
|
|
|
+
|
|
|
+
|
|
|
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
|
|
def _flash_attn_forward(
|
|
|
- q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
|
|
|
-):
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int,
|
|
|
+ window_size_right: int,
|
|
|
+ softcap: float,
|
|
|
+ alibi_slopes: Optional[torch.Tensor],
|
|
|
+ return_softmax: bool
|
|
|
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
|
|
|
+ out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -58,36 +97,71 @@ def _flash_attn_forward(
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
- window_size[0],
|
|
|
- window_size[1],
|
|
|
+ window_size_left,
|
|
|
+ window_size_right,
|
|
|
softcap,
|
|
|
return_softmax,
|
|
|
None,
|
|
|
)
|
|
|
- return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
|
|
+ return out, softmax_lse, S_dmask, rng_state
|
|
|
+
|
|
|
+
|
|
|
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
|
|
|
+def _flash_attn_forward_fake(
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int,
|
|
|
+ window_size_right: int,
|
|
|
+ softcap: float,
|
|
|
+ alibi_slopes: Optional[torch.Tensor],
|
|
|
+ return_softmax: bool
|
|
|
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
+ q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
|
+ batch_size, seqlen_q, num_heads, head_size = q.shape
|
|
|
+ seqlen_k = k.shape[1]
|
|
|
+ out = torch.empty_like(q)
|
|
|
+ softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
|
|
|
+ p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
|
+ if return_softmax:
|
|
|
+ p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
|
+ rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
|
|
|
+
|
|
|
+ return out, softmax_lse, p, rng_state
|
|
|
+
|
|
|
+
|
|
|
+if torch.__version__ >= "2.4.0":
|
|
|
+ _wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
|
|
|
+else:
|
|
|
+ _wrapped_flash_attn_forward = _flash_attn_forward
|
|
|
|
|
|
|
|
|
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
|
|
|
def _flash_attn_varlen_forward(
|
|
|
- q,
|
|
|
- k,
|
|
|
- v,
|
|
|
- cu_seqlens_q,
|
|
|
- cu_seqlens_k,
|
|
|
- max_seqlen_q,
|
|
|
- max_seqlen_k,
|
|
|
- dropout_p,
|
|
|
- softmax_scale,
|
|
|
- causal,
|
|
|
- window_size=(-1, -1),
|
|
|
- softcap=0.0,
|
|
|
- alibi_slopes=None,
|
|
|
- return_softmax=False,
|
|
|
- block_table=None,
|
|
|
- leftpad_k=None,
|
|
|
- seqused_k=None,
|
|
|
-):
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ cu_seqlens_q: torch.Tensor,
|
|
|
+ cu_seqlens_k: torch.Tensor,
|
|
|
+ max_seqlen_q: int,
|
|
|
+ max_seqlen_k: int,
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int = -1,
|
|
|
+ window_size_right: int = -1,
|
|
|
+ softcap: float = 0.0,
|
|
|
+ alibi_slopes: Optional[torch.Tensor] = None,
|
|
|
+ return_softmax: bool = False,
|
|
|
+ block_table: Optional[torch.Tensor] = None,
|
|
|
+ leftpad_k: Optional[torch.Tensor] = None,
|
|
|
+ seqused_k: Optional[torch.Tensor] = None,
|
|
|
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
|
|
|
+ out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -104,36 +178,81 @@ def _flash_attn_varlen_forward(
|
|
|
softmax_scale,
|
|
|
False,
|
|
|
causal,
|
|
|
- window_size[0],
|
|
|
- window_size[1],
|
|
|
+ window_size_left,
|
|
|
+ window_size_right,
|
|
|
softcap,
|
|
|
return_softmax,
|
|
|
None,
|
|
|
)
|
|
|
# if out.isnan().any() or softmax_lse.isnan().any():
|
|
|
# breakpoint()
|
|
|
- return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
|
|
-
|
|
|
-
|
|
|
+ return out, softmax_lse, S_dmask, rng_state
|
|
|
+
|
|
|
+
|
|
|
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
|
|
|
+def _flash_attn_varlen_forward_fake(
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ cu_seqlens_q: torch.Tensor,
|
|
|
+ cu_seqlens_k: torch.Tensor,
|
|
|
+ max_seqlen_q: int,
|
|
|
+ max_seqlen_k: int,
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int = -1,
|
|
|
+ window_size_right: int = -1,
|
|
|
+ softcap: float = 0.0,
|
|
|
+ alibi_slopes: Optional[torch.Tensor] = None,
|
|
|
+ return_softmax: bool = False,
|
|
|
+ block_table: Optional[torch.Tensor] = None,
|
|
|
+ leftpad_k: Optional[torch.Tensor] = None,
|
|
|
+ seqused_k: Optional[torch.Tensor] = None,
|
|
|
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
+ q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
|
+ paged_kv = block_table is not None
|
|
|
+ batch_size = cu_seqlens_q.numel() - 1
|
|
|
+ total_q, num_heads, _ = q.shape
|
|
|
+
|
|
|
+ out = torch.empty_like(q)
|
|
|
+ softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
|
|
|
+ p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
|
+ seqlen_q_rounded = round_multiple(max_seqlen_q, 128)
|
|
|
+ seqlen_k_rounded = round_multiple(max_seqlen_k, 128)
|
|
|
+ if return_softmax:
|
|
|
+ p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
|
+ rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
|
|
|
+ return out, softmax_lse, p, rng_state
|
|
|
+
|
|
|
+
|
|
|
+if torch.__version__ >= "2.4.0":
|
|
|
+ _wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
|
|
|
+else:
|
|
|
+ _wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
|
|
|
+
|
|
|
+
|
|
|
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
|
|
|
def _flash_attn_backward(
|
|
|
- dout,
|
|
|
- q,
|
|
|
- k,
|
|
|
- v,
|
|
|
- out,
|
|
|
- softmax_lse,
|
|
|
- dq,
|
|
|
- dk,
|
|
|
- dv,
|
|
|
- dropout_p,
|
|
|
- softmax_scale,
|
|
|
- causal,
|
|
|
- window_size,
|
|
|
- softcap,
|
|
|
- alibi_slopes,
|
|
|
- deterministic,
|
|
|
- rng_state=None,
|
|
|
-):
|
|
|
+ dout: torch.Tensor,
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ out: torch.Tensor,
|
|
|
+ softmax_lse: torch.Tensor,
|
|
|
+ dq: Optional[torch.Tensor],
|
|
|
+ dk: Optional[torch.Tensor],
|
|
|
+ dv: Optional[torch.Tensor],
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int,
|
|
|
+ window_size_right: int,
|
|
|
+ softcap: float,
|
|
|
+ alibi_slopes: Optional[torch.Tensor],
|
|
|
+ deterministic: bool,
|
|
|
+ rng_state: Optional[torch.Tensor] = None,
|
|
|
+) -> torch.Tensor:
|
|
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
|
(
|
|
@@ -155,39 +274,81 @@ def _flash_attn_backward(
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal,
|
|
|
- window_size[0],
|
|
|
- window_size[1],
|
|
|
+ window_size_left,
|
|
|
+ window_size_right,
|
|
|
softcap,
|
|
|
deterministic,
|
|
|
None,
|
|
|
rng_state,
|
|
|
)
|
|
|
- return dq, dk, dv, softmax_d
|
|
|
+ return softmax_d
|
|
|
+
|
|
|
+
|
|
|
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
|
|
|
+def _flash_attn_backward_fake(
|
|
|
+ dout: torch.Tensor,
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ out: torch.Tensor,
|
|
|
+ softmax_lse: torch.Tensor,
|
|
|
+ dq: Optional[torch.Tensor],
|
|
|
+ dk: Optional[torch.Tensor],
|
|
|
+ dv: Optional[torch.Tensor],
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int,
|
|
|
+ window_size_right: int,
|
|
|
+ softcap: float,
|
|
|
+ alibi_slopes: Optional[torch.Tensor],
|
|
|
+ deterministic: bool,
|
|
|
+ rng_state: Optional[torch.Tensor] = None,
|
|
|
+) -> torch.Tensor:
|
|
|
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
|
+ if dq is None:
|
|
|
+ dq = torch.empty_like(q)
|
|
|
+ if dk is None:
|
|
|
+ dk = torch.empty_like(k)
|
|
|
+ if dv is None:
|
|
|
+ dv = torch.empty_like(v)
|
|
|
+ batch_size, seqlen_q, num_heads, _ = q.shape
|
|
|
+ softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
|
|
|
+
|
|
|
+ return softmax_d
|
|
|
+
|
|
|
|
|
|
+if torch.__version__ >= "2.4.0":
|
|
|
+ _wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
|
|
|
+else:
|
|
|
+ _wrapped_flash_attn_backward = _flash_attn_backward
|
|
|
|
|
|
+
|
|
|
+@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
|
|
|
def _flash_attn_varlen_backward(
|
|
|
- dout,
|
|
|
- q,
|
|
|
- k,
|
|
|
- v,
|
|
|
- out,
|
|
|
- softmax_lse,
|
|
|
- dq,
|
|
|
- dk,
|
|
|
- dv,
|
|
|
- cu_seqlens_q,
|
|
|
- cu_seqlens_k,
|
|
|
- max_seqlen_q,
|
|
|
- max_seqlen_k,
|
|
|
- dropout_p,
|
|
|
- softmax_scale,
|
|
|
- causal,
|
|
|
- window_size,
|
|
|
- softcap,
|
|
|
- alibi_slopes,
|
|
|
- deterministic,
|
|
|
- rng_state=None,
|
|
|
-):
|
|
|
+ dout: torch.Tensor,
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ out: torch.Tensor,
|
|
|
+ softmax_lse: torch.Tensor,
|
|
|
+ dq: Optional[torch.Tensor],
|
|
|
+ dk: Optional[torch.Tensor],
|
|
|
+ dv: Optional[torch.Tensor],
|
|
|
+ cu_seqlens_q: torch.Tensor,
|
|
|
+ cu_seqlens_k: torch.Tensor,
|
|
|
+ max_seqlen_q: int,
|
|
|
+ max_seqlen_k: int,
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int,
|
|
|
+ window_size_right: int,
|
|
|
+ softcap: float,
|
|
|
+ alibi_slopes: Optional[torch.Tensor],
|
|
|
+ deterministic: bool,
|
|
|
+ rng_state: Optional[torch.Tensor] = None,
|
|
|
+) -> torch.Tensor:
|
|
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
|
(
|
|
@@ -214,8 +375,8 @@ def _flash_attn_varlen_backward(
|
|
|
softmax_scale,
|
|
|
False,
|
|
|
causal,
|
|
|
- window_size[0],
|
|
|
- window_size[1],
|
|
|
+ window_size_left,
|
|
|
+ window_size_right,
|
|
|
softcap,
|
|
|
deterministic,
|
|
|
None,
|
|
@@ -223,7 +384,53 @@ def _flash_attn_varlen_backward(
|
|
|
)
|
|
|
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
|
|
# breakpoint()
|
|
|
- return dq, dk, dv, softmax_d
|
|
|
+ return softmax_d
|
|
|
+
|
|
|
+
|
|
|
+@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
|
|
|
+def _flash_attn_varlen_backward_fake(
|
|
|
+ dout: torch.Tensor,
|
|
|
+ q: torch.Tensor,
|
|
|
+ k: torch.Tensor,
|
|
|
+ v: torch.Tensor,
|
|
|
+ out: torch.Tensor,
|
|
|
+ softmax_lse: torch.Tensor,
|
|
|
+ dq: Optional[torch.Tensor],
|
|
|
+ dk: Optional[torch.Tensor],
|
|
|
+ dv: Optional[torch.Tensor],
|
|
|
+ cu_seqlens_q: torch.Tensor,
|
|
|
+ cu_seqlens_k: torch.Tensor,
|
|
|
+ max_seqlen_q: int,
|
|
|
+ max_seqlen_k: int,
|
|
|
+ dropout_p: float,
|
|
|
+ softmax_scale: float,
|
|
|
+ causal: bool,
|
|
|
+ window_size_left: int,
|
|
|
+ window_size_right: int,
|
|
|
+ softcap: float,
|
|
|
+ alibi_slopes: Optional[torch.Tensor],
|
|
|
+ deterministic: bool,
|
|
|
+ rng_state: Optional[torch.Tensor] = None,
|
|
|
+) -> torch.Tensor:
|
|
|
+ dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
|
+ batch_size = cu_seqlens_q.numel() - 1
|
|
|
+ total_q, num_heads, _ = q.shape
|
|
|
+
|
|
|
+ if dq is None:
|
|
|
+ dq = torch.empty_like(q)
|
|
|
+ if dk is None:
|
|
|
+ dk = torch.empty_like(k)
|
|
|
+ if dv is None:
|
|
|
+ dv = torch.empty_like(v)
|
|
|
+ softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
|
|
|
+
|
|
|
+ return softmax_d
|
|
|
+
|
|
|
+
|
|
|
+if torch.__version__ >= "2.4.0":
|
|
|
+ _wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
|
|
|
+else:
|
|
|
+ _wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
|
|
|
|
|
|
|
|
|
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
@@ -242,14 +449,21 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
|
|
- qkv[:, :, 0],
|
|
|
- qkv[:, :, 1],
|
|
|
- qkv[:, :, 2],
|
|
|
+ q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
|
|
|
+ head_size_og = q.size(3)
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
|
+ k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
|
+ v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
|
+ out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
- window_size=window_size,
|
|
|
+ window_size_left=window_size[0],
|
|
|
+ window_size_right=window_size[1],
|
|
|
softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
@@ -262,6 +476,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
+ out = out_padded[..., :head_size_og]
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
@staticmethod
|
|
@@ -269,8 +484,12 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
|
|
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
|
|
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
|
|
- _flash_attn_backward(
|
|
|
- dout,
|
|
|
+ head_size_og = dout.size(3)
|
|
|
+ dout_padded = dout
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
|
+ _wrapped_flash_attn_backward(
|
|
|
+ dout_padded,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -282,7 +501,8 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.dropout_p,
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
- ctx.window_size,
|
|
|
+ ctx.window_size[0],
|
|
|
+ ctx.window_size[1],
|
|
|
ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
@@ -310,10 +530,16 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
|
|
- qkv[:, 0],
|
|
|
- qkv[:, 1],
|
|
|
- qkv[:, 2],
|
|
|
+ q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
|
|
|
+ head_size_og = q.size(2)
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
|
+ k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
|
+ v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
|
+ out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
|
|
|
+ q,
|
|
|
+ k,
|
|
|
+ v,
|
|
|
cu_seqlens,
|
|
|
cu_seqlens,
|
|
|
max_seqlen,
|
|
@@ -321,7 +547,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
- window_size=window_size,
|
|
|
+ window_size_left=window_size[0],
|
|
|
+ window_size_right=window_size[1],
|
|
|
softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
@@ -336,6 +563,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
+ out = out_padded[..., :head_size_og]
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
@staticmethod
|
|
@@ -343,8 +571,12 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
|
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
|
|
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
|
|
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
|
|
- _flash_attn_varlen_backward(
|
|
|
- dout,
|
|
|
+ head_size_og = dout.size(2)
|
|
|
+ dout_padded = dout
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
|
+ _wrapped_flash_attn_varlen_backward(
|
|
|
+ dout_padded,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -360,7 +592,8 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.dropout_p,
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
- ctx.window_size,
|
|
|
+ ctx.window_size[0],
|
|
|
+ ctx.window_size[1],
|
|
|
ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
@@ -387,14 +620,21 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
|
|
+ k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
|
|
|
+ head_size_og = q.size(3)
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
|
+ k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
|
+ v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
|
+ out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
|
q,
|
|
|
- kv[:, :, 0],
|
|
|
- kv[:, :, 1],
|
|
|
+ k,
|
|
|
+ v,
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
- window_size=window_size,
|
|
|
+ window_size_left=window_size[0],
|
|
|
+ window_size_right=window_size[1],
|
|
|
softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
@@ -407,6 +647,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
+ out = out_padded[..., :head_size_og]
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
@staticmethod
|
|
@@ -415,8 +656,12 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|
|
dq = torch.empty_like(q)
|
|
|
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
|
|
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
|
|
- _flash_attn_backward(
|
|
|
- dout,
|
|
|
+ head_size_og = dout.size(3)
|
|
|
+ dout_padded = dout
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
|
+ _wrapped_flash_attn_backward(
|
|
|
+ dout_padded,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -428,7 +673,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.dropout_p,
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
- ctx.window_size,
|
|
|
+ ctx.window_size[0],
|
|
|
+ ctx.window_size[1],
|
|
|
ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
@@ -460,10 +706,16 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
|
|
+ k, v = kv[:, 0].detach(), kv[:, 1].detach()
|
|
|
+ head_size_og = q.size(2)
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
|
+ k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
|
+ v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
|
+ out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
|
|
|
q,
|
|
|
- kv[:, 0],
|
|
|
- kv[:, 1],
|
|
|
+ k,
|
|
|
+ v,
|
|
|
cu_seqlens_q,
|
|
|
cu_seqlens_k,
|
|
|
max_seqlen_q,
|
|
@@ -471,7 +723,8 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
- window_size=window_size,
|
|
|
+ window_size_left=window_size[0],
|
|
|
+ window_size_right=window_size[1],
|
|
|
softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
@@ -489,6 +742,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
+ out = out_padded[..., :head_size_og]
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
@staticmethod
|
|
@@ -497,8 +751,12 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
dq = torch.empty_like(q)
|
|
|
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
|
|
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
|
|
- _flash_attn_varlen_backward(
|
|
|
- dout,
|
|
|
+ head_size_og = dout.size(2)
|
|
|
+ dout_padded = dout
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
|
+ _wrapped_flash_attn_varlen_backward(
|
|
|
+ dout_padded,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -514,7 +772,8 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
|
ctx.dropout_p,
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
- ctx.window_size,
|
|
|
+ ctx.window_size[0],
|
|
|
+ ctx.window_size[1],
|
|
|
ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
@@ -543,14 +802,20 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
|
|
+ head_size_og = q.size(3)
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
|
+ k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
|
+ v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
|
+ out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
- window_size=window_size,
|
|
|
+ window_size_left=window_size[0],
|
|
|
+ window_size_right=window_size[1],
|
|
|
softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
@@ -563,14 +828,19 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
+ out = out_padded[..., :head_size_og]
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, dout, *args):
|
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
|
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
|
|
- _flash_attn_backward(
|
|
|
- dout,
|
|
|
+ head_size_og = dout.size(3)
|
|
|
+ dout_padded = dout
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
|
+ _wrapped_flash_attn_backward(
|
|
|
+ dout_padded,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -582,7 +852,8 @@ class FlashAttnFunc(torch.autograd.Function):
|
|
|
ctx.dropout_p,
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
- ctx.window_size,
|
|
|
+ ctx.window_size[0],
|
|
|
+ ctx.window_size[1],
|
|
|
ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|
|
@@ -617,7 +888,12 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
):
|
|
|
if softmax_scale is None:
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
|
- out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
|
|
+ head_size_og = q.size(2)
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
|
+ k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
|
+ v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
|
+ out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -628,7 +904,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
dropout_p,
|
|
|
softmax_scale,
|
|
|
causal=causal,
|
|
|
- window_size=window_size,
|
|
|
+ window_size_left=window_size[0],
|
|
|
+ window_size_right=window_size[1],
|
|
|
softcap=softcap,
|
|
|
alibi_slopes=alibi_slopes,
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
@@ -646,14 +923,19 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
ctx.softcap = softcap
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
ctx.deterministic = deterministic
|
|
|
+ out = out_padded[..., :head_size_og]
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, dout, *args):
|
|
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
|
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
|
|
- _flash_attn_varlen_backward(
|
|
|
- dout,
|
|
|
+ head_size_og = dout.size(2)
|
|
|
+ dout_padded = dout
|
|
|
+ if head_size_og % 8 != 0:
|
|
|
+ dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
|
+ _wrapped_flash_attn_varlen_backward(
|
|
|
+ dout_padded,
|
|
|
q,
|
|
|
k,
|
|
|
v,
|
|
@@ -669,7 +951,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
|
ctx.dropout_p,
|
|
|
ctx.softmax_scale,
|
|
|
ctx.causal,
|
|
|
- ctx.window_size,
|
|
|
+ ctx.window_size[0],
|
|
|
+ ctx.window_size[1],
|
|
|
ctx.softcap,
|
|
|
ctx.alibi_slopes,
|
|
|
ctx.deterministic,
|