test_pos_encoding.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from typing import Optional
  2. import pytest
  3. import torch
  4. from aphrodite.modeling.layers.rotary_embedding import get_rope
  5. IS_NEOX_STYLE = [True, False]
  6. DTYPES = [torch.half, torch.bfloat16, torch.float]
  7. HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  8. ROTARY_DIMS = [None, 32] # None means rotary dim == head size
  9. NUM_HEADS = [7, 17] # Arbitrary values for testing
  10. BATCH_SIZES = [1, 5] # Arbitrary values for testing
  11. SEQ_LENS = [11, 8192] # Arbitrary values for testing
  12. SEEDS = [0]
  13. DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
  14. @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
  15. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  16. @pytest.mark.parametrize("seq_len", SEQ_LENS)
  17. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  18. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  19. @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
  20. @pytest.mark.parametrize("dtype", DTYPES)
  21. @pytest.mark.parametrize("seed", SEEDS)
  22. @pytest.mark.parametrize("device", DEVICES)
  23. @torch.inference_mode()
  24. def test_rotary_embedding(
  25. is_neox_style: bool,
  26. batch_size: int,
  27. seq_len: int,
  28. num_heads: int,
  29. head_size: int,
  30. rotary_dim: Optional[int],
  31. dtype: torch.dtype,
  32. seed: int,
  33. device: int,
  34. max_position: int = 8192,
  35. base: int = 10000,
  36. ) -> None:
  37. if rotary_dim is None:
  38. rotary_dim = head_size
  39. torch.random.manual_seed(seed)
  40. torch.cuda.manual_seed(seed)
  41. gpu_id = f"cuda:{device}"
  42. if rotary_dim is None:
  43. rotary_dim = head_size
  44. rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
  45. rope = rope.to(dtype=dtype, device=gpu_id)
  46. positions = torch.randint(0,
  47. max_position, (batch_size, seq_len),
  48. device=gpu_id)
  49. query = torch.randn(batch_size,
  50. seq_len,
  51. num_heads * head_size,
  52. dtype=dtype,
  53. device=gpu_id)
  54. key = torch.randn_like(query)
  55. # NOTE: The reference implementation should be executed first
  56. # because the custom kernel is in-place.
  57. ref_query, ref_key = rope._forward(positions, query, key)
  58. out_query, out_key = rope.forward(positions, query, key)
  59. # Compare the results.
  60. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
  61. assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)