test_pos_encoding.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from itertools import accumulate, product
  2. from typing import Dict, List, Optional
  3. import pytest
  4. import torch
  5. from aphrodite.modeling.layers.rotary_embedding import get_rope
  6. from .allclose_default import get_default_atol, get_default_rtol
  7. IS_NEOX_STYLE = [True, False]
  8. DTYPES = [torch.half, torch.bfloat16, torch.float]
  9. HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
  10. ROTARY_DIMS = [None, 32] # None means rotary dim == head size
  11. NUM_HEADS = [7, 17] # Arbitrary values for testing
  12. BATCH_SIZES = [1, 5] # Arbitrary values for testing
  13. SEQ_LENS = [11, 8192] # Arbitrary values for testing
  14. SEEDS = [0]
  15. CUDA_DEVICES = [
  16. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  17. ]
  18. @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
  19. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  20. @pytest.mark.parametrize("seq_len", SEQ_LENS)
  21. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  22. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  23. @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
  24. @pytest.mark.parametrize("dtype", DTYPES)
  25. @pytest.mark.parametrize("seed", SEEDS)
  26. @pytest.mark.parametrize("device", CUDA_DEVICES)
  27. @torch.inference_mode()
  28. def test_rotary_embedding(
  29. is_neox_style: bool,
  30. batch_size: int,
  31. seq_len: int,
  32. num_heads: int,
  33. head_size: int,
  34. rotary_dim: Optional[int],
  35. dtype: torch.dtype,
  36. seed: int,
  37. device: str,
  38. max_position: int = 8192,
  39. base: int = 10000,
  40. ) -> None:
  41. if rotary_dim is None:
  42. rotary_dim = head_size
  43. torch.random.manual_seed(seed)
  44. if torch.cuda.is_available():
  45. torch.cuda.manual_seed(seed)
  46. torch.set_default_device(device)
  47. if rotary_dim is None:
  48. rotary_dim = head_size
  49. rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
  50. rope = rope.to(dtype=dtype)
  51. positions = torch.randint(0, max_position, (batch_size, seq_len))
  52. query = torch.randn(batch_size,
  53. seq_len,
  54. num_heads * head_size,
  55. dtype=dtype)
  56. key = torch.randn_like(query)
  57. # NOTE(woosuk): The reference implementation should be executed first
  58. # because the custom kernel is in-place.
  59. ref_query, ref_key = rope.forward_native(positions, query, key)
  60. out_query, out_key = rope.forward(positions, query, key)
  61. # Compare the results.
  62. torch.testing.assert_close(out_query,
  63. ref_query,
  64. atol=get_default_atol(out_query),
  65. rtol=get_default_rtol(out_query))
  66. torch.testing.assert_close(out_key,
  67. ref_key,
  68. atol=get_default_atol(out_key),
  69. rtol=get_default_rtol(out_key))
  70. @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
  71. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  72. @pytest.mark.parametrize("seq_len", SEQ_LENS)
  73. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  74. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  75. @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
  76. @pytest.mark.parametrize("dtype", DTYPES)
  77. @pytest.mark.parametrize("seed", SEEDS)
  78. @pytest.mark.parametrize("device", CUDA_DEVICES)
  79. @torch.inference_mode()
  80. def test_batched_rotary_embedding(
  81. is_neox_style: bool,
  82. batch_size: int,
  83. seq_len: int,
  84. num_heads: int,
  85. head_size: int,
  86. rotary_dim: Optional[int],
  87. dtype: torch.dtype,
  88. seed: int,
  89. device: str,
  90. max_position: int = 8192,
  91. base: int = 10000,
  92. ) -> None:
  93. torch.random.manual_seed(seed)
  94. if torch.cuda.is_available():
  95. torch.cuda.manual_seed(seed)
  96. torch.set_default_device(device)
  97. if rotary_dim is None:
  98. rotary_dim = head_size
  99. rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
  100. "type": "linear",
  101. "factor": (1, )
  102. })
  103. rope = rope.to(dtype=dtype)
  104. positions = torch.randint(0, max_position, (batch_size, seq_len))
  105. query = torch.randn(batch_size,
  106. seq_len,
  107. num_heads * head_size,
  108. dtype=dtype)
  109. key = torch.randn_like(query)
  110. # NOTE: The reference implementation should be executed first
  111. # because the custom kernel is in-place.
  112. ref_query, ref_key = rope.forward_native(positions, query, key)
  113. out_query, out_key = rope.forward(positions,
  114. query,
  115. key,
  116. offsets=torch.zeros(batch_size * seq_len,
  117. dtype=torch.long,
  118. device=device))
  119. # Compare the results.
  120. torch.testing.assert_close(out_query,
  121. ref_query,
  122. atol=get_default_atol(out_query),
  123. rtol=get_default_rtol(out_query))
  124. torch.testing.assert_close(out_key,
  125. ref_key,
  126. atol=get_default_atol(out_key),
  127. rtol=get_default_rtol(out_key))
  128. @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
  129. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  130. @pytest.mark.parametrize("seq_len", SEQ_LENS)
  131. @pytest.mark.parametrize("num_heads", NUM_HEADS)
  132. @pytest.mark.parametrize("head_size", HEAD_SIZES)
  133. @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
  134. @pytest.mark.parametrize("dtype", DTYPES)
  135. @pytest.mark.parametrize("seed", SEEDS)
  136. @pytest.mark.parametrize("device", CUDA_DEVICES)
  137. @torch.inference_mode()
  138. def test_batched_rotary_embedding_multi_lora(
  139. is_neox_style: bool,
  140. batch_size: int,
  141. seq_len: int,
  142. num_heads: int,
  143. head_size: int,
  144. rotary_dim: Optional[int],
  145. dtype: torch.dtype,
  146. seed: int,
  147. device: str,
  148. max_position: int = 8192,
  149. base: int = 10000,
  150. ) -> None:
  151. torch.random.manual_seed(seed)
  152. if torch.cuda.is_available():
  153. torch.cuda.manual_seed(seed)
  154. torch.set_default_device(device)
  155. if rotary_dim is None:
  156. rotary_dim = head_size
  157. scaling_factors: List[int] = [1, 2, 4]
  158. rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
  159. "type": "linear",
  160. "factor": tuple(scaling_factors)
  161. })
  162. rope = rope.to(dtype=dtype)
  163. positions = torch.randint(0, max_position, (batch_size, seq_len))
  164. query = torch.randn(batch_size,
  165. seq_len,
  166. num_heads * head_size,
  167. dtype=dtype)
  168. key = torch.randn_like(query)
  169. offset_map = torch.tensor(
  170. list(
  171. accumulate([0] + [
  172. max_position * scaling_factor * 2
  173. for scaling_factor in scaling_factors[:-1]
  174. ])))
  175. query_types = torch.randint(0,
  176. len(scaling_factors), (batch_size, seq_len),
  177. device=device)
  178. query_offsets = offset_map[query_types]
  179. # NOTE: The reference implementation should be executed first
  180. # because the custom kernel is in-place.
  181. ref_query, ref_key = rope.forward_native(positions, query, key,
  182. query_offsets)
  183. out_query, out_key = rope.forward(positions, query, key,
  184. query_offsets.flatten())
  185. # Compare the results.
  186. torch.testing.assert_close(out_query,
  187. ref_query,
  188. atol=get_default_atol(out_query),
  189. rtol=get_default_rtol(out_query))
  190. torch.testing.assert_close(out_key,
  191. ref_key,
  192. atol=get_default_atol(out_key),
  193. rtol=get_default_rtol(out_key))
  194. @torch.inference_mode()
  195. def test_rope_module_cache():
  196. MAX_POSITIONS = [123, 1234]
  197. BASES = [10000, 1000000]
  198. ROPE_SCALINGS = (None, {
  199. "type": "linear",
  200. "factor": (1, )
  201. }, {
  202. "type": "dynamic",
  203. "factor": 1
  204. })
  205. settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
  206. ROPE_SCALINGS, DTYPES)
  207. rope_setting_id_map: Dict[str, int] = {}
  208. for setting in product(*settings):
  209. head_size, rotary_dim, max_position, base, \
  210. is_neox_stype, rope_scaling, dtype = setting
  211. if rotary_dim is None:
  212. rotary_dim = head_size
  213. rope = get_rope(head_size, rotary_dim, max_position, base,
  214. is_neox_stype, rope_scaling, dtype)
  215. # different settings cannot share the same rope module
  216. assert id(rope) not in rope_setting_id_map.values()
  217. assert all(x.dtype == dtype for x in rope.buffers())
  218. assert all(x.dtype == dtype for x in rope.parameters())
  219. rope_setting_id_map[str(setting)] = id(rope)
  220. for setting in product(*settings):
  221. head_size, rotary_dim, max_position, base, \
  222. is_neox_stype, rope_scaling, dtype = setting
  223. if rotary_dim is None:
  224. rotary_dim = head_size
  225. rope = get_rope(head_size, rotary_dim, max_position, base,
  226. is_neox_stype, rope_scaling, dtype)
  227. # check if cache take effect
  228. assert id(rope) == rope_setting_id_map[str(setting)]