test_pos_encoding.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from typing import Optional, Tuple
  2. import pytest
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from aphrodite import pos_encoding_ops
  7. IS_NEOX_STYLE = [True, False]
  8. DTYPES = [torch.half, torch.bfloat16, torch.float]
  9. HEAD_SIZES = [64, 80, 96, 112, 128, 256]
  10. ROTARY_DIMS = [None, 32] # None means rotary dim == head size
  11. NUM_HEADS = [7, 12, 40, 52]
  12. NUM_TOKENS = [11, 83, 2048]
  13. SEEDS = [0]
  14. def rotate_neox(x: torch.Tensor) -> torch.Tensor:
  15. x1 = x[..., :x.shape[-1] // 2]
  16. x2 = x[..., x.shape[-1] // 2:]
  17. return torch.cat((-x2, x1), dim=-1)
  18. def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  19. x1 = x[..., ::2]
  20. x2 = x[..., 1::2]
  21. x = torch.stack((-x2, x1), dim=-1)
  22. return x.flatten(-2)
  23. def apply_rope(
  24. q: torch.Tensor,
  25. k: torch.Tensor,
  26. cos: torch.Tensor,
  27. sin: torch.Tensor,
  28. is_neox_style: bool,
  29. ) -> Tuple[torch.Tensor, torch.Tensor]:
  30. rotate_fn = rotate_neox if is_neox_style else rotate_gptj
  31. q_embed = (q * cos) + (rotate_fn(q) * sin)
  32. k_embed = (k * cos) + (rotate_fn(k) * sin)
  33. return q_embed, k_embed
  34. class RefRotaryEmbedding(nn.Module):
  35. def __init__(
  36. self,
  37. dim: int,
  38. is_neox_style: bool,
  39. max_position_embeddings: int = 8192,
  40. base: int = 10000,
  41. ) -> None:
  42. super().__init__()
  43. self.rotary_dim = dim
  44. self.is_neox_style = is_neox_style
  45. self.max_position_embeddings = max_position_embeddings
  46. # create cos and sin embeddings
  47. inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
  48. t = torch.arange(max_position_embeddings).float()
  49. freqs = torch.einsum("i,j->ij", t, inv_freq.float())
  50. if is_neox_style:
  51. emb = torch.cat((freqs, freqs), dim=-1)
  52. else:
  53. emb = torch.repeat_interleave(freqs, 2, -1)
  54. cos = emb.cos().to(dtype=inv_freq.dtype)
  55. sin = emb.sin().to(dtype=inv_freq.dtype)
  56. self.register_buffer("cos_cached", cos, persistent=False)
  57. self.register_buffer("sin_cached", sin, persistent=False)
  58. def forward(
  59. self,
  60. positions: torch.Tensor,
  61. query: torch.Tensor,
  62. key: torch.Tensor,
  63. ) -> Tuple[torch.Tensor, torch.Tensor]:
  64. query_rot = query[..., :self.rotary_dim]
  65. query_pass = query[..., self.rotary_dim]
  66. key_rot = key[..., :self.rotary_dim]
  67. key_pass = key[..., self.rotary_dim:]
  68. query_rot = query_rot.transpose(0, 1)
  69. key_rot = key_rot.transpose(0, 1)
  70. cos = F.embedding(positions, self.cos_cached)
  71. sin = F.embedding(positions, self.sin_cached)
  72. query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
  73. self.is_neox_style)
  74. query_rot = query_rot.transpose(0, 1).contiguous()
  75. key_rot = key_rot.transpose(0, 1).contiguous()
  76. query = torch.cat((query_rot, query_pass), dim=-1)
  77. key = torch.cat((key_rot, key_pass), dim=-1)
  78. return query, key
  79. @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
  80. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  81. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  82. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  83. @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
  84. @pytest.mark.parametrize("dtype", DTYPES)
  85. @pytest.mark.parametrize("seed", SEEDS)
  86. @torch.inference_mode()
  87. def test_rotary_embedding(
  88. is_neox_style: bool,
  89. num_tokens: int,
  90. num_heads: int,
  91. head_size: int,
  92. rotary_dim: Optional[int],
  93. dtype: torch.dtype,
  94. seed: int,
  95. max_position: int = 8192,
  96. base: int = 10000,
  97. ) -> None:
  98. if rotary_dim is None:
  99. rotary_dim = head_size
  100. torch.random.manual_seed(seed)
  101. torch.cuda.manual_seed(seed)
  102. positions = torch.randint(0, max_position, (num_tokens, ), device="cuda")
  103. query = torch.randn(num_tokens,
  104. num_heads * head_size,
  105. dtype=dtype,
  106. device="cuda")
  107. key = torch.randn(num_tokens,
  108. num_heads * head_size,
  109. dtype=dtype,
  110. device="cuda")
  111. # create the rotary embedding
  112. inv_freq = 1.0 / (base**(
  113. torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
  114. t = torch.arange(max_position).float()
  115. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  116. cos = freqs.cos()
  117. sin = freqs.sin()
  118. cos_sin_cache = torch.cat((cos, sin), dim=-1)
  119. cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
  120. out_query = query.clone()
  121. out_key = key.clone()
  122. pos_encoding_ops.rotary_embedding(
  123. positions,
  124. out_query,
  125. out_key,
  126. head_size,
  127. cos_sin_cache,
  128. is_neox_style,
  129. )
  130. ref_rotary_embedding = RefRotaryEmbedding(
  131. dim=rotary_dim,
  132. is_neox_style=is_neox_style,
  133. max_position_embeddings=max_position,
  134. base=base,
  135. ).to(dtype=dtype, device="cuda")
  136. ref_query, ref_key = ref_rotary_embedding(
  137. positions,
  138. query.view(num_tokens, num_heads, head_size),
  139. key.view(num_tokens, num_heads, head_size),
  140. )
  141. ref_query = ref_query.view(num_tokens, num_heads * head_size)
  142. ref_key = ref_key.view(num_tokens, num_heads * head_size)
  143. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
  144. assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)