Răsfoiți Sursa

Implement rotary embedding in CUDA

Tri Dao 2 ani în urmă
părinte
comite
ca81f32e04
5 a modificat fișierele cu 359 adăugiri și 90 ștergeri
  1. 34 0
      csrc/rotary/rotary.cpp
  2. 41 0
      csrc/rotary/rotary_cuda.cu
  3. 118 0
      csrc/rotary/setup.py
  4. 122 90
      flash_attn/rotary.py
  5. 44 0
      tests/test_rotary.py

+ 34 - 0
csrc/rotary/rotary.cpp

@@ -0,0 +1,34 @@
+#include <torch/extension.h>
+
+#define CHECK_DEVICE(x)                                                        \
+  TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
+#define CHECK_SHAPE(x, ...)                                                    \
+  TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}),                  \
+              #x " must have shape (" #__VA_ARGS__ ")")
+
+void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
+                       const torch::Tensor cos, const torch::Tensor sin,
+                       torch::Tensor out1, torch::Tensor out2,
+                       const bool conj);
+
+void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
+                  const torch::Tensor cos, const torch::Tensor sin,
+                  torch::Tensor out1, torch::Tensor out2,
+                  const bool conj) {
+    CHECK_DEVICE(x1); CHECK_DEVICE(x2);
+    CHECK_DEVICE(cos); CHECK_DEVICE(sin);
+    CHECK_DEVICE(out1); CHECK_DEVICE(out1);
+    TORCH_CHECK(x1.dtype() == x2.dtype());
+    TORCH_CHECK(cos.dtype() == sin.dtype());
+    TORCH_CHECK(out1.dtype() == out2.dtype());
+    TORCH_CHECK(x1.dtype() == cos.dtype());
+    TORCH_CHECK(x1.dtype() == out1.dtype());
+    TORCH_CHECK(x1.sizes() == x2.sizes());
+    TORCH_CHECK(cos.sizes() == sin.sizes());
+    TORCH_CHECK(out1.sizes() == out2.sizes());
+    apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("apply_rotary", &apply_rotary, "Apply rotary embedding");
+}

+ 41 - 0
csrc/rotary/rotary_cuda.cu

@@ -0,0 +1,41 @@
+#include <torch/python.h>
+#include <ATen/native/TensorIterator.h>
+#include <ATen/native/cuda/Loops.cuh>
+
+void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
+                       const torch::Tensor cos, const torch::Tensor sin,
+                       torch::Tensor out1, torch::Tensor out2,
+                       const bool conj) {
+    auto iter = at::TensorIteratorConfig()
+        .add_output(out1)
+        .add_output(out2)
+        .add_input(x1)
+        .add_input(x2)
+        .add_input(cos)
+        .add_input(sin)
+        .check_all_same_dtype(false)
+        .promote_inputs_to_common_dtype(false)
+        .build();
+
+    if (!conj) {
+        AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
+            at::native::gpu_kernel_multiple_outputs(
+                iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
+                                    scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
+                scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
+                scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
+                return {out1, out2};
+            });
+        });
+    } else {
+        AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
+            at::native::gpu_kernel_multiple_outputs(
+                iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
+                                    scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
+                scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
+                scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
+                return {out1, out2};
+            });
+        });
+    }
+}

+ 118 - 0
csrc/rotary/setup.py

@@ -0,0 +1,118 @@
+# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
+import torch
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
+from setuptools import setup, find_packages
+import subprocess
+
+import sys
+import warnings
+import os
+
+# ninja build does not work unless include_dirs are abs path
+this_dir = os.path.dirname(os.path.abspath(__file__))
+
+
+def get_cuda_bare_metal_version(cuda_dir):
+    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
+    output = raw_output.split()
+    release_idx = output.index("release") + 1
+    release = output[release_idx].split(".")
+    bare_metal_major = release[0]
+    bare_metal_minor = release[1][0]
+
+    return raw_output, bare_metal_major, bare_metal_minor
+
+
+def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
+    raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
+    torch_binary_major = torch.version.cuda.split(".")[0]
+    torch_binary_minor = torch.version.cuda.split(".")[1]
+
+    print("\nCompiling cuda extensions with")
+    print(raw_output + "from " + cuda_dir + "/bin\n")
+
+    if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
+        raise RuntimeError(
+            "Cuda extensions are being compiled with a version of Cuda that does "
+            "not match the version used to compile Pytorch binaries.  "
+            "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+            + "In some cases, a minor-version mismatch will not cause later errors:  "
+            "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
+            "You can try commenting out this check (at your own risk)."
+        )
+
+
+def raise_if_cuda_home_none(global_option: str) -> None:
+    if CUDA_HOME is not None:
+        return
+    raise RuntimeError(
+        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
+        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
+        "only images whose names contain 'devel' will provide nvcc."
+    )
+
+
+def append_nvcc_threads(nvcc_extra_args):
+    _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+    if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
+        return nvcc_extra_args + ["--threads", "4"]
+    return nvcc_extra_args
+
+
+if not torch.cuda.is_available():
+    # https://github.com/NVIDIA/apex/issues/486
+    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
+    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
+    print(
+        "\nWarning: Torch did not find available GPUs on this system.\n",
+        "If your intention is to cross-compile, this is not an error.\n"
+        "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
+        "Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
+        "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
+        "If you wish to cross-compile for a single specific architecture,\n"
+        'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
+    )
+    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
+        _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+        if int(bare_metal_major) == 11:
+            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
+            if int(bare_metal_minor) > 0:
+                os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
+        else:
+            os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
+
+print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
+TORCH_MAJOR = int(torch.__version__.split(".")[0])
+TORCH_MINOR = int(torch.__version__.split(".")[1])
+
+cmdclass = {}
+ext_modules = []
+
+raise_if_cuda_home_none("rotary_emb")
+# Check, if CUDA11 is installed for compute capability 8.0
+cc_flag = []
+cc_flag.append("-gencode")
+cc_flag.append("arch=compute_70,code=sm_70")
+cc_flag.append("-gencode")
+cc_flag.append("arch=compute_80,code=sm_80")
+
+ext_modules.append(
+    CUDAExtension(
+        'rotary_emb', [
+            'rotary.cpp',
+            'rotary_cuda.cu',
+        ],
+        extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'],
+                            'nvcc': append_nvcc_threads([
+                                '-O3', '--use_fast_math', '--expt-extended-lambda'
+                            ] + cc_flag)
+                           }
+    )
+)
+
+setup(
+    name="rotary_emb",
+    version="0.1",
+    ext_modules=ext_modules,
+    cmdclass={"build_ext": BuildExtension} if ext_modules else {},
+)

+ 122 - 90
flash_attn/rotary.py

@@ -1,15 +1,4 @@
-# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
-# We split the input differently ((d 2) -> d 2 instead of (2 d) -> d 2), following the original
-# paper's implementation. This should not matter.
-
-# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
-#
-# This source code is licensed under the BSD license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
-# NOTE: Almost the same right now, moving parts to Triton is the next step
+# Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
 
 from typing import Tuple
 import math
@@ -18,28 +7,118 @@ import torch
 
 from einops import rearrange, repeat
 
+import rotary_emb
 
-def rotate_half(x):
-    # rearrange doesn't work with torch.jit
-    # x = rearrange(x, '... (d r) -> ... d r', r=2)
-    x = x.unflatten(dim=-1, sizes=(-1, 2))
-    x1, x2 = x.unbind(dim=-1)
-    rotated_x = torch.stack((-x2, x1), dim=-1)
-    # return rearrange(rotated_x, '... d r -> ... (d r)')
-    return rotated_x.flatten(start_dim=-2)
 
+def rotate_half(x):
+    x1, x2 = x.chunk(2, dim=-1)
+    return torch.cat((-x2, x1), dim=-1)
 
-@torch.jit.script
-def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int = -2):
-    # NOTE: This could probably be moved to Triton
 
-    # Handle a possible sequence length mismatch in between q and k
-    cos = cos[:x.shape[seq_dimension], :]
-    sin = sin[:x.shape[seq_dimension], :]
-    if seq_dimension == -3:
-        cos = cos[:, None, :]
-        sin = sin[:, None, :]
-    return (x * cos) + (rotate_half(x) * sin)
+def apply_rotary_emb_torch(x, cos, sin):
+    """
+    x: (batch_size, seqlen, nheads, headdim)
+    cos, sin: (seqlen, rotary_dim / 2)
+    """
+    rotary_dim = cos.shape[-1] * 2
+    assert rotary_dim <= x.shape[-1]
+    cos = repeat(cos, 's d -> s 1 (2 d)')
+    sin = repeat(sin, 's d -> s 1 (2 d)')
+    return torch.cat([x[..., :rotary_dim] * cos + rotate_half(x[..., :rotary_dim]) * sin,
+                      x[..., rotary_dim:]], dim=-1)
+
+
+class ApplyRotaryEmb(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, x, cos, sin, inplace=False):
+        """
+            x: (batch_size, seqlen, nheads, headdim)
+            cos, sin: (seqlen, rotary_dim / 2)
+        rotary_dim must be <= headdim
+        Apply rotary embedding to the first rotary_dim of x.
+        """
+        batch, seqlen, nheads, headdim = x.shape
+        rotary_seqlen, rotary_dim = cos.shape
+        rotary_dim *= 2
+        assert rotary_dim <= headdim
+        assert seqlen <= rotary_seqlen
+        assert cos.shape == (rotary_seqlen, rotary_dim // 2)
+        assert sin.shape == (rotary_seqlen, rotary_dim // 2)
+        x1, x2 = x[..., :rotary_dim].chunk(2, dim=-1)
+        out = torch.empty_like(x) if not inplace else x
+        o1, o2 = out[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (x1, x2)
+        rotary_emb.apply_rotary(x1, x2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
+                                rearrange(sin[:, :seqlen], 's d -> s 1 d'), o1, o2, False)
+        if not inplace and rotary_dim < headdim:
+            out[..., rotary_dim:].copy_(x[..., rotary_dim:])
+        ctx.save_for_backward(cos, sin)
+        ctx.inplace = inplace
+        return out if not inplace else x
+
+    @staticmethod
+    def backward(ctx, do):
+        cos, sin = ctx.saved_tensors
+        _, seqlen, _, headdim = do.shape
+        rotary_dim = cos.shape[-1]
+        rotary_dim *= 2
+        inplace = ctx.inplace
+        do1, do2 = do[..., :rotary_dim].chunk(2, dim=-1)
+        dx = torch.empty_like(do) if not inplace else do
+        dx1, dx2 = dx[..., :rotary_dim].chunk(2, dim=-1) if not inplace else (do1, do2)
+        rotary_emb.apply_rotary(do1, do2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
+                                rearrange(sin[:, :seqlen], 's d -> s 1 d'), dx1, dx2, True)
+        if not inplace and rotary_dim < headdim:
+            dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
+        return dx, None, None, None
+
+
+apply_rotary_emb_func = ApplyRotaryEmb.apply
+
+
+class ApplyRotaryEmbQKV_(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, qkv, cos, sin):
+        """
+            qkv: (batch_size, seqlen, 3, nheads, headdim)
+            cos, sin: (seqlen, rotary_dim / 2)
+        rotary_dim must be <= headdim
+        Apply rotary embedding *inplace* to the first rotary_dim of q and k.
+        """
+        batch, seqlen, three, nheads, headdim = qkv.shape
+        assert three == 3
+        rotary_seqlen, rotary_dim = cos.shape
+        rotary_dim *= 2
+        assert rotary_dim <= headdim
+        assert seqlen <= rotary_seqlen
+        assert cos.shape == (seqlen, rotary_dim // 2)
+        assert sin.shape == (seqlen, rotary_dim // 2)
+        q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
+        rotary_emb.apply_rotary(q1, q2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
+                                rearrange(sin[:, :seqlen], 's d -> s 1 d'), q1, q2, False)
+        k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
+        rotary_emb.apply_rotary(k1, k2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
+                                rearrange(sin[:, :seqlen], 's d -> s 1 d'), k1, k2, False)
+        ctx.save_for_backward(cos, sin)
+        return qkv
+
+    @staticmethod
+    def backward(ctx, dqkv):
+        cos, sin = ctx.saved_tensors
+        _, seqlen, _, _, headdim = dqkv.shape
+        rotary_dim = cos.shape[-1]
+        rotary_dim *= 2
+        dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
+        rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
+                                rearrange(sin[:, :seqlen], 's d -> s 1 d'), dq1, dq2, True)
+        dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
+        rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:, :seqlen], 's d -> s 1 d'),
+                                rearrange(sin[:, :seqlen], 's d -> s 1 d'), dk1, dk2, True)
+        return dqkv, None, None
+
+
+apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
 
 
 class RotaryEmbedding(torch.nn.Module):
@@ -55,9 +134,6 @@ class RotaryEmbedding(torch.nn.Module):
     .. _repo: https://github.com/ZhuiyiTechnology/roformer
     .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
 
-
-    .. warning: Please note that this embedding is not registered on purpose, as it is transformative
-        (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
     """
 
     def __init__(self, dim_model: int, *_, **__):
@@ -66,70 +142,26 @@ class RotaryEmbedding(torch.nn.Module):
         inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
         self.register_buffer("inv_freq", inv_freq)
 
-        self._seq_len_cached = None
+        self._seq_len_cached = 0
         self._cos_cached = None
         self._sin_cached = None
 
-    def _update_cos_sin_tables(self, x, seq_dimension=-2):
-        seq_len = x.shape[seq_dimension]
-
+    def _update_cos_sin_cache(self, x):
+        """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
+        """
+        seqlen = x.shape[1]
         # Reset the tables if the sequence length has changed,
         # or if we're on a new device (possibly due to tracing for instance)
-        if (seq_len != self._seq_len_cached or self._cos_cached.device != x.device
+        if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
             or self._cos_cached.dtype != x.dtype):
-            self._seq_len_cached = seq_len
-            t = torch.arange(x.shape[seq_dimension], device=x.device, dtype=self.inv_freq.dtype)
+            self._seq_len_cached = seqlen
+            t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
             # Don't do einsum, it converts fp32 to fp16
             # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
             freqs = torch.outer(t, self.inv_freq)
-            self._cos_cached = repeat(torch.cos(freqs).to(x.dtype), '... d -> ... (d 2)')
-            self._sin_cached = repeat(torch.sin(freqs).to(x.dtype), '... d -> ... (d 2)')
-
-        return self._cos_cached, self._sin_cached
-
-    def forward(self, q: torch.Tensor, k: torch.Tensor,
-                seq_dimension=-2) -> Tuple[torch.Tensor, torch.Tensor]:
-        assert seq_dimension in [-2, -3]  # Either (bs, h, s, d) or (bs, s, h, d)
-        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
-            k, seq_dimension=seq_dimension
-        )
-
-        return (
-            apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached, seq_dimension),
-            apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached, seq_dimension),
-        )
+            self._cos_cached = torch.cos(freqs).to(x.dtype)
+            self._sin_cached = torch.sin(freqs).to(x.dtype)
 
-
-class RotaryEmbedding2D(torch.nn.Module):
-
-    def __init__(self, dim: int):
-        super().__init__()
-        assert dim % 4 == 0
-        self.rotary_emb1d = RotaryEmbedding(dim // 2)
-
-    def forward(self, q: torch.Tensor, k: torch.Tensor, seq_dimension=-2):
-        assert seq_dimension in [-2, -3]  # Either (bs, h, s, d) or (bs, s, h, d)
-        seqlen = q.shape[seq_dimension]
-        seqlen_sqrt = int(math.sqrt(seqlen))
-        assert seqlen == seqlen_sqrt ** 2
-        if seq_dimension == -3:  # (bs, s, h, d)
-            q = rearrange(q, 'b s h d -> b h s d')
-            k = rearrange(k, 'b s h d -> b h s d')
-        q0, q1 = q.chunk(2, dim=-1)
-        k0, k1 = k.chunk(2, dim=-1)
-        # (bs, h, s, d)
-        q0 = rearrange(q0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
-        k0 = rearrange(k0, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
-        q0_emb, k0_emb = self.rotary_emb1d(q0, k0, seq_dimension=-2)
-        q0_emb = rearrange(q0_emb, 'b nheads h w d -> b nheads (h w) d')
-        k0_emb = rearrange(k0_emb, 'b nheads h w d -> b nheads (h w) d')
-        q1 = rearrange(q1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
-        k1 = rearrange(k1, 'b nheads (h w) d -> b nheads h w d', h=seqlen_sqrt)
-        q1_emb, k1_emb = self.rotary_emb1d(q1, k1, seq_dimension=-3)
-        q1_emb = rearrange(q1_emb, 'b nheads h w d -> b nheads (h w) d')
-        k1_emb = rearrange(k1_emb, 'b nheads h w d -> b nheads (h w) d')
-        q_emb, k_emb = torch.cat([q0_emb, q1_emb], dim=-1), torch.cat([k0_emb, k1_emb], dim=-1)
-        if seq_dimension == -3:
-            q_emb = rearrange(q_emb, 'b h s d -> b s h d')
-            k_emb = rearrange(k_emb, 'b h s d -> b s h d')
-        return q_emb, k_emb
+    def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        self._update_cos_sin_cache(qkv)
+        return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached)

+ 44 - 0
tests/test_rotary.py

@@ -0,0 +1,44 @@
+import math
+
+import torch
+import torch.nn.functional as F
+import pytest
+
+from einops import rearrange
+
+from flash_attn.rotary import apply_rotary_emb_func, apply_rotary_emb_torch
+
+
+is_sm8x = torch.cuda.get_device_capability('cuda') >= (8, 0)
+
+@pytest.mark.parametrize('dtype', ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16]))
+# @pytest.mark.parametrize('dtype', ([torch.float16]))
+@pytest.mark.parametrize('rotary_fraction', [1.0, 0.5])
+# @pytest.mark.parametrize('rotary_fraction', [0.5])
+@pytest.mark.parametrize('inplace', [False, True])
+# @pytest.mark.parametrize('inplace', [False])
+def test_rotary_single_tensor(inplace, rotary_fraction, dtype):
+    rtol = 1e-3
+    batch_size = 32
+    nheads = 4
+    seqlen = 217
+    headdim = 128
+    x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device='cuda',
+                    requires_grad=True)
+    x_pt = x.detach().clone().requires_grad_()
+    rotary_dim = int(rotary_fraction * headdim)
+    assert rotary_dim % 2 == 0
+    angle = torch.randn(seqlen, rotary_dim // 2, device='cuda')
+    cos = torch.cos(angle).to(dtype=dtype)
+    sin = torch.sin(angle).to(dtype=dtype)
+    out = apply_rotary_emb_func(x, cos, sin, inplace)
+    out_pt = apply_rotary_emb_torch(x_pt, cos, sin)
+    # Numerical error if we just do any arithmetic
+    atol = ((out + 0.3 - 0.3) - out).abs().max().item()
+    assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol)
+    g = torch.randn_like(out)
+    g_pt = g.clone()  # If inplace=True, we might modify the gradient inplace
+    out.backward(g)
+    out_pt.backward(g_pt)
+    atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item()
+    assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=2 * atol)