rotary.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. from typing import Optional, Tuple, Union
  4. import torch
  5. from einops import rearrange, repeat
  6. from flash_attn.ops.triton.rotary import apply_rotary
  7. def rotate_half(x, interleaved=False):
  8. if not interleaved:
  9. x1, x2 = x.chunk(2, dim=-1)
  10. return torch.cat((-x2, x1), dim=-1)
  11. else:
  12. x1, x2 = x[..., ::2], x[..., 1::2]
  13. return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
  14. def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
  15. """
  16. x: (batch_size, seqlen, nheads, headdim)
  17. cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
  18. """
  19. ro_dim = cos.shape[-1] * 2
  20. assert ro_dim <= x.shape[-1]
  21. cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
  22. sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
  23. return torch.cat(
  24. [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
  25. dim=-1,
  26. )
  27. class ApplyRotaryEmb(torch.autograd.Function):
  28. @staticmethod
  29. def forward(
  30. ctx,
  31. x,
  32. cos,
  33. sin,
  34. interleaved=False,
  35. inplace=False,
  36. seqlen_offsets: Union[int, torch.Tensor] = 0,
  37. cu_seqlens: Optional[torch.Tensor] = None,
  38. max_seqlen: Optional[int] = None,
  39. ):
  40. out = apply_rotary(
  41. x,
  42. cos,
  43. sin,
  44. seqlen_offsets=seqlen_offsets,
  45. cu_seqlens=cu_seqlens,
  46. max_seqlen=max_seqlen,
  47. interleaved=interleaved,
  48. inplace=inplace,
  49. )
  50. if isinstance(seqlen_offsets, int):
  51. ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
  52. ctx.seqlen_offsets = seqlen_offsets
  53. else:
  54. ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
  55. ctx.seqlen_offsets = None
  56. ctx.interleaved = interleaved
  57. ctx.inplace = inplace
  58. ctx.max_seqlen = max_seqlen
  59. return out if not inplace else x
  60. @staticmethod
  61. def backward(ctx, do):
  62. seqlen_offsets = ctx.seqlen_offsets
  63. if seqlen_offsets is None:
  64. cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
  65. else:
  66. cos, sin, cu_seqlens = ctx.saved_tensors
  67. # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
  68. # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
  69. if not ctx.interleaved and not ctx.inplace:
  70. do = do.clone()
  71. dx = apply_rotary(
  72. do,
  73. cos,
  74. sin,
  75. seqlen_offsets=seqlen_offsets,
  76. cu_seqlens=cu_seqlens,
  77. max_seqlen=ctx.max_seqlen,
  78. interleaved=ctx.interleaved,
  79. inplace=ctx.inplace,
  80. conjugate=True,
  81. )
  82. return dx, None, None, None, None, None, None, None
  83. def apply_rotary_emb(
  84. x,
  85. cos,
  86. sin,
  87. interleaved=False,
  88. inplace=False,
  89. seqlen_offsets: Union[int, torch.Tensor] = 0,
  90. cu_seqlens: Optional[torch.Tensor] = None,
  91. max_seqlen: Optional[int] = None,
  92. ):
  93. """
  94. Arguments:
  95. x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
  96. else (total_seqlen, nheads, headdim)
  97. cos, sin: (seqlen_rotary, rotary_dim / 2)
  98. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  99. of 1st half and 2nd half (GPT-NeoX style).
  100. inplace: if True, apply rotary embedding in-place.
  101. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
  102. Most commonly used in inference when we have KV cache.
  103. cu_seqlens: (batch + 1,) or None
  104. max_seqlen: int
  105. Return:
  106. out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
  107. else (total_seqlen, nheads, headdim)
  108. rotary_dim must be <= headdim
  109. Apply rotary embedding to the first rotary_dim of x.
  110. """
  111. return ApplyRotaryEmb.apply(
  112. x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
  113. )
  114. # For backward compatibility
  115. apply_rotary_emb_func = apply_rotary_emb
  116. class ApplyRotaryEmbQKV_(torch.autograd.Function):
  117. @staticmethod
  118. def forward(
  119. ctx,
  120. qkv,
  121. cos,
  122. sin,
  123. cos_k=None,
  124. sin_k=None,
  125. interleaved=False,
  126. seqlen_offsets: Union[int, torch.Tensor] = 0,
  127. num_heads_q: Union[int] = None,
  128. ):
  129. if cos_k is None and sin_k is None and qkv.is_contiguous():
  130. # Call 1 kernel instead of 2 kernels
  131. # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
  132. # dimensions, we get the same tensor
  133. if qkv.dim() == 5:
  134. batch, seqlen, three, nheads, headdim = qkv.shape
  135. assert three == 3
  136. # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
  137. qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
  138. else:
  139. assert qkv.dim() == 4
  140. assert num_heads_q is not None
  141. num_heads_k = (qkv.shape[2] - num_heads_q) // 2
  142. assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
  143. qk = qkv[:, :, :num_heads_q + num_heads_k]
  144. apply_rotary(
  145. qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
  146. )
  147. else:
  148. cos_k = cos if cos_k is None else cos_k
  149. sin_k = sin if sin_k is None else sin_k
  150. if qkv.dim() == 5:
  151. q, k = qkv[:, :, 0], qkv[:, :, 1]
  152. else:
  153. assert qkv.dim() == 4
  154. assert num_heads_q is not None
  155. num_heads_k = (qkv.shape[2] - num_heads_q) // 2
  156. assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
  157. q, k = qkv[:, :, :num_heads_q], qkv[:, :, num_heads_q : num_heads_q + num_heads_k]
  158. apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
  159. apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
  160. ctx.save_for_backward(cos, sin, cos_k, sin_k)
  161. if isinstance(seqlen_offsets, int):
  162. ctx.save_for_backward(cos, sin, cos_k, sin_k)
  163. ctx.seqlen_offsets = seqlen_offsets
  164. else:
  165. ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
  166. ctx.seqlen_offsets = None
  167. ctx.interleaved = interleaved
  168. ctx.num_heads_q = num_heads_q
  169. return qkv
  170. @staticmethod
  171. def backward(ctx, dqkv):
  172. seqlen_offsets = ctx.seqlen_offsets
  173. if seqlen_offsets is None:
  174. cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
  175. else:
  176. cos, sin, cos_k, sin_k = ctx.saved_tensors
  177. if cos_k is None and sin_k is None and dqkv.is_contiguous():
  178. # Call 1 kernel instead of 2 kernels
  179. # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
  180. # dimensions, we get the same tensor
  181. if dqkv.dim() == 5:
  182. dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
  183. else:
  184. assert dqkv.dim() == 4
  185. assert ctx.num_heads_q is not None
  186. num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
  187. assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
  188. dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k]
  189. apply_rotary(
  190. dqk,
  191. cos,
  192. sin,
  193. seqlen_offsets=seqlen_offsets,
  194. interleaved=ctx.interleaved,
  195. inplace=True,
  196. conjugate=True,
  197. )
  198. else:
  199. cos_k = cos if cos_k is None else cos_k
  200. sin_k = sin if sin_k is None else sin_k
  201. if dqkv.dim() == 5:
  202. dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
  203. else:
  204. assert dqkv.dim() == 4
  205. assert ctx.num_heads_q is not None
  206. num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
  207. assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
  208. dq = dqkv[:, :, : ctx.num_heads_q]
  209. dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k]
  210. apply_rotary(
  211. dq,
  212. cos,
  213. sin,
  214. seqlen_offsets,
  215. interleaved=ctx.interleaved,
  216. inplace=True,
  217. conjugate=True,
  218. )
  219. apply_rotary(
  220. dk,
  221. cos_k,
  222. sin_k,
  223. seqlen_offsets,
  224. interleaved=ctx.interleaved,
  225. inplace=True,
  226. conjugate=True,
  227. )
  228. return dqkv, None, None, None, None, None, None, None
  229. def apply_rotary_emb_qkv_(
  230. qkv,
  231. cos,
  232. sin,
  233. cos_k=None,
  234. sin_k=None,
  235. interleaved=False,
  236. seqlen_offsets: Union[int, torch.Tensor] = 0,
  237. num_heads_q: Optional[int] = None,
  238. ):
  239. """
  240. Arguments:
  241. qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).
  242. If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
  243. then num_heads_q must be provided.
  244. cos, sin: (seqlen, rotary_dim / 2)
  245. cos_k, sin_k: (seqlen, rotary_dim / 2), optional
  246. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
  247. 1st half and 2nd half (GPT-NeoX style).
  248. seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
  249. Most commonly used in inference when we have KV cache.
  250. Return:
  251. qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)
  252. rotary_dim must be <= headdim
  253. Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
  254. """
  255. return ApplyRotaryEmbQKV_.apply(
  256. qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q
  257. )
  258. class ApplyRotaryEmbKV_(torch.autograd.Function):
  259. @staticmethod
  260. def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
  261. batch, seqlen, two, nheads, headdim = kv.shape
  262. assert two == 2
  263. k = kv[:, :, 0]
  264. apply_rotary(
  265. k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
  266. )
  267. if isinstance(seqlen_offsets, int):
  268. ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
  269. ctx.seqlen_offsets = seqlen_offsets
  270. else:
  271. ctx.save_for_backward(cos, sin, seqlen_offsets)
  272. ctx.seqlen_offsets = None
  273. ctx.interleaved = interleaved
  274. return kv
  275. @staticmethod
  276. def backward(ctx, dkv):
  277. seqlen_offsets = ctx.seqlen_offsets
  278. if seqlen_offsets is None:
  279. cos, sin, seqlen_offsets = ctx.saved_tensors
  280. else:
  281. cos, sin = ctx.saved_tensors
  282. apply_rotary(
  283. dkv[:, :, 0],
  284. cos,
  285. sin,
  286. seqlen_offsets=seqlen_offsets,
  287. interleaved=ctx.interleaved,
  288. inplace=True,
  289. conjugate=True,
  290. )
  291. return dkv, None, None, None, None
  292. apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
  293. def apply_rotary_emb_kv_(
  294. kv,
  295. cos,
  296. sin,
  297. interleaved=False,
  298. seqlen_offsets: Union[int, torch.Tensor] = 0,
  299. ):
  300. """
  301. Arguments:
  302. kv: (batch_size, seqlen, 2, nheads, headdim)
  303. cos, sin: (seqlen, rotary_dim / 2)
  304. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
  305. 1st half and 2nd half (GPT-NeoX style).
  306. seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
  307. Most commonly used in inference when we have KV cache.
  308. Return:
  309. kv: (batch_size, seqlen, 2, nheads, headdim)
  310. rotary_dim must be <= headdim
  311. Apply rotary embedding *inplace* to the first rotary_dim of K.
  312. """
  313. return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
  314. class RotaryEmbedding(torch.nn.Module):
  315. """
  316. The rotary position embeddings from RoFormer_ (Su et. al).
  317. A crucial insight from the method is that the query and keys are
  318. transformed by rotation matrices which depend on the relative positions.
  319. Other implementations are available in the Rotary Transformer repo_ and in
  320. GPT-NeoX_, GPT-NeoX was an inspiration
  321. .. _RoFormer: https://arxiv.org/abs/2104.09864
  322. .. _repo: https://github.com/ZhuiyiTechnology/roformer
  323. .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
  324. If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
  325. A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
  326. Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
  327. """
  328. def __init__(
  329. self,
  330. dim: int,
  331. base=10000.0,
  332. interleaved=False,
  333. scale_base=None,
  334. pos_idx_in_fp32=True,
  335. device=None,
  336. ):
  337. """
  338. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  339. of 1st half and 2nd half (GPT-NeoX style).
  340. pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
  341. otherwise they might be in lower precision.
  342. This option was added because previously (before 2023-07-02), when we construct
  343. the position indices, we use the dtype of self.inv_freq. In most cases this would
  344. be fp32, but if the model is trained in pure bf16 (not mixed precision), then
  345. self.inv_freq would be bf16, and the position indices are also in bf16.
  346. Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
  347. embeddings for some positions will coincide.
  348. To maintain compatibility with models previously trained in pure bf16,
  349. we add this option.
  350. """
  351. super().__init__()
  352. self.dim = dim
  353. self.base = float(base)
  354. self.pos_idx_in_fp32 = pos_idx_in_fp32
  355. # Generate and save the inverse frequency buffer (non trainable)
  356. inv_freq = self._compute_inv_freq(device)
  357. self.register_buffer("inv_freq", inv_freq, persistent=False)
  358. self.interleaved = interleaved
  359. self.scale_base = scale_base
  360. scale = (
  361. (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
  362. if scale_base is not None
  363. else None
  364. )
  365. self.register_buffer("scale", scale, persistent=False)
  366. self._seq_len_cached = 0
  367. self._cos_cached = None
  368. self._sin_cached = None
  369. self._cos_k_cached = None
  370. self._sin_k_cached = None
  371. def _compute_inv_freq(self, device=None):
  372. return 1.0 / (
  373. self.base
  374. ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
  375. )
  376. def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
  377. # Reset the tables if the sequence length has changed,
  378. # if we're on a new device (possibly due to tracing for instance),
  379. # or if we're switching from inference mode to training
  380. if (
  381. seqlen > self._seq_len_cached
  382. or self._cos_cached is None
  383. or self._cos_cached.device != device
  384. or self._cos_cached.dtype != dtype
  385. or (self.training and self._cos_cached.is_inference())
  386. ):
  387. self._seq_len_cached = seqlen
  388. # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
  389. # And the output of arange can be quite large, so bf16 would lose a lot of precision.
  390. # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
  391. if self.pos_idx_in_fp32:
  392. t = torch.arange(seqlen, device=device, dtype=torch.float32)
  393. # We want fp32 here as well since inv_freq will be multiplied with t, and the output
  394. # will be large. Having it in bf16 will lose a lot of precision and cause the
  395. # cos & sin output to change significantly.
  396. # We want to recompute self.inv_freq if it was not loaded in fp32
  397. if self.inv_freq.dtype != torch.float32:
  398. inv_freq = self._compute_inv_freq(device=device)
  399. else:
  400. inv_freq = self.inv_freq
  401. else:
  402. t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
  403. inv_freq = self.inv_freq
  404. # Don't do einsum, it converts fp32 to fp16 under AMP
  405. # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
  406. freqs = torch.outer(t, inv_freq)
  407. if self.scale is None:
  408. self._cos_cached = torch.cos(freqs).to(dtype)
  409. self._sin_cached = torch.sin(freqs).to(dtype)
  410. else:
  411. power = (
  412. torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
  413. - seqlen // 2
  414. ) / self.scale_base
  415. scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
  416. # We want the multiplication by scale to happen in fp32
  417. self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
  418. self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
  419. self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
  420. self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
  421. def forward(
  422. self,
  423. qkv: torch.Tensor,
  424. kv: Optional[torch.Tensor] = None,
  425. seqlen_offset: Union[int, torch.Tensor] = 0,
  426. max_seqlen: Optional[int] = None,
  427. num_heads_q: Optional[int] = None,
  428. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  429. """
  430. qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)
  431. if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).
  432. If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
  433. then num_heads_q must be provided.
  434. kv: (batch, seqlen, 2, nheads, headdim)
  435. seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
  436. Most commonly used in inference when we have KV cache.
  437. If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
  438. should pass in max_seqlen, which will update the cos / sin cache up to that length.
  439. Apply rotary embedding *inplace* to qkv and / or kv.
  440. """
  441. seqlen = qkv.shape[1]
  442. if max_seqlen is not None:
  443. self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
  444. elif isinstance(seqlen_offset, int):
  445. self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
  446. if kv is None:
  447. if self.scale is None:
  448. return apply_rotary_emb_qkv_(
  449. qkv,
  450. self._cos_cached,
  451. self._sin_cached,
  452. interleaved=self.interleaved,
  453. seqlen_offsets=seqlen_offset,
  454. num_heads_q=num_heads_q,
  455. )
  456. else:
  457. return apply_rotary_emb_qkv_(
  458. qkv,
  459. self._cos_cached,
  460. self._sin_cached,
  461. self._cos_k_cached,
  462. self._sin_k_cached,
  463. interleaved=self.interleaved,
  464. seqlen_offsets=seqlen_offset,
  465. num_heads_q=num_heads_q,
  466. )
  467. else:
  468. q = qkv
  469. q = apply_rotary_emb_func(
  470. q,
  471. self._cos_cached,
  472. self._sin_cached,
  473. interleaved=self.interleaved,
  474. inplace=True,
  475. seqlen_offsets=seqlen_offset,
  476. )
  477. if self.scale is None:
  478. kv = apply_rotary_emb_kv_(
  479. kv,
  480. self._cos_cached,
  481. self._sin_cached,
  482. interleaved=self.interleaved,
  483. seqlen_offsets=seqlen_offset,
  484. )
  485. else:
  486. kv = apply_rotary_emb_kv_(
  487. kv,
  488. self._cos_k_cached,
  489. self._sin_k_cached,
  490. interleaved=self.interleaved,
  491. seqlen_offsets=seqlen_offset,
  492. )
  493. return q, kv