mha.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. from functools import partial
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange, repeat
  7. from flash_attn.utils.distributed import get_dim_for_local_rank
  8. try:
  9. from flash_attn import (
  10. flash_attn_kvpacked_func,
  11. flash_attn_qkvpacked_func,
  12. flash_attn_varlen_kvpacked_func,
  13. flash_attn_varlen_qkvpacked_func,
  14. flash_attn_with_kvcache,
  15. )
  16. except ImportError:
  17. flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
  18. flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
  19. flash_attn_with_kvcache = None
  20. try:
  21. from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
  22. except ImportError:
  23. FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
  24. try:
  25. from flash_attn.layers.rotary import RotaryEmbedding
  26. except ImportError:
  27. RotaryEmbedding = None
  28. # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
  29. def get_alibi_slopes(nheads):
  30. def get_slopes_power_of_2(nheads):
  31. start = 2 ** (-(2 ** -(math.log2(nheads) - 3)))
  32. ratio = start
  33. return [start * ratio**i for i in range(nheads)]
  34. if math.log2(nheads).is_integer():
  35. return get_slopes_power_of_2(nheads)
  36. else:
  37. closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
  38. return (
  39. get_slopes_power_of_2(closest_power_of_2)
  40. + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
  41. )
  42. class FlashSelfAttention(nn.Module):
  43. """Implement the scaled dot product attention with softmax.
  44. Arguments
  45. ---------
  46. softmax_scale: The temperature to use for the softmax attention.
  47. (default: 1/sqrt(d_keys) where d_keys is computed at
  48. runtime)
  49. attention_dropout: The dropout rate to apply to the attention
  50. (default: 0.0)
  51. """
  52. def __init__(
  53. self,
  54. causal=False,
  55. softmax_scale=None,
  56. attention_dropout=0.0,
  57. window_size=(-1, -1),
  58. alibi_slopes=None,
  59. deterministic=False,
  60. ):
  61. super().__init__()
  62. assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
  63. assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
  64. self.causal = causal
  65. self.softmax_scale = softmax_scale
  66. self.drop = nn.Dropout(attention_dropout)
  67. self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
  68. self.window_size = window_size
  69. self.deterministic = deterministic
  70. def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
  71. """Implements the multihead softmax attention.
  72. Arguments
  73. ---------
  74. qkv: The tensor containing the query, key, and value.
  75. If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
  76. If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
  77. (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
  78. causal: if passed, will override self.causal
  79. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  80. of the sequences in the batch, used to index into qkv.
  81. max_seqlen: int. Maximum sequence length in the batch.
  82. Returns:
  83. --------
  84. out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
  85. else (B, S, H, D).
  86. """
  87. assert qkv.dtype in [torch.float16, torch.bfloat16]
  88. assert qkv.is_cuda
  89. causal = self.causal if causal is None else causal
  90. unpadded = cu_seqlens is not None
  91. if self.alibi_slopes is not None:
  92. self.alibi_slopes = self.alibi_slopes.to(torch.float32)
  93. if unpadded:
  94. assert cu_seqlens.dtype == torch.int32
  95. assert max_seqlen is not None
  96. assert isinstance(max_seqlen, int)
  97. return flash_attn_varlen_qkvpacked_func(
  98. qkv,
  99. cu_seqlens,
  100. max_seqlen,
  101. self.drop.p if self.training else 0.0,
  102. softmax_scale=self.softmax_scale,
  103. causal=causal,
  104. alibi_slopes=self.alibi_slopes,
  105. window_size=self.window_size,
  106. deterministic=self.deterministic,
  107. )
  108. else:
  109. return flash_attn_qkvpacked_func(
  110. qkv,
  111. self.drop.p if self.training else 0.0,
  112. softmax_scale=self.softmax_scale,
  113. causal=causal,
  114. alibi_slopes=self.alibi_slopes,
  115. window_size=self.window_size,
  116. deterministic=self.deterministic,
  117. )
  118. class FlashCrossAttention(nn.Module):
  119. """Implement the scaled dot product attention with softmax.
  120. Arguments
  121. ---------
  122. softmax_scale: The temperature to use for the softmax attention.
  123. (default: 1/sqrt(d_keys) where d_keys is computed at
  124. runtime)
  125. attention_dropout: The dropout rate to apply to the attention
  126. (default: 0.0)
  127. """
  128. def __init__(
  129. self,
  130. causal=False,
  131. softmax_scale=None,
  132. attention_dropout=0.0,
  133. alibi_slopes=None,
  134. window_size=(-1, -1),
  135. deterministic=False,
  136. ):
  137. super().__init__()
  138. assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
  139. assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
  140. self.causal = causal
  141. self.softmax_scale = softmax_scale
  142. self.drop = nn.Dropout(attention_dropout)
  143. self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
  144. self.window_size = window_size
  145. self.deterministic = deterministic
  146. def forward(
  147. self,
  148. q,
  149. kv,
  150. causal=None,
  151. cu_seqlens=None,
  152. max_seqlen=None,
  153. cu_seqlens_k=None,
  154. max_seqlen_k=None,
  155. ):
  156. """Implements the multihead softmax attention.
  157. Arguments
  158. ---------
  159. q: The tensor containing the query. (B, Sq, H, D)
  160. kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
  161. causal: if passed, will override self.causal
  162. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  163. of the sequences in the batch, used to index into q.
  164. max_seqlen: int. Maximum sequence length in the batch of q.
  165. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  166. of the sequences in the batch, used to index into kv.
  167. max_seqlen_k: int. Maximum sequence length in the batch of k and v.
  168. """
  169. assert q.dtype in [torch.float16, torch.bfloat16]
  170. assert q.is_cuda and kv.is_cuda
  171. causal = self.causal if causal is None else causal
  172. unpadded = cu_seqlens is not None
  173. if self.alibi_slopes is not None:
  174. self.alibi_slopes = self.alibi_slopes.to(torch.float32)
  175. if unpadded:
  176. assert cu_seqlens.dtype == torch.int32
  177. assert max_seqlen is not None
  178. assert isinstance(max_seqlen, int)
  179. assert cu_seqlens_k is not None
  180. assert cu_seqlens_k.dtype == torch.int32
  181. assert max_seqlen_k is not None
  182. assert isinstance(max_seqlen_k, int)
  183. return flash_attn_varlen_kvpacked_func(
  184. q,
  185. kv,
  186. cu_seqlens,
  187. cu_seqlens_k,
  188. max_seqlen,
  189. max_seqlen_k,
  190. self.drop.p if self.training else 0.0,
  191. softmax_scale=self.softmax_scale,
  192. causal=causal,
  193. alibi_slopes=self.alibi_slopes,
  194. window_size=self.window_size,
  195. deterministic=self.deterministic,
  196. )
  197. else:
  198. batch_size, seqlen_q = q.shape[0], q.shape[1]
  199. seqlen_k = kv.shape[1]
  200. assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
  201. return flash_attn_kvpacked_func(
  202. q,
  203. kv,
  204. self.drop.p if self.training else 0.0,
  205. causal=causal,
  206. softmax_scale=self.softmax_scale,
  207. alibi_slopes=self.alibi_slopes,
  208. window_size=self.window_size,
  209. deterministic=self.deterministic,
  210. )
  211. class SelfAttention(nn.Module):
  212. """Implement the scaled dot product attention with softmax.
  213. Arguments
  214. ---------
  215. softmax_scale: The temperature to use for the softmax attention.
  216. (default: 1/sqrt(d_keys) where d_keys is computed at
  217. runtime)
  218. attention_dropout: The dropout rate to apply to the attention
  219. (default: 0.0)
  220. """
  221. def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
  222. super().__init__()
  223. self.causal = causal
  224. self.softmax_scale = softmax_scale
  225. self.drop = nn.Dropout(attention_dropout)
  226. def forward(self, qkv, causal=None, key_padding_mask=None):
  227. """Implements the multihead softmax attention.
  228. Arguments
  229. ---------
  230. qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
  231. causal: if passed, will override self.causal
  232. key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
  233. False means to mask out. (B, S)
  234. """
  235. batch_size, seqlen = qkv.shape[0], qkv.shape[1]
  236. causal = self.causal if causal is None else causal
  237. q, k, v = qkv.unbind(dim=2)
  238. softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
  239. scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
  240. if key_padding_mask is not None:
  241. padding_mask = torch.full(
  242. (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
  243. )
  244. padding_mask.masked_fill_(key_padding_mask, 0.0)
  245. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  246. scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
  247. if causal:
  248. # "triu_tril_cuda_template" not implemented for 'BFloat16'
  249. # So we have to construct the mask in float
  250. causal_mask = torch.triu(
  251. torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
  252. )
  253. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  254. scores = scores + causal_mask.to(dtype=scores.dtype)
  255. attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
  256. attention_drop = self.drop(attention)
  257. output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
  258. return output
  259. class CrossAttention(nn.Module):
  260. """Implement the scaled dot product attention with softmax.
  261. Arguments
  262. ---------
  263. softmax_scale: The temperature to use for the softmax attention.
  264. (default: 1/sqrt(d_keys) where d_keys is computed at
  265. runtime)
  266. attention_dropout: The dropout rate to apply to the attention
  267. (default: 0.0)
  268. """
  269. def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
  270. super().__init__()
  271. self.causal = causal
  272. self.softmax_scale = softmax_scale
  273. self.drop = nn.Dropout(attention_dropout)
  274. def forward(self, q, kv, causal=None, key_padding_mask=None):
  275. """Implements the multihead softmax attention.
  276. Arguments
  277. ---------
  278. q: The tensor containing the query. (B, Sq, H, D)
  279. kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
  280. causal: if passed, will override self.causal
  281. key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
  282. False means to mask out. (B, Sk)
  283. """
  284. batch_size, seqlen_q = q.shape[0], q.shape[1]
  285. causal = self.causal if causal is None else causal
  286. seqlen_k = kv.shape[1]
  287. assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
  288. if kv.shape[3] != q.shape[2]: # MQA/GQA
  289. kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
  290. k, v = kv.unbind(dim=2)
  291. softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
  292. scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
  293. if key_padding_mask is not None:
  294. padding_mask = torch.full(
  295. (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
  296. )
  297. padding_mask.masked_fill_(key_padding_mask, 0.0)
  298. # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
  299. scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
  300. if causal:
  301. # causal mask needs to take into account the difference between seqlen_q and seqlen_k
  302. row_idx = rearrange(
  303. torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
  304. )
  305. col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
  306. sk = (
  307. seqlen_k
  308. if key_padding_mask is None
  309. else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
  310. )
  311. causal_mask = col_idx > row_idx + sk - seqlen_q
  312. scores = scores.masked_fill(causal_mask, -10000.0)
  313. attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
  314. attention_drop = self.drop(attention)
  315. output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
  316. return output
  317. class LinearResidual(nn.Linear):
  318. """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
  319. def forward(self, input: torch.Tensor) -> torch.Tensor:
  320. return super().forward(input), input
  321. def _update_kv_cache(kv, inference_params, layer_idx):
  322. """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
  323. # Pre-allocate memory for key-values for inference.
  324. num_heads, head_dim = kv.shape[-2:]
  325. if layer_idx not in inference_params.key_value_memory_dict:
  326. kv_cache = torch.empty(
  327. inference_params.max_batch_size,
  328. inference_params.max_seqlen,
  329. 2,
  330. num_heads,
  331. head_dim,
  332. dtype=kv.dtype,
  333. device=kv.device,
  334. )
  335. inference_params.key_value_memory_dict[layer_idx] = kv_cache
  336. else:
  337. kv_cache = inference_params.key_value_memory_dict[layer_idx]
  338. # Adjust key and value for inference
  339. batch_start = inference_params.batch_size_offset
  340. batch_end = batch_start + kv.shape[0]
  341. sequence_start = inference_params.seqlen_offset
  342. sequence_end = sequence_start + kv.shape[1]
  343. assert batch_end <= kv_cache.shape[0]
  344. assert sequence_end <= kv_cache.shape[1]
  345. assert kv_cache is not None
  346. kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
  347. return kv_cache[batch_start:batch_end, :sequence_end, ...]
  348. class MHA(nn.Module):
  349. """Multi-head self-attention and cross-attention"""
  350. def __init__(
  351. self,
  352. embed_dim,
  353. num_heads,
  354. num_heads_kv=None,
  355. cross_attn=False,
  356. qkv_proj_bias=True,
  357. out_proj_bias=True,
  358. dropout=0.0,
  359. softmax_scale=None,
  360. causal=False,
  361. layer_idx=None,
  362. dwconv=False,
  363. rotary_emb_dim=0,
  364. rotary_emb_base=10000.0,
  365. rotary_emb_scale_base=None,
  366. rotary_emb_interleaved=False,
  367. use_alibi=False,
  368. window_size=(-1, -1),
  369. fused_bias_fc=False,
  370. use_flash_attn=False,
  371. return_residual=False,
  372. checkpointing=False,
  373. device=None,
  374. dtype=None,
  375. ) -> None:
  376. """
  377. num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
  378. return_residual: whether to return the input x along with the output. This is for
  379. performance reason: for post-norm architecture, returning the input allows us
  380. to fuse the backward of nn.Linear with the residual connection.
  381. """
  382. factory_kwargs = {"device": device, "dtype": dtype}
  383. super().__init__()
  384. self.embed_dim = embed_dim
  385. self.cross_attn = cross_attn
  386. self.causal = causal
  387. self.layer_idx = layer_idx
  388. self.dwconv = dwconv
  389. self.rotary_emb_dim = rotary_emb_dim
  390. self.use_flash_attn = use_flash_attn
  391. self.return_residual = return_residual
  392. self.checkpointing = checkpointing
  393. if use_alibi:
  394. assert use_flash_attn, "ALiBi code path requires flash_attn"
  395. alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
  396. else:
  397. alibi_slopes = None
  398. if window_size != (-1, -1):
  399. assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
  400. self.num_heads = num_heads
  401. self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
  402. assert (
  403. self.num_heads % self.num_heads_kv == 0
  404. ), "num_heads must be divisible by num_heads_kv"
  405. assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
  406. self.head_dim = self.embed_dim // num_heads
  407. qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
  408. kv_dim = 2 * self.head_dim * self.num_heads_kv
  409. if self.rotary_emb_dim > 0:
  410. assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
  411. assert RotaryEmbedding is not None, "rotary_emb is not installed"
  412. self.rotary_emb = RotaryEmbedding(
  413. self.rotary_emb_dim,
  414. base=rotary_emb_base,
  415. scale_base=rotary_emb_scale_base,
  416. interleaved=rotary_emb_interleaved,
  417. device=device,
  418. )
  419. if fused_bias_fc and FusedDense is None:
  420. raise ImportError("fused_dense is not installed")
  421. linear_cls = nn.Linear if not fused_bias_fc else FusedDense
  422. linear_resid_cls = (
  423. LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
  424. )
  425. wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
  426. inner_attn_cls = (
  427. partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
  428. if use_flash_attn
  429. else SelfAttention
  430. )
  431. inner_cross_attn_cls = (
  432. partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
  433. if use_flash_attn
  434. else CrossAttention
  435. )
  436. if not self.cross_attn:
  437. self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
  438. else:
  439. self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
  440. self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
  441. if self.dwconv:
  442. if self.num_heads_kv == self.num_heads:
  443. self.dwconv_qkv = nn.Conv1d(
  444. qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
  445. )
  446. else:
  447. self.dwconv_q = nn.Conv1d(
  448. embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
  449. )
  450. self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
  451. self.inner_attn = inner_attn_cls(
  452. causal=causal,
  453. softmax_scale=softmax_scale,
  454. attention_dropout=dropout,
  455. )
  456. self.inner_cross_attn = inner_cross_attn_cls(
  457. causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
  458. )
  459. self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
  460. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
  461. dtype = self.out_proj.weight.dtype if dtype is None else dtype
  462. device = self.out_proj.weight.device
  463. return torch.empty(
  464. batch_size,
  465. max_seqlen,
  466. 2,
  467. self.num_heads_kv,
  468. self.head_dim,
  469. dtype=dtype,
  470. device=device,
  471. )
  472. def _update_kv_cache(self, kv, inference_params):
  473. """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
  474. assert not self.dwconv, "Generation does not support dwconv yet"
  475. assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
  476. return _update_kv_cache(kv, inference_params, self.layer_idx)
  477. def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
  478. """
  479. Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
  480. q: (batch_size, seqlen_q, nheads, head_dim)
  481. kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
  482. """
  483. assert inference_params is not None and inference_params.seqlen_offset > 0
  484. assert self.use_flash_attn
  485. if self.rotary_emb_dim > 0:
  486. assert self.rotary_emb.scale is None, "This code path does not support xPos"
  487. self.rotary_emb._update_cos_sin_cache(
  488. inference_params.max_seqlen, device=q.device, dtype=q.dtype
  489. )
  490. rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
  491. else:
  492. rotary_cos, rotary_sin = None, None
  493. batch = q.shape[0]
  494. kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
  495. cache_seqlens = (
  496. inference_params.lengths_per_sample[:batch]
  497. if inference_params.lengths_per_sample is not None
  498. else inference_params.seqlen_offset
  499. )
  500. alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
  501. context = flash_attn_with_kvcache(
  502. q,
  503. kv_cache[:, :, 0],
  504. kv_cache[:, :, 1],
  505. kv[:, :, 0],
  506. kv[:, :, 1],
  507. rotary_cos=rotary_cos,
  508. rotary_sin=rotary_sin,
  509. cache_seqlens=cache_seqlens,
  510. softmax_scale=self.inner_cross_attn.softmax_scale,
  511. causal=self.inner_cross_attn.causal,
  512. rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
  513. alibi_slopes=alibi_slopes,
  514. )
  515. return context
  516. def _update_kvcache_attention(self, q, kv, inference_params):
  517. """Write kv to inference_params, then do attention"""
  518. if (
  519. inference_params.seqlen_offset == 0
  520. or flash_attn_with_kvcache is None
  521. or not self.use_flash_attn
  522. ):
  523. # TODO: this only uses seqlen_offset and not lengths_per_sample.
  524. kv = self._update_kv_cache(kv, inference_params)
  525. return self.inner_cross_attn(q, kv)
  526. else:
  527. batch = q.shape[0]
  528. kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
  529. cache_seqlens = (
  530. inference_params.lengths_per_sample[:batch]
  531. if inference_params.lengths_per_sample is not None
  532. else inference_params.seqlen_offset
  533. )
  534. alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
  535. return flash_attn_with_kvcache(
  536. q,
  537. kv_cache[:, :, 0],
  538. kv_cache[:, :, 1],
  539. kv[:, :, 0],
  540. kv[:, :, 1],
  541. cache_seqlens=cache_seqlens,
  542. softmax_scale=self.inner_cross_attn.softmax_scale,
  543. causal=self.inner_cross_attn.causal,
  544. alibi_slopes=alibi_slopes,
  545. )
  546. def forward(
  547. self,
  548. x,
  549. x_kv=None,
  550. key_padding_mask=None,
  551. cu_seqlens=None,
  552. max_seqlen=None,
  553. mixer_subset=None,
  554. inference_params=None,
  555. **kwargs,
  556. ):
  557. """
  558. Arguments:
  559. x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
  560. cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
  561. is the is the sum of the sequence lengths in the batch.
  562. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
  563. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
  564. of the sequences in the batch, used to index into x. Only applicable when using
  565. FlashAttention.
  566. max_seqlen: int. Maximum sequence length in the batch.
  567. key_padding_mask: boolean mask, True means to keep, False means to mask out.
  568. (batch, seqlen). Only applicable when not using FlashAttention.
  569. mixer_subset: for cross-attention only. If not None, will take a subset of x
  570. before applying the query projection. Useful for e.g., ViT where we only care
  571. about the CLS token in the last layer.
  572. inference_params: for generation. Adapted from Megatron-LM (and Apex)
  573. https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
  574. """
  575. if cu_seqlens is not None:
  576. assert max_seqlen is not None
  577. assert key_padding_mask is None
  578. assert self.use_flash_attn
  579. assert not self.dwconv
  580. assert self.rotary_emb_dim == 0
  581. if key_padding_mask is not None:
  582. assert cu_seqlens is None
  583. assert max_seqlen is None
  584. assert not self.use_flash_attn
  585. if inference_params is not None:
  586. assert key_padding_mask is None
  587. assert cu_seqlens is None and max_seqlen is None
  588. assert not self.dwconv
  589. kwargs = (
  590. {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
  591. if self.use_flash_attn
  592. else {"key_padding_mask": key_padding_mask, **kwargs}
  593. )
  594. seqlen_offset = (
  595. 0
  596. if inference_params is None
  597. else (
  598. inference_params.lengths_per_sample
  599. if inference_params.lengths_per_sample is not None
  600. else inference_params.seqlen_offset
  601. )
  602. )
  603. rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
  604. batch, seqlen = x.shape[:2]
  605. if not self.cross_attn and self.num_heads_kv == self.num_heads:
  606. assert x_kv is None and mixer_subset is None
  607. if not self.return_residual:
  608. qkv = self.Wqkv(x)
  609. else:
  610. qkv, x = self.Wqkv(x)
  611. if self.dwconv:
  612. qkv = rearrange(
  613. self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
  614. ).contiguous()
  615. qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
  616. if (
  617. inference_params is None
  618. or inference_params.seqlen_offset == 0
  619. or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
  620. or not self.use_flash_attn
  621. ):
  622. if self.rotary_emb_dim > 0:
  623. qkv = self.rotary_emb(
  624. qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
  625. )
  626. if inference_params is None:
  627. if not self.checkpointing:
  628. context = self.inner_attn(qkv, **kwargs)
  629. else:
  630. context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
  631. else:
  632. context = self._update_kvcache_attention(
  633. qkv[:, :, 0], qkv[:, :, 1:], inference_params
  634. )
  635. else:
  636. context = self._apply_rotary_update_kvcache_attention(
  637. qkv[:, :, 0], qkv[:, :, 1:], inference_params
  638. )
  639. else:
  640. if self.cross_attn:
  641. if not self.return_residual:
  642. q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
  643. kv = self.Wkv(x_kv if x_kv is not None else x)
  644. else:
  645. if x_kv is not None:
  646. kv, x_kv = self.Wkv(x_kv)
  647. else:
  648. kv, x = self.Wkv(x)
  649. q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
  650. else:
  651. assert self.num_heads_kv != self.num_heads
  652. if not self.return_residual:
  653. qkv = self.Wqkv(x)
  654. else:
  655. qkv, x = self.Wqkv(x)
  656. q = qkv[..., : self.num_heads * self.head_dim]
  657. kv = qkv[..., self.num_heads * self.head_dim :]
  658. q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
  659. kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
  660. if self.dwconv:
  661. q = rearrange(
  662. self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
  663. ).contiguous()
  664. kv = rearrange(
  665. self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
  666. ).contiguous()
  667. if (
  668. inference_params is None
  669. or inference_params.seqlen_offset == 0
  670. or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
  671. or not self.use_flash_attn
  672. ):
  673. if self.rotary_emb_dim > 0:
  674. q, kv = self.rotary_emb(
  675. q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
  676. )
  677. if inference_params is None:
  678. if not self.checkpointing:
  679. context = self.inner_cross_attn(q, kv, **kwargs)
  680. else:
  681. context = torch.utils.checkpoint.checkpoint(
  682. self.inner_cross_attn, q, kv, **kwargs
  683. )
  684. else:
  685. context = self._update_kvcache_attention(q, kv, inference_params)
  686. else:
  687. context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
  688. out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
  689. return out if not self.return_residual else (out, x)
  690. class ParallelMHA(nn.Module):
  691. """Multi-head self-attention and cross-attention"""
  692. def __init__(
  693. self,
  694. embed_dim,
  695. num_heads,
  696. process_group,
  697. num_heads_kv=None,
  698. qkv_proj_bias=True,
  699. out_proj_bias=True,
  700. dropout=0.0,
  701. softmax_scale=None,
  702. causal=False,
  703. layer_idx=None,
  704. rotary_emb_dim=0,
  705. rotary_emb_base=10000.0,
  706. rotary_emb_scale_base=None,
  707. rotary_emb_interleaved=False,
  708. use_alibi=False,
  709. window_size=(-1, -1),
  710. use_flash_attn=False,
  711. checkpointing=False,
  712. sequence_parallel=True,
  713. device=None,
  714. dtype=None,
  715. ) -> None:
  716. factory_kwargs = {"device": device, "dtype": dtype}
  717. super().__init__()
  718. self.embed_dim = embed_dim
  719. self.causal = causal
  720. self.layer_idx = layer_idx
  721. self.rotary_emb_dim = rotary_emb_dim
  722. self.use_flash_attn = use_flash_attn
  723. self.checkpointing = checkpointing
  724. self.process_group = process_group
  725. self.world_size = process_group.size()
  726. self.local_rank = torch.distributed.get_rank(process_group)
  727. self.num_heads = num_heads
  728. assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
  729. self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
  730. assert (
  731. self.num_heads % self.num_heads_kv == 0
  732. ), "num_heads must be divisible by num_heads_kv"
  733. self.num_heads_per_rank = get_dim_for_local_rank(
  734. self.num_heads, self.world_size, self.local_rank
  735. )
  736. self.num_heads_kv_per_rank = get_dim_for_local_rank(
  737. self.num_heads_kv, self.world_size, self.local_rank
  738. )
  739. self.head_dim = self.embed_dim // num_heads
  740. qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
  741. if use_alibi:
  742. assert use_flash_attn, "ALiBi code path requires flash_attn"
  743. num_heads_local = math.ceil(self.num_heads / self.world_size)
  744. alibi_slopes = torch.tensor(
  745. get_alibi_slopes(num_heads)[
  746. self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local
  747. ],
  748. device=device,
  749. )
  750. else:
  751. alibi_slopes = None
  752. if window_size != (-1, -1):
  753. assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
  754. if self.rotary_emb_dim > 0:
  755. assert RotaryEmbedding is not None, "rotary_emb is not installed"
  756. self.rotary_emb = RotaryEmbedding(
  757. self.rotary_emb_dim,
  758. base=rotary_emb_base,
  759. scale_base=rotary_emb_scale_base,
  760. interleaved=rotary_emb_interleaved,
  761. device=device,
  762. )
  763. if ColumnParallelLinear is None or RowParallelLinear is None:
  764. raise ImportError("fused_dense is not installed")
  765. self.Wqkv = ColumnParallelLinear(
  766. embed_dim,
  767. qkv_dim,
  768. process_group,
  769. bias=qkv_proj_bias,
  770. sequence_parallel=sequence_parallel,
  771. multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
  772. **factory_kwargs,
  773. )
  774. inner_attn_cls = (
  775. partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
  776. if use_flash_attn
  777. else SelfAttention
  778. )
  779. inner_cross_attn_cls = (
  780. partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
  781. if use_flash_attn
  782. else CrossAttention
  783. )
  784. self.inner_attn = inner_attn_cls(
  785. causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
  786. )
  787. self.inner_cross_attn = inner_cross_attn_cls(
  788. causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
  789. )
  790. self.out_proj = RowParallelLinear(
  791. embed_dim,
  792. embed_dim,
  793. process_group,
  794. bias=out_proj_bias,
  795. sequence_parallel=sequence_parallel,
  796. multiple_of=self.head_dim,
  797. **factory_kwargs,
  798. )
  799. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
  800. dtype = self.out_proj.weight.dtype if dtype is None else dtype
  801. device = self.out_proj.weight.device
  802. return torch.empty(
  803. batch_size,
  804. max_seqlen,
  805. 2,
  806. self.num_heads_kv_per_rank,
  807. self.head_dim,
  808. dtype=dtype,
  809. device=device,
  810. )
  811. def _update_kv_cache(self, kv, inference_params):
  812. """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
  813. assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
  814. return _update_kv_cache(kv, inference_params, self.layer_idx)
  815. def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
  816. """
  817. Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
  818. q: (batch_size, seqlen_q, nheads, head_dim)
  819. kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
  820. """
  821. assert inference_params is not None and inference_params.seqlen_offset > 0
  822. assert self.use_flash_attn
  823. if self.rotary_emb_dim > 0:
  824. assert self.rotary_emb.scale is None, "This code path does not support xPos"
  825. self.rotary_emb._update_cos_sin_cache(
  826. inference_params.max_seqlen, device=q.device, dtype=q.dtype
  827. )
  828. rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
  829. else:
  830. rotary_cos, rotary_sin = None, None
  831. batch = q.shape[0]
  832. kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
  833. cache_seqlens = (
  834. inference_params.lengths_per_sample[:batch]
  835. if inference_params.lengths_per_sample is not None
  836. else inference_params.seqlen_offset
  837. )
  838. alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
  839. context = flash_attn_with_kvcache(
  840. q,
  841. kv_cache[:, :, 0],
  842. kv_cache[:, :, 1],
  843. kv[:, :, 0],
  844. kv[:, :, 1],
  845. rotary_cos=rotary_cos,
  846. rotary_sin=rotary_sin,
  847. cache_seqlens=cache_seqlens,
  848. softmax_scale=self.inner_cross_attn.softmax_scale,
  849. causal=self.inner_cross_attn.causal,
  850. rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
  851. alibi_slopes=alibi_slopes,
  852. )
  853. return context
  854. def _update_kvcache_attention(self, q, kv, inference_params):
  855. """Write kv to inference_params, then do attention"""
  856. if inference_params.seqlen_offset == 0 or not self.use_flash_attn:
  857. # TODO: this only uses seqlen_offset and not lengths_per_sample.
  858. kv = self._update_kv_cache(kv, inference_params)
  859. return self.inner_cross_attn(q, kv)
  860. else:
  861. batch = q.shape[0]
  862. kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
  863. cache_seqlens = (
  864. inference_params.lengths_per_sample[:batch]
  865. if inference_params.lengths_per_sample is not None
  866. else inference_params.seqlen_offset
  867. )
  868. alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None)
  869. context = flash_attn_with_kvcache(
  870. q,
  871. kv_cache[:, :, 0],
  872. kv_cache[:, :, 1],
  873. kv[:, :, 0],
  874. kv[:, :, 1],
  875. cache_seqlens=cache_seqlens,
  876. softmax_scale=self.inner_cross_attn.softmax_scale,
  877. causal=self.inner_cross_attn.causal,
  878. alibi_slopes=alibi_slopes,
  879. )
  880. return context
  881. def forward(self, x, seqlen=None, inference_params=None, **kwargs):
  882. """
  883. Arguments:
  884. x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
  885. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
  886. split x during sequence parallel, we split the batch * seqlen dimension
  887. (in case batch is small).
  888. """
  889. qkv = self.Wqkv(x)
  890. if seqlen is not None:
  891. qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
  892. seqlen_offset = (
  893. 0
  894. if inference_params is None
  895. else (
  896. inference_params.lengths_per_sample
  897. if inference_params.lengths_per_sample is not None
  898. else inference_params.seqlen_offset
  899. )
  900. )
  901. rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
  902. if self.num_heads_kv == self.num_heads:
  903. qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
  904. if (
  905. inference_params is None
  906. or inference_params.seqlen_offset == 0
  907. or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
  908. or not self.use_flash_attn
  909. ):
  910. if self.rotary_emb_dim > 0:
  911. qkv = self.rotary_emb(
  912. qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
  913. )
  914. if inference_params is None:
  915. if not self.checkpointing:
  916. context = self.inner_attn(qkv, **kwargs)
  917. else:
  918. context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
  919. else:
  920. context = self._update_kvcache_attention(
  921. qkv[:, :, 0], qkv[:, :, 1:], inference_params
  922. )
  923. else:
  924. context = self._apply_rotary_update_kvcache_attention(
  925. qkv[:, :, 0], qkv[:, :, 1:], inference_params
  926. )
  927. else:
  928. q = rearrange(
  929. qkv[..., : self.num_heads_per_rank * self.head_dim],
  930. "... (h d) -> ... h d",
  931. d=self.head_dim,
  932. )
  933. kv = rearrange(
  934. qkv[..., self.num_heads_per_rank * self.head_dim :],
  935. "... (two hkv d) -> ... two hkv d",
  936. two=2,
  937. d=self.head_dim,
  938. )
  939. if (
  940. inference_params is None
  941. or inference_params.seqlen_offset == 0
  942. or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
  943. or not self.use_flash_attn
  944. ):
  945. if self.rotary_emb_dim > 0:
  946. q, kv = self.rotary_emb(
  947. q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
  948. )
  949. if inference_params is None:
  950. if not self.checkpointing:
  951. context = self.inner_cross_attn(q, kv, **kwargs)
  952. else:
  953. context = torch.utils.checkpoint.checkpoint(
  954. self.inner_cross_attn, q, kv, **kwargs
  955. )
  956. else:
  957. context = self._update_kvcache_attention(q, kv, inference_params)
  958. else:
  959. context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
  960. context = rearrange(context, "b s h d -> b s (h d)")
  961. if seqlen is not None:
  962. context = rearrange(context, "b s d -> (b s) d")
  963. out = self.out_proj(context)
  964. return out