test_sampler.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. import itertools
  2. import random
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. from unittest.mock import Mock, patch
  6. import pytest
  7. import torch
  8. from transformers import GenerationConfig, GenerationMixin
  9. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  10. SequenceGroupMetadata)
  11. from aphrodite.common.utils import Counter, is_pin_memory_available
  12. from aphrodite.modeling.layers.sampler import Sampler
  13. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  14. from aphrodite.modeling.utils import set_random_seed
  15. class MockLogitsSampler(Sampler):
  16. def __init__(self, fake_logits: torch.Tensor):
  17. super().__init__()
  18. self.fake_logits = fake_logits
  19. def forward(self, *args, **kwargs):
  20. return super().forward(*args, **kwargs)
  21. def _prepare_test(
  22. batch_size: int
  23. ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
  24. input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
  25. fake_logits = torch.full((batch_size, VOCAB_SIZE),
  26. 1e-2,
  27. dtype=input_tensor.dtype)
  28. sampler = MockLogitsSampler(fake_logits)
  29. return input_tensor, fake_logits, sampler
  30. VOCAB_SIZE = 32000
  31. RANDOM_SEEDS = list(range(128))
  32. CUDA_DEVICES = [
  33. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  34. ]
  35. def _do_sample(
  36. batch_size: int,
  37. input_tensor: torch.Tensor,
  38. sampler: MockLogitsSampler,
  39. sampling_params: SamplingParams,
  40. device: str,
  41. ):
  42. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  43. seq_lens: List[int] = []
  44. for i in range(batch_size):
  45. seq_group_metadata_list.append(
  46. SequenceGroupMetadata(
  47. request_id=f"test_{i}",
  48. is_prompt=True,
  49. seq_data={0: SequenceData.from_seqs([1, 2, 3])},
  50. sampling_params=sampling_params,
  51. block_tables={0: [1]},
  52. ))
  53. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  54. sampling_metadata = SamplingMetadata.prepare(
  55. seq_group_metadata_list,
  56. seq_lens,
  57. query_lens=seq_lens,
  58. device=device,
  59. pin_memory=is_pin_memory_available())
  60. return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
  61. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  62. @pytest.mark.parametrize("device", CUDA_DEVICES)
  63. def test_sampler_all_greedy(seed: int, device: str):
  64. set_random_seed(seed)
  65. torch.set_default_device(device)
  66. batch_size = random.randint(1, 256)
  67. input_tensor, fake_logits, sampler = _prepare_test(batch_size)
  68. sampling_params = SamplingParams(temperature=0)
  69. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  70. sampling_params, device)
  71. expected = torch.argmax(fake_logits, dim=-1)
  72. for i, sequence_output in enumerate(sampler_output):
  73. for nth_output in sequence_output.samples:
  74. assert nth_output.output_token == expected[i].item()
  75. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  76. @pytest.mark.parametrize("device", CUDA_DEVICES)
  77. def test_sampler_all_random(seed: int, device: str):
  78. set_random_seed(seed)
  79. torch.set_default_device(device)
  80. batch_size = random.randint(1, 256)
  81. _, fake_logits, sampler = _prepare_test(batch_size)
  82. for i in range(batch_size):
  83. fake_logits[i, i] = 1e2
  84. sampling_params = SamplingParams(
  85. temperature=1.0,
  86. n=random.randint(1, 10),
  87. )
  88. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  89. sampling_params, device)
  90. for i, sequence_output in enumerate(sampler_output):
  91. for nth_output in sequence_output.samples:
  92. assert nth_output.output_token == i
  93. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  94. @pytest.mark.parametrize("device", CUDA_DEVICES)
  95. def test_sampler_all_random_seed(seed: int, device: str):
  96. set_random_seed(seed)
  97. torch.set_default_device(device)
  98. batch_size = random.randint(1, 256)
  99. _, fake_logits, sampler = _prepare_test(batch_size)
  100. for i in range(batch_size):
  101. fake_logits[i, i] = 1e2
  102. sampling_params = SamplingParams(
  103. temperature=1.0,
  104. n=random.randint(1, 10),
  105. seed=random.randint(0, 10000),
  106. )
  107. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  108. sampling_params, device)
  109. for i, sequence_output in enumerate(sampler_output):
  110. for nth_output in sequence_output.samples:
  111. assert nth_output.output_token == i
  112. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  113. @pytest.mark.parametrize("device", CUDA_DEVICES)
  114. def test_sampler_all_random_seed_deterministic(seed: int, device: str):
  115. set_random_seed(seed)
  116. torch.set_default_device(device)
  117. batch_size = random.randint(1, 256)
  118. _, fake_logits, sampler = _prepare_test(batch_size)
  119. sampling_params = SamplingParams(
  120. temperature=1.0,
  121. n=random.randint(1, 10),
  122. seed=random.randint(0, 10000),
  123. )
  124. first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
  125. sampling_params, device)
  126. second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
  127. sampling_params, device)
  128. assert first_sampler_output == second_sampler_output
  129. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  130. @pytest.mark.parametrize("device", CUDA_DEVICES)
  131. def test_sampler_all_beam(seed: int, device: str):
  132. set_random_seed(seed)
  133. torch.set_default_device(device)
  134. batch_size = random.randint(1, 256)
  135. _, fake_logits, sampler = _prepare_test(batch_size)
  136. sampling_params = SamplingParams(
  137. temperature=0,
  138. best_of=2,
  139. use_beam_search=True,
  140. )
  141. _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
  142. # no assertion here as I am not sure how to determine whether
  143. # the outputs are expected - in other words, this just tests
  144. # whether there are no exceptions in the sampler
  145. # when handling an all-beam search case.
  146. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  147. @pytest.mark.parametrize("device", CUDA_DEVICES)
  148. def test_sampler_min_tokens_penalty(seed: int, device: str):
  149. seq_id_counter = Counter(start=random.randint(0, 100))
  150. set_random_seed(seed)
  151. torch.set_default_device(device)
  152. def create_sampling_params(min_tokens,
  153. eos_token_id=0,
  154. *,
  155. stop_token_ids: Optional[List[int]] = None,
  156. prompt_logprobs: Optional[int] = None):
  157. sampling_params = SamplingParams(
  158. min_tokens=min_tokens,
  159. max_tokens=9999, # keep higher than max of min_tokens
  160. stop_token_ids=stop_token_ids,
  161. # requesting prompt_logprobs changes the structure of `logits`
  162. prompt_logprobs=prompt_logprobs,
  163. )
  164. sampling_params.all_stop_token_ids.add(eos_token_id)
  165. return sampling_params
  166. def create_sequence_data(num_input=3, num_generated=0):
  167. seq_data = SequenceData.from_seqs(
  168. random.choices(range(0, VOCAB_SIZE), k=num_input))
  169. if num_generated > 0:
  170. seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
  171. k=num_generated)
  172. return seq_data
  173. def generate_test_case():
  174. # generate multiple seq groups but limit total batch size
  175. batch_size = random.randint(1, 128)
  176. expected_penalization = []
  177. sequence_metadata_list: List[SequenceGroupMetadata] = []
  178. # 20% chance to generate seq group metadata list with all prompts
  179. is_prompt = random.random() < 0.2
  180. while batch_size > 0:
  181. num_seqs = 1 if is_prompt else random.randint(1, batch_size)
  182. eos_token_id = random.randint(0, VOCAB_SIZE - 1)
  183. min_tokens = random.randint(0, 50)
  184. num_stop_tokens = random.randint(0, 8)
  185. if num_stop_tokens > 0:
  186. stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
  187. k=num_stop_tokens)
  188. else:
  189. stop_token_ids = None
  190. sampling_params = create_sampling_params(
  191. min_tokens=min_tokens,
  192. eos_token_id=eos_token_id,
  193. stop_token_ids=stop_token_ids)
  194. seq_data: Dict[int, SequenceData] = {}
  195. seq_group_penalization: List[bool] = []
  196. for _ in range(num_seqs):
  197. num_input = random.randint(1, 100)
  198. num_generated = 0 if is_prompt else random.randint(1, 100)
  199. seq_data[next(seq_id_counter)] = create_sequence_data(
  200. num_input=num_input, num_generated=num_generated)
  201. seq_group_penalization.append(num_generated < min_tokens)
  202. expected_penalization.extend(seq_group_penalization)
  203. sequence_metadata_list.append(
  204. SequenceGroupMetadata(
  205. request_id=f"test_{batch_size}",
  206. is_prompt=is_prompt,
  207. seq_data=seq_data,
  208. sampling_params=sampling_params,
  209. block_tables={},
  210. ))
  211. batch_size -= num_seqs
  212. return {
  213. "expected_penalization": expected_penalization,
  214. "seq_group_metadata_list": sequence_metadata_list,
  215. }
  216. # define some explicit test cases for edge case behavior
  217. prompt_without_penalization = {
  218. "expected_penalization": [False],
  219. "seq_group_metadata_list": [
  220. SequenceGroupMetadata(
  221. request_id="test_1",
  222. is_prompt=True,
  223. seq_data={
  224. next(seq_id_counter): create_sequence_data(),
  225. },
  226. sampling_params=create_sampling_params(0),
  227. block_tables={},
  228. ),
  229. ]
  230. }
  231. prompt_with_penalization = {
  232. "expected_penalization": [True],
  233. "seq_group_metadata_list": [
  234. SequenceGroupMetadata(
  235. request_id="test_1",
  236. is_prompt=True,
  237. seq_data={
  238. next(seq_id_counter): create_sequence_data(),
  239. },
  240. sampling_params=create_sampling_params(1),
  241. block_tables={},
  242. ),
  243. ]
  244. }
  245. prompt_with_penalization_and_prompt_logprobs = {
  246. "expected_penalization": [False, False, True],
  247. "seq_group_metadata_list": [
  248. SequenceGroupMetadata(
  249. request_id="test_1",
  250. is_prompt=True,
  251. seq_data={
  252. next(seq_id_counter): create_sequence_data(num_input=3),
  253. },
  254. sampling_params=create_sampling_params(1, prompt_logprobs=3),
  255. block_tables={},
  256. ),
  257. ]
  258. }
  259. stop_penalizing_after_min_tokens = {
  260. "expected_penalization": [False],
  261. "seq_group_metadata_list": [
  262. SequenceGroupMetadata(
  263. request_id="test_1",
  264. is_prompt=False,
  265. seq_data={
  266. next(seq_id_counter):
  267. create_sequence_data(num_generated=1),
  268. },
  269. sampling_params=create_sampling_params(1),
  270. block_tables={},
  271. )
  272. ]
  273. }
  274. stop_token_ids = [42, 99, 42, 0] # intentional duplication
  275. prompt_combination = {
  276. "expected_penalization": [False, True, False],
  277. "seq_group_metadata_list": [
  278. SequenceGroupMetadata(
  279. request_id="test_2",
  280. is_prompt=True,
  281. seq_data={
  282. next(seq_id_counter): create_sequence_data(num_input=2),
  283. },
  284. sampling_params=create_sampling_params(1, prompt_logprobs=3),
  285. block_tables={},
  286. ),
  287. SequenceGroupMetadata(
  288. request_id="test_3",
  289. is_prompt=True,
  290. seq_data={
  291. next(seq_id_counter): create_sequence_data(),
  292. },
  293. sampling_params=create_sampling_params(
  294. 0, stop_token_ids=stop_token_ids),
  295. block_tables={},
  296. )
  297. ]
  298. }
  299. stop_token_ids = [1, 999, 37, 37] # intentional duplication
  300. decode_combination = {
  301. "expected_penalization": [True, False, False, True, False],
  302. "seq_group_metadata_list": [
  303. SequenceGroupMetadata(
  304. request_id="test_1",
  305. is_prompt=False,
  306. seq_data={
  307. next(seq_id_counter):
  308. create_sequence_data(num_generated=1),
  309. next(seq_id_counter):
  310. create_sequence_data(num_generated=100),
  311. },
  312. sampling_params=create_sampling_params(
  313. 2, stop_token_ids=stop_token_ids),
  314. block_tables={},
  315. ),
  316. SequenceGroupMetadata(
  317. request_id="test_2",
  318. is_prompt=False,
  319. seq_data={
  320. next(seq_id_counter):
  321. create_sequence_data(num_generated=20),
  322. next(seq_id_counter):
  323. create_sequence_data(num_generated=1),
  324. next(seq_id_counter):
  325. create_sequence_data(num_generated=10),
  326. },
  327. sampling_params=create_sampling_params(
  328. 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
  329. block_tables={},
  330. ),
  331. ]
  332. }
  333. if seed == 0:
  334. test_cases = [
  335. prompt_without_penalization,
  336. prompt_with_penalization,
  337. prompt_with_penalization_and_prompt_logprobs,
  338. stop_penalizing_after_min_tokens,
  339. prompt_combination,
  340. decode_combination,
  341. ]
  342. else:
  343. test_cases = [generate_test_case()]
  344. def run_test_case(*, expected_penalization: List[bool],
  345. seq_group_metadata_list: List[SequenceGroupMetadata]):
  346. assert expected_penalization, \
  347. "Invalid test case, need expected_penalization"
  348. assert seq_group_metadata_list, \
  349. "Invalid test case, need seq_group_metadata_list"
  350. batch_size = 0
  351. seq_lens: List[int] = []
  352. sampling_params_per_row: List[SamplingParams] = []
  353. for sgm in seq_group_metadata_list:
  354. sampling_params = sgm.sampling_params
  355. num_rows = len(sgm.seq_data)
  356. if sgm.is_prompt:
  357. # a prompt seq_group has only one sequence
  358. seq_data = next(iter(sgm.seq_data.values()))
  359. prompt_len = seq_data.get_prompt_len()
  360. seq_lens.append(prompt_len)
  361. if sgm.sampling_params.prompt_logprobs:
  362. # with prompt_logprobs each token in the prompt has a row in
  363. # logits
  364. num_rows = prompt_len
  365. batch_size += num_rows
  366. sampling_params_per_row.extend(
  367. itertools.repeat(sampling_params, num_rows))
  368. assert len(
  369. expected_penalization
  370. ) == batch_size, \
  371. ("Invalid test case, expected_penalization does not match computed"
  372. "batch size")
  373. _, fake_logits, sampler = _prepare_test(batch_size)
  374. sampling_metadata = SamplingMetadata.prepare(
  375. seq_group_metadata_list,
  376. seq_lens=seq_lens if seq_lens else None,
  377. query_lens=seq_lens if seq_lens else None,
  378. device=device,
  379. pin_memory=is_pin_memory_available())
  380. # the logits tensor is modified in-place by the sampler
  381. _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
  382. for logits_idx, (should_penalize, sampling_params) in enumerate(
  383. zip(expected_penalization, sampling_params_per_row)):
  384. tokens_to_check = sampling_params.all_stop_token_ids
  385. if should_penalize:
  386. for token_id in tokens_to_check:
  387. assert fake_logits[logits_idx, token_id] == -float(
  388. 'inf'
  389. ), f"Expected token {token_id} for logits row {logits_idx}"
  390. " to be penalized"
  391. # no other tokens should be set to -inf
  392. assert torch.count_nonzero(
  393. fake_logits[logits_idx, :] == -float('inf')) == len(
  394. tokens_to_check
  395. ), f"Expected only {len(tokens_to_check)} to be penalized"
  396. else:
  397. # no tokens should be set to -inf
  398. assert torch.count_nonzero(
  399. fake_logits[logits_idx, :] ==
  400. -float('inf')) == 0, "No tokens should have been penalized"
  401. for test_case in test_cases:
  402. run_test_case(**test_case)
  403. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  404. @pytest.mark.parametrize("device", CUDA_DEVICES)
  405. def test_sampler_mixed(seed: int, device: str):
  406. set_random_seed(seed)
  407. torch.set_default_device(device)
  408. batch_size = random.randint(1, 256)
  409. input_tensor, fake_logits, sampler = _prepare_test(batch_size)
  410. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  411. expected_tokens: List[Optional[List[int]]] = []
  412. seq_lens: List[int] = []
  413. for i in range(batch_size):
  414. expected: Optional[List[int]] = None
  415. sampling_type = random.randint(0, 3)
  416. if sampling_type == 0:
  417. sampling_params = SamplingParams(temperature=0)
  418. expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
  419. elif sampling_type in (1, 2):
  420. n = random.randint(1, 10)
  421. sampling_params = SamplingParams(
  422. temperature=random.random() + 0.1,
  423. top_p=min(random.random() + 0.1, 1),
  424. top_k=random.randint(0, 10) or -1,
  425. n=n,
  426. presence_penalty=random.randint(0, 1),
  427. )
  428. if sampling_type == 2:
  429. sampling_params.seed = random.randint(0, 10000)
  430. else:
  431. for idx in range(n):
  432. fake_logits[i, i + idx] = 1e2
  433. expected = list(range(i, i + n))
  434. else:
  435. sampling_params = SamplingParams(temperature=0,
  436. use_beam_search=True,
  437. best_of=2)
  438. expected_tokens.append(expected)
  439. seq_group_metadata_list.append(
  440. SequenceGroupMetadata(
  441. request_id=f"test_{i}",
  442. is_prompt=True,
  443. seq_data={0: SequenceData.from_seqs([1, 2, 3])},
  444. sampling_params=sampling_params,
  445. block_tables={0: [1]},
  446. ))
  447. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  448. generators: Dict[str, torch.Generator] = {}
  449. def test_sampling():
  450. sampling_metadata = SamplingMetadata.prepare(
  451. seq_group_metadata_list,
  452. seq_lens,
  453. query_lens=seq_lens,
  454. device=device,
  455. pin_memory=is_pin_memory_available(),
  456. generators=generators)
  457. sampler_output = sampler(logits=fake_logits,
  458. sampling_metadata=sampling_metadata)
  459. for i, (sequence_output, metadata) in enumerate(
  460. zip(sampler_output, seq_group_metadata_list)):
  461. if metadata.sampling_params.use_beam_search:
  462. continue
  463. if (metadata.sampling_params.seed is not None
  464. and expected_tokens[i] is None):
  465. # Record seeded random result to compare with results of
  466. # second invocation
  467. expected_tokens[i] = [
  468. nth_output.output_token
  469. for nth_output in sequence_output.samples
  470. ]
  471. continue
  472. expected_tokens_item = expected_tokens[i]
  473. assert expected_tokens_item is not None
  474. for n, nth_output in enumerate(sequence_output.samples):
  475. if (metadata.sampling_params.temperature == 0
  476. or metadata.sampling_params.seed is not None):
  477. # Ensure exact matches for greedy or random with seed
  478. assert nth_output.output_token == expected_tokens_item[n]
  479. else:
  480. # For non-seeded random check that one of the high-logit
  481. # tokens were chosen
  482. assert nth_output.output_token in expected_tokens_item
  483. # Test batch
  484. test_sampling()
  485. # Shuffle the batch and resample
  486. target_index = list(range(batch_size))
  487. for list_to_shuffle in (target_index, seq_group_metadata_list,
  488. expected_tokens, seq_lens):
  489. random.Random(seed).shuffle(list_to_shuffle)
  490. target_index = torch.tensor(target_index)
  491. input_tensor.data = input_tensor.index_select(0, target_index)
  492. fake_logits.data = fake_logits.index_select(0, target_index)
  493. # This time, results of seeded random samples will be compared with
  494. # the corresponding sample in the pre-shuffled batch
  495. test_sampling()
  496. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  497. @pytest.mark.parametrize("device", CUDA_DEVICES)
  498. def test_sampler_top_k_top_p(seed: int, device: str):
  499. set_random_seed(seed)
  500. batch_size = random.randint(1, 256)
  501. top_k = random.randint(100, 500)
  502. top_p = random.random() * 0.1
  503. vocab_size = 32000
  504. input_tensor = torch.rand((batch_size, 1024),
  505. device=device,
  506. dtype=torch.float16)
  507. fake_logits = torch.normal(0,
  508. 5,
  509. size=(batch_size, vocab_size),
  510. device=input_tensor.device,
  511. dtype=input_tensor.dtype)
  512. sampler = MockLogitsSampler(fake_logits)
  513. generation_model = GenerationMixin()
  514. generation_config = GenerationConfig(top_k=top_k,
  515. top_p=top_p,
  516. do_sample=True)
  517. @dataclass
  518. class MockConfig:
  519. is_encoder_decoder: bool = False
  520. generation_model.config = MockConfig() # needed by the following method
  521. generation_model._prepare_special_tokens(generation_config, device=device)
  522. processors = generation_model._get_logits_processor(generation_config,
  523. None,
  524. None,
  525. None, [],
  526. device=device)
  527. assert len(processors) == 2 # top_p and top_k
  528. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  529. seq_lens: List[int] = []
  530. for i in range(batch_size):
  531. seq_group_metadata_list.append(
  532. SequenceGroupMetadata(
  533. request_id=f"test_{i}",
  534. is_prompt=True,
  535. seq_data={0: SequenceData.from_seqs([1, 2, 3])},
  536. sampling_params=SamplingParams(
  537. temperature=1,
  538. top_k=top_k,
  539. top_p=top_p,
  540. ),
  541. block_tables={0: [1]},
  542. ))
  543. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  544. sampling_metadata = SamplingMetadata.prepare(
  545. seq_group_metadata_list,
  546. seq_lens,
  547. query_lens=seq_lens,
  548. device=device,
  549. pin_memory=is_pin_memory_available())
  550. sample_probs = None
  551. def mock_sample(probs, *args, **kwargs):
  552. nonlocal sample_probs
  553. sample_probs = probs
  554. return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
  555. for prob in probs], None)
  556. with patch("aphrodite.modeling.layers.sampler._sample", mock_sample):
  557. sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
  558. assert sample_probs is not None
  559. hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
  560. hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
  561. torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
  562. assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
  563. @pytest.mark.parametrize("device", CUDA_DEVICES)
  564. def test_sampler_repetition_penalty_mixed(device: str):
  565. vocab_size = 8
  566. def test_sampling_params(sampling_params: List[SamplingParams]):
  567. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  568. seq_lens: List[int] = []
  569. for i in range(2):
  570. seq_group_metadata_list.append(
  571. SequenceGroupMetadata(
  572. request_id=f"test_{i}",
  573. is_prompt=True,
  574. seq_data={0: SequenceData.from_seqs([1, 2, 3])},
  575. sampling_params=sampling_params[i],
  576. block_tables={0: [1]},
  577. ))
  578. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  579. sampling_metadata = SamplingMetadata.prepare(
  580. seq_group_metadata_list,
  581. seq_lens,
  582. query_lens=seq_lens,
  583. device=device,
  584. pin_memory=is_pin_memory_available())
  585. fake_logits = torch.full((2, vocab_size),
  586. 1e-2,
  587. device=device,
  588. dtype=torch.float16)
  589. fake_logits[:, 5] = 1.1e-2
  590. fake_logits[:, 1] = 1.2e-2
  591. sampler = MockLogitsSampler(fake_logits)
  592. sampler_output = sampler(logits=fake_logits,
  593. sampling_metadata=sampling_metadata)
  594. generated_tokens = []
  595. for output in sampler_output:
  596. generated_tokens.append(output.samples[0].output_token)
  597. return generated_tokens
  598. # one configuration is greedy with repetition_penalty
  599. sampling_params_rep = SamplingParams(
  600. temperature=0.0,
  601. repetition_penalty=2.0,
  602. )
  603. # other configuration is sampling w/o repetition_penalty
  604. sampling_params_sample = SamplingParams(
  605. temperature=1.0,
  606. top_k=1,
  607. seed=42,
  608. )
  609. tokens1 = test_sampling_params(
  610. [sampling_params_rep, sampling_params_sample])
  611. tokens2 = test_sampling_params(
  612. [sampling_params_sample, sampling_params_rep])
  613. assert tokens1[0] == tokens2[1]
  614. assert tokens1[1] == tokens2[0]
  615. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  616. @pytest.mark.parametrize("device", CUDA_DEVICES)
  617. def test_sampler_no_repeat_ngram(seed: int, device: str):
  618. """Test that no-repeat-ngram sampling behaves as expected."""
  619. set_random_seed(seed)
  620. torch.set_default_device(device)
  621. batch_size = random.randint(1, 256)
  622. _, fake_logits, sampler = _prepare_test(batch_size)
  623. test_sequences = {
  624. # Format: sequence: [tokens_that_should_be_blocked]
  625. (1, 2, 3): [3], # With ngram_size=2, should block 3 after [2]
  626. (4, 5, 4, 5): [4], # With ngram_size=2, should block 4 after [5]
  627. (6, 7, 8, 6, 7): [8], # With ngram_size=3, should block 8 after [6, 7]
  628. (1, 2, 3, 4, 1, 2): [3], # With ngram_size=4, should block 3 after [1, 2] # noqa: E501
  629. }
  630. for input_seq, blocked_tokens in test_sequences.items():
  631. for ngram_size in [2, 3, 4]:
  632. sampling_params = SamplingParams(
  633. temperature=1.0,
  634. no_repeat_ngram_size=ngram_size,
  635. seed=random.randint(0, 10000),
  636. )
  637. sampler_output = _do_sample(
  638. 1,
  639. fake_logits[0:1].clone(), # Just use first row
  640. sampler,
  641. sampling_params,
  642. device
  643. )
  644. if len(input_seq) >= ngram_size:
  645. # check if blocked tokens have -inf logits
  646. for token in blocked_tokens:
  647. assert sampler_output[0].samples[0].output_token != token, \
  648. f"Token {token} should have been blocked by {ngram_size}-gram repetition prevention" # noqa: E501
  649. # disabled
  650. sampling_params = SamplingParams(
  651. temperature=1.0,
  652. no_repeat_ngram_size=0,
  653. seed=random.randint(0, 10000),
  654. )
  655. sampler_output = _do_sample(
  656. 1,
  657. fake_logits[0:1].clone(),
  658. sampler,
  659. sampling_params,
  660. device
  661. )
  662. output_token = sampler_output[0].samples[0].output_token
  663. assert output_token is not None, "Should produce output token with ngram_size=0" # noqa: E501
  664. # determinism
  665. sampling_params = SamplingParams(
  666. temperature=1.0,
  667. no_repeat_ngram_size=3,
  668. seed=random.randint(0, 10000),
  669. )
  670. first_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  671. sampling_params, device)
  672. second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  673. sampling_params, device)
  674. assert first_output == second_output, \
  675. "No-repeat-ngram sampling is not deterministic with same seed"
  676. @pytest.mark.parametrize("device", CUDA_DEVICES)
  677. def test_sampler_dry(device: str):
  678. vocab_size = 8
  679. def test_sampling_params(sampling_params: List[SamplingParams]):
  680. seq_group_metadata_list: List[SequenceGroupMetadata] = []
  681. seq_lens: List[int] = []
  682. for i in range(2):
  683. seq_group_metadata_list.append(
  684. SequenceGroupMetadata(
  685. request_id=f"test_{i}",
  686. is_prompt=True,
  687. seq_data={
  688. 0: SequenceData.from_seqs([1, 2, 3, 1, 2])
  689. },
  690. sampling_params=sampling_params[i],
  691. block_tables={0: [1]},
  692. ))
  693. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  694. sampling_metadata = SamplingMetadata.prepare(
  695. seq_group_metadata_list,
  696. seq_lens,
  697. query_lens=seq_lens,
  698. device=device,
  699. pin_memory=is_pin_memory_available())
  700. fake_logits = torch.full((2, vocab_size),
  701. 1e-2,
  702. device=device,
  703. dtype=torch.float16)
  704. fake_logits[:, 3] = 1.0
  705. sampler = MockLogitsSampler(fake_logits)
  706. sampler_output = sampler(logits=fake_logits,
  707. sampling_metadata=sampling_metadata)
  708. generated_tokens = []
  709. for output in sampler_output:
  710. generated_tokens.append(output.samples[0].output_token)
  711. return generated_tokens
  712. # Test case 1: DRY disabled (multiplier = 0)
  713. sampling_params_no_dry = SamplingParams(
  714. temperature=0.0,
  715. dry_multiplier=0.0,
  716. )
  717. # Test case 2: DRY enabled with full range
  718. sampling_params_full_dry = SamplingParams(
  719. temperature=0.0,
  720. dry_multiplier=1.0,
  721. dry_allowed_length=2,
  722. dry_base=2.0,
  723. dry_range=0,
  724. )
  725. sampling_params_limited_dry = SamplingParams(
  726. temperature=0.0,
  727. dry_multiplier=1.0,
  728. dry_allowed_length=2,
  729. dry_base=2.0,
  730. dry_range=3,
  731. )
  732. tokens1 = test_sampling_params(
  733. [sampling_params_no_dry, sampling_params_full_dry])
  734. assert tokens1[0] == 3, "Without DRY, should choose highest logit token"
  735. assert tokens1[1] != 3, "With full-range DRY, should avoid repeating pattern" # noqa: E501
  736. tokens2 = test_sampling_params(
  737. [sampling_params_full_dry, sampling_params_limited_dry])
  738. assert tokens2[0] != 3, "Full-range DRY should detect full pattern"
  739. assert tokens2[1] == 3, "Limited-range DRY should only consider recent tokens" # noqa: E501
  740. tokens3 = test_sampling_params(
  741. [sampling_params_full_dry, sampling_params_limited_dry])
  742. assert tokens2 == tokens3, "DRY sampling should be deterministic"
  743. @pytest.mark.parametrize("device", CUDA_DEVICES)
  744. def test_sampler_dry_sequence_breakers(device: str):
  745. """Test that DRY respects sequence breakers."""
  746. vocab_size = 8
  747. # 7 is a sequence breaker
  748. input_sequence = [1, 2, 7, 1, 2]
  749. seq_group_metadata = SequenceGroupMetadata(
  750. request_id="test_0",
  751. is_prompt=True,
  752. seq_data={0: SequenceData.from_seqs(input_sequence)},
  753. sampling_params=SamplingParams(
  754. temperature=0.0,
  755. dry_multiplier=1.0,
  756. dry_allowed_length=2,
  757. dry_base=2.0,
  758. dry_range=0,
  759. dry_sequence_breaker_ids=[7],
  760. ),
  761. block_tables={0: [1]},
  762. )
  763. sampling_metadata = SamplingMetadata.prepare(
  764. [seq_group_metadata],
  765. seq_lens=[len(input_sequence)],
  766. query_lens=[len(input_sequence)],
  767. device=device,
  768. pin_memory=is_pin_memory_available())
  769. fake_logits = torch.full((1, vocab_size),
  770. 1e-2,
  771. device=device,
  772. dtype=torch.float16)
  773. fake_logits[0, 3] = 1.0
  774. sampler = MockLogitsSampler(fake_logits)
  775. sampler_output = sampler(logits=fake_logits,
  776. sampling_metadata=sampling_metadata)
  777. assert sampler_output[0].samples[0].output_token == 3, \
  778. "DRY should not detect patterns across sequence breakers"
  779. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  780. @pytest.mark.parametrize("device", CUDA_DEVICES)
  781. def test_sampler_nsigma(seed: int, device: str):
  782. """Test that top-nsigma sampling behaves as expected."""
  783. set_random_seed(seed)
  784. torch.set_default_device(device)
  785. batch_size = random.randint(1, 256)
  786. _, fake_logits, sampler = _prepare_test(batch_size)
  787. # Create a clear separation in logits for testing
  788. high_logit_indices = {} # Store high logit indices for each batch
  789. for i in range(batch_size):
  790. # Set a few logits significantly higher than others
  791. num_high_logits = random.randint(1, 5)
  792. high_indices = random.sample(range(fake_logits.size(1)),
  793. num_high_logits)
  794. high_logit_indices[i] = set(high_indices) # Store for verification
  795. for idx in high_indices:
  796. fake_logits[i, idx] = 10.0 # Clearly above the mean
  797. # Test with different nsigma values
  798. for nsigma in [1.5, 2.0, 3.0]:
  799. sampling_params = SamplingParams(
  800. temperature=1.0,
  801. nsigma=nsigma,
  802. seed=random.randint(0, 10000),
  803. )
  804. sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  805. sampling_params, device)
  806. # Verify that sampling only selects from high logits
  807. for batch_idx, sequence_output in enumerate(sampler_output):
  808. for nth_output in sequence_output.samples:
  809. token_id = nth_output.output_token
  810. # The token should come from the high logits region
  811. assert token_id in high_logit_indices[batch_idx], \
  812. f"Sampled token {token_id} for batch {batch_idx} was not in the high logit set" # noqa
  813. # Test determinism
  814. second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  815. sampling_params, device)
  816. assert sampler_output == second_output, \
  817. "Top-nsigma sampling is not deterministic with same seed"
  818. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  819. @pytest.mark.parametrize("device", CUDA_DEVICES)
  820. def test_sampler_skew(seed: int, device: str):
  821. """Test that skew sampling behaves as expected."""
  822. set_random_seed(seed)
  823. torch.set_default_device(device)
  824. batch_size = random.randint(1, 256)
  825. _, fake_logits, sampler = _prepare_test(batch_size)
  826. high_prob_tokens = {}
  827. for i in range(batch_size):
  828. # Make token i have a much higher logit in sequence i
  829. fake_logits[i, i] = 10.0
  830. high_prob_tokens[i] = i
  831. test_cases = [
  832. # (skew, expected_behavior)
  833. (2.0, "low"), # Strong bias away from high probability tokens
  834. (0.5, "subtle"), # Subtle bias away from high probability tokens
  835. (0.0, "neutral"), # No bias (regular sampling)
  836. ]
  837. for skew, expected_behavior in test_cases:
  838. sampling_params = SamplingParams(
  839. temperature=1.0, # neutral temperature
  840. skew=skew,
  841. seed=random.randint(0, 10000), # for determinism
  842. )
  843. sampler_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  844. sampling_params, device)
  845. for batch_idx, sequence_output in enumerate(sampler_output):
  846. token_id = sequence_output.samples[0].output_token
  847. if expected_behavior == "low":
  848. # strong skew should bias away from high probability tokens
  849. assert token_id != high_prob_tokens[batch_idx], \
  850. f"With high skew {skew}, should not select high " \
  851. f"probability token {high_prob_tokens[batch_idx]}"
  852. elif expected_behavior == "subtle":
  853. # we don't assert anything for subtle effect,
  854. # as it's probabilistic
  855. pass
  856. # determinism
  857. second_output = _do_sample(batch_size, fake_logits.clone(), sampler,
  858. sampling_params, device)
  859. assert sampler_output == second_output, \
  860. f"Skew sampling with seed is not deterministic for skew={skew}"
  861. @pytest.mark.parametrize("device", CUDA_DEVICES)
  862. def test_sampler_include_gpu_probs_tensor(device: str):
  863. set_random_seed(42)
  864. torch.set_default_device(device)
  865. batch_size = random.randint(1, 256)
  866. _, fake_logits, sampler = _prepare_test(batch_size)
  867. sampler.include_gpu_probs_tensor = True
  868. sampler.should_modify_greedy_probs_inplace = False
  869. sampling_params = SamplingParams(temperature=0)
  870. mock_inplace = Mock()
  871. with patch(
  872. "aphrodite.modeling.layers.sampler._modify_greedy_probs_inplace",
  873. mock_inplace):
  874. sampler_output = _do_sample(batch_size, fake_logits, sampler,
  875. sampling_params, device)
  876. mock_inplace.assert_not_called()
  877. assert sampler_output.sampled_token_probs is not None
  878. assert sampler_output.logprobs is not None
  879. assert sampler_output.sampled_token_ids is not None