|
@@ -1,4 +1,4 @@
|
|
|
-# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
|
|
|
+# Copyright (c) 2023, Tri Dao.
|
|
|
|
|
|
from typing import Tuple
|
|
|
import math
|
|
@@ -10,31 +10,37 @@ from einops import rearrange, repeat
|
|
|
import rotary_emb
|
|
|
|
|
|
|
|
|
-def rotate_half(x):
|
|
|
- x1, x2 = x.chunk(2, dim=-1)
|
|
|
- return torch.cat((-x2, x1), dim=-1)
|
|
|
+def rotate_half(x, interleaved=False):
|
|
|
+ if not interleaved:
|
|
|
+ x1, x2 = x.chunk(2, dim=-1)
|
|
|
+ return torch.cat((-x2, x1), dim=-1)
|
|
|
+ else:
|
|
|
+ x1, x2 = x[..., ::2], x[..., 1::2]
|
|
|
+ return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2)
|
|
|
|
|
|
|
|
|
-def apply_rotary_emb_torch(x, cos, sin):
|
|
|
+def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
|
"""
|
|
|
x: (batch_size, seqlen, nheads, headdim)
|
|
|
cos, sin: (seqlen, rotary_dim / 2)
|
|
|
"""
|
|
|
- rotary_dim = cos.shape[-1] * 2
|
|
|
- assert rotary_dim <= x.shape[-1]
|
|
|
+ ro_dim = cos.shape[-1] * 2
|
|
|
+ assert ro_dim <= x.shape[-1]
|
|
|
cos = repeat(cos, 's d -> s 1 (2 d)')
|
|
|
sin = repeat(sin, 's d -> s 1 (2 d)')
|
|
|
- return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin,
|
|
|
- x[..., rotary_dim:]], dim=-1)
|
|
|
+ return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
|
+ x[..., ro_dim:]], dim=-1)
|
|
|
|
|
|
|
|
|
class ApplyRotaryEmb(torch.autograd.Function):
|
|
|
|
|
|
@staticmethod
|
|
|
- def forward(ctx, x, cos, sin, inplace=False):
|
|
|
+ def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
|
|
|
"""
|
|
|
x: (batch_size, seqlen, nheads, headdim)
|
|
|
cos, sin: (seqlen, rotary_dim / 2)
|
|
|
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
|
|
+ of 1st half and 2nd half (GPT-NeoX style).
|
|
|
rotary_dim must be <= headdim
|
|
|
Apply rotary embedding to the first rotary_dim of x.
|
|
|
"""
|
|
@@ -44,14 +50,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|
|
assert rotary_dim <= headdim
|
|
|
assert seqlen <= rotary_seqlen
|
|
|
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
|
|
- x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
|
|
|
+ x_ro = x[..., :rotary_dim]
|
|
|
+ x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
|
|
|
out = torch.empty_like(x) if not inplace else x
|
|
|
- o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
|
|
|
+ out_ro = out[..., :rotary_dim]
|
|
|
+ if inplace:
|
|
|
+ o1, o2 = x1, x2
|
|
|
+ else:
|
|
|
+ o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
|
|
|
+ else (out_ro[..., ::2], out_ro[..., 1::2]))
|
|
|
rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
|
|
rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False)
|
|
|
if not inplace and rotary_dim < headdim:
|
|
|
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
|
|
ctx.save_for_backward(cos, sin)
|
|
|
+ ctx.interleaved = interleaved
|
|
|
ctx.inplace = inplace
|
|
|
return out if not inplace else x
|
|
|
|
|
@@ -62,14 +75,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|
|
rotary_dim = cos.shape[-1]
|
|
|
rotary_dim *= 2
|
|
|
inplace = ctx.inplace
|
|
|
- do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
|
|
|
+ do_ro = do[..., :rotary_dim]
|
|
|
+ do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
|
|
|
+ else (do_ro[..., ::2], do_ro[..., 1::2]))
|
|
|
dx = torch.empty_like(do) if not inplace else do
|
|
|
- dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
|
|
|
+ if inplace:
|
|
|
+ dx1, dx2 = do1, do2
|
|
|
+ else:
|
|
|
+ dx_ro = dx[..., :rotary_dim]
|
|
|
+ dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
|
|
|
+ else (dx_ro[..., ::2], dx_ro[..., 1::2]))
|
|
|
rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
|
|
rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True)
|
|
|
if not inplace and rotary_dim < headdim:
|
|
|
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
|
|
- return dx, None, None, None
|
|
|
+ return dx, None, None, None, None
|
|
|
|
|
|
|
|
|
apply_rotary_emb_func = ApplyRotaryEmb.apply
|
|
@@ -78,11 +98,13 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
|
|
|
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
|
|
|
@staticmethod
|
|
|
- def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
|
|
+ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
|
|
|
"""
|
|
|
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
|
|
cos, sin: (seqlen, rotary_dim / 2)
|
|
|
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
|
|
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
|
|
|
+ 1st half and 2nd half (GPT-NeoX style).
|
|
|
rotary_dim must be <= headdim
|
|
|
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
|
|
"""
|
|
@@ -95,13 +117,16 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
cos_k = cos if cos_k is None else cos_k
|
|
|
sin_k = sin if sin_k is None else sin_k
|
|
|
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
|
|
- q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ q_ro = qkv[:, :, 0, :, :rotary_dim]
|
|
|
+ q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
|
|
|
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
|
|
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
|
|
|
- k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ k_ro = qkv[:, :, 1, :, :rotary_dim]
|
|
|
+ k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
|
|
|
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
|
|
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
|
|
|
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
|
|
+ ctx.interleaved = interleaved
|
|
|
return qkv
|
|
|
|
|
|
@staticmethod
|
|
@@ -110,13 +135,17 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
_, seqlen, _, _, headdim = dqkv.shape
|
|
|
rotary_dim = cos.shape[-1]
|
|
|
rotary_dim *= 2
|
|
|
- dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ dq_ro = dqkv[:, :, 0, :, :rotary_dim]
|
|
|
+ dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
|
|
|
+ else (dq_ro[..., ::2], dq_ro[..., 1::2]))
|
|
|
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
|
|
|
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
|
|
- dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ dk_ro = dqkv[:, :, 1, :, :rotary_dim]
|
|
|
+ dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
|
|
|
+ else (dk_ro[..., ::2], dk_ro[..., 1::2]))
|
|
|
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
|
|
|
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
|
|
- return dqkv, None, None, None, None
|
|
|
+ return dqkv, None, None, None, None, None
|
|
|
|
|
|
|
|
|
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
|
@@ -135,22 +164,25 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
|
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
|
|
|
|
|
- If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
|
|
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
|
|
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
|
|
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, dim: int, base=10000, scale_base=0, device=None):
|
|
|
+ def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None):
|
|
|
"""
|
|
|
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
|
|
+ of 1st half and 2nd half (GPT-NeoX style).
|
|
|
"""
|
|
|
super().__init__()
|
|
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
|
|
|
dtype=torch.float32) / dim))
|
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
+ self.interleaved = interleaved
|
|
|
self.scale_base = scale_base
|
|
|
scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
|
|
- / (1.4 * dim) if scale_base > 0 else None)
|
|
|
+ / (1.4 * dim) if scale_base is not None else None)
|
|
|
self.register_buffer("scale", scale)
|
|
|
|
|
|
self._seq_len_cached = 0
|
|
@@ -187,16 +219,19 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
|
|
|
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
+ qkv: (batch, seqlen, 3, nheads, headdim)
|
|
|
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
|
token in the batch.
|
|
|
"""
|
|
|
self._update_cos_sin_cache(qkv, seqlen_offset)
|
|
|
if self.scale is None:
|
|
|
return apply_rotary_emb_qkv_(
|
|
|
- qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
|
|
|
+ qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
|
|
+ None, None, self.interleaved
|
|
|
)
|
|
|
else:
|
|
|
return apply_rotary_emb_qkv_(
|
|
|
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
|
|
- self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:]
|
|
|
+ self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:],
|
|
|
+ self.interleaved
|
|
|
)
|