|
@@ -1,15 +1,4 @@
|
|
|
-# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
|
|
|
-# We split the input differently ((d 2) -> d 2 instead of (2 d) -> d 2), following the original
|
|
|
-# paper's implementation. This should not matter.
|
|
|
-
|
|
|
-# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
-#
|
|
|
-# This source code is licensed under the BSD license found in the
|
|
|
-# LICENSE file in the root directory of this source tree.
|
|
|
-
|
|
|
-
|
|
|
-# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
|
|
|
-# NOTE: Almost the same right now, moving parts to Triton is the next step
|
|
|
+# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
|
|
|
|
|
|
from typing import Tuple
|
|
|
import math
|
|
@@ -18,28 +7,118 @@ import torch
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
+import rotary_emb
|
|
|
|
|
|
-def rotate_half(x):
|
|
|
- # rearrange doesn't work with torch.jit
|
|
|
- # x = rearrange(x, '... (d r) -> ... d r', r=2)
|
|
|
- x = x.unflatten(dim=-1, sizes=(-1, 2))
|
|
|
- x1, x2 = x.unbind(dim=-1)
|
|
|
- rotated_x = torch.stack((-x2, x1), dim=-1)
|
|
|
- # return rearrange(rotated_x, '... d r -> ... (d r)')
|
|
|
- return rotated_x.flatten(start_dim=-2)
|
|
|
|
|
|
+def rotate_half(x):
|
|
|
+ x1, x2 = x.chunk(2, dim=-1)
|
|
|
+ return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
-@torch.jit.script
|
|
|
-def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int = -2):
|
|
|
- # NOTE: This could probably be moved to Triton
|
|
|
|
|
|
- # Handle a possible sequence length mismatch in between q and k
|
|
|
- cos = cos[:x.shape[seq_dimension], :]
|
|
|
- sin = sin[:x.shape[seq_dimension], :]
|
|
|
- if seq_dimension == -3:
|
|
|
- cos = cos[:, None, :]
|
|
|
- sin = sin[:, None, :]
|
|
|
- return (x * cos) + (rotate_half(x) * sin)
|
|
|
+def apply_rotary_emb_torch(x, cos, sin):
|
|
|
+ """
|
|
|
+ x: (batch_size, seqlen, nheads, headdim)
|
|
|
+ cos, sin: (seqlen, rotary_dim / 2)
|
|
|
+ """
|
|
|
+ rotary_dim = cos.shape[-1] * 2
|
|
|
+ assert rotary_dim <= x.shape[-1]
|
|
|
+ cos = repeat(cos, 's d -> s 1 (2 d)')
|
|
|
+ sin = repeat(sin, 's d -> s 1 (2 d)')
|
|
|
+ return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin,
|
|
|
+ x[..., rotary_dim:]], dim=-1)
|
|
|
+
|
|
|
+
|
|
|
+class ApplyRotaryEmb(torch.autograd.Function):
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def forward(ctx, x, cos, sin, inplace=False):
|
|
|
+ """
|
|
|
+ x: (batch_size, seqlen, nheads, headdim)
|
|
|
+ cos, sin: (seqlen, rotary_dim / 2)
|
|
|
+ rotary_dim must be <= headdim
|
|
|
+ Apply rotary embedding to the first rotary_dim of x.
|
|
|
+ """
|
|
|
+ batch, seqlen, nheads, headdim = x.shape
|
|
|
+ rotary_seqlen, rotary_dim = cos.shape
|
|
|
+ rotary_dim *= 2
|
|
|
+ assert rotary_dim <= headdim
|
|
|
+ assert seqlen <= rotary_seqlen
|
|
|
+ assert cos.shape == (rotary_seqlen, rotary_dim // 2)
|
|
|
+ assert sin.shape == (rotary_seqlen, rotary_dim // 2)
|
|
|
+ x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
|
|
|
+ out = torch.empty_like(x) if not inplace else x
|
|
|
+ o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
|
|
|
+ rotary_emb.apply_rotary(x1, x2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
|
|
+ rearrange(sin[:, :seqlen], 's d -> s 1 d'), o1, o2, False)
|
|
|
+ if not inplace and rotary_dim < headdim:
|
|
|
+ out[..., rotary_dim:].copy_(x[..., rotary_dim:])
|
|
|
+ ctx.save_for_backward(cos, sin)
|
|
|
+ ctx.inplace = inplace
|
|
|
+ return out if not inplace else x
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def backward(ctx, do):
|
|
|
+ cos, sin = ctx.saved_tensors
|
|
|
+ _, seqlen, _, headdim = do.shape
|
|
|
+ rotary_dim = cos.shape[-1]
|
|
|
+ rotary_dim *= 2
|
|
|
+ inplace = ctx.inplace
|
|
|
+ do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
|
|
|
+ dx = torch.empty_like(do) if not inplace else do
|
|
|
+ dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
|
|
|
+ rotary_emb.apply_rotary(do1, do2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
|
|
+ rearrange(sin[:, :seqlen], 's d -> s 1 d'), dx1, dx2, True)
|
|
|
+ if not inplace and rotary_dim < headdim:
|
|
|
+ dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
|
|
|
+ return dx, None, None, None
|
|
|
+
|
|
|
+
|
|
|
+apply_rotary_emb_func = ApplyRotaryEmb.apply
|
|
|
+
|
|
|
+
|
|
|
+class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def forward(ctx, qkv, cos, sin):
|
|
|
+ """
|
|
|
+ qkv: (batch_size, seqlen, 3, nheads, headdim)
|
|
|
+ cos, sin: (seqlen, rotary_dim / 2)
|
|
|
+ rotary_dim must be <= headdim
|
|
|
+ Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
|
|
+ """
|
|
|
+ batch, seqlen, three, nheads, headdim = qkv.shape
|
|
|
+ assert three == 3
|
|
|
+ rotary_seqlen, rotary_dim = cos.shape
|
|
|
+ rotary_dim *= 2
|
|
|
+ assert rotary_dim <= headdim
|
|
|
+ assert seqlen <= rotary_seqlen
|
|
|
+ assert cos.shape == (seqlen, rotary_dim // 2)
|
|
|
+ assert sin.shape == (seqlen, rotary_dim // 2)
|
|
|
+ q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ rotary_emb.apply_rotary(q1, q2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
|
|
+ rearrange(sin[:, :seqlen], 's d -> s 1 d'), q1, q2, False)
|
|
|
+ k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ rotary_emb.apply_rotary(k1, k2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
|
|
+ rearrange(sin[:, :seqlen], 's d -> s 1 d'), k1, k2, False)
|
|
|
+ ctx.save_for_backward(cos, sin)
|
|
|
+ return qkv
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def backward(ctx, dqkv):
|
|
|
+ cos, sin = ctx.saved_tensors
|
|
|
+ _, seqlen, _, _, headdim = dqkv.shape
|
|
|
+ rotary_dim = cos.shape[-1]
|
|
|
+ rotary_dim *= 2
|
|
|
+ dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
|
|
+ rearrange(sin[:, :seqlen], 's d -> s 1 d'), dq1, dq2, True)
|
|
|
+ dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
|
|
|
+ rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
|
|
|
+ rearrange(sin[:, :seqlen], 's d -> s 1 d'), dk1, dk2, True)
|
|
|
+ return dqkv, None, None
|
|
|
+
|
|
|
+
|
|
|
+apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
@@ -55,9 +134,6 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
|
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
|
|
|
|
|
-
|
|
|
- .. warning: Please note that this embedding is not registered on purpose, as it is transformative
|
|
|
- (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim_model: int, *_, **__):
|
|
@@ -66,70 +142,26 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
|
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
|
|
- self._seq_len_cached = None
|
|
|
+ self._seq_len_cached = 0
|
|
|
self._cos_cached = None
|
|
|
self._sin_cached = None
|
|
|
|
|
|
- def _update_cos_sin_tables(self, x, seq_dimension=-2):
|
|
|
- seq_len = x.shape[seq_dimension]
|
|
|
-
|
|
|
+ def _update_cos_sin_cache(self, x):
|
|
|
+ """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
|
|
|
+ """
|
|
|
+ seqlen = x.shape[1]
|
|
|
# Reset the tables if the sequence length has changed,
|
|
|
# or if we're on a new device (possibly due to tracing for instance)
|
|
|
- if (seq_len != self._seq_len_cached or self._cos_cached.device != x.device
|
|
|
+ if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
|
|
|
or self._cos_cached.dtype != x.dtype):
|
|
|
- self._seq_len_cached = seq_len
|
|
|
- t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)
|
|
|
+ self._seq_len_cached = seqlen
|
|
|
+ t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
|
|
|
# Don't do einsum, it converts fp32 to fp16
|
|
|
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
freqs = torch.outer(t, self.inv_freq)
|
|
|
- self._cos_cached = repeat(torch.cos(freqs).to(x.dtype), '... d -> ... (d 2)')
|
|
|
- self._sin_cached = repeat(torch.sin(freqs).to(x.dtype), '... d -> ... (d 2)')
|
|
|
-
|
|
|
- return self._cos_cached, self._sin_cached
|
|
|
-
|
|
|
- def forward(self, q: torch.Tensor, k: torch.Tensor,
|
|
|
- seq_dimension=-2) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
|
|
|
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
|
|
|
- k, seq_dimension=seq_dimension
|
|
|
- )
|
|
|
-
|
|
|
- return (
|
|
|
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),
|
|
|
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),
|
|
|
- )
|
|
|
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
|
|
|
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
|
|
|
|
|
|
-
|
|
|
-class RotaryEmbedding2D(torch.nn.Module):
|
|
|
-
|
|
|
- def __init__(self, dim: int):
|
|
|
- super().__init__()
|
|
|
- assert dim % 4 == 0
|
|
|
- self.rotary_emb1d = RotaryEmbedding(dim // 2)
|
|
|
-
|
|
|
- def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension=-2):
|
|
|
- assert seq_dimension in [-2, -3] # Either (bs, h, s, d) or (bs, s, h, d)
|
|
|
- seqlen = q.shape[seq_dimension]
|
|
|
- seqlen_sqrt = int(math.sqrt(seqlen))
|
|
|
- assert seqlen == seqlen_sqrt ** 2
|
|
|
- if seq_dimension == -3: # (bs, s, h, d)
|
|
|
- q = rearrange(q, 'b s h d -> b h s d')
|
|
|
- k = rearrange(k, 'b s h d -> b h s d')
|
|
|
- q0, q1 = q.chunk(2, dim=-1)
|
|
|
- k0, k1 = k.chunk(2, dim=-1)
|
|
|
- # (bs, h, s, d)
|
|
|
- q0 = rearrange(q0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
|
|
|
- k0 = rearrange(k0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
|
|
|
- q0_emb, k0_emb = self.rotary_emb1d(q0, k0, seq_dimension=-2)
|
|
|
- q0_emb = rearrange(q0_emb, 'b nheads h w d -> b nheads (h w) d')
|
|
|
- k0_emb = rearrange(k0_emb, 'b nheads h w d -> b nheads (h w) d')
|
|
|
- q1 = rearrange(q1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
|
|
|
- k1 = rearrange(k1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
|
|
|
- q1_emb, k1_emb = self.rotary_emb1d(q1, k1, seq_dimension=-3)
|
|
|
- q1_emb = rearrange(q1_emb, 'b nheads h w d -> b nheads (h w) d')
|
|
|
- k1_emb = rearrange(k1_emb, 'b nheads h w d -> b nheads (h w) d')
|
|
|
- q_emb, k_emb = torch.cat([q0_emb, q1_emb], dim=-1), torch.cat([k0_emb, k1_emb], dim=-1)
|
|
|
- if seq_dimension == -3:
|
|
|
- q_emb = rearrange(q_emb, 'b h s d -> b s h d')
|
|
|
- k_emb = rearrange(k_emb, 'b h s d -> b s h d')
|
|
|
- return q_emb, k_emb
|
|
|
+ def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ self._update_cos_sin_cache(qkv)
|
|
|
+ return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached)
|