1
0

test_sampler.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. import itertools
  2. import random
  3. from array import array
  4. from dataclasses import dataclass
  5. from typing import Dict, List, Optional, Tuple
  6. from unittest.mock import Mock, patch
  7. import pytest
  8. import torch
  9. from transformers import GenerationConfig, GenerationMixin
  10. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  11. SequenceGroupMetadata)
  12. from aphrodite.common.utils import Counter, is_pin_memory_available
  13. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.modeling.utils import set_random_seed
  17. class MockLogitsSampler(Sampler):
  18. def __init__(self, fake_logits: torch.Tensor):
  19. super().__init__()
  20. self.fake_logits = fake_logits
  21. def forward(self, *args, **kwargs):
  22. return super().forward(*args, **kwargs)
  23. def _prepare_test(
  24. batch_size: int
  25. ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
  26. input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
  27. fake_logits = torch.full((batch_size, VOCAB_SIZE),
  28. 1e-2,
  29. dtype=input_tensor.dtype)
  30. sampler = MockLogitsSampler(fake_logits)
  31. return input_tensor, fake_logits, sampler
  32. VOCAB_SIZE = 32000
  33. RANDOM_SEEDS = list(range(128))
  34. CUDA_DEVICES = [
  35. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  36. ]
  37. def _do_sample(
  38. batch_size: int,
  39. input_tensor: torch.Tensor,
  40. sampler: MockLogitsSampler,
  41. sampling_params: SamplingParams,
  42. device: str,
  43. ):
  44. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  45. seq_lens: List[int] = []
  46. for i in range(batch_size):
  47. seq_group_metadata_list.append(
  48. SequenceGroupMetadata(
  49. request_id=f"test_{i}",
  50. is_prompt=True,
  51. seq_data={
  52. 0: SequenceData(array(
  53. APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
  54. },
  55. sampling_params=sampling_params,
  56. block_tables={0: [1]},
  57. ))
  58. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  59. sampling_metadata = SamplingMetadata.prepare(
  60. seq_group_metadata_list,
  61. seq_lens,
  62. query_lens=seq_lens,
  63. device=device,
  64. pin_memory=is_pin_memory_available())
  65. return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
  66. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  67. @pytest.mark.parametrize("device", CUDA_DEVICES)
  68. def test_sampler_all_greedy(seed: int, device: str):
  69. set_random_seed(seed)
  70. torch.set_default_device(device)
  71. batch_size = random.randint(1, 256)
  72. input_tensor, fake_logits, sampler = _prepare_test(batch_size)
  73. sampling_params = SamplingParams(temperature=0)
  74. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  75. sampling_params, device)
  76. expected = torch.argmax(fake_logits, dim=-1)
  77. for i, sequence_output in enumerate(sampler_output):
  78. for nth_output in sequence_output.samples:
  79. assert nth_output.output_token == expected[i].item()
  80. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  81. @pytest.mark.parametrize("device", CUDA_DEVICES)
  82. def test_sampler_all_random(seed: int, device: str):
  83. set_random_seed(seed)
  84. torch.set_default_device(device)
  85. batch_size = random.randint(1, 256)
  86. _, fake_logits, sampler = _prepare_test(batch_size)
  87. for i in range(batch_size):
  88. fake_logits[i, i] = 1e2
  89. sampling_params = SamplingParams(
  90. temperature=1.0,
  91. n=random.randint(1, 10),
  92. )
  93. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  94. sampling_params, device)
  95. for i, sequence_output in enumerate(sampler_output):
  96. for nth_output in sequence_output.samples:
  97. assert nth_output.output_token == i
  98. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  99. @pytest.mark.parametrize("device", CUDA_DEVICES)
  100. def test_sampler_all_random_seed(seed: int, device: str):
  101. set_random_seed(seed)
  102. torch.set_default_device(device)
  103. batch_size = random.randint(1, 256)
  104. _, fake_logits, sampler = _prepare_test(batch_size)
  105. for i in range(batch_size):
  106. fake_logits[i, i] = 1e2
  107. sampling_params = SamplingParams(
  108. temperature=1.0,
  109. n=random.randint(1, 10),
  110. seed=random.randint(0, 10000),
  111. )
  112. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  113. sampling_params, device)
  114. for i, sequence_output in enumerate(sampler_output):
  115. for nth_output in sequence_output.samples:
  116. assert nth_output.output_token == i
  117. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  118. @pytest.mark.parametrize("device", CUDA_DEVICES)
  119. def test_sampler_all_random_seed_deterministic(seed: int, device: str):
  120. set_random_seed(seed)
  121. torch.set_default_device(device)
  122. batch_size = random.randint(1, 256)
  123. _, fake_logits, sampler = _prepare_test(batch_size)
  124. sampling_params = SamplingParams(
  125. temperature=1.0,
  126. n=random.randint(1, 10),
  127. seed=random.randint(0, 10000),
  128. )
  129. first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
  130. sampling_params, device)
  131. second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
  132. sampling_params, device)
  133. assert first_sampler_output == second_sampler_output
  134. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  135. @pytest.mark.parametrize("device", CUDA_DEVICES)
  136. def test_sampler_all_beam(seed: int, device: str):
  137. set_random_seed(seed)
  138. torch.set_default_device(device)
  139. batch_size = random.randint(1, 256)
  140. _, fake_logits, sampler = _prepare_test(batch_size)
  141. sampling_params = SamplingParams(
  142. temperature=0,
  143. best_of=2,
  144. use_beam_search=True,
  145. )
  146. _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
  147. # no assertion here as I am not sure how to determine whether
  148. # the outputs are expected - in other words, this just tests
  149. # whether there are no exceptions in the sampler
  150. # when handling an all-beam search case.
  151. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  152. @pytest.mark.parametrize("device", CUDA_DEVICES)
  153. def test_sampler_min_tokens_penalty(seed: int, device: str):
  154. seq_id_counter = Counter(start=random.randint(0, 100))
  155. set_random_seed(seed)
  156. torch.set_default_device(device)
  157. def create_sampling_params(min_tokens,
  158. eos_token_id=0,
  159. *,
  160. stop_token_ids: Optional[List[int]] = None,
  161. prompt_logprobs: Optional[int] = None):
  162. sampling_params = SamplingParams(
  163. min_tokens=min_tokens,
  164. max_tokens=9999, # keep higher than max of min_tokens
  165. stop_token_ids=stop_token_ids,
  166. # requesting prompt_logprobs changes the structure of `logits`
  167. prompt_logprobs=prompt_logprobs,
  168. )
  169. sampling_params.all_stop_token_ids.add(eos_token_id)
  170. return sampling_params
  171. def create_sequence_data(num_input=3, num_generated=0):
  172. seq_data = SequenceData(
  173. array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  174. random.choices(range(0, VOCAB_SIZE), k=num_input)))
  175. if num_generated > 0:
  176. seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
  177. k=num_generated)
  178. return seq_data
  179. def generate_test_case():
  180. # generate multiple seq groups but limit total batch size
  181. batch_size = random.randint(1, 128)
  182. expected_penalization = []
  183. sequence_metadata_list: List[SequenceGroupMetadata] = []
  184. # 20% chance to generate seq group metadata list with all prompts
  185. is_prompt = random.random() < 0.2
  186. while batch_size > 0:
  187. num_seqs = 1 if is_prompt else random.randint(1, batch_size)
  188. eos_token_id = random.randint(0, VOCAB_SIZE - 1)
  189. min_tokens = random.randint(0, 50)
  190. num_stop_tokens = random.randint(0, 8)
  191. if num_stop_tokens > 0:
  192. stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
  193. k=num_stop_tokens)
  194. else:
  195. stop_token_ids = None
  196. sampling_params = create_sampling_params(
  197. min_tokens=min_tokens,
  198. eos_token_id=eos_token_id,
  199. stop_token_ids=stop_token_ids)
  200. seq_data: Dict[int, SequenceData] = {}
  201. seq_group_penalization: List[bool] = []
  202. for _ in range(num_seqs):
  203. num_input = random.randint(1, 100)
  204. num_generated = 0 if is_prompt else random.randint(1, 100)
  205. seq_data[next(seq_id_counter)] = create_sequence_data(
  206. num_input=num_input, num_generated=num_generated)
  207. seq_group_penalization.append(num_generated < min_tokens)
  208. expected_penalization.extend(seq_group_penalization)
  209. sequence_metadata_list.append(
  210. SequenceGroupMetadata(
  211. request_id=f"test_{batch_size}",
  212. is_prompt=is_prompt,
  213. seq_data=seq_data,
  214. sampling_params=sampling_params,
  215. block_tables={},
  216. ))
  217. batch_size -= num_seqs
  218. return {
  219. "expected_penalization": expected_penalization,
  220. "seq_group_metadata_list": sequence_metadata_list,
  221. }
  222. # define some explicit test cases for edge case behavior
  223. prompt_without_penalization = {
  224. "expected_penalization": [False],
  225. "seq_group_metadata_list": [
  226. SequenceGroupMetadata(
  227. request_id="test_1",
  228. is_prompt=True,
  229. seq_data={
  230. next(seq_id_counter): create_sequence_data(),
  231. },
  232. sampling_params=create_sampling_params(0),
  233. block_tables={},
  234. ),
  235. ]
  236. }
  237. prompt_with_penalization = {
  238. "expected_penalization": [True],
  239. "seq_group_metadata_list": [
  240. SequenceGroupMetadata(
  241. request_id="test_1",
  242. is_prompt=True,
  243. seq_data={
  244. next(seq_id_counter): create_sequence_data(),
  245. },
  246. sampling_params=create_sampling_params(1),
  247. block_tables={},
  248. ),
  249. ]
  250. }
  251. prompt_with_penalization_and_prompt_logprobs = {
  252. "expected_penalization": [False, False, True],
  253. "seq_group_metadata_list": [
  254. SequenceGroupMetadata(
  255. request_id="test_1",
  256. is_prompt=True,
  257. seq_data={
  258. next(seq_id_counter): create_sequence_data(num_input=3),
  259. },
  260. sampling_params=create_sampling_params(1, prompt_logprobs=3),
  261. block_tables={},
  262. ),
  263. ]
  264. }
  265. stop_penalizing_after_min_tokens = {
  266. "expected_penalization": [False],
  267. "seq_group_metadata_list": [
  268. SequenceGroupMetadata(
  269. request_id="test_1",
  270. is_prompt=False,
  271. seq_data={
  272. next(seq_id_counter):
  273. create_sequence_data(num_generated=1),
  274. },
  275. sampling_params=create_sampling_params(1),
  276. block_tables={},
  277. )
  278. ]
  279. }
  280. stop_token_ids = [42, 99, 42, 0] # intentional duplication
  281. prompt_combination = {
  282. "expected_penalization": [False, True, False],
  283. "seq_group_metadata_list": [
  284. SequenceGroupMetadata(
  285. request_id="test_2",
  286. is_prompt=True,
  287. seq_data={
  288. next(seq_id_counter): create_sequence_data(num_input=2),
  289. },
  290. sampling_params=create_sampling_params(1, prompt_logprobs=3),
  291. block_tables={},
  292. ),
  293. SequenceGroupMetadata(
  294. request_id="test_3",
  295. is_prompt=True,
  296. seq_data={
  297. next(seq_id_counter): create_sequence_data(),
  298. },
  299. sampling_params=create_sampling_params(
  300. 0, stop_token_ids=stop_token_ids),
  301. block_tables={},
  302. )
  303. ]
  304. }
  305. stop_token_ids = [1, 999, 37, 37] # intentional duplication
  306. decode_combination = {
  307. "expected_penalization": [True, False, False, True, False],
  308. "seq_group_metadata_list": [
  309. SequenceGroupMetadata(
  310. request_id="test_1",
  311. is_prompt=False,
  312. seq_data={
  313. next(seq_id_counter):
  314. create_sequence_data(num_generated=1),
  315. next(seq_id_counter):
  316. create_sequence_data(num_generated=100),
  317. },
  318. sampling_params=create_sampling_params(
  319. 2, stop_token_ids=stop_token_ids),
  320. block_tables={},
  321. ),
  322. SequenceGroupMetadata(
  323. request_id="test_2",
  324. is_prompt=False,
  325. seq_data={
  326. next(seq_id_counter):
  327. create_sequence_data(num_generated=20),
  328. next(seq_id_counter):
  329. create_sequence_data(num_generated=1),
  330. next(seq_id_counter):
  331. create_sequence_data(num_generated=10),
  332. },
  333. sampling_params=create_sampling_params(
  334. 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
  335. block_tables={},
  336. ),
  337. ]
  338. }
  339. if seed == 0:
  340. test_cases = [
  341. prompt_without_penalization,
  342. prompt_with_penalization,
  343. prompt_with_penalization_and_prompt_logprobs,
  344. stop_penalizing_after_min_tokens,
  345. prompt_combination,
  346. decode_combination,
  347. ]
  348. else:
  349. test_cases = [generate_test_case()]
  350. def run_test_case(*, expected_penalization: List[bool],
  351. seq_group_metadata_list: List[SequenceGroupMetadata]):
  352. assert expected_penalization, \
  353. "Invalid test case, need expected_penalization"
  354. assert seq_group_metadata_list, \
  355. "Invalid test case, need seq_group_metadata_list"
  356. batch_size = 0
  357. seq_lens: List[int] = []
  358. sampling_params_per_row: List[SamplingParams] = []
  359. for sgm in seq_group_metadata_list:
  360. sampling_params = sgm.sampling_params
  361. num_rows = len(sgm.seq_data)
  362. if sgm.is_prompt:
  363. # a prompt seq_group has only one sequence
  364. seq_data = next(iter(sgm.seq_data.values()))
  365. prompt_len = seq_data.get_prompt_len()
  366. seq_lens.append(prompt_len)
  367. if sgm.sampling_params.prompt_logprobs:
  368. # with prompt_logprobs each token in the prompt has a row in
  369. # logits
  370. num_rows = prompt_len
  371. batch_size += num_rows
  372. sampling_params_per_row.extend(
  373. itertools.repeat(sampling_params, num_rows))
  374. assert len(
  375. expected_penalization
  376. ) == batch_size, \
  377. ("Invalid test case, expected_penalization does not match computed"
  378. "batch size")
  379. _, fake_logits, sampler = _prepare_test(batch_size)
  380. sampling_metadata = SamplingMetadata.prepare(
  381. seq_group_metadata_list,
  382. seq_lens=seq_lens if seq_lens else None,
  383. query_lens=seq_lens if seq_lens else None,
  384. device=device,
  385. pin_memory=is_pin_memory_available())
  386. # the logits tensor is modified in-place by the sampler
  387. _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
  388. for logits_idx, (should_penalize, sampling_params) in enumerate(
  389. zip(expected_penalization, sampling_params_per_row)):
  390. tokens_to_check = sampling_params.all_stop_token_ids
  391. if should_penalize:
  392. for token_id in tokens_to_check:
  393. assert fake_logits[logits_idx, token_id] == -float(
  394. 'inf'
  395. ), f"Expected token {token_id} for logits row {logits_idx}"
  396. " to be penalized"
  397. # no other tokens should be set to -inf
  398. assert torch.count_nonzero(
  399. fake_logits[logits_idx, :] == -float('inf')) == len(
  400. tokens_to_check
  401. ), f"Expected only {len(tokens_to_check)} to be penalized"
  402. else:
  403. # no tokens should be set to -inf
  404. assert torch.count_nonzero(
  405. fake_logits[logits_idx, :] ==
  406. -float('inf')) == 0, "No tokens should have been penalized"
  407. for test_case in test_cases:
  408. run_test_case(**test_case)
  409. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  410. @pytest.mark.parametrize("device", CUDA_DEVICES)
  411. def test_sampler_mixed(seed: int, device: str):
  412. set_random_seed(seed)
  413. torch.set_default_device(device)
  414. batch_size = random.randint(1, 256)
  415. input_tensor, fake_logits, sampler = _prepare_test(batch_size)
  416. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  417. expected_tokens: List[Optional[List[int]]] = []
  418. seq_lens: List[int] = []
  419. for i in range(batch_size):
  420. expected: Optional[List[int]] = None
  421. sampling_type = random.randint(0, 3)
  422. if sampling_type == 0:
  423. sampling_params = SamplingParams(temperature=0)
  424. expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
  425. elif sampling_type in (1, 2):
  426. n = random.randint(1, 10)
  427. sampling_params = SamplingParams(
  428. temperature=random.random() + 0.1,
  429. top_p=min(random.random() + 0.1, 1),
  430. top_k=random.randint(0, 10) or -1,
  431. n=n,
  432. presence_penalty=random.randint(0, 1),
  433. )
  434. if sampling_type == 2:
  435. sampling_params.seed = random.randint(0, 10000)
  436. else:
  437. for idx in range(n):
  438. fake_logits[i, i + idx] = 1e2
  439. expected = list(range(i, i + n))
  440. else:
  441. sampling_params = SamplingParams(temperature=0,
  442. use_beam_search=True,
  443. best_of=2)
  444. expected_tokens.append(expected)
  445. seq_group_metadata_list.append(
  446. SequenceGroupMetadata(
  447. request_id=f"test_{i}",
  448. is_prompt=True,
  449. seq_data={
  450. 0: SequenceData(array(
  451. APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
  452. },
  453. sampling_params=sampling_params,
  454. block_tables={0: [1]},
  455. ))
  456. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  457. generators: Dict[str, torch.Generator] = {}
  458. def test_sampling():
  459. sampling_metadata = SamplingMetadata.prepare(
  460. seq_group_metadata_list,
  461. seq_lens,
  462. query_lens=seq_lens,
  463. device=device,
  464. pin_memory=is_pin_memory_available(),
  465. generators=generators)
  466. sampler_output = sampler(logits=fake_logits,
  467. sampling_metadata=sampling_metadata)
  468. for i, (sequence_output, metadata) in enumerate(
  469. zip(sampler_output, seq_group_metadata_list)):
  470. if metadata.sampling_params.use_beam_search:
  471. continue
  472. if (metadata.sampling_params.seed is not None
  473. and expected_tokens[i] is None):
  474. # Record seeded random result to compare with results of
  475. # second invocation
  476. expected_tokens[i] = [
  477. nth_output.output_token
  478. for nth_output in sequence_output.samples
  479. ]
  480. continue
  481. expected_tokens_item = expected_tokens[i]
  482. assert expected_tokens_item is not None
  483. for n, nth_output in enumerate(sequence_output.samples):
  484. if (metadata.sampling_params.temperature == 0
  485. or metadata.sampling_params.seed is not None):
  486. # Ensure exact matches for greedy or random with seed
  487. assert nth_output.output_token == expected_tokens_item[n]
  488. else:
  489. # For non-seeded random check that one of the high-logit
  490. # tokens were chosen
  491. assert nth_output.output_token in expected_tokens_item
  492. # Test batch
  493. test_sampling()
  494. # Shuffle the batch and resample
  495. target_index = list(range(batch_size))
  496. for list_to_shuffle in (target_index, seq_group_metadata_list,
  497. expected_tokens, seq_lens):
  498. random.Random(seed).shuffle(list_to_shuffle)
  499. target_index = torch.tensor(target_index)
  500. input_tensor.data = input_tensor.index_select(0, target_index)
  501. fake_logits.data = fake_logits.index_select(0, target_index)
  502. # This time, results of seeded random samples will be compared with
  503. # the corresponding sample in the pre-shuffled batch
  504. test_sampling()
  505. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  506. @pytest.mark.parametrize("device", CUDA_DEVICES)
  507. def test_sampler_top_k_top_p(seed: int, device: str):
  508. set_random_seed(seed)
  509. batch_size = random.randint(1, 256)
  510. top_k = random.randint(100, 500)
  511. top_p = random.random() * 0.1
  512. vocab_size = 32000
  513. input_tensor = torch.rand((batch_size, 1024),
  514. device=device,
  515. dtype=torch.float16)
  516. fake_logits = torch.normal(0,
  517. 5,
  518. size=(batch_size, vocab_size),
  519. device=input_tensor.device,
  520. dtype=input_tensor.dtype)
  521. sampler = MockLogitsSampler(fake_logits)
  522. generation_model = GenerationMixin()
  523. generation_config = GenerationConfig(top_k=top_k,
  524. top_p=top_p,
  525. do_sample=True)
  526. @dataclass
  527. class MockConfig:
  528. is_encoder_decoder: bool = False
  529. generation_model.config = MockConfig() # needed by the following method
  530. generation_model._prepare_special_tokens(generation_config, device=device)
  531. processors = generation_model._get_logits_processor(generation_config,
  532. None,
  533. None,
  534. None, [],
  535. device=device)
  536. assert len(processors) == 2 # top_p and top_k
  537. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  538. seq_lens: List[int] = []
  539. for i in range(batch_size):
  540. seq_group_metadata_list.append(
  541. SequenceGroupMetadata(
  542. request_id=f"test_{i}",
  543. is_prompt=True,
  544. seq_data={
  545. 0: SequenceData(array(
  546. APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3]))
  547. },
  548. sampling_params=SamplingParams(
  549. temperature=1,
  550. top_k=top_k,
  551. top_p=top_p,
  552. ),
  553. block_tables={0: [1]},
  554. ))
  555. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  556. sampling_metadata = SamplingMetadata.prepare(
  557. seq_group_metadata_list,
  558. seq_lens,
  559. query_lens=seq_lens,
  560. device=device,
  561. pin_memory=is_pin_memory_available())
  562. sample_probs = None
  563. def mock_sample(probs, *args, **kwargs):
  564. nonlocal sample_probs
  565. sample_probs = probs
  566. return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
  567. for prob in probs], None)
  568. with patch("aphrodite.modeling.layers.sampler._sample", mock_sample):
  569. sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
  570. assert sample_probs is not None
  571. hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
  572. hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
  573. torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
  574. assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
  575. @pytest.mark.parametrize("device", CUDA_DEVICES)
  576. def test_sampler_repetition_penalty_mixed(device: str):
  577. vocab_size = 8
  578. def test_sampling_params(sampling_params: List[SamplingParams]):
  579. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  580. seq_lens: List[int] = []
  581. for i in range(2):
  582. seq_group_metadata_list.append(
  583. SequenceGroupMetadata(
  584. request_id=f"test_{i}",
  585. is_prompt=True,
  586. seq_data={
  587. 0:
  588. SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  589. [1, 2, 3]))
  590. },
  591. sampling_params=sampling_params[i],
  592. block_tables={0: [1]},
  593. ))
  594. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  595. sampling_metadata = SamplingMetadata.prepare(
  596. seq_group_metadata_list,
  597. seq_lens,
  598. query_lens=seq_lens,
  599. device=device,
  600. pin_memory=is_pin_memory_available())
  601. fake_logits = torch.full((2, vocab_size),
  602. 1e-2,
  603. device=device,
  604. dtype=torch.float16)
  605. fake_logits[:, 5] = 1.1e-2
  606. fake_logits[:, 1] = 1.2e-2
  607. sampler = MockLogitsSampler(fake_logits)
  608. sampler_output = sampler(logits=fake_logits,
  609. sampling_metadata=sampling_metadata)
  610. generated_tokens = []
  611. for output in sampler_output:
  612. generated_tokens.append(output.samples[0].output_token)
  613. return generated_tokens
  614. # one configuration is greedy with repetition_penalty
  615. sampling_params_rep = SamplingParams(
  616. temperature=0.0,
  617. repetition_penalty=2.0,
  618. )
  619. # other configuration is sampling w/o repetition_penalty
  620. sampling_params_sample = SamplingParams(
  621. temperature=1.0,
  622. top_k=1,
  623. seed=42,
  624. )
  625. tokens1 = test_sampling_params(
  626. [sampling_params_rep, sampling_params_sample])
  627. tokens2 = test_sampling_params(
  628. [sampling_params_sample, sampling_params_rep])
  629. assert tokens1[0] == tokens2[1]
  630. assert tokens1[1] == tokens2[0]
  631. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  632. @pytest.mark.parametrize("device", CUDA_DEVICES)
  633. def test_sampler_no_repeat_ngram(seed: int, device: str):
  634. """Test that no-repeat-ngram sampling behaves as expected."""
  635. set_random_seed(seed)
  636. torch.set_default_device(device)
  637. batch_size = random.randint(1, 256)
  638. _, fake_logits, sampler = _prepare_test(batch_size)
  639. test_sequences = {
  640. # Format: sequence: [tokens_that_should_be_blocked]
  641. (1, 2, 3): [3], # With ngram_size=2, should block 3 after [2]
  642. (4, 5, 4, 5): [4], # With ngram_size=2, should block 4 after [5]
  643. (6, 7, 8, 6, 7): [8], # With ngram_size=3, should block 8 after [6, 7]
  644. (1, 2, 3, 4, 1, 2): [3], # With ngram_size=4, should block 3 after [1, 2] # noqa: E501
  645. }
  646. for input_seq, blocked_tokens in test_sequences.items():
  647. for ngram_size in [2, 3, 4]:
  648. sampling_params = SamplingParams(
  649. temperature=1.0,
  650. no_repeat_ngram_size=ngram_size,
  651. seed=random.randint(0, 10000),
  652. )
  653. sampler_output = _do_sample(
  654. 1,
  655. fake_logits[0:1].clone(), # Just use first row
  656. sampler,
  657. sampling_params,
  658. device
  659. )
  660. if len(input_seq) >= ngram_size:
  661. # check if blocked tokens have -inf logits
  662. for token in blocked_tokens:
  663. assert sampler_output[0].samples[0].output_token != token, \
  664. f"Token {token} should have been blocked by {ngram_size}-gram repetition prevention" # noqa: E501
  665. # disabled
  666. sampling_params = SamplingParams(
  667. temperature=1.0,
  668. no_repeat_ngram_size=0,
  669. seed=random.randint(0, 10000),
  670. )
  671. sampler_output = _do_sample(
  672. 1,
  673. fake_logits[0:1].clone(),
  674. sampler,
  675. sampling_params,
  676. device
  677. )
  678. output_token = sampler_output[0].samples[0].output_token
  679. assert output_token is not None, "Should produce output token with ngram_size=0" # noqa: E501
  680. # determinism
  681. sampling_params = SamplingParams(
  682. temperature=1.0,
  683. no_repeat_ngram_size=3,
  684. seed=random.randint(0, 10000),
  685. )
  686. first_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  687. sampling_params, device)
  688. second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  689. sampling_params, device)
  690. assert first_output == second_output, \
  691. "No-repeat-ngram sampling is not deterministic with same seed"
  692. @pytest.mark.parametrize("device", CUDA_DEVICES)
  693. def test_sampler_dry(device: str):
  694. vocab_size = 8
  695. def test_sampling_params(sampling_params: List[SamplingParams]):
  696. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  697. seq_lens: List[int] = []
  698. for i in range(2):
  699. seq_group_metadata_list.append(
  700. SequenceGroupMetadata(
  701. request_id=f"test_{i}",
  702. is_prompt=True,
  703. seq_data={
  704. 0: SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  705. [1, 2, 3, 1, 2]))
  706. },
  707. sampling_params=sampling_params[i],
  708. block_tables={0: [1]},
  709. ))
  710. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  711. sampling_metadata = SamplingMetadata.prepare(
  712. seq_group_metadata_list,
  713. seq_lens,
  714. query_lens=seq_lens,
  715. device=device,
  716. pin_memory=is_pin_memory_available())
  717. fake_logits = torch.full((2, vocab_size),
  718. 1e-2,
  719. device=device,
  720. dtype=torch.float16)
  721. fake_logits[:, 3] = 1.0
  722. sampler = MockLogitsSampler(fake_logits)
  723. sampler_output = sampler(logits=fake_logits,
  724. sampling_metadata=sampling_metadata)
  725. generated_tokens = []
  726. for output in sampler_output:
  727. generated_tokens.append(output.samples[0].output_token)
  728. return generated_tokens
  729. # Test case 1: DRY disabled (multiplier = 0)
  730. sampling_params_no_dry = SamplingParams(
  731. temperature=0.0,
  732. dry_multiplier=0.0,
  733. )
  734. # Test case 2: DRY enabled with full range
  735. sampling_params_full_dry = SamplingParams(
  736. temperature=0.0,
  737. dry_multiplier=1.0,
  738. dry_allowed_length=2,
  739. dry_base=2.0,
  740. dry_range=0,
  741. )
  742. sampling_params_limited_dry = SamplingParams(
  743. temperature=0.0,
  744. dry_multiplier=1.0,
  745. dry_allowed_length=2,
  746. dry_base=2.0,
  747. dry_range=3,
  748. )
  749. tokens1 = test_sampling_params(
  750. [sampling_params_no_dry, sampling_params_full_dry])
  751. assert tokens1[0] == 3, "Without DRY, should choose highest logit token"
  752. assert tokens1[1] != 3, "With full-range DRY, should avoid repeating pattern" # noqa: E501
  753. tokens2 = test_sampling_params(
  754. [sampling_params_full_dry, sampling_params_limited_dry])
  755. assert tokens2[0] != 3, "Full-range DRY should detect full pattern"
  756. assert tokens2[1] == 3, "Limited-range DRY should only consider recent tokens" # noqa: E501
  757. tokens3 = test_sampling_params(
  758. [sampling_params_full_dry, sampling_params_limited_dry])
  759. assert tokens2 == tokens3, "DRY sampling should be deterministic"
  760. @pytest.mark.parametrize("device", CUDA_DEVICES)
  761. def test_sampler_dry_sequence_breakers(device: str):
  762. """Test that DRY respects sequence breakers."""
  763. vocab_size = 8
  764. # 7 is a sequence breaker
  765. input_sequence = [1, 2, 7, 1, 2]
  766. seq_group_metadata = SequenceGroupMetadata(
  767. request_id="test_0",
  768. is_prompt=True,
  769. seq_data={
  770. 0: SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  771. input_sequence))
  772. },
  773. sampling_params=SamplingParams(
  774. temperature=0.0,
  775. dry_multiplier=1.0,
  776. dry_allowed_length=2,
  777. dry_base=2.0,
  778. dry_range=0,
  779. dry_sequence_breaker_ids=[7],
  780. ),
  781. block_tables={0: [1]},
  782. )
  783. sampling_metadata = SamplingMetadata.prepare(
  784. [seq_group_metadata],
  785. seq_lens=[len(input_sequence)],
  786. query_lens=[len(input_sequence)],
  787. device=device,
  788. pin_memory=is_pin_memory_available())
  789. fake_logits = torch.full((1, vocab_size),
  790. 1e-2,
  791. device=device,
  792. dtype=torch.float16)
  793. fake_logits[0, 3] = 1.0
  794. sampler = MockLogitsSampler(fake_logits)
  795. sampler_output = sampler(logits=fake_logits,
  796. sampling_metadata=sampling_metadata)
  797. assert sampler_output[0].samples[0].output_token == 3, \
  798. "DRY should not detect patterns across sequence breakers"
  799. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  800. @pytest.mark.parametrize("device", CUDA_DEVICES)
  801. def test_sampler_nsigma(seed: int, device: str):
  802. """Test that top-nsigma sampling behaves as expected."""
  803. set_random_seed(seed)
  804. torch.set_default_device(device)
  805. batch_size = random.randint(1, 256)
  806. _, fake_logits, sampler = _prepare_test(batch_size)
  807. # Create a clear separation in logits for testing
  808. high_logit_indices = {} # Store high logit indices for each batch
  809. for i in range(batch_size):
  810. # Set a few logits significantly higher than others
  811. num_high_logits = random.randint(1, 5)
  812. high_indices = random.sample(range(fake_logits.size(1)),
  813. num_high_logits)
  814. high_logit_indices[i] = set(high_indices) # Store for verification
  815. for idx in high_indices:
  816. fake_logits[i, idx] = 10.0 # Clearly above the mean
  817. # Test with different nsigma values
  818. for nsigma in [1.5, 2.0, 3.0]:
  819. sampling_params = SamplingParams(
  820. temperature=1.0,
  821. nsigma=nsigma,
  822. seed=random.randint(0, 10000),
  823. )
  824. sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  825. sampling_params, device)
  826. # Verify that sampling only selects from high logits
  827. for batch_idx, sequence_output in enumerate(sampler_output):
  828. for nth_output in sequence_output.samples:
  829. token_id = nth_output.output_token
  830. # The token should come from the high logits region
  831. assert token_id in high_logit_indices[batch_idx], \
  832. f"Sampled token {token_id} for batch {batch_idx} was not in the high logit set" # noqa
  833. # Test determinism
  834. second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  835. sampling_params, device)
  836. assert sampler_output == second_output, \
  837. "Top-nsigma sampling is not deterministic with same seed"
  838. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  839. @pytest.mark.parametrize("device", CUDA_DEVICES)
  840. def test_sampler_skew(seed: int, device: str):
  841. """Test that skew sampling behaves as expected."""
  842. set_random_seed(seed)
  843. torch.set_default_device(device)
  844. batch_size = random.randint(1, 256)
  845. _, fake_logits, sampler = _prepare_test(batch_size)
  846. high_prob_tokens = {}
  847. for i in range(batch_size):
  848. # Make token i have a much higher logit in sequence i
  849. fake_logits[i, i] = 10.0
  850. high_prob_tokens[i] = i
  851. test_cases = [
  852. # (skew, expected_behavior)
  853. (2.0, "low"), # Strong bias away from high probability tokens
  854. (0.5, "subtle"), # Subtle bias away from high probability tokens
  855. (0.0, "neutral"), # No bias (regular sampling)
  856. ]
  857. for skew, expected_behavior in test_cases:
  858. sampling_params = SamplingParams(
  859. temperature=1.0, # neutral temperature
  860. skew=skew,
  861. seed=random.randint(0, 10000), # for determinism
  862. )
  863. sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  864. sampling_params, device)
  865. for batch_idx, sequence_output in enumerate(sampler_output):
  866. token_id = sequence_output.samples[0].output_token
  867. if expected_behavior == "low":
  868. # strong skew should bias away from high probability tokens
  869. assert token_id != high_prob_tokens[batch_idx], \
  870. f"With high skew {skew}, should not select high " \
  871. f"probability token {high_prob_tokens[batch_idx]}"
  872. elif expected_behavior == "subtle":
  873. # we don't assert anything for subtle effect,
  874. # as it's probabilistic
  875. pass
  876. # determinism
  877. second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  878. sampling_params, device)
  879. assert sampler_output == second_output, \
  880. f"Skew sampling with seed is not deterministic for skew={skew}"
  881. @pytest.mark.parametrize("device", CUDA_DEVICES)
  882. def test_sampler_include_gpu_probs_tensor(device: str):
  883. set_random_seed(42)
  884. torch.set_default_device(device)
  885. batch_size = random.randint(1, 256)
  886. _, fake_logits, sampler = _prepare_test(batch_size)
  887. sampler.include_gpu_probs_tensor = True
  888. sampler.should_modify_greedy_probs_inplace = False
  889. sampling_params = SamplingParams(temperature=0)
  890. mock_inplace = Mock()
  891. with patch(
  892. "aphrodite.modeling.layers.sampler._modify_greedy_probs_inplace",
  893. mock_inplace):
  894. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  895. sampling_params, device)
  896. mock_inplace.assert_not_called()
  897. assert sampler_output.sampled_token_probs is not None
  898. assert sampler_output.logprobs is not None
  899. assert sampler_output.sampled_token_ids is not None