123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020 |
- # Copyright (c) 2023, Tri Dao.
- import math
- from functools import partial
- import torch
- import torch.nn as nn
- from einops import rearrange, repeat
- from flash_attn.utils.distributed import get_dim_for_local_rank
- try:
- from flash_attn import (
- flash_attn_kvpacked_func,
- flash_attn_qkvpacked_func,
- flash_attn_varlen_kvpacked_func,
- flash_attn_varlen_qkvpacked_func,
- flash_attn_with_kvcache,
- )
- except ImportError:
- flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
- flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
- flash_attn_with_kvcache = None
- try:
- from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
- except ImportError:
- FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
- try:
- from flash_attn.layers.rotary import RotaryEmbedding
- except ImportError:
- RotaryEmbedding = None
- # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
- def get_alibi_slopes(nheads):
- def get_slopes_power_of_2(nheads):
- start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
- ratio = start
- return [start * ratio**i for i in range(nheads)]
- if math.log2(nheads).is_integer():
- return get_slopes_power_of_2(nheads)
- else:
- closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
- return (
- get_slopes_power_of_2(closest_power_of_2)
- + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
- )
- class FlashSelfAttention(nn.Module):
- """Implement the scaled dot product attention with softmax.
- Arguments
- ---------
- softmax_scale: The temperature to use for the softmax attention.
- (default: 1/sqrt(d_keys) where d_keys is computed at
- runtime)
- attention_dropout: The dropout rate to apply to the attention
- (default: 0.0)
- """
- def __init__(
- self,
- causal=False,
- softmax_scale=None,
- attention_dropout=0.0,
- window_size=(-1, -1),
- alibi_slopes=None,
- deterministic=False,
- ):
- super().__init__()
- assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
- assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
- self.causal = causal
- self.softmax_scale = softmax_scale
- self.drop = nn.Dropout(attention_dropout)
- self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
- self.window_size = window_size
- self.deterministic = deterministic
- def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
- """Implements the multihead softmax attention.
- Arguments
- ---------
- qkv: The tensor containing the query, key, and value.
- If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
- If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
- (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
- causal: if passed, will override self.causal
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into qkv.
- max_seqlen: int. Maximum sequence length in the batch.
- Returns:
- --------
- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
- else (B, S, H, D).
- """
- assert qkv.dtype in [torch.float16, torch.bfloat16]
- assert qkv.is_cuda
- causal = self.causal if causal is None else causal
- unpadded = cu_seqlens is not None
- if self.alibi_slopes is not None:
- self.alibi_slopes = self.alibi_slopes.to(torch.float32)
- if unpadded:
- assert cu_seqlens.dtype == torch.int32
- assert max_seqlen is not None
- assert isinstance(max_seqlen, int)
- return flash_attn_varlen_qkvpacked_func(
- qkv,
- cu_seqlens,
- max_seqlen,
- self.drop.p if self.training else 0.0,
- softmax_scale=self.softmax_scale,
- causal=causal,
- alibi_slopes=self.alibi_slopes,
- window_size=self.window_size,
- deterministic=self.deterministic,
- )
- else:
- return flash_attn_qkvpacked_func(
- qkv,
- self.drop.p if self.training else 0.0,
- softmax_scale=self.softmax_scale,
- causal=causal,
- alibi_slopes=self.alibi_slopes,
- window_size=self.window_size,
- deterministic=self.deterministic,
- )
- class FlashCrossAttention(nn.Module):
- """Implement the scaled dot product attention with softmax.
- Arguments
- ---------
- softmax_scale: The temperature to use for the softmax attention.
- (default: 1/sqrt(d_keys) where d_keys is computed at
- runtime)
- attention_dropout: The dropout rate to apply to the attention
- (default: 0.0)
- """
- def __init__(
- self,
- causal=False,
- softmax_scale=None,
- attention_dropout=0.0,
- alibi_slopes=None,
- window_size=(-1, -1),
- deterministic=False,
- ):
- super().__init__()
- assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
- assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
- self.causal = causal
- self.softmax_scale = softmax_scale
- self.drop = nn.Dropout(attention_dropout)
- self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
- self.window_size = window_size
- self.deterministic = deterministic
- def forward(
- self,
- q,
- kv,
- causal=None,
- cu_seqlens=None,
- max_seqlen=None,
- cu_seqlens_k=None,
- max_seqlen_k=None,
- ):
- """Implements the multihead softmax attention.
- Arguments
- ---------
- q: The tensor containing the query. (B, Sq, H, D)
- kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
- causal: if passed, will override self.causal
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into q.
- max_seqlen: int. Maximum sequence length in the batch of q.
- cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into kv.
- max_seqlen_k: int. Maximum sequence length in the batch of k and v.
- """
- assert q.dtype in [torch.float16, torch.bfloat16]
- assert q.is_cuda and kv.is_cuda
- causal = self.causal if causal is None else causal
- unpadded = cu_seqlens is not None
- if self.alibi_slopes is not None:
- self.alibi_slopes = self.alibi_slopes.to(torch.float32)
- if unpadded:
- assert cu_seqlens.dtype == torch.int32
- assert max_seqlen is not None
- assert isinstance(max_seqlen, int)
- assert cu_seqlens_k is not None
- assert cu_seqlens_k.dtype == torch.int32
- assert max_seqlen_k is not None
- assert isinstance(max_seqlen_k, int)
- return flash_attn_varlen_kvpacked_func(
- q,
- kv,
- cu_seqlens,
- cu_seqlens_k,
- max_seqlen,
- max_seqlen_k,
- self.drop.p if self.training else 0.0,
- softmax_scale=self.softmax_scale,
- causal=causal,
- alibi_slopes=self.alibi_slopes,
- window_size=self.window_size,
- deterministic=self.deterministic,
- )
- else:
- batch_size, seqlen_q = q.shape[0], q.shape[1]
- seqlen_k = kv.shape[1]
- assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
- return flash_attn_kvpacked_func(
- q,
- kv,
- self.drop.p if self.training else 0.0,
- causal=causal,
- softmax_scale=self.softmax_scale,
- alibi_slopes=self.alibi_slopes,
- window_size=self.window_size,
- deterministic=self.deterministic,
- )
- class SelfAttention(nn.Module):
- """Implement the scaled dot product attention with softmax.
- Arguments
- ---------
- softmax_scale: The temperature to use for the softmax attention.
- (default: 1/sqrt(d_keys) where d_keys is computed at
- runtime)
- attention_dropout: The dropout rate to apply to the attention
- (default: 0.0)
- """
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
- super().__init__()
- self.causal = causal
- self.softmax_scale = softmax_scale
- self.drop = nn.Dropout(attention_dropout)
- def forward(self, qkv, causal=None, key_padding_mask=None):
- """Implements the multihead softmax attention.
- Arguments
- ---------
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
- causal: if passed, will override self.causal
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
- False means to mask out. (B, S)
- """
- batch_size, seqlen = qkv.shape[0], qkv.shape[1]
- causal = self.causal if causal is None else causal
- q, k, v = qkv.unbind(dim=2)
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
- if key_padding_mask is not None:
- padding_mask = torch.full(
- (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
- )
- padding_mask.masked_fill_(key_padding_mask, 0.0)
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
- if causal:
- # "triu_tril_cuda_template" not implemented for 'BFloat16'
- # So we have to construct the mask in float
- causal_mask = torch.triu(
- torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
- )
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
- scores = scores + causal_mask.to(dtype=scores.dtype)
- attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
- attention_drop = self.drop(attention)
- output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
- return output
- class CrossAttention(nn.Module):
- """Implement the scaled dot product attention with softmax.
- Arguments
- ---------
- softmax_scale: The temperature to use for the softmax attention.
- (default: 1/sqrt(d_keys) where d_keys is computed at
- runtime)
- attention_dropout: The dropout rate to apply to the attention
- (default: 0.0)
- """
- def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
- super().__init__()
- self.causal = causal
- self.softmax_scale = softmax_scale
- self.drop = nn.Dropout(attention_dropout)
- def forward(self, q, kv, causal=None, key_padding_mask=None):
- """Implements the multihead softmax attention.
- Arguments
- ---------
- q: The tensor containing the query. (B, Sq, H, D)
- kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
- causal: if passed, will override self.causal
- key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
- False means to mask out. (B, Sk)
- """
- batch_size, seqlen_q = q.shape[0], q.shape[1]
- causal = self.causal if causal is None else causal
- seqlen_k = kv.shape[1]
- assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
- if kv.shape[3] != q.shape[2]: # MQA/GQA
- kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
- k, v = kv.unbind(dim=2)
- softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
- scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
- if key_padding_mask is not None:
- padding_mask = torch.full(
- (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
- )
- padding_mask.masked_fill_(key_padding_mask, 0.0)
- # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
- scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
- if causal:
- # causal mask needs to take into account the difference between seqlen_q and seqlen_k
- row_idx = rearrange(
- torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
- )
- col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
- sk = (
- seqlen_k
- if key_padding_mask is None
- else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
- )
- causal_mask = col_idx > row_idx + sk - seqlen_q
- scores = scores.masked_fill(causal_mask, -10000.0)
- attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
- attention_drop = self.drop(attention)
- output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
- return output
- class LinearResidual(nn.Linear):
- """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- return super().forward(input), input
- def _update_kv_cache(kv, inference_params, layer_idx):
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
- # Pre-allocate memory for key-values for inference.
- num_heads, head_dim = kv.shape[-2:]
- if layer_idx not in inference_params.key_value_memory_dict:
- kv_cache = torch.empty(
- inference_params.max_batch_size,
- inference_params.max_seqlen,
- 2,
- num_heads,
- head_dim,
- dtype=kv.dtype,
- device=kv.device,
- )
- inference_params.key_value_memory_dict[layer_idx] = kv_cache
- else:
- kv_cache = inference_params.key_value_memory_dict[layer_idx]
- # Adjust key and value for inference
- batch_start = inference_params.batch_size_offset
- batch_end = batch_start + kv.shape[0]
- sequence_start = inference_params.seqlen_offset
- sequence_end = sequence_start + kv.shape[1]
- assert batch_end <= kv_cache.shape[0]
- assert sequence_end <= kv_cache.shape[1]
- assert kv_cache is not None
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
- return kv_cache[batch_start:batch_end, :sequence_end, ...]
- class MHA(nn.Module):
- """Multi-head self-attention and cross-attention"""
- def __init__(
- self,
- embed_dim,
- num_heads,
- num_heads_kv=None,
- cross_attn=False,
- qkv_proj_bias=True,
- out_proj_bias=True,
- dropout=0.0,
- softmax_scale=None,
- causal=False,
- layer_idx=None,
- dwconv=False,
- rotary_emb_dim=0,
- rotary_emb_base=10000.0,
- rotary_emb_scale_base=None,
- rotary_emb_interleaved=False,
- use_alibi=False,
- window_size=(-1, -1),
- fused_bias_fc=False,
- use_flash_attn=False,
- return_residual=False,
- checkpointing=False,
- device=None,
- dtype=None,
- ) -> None:
- """
- num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
- return_residual: whether to return the input x along with the output. This is for
- performance reason: for post-norm architecture, returning the input allows us
- to fuse the backward of nn.Linear with the residual connection.
- """
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.embed_dim = embed_dim
- self.cross_attn = cross_attn
- self.causal = causal
- self.layer_idx = layer_idx
- self.dwconv = dwconv
- self.rotary_emb_dim = rotary_emb_dim
- self.use_flash_attn = use_flash_attn
- self.return_residual = return_residual
- self.checkpointing = checkpointing
- if use_alibi:
- assert use_flash_attn, "ALiBi code path requires flash_attn"
- alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
- else:
- alibi_slopes = None
- if window_size != (-1, -1):
- assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
- self.num_heads = num_heads
- self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
- assert (
- self.num_heads % self.num_heads_kv == 0
- ), "num_heads must be divisible by num_heads_kv"
- assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
- self.head_dim = self.embed_dim // num_heads
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
- kv_dim = 2 * self.head_dim * self.num_heads_kv
- if self.rotary_emb_dim > 0:
- assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
- assert RotaryEmbedding is not None, "rotary_emb is not installed"
- self.rotary_emb = RotaryEmbedding(
- self.rotary_emb_dim,
- base=rotary_emb_base,
- scale_base=rotary_emb_scale_base,
- interleaved=rotary_emb_interleaved,
- device=device,
- )
- if fused_bias_fc and FusedDense is None:
- raise ImportError("fused_dense is not installed")
- linear_cls = nn.Linear if not fused_bias_fc else FusedDense
- linear_resid_cls = (
- LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
- )
- wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
- inner_attn_cls = (
- partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
- if use_flash_attn
- else SelfAttention
- )
- inner_cross_attn_cls = (
- partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
- if use_flash_attn
- else CrossAttention
- )
- if not self.cross_attn:
- self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
- else:
- self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
- self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
- if self.dwconv:
- if self.num_heads_kv == self.num_heads:
- self.dwconv_qkv = nn.Conv1d(
- qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
- )
- else:
- self.dwconv_q = nn.Conv1d(
- embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
- )
- self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
- self.inner_attn = inner_attn_cls(
- causal=causal,
- softmax_scale=softmax_scale,
- attention_dropout=dropout,
- )
- self.inner_cross_attn = inner_cross_attn_cls(
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
- )
- self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
- dtype = self.out_proj.weight.dtype if dtype is None else dtype
- device = self.out_proj.weight.device
- return torch.empty(
- batch_size,
- max_seqlen,
- 2,
- self.num_heads_kv,
- self.head_dim,
- dtype=dtype,
- device=device,
- )
- def _update_kv_cache(self, kv, inference_params):
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
- assert not self.dwconv, "Generation does not support dwconv yet"
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
- return _update_kv_cache(kv, inference_params, self.layer_idx)
- def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
- """
- Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
- q: (batch_size, seqlen_q, nheads, head_dim)
- kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
- """
- assert inference_params is not None and inference_params.seqlen_offset > 0
- assert self.use_flash_attn
- if self.rotary_emb_dim > 0:
- assert self.rotary_emb.scale is None, "This code path does not support xPos"
- self.rotary_emb._update_cos_sin_cache(
- inference_params.max_seqlen, device=q.device, dtype=q.dtype
- )
- rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
- else:
- rotary_cos, rotary_sin = None, None
- batch = q.shape[0]
- kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
- cache_seqlens = (
- inference_params.lengths_per_sample[:batch]
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
- context = flash_attn_with_kvcache(
- q,
- kv_cache[:, :, 0],
- kv_cache[:, :, 1],
- kv[:, :, 0],
- kv[:, :, 1],
- rotary_cos=rotary_cos,
- rotary_sin=rotary_sin,
- cache_seqlens=cache_seqlens,
- softmax_scale=self.inner_cross_attn.softmax_scale,
- causal=self.inner_cross_attn.causal,
- rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
- alibi_slopes=alibi_slopes,
- )
- return context
- def _update_kvcache_attention(self, q, kv, inference_params):
- """Write kv to inference_params, then do attention"""
- if (
- inference_params.seqlen_offset == 0
- or flash_attn_with_kvcache is None
- or not self.use_flash_attn
- ):
- # TODO: this only uses seqlen_offset and not lengths_per_sample.
- kv = self._update_kv_cache(kv, inference_params)
- return self.inner_cross_attn(q, kv)
- else:
- batch = q.shape[0]
- kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
- cache_seqlens = (
- inference_params.lengths_per_sample[:batch]
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
- return flash_attn_with_kvcache(
- q,
- kv_cache[:, :, 0],
- kv_cache[:, :, 1],
- kv[:, :, 0],
- kv[:, :, 1],
- cache_seqlens=cache_seqlens,
- softmax_scale=self.inner_cross_attn.softmax_scale,
- causal=self.inner_cross_attn.causal,
- alibi_slopes=alibi_slopes,
- )
- def forward(
- self,
- x,
- x_kv=None,
- key_padding_mask=None,
- cu_seqlens=None,
- max_seqlen=None,
- mixer_subset=None,
- inference_params=None,
- **kwargs,
- ):
- """
- Arguments:
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
- cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
- is the is the sum of the sequence lengths in the batch.
- x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
- cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
- of the sequences in the batch, used to index into x. Only applicable when using
- FlashAttention.
- max_seqlen: int. Maximum sequence length in the batch.
- key_padding_mask: boolean mask, True means to keep, False means to mask out.
- (batch, seqlen). Only applicable when not using FlashAttention.
- mixer_subset: for cross-attention only. If not None, will take a subset of x
- before applying the query projection. Useful for e.g., ViT where we only care
- about the CLS token in the last layer.
- inference_params: for generation. Adapted from Megatron-LM (and Apex)
- https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
- """
- if cu_seqlens is not None:
- assert max_seqlen is not None
- assert key_padding_mask is None
- assert self.use_flash_attn
- assert not self.dwconv
- assert self.rotary_emb_dim == 0
- if key_padding_mask is not None:
- assert cu_seqlens is None
- assert max_seqlen is None
- assert not self.use_flash_attn
- if inference_params is not None:
- assert key_padding_mask is None
- assert cu_seqlens is None and max_seqlen is None
- assert not self.dwconv
- kwargs = (
- {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
- if self.use_flash_attn
- else {"key_padding_mask": key_padding_mask, **kwargs}
- )
- seqlen_offset = (
- 0
- if inference_params is None
- else (
- inference_params.lengths_per_sample
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- )
- rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
- batch, seqlen = x.shape[:2]
- if not self.cross_attn and self.num_heads_kv == self.num_heads:
- assert x_kv is None and mixer_subset is None
- if not self.return_residual:
- qkv = self.Wqkv(x)
- else:
- qkv, x = self.Wqkv(x)
- if self.dwconv:
- qkv = rearrange(
- self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
- ).contiguous()
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
- if (
- inference_params is None
- or inference_params.seqlen_offset == 0
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
- or not self.use_flash_attn
- ):
- if self.rotary_emb_dim > 0:
- qkv = self.rotary_emb(
- qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
- )
- if inference_params is None:
- if not self.checkpointing:
- context = self.inner_attn(qkv, **kwargs)
- else:
- context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
- else:
- context = self._update_kvcache_attention(
- qkv[:, :, 0], qkv[:, :, 1:], inference_params
- )
- else:
- context = self._apply_rotary_update_kvcache_attention(
- qkv[:, :, 0], qkv[:, :, 1:], inference_params
- )
- else:
- if self.cross_attn:
- if not self.return_residual:
- q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
- kv = self.Wkv(x_kv if x_kv is not None else x)
- else:
- if x_kv is not None:
- kv, x_kv = self.Wkv(x_kv)
- else:
- kv, x = self.Wkv(x)
- q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
- else:
- assert self.num_heads_kv != self.num_heads
- if not self.return_residual:
- qkv = self.Wqkv(x)
- else:
- qkv, x = self.Wqkv(x)
- q = qkv[..., : self.num_heads * self.head_dim]
- kv = qkv[..., self.num_heads * self.head_dim :]
- q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
- if self.dwconv:
- q = rearrange(
- self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
- ).contiguous()
- kv = rearrange(
- self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
- ).contiguous()
- if (
- inference_params is None
- or inference_params.seqlen_offset == 0
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
- or not self.use_flash_attn
- ):
- if self.rotary_emb_dim > 0:
- q, kv = self.rotary_emb(
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
- )
- if inference_params is None:
- if not self.checkpointing:
- context = self.inner_cross_attn(q, kv, **kwargs)
- else:
- context = torch.utils.checkpoint.checkpoint(
- self.inner_cross_attn, q, kv, **kwargs
- )
- else:
- context = self._update_kvcache_attention(q, kv, inference_params)
- else:
- context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
- return out if not self.return_residual else (out, x)
- class ParallelMHA(nn.Module):
- """Multi-head self-attention and cross-attention"""
- def __init__(
- self,
- embed_dim,
- num_heads,
- process_group,
- num_heads_kv=None,
- qkv_proj_bias=True,
- out_proj_bias=True,
- dropout=0.0,
- softmax_scale=None,
- causal=False,
- layer_idx=None,
- rotary_emb_dim=0,
- rotary_emb_base=10000.0,
- rotary_emb_scale_base=None,
- rotary_emb_interleaved=False,
- use_alibi=False,
- window_size=(-1, -1),
- use_flash_attn=False,
- checkpointing=False,
- sequence_parallel=True,
- device=None,
- dtype=None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.embed_dim = embed_dim
- self.causal = causal
- self.layer_idx = layer_idx
- self.rotary_emb_dim = rotary_emb_dim
- self.use_flash_attn = use_flash_attn
- self.checkpointing = checkpointing
- self.process_group = process_group
- self.world_size = process_group.size()
- self.local_rank = torch.distributed.get_rank(process_group)
- self.num_heads = num_heads
- assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
- self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
- assert (
- self.num_heads % self.num_heads_kv == 0
- ), "num_heads must be divisible by num_heads_kv"
- self.num_heads_per_rank = get_dim_for_local_rank(
- self.num_heads, self.world_size, self.local_rank
- )
- self.num_heads_kv_per_rank = get_dim_for_local_rank(
- self.num_heads_kv, self.world_size, self.local_rank
- )
- self.head_dim = self.embed_dim // num_heads
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
- if use_alibi:
- assert use_flash_attn, "ALiBi code path requires flash_attn"
- num_heads_local = math.ceil(self.num_heads / self.world_size)
- alibi_slopes = torch.tensor(
- get_alibi_slopes(num_heads)[
- self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
- ],
- device=device,
- )
- else:
- alibi_slopes = None
- if window_size != (-1, -1):
- assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
- if self.rotary_emb_dim > 0:
- assert RotaryEmbedding is not None, "rotary_emb is not installed"
- self.rotary_emb = RotaryEmbedding(
- self.rotary_emb_dim,
- base=rotary_emb_base,
- scale_base=rotary_emb_scale_base,
- interleaved=rotary_emb_interleaved,
- device=device,
- )
- if ColumnParallelLinear is None or RowParallelLinear is None:
- raise ImportError("fused_dense is not installed")
- self.Wqkv = ColumnParallelLinear(
- embed_dim,
- qkv_dim,
- process_group,
- bias=qkv_proj_bias,
- sequence_parallel=sequence_parallel,
- multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
- **factory_kwargs,
- )
- inner_attn_cls = (
- partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
- if use_flash_attn
- else SelfAttention
- )
- inner_cross_attn_cls = (
- partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
- if use_flash_attn
- else CrossAttention
- )
- self.inner_attn = inner_attn_cls(
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
- )
- self.inner_cross_attn = inner_cross_attn_cls(
- causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
- )
- self.out_proj = RowParallelLinear(
- embed_dim,
- embed_dim,
- process_group,
- bias=out_proj_bias,
- sequence_parallel=sequence_parallel,
- multiple_of=self.head_dim,
- **factory_kwargs,
- )
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
- dtype = self.out_proj.weight.dtype if dtype is None else dtype
- device = self.out_proj.weight.device
- return torch.empty(
- batch_size,
- max_seqlen,
- 2,
- self.num_heads_kv_per_rank,
- self.head_dim,
- dtype=dtype,
- device=device,
- )
- def _update_kv_cache(self, kv, inference_params):
- """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
- return _update_kv_cache(kv, inference_params, self.layer_idx)
- def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
- """
- Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
- q: (batch_size, seqlen_q, nheads, head_dim)
- kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
- """
- assert inference_params is not None and inference_params.seqlen_offset > 0
- assert self.use_flash_attn
- if self.rotary_emb_dim > 0:
- assert self.rotary_emb.scale is None, "This code path does not support xPos"
- self.rotary_emb._update_cos_sin_cache(
- inference_params.max_seqlen, device=q.device, dtype=q.dtype
- )
- rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
- else:
- rotary_cos, rotary_sin = None, None
- batch = q.shape[0]
- kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
- cache_seqlens = (
- inference_params.lengths_per_sample[:batch]
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
- context = flash_attn_with_kvcache(
- q,
- kv_cache[:, :, 0],
- kv_cache[:, :, 1],
- kv[:, :, 0],
- kv[:, :, 1],
- rotary_cos=rotary_cos,
- rotary_sin=rotary_sin,
- cache_seqlens=cache_seqlens,
- softmax_scale=self.inner_cross_attn.softmax_scale,
- causal=self.inner_cross_attn.causal,
- rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
- alibi_slopes=alibi_slopes,
- )
- return context
- def _update_kvcache_attention(self, q, kv, inference_params):
- """Write kv to inference_params, then do attention"""
- if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
- # TODO: this only uses seqlen_offset and not lengths_per_sample.
- kv = self._update_kv_cache(kv, inference_params)
- return self.inner_cross_attn(q, kv)
- else:
- batch = q.shape[0]
- kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
- cache_seqlens = (
- inference_params.lengths_per_sample[:batch]
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
- context = flash_attn_with_kvcache(
- q,
- kv_cache[:, :, 0],
- kv_cache[:, :, 1],
- kv[:, :, 0],
- kv[:, :, 1],
- cache_seqlens=cache_seqlens,
- softmax_scale=self.inner_cross_attn.softmax_scale,
- causal=self.inner_cross_attn.causal,
- alibi_slopes=alibi_slopes,
- )
- return context
- def forward(self, x, seqlen=None, inference_params=None, **kwargs):
- """
- Arguments:
- x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
- If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
- split x during sequence parallel, we split the batch * seqlen dimension
- (in case batch is small).
- """
- qkv = self.Wqkv(x)
- if seqlen is not None:
- qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
- seqlen_offset = (
- 0
- if inference_params is None
- else (
- inference_params.lengths_per_sample
- if inference_params.lengths_per_sample is not None
- else inference_params.seqlen_offset
- )
- )
- rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
- if self.num_heads_kv == self.num_heads:
- qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
- if (
- inference_params is None
- or inference_params.seqlen_offset == 0
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
- or not self.use_flash_attn
- ):
- if self.rotary_emb_dim > 0:
- qkv = self.rotary_emb(
- qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
- )
- if inference_params is None:
- if not self.checkpointing:
- context = self.inner_attn(qkv, **kwargs)
- else:
- context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
- else:
- context = self._update_kvcache_attention(
- qkv[:, :, 0], qkv[:, :, 1:], inference_params
- )
- else:
- context = self._apply_rotary_update_kvcache_attention(
- qkv[:, :, 0], qkv[:, :, 1:], inference_params
- )
- else:
- q = rearrange(
- qkv[..., : self.num_heads_per_rank * self.head_dim],
- "... (h d) -> ... h d",
- d=self.head_dim,
- )
- kv = rearrange(
- qkv[..., self.num_heads_per_rank * self.head_dim :],
- "... (two hkv d) -> ... two hkv d",
- two=2,
- d=self.head_dim,
- )
- if (
- inference_params is None
- or inference_params.seqlen_offset == 0
- or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
- or not self.use_flash_attn
- ):
- if self.rotary_emb_dim > 0:
- q, kv = self.rotary_emb(
- q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
- )
- if inference_params is None:
- if not self.checkpointing:
- context = self.inner_cross_attn(q, kv, **kwargs)
- else:
- context = torch.utils.checkpoint.checkpoint(
- self.inner_cross_attn, q, kv, **kwargs
- )
- else:
- context = self._update_kvcache_attention(q, kv, inference_params)
- else:
- context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
- context = rearrange(context, "b s h d -> b s (h d)")
- if seqlen is not None:
- context = rearrange(context, "b s d -> (b s) d")
- out = self.out_proj(context)
- return out
|