test_sampler.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747
  1. import itertools
  2. import random
  3. from array import array
  4. from typing import Dict, List, Optional, Tuple
  5. from unittest.mock import Mock, patch
  6. import pytest
  7. import torch
  8. from transformers import GenerationConfig, GenerationMixin
  9. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  10. SamplingParams, SequenceData,
  11. SequenceGroupMetadata)
  12. from aphrodite.common.utils import Counter, is_pin_memory_available
  13. from aphrodite.modeling.layers.sampler import Sampler
  14. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  15. from aphrodite.modeling.utils import set_random_seed
  16. class MockLogitsSampler(Sampler):
  17. def __init__(self, fake_logits: torch.Tensor):
  18. super().__init__()
  19. self.fake_logits = fake_logits
  20. def forward(self, *args, **kwargs):
  21. return super().forward(*args, **kwargs)
  22. def _prepare_test(
  23. batch_size: int
  24. ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
  25. input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
  26. fake_logits = torch.full((batch_size, VOCAB_SIZE),
  27. 1e-2,
  28. dtype=input_tensor.dtype)
  29. sampler = MockLogitsSampler(fake_logits)
  30. return input_tensor, fake_logits, sampler
  31. VOCAB_SIZE = 32000
  32. RANDOM_SEEDS = list(range(128))
  33. CUDA_DEVICES = [
  34. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  35. ]
  36. def _do_sample(
  37. batch_size: int,
  38. input_tensor: torch.Tensor,
  39. sampler: MockLogitsSampler,
  40. sampling_params: SamplingParams,
  41. device: str,
  42. ):
  43. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  44. seq_lens: List[int] = []
  45. for i in range(batch_size):
  46. seq_group_metadata_list.append(
  47. SequenceGroupMetadata(
  48. request_id=f"test_{i}",
  49. is_prompt=True,
  50. seq_data={
  51. 0: SequenceData(array(
  52. APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
  53. },
  54. sampling_params=sampling_params,
  55. block_tables={0: [1]},
  56. ))
  57. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  58. sampling_metadata = SamplingMetadata.prepare(
  59. seq_group_metadata_list,
  60. seq_lens,
  61. query_lens=seq_lens,
  62. device=device,
  63. pin_memory=is_pin_memory_available())
  64. return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
  65. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  66. @pytest.mark.parametrize("device", CUDA_DEVICES)
  67. def test_sampler_all_greedy(seed: int, device: str):
  68. set_random_seed(seed)
  69. torch.set_default_device(device)
  70. batch_size = random.randint(1, 256)
  71. input_tensor, fake_logits, sampler = _prepare_test(batch_size)
  72. sampling_params = SamplingParams(temperature=0)
  73. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  74. sampling_params, device)
  75. expected = torch.argmax(fake_logits, dim=-1)
  76. for i, sequence_output in enumerate(sampler_output):
  77. for nth_output in sequence_output.samples:
  78. assert nth_output.output_token == expected[i].item()
  79. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  80. @pytest.mark.parametrize("device", CUDA_DEVICES)
  81. def test_sampler_all_random(seed: int, device: str):
  82. set_random_seed(seed)
  83. torch.set_default_device(device)
  84. batch_size = random.randint(1, 256)
  85. _, fake_logits, sampler = _prepare_test(batch_size)
  86. for i in range(batch_size):
  87. fake_logits[i, i] = 1e2
  88. sampling_params = SamplingParams(
  89. temperature=1.0,
  90. n=random.randint(1, 10),
  91. )
  92. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  93. sampling_params, device)
  94. for i, sequence_output in enumerate(sampler_output):
  95. for nth_output in sequence_output.samples:
  96. assert nth_output.output_token == i
  97. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  98. @pytest.mark.parametrize("device", CUDA_DEVICES)
  99. def test_sampler_all_random_seed(seed: int, device: str):
  100. set_random_seed(seed)
  101. torch.set_default_device(device)
  102. batch_size = random.randint(1, 256)
  103. _, fake_logits, sampler = _prepare_test(batch_size)
  104. for i in range(batch_size):
  105. fake_logits[i, i] = 1e2
  106. sampling_params = SamplingParams(
  107. temperature=1.0,
  108. n=random.randint(1, 10),
  109. seed=random.randint(0, 10000),
  110. )
  111. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  112. sampling_params, device)
  113. for i, sequence_output in enumerate(sampler_output):
  114. for nth_output in sequence_output.samples:
  115. assert nth_output.output_token == i
  116. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  117. @pytest.mark.parametrize("device", CUDA_DEVICES)
  118. def test_sampler_all_random_seed_deterministic(seed: int, device: str):
  119. set_random_seed(seed)
  120. torch.set_default_device(device)
  121. batch_size = random.randint(1, 256)
  122. _, fake_logits, sampler = _prepare_test(batch_size)
  123. sampling_params = SamplingParams(
  124. temperature=1.0,
  125. n=random.randint(1, 10),
  126. seed=random.randint(0, 10000),
  127. )
  128. first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
  129. sampling_params, device)
  130. second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
  131. sampling_params, device)
  132. assert first_sampler_output == second_sampler_output
  133. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  134. @pytest.mark.parametrize("device", CUDA_DEVICES)
  135. def test_sampler_all_beam(seed: int, device: str):
  136. set_random_seed(seed)
  137. torch.set_default_device(device)
  138. batch_size = random.randint(1, 256)
  139. _, fake_logits, sampler = _prepare_test(batch_size)
  140. sampling_params = SamplingParams(
  141. temperature=0,
  142. best_of=2,
  143. use_beam_search=True,
  144. )
  145. _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
  146. # no assertion here as I am not sure how to determine whether
  147. # the outputs are expected - in other words, this just tests
  148. # whether there are no exceptions in the sampler
  149. # when handling an all-beam search case.
  150. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  151. @pytest.mark.parametrize("device", CUDA_DEVICES)
  152. def test_sampler_min_tokens_penalty(seed: int, device: str):
  153. seq_id_counter = Counter(start=random.randint(0, 100))
  154. set_random_seed(seed)
  155. torch.set_default_device(device)
  156. def create_sampling_params(min_tokens,
  157. eos_token_id=0,
  158. *,
  159. stop_token_ids: Optional[List[int]] = None,
  160. prompt_logprobs: Optional[int] = None):
  161. sampling_params = SamplingParams(
  162. min_tokens=min_tokens,
  163. max_tokens=9999, # keep higher than max of min_tokens
  164. stop_token_ids=stop_token_ids,
  165. # requesting prompt_logprobs changes the structure of `logits`
  166. prompt_logprobs=prompt_logprobs,
  167. )
  168. sampling_params.all_stop_token_ids.add(eos_token_id)
  169. return sampling_params
  170. def create_sequence_data(num_input=3, num_generated=0):
  171. seq_data = SequenceData(
  172. array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  173. random.choices(range(0, VOCAB_SIZE), k=num_input)))
  174. if num_generated > 0:
  175. seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
  176. k=num_generated)
  177. return seq_data
  178. def generate_test_case():
  179. # generate multiple seq groups but limit total batch size
  180. batch_size = random.randint(1, 128)
  181. expected_penalization = []
  182. sequence_metadata_list: List[SequenceGroupMetadata] = []
  183. # 20% chance to generate seq group metadata list with all prompts
  184. is_prompt = random.random() < 0.2
  185. while batch_size > 0:
  186. num_seqs = 1 if is_prompt else random.randint(1, batch_size)
  187. eos_token_id = random.randint(0, VOCAB_SIZE - 1)
  188. min_tokens = random.randint(0, 50)
  189. num_stop_tokens = random.randint(0, 8)
  190. if num_stop_tokens > 0:
  191. stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
  192. k=num_stop_tokens)
  193. else:
  194. stop_token_ids = None
  195. sampling_params = create_sampling_params(
  196. min_tokens=min_tokens,
  197. eos_token_id=eos_token_id,
  198. stop_token_ids=stop_token_ids)
  199. seq_data: Dict[int, SequenceData] = {}
  200. seq_group_penalization: List[bool] = []
  201. for _ in range(num_seqs):
  202. num_input = random.randint(1, 100)
  203. num_generated = 0 if is_prompt else random.randint(1, 100)
  204. seq_data[next(seq_id_counter)] = create_sequence_data(
  205. num_input=num_input, num_generated=num_generated)
  206. seq_group_penalization.append(num_generated < min_tokens)
  207. expected_penalization.extend(seq_group_penalization)
  208. sequence_metadata_list.append(
  209. SequenceGroupMetadata(
  210. request_id=f"test_{batch_size}",
  211. is_prompt=is_prompt,
  212. seq_data=seq_data,
  213. sampling_params=sampling_params,
  214. block_tables={},
  215. ))
  216. batch_size -= num_seqs
  217. return {
  218. "expected_penalization": expected_penalization,
  219. "seq_group_metadata_list": sequence_metadata_list,
  220. }
  221. # define some explicit test cases for edge case behavior
  222. prompt_without_penalization = {
  223. "expected_penalization": [False],
  224. "seq_group_metadata_list": [
  225. SequenceGroupMetadata(
  226. request_id="test_1",
  227. is_prompt=True,
  228. seq_data={
  229. next(seq_id_counter): create_sequence_data(),
  230. },
  231. sampling_params=create_sampling_params(0),
  232. block_tables={},
  233. ),
  234. ]
  235. }
  236. prompt_with_penalization = {
  237. "expected_penalization": [True],
  238. "seq_group_metadata_list": [
  239. SequenceGroupMetadata(
  240. request_id="test_1",
  241. is_prompt=True,
  242. seq_data={
  243. next(seq_id_counter): create_sequence_data(),
  244. },
  245. sampling_params=create_sampling_params(1),
  246. block_tables={},
  247. ),
  248. ]
  249. }
  250. prompt_with_penalization_and_prompt_logprobs = {
  251. "expected_penalization": [False, False, True],
  252. "seq_group_metadata_list": [
  253. SequenceGroupMetadata(
  254. request_id="test_1",
  255. is_prompt=True,
  256. seq_data={
  257. next(seq_id_counter): create_sequence_data(num_input=3),
  258. },
  259. sampling_params=create_sampling_params(1, prompt_logprobs=3),
  260. block_tables={},
  261. ),
  262. ]
  263. }
  264. stop_penalizing_after_min_tokens = {
  265. "expected_penalization": [False],
  266. "seq_group_metadata_list": [
  267. SequenceGroupMetadata(
  268. request_id="test_1",
  269. is_prompt=False,
  270. seq_data={
  271. next(seq_id_counter):
  272. create_sequence_data(num_generated=1),
  273. },
  274. sampling_params=create_sampling_params(1),
  275. block_tables={},
  276. )
  277. ]
  278. }
  279. stop_token_ids = [42, 99, 42, 0] # intentional duplication
  280. prompt_combination = {
  281. "expected_penalization": [False, True, False],
  282. "seq_group_metadata_list": [
  283. SequenceGroupMetadata(
  284. request_id="test_2",
  285. is_prompt=True,
  286. seq_data={
  287. next(seq_id_counter): create_sequence_data(num_input=2),
  288. },
  289. sampling_params=create_sampling_params(1, prompt_logprobs=3),
  290. block_tables={},
  291. ),
  292. SequenceGroupMetadata(
  293. request_id="test_3",
  294. is_prompt=True,
  295. seq_data={
  296. next(seq_id_counter): create_sequence_data(),
  297. },
  298. sampling_params=create_sampling_params(
  299. 0, stop_token_ids=stop_token_ids),
  300. block_tables={},
  301. )
  302. ]
  303. }
  304. stop_token_ids = [1, 999, 37, 37] # intentional duplication
  305. decode_combination = {
  306. "expected_penalization": [True, False, False, True, False],
  307. "seq_group_metadata_list": [
  308. SequenceGroupMetadata(
  309. request_id="test_1",
  310. is_prompt=False,
  311. seq_data={
  312. next(seq_id_counter):
  313. create_sequence_data(num_generated=1),
  314. next(seq_id_counter):
  315. create_sequence_data(num_generated=100),
  316. },
  317. sampling_params=create_sampling_params(
  318. 2, stop_token_ids=stop_token_ids),
  319. block_tables={},
  320. ),
  321. SequenceGroupMetadata(
  322. request_id="test_2",
  323. is_prompt=False,
  324. seq_data={
  325. next(seq_id_counter):
  326. create_sequence_data(num_generated=20),
  327. next(seq_id_counter):
  328. create_sequence_data(num_generated=1),
  329. next(seq_id_counter):
  330. create_sequence_data(num_generated=10),
  331. },
  332. sampling_params=create_sampling_params(
  333. 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
  334. block_tables={},
  335. ),
  336. ]
  337. }
  338. if seed == 0:
  339. test_cases = [
  340. prompt_without_penalization,
  341. prompt_with_penalization,
  342. prompt_with_penalization_and_prompt_logprobs,
  343. stop_penalizing_after_min_tokens,
  344. prompt_combination,
  345. decode_combination,
  346. ]
  347. else:
  348. test_cases = [generate_test_case()]
  349. def run_test_case(*, expected_penalization: List[bool],
  350. seq_group_metadata_list: List[SequenceGroupMetadata]):
  351. assert expected_penalization, \
  352. "Invalid test case, need expected_penalization"
  353. assert seq_group_metadata_list, \
  354. "Invalid test case, need seq_group_metadata_list"
  355. batch_size = 0
  356. seq_lens: List[int] = []
  357. sampling_params_per_row: List[SamplingParams] = []
  358. for sgm in seq_group_metadata_list:
  359. sampling_params = sgm.sampling_params
  360. num_rows = len(sgm.seq_data)
  361. if sgm.is_prompt:
  362. # a prompt seq_group has only one sequence
  363. seq_data = next(iter(sgm.seq_data.values()))
  364. prompt_len = seq_data.get_prompt_len()
  365. seq_lens.append(prompt_len)
  366. if sgm.sampling_params.prompt_logprobs:
  367. # with prompt_logprobs each token in the prompt has a row in
  368. # logits
  369. num_rows = prompt_len
  370. batch_size += num_rows
  371. sampling_params_per_row.extend(
  372. itertools.repeat(sampling_params, num_rows))
  373. assert len(
  374. expected_penalization
  375. ) == batch_size, \
  376. ("Invalid test case, expected_penalization does not match computed"
  377. "batch size")
  378. _, fake_logits, sampler = _prepare_test(batch_size)
  379. sampling_metadata = SamplingMetadata.prepare(
  380. seq_group_metadata_list,
  381. seq_lens=seq_lens if seq_lens else None,
  382. query_lens=seq_lens if seq_lens else None,
  383. device=device,
  384. pin_memory=is_pin_memory_available())
  385. # the logits tensor is modified in-place by the sampler
  386. _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
  387. for logits_idx, (should_penalize, sampling_params) in enumerate(
  388. zip(expected_penalization, sampling_params_per_row)):
  389. tokens_to_check = sampling_params.all_stop_token_ids
  390. if should_penalize:
  391. for token_id in tokens_to_check:
  392. assert fake_logits[logits_idx, token_id] == -float(
  393. 'inf'
  394. ), f"Expected token {token_id} for logits row {logits_idx}"
  395. " to be penalized"
  396. # no other tokens should be set to -inf
  397. assert torch.count_nonzero(
  398. fake_logits[logits_idx, :] == -float('inf')) == len(
  399. tokens_to_check
  400. ), f"Expected only {len(tokens_to_check)} to be penalized"
  401. else:
  402. # no tokens should be set to -inf
  403. assert torch.count_nonzero(
  404. fake_logits[logits_idx, :] ==
  405. -float('inf')) == 0, "No tokens should have been penalized"
  406. for test_case in test_cases:
  407. run_test_case(**test_case)
  408. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  409. @pytest.mark.parametrize("device", CUDA_DEVICES)
  410. def test_sampler_mixed(seed: int, device: str):
  411. set_random_seed(seed)
  412. torch.set_default_device(device)
  413. batch_size = random.randint(1, 256)
  414. input_tensor, fake_logits, sampler = _prepare_test(batch_size)
  415. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  416. expected_tokens: List[Optional[List[int]]] = []
  417. seq_lens: List[int] = []
  418. for i in range(batch_size):
  419. expected: Optional[List[int]] = None
  420. sampling_type = random.randint(0, 3)
  421. if sampling_type == 0:
  422. sampling_params = SamplingParams(temperature=0)
  423. expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
  424. elif sampling_type in (1, 2):
  425. n = random.randint(1, 10)
  426. sampling_params = SamplingParams(
  427. temperature=random.random() + 0.1,
  428. top_p=min(random.random() + 0.1, 1),
  429. top_k=random.randint(0, 10) or -1,
  430. n=n,
  431. presence_penalty=random.randint(0, 1),
  432. )
  433. if sampling_type == 2:
  434. sampling_params.seed = random.randint(0, 10000)
  435. else:
  436. for idx in range(n):
  437. fake_logits[i, i + idx] = 1e2
  438. expected = list(range(i, i + n))
  439. else:
  440. sampling_params = SamplingParams(temperature=0,
  441. use_beam_search=True,
  442. best_of=2)
  443. expected_tokens.append(expected)
  444. seq_group_metadata_list.append(
  445. SequenceGroupMetadata(
  446. request_id=f"test_{i}",
  447. is_prompt=True,
  448. seq_data={
  449. 0: SequenceData(array(
  450. APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
  451. },
  452. sampling_params=sampling_params,
  453. block_tables={0: [1]},
  454. ))
  455. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  456. generators: Dict[str, torch.Generator] = {}
  457. def test_sampling():
  458. sampling_metadata = SamplingMetadata.prepare(
  459. seq_group_metadata_list,
  460. seq_lens,
  461. query_lens=seq_lens,
  462. device=device,
  463. pin_memory=is_pin_memory_available(),
  464. generators=generators)
  465. sampler_output = sampler(logits=fake_logits,
  466. sampling_metadata=sampling_metadata)
  467. for i, (sequence_output, metadata) in enumerate(
  468. zip(sampler_output, seq_group_metadata_list)):
  469. if metadata.sampling_params.use_beam_search:
  470. continue
  471. if (metadata.sampling_params.seed is not None
  472. and expected_tokens[i] is None):
  473. # Record seeded random result to compare with results of
  474. # second invocation
  475. expected_tokens[i] = [
  476. nth_output.output_token
  477. for nth_output in sequence_output.samples
  478. ]
  479. continue
  480. expected_tokens_item = expected_tokens[i]
  481. assert expected_tokens_item is not None
  482. for n, nth_output in enumerate(sequence_output.samples):
  483. if (metadata.sampling_params.temperature == 0
  484. or metadata.sampling_params.seed is not None):
  485. # Ensure exact matches for greedy or random with seed
  486. assert nth_output.output_token == expected_tokens_item[n]
  487. else:
  488. # For non-seeded random check that one of the high-logit
  489. # tokens were chosen
  490. assert nth_output.output_token in expected_tokens_item
  491. # Test batch
  492. test_sampling()
  493. # Shuffle the batch and resample
  494. target_index = list(range(batch_size))
  495. for list_to_shuffle in (target_index, seq_group_metadata_list,
  496. expected_tokens, seq_lens):
  497. random.Random(seed).shuffle(list_to_shuffle)
  498. target_index = torch.tensor(target_index)
  499. input_tensor.data = input_tensor.index_select(0, target_index)
  500. fake_logits.data = fake_logits.index_select(0, target_index)
  501. # This time, results of seeded random samples will be compared with
  502. # the corresponding sample in the pre-shuffled batch
  503. test_sampling()
  504. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  505. @pytest.mark.parametrize("device", CUDA_DEVICES)
  506. def test_sampler_top_k_top_p(seed: int, device: str):
  507. set_random_seed(seed)
  508. batch_size = random.randint(1, 256)
  509. top_k = random.randint(100, 500)
  510. top_p = random.random() * 0.1
  511. vocab_size = 32000
  512. input_tensor = torch.rand((batch_size, 1024),
  513. device=device,
  514. dtype=torch.float16)
  515. fake_logits = torch.normal(0,
  516. 5,
  517. size=(batch_size, vocab_size),
  518. device=input_tensor.device,
  519. dtype=input_tensor.dtype)
  520. sampler = MockLogitsSampler(fake_logits)
  521. generation_model = GenerationMixin()
  522. generation_config = GenerationConfig(top_k=top_k,
  523. top_p=top_p,
  524. do_sample=True)
  525. warpers = generation_model._get_logits_warper(generation_config, device)
  526. assert len(warpers) == 2 # top_p and top_k
  527. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  528. seq_lens: List[int] = []
  529. for i in range(batch_size):
  530. seq_group_metadata_list.append(
  531. SequenceGroupMetadata(
  532. request_id=f"test_{i}",
  533. is_prompt=True,
  534. seq_data={
  535. 0: SequenceData(array(
  536. APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
  537. },
  538. sampling_params=SamplingParams(
  539. temperature=1,
  540. top_k=top_k,
  541. top_p=top_p,
  542. ),
  543. block_tables={0: [1]},
  544. ))
  545. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  546. sampling_metadata = SamplingMetadata.prepare(
  547. seq_group_metadata_list,
  548. seq_lens,
  549. query_lens=seq_lens,
  550. device=device,
  551. pin_memory=is_pin_memory_available())
  552. sample_probs = None
  553. def mock_sample(probs, *args, **kwargs):
  554. nonlocal sample_probs
  555. sample_probs = probs
  556. return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
  557. for prob in probs], None)
  558. with patch("aphrodite.model_executor.layers.sampler._sample", mock_sample):
  559. sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
  560. assert sample_probs is not None
  561. hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
  562. hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
  563. torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
  564. assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
  565. @pytest.mark.parametrize("device", CUDA_DEVICES)
  566. def test_sampler_repetition_penalty_mixed(device: str):
  567. vocab_size = 8
  568. def test_sampling_params(sampling_params: List[SamplingParams]):
  569. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  570. seq_lens: List[int] = []
  571. for i in range(2):
  572. seq_group_metadata_list.append(
  573. SequenceGroupMetadata(
  574. request_id=f"test_{i}",
  575. is_prompt=True,
  576. seq_data={
  577. 0:
  578. SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  579. [1, 2, 3]))
  580. },
  581. sampling_params=sampling_params[i],
  582. block_tables={0: [1]},
  583. ))
  584. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  585. sampling_metadata = SamplingMetadata.prepare(
  586. seq_group_metadata_list,
  587. seq_lens,
  588. query_lens=seq_lens,
  589. device=device,
  590. pin_memory=is_pin_memory_available())
  591. fake_logits = torch.full((2, vocab_size),
  592. 1e-2,
  593. device=device,
  594. dtype=torch.float16)
  595. fake_logits[:, 5] = 1.1e-2
  596. fake_logits[:, 1] = 1.2e-2
  597. sampler = MockLogitsSampler(fake_logits)
  598. sampler_output = sampler(logits=fake_logits,
  599. sampling_metadata=sampling_metadata)
  600. generated_tokens = []
  601. for output in sampler_output:
  602. generated_tokens.append(output.samples[0].output_token)
  603. return generated_tokens
  604. # one configuration is greedy with repetition_penalty
  605. sampling_params_rep = SamplingParams(
  606. temperature=0.0,
  607. repetition_penalty=2.0,
  608. )
  609. # other configuration is sampling w/o repetition_penalty
  610. sampling_params_sample = SamplingParams(
  611. temperature=1.0,
  612. top_k=1,
  613. seed=42,
  614. )
  615. tokens1 = test_sampling_params(
  616. [sampling_params_rep, sampling_params_sample])
  617. tokens2 = test_sampling_params(
  618. [sampling_params_sample, sampling_params_rep])
  619. assert tokens1[0] == tokens2[1]
  620. assert tokens1[1] == tokens2[0]
  621. @pytest.mark.parametrize("device", CUDA_DEVICES)
  622. def test_sampler_include_gpu_probs_tensor(device: str):
  623. set_random_seed(42)
  624. torch.set_default_device(device)
  625. batch_size = random.randint(1, 256)
  626. _, fake_logits, sampler = _prepare_test(batch_size)
  627. sampler.include_gpu_probs_tensor = True
  628. sampler.should_modify_greedy_probs_inplace = False
  629. sampling_params = SamplingParams(temperature=0)
  630. mock_inplace = Mock()
  631. with patch(
  632. "aphrodite.model_executor.layers.sampler._modify_greedy_probs_inplace",
  633. mock_inplace):
  634. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  635. sampling_params, device)
  636. mock_inplace.assert_not_called()
  637. assert sampler_output.sampled_token_probs is not None
  638. assert sampler_output.logprobs is not None
  639. assert sampler_output.sampled_token_ids is not None