123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- import math
- import torch
- from einops import rearrange, repeat
- from flash_attn.bert_padding import pad_input, unpad_input
- def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
- assert mode in ["full", "random", "third"]
- if mode == "full":
- lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
- elif mode == "random":
- lengths = torch.randint(
- max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
- )
- elif mode == "third":
- lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
- if zero_lengths:
- # Generate zero-lengths every 5 batches and the last batch.
- for i in range(batch_size):
- if i % 5 == 0:
- lengths[i] = 0
- lengths[-1] = 0
- padding_mask = (
- repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
- )
- return padding_mask
- def generate_qkv(
- q, k, v, query_padding_mask=None, key_padding_mask=None,
- kvpacked=False, qkvpacked=False, add_unused_qkv=False,
- query_unused_mask=None, key_unused_mask=None,
- ):
- """
- Arguments:
- q: (batch_size, seqlen_q, nheads, d)
- k: (batch_size, seqlen_k, nheads_k, d)
- v: (batch_size, seqlen_k, nheads_k, d)
- query_padding_mask: (batch_size, seqlen), bool
- key_padding_mask: (batch_size, seqlen), bool
- """
- assert not (kvpacked and qkvpacked)
- batch_size, seqlen_q, nheads, d = q.shape
- _, seqlen_k, nheads_k, _ = k.shape
- assert k.shape == (batch_size, seqlen_k, nheads_k, d)
- assert v.shape == (batch_size, seqlen_k, nheads_k, d)
- if query_unused_mask is not None or key_unused_mask is not None:
- assert not kvpacked
- assert not qkvpacked
- if query_padding_mask is not None:
- q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
- q, query_padding_mask, query_unused_mask,
- )
- output_pad_fn = lambda output_unpad: pad_input(
- output_unpad, indices_q, batch_size, seqlen_q
- )
- else:
- q_unpad = rearrange(q, "b s h d -> (b s) h d")
- cu_seqlens_q = torch.arange(
- 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
- )
- seqused_q = None
- max_seqlen_q = seqlen_q
- output_pad_fn = lambda output_unpad: rearrange(
- output_unpad, "(b s) h d -> b s h d", b=batch_size
- )
- if key_padding_mask is not None:
- k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask)
- v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
- else:
- k_unpad = rearrange(k, "b s h d -> (b s) h d")
- v_unpad = rearrange(v, "b s h d -> (b s) h d")
- cu_seqlens_k = torch.arange(
- 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
- )
- seqused_k = None
- max_seqlen_k = seqlen_k
- if qkvpacked:
- assert (query_padding_mask == key_padding_mask).all()
- assert nheads == nheads_k
- qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
- qkv = torch.stack([q, k, v], dim=2)
- if query_padding_mask is not None:
- dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
- else:
- dqkv_pad_fn = lambda dqkv_unpad: rearrange(
- dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
- )
- return (
- qkv_unpad.detach().requires_grad_(),
- cu_seqlens_q,
- max_seqlen_q,
- qkv.detach().requires_grad_(),
- output_pad_fn,
- dqkv_pad_fn,
- )
- elif kvpacked:
- kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
- kv = torch.stack([k, v], dim=2)
- dq_pad_fn = output_pad_fn
- if key_padding_mask is not None:
- dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k)
- else:
- dkv_pad_fn = lambda dkv_unpad: rearrange(
- dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
- )
- return (
- q_unpad.detach().requires_grad_(),
- kv_unpad.detach().requires_grad_(),
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- q.detach().requires_grad_(),
- kv.detach().requires_grad_(),
- output_pad_fn,
- dq_pad_fn,
- dkv_pad_fn,
- )
- else:
- dq_pad_fn = output_pad_fn
- if key_padding_mask is not None:
- dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k)
- else:
- dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size)
- return (
- q_unpad.detach().requires_grad_(),
- k_unpad.detach().requires_grad_(),
- v_unpad.detach().requires_grad_(),
- cu_seqlens_q,
- cu_seqlens_k,
- seqused_q,
- seqused_k,
- max_seqlen_q,
- max_seqlen_k,
- q.detach().requires_grad_(),
- k.detach().requires_grad_(),
- v.detach().requires_grad_(),
- output_pad_fn,
- dq_pad_fn,
- dk_pad_fn,
- )
- def construct_local_mask(
- seqlen_q,
- seqlen_k,
- window_size=(-1, -1), # -1 means infinite window size
- query_padding_mask=None,
- key_padding_mask=None,
- device=None,
- key_leftpad=None,
- ):
- row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
- col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
- if key_leftpad is not None:
- key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1")
- col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0])
- col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32)
- sk = (
- seqlen_k
- if key_padding_mask is None
- else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
- )
- sq = (
- seqlen_q
- if query_padding_mask is None
- else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
- )
- if window_size[0] < 0:
- return col_idx > row_idx + sk - sq + window_size[1]
- else:
- sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
- return torch.logical_or(
- col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
- col_idx < row_idx + sk - sq - window_size[0],
- )
- def attention_ref(
- q,
- k,
- v,
- query_padding_mask=None,
- key_padding_mask=None,
- attn_bias=None,
- dropout_p=0.0,
- dropout_mask=None,
- causal=False,
- window_size=(-1, -1), # -1 means infinite window size
- softcap=0.0,
- upcast=True,
- reorder_ops=False,
- key_leftpad=None,
- ):
- """
- Arguments:
- q: (batch_size, seqlen_q, nheads, head_dim)
- k: (batch_size, seqlen_k, nheads_k, head_dim)
- v: (batch_size, seqlen_k, nheads_k, head_dim)
- query_padding_mask: (batch_size, seqlen_q)
- key_padding_mask: (batch_size, seqlen_k)
- attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
- dropout_p: float
- dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
- causal: whether to apply causal masking
- window_size: (int, int), left and right window size
- upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
- output back to fp16/bf16.
- reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.)
- without changing the math. This is to estimate the numerical error from operation
- reordering.
- Output:
- output: (batch_size, seqlen_q, nheads, head_dim)
- attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
- """
- if causal:
- window_size = (window_size[0], 0)
- dtype_og = q.dtype
- if upcast:
- q, k, v = q.float(), k.float(), v.float()
- seqlen_q, seqlen_k = q.shape[1], k.shape[1]
- k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
- v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
- d = q.shape[-1]
- if not reorder_ops:
- scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
- else:
- scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
- if softcap > 0:
- scores /= softcap
- scores = scores.tanh()
- scores *= softcap
- if key_padding_mask is not None:
- scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
- if window_size[0] >= 0 or window_size[1] >= 0:
- local_mask = construct_local_mask(
- seqlen_q,
- seqlen_k,
- window_size,
- query_padding_mask,
- key_padding_mask,
- q.device,
- key_leftpad=key_leftpad,
- )
- scores.masked_fill_(local_mask, float("-inf"))
- if attn_bias is not None:
- scores = scores + attn_bias
- attention = torch.softmax(scores, dim=-1).to(v.dtype)
- # Some rows might be completely masked out so we fill them with zero instead of NaN
- if window_size[0] >= 0 or window_size[1] >= 0:
- attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0)
- # We want to mask here so that the attention matrix doesn't have any NaNs
- # Otherwise we'll get NaN in dV
- if query_padding_mask is not None:
- attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
- dropout_scaling = 1.0 / (1 - dropout_p)
- # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
- # output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
- if dropout_mask is not None:
- attention_drop = attention.masked_fill(~dropout_mask, 0.0)
- else:
- attention_drop = attention
- output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
- if query_padding_mask is not None:
- output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
- if key_padding_mask is not None:
- output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0)
- return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|