1
0

test_samplers.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import random
  2. from typing import Tuple
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  7. SequenceGroupMetadata)
  8. from aphrodite.modeling.layers.sampler import Sampler
  9. from aphrodite.modeling.utils import set_random_seed
  10. from aphrodite.task_handler.model_runner import ModelRunner
  11. class MockLogitsSampler(Sampler):
  12. def __init__(self, vocab_size: int, fake_logits: torch.Tensor):
  13. super().__init__(vocab_size=vocab_size)
  14. self.fake_logits = fake_logits
  15. def forward(self, *args, **kwargs):
  16. with patch("aphrodite.modeling.layers.sampler._prune_hidden_states",
  17. lambda x, y: x), patch(
  18. "aphrodite.modeling.layers.sampler._get_logits",
  19. lambda *args, **kwargs: self.fake_logits):
  20. return super().forward(*args, **kwargs)
  21. def _prepare_test(
  22. batch_size: int
  23. ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
  24. vocab_size = 32000
  25. input_tensor = torch.rand((batch_size, 1024),
  26. device="cuda",
  27. dtype=torch.float16)
  28. fake_logits = torch.full((batch_size, vocab_size),
  29. 1e-2,
  30. device=input_tensor.device,
  31. dtype=input_tensor.dtype)
  32. sampler = MockLogitsSampler(32000, fake_logits)
  33. model_runner = ModelRunner(None, None, None)
  34. return input_tensor, fake_logits, sampler, model_runner
  35. RANDOM_SEEDS = list(range(128))
  36. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  37. def test_sampler_all_greedy(seed: int):
  38. set_random_seed(seed)
  39. batch_size = random.randint(1, 256)
  40. input_tensor, fake_logits, sampler, model_runner = _prepare_test(
  41. batch_size)
  42. seq_group_metadata_list = []
  43. prompt_lens = []
  44. for i in range(batch_size):
  45. seq_group_metadata_list.append(
  46. SequenceGroupMetadata(
  47. request_id=f"test_{i}",
  48. is_prompt=True,
  49. seq_data={0: SequenceData([1, 2, 3])},
  50. sampling_params=SamplingParams(temperature=0, ),
  51. block_tables={0: [1]},
  52. persistent_data={},
  53. ))
  54. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  55. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  56. prompt_lens)
  57. sampler_output = sampler(embedding=None,
  58. hidden_states=input_tensor,
  59. sampling_metadata=sampling_metadata)
  60. expected = torch.argmax(fake_logits, dim=-1)
  61. for i, sequence_output in enumerate(sampler_output):
  62. for nth_output in sequence_output.samples:
  63. assert nth_output.output_token == expected[i].item()
  64. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  65. def test_sampler_all_random(seed: int):
  66. set_random_seed(seed)
  67. batch_size = random.randint(1, 256)
  68. input_tensor, fake_logits, sampler, model_runner = _prepare_test(
  69. batch_size)
  70. for i in range(batch_size):
  71. fake_logits[i, i] = 1e2
  72. seq_group_metadata_list = []
  73. prompt_lens = []
  74. for i in range(batch_size):
  75. seq_group_metadata_list.append(
  76. SequenceGroupMetadata(
  77. request_id=f"test_{i}",
  78. is_prompt=True,
  79. seq_data={0: SequenceData([1, 2, 3])},
  80. sampling_params=SamplingParams(
  81. temperature=1.0,
  82. n=random.randint(1, 10),
  83. ),
  84. block_tables={0: [1]},
  85. persistent_data={},
  86. ))
  87. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  88. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  89. prompt_lens)
  90. sampler_output = sampler(embedding=None,
  91. hidden_states=input_tensor,
  92. sampling_metadata=sampling_metadata)
  93. for i, sequence_output in enumerate(sampler_output):
  94. for nth_output in sequence_output.samples:
  95. assert nth_output.output_token == i
  96. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  97. def test_sampler_all_beam(seed: int):
  98. set_random_seed(seed)
  99. batch_size = random.randint(1, 256)
  100. input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
  101. seq_group_metadata_list = []
  102. prompt_lens = []
  103. for i in range(batch_size):
  104. seq_group_metadata_list.append(
  105. SequenceGroupMetadata(
  106. request_id=f"test_{i}",
  107. is_prompt=True,
  108. seq_data={0: SequenceData([1, 2, 3])},
  109. sampling_params=SamplingParams(
  110. temperature=0,
  111. best_of=2,
  112. use_beam_search=True,
  113. ),
  114. block_tables={0: [1]},
  115. persistent_data={},
  116. ))
  117. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  118. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  119. prompt_lens)
  120. sampler(embedding=None,
  121. hidden_states=input_tensor,
  122. sampling_metadata=sampling_metadata)
  123. # no assertion here as I am not sure how to determine whether
  124. # the outputs are expected - in other words, this just tests
  125. # whether there are no exceptions in the sampler
  126. # when handling an all-beam search case.
  127. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  128. def test_sampler_mixed(seed: int):
  129. set_random_seed(seed)
  130. batch_size = random.randint(1, 256)
  131. input_tensor, fake_logits, sampler, model_runner = _prepare_test(
  132. batch_size)
  133. seq_group_metadata_list = []
  134. expected_tokens = []
  135. prompt_lens = []
  136. for i in range(batch_size):
  137. n = 1
  138. sampling_type = random.randint(0, 2)
  139. if sampling_type == 0:
  140. sampling_params = SamplingParams(temperature=0)
  141. elif sampling_type == 1:
  142. n = random.randint(1, 10)
  143. sampling_params = SamplingParams(
  144. temperature=random.random() + 0.1,
  145. top_p=min(random.random() + 0.1, 1),
  146. top_k=random.randint(0, 10) or -1,
  147. n=n,
  148. presence_penalty=random.randint(0, 1),
  149. )
  150. else:
  151. sampling_params = SamplingParams(temperature=0,
  152. use_beam_search=True,
  153. best_of=2)
  154. for idx in range(n):
  155. fake_logits[i, i + idx] = 1e2
  156. expected_tokens.append(i + idx)
  157. seq_group_metadata_list.append(
  158. SequenceGroupMetadata(
  159. request_id=f"test_{i}",
  160. is_prompt=True,
  161. seq_data={0: SequenceData([1, 2, 3])},
  162. sampling_params=sampling_params,
  163. block_tables={0: [1]},
  164. persistent_data={},
  165. ))
  166. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  167. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  168. prompt_lens)
  169. sampler_output = sampler(embedding=None,
  170. hidden_states=input_tensor,
  171. sampling_metadata=sampling_metadata)
  172. for i, sequence_output in enumerate(sampler_output):
  173. if seq_group_metadata_list[i].sampling_params.use_beam_search:
  174. continue
  175. for nth_output in sequence_output.samples:
  176. assert nth_output.output_token in expected_tokens
  177. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  178. def test_sampler_logits_processors(seed: int):
  179. set_random_seed(seed)
  180. batch_size = random.randint(1, 256)
  181. input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
  182. # This sample logits processor gives infinite score to the i-th token,
  183. # where i is the length of the input sequence.
  184. # We therefore expect the output token sequence to be [0, 1, 2, ...]
  185. def pick_ith(token_ids, logits):
  186. logits[len(token_ids)] = float("inf")
  187. return logits
  188. seq_group_metadata_list = []
  189. prompt_lens = []
  190. for i in range(batch_size):
  191. seq_group_metadata_list.append(
  192. SequenceGroupMetadata(
  193. request_id=f"test_{i}",
  194. is_prompt=True,
  195. seq_data={0: SequenceData([1, 2, 3])},
  196. sampling_params=SamplingParams(temperature=0,
  197. logits_processors=[pick_ith]),
  198. block_tables={0: [1]},
  199. persistent_data={},
  200. ))
  201. prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  202. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  203. prompt_lens)
  204. sampler_output = sampler(embedding=None,
  205. hidden_states=input_tensor,
  206. sampling_metadata=sampling_metadata)
  207. for _, sequence_output in enumerate(sampler_output):
  208. for idx, nth_output in enumerate(sequence_output.samples):
  209. assert nth_output.output_token == idx