1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from typing import Optional
- import pytest
- import torch
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- IS_NEOX_STYLE = [True, False]
- DTYPES = [torch.half, torch.bfloat16, torch.float]
- HEAD_SIZES = [64, 80, 96, 112, 128, 256]
- ROTARY_DIMS = [None, 32] # None means rotary dim == head size
- NUM_HEADS = [7, 17] # Arbitrary values for testing
- BATCH_SIZES = [1, 5] # Arbitrary values for testing
- SEQ_LENS = [11, 8192] # Arbitrary values for testing
- SEEDS = [0]
- DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
- @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
- @pytest.mark.parametrize("batch_size", BATCH_SIZES)
- @pytest.mark.parametrize("seq_len", SEQ_LENS)
- @pytest.mark.parametrize("num_heads", NUM_HEADS)
- @pytest.mark.parametrize("head_size", HEAD_SIZES)
- @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
- @pytest.mark.parametrize("dtype", DTYPES)
- @pytest.mark.parametrize("seed", SEEDS)
- @pytest.mark.parametrize("device", DEVICES)
- @torch.inference_mode()
- def test_rotary_embedding(
- is_neox_style: bool,
- batch_size: int,
- seq_len: int,
- num_heads: int,
- head_size: int,
- rotary_dim: Optional[int],
- dtype: torch.dtype,
- seed: int,
- device: int,
- max_position: int = 8192,
- base: int = 10000,
- ) -> None:
- if rotary_dim is None:
- rotary_dim = head_size
- torch.random.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- gpu_id = f"cuda:{device}"
- if rotary_dim is None:
- rotary_dim = head_size
- rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
- rope = rope.to(dtype=dtype, device=gpu_id)
- positions = torch.randint(0,
- max_position, (batch_size, seq_len),
- device=gpu_id)
- query = torch.randn(batch_size,
- seq_len,
- num_heads * head_size,
- dtype=dtype,
- device=gpu_id)
- key = torch.randn_like(query)
- # NOTE: The reference implementation should be executed first
- # because the custom kernel is in-place.
- ref_query, ref_key = rope._forward(positions, query, key)
- out_query, out_key = rope.forward(positions, query, key)
- # Compare the results.
- assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
- assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
|