rotary.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # Copyright (c) 2023, Tri Dao.
  2. from typing import Tuple
  3. import math
  4. import torch
  5. from einops import rearrange, repeat
  6. import rotary_emb
  7. def rotate_half(x, interleaved=False):
  8. if not interleaved:
  9. x1, x2 = x.chunk(2, dim=-1)
  10. return torch.cat((-x2, x1), dim=-1)
  11. else:
  12. x1, x2 = x[..., ::2], x[..., 1::2]
  13. return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
  14. def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
  15. """
  16. x: (batch_size, seqlen, nheads, headdim)
  17. cos, sin: (seqlen, rotary_dim / 2)
  18. """
  19. ro_dim = cos.shape[-1] * 2
  20. assert ro_dim <= x.shape[-1]
  21. cos = repeat(cos, 's d -> s 1 (2 d)')
  22. sin = repeat(sin, 's d -> s 1 (2 d)')
  23. return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
  24. x[..., ro_dim:]], dim=-1)
  25. class ApplyRotaryEmb(torch.autograd.Function):
  26. @staticmethod
  27. def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
  28. """
  29. x: (batch_size, seqlen, nheads, headdim)
  30. cos, sin: (seqlen, rotary_dim / 2)
  31. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  32. of 1st half and 2nd half (GPT-NeoX style).
  33. rotary_dim must be <= headdim
  34. Apply rotary embedding to the first rotary_dim of x.
  35. """
  36. batch, seqlen, nheads, headdim = x.shape
  37. rotary_seqlen, rotary_dim = cos.shape
  38. rotary_dim *= 2
  39. assert rotary_dim <= headdim
  40. assert seqlen <= rotary_seqlen
  41. assert sin.shape == (rotary_seqlen, rotary_dim // 2)
  42. x_ro = x[..., :rotary_dim]
  43. x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
  44. out = torch.empty_like(x) if not inplace else x
  45. out_ro = out[..., :rotary_dim]
  46. if inplace:
  47. o1, o2 = x1, x2
  48. else:
  49. o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
  50. else (out_ro[..., ::2], out_ro[..., 1::2]))
  51. rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
  52. rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
  53. if not inplace and rotary_dim < headdim:
  54. out[..., rotary_dim:].copy_(x[..., rotary_dim:])
  55. ctx.save_for_backward(cos, sin)
  56. ctx.interleaved = interleaved
  57. ctx.inplace = inplace
  58. return out if not inplace else x
  59. @staticmethod
  60. def backward(ctx, do):
  61. cos, sin = ctx.saved_tensors
  62. _, seqlen, _, headdim = do.shape
  63. rotary_dim = cos.shape[-1]
  64. rotary_dim *= 2
  65. inplace = ctx.inplace
  66. do_ro = do[..., :rotary_dim]
  67. do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
  68. else (do_ro[..., ::2], do_ro[..., 1::2]))
  69. dx = torch.empty_like(do) if not inplace else do
  70. if inplace:
  71. dx1, dx2 = do1, do2
  72. else:
  73. dx_ro = dx[..., :rotary_dim]
  74. dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
  75. else (dx_ro[..., ::2], dx_ro[..., 1::2]))
  76. rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
  77. rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
  78. if not inplace and rotary_dim < headdim:
  79. dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
  80. return dx, None, None, None, None
  81. apply_rotary_emb_func = ApplyRotaryEmb.apply
  82. class ApplyRotaryEmbQKV_(torch.autograd.Function):
  83. @staticmethod
  84. def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
  85. """
  86. qkv: (batch_size, seqlen, 3, nheads, headdim)
  87. cos, sin: (seqlen, rotary_dim / 2)
  88. cos_k, sin_k: (seqlen, rotary_dim / 2), optional
  89. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
  90. 1st half and 2nd half (GPT-NeoX style).
  91. rotary_dim must be <= headdim
  92. Apply rotary embedding *inplace* to the first rotary_dim of q and k.
  93. """
  94. batch, seqlen, three, nheads, headdim = qkv.shape
  95. assert three == 3
  96. rotary_seqlen, rotary_dim = cos.shape
  97. rotary_dim *= 2
  98. assert rotary_dim <= headdim
  99. assert seqlen <= rotary_seqlen
  100. cos_k = cos if cos_k is None else cos_k
  101. sin_k = sin if sin_k is None else sin_k
  102. assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
  103. q_ro = qkv[:, :, 0, :, :rotary_dim]
  104. q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
  105. rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
  106. rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
  107. k_ro = qkv[:, :, 1, :, :rotary_dim]
  108. k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
  109. rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
  110. rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
  111. ctx.save_for_backward(cos, sin, cos_k, sin_k)
  112. ctx.interleaved = interleaved
  113. return qkv
  114. @staticmethod
  115. def backward(ctx, dqkv):
  116. cos, sin, cos_k, sin_k = ctx.saved_tensors
  117. _, seqlen, _, _, headdim = dqkv.shape
  118. rotary_dim = cos.shape[-1]
  119. rotary_dim *= 2
  120. dq_ro = dqkv[:, :, 0, :, :rotary_dim]
  121. dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
  122. else (dq_ro[..., ::2], dq_ro[..., 1::2]))
  123. rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
  124. rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
  125. dk_ro = dqkv[:, :, 1, :, :rotary_dim]
  126. dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
  127. else (dk_ro[..., ::2], dk_ro[..., 1::2]))
  128. rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
  129. rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
  130. return dqkv, None, None, None, None, None
  131. apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
  132. class RotaryEmbedding(torch.nn.Module):
  133. """
  134. The rotary position embeddings from RoFormer_ (Su et. al).
  135. A crucial insight from the method is that the query and keys are
  136. transformed by rotation matrices which depend on the relative positions.
  137. Other implementations are available in the Rotary Transformer repo_ and in
  138. GPT-NeoX_, GPT-NeoX was an inspiration
  139. .. _RoFormer: https://arxiv.org/abs/2104.09864
  140. .. _repo: https://github.com/ZhuiyiTechnology/roformer
  141. .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
  142. If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
  143. A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
  144. Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
  145. """
  146. def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None):
  147. """
  148. interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
  149. of 1st half and 2nd half (GPT-NeoX style).
  150. """
  151. super().__init__()
  152. # Generate and save the inverse frequency buffer (non trainable)
  153. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
  154. dtype=torch.float32) / dim))
  155. self.register_buffer("inv_freq", inv_freq)
  156. self.interleaved = interleaved
  157. self.scale_base = scale_base
  158. scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
  159. / (1.4 * dim) if scale_base is not None else None)
  160. self.register_buffer("scale", scale)
  161. self._seq_len_cached = 0
  162. self._cos_cached = None
  163. self._sin_cached = None
  164. self._cos_k_cached = None
  165. self._sin_k_cached = None
  166. def _update_cos_sin_cache(self, x, seqlen_offset=0):
  167. """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
  168. """
  169. seqlen = x.shape[1] + seqlen_offset
  170. # Reset the tables if the sequence length has changed,
  171. # or if we're on a new device (possibly due to tracing for instance)
  172. if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
  173. or self._cos_cached.dtype != x.dtype):
  174. self._seq_len_cached = seqlen
  175. t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
  176. # Don't do einsum, it converts fp32 to fp16
  177. # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
  178. freqs = torch.outer(t, self.inv_freq.to(device=t.device))
  179. if self.scale is None:
  180. self._cos_cached = torch.cos(freqs).to(x.dtype)
  181. self._sin_cached = torch.sin(freqs).to(x.dtype)
  182. else:
  183. power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
  184. - seqlen // 2) / self.scale_base)
  185. scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
  186. # We want the multiplication by scale to happen in fp32
  187. self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
  188. self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
  189. self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
  190. self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
  191. def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
  192. """
  193. qkv: (batch, seqlen, 3, nheads, headdim)
  194. seqlen_offset: can be used in generation where the qkv being passed in is only the last
  195. token in the batch.
  196. """
  197. self._update_cos_sin_cache(qkv, seqlen_offset)
  198. if self.scale is None:
  199. return apply_rotary_emb_qkv_(
  200. qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
  201. None, None, self.interleaved
  202. )
  203. else:
  204. return apply_rotary_emb_qkv_(
  205. qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
  206. self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
  207. self.interleaved
  208. )