rotary.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright (c) 2023, Tri Dao.
  2. import math
  3. from typing import Optional, Tuple, Union
  4. import torch
  5. from einops import rearrange, repeat
  6. from flash_attn.ops.triton.rotary import apply_rotary
  7. def rotate_half(x, interleaved=False):
  8. if not interleaved:
  9. x1, x2 = x.chunk(2, dim=-1)
  10. return torch.cat((-x2, x1), dim=-1)
  11. else:
  12. x1, x2 = x[..., ::2], x[..., 1::2]
  13. return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
  14. def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
  15. """
  16. x: (batch_size, seqlen, nheads, headdim)
  17. cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
  18. """
  19. ro_dim = cos.shape[-1] * 2
  20. assert ro_dim <= x.shape[-1]
  21. cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
  22. sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
  23. return torch.cat(
  24. [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
  25. dim=-1,
  26. )
  27. class ApplyRotaryEmb(torch.autograd.Function):
  28. @staticmethod
  29. def forward(
  30. ctx,
  31. x,
  32. cos,
  33. sin,
  34. interleaved=False,
  35. inplace=False,
  36. seqlen_offsets: Union[int, torch.Tensor] = 0,
  37. cu_seqlens: Optional[torch.Tensor] = None,
  38. max_seqlen: Optional[int] = None,
  39. ):
  40. out = apply_rotary(
  41. x,
  42. cos,
  43. sin,
  44. seqlen_offsets=seqlen_offsets,
  45. cu_seqlens=cu_seqlens,
  46. max_seqlen=max_seqlen,
  47. interleaved=interleaved,
  48. inplace=inplace,
  49. )
  50. if isinstance(seqlen_offsets, int):
  51. ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
  52. ctx.seqlen_offsets = seqlen_offsets
  53. else:
  54. ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
  55. ctx.seqlen_offsets = None
  56. ctx.interleaved = interleaved
  57. ctx.inplace = inplace
  58. ctx.max_seqlen = max_seqlen
  59. return out if not inplace else x
  60. @staticmethod
  61. def backward(ctx, do):
  62. seqlen_offsets = ctx.seqlen_offsets
  63. if seqlen_offsets is None:
  64. cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
  65. else:
  66. cos, sin, cu_seqlens = ctx.saved_tensors
  67. # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
  68. # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
  69. if not ctx.interleaved and not ctx.inplace:
  70. do = do.clone()
  71. dx = apply_rotary(
  72. do,
  73. cos,
  74. sin,
  75. seqlen_offsets=seqlen_offsets,
  76. cu_seqlens=cu_seqlens,
  77. max_seqlen=ctx.max_seqlen,
  78. interleaved=ctx.interleaved,
  79. inplace=ctx.inplace,
  80. conjugate=True,
  81. )
  82. return dx, None, None, None, None, None, None, None
  83. def apply_rotary_emb(
  84. x,
  85. cos,
  86. sin,
  87. interleaved=False,
  88. inplace=False,
  89. seqlen_offsets: Union[int, torch.Tensor] = 0,
  90. cu_seqlens: Optional[torch.Tensor] = None,
  91. max_seqlen: Optional[int] = None,
  92. ):
  93. """
  94. Arguments:
  95. x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
  96. else (total_seqlen, nheads, headdim)
  97. cos, sin: (seqlen_rotary, rotary_dim / 2)
  98. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  99. of 1st half and 2nd half (GPT-NeoX style).
  100. inplace: if True, apply rotary embedding in-place.
  101. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
  102. Most commonly used in inference when we have KV cache.
  103. cu_seqlens: (batch + 1,) or None
  104. max_seqlen: int
  105. Return:
  106. out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
  107. else (total_seqlen, nheads, headdim)
  108. rotary_dim must be <= headdim
  109. Apply rotary embedding to the first rotary_dim of x.
  110. """
  111. return ApplyRotaryEmb.apply(
  112. x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
  113. )
  114. # For backward compatibility
  115. apply_rotary_emb_func = apply_rotary_emb
  116. class ApplyRotaryEmbQKV_(torch.autograd.Function):
  117. @staticmethod
  118. def forward(
  119. ctx,
  120. qkv,
  121. cos,
  122. sin,
  123. cos_k=None,
  124. sin_k=None,
  125. interleaved=False,
  126. seqlen_offsets: Union[int, torch.Tensor] = 0,
  127. ):
  128. batch, seqlen, three, nheads, headdim = qkv.shape
  129. assert three == 3
  130. if cos_k is None and sin_k is None and qkv.is_contiguous():
  131. # Call 1 kernel instead of 2 kernels
  132. # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
  133. # dimensions, we get the same tensor
  134. # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
  135. qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
  136. apply_rotary(
  137. qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
  138. )
  139. else:
  140. cos_k = cos if cos_k is None else cos_k
  141. sin_k = sin if sin_k is None else sin_k
  142. q, k = qkv[:, :, 0], qkv[:, :, 1]
  143. apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
  144. apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
  145. ctx.save_for_backward(cos, sin, cos_k, sin_k)
  146. if isinstance(seqlen_offsets, int):
  147. ctx.save_for_backward(cos, sin, cos_k, sin_k)
  148. ctx.seqlen_offsets = seqlen_offsets
  149. else:
  150. ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
  151. ctx.seqlen_offsets = None
  152. ctx.interleaved = interleaved
  153. return qkv
  154. @staticmethod
  155. def backward(ctx, dqkv):
  156. seqlen_offsets = ctx.seqlen_offsets
  157. if seqlen_offsets is None:
  158. cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
  159. else:
  160. cos, sin, cos_k, sin_k = ctx.saved_tensors
  161. if cos_k is None and sin_k is None and dqkv.is_contiguous():
  162. # Call 1 kernel instead of 2 kernels
  163. # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
  164. # dimensions, we get the same tensor
  165. dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
  166. apply_rotary(
  167. dqk,
  168. cos,
  169. sin,
  170. seqlen_offsets=seqlen_offsets,
  171. interleaved=ctx.interleaved,
  172. inplace=True,
  173. conjugate=True,
  174. )
  175. else:
  176. cos_k = cos if cos_k is None else cos_k
  177. sin_k = sin if sin_k is None else sin_k
  178. dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
  179. apply_rotary(
  180. dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True
  181. )
  182. apply_rotary(
  183. dk,
  184. cos_k,
  185. sin_k,
  186. seqlen_offsets,
  187. interleaved=ctx.interleaved,
  188. inplace=True,
  189. conjugate=True,
  190. )
  191. return dqkv, None, None, None, None, None, None
  192. def apply_rotary_emb_qkv_(
  193. qkv,
  194. cos,
  195. sin,
  196. cos_k=None,
  197. sin_k=None,
  198. interleaved=False,
  199. seqlen_offsets: Union[int, torch.Tensor] = 0,
  200. ):
  201. """
  202. Arguments:
  203. qkv: (batch_size, seqlen, 3, nheads, headdim)
  204. cos, sin: (seqlen, rotary_dim / 2)
  205. cos_k, sin_k: (seqlen, rotary_dim / 2), optional
  206. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
  207. 1st half and 2nd half (GPT-NeoX style).
  208. seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
  209. Most commonly used in inference when we have KV cache.
  210. Return:
  211. qkv: (batch_size, seqlen, 3, nheads, headdim)
  212. rotary_dim must be <= headdim
  213. Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
  214. """
  215. return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
  216. class ApplyRotaryEmbKV_(torch.autograd.Function):
  217. @staticmethod
  218. def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
  219. batch, seqlen, two, nheads, headdim = kv.shape
  220. assert two == 2
  221. k = kv[:, :, 0]
  222. apply_rotary(
  223. k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
  224. )
  225. if isinstance(seqlen_offsets, int):
  226. ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
  227. ctx.seqlen_offsets = seqlen_offsets
  228. else:
  229. ctx.save_for_backward(cos, sin, seqlen_offsets)
  230. ctx.seqlen_offsets = None
  231. ctx.interleaved = interleaved
  232. return kv
  233. @staticmethod
  234. def backward(ctx, dkv):
  235. seqlen_offsets = ctx.seqlen_offsets
  236. if seqlen_offsets is None:
  237. cos, sin, seqlen_offsets = ctx.saved_tensors
  238. else:
  239. cos, sin = ctx.saved_tensors
  240. apply_rotary(
  241. dkv[:, :, 0],
  242. cos,
  243. sin,
  244. seqlen_offsets=seqlen_offsets,
  245. interleaved=ctx.interleaved,
  246. inplace=True,
  247. conjugate=True,
  248. )
  249. return dkv, None, None, None, None
  250. apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
  251. def apply_rotary_emb_kv_(
  252. kv,
  253. cos,
  254. sin,
  255. interleaved=False,
  256. seqlen_offsets: Union[int, torch.Tensor] = 0,
  257. ):
  258. """
  259. Arguments:
  260. kv: (batch_size, seqlen, 2, nheads, headdim)
  261. cos, sin: (seqlen, rotary_dim / 2)
  262. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
  263. 1st half and 2nd half (GPT-NeoX style).
  264. seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
  265. Most commonly used in inference when we have KV cache.
  266. Return:
  267. kv: (batch_size, seqlen, 2, nheads, headdim)
  268. rotary_dim must be <= headdim
  269. Apply rotary embedding *inplace* to the first rotary_dim of K.
  270. """
  271. return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
  272. class RotaryEmbedding(torch.nn.Module):
  273. """
  274. The rotary position embeddings from RoFormer_ (Su et. al).
  275. A crucial insight from the method is that the query and keys are
  276. transformed by rotation matrices which depend on the relative positions.
  277. Other implementations are available in the Rotary Transformer repo_ and in
  278. GPT-NeoX_, GPT-NeoX was an inspiration
  279. .. _RoFormer: https://arxiv.org/abs/2104.09864
  280. .. _repo: https://github.com/ZhuiyiTechnology/roformer
  281. .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
  282. If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
  283. A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
  284. Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
  285. """
  286. def __init__(
  287. self,
  288. dim: int,
  289. base=10000.0,
  290. interleaved=False,
  291. scale_base=None,
  292. pos_idx_in_fp32=True,
  293. device=None,
  294. ):
  295. """
  296. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  297. of 1st half and 2nd half (GPT-NeoX style).
  298. pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
  299. otherwise they might be in lower precision.
  300. This option was added because previously (before 2023-07-02), when we construct
  301. the position indices, we use the dtype of self.inv_freq. In most cases this would
  302. be fp32, but if the model is trained in pure bf16 (not mixed precision), then
  303. self.inv_freq would be bf16, and the position indices are also in bf16.
  304. Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
  305. embeddings for some positions will coincide.
  306. To maintain compatibility with models previously trained in pure bf16,
  307. we add this option.
  308. """
  309. super().__init__()
  310. self.dim = dim
  311. self.base = float(base)
  312. self.pos_idx_in_fp32 = pos_idx_in_fp32
  313. # Generate and save the inverse frequency buffer (non trainable)
  314. inv_freq = self._compute_inv_freq(device)
  315. self.register_buffer("inv_freq", inv_freq, persistent=False)
  316. self.interleaved = interleaved
  317. self.scale_base = scale_base
  318. scale = (
  319. (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
  320. if scale_base is not None
  321. else None
  322. )
  323. self.register_buffer("scale", scale, persistent=False)
  324. self._seq_len_cached = 0
  325. self._cos_cached = None
  326. self._sin_cached = None
  327. self._cos_k_cached = None
  328. self._sin_k_cached = None
  329. def _compute_inv_freq(self, device=None):
  330. return 1.0 / (
  331. self.base
  332. ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
  333. )
  334. def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
  335. # Reset the tables if the sequence length has changed,
  336. # if we're on a new device (possibly due to tracing for instance),
  337. # or if we're switching from inference mode to training
  338. if (
  339. seqlen > self._seq_len_cached
  340. or self._cos_cached is None
  341. or self._cos_cached.device != device
  342. or self._cos_cached.dtype != dtype
  343. or (self.training and self._cos_cached.is_inference())
  344. ):
  345. self._seq_len_cached = seqlen
  346. # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
  347. # And the output of arange can be quite large, so bf16 would lose a lot of precision.
  348. # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
  349. if self.pos_idx_in_fp32:
  350. t = torch.arange(seqlen, device=device, dtype=torch.float32)
  351. # We want fp32 here as well since inv_freq will be multiplied with t, and the output
  352. # will be large. Having it in bf16 will lose a lot of precision and cause the
  353. # cos & sin output to change significantly.
  354. # We want to recompute self.inv_freq if it was not loaded in fp32
  355. if self.inv_freq.dtype != torch.float32:
  356. inv_freq = self._compute_inv_freq(device=device)
  357. else:
  358. inv_freq = self.inv_freq
  359. else:
  360. t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
  361. inv_freq = self.inv_freq
  362. # Don't do einsum, it converts fp32 to fp16 under AMP
  363. # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
  364. freqs = torch.outer(t, inv_freq)
  365. if self.scale is None:
  366. self._cos_cached = torch.cos(freqs).to(dtype)
  367. self._sin_cached = torch.sin(freqs).to(dtype)
  368. else:
  369. power = (
  370. torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
  371. - seqlen // 2
  372. ) / self.scale_base
  373. scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
  374. # We want the multiplication by scale to happen in fp32
  375. self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
  376. self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
  377. self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
  378. self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
  379. def forward(
  380. self,
  381. qkv: torch.Tensor,
  382. kv: Optional[torch.Tensor] = None,
  383. seqlen_offset: Union[int, torch.Tensor] = 0,
  384. max_seqlen: Optional[int] = None,
  385. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  386. """
  387. qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
  388. else it's just q of shape (batch, seqlen, nheads, headdim)
  389. kv: (batch, seqlen, 2, nheads, headdim)
  390. seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
  391. Most commonly used in inference when we have KV cache.
  392. If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
  393. should pass in max_seqlen, which will update the cos / sin cache up to that length.
  394. Apply rotary embedding *inplace* to qkv and / or kv.
  395. """
  396. seqlen = qkv.shape[1]
  397. if max_seqlen is not None:
  398. self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
  399. elif isinstance(seqlen_offset, int):
  400. self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
  401. if kv is None:
  402. if self.scale is None:
  403. return apply_rotary_emb_qkv_(
  404. qkv,
  405. self._cos_cached,
  406. self._sin_cached,
  407. interleaved=self.interleaved,
  408. seqlen_offsets=seqlen_offset,
  409. )
  410. else:
  411. return apply_rotary_emb_qkv_(
  412. qkv,
  413. self._cos_cached,
  414. self._sin_cached,
  415. self._cos_k_cached,
  416. self._sin_k_cached,
  417. interleaved=self.interleaved,
  418. seqlen_offsets=seqlen_offset,
  419. )
  420. else:
  421. q = qkv
  422. q = apply_rotary_emb_func(
  423. q,
  424. self._cos_cached,
  425. self._sin_cached,
  426. interleaved=self.interleaved,
  427. inplace=True,
  428. seqlen_offsets=seqlen_offset,
  429. )
  430. if self.scale is None:
  431. kv = apply_rotary_emb_kv_(
  432. kv,
  433. self._cos_cached,
  434. self._sin_cached,
  435. interleaved=self.interleaved,
  436. seqlen_offsets=seqlen_offset,
  437. )
  438. else:
  439. kv = apply_rotary_emb_kv_(
  440. kv,
  441. self._cos_k_cached,
  442. self._sin_k_cached,
  443. interleaved=self.interleaved,
  444. seqlen_offsets=seqlen_offset,
  445. )
  446. return q, kv