Переглянути джерело

feat: finish kernel testing

AlpinDale 1 рік тому
батько
коміт
b20083a1f3
2 змінених файлів з 211 додано та 0 видалено
  1. 42 0
      tests/kernels/conftest.py
  2. 169 0
      tests/kernels/test_pos_encoding.py

+ 42 - 0
tests/kernels/conftest.py

@@ -0,0 +1,42 @@
+from typing import List, Tuple
+
+import pytest
+import torch
+
+
+def create_kv_caches(
+        num_blocks: int,
+        block_size: int,
+        num_layers: int,
+        num_heads: int,
+        head_size: int,
+        dtype: torch.dtype,
+        seed: int,
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+    torch.random.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+
+    scale = head_size**-0.5
+    x = 16 // torch.tensor([], dtype=dtype).element_size()
+    key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
+    key_caches = []
+    for _ in range(num_layers):
+        key_cache = torch.empty(size=key_cache_shape,
+                                dtype=dtype,
+                                device='cuda')
+        key_cache.uniform_(-scale, scale)
+        key_caches.append(key_cache)
+
+    value_cache_shape = (num_blocks, num_heads, head_size, block_size)
+    values_caches = []
+    for _ in range(num_layers):
+        values_cache = torch.empty(size=value_cache_shape,
+                                   dtype=dtype,
+                                   device='cuda')
+        values_cache.uniform_(-scale, scale)
+        values_caches.append(values_cache)
+    return key_caches, values_caches
+
+@pytest.fixture()
+def kv_cache_factory():
+    return create_kv_caches

+ 169 - 0
tests/kernels/test_pos_encoding.py

@@ -0,0 +1,169 @@
+from typing import Optional, Tuple
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from aphrodite import pos_encoding_ops
+
+IS_NEOX_STYLE = [True, False]
+DTYPES = [torch.half, torch.bfloat16, torch.float]
+HEAD_SIZES = [64, 80, 96, 112, 128, 256]
+ROTARY_DIMS = [None, 32] # None means rotary dim == head size
+NUM_HEADS = [7, 12, 40, 52]
+NUM_TOKENS = [11, 83, 2048]
+SEEDS = [0]
+
+
+def rotate_neox(x: torch.Tensor) -> torch.Tensor:
+    x1 = x[..., :x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2:]
+    return torch.cat((-x2, x1), dim=-1)
+
+def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
+    x1 = x[..., ::2]
+    x2 = x[..., 1::2]
+    x = torch.stack((-x2, x1), dim=-1)
+    return x.flatten(-2)
+
+def apply_rope(
+        q: torch.Tensor,
+        k: torch.Tensor,
+        cos: torch.Tensor,
+        sin: torch.Tensor,
+        is_neox_style: bool,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    rotate_fn = rotate_neox if is_neox_style else rotate_gptj
+    q_embed = (q * cos) + (rotate_fn(q) * sin)
+    k_embed = (k * cos) + (rotate_fn(k) * sin)
+    return q_embed, k_embed
+
+class RefRotaryEmbedding(nn.Module):
+
+    def __init__(
+            self,
+            dim: int,
+            is_neox_style: bool,
+            max_position_embeddings: int = 8192,
+            base: int = 10000,
+    ) -> None:
+        super().__init__()
+        self.rotary_dim = dim
+        self.is_neox_style = is_neox_style
+        self.max_position_embeddings = max_position_embeddings
+
+        # create cos and sin embeddings
+        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
+        t = torch.arange(max_position_embeddings).float()
+        freqs = torch.einsum("i,j->ij", t, inv_freq.float())
+        if is_neox_style:
+            emb = torch.cat((freqs, freqs), dim=-1)
+        else:
+            emb = torch.repeat_interleave(freqs, 2, -1)
+        cos = emb.cos().to(dtype=inv_freq.dtype)
+        sin = emb.sin().to(dtype=inv_freq.dtype)
+        self.register_buffer("cos_cached", cos, persistent=False)
+        self.register_buffer("sin_cached", sin, persistent=False)
+
+    def forward(
+            self,
+            positions: torch.Tensor,
+            query: torch.Tensor,
+            key: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        query_rot = query[..., :self.rotary_dim]
+        query_pass = query[..., self.rotary_dim]
+        key_rot = key[..., :self.rotary_dim]
+        key_pass = key[..., self.rotary_dim:]
+
+        query_rot = query_rot.transpose(0, 1)
+        key_rot = key_rot.transpose(0, 1)
+        cos = F.embedding(positions, self.cos_cached)
+        sin = F.embedding(positions, self.sin_cached)
+
+        query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
+                                        self.is_neox_style)
+        query_rot = query_rot.transpose(0, 1).contiguous()
+        key_rot = key_rot.transpose(0, 1).contiguous()
+
+        query = torch.cat((query_rot, query_pass), dim=-1)
+        key = torch.cat((key_rot, key_pass), dim=-1)
+
+        return query, key
+    
+@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("num_heads", NUM_HEADS)
+@pytest.mark.parametrize("head_size", HEAD_SIZES)
+@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("seed", SEEDS)
+@torch.inference_mode()
+def test_rotary_embedding(
+    is_neox_style: bool,
+    num_tokens: int,
+    num_heads: int,
+    head_size: int,
+    rotary_dim: Optional[int],
+    dtype: torch.dtype,
+    seed: int,
+    max_position: int = 8192,
+    base: int = 10000,
+) -> None:
+    if rotary_dim is None:
+        rotary_dim = head_size
+    torch.random.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+
+    positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
+    query = torch.randn(num_tokens,
+                        num_heads * head_size,
+                        dtype=dtype,
+                        device="cuda")
+    key = torch.randn(num_tokens,
+                      num_heads * head_size,
+                      dtype=dtype,
+                      device="cuda")
+    
+    # create the rotary embedding
+    inv_freq = 1.0 / (base**(
+        torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
+    t = torch.arange(max_position).float()
+    freqs = torch.einsum("i,j -> ij", t, inv_freq)
+    cos = freqs.cos()
+    sin = freqs.sin()
+    cos_sin_cache = torch.cat((cos, sin), dim=-1)
+    cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
+
+    out_query = query.clone()
+    out_key = key.clone()
+    pos_encoding_ops.rotary_embedding(
+        positions,
+        out_query,
+        out_key,
+        head_size,
+        cos_sin_cache,
+        is_neox_style,
+    )
+
+    ref_rotary_embedding = RefRotaryEmbedding(
+        dim=rotary_dim,
+        is_neox_style=is_neox_style,
+        max_position_embeddings=max_position,
+        base=base,
+    ).to(dtype=dtype, device="cuda")
+    ref_query, ref_key = ref_rotary_embedding(
+        positions,
+        query.view(num_tokens, num_heads, head_size),
+        key.view(num_tokens, num_heads, head_size),
+    )
+    ref_query = ref_query.view(num_tokens, num_heads * head_size)
+    ref_key = ref_key.view(num_tokens, num_heads * head_size)
+
+    assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
+    assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
+
+
+
+