test_samplers.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import random
  2. from typing import Tuple
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. from aphrodite.modeling.layers.sampler import Sampler
  7. from aphrodite.modeling.utils import set_random_seed
  8. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  9. SequenceGroupMetadata)
  10. from aphrodite.task_handler.model_runner import ModelRunner
  11. class MockLogitsSampler(Sampler):
  12. def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
  13. super().__init__(vocab_size=vocab_size)
  14. self.fake_logits = fake_logits
  15. def forward(self, *args, **kwargs):
  16. with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
  17. lambda x, y: x), patch(
  18. "aphrodite.modeling.layers.sampler._get_logits",
  19. lambda *args, **kwargs: self.fake_logits):
  20. return super().forward(*args, **kwargs)
  21. def _prepare_test(
  22. batch_size: int
  23. ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
  24. vocab_size = 32000
  25. input_tensor = torch.rand((batch_size, 1024),
  26. device="cuda",
  27. dtype=torch.float16)
  28. fake_logits = torch.full((batch_size, vocab_size),
  29. 1e-2,
  30. device=input_tensor.device,
  31. dtype=input_tensor.dtype)
  32. sampler = MockLogitsSampler(32000, fake_logits)
  33. model_runner = ModelRunner(None, None, None)
  34. return input_tensor, fake_logits, sampler, model_runner
  35. RANDOM_SEEDS = list(range(128))
  36. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  37. def test_sampler_all_greedy(seed: int):
  38. set_random_seed(seed)
  39. batch_size = random.randint(1, 256)
  40. input_tensor, fake_logits, sampler, model_runner = _prepare_test(
  41. batch_size)
  42. seq_group_metadata_list = []
  43. prompt_lens = []
  44. for i in range(batch_size):
  45. seq_group_metadata_list.append(
  46. SequenceGroupMetadata(
  47. request_id=f"test_{i}",
  48. is_prompt=True,
  49. seq_data={0: SequenceData([1, 2, 3])},
  50. sampling_params=SamplingParams(temperature=0, ),
  51. block_tables={0: [1]},
  52. persistent_data={},
  53. ))
  54. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  55. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  56. prompt_lens)
  57. sampler_output = sampler(embedding=None,
  58. hidden_states=input_tensor,
  59. sampling_metadata=sampling_metadata)
  60. expected = torch.argmax(fake_logits, dim=-1)
  61. for i, sequence_output in enumerate(sampler_output):
  62. for nth_output in sequence_output.samples:
  63. assert nth_output.output_token == expected[i].item()
  64. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  65. def test_sampler_all_random(seed: int):
  66. set_random_seed(seed)
  67. batch_size = random.randint(1, 256)
  68. input_tensor, fake_logits, sampler, model_runner = _prepare_test(
  69. batch_size)
  70. for i in range(batch_size):
  71. fake_logits[i, i] = 1e2
  72. seq_group_metadata_list = []
  73. prompt_lens = []
  74. for i in range(batch_size):
  75. seq_group_metadata_list.append(
  76. SequenceGroupMetadata(
  77. request_id=f"test_{i}",
  78. is_prompt=True,
  79. seq_data={0: SequenceData([1, 2, 3])},
  80. sampling_params=SamplingParams(
  81. temperature=1.0,
  82. n=random.randint(1, 10),
  83. ),
  84. block_tables={0: [1]},
  85. persistent_data={},
  86. ))
  87. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  88. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  89. prompt_lens)
  90. sampler_output = sampler(embedding=None,
  91. hidden_states=input_tensor,
  92. sampling_metadata=sampling_metadata)
  93. for i, sequence_output in enumerate(sampler_output):
  94. for nth_output in sequence_output.samples:
  95. assert nth_output.output_token == i
  96. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  97. def test_sampler_all_beam(seed: int):
  98. set_random_seed(seed)
  99. batch_size = random.randint(1, 256)
  100. input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
  101. seq_group_metadata_list = []
  102. prompt_lens = []
  103. for i in range(batch_size):
  104. seq_group_metadata_list.append(
  105. SequenceGroupMetadata(
  106. request_id=f"test_{i}",
  107. is_prompt=True,
  108. seq_data={0: SequenceData([1, 2, 3])},
  109. sampling_params=SamplingParams(
  110. temperature=0,
  111. best_of=2,
  112. use_beam_search=True,
  113. ),
  114. block_tables={0: [1]},
  115. persistent_data={},
  116. ))
  117. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  118. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  119. prompt_lens)
  120. sampler(embedding=None,
  121. hidden_states=input_tensor,
  122. sampling_metadata=sampling_metadata)
  123. # no assertion here as I am not sure how to determine whether
  124. # the outputs are expected - in other words, this just tests
  125. # whether there are no exceptions in the sampler
  126. # when handling an all-beam search case.
  127. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  128. def test_sampler_mixed(seed: int):
  129. set_random_seed(seed)
  130. batch_size = random.randint(1, 256)
  131. input_tensor, fake_logits, sampler, model_runner = _prepare_test(
  132. batch_size)
  133. seq_group_metadata_list = []
  134. expected_tokens = []
  135. prompt_lens = []
  136. for i in range(batch_size):
  137. n = 1
  138. sampling_type = random.randint(0, 2)
  139. if sampling_type == 0:
  140. sampling_params = SamplingParams(temperature=0)
  141. elif sampling_type == 1:
  142. n = random.randint(1, 10)
  143. sampling_params = SamplingParams(
  144. temperature=random.random() + 0.1,
  145. top_p=min(random.random() + 0.1, 1),
  146. top_k=random.randint(0, 10) or -1,
  147. n=n,
  148. presence_penalty=random.randint(0, 1),
  149. )
  150. else:
  151. sampling_params = SamplingParams(temperature=0,
  152. use_beam_search=True,
  153. best_of=2)
  154. for idx in range(n):
  155. fake_logits[i, i + idx] = 1e2
  156. expected_tokens.append(i + idx)
  157. seq_group_metadata_list.append(
  158. SequenceGroupMetadata(
  159. request_id=f"test_{i}",
  160. is_prompt=True,
  161. seq_data={0: SequenceData([1, 2, 3])},
  162. sampling_params=sampling_params,
  163. block_tables={0: [1]},
  164. persistent_data={},
  165. ))
  166. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  167. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  168. prompt_lens)
  169. sampler_output = sampler(embedding=None,
  170. hidden_states=input_tensor,
  171. sampling_metadata=sampling_metadata)
  172. for i, sequence_output in enumerate(sampler_output):
  173. if seq_group_metadata_list[i].sampling_params.use_beam_search:
  174. continue
  175. for nth_output in sequence_output.samples:
  176. assert nth_output.output_token in expected_tokens
  177. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  178. def test_sampler_logits_processors(seed: int):
  179. set_random_seed(seed)
  180. batch_size = random.randint(1, 256)
  181. input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
  182. # This sample logits processor gives infinite score to the i-th token,
  183. # where i is the length of the input sequence.
  184. # We therefore expect the output token sequence to be [0, 1, 2, ...]
  185. def pick_ith(token_ids, logits):
  186. logits[len(token_ids)] = float("inf")
  187. return logits
  188. seq_group_metadata_list = []
  189. prompt_lens = []
  190. for i in range(batch_size):
  191. seq_group_metadata_list.append(
  192. SequenceGroupMetadata(
  193. request_id=f"test_{i}",
  194. is_prompt=True,
  195. seq_data={0: SequenceData([1, 2, 3])},
  196. sampling_params=SamplingParams(temperature=0,
  197. logits_processors=[pick_ith]),
  198. block_tables={0: [1]},
  199. persistent_data={},
  200. ))
  201. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  202. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  203. prompt_lens)
  204. sampler_output = sampler(embedding=None,
  205. hidden_states=input_tensor,
  206. sampling_metadata=sampling_metadata)
  207. for _, sequence_output in enumerate(sampler_output):
  208. for idx, nth_output in enumerate(sequence_output.samples):
  209. assert nth_output.output_token == idx