test_typical_acceptance_sampler.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. """Tests for rejection sampling."""
  2. import pytest
  3. import torch
  4. from aphrodite.modeling.layers.typical_acceptance_sampler import (
  5. TypicalAcceptanceSampler)
  6. from aphrodite.modeling.utils import set_random_seed
  7. CUDA_DEVICES = [f"cuda:{i}" for i in range(1)]
  8. def get_zero_temperature_prob_dist(batch_size, k, vocab_size):
  9. """
  10. Generates a fake temperature zero probability distribution.
  11. Returns:
  12. 1. A fake temperature zero probability distribution of shape
  13. [batch_size, k, vocab_size]
  14. 2. Tensor of shape [batch_size, k] containing the token ids
  15. of the probability 1.0 tokens at each position.
  16. """
  17. # Simulate temperature 0 probability distribution for target probabilities
  18. # and create target probabilities such that only 1 token id has
  19. # probability 1.0
  20. target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
  21. probs = torch.rand(batch_size, k, vocab_size)
  22. _, zero_temperature_token_ids = torch.max(probs, dim=-1)
  23. # set the probability of the tokens with ids in zero_temperature_token_ids
  24. # to 1 and the rest to 0.
  25. target_probs = torch.zeros_like(probs).scatter_(
  26. -1, zero_temperature_token_ids.unsqueeze(-1), 1.0)
  27. return target_probs, zero_temperature_token_ids
  28. def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
  29. token_ids_to_exclude: torch.Tensor):
  30. """
  31. Returns a tensor of shape [batch_size, k] of fake draft token ids
  32. drawn randomly from a vocab of size vocab_size. We however ensure
  33. that token_ids from token_ids_to_exclude are excluded at the
  34. corresponding positions.
  35. """
  36. draft_token_ids = torch.empty(batch_size, k, dtype=torch.long)
  37. for i in range(batch_size):
  38. for j in range(k):
  39. # Generate a random token ID excluding token_ids_to_exclude[i, j]
  40. while True:
  41. token_id = torch.randint(0, vocab_size, (1, )).item()
  42. if token_id != token_ids_to_exclude[i, j]:
  43. draft_token_ids[i, j] = token_id
  44. break
  45. return draft_token_ids
  46. def get_acceptance_sampler(
  47. posterior_threshold: float = 0.03,
  48. posterior_alpha: float = 0.9,
  49. disable_bonus_tokens: bool = False,
  50. strict_mode: bool = False,
  51. ) -> TypicalAcceptanceSampler:
  52. """
  53. Initializes and returns a TypicalAcceptanceSampler.
  54. """
  55. return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
  56. disable_bonus_tokens, strict_mode)
  57. @pytest.mark.parametrize("k", list(range(1, 6)))
  58. @pytest.mark.parametrize("vocab_size", [30_000, 50_000])
  59. @pytest.mark.parametrize("batch_size", list(range(1, 32)))
  60. @pytest.mark.parametrize("device", CUDA_DEVICES)
  61. @torch.inference_mode()
  62. def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
  63. device: str):
  64. """
  65. Tests that the TypicalAcceptancSampler forward succeeds for
  66. different combinations of k, vocab_size, batch_size and num devices.
  67. """
  68. torch.set_default_device(device)
  69. typical_acceptance_sampler = get_acceptance_sampler()
  70. typical_acceptance_sampler.init_gpu_tensors(device=device)
  71. target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
  72. bonus_token_ids = torch.randint(low=0,
  73. high=vocab_size,
  74. size=(batch_size, 1),
  75. dtype=torch.int64)
  76. draft_token_ids = torch.randint(low=0,
  77. high=vocab_size,
  78. size=(batch_size, k),
  79. dtype=torch.int64)
  80. # Verify that sampling succeeds for all cases.
  81. typical_acceptance_sampler(target_probs,
  82. bonus_token_ids,
  83. draft_probs=None,
  84. draft_token_ids=draft_token_ids)
  85. @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
  86. @pytest.mark.parametrize("which_token_ids",
  87. ["bonus_token_ids", "draft_token_ids"])
  88. @pytest.mark.parametrize("device", CUDA_DEVICES)
  89. @torch.inference_mode()
  90. def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
  91. which_token_ids: str, device: str):
  92. """
  93. Tests that we throw an exception of the token ids fall outside
  94. the bound of the provided vocabulary.
  95. """
  96. k = 3
  97. batch_size = 5
  98. vocab_size = 30_000
  99. torch.set_default_device(device)
  100. typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
  101. typical_acceptance_sampler.init_gpu_tensors(device=device)
  102. target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
  103. bonus_token_ids = torch.randint(low=0,
  104. high=vocab_size,
  105. size=(batch_size, 1),
  106. dtype=torch.int64)
  107. draft_token_ids = torch.randint(low=0,
  108. high=vocab_size,
  109. size=(batch_size, k),
  110. dtype=torch.int64)
  111. # Verify that appropriate exceptions are thrown for out
  112. # of bound vocabs.
  113. oob_token_ids = None
  114. if which_token_ids == "bonus_token_ids":
  115. oob_token_ids = bonus_token_ids
  116. elif which_token_ids == "draft_token_ids":
  117. oob_token_ids = draft_token_ids
  118. else:
  119. raise AssertionError()
  120. if above_or_below_vocab_range == "above":
  121. rogue_token_id = vocab_size + 1
  122. elif above_or_below_vocab_range == "below":
  123. rogue_token_id = -1
  124. else:
  125. raise AssertionError()
  126. oob_token_ids[0][0] = rogue_token_id
  127. with pytest.raises(AssertionError):
  128. typical_acceptance_sampler(target_probs,
  129. bonus_token_ids,
  130. draft_probs=None,
  131. draft_token_ids=draft_token_ids)
  132. @pytest.mark.parametrize("seed", list(range(10)))
  133. @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
  134. @pytest.mark.parametrize("device", CUDA_DEVICES)
  135. @torch.inference_mode()
  136. def test_uniform_target_distribution_accepts_all_tokens(
  137. seed: int, disable_bonus_tokens: bool, device: str):
  138. """
  139. Test the TypicalAcceptanceSampler with a uniform target probability
  140. distribution.
  141. This test verifies that when provided with a uniform target probability
  142. distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
  143. entropy of the uniform target distribution being high should lead to all
  144. draft tokens being accepted. The test also ensures that the behavior
  145. regarding bonus tokens is consistent with the `disable_bonus_tokens`
  146. flag.
  147. """
  148. set_random_seed(seed)
  149. k = 3
  150. batch_size = 5
  151. vocab_size = 30_000
  152. torch.set_default_device(device)
  153. typical_acceptance_sampler = get_acceptance_sampler(
  154. strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
  155. typical_acceptance_sampler.init_gpu_tensors(device=device)
  156. target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
  157. draft_token_ids = torch.randint(low=0,
  158. high=vocab_size,
  159. size=(batch_size, k),
  160. dtype=torch.int64)
  161. bonus_token_ids = torch.randint(low=0,
  162. high=vocab_size,
  163. size=(batch_size, 1),
  164. dtype=torch.int64)
  165. output_token_ids = typical_acceptance_sampler(
  166. target_probs,
  167. bonus_token_ids,
  168. draft_probs=None,
  169. draft_token_ids=draft_token_ids)
  170. # We are using a uniform target probability distribution.
  171. # For a uniform distribution the entropy is very high and it
  172. # should lead to all draft tokens being accepted. Verify that.
  173. assert output_token_ids.shape[0] == batch_size
  174. assert output_token_ids.shape[1] == (k + 1)
  175. if disable_bonus_tokens:
  176. assert torch.all(output_token_ids[:, -1] == -1)
  177. else:
  178. assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
  179. assert torch.all(output_token_ids[:, :k] == draft_token_ids)
  180. @pytest.mark.parametrize("seed", list(range(10)))
  181. @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
  182. @pytest.mark.parametrize("device", CUDA_DEVICES)
  183. @torch.inference_mode()
  184. def test_temperature_zero_target_distribution(seed: int,
  185. disable_bonus_tokens: bool,
  186. device: str):
  187. """
  188. Test the TypicalAcceptanceSampler with a zero-temperature target
  189. probability distribution.
  190. This test verifies that when using a zero-temperature target probability
  191. distribution, where only one token has a probability of 1.0, the
  192. TypicalAcceptanceSampler correctly rejects all draft tokens that do not
  193. match this probability. Additionally, it ensures that when all draft
  194. tokens are rejected, the sampler falls back to greedy sampling to select a
  195. single token from the target distribution.
  196. """
  197. set_random_seed(seed)
  198. k = 3
  199. batch_size = 5
  200. vocab_size = 30_000
  201. torch.set_default_device(device)
  202. typical_acceptance_sampler = get_acceptance_sampler(
  203. strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
  204. typical_acceptance_sampler.init_gpu_tensors(device=device)
  205. # Simulate temperature 0 probability distribution for target probabilities
  206. # and create target probabilities such that only 1 token id has
  207. # probability 1.0
  208. target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
  209. batch_size, k, vocab_size)
  210. # Populate draft_token_ids such that they exclude the token_ids
  211. # with probability = 1.0
  212. draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
  213. zero_temperature_token_ids)
  214. bonus_token_ids = torch.randint(low=0,
  215. high=vocab_size,
  216. size=(batch_size, 1),
  217. dtype=torch.int64)
  218. # The target probaility distribution is a temperature zero distribution
  219. # with zero entroy. Since our draft token ids don't match the probability
  220. # 1.0 tokens in the target distribution we will reject all of them and
  221. # fallback to the greedy sampling for selecting 1 token for each sequence.
  222. # Verify the same.
  223. output_token_ids = typical_acceptance_sampler(
  224. target_probs,
  225. bonus_token_ids,
  226. draft_probs=None,
  227. draft_token_ids=draft_token_ids)
  228. assert output_token_ids.shape[0] == batch_size
  229. assert output_token_ids.shape[1] == (k + 1)
  230. assert torch.all(output_token_ids[:, -1] == -1)
  231. assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:,
  232. 0])
  233. @pytest.mark.parametrize("seed", list(range(10)))
  234. @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
  235. @pytest.mark.parametrize("device", CUDA_DEVICES)
  236. @torch.inference_mode()
  237. def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
  238. device: str):
  239. """
  240. Test the TypicalAcceptanceSampler with a mixed target probability
  241. distribution.
  242. This test ensures that the TypicalAcceptanceSampler handles a mixed
  243. target probability distribution correctly. Specifically, it uses a
  244. zero-temperature distribution for some sequences and a uniform
  245. distribution for others. The test verifies that:
  246. - For sequences with a zero-temperature distribution, only the token
  247. with a probability of 1.0 is accepted, and all other tokens are rejected.
  248. - For sequences with a uniform distribution, all draft tokens are
  249. accepted.
  250. - When `disable_bonus_tokens` is False, the bonus tokens are also accepted
  251. for sequences with a uniform distribution.
  252. """
  253. set_random_seed(seed)
  254. k = 3
  255. batch_size = 4
  256. vocab_size = 30_000
  257. torch.set_default_device(device)
  258. typical_acceptance_sampler = get_acceptance_sampler(
  259. strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
  260. typical_acceptance_sampler.init_gpu_tensors(device=device)
  261. # For sequences 0 and 2 set the distribution to a temperature
  262. # zero distribution. For sequences 1 and 3 set it to a uniform
  263. # distribution.
  264. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
  265. batch_size, k, vocab_size))
  266. draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
  267. zero_temperature_token_ids)
  268. uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
  269. target_probs[[1, 3]] = uniform_probs
  270. bonus_token_ids = torch.randint(low=0,
  271. high=vocab_size,
  272. size=(batch_size, 1),
  273. dtype=torch.int64)
  274. output_token_ids = typical_acceptance_sampler(
  275. target_probs,
  276. bonus_token_ids,
  277. draft_probs=None,
  278. draft_token_ids=draft_token_ids)
  279. # verify the shape of output_token_ids
  280. assert output_token_ids.shape[0] == batch_size
  281. assert output_token_ids.shape[1] == (k + 1)
  282. # For sequences 0 and 2 verify that only 1 token is accepted
  283. # which is the token with probability 1.0 in the target distribution
  284. # at position 0.
  285. assert torch.all(output_token_ids[[0, 2], 1:] == -1)
  286. assert (torch.all(output_token_ids[[0, 2],
  287. 0] == zero_temperature_token_ids[[0, 2],
  288. 0]))
  289. # For sequences 1 and 3 verify that all tokens are accepted since the
  290. # target probability distribution is uniform. In addition verify that
  291. # if disable_bonus_tokens is false then we also accept the bonus tokens.
  292. assert torch.all(
  293. output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
  294. if disable_bonus_tokens:
  295. assert torch.all(output_token_ids[[1, 3], -1] == -1)
  296. else:
  297. assert torch.all(output_token_ids[[1, 3], -1] != -1)
  298. @pytest.mark.parametrize("seed", list(range(10)))
  299. @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
  300. @pytest.mark.parametrize("device", CUDA_DEVICES)
  301. @torch.inference_mode()
  302. def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
  303. device: str):
  304. """
  305. Test the TypicalAcceptanceSampler's behavior when only a subset of draft
  306. tokens should be accepted.
  307. This test verifies that the TypicalAcceptanceSampler correctly accepts or
  308. rejects draft tokens based on a zero-temperature target probability
  309. distribution. Specifically, it ensures that:
  310. - When all draft tokens match tokens with a probability of 1.0 in the
  311. target distribution, all draft tokens are accepted.
  312. - When only some draft tokens match tokens with a probability of 1.0 in
  313. the target distribution, only those matching tokens are accepted, and the
  314. rest are rejected.
  315. """
  316. set_random_seed(seed)
  317. k = 5
  318. batch_size = 1
  319. vocab_size = 30_000
  320. torch.set_default_device(device)
  321. typical_acceptance_sampler = get_acceptance_sampler(
  322. strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
  323. typical_acceptance_sampler.init_gpu_tensors(device=device)
  324. # Create a temperature zero target probability distribution and ensure
  325. # all draft token ids correspond to the tokens with 1.0 probability.
  326. # Verify that all of them are accepted.
  327. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
  328. batch_size, k, vocab_size))
  329. draft_token_ids = zero_temperature_token_ids
  330. bonus_token_ids = torch.randint(low=0,
  331. high=vocab_size,
  332. size=(batch_size, 1),
  333. dtype=torch.int64)
  334. output_token_ids = typical_acceptance_sampler(
  335. target_probs,
  336. bonus_token_ids,
  337. draft_probs=None,
  338. draft_token_ids=draft_token_ids)
  339. assert output_token_ids.shape[0] == batch_size
  340. assert output_token_ids.shape[1] == (k + 1)
  341. assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
  342. if disable_bonus_tokens:
  343. assert torch.all(output_token_ids[:, -1] == -1)
  344. else:
  345. assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
  346. # Next only keep the first 2 draft tokens same as the zero temperature
  347. # tokens. For the remaining 3 choose some other tokens. In the
  348. # response we will expect the first 2 tokens to be the same as the
  349. # draft tokens and the rest as -1
  350. draft_token_ids_to_replace = get_draft_token_ids(
  351. batch_size, k, vocab_size, zero_temperature_token_ids)
  352. draft_token_ids = torch.cat(
  353. (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
  354. output_token_ids = typical_acceptance_sampler(
  355. target_probs,
  356. bonus_token_ids,
  357. draft_probs=None,
  358. draft_token_ids=draft_token_ids)
  359. assert output_token_ids.shape[0] == batch_size
  360. assert output_token_ids.shape[1] == (k + 1)
  361. assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
  362. assert torch.all(output_token_ids[:, -3:] == -1)
  363. @pytest.mark.parametrize("seed", list(range(1)))
  364. @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
  365. @pytest.mark.parametrize("device", CUDA_DEVICES)
  366. @torch.inference_mode()
  367. def test_accept_tokens_set_non_default_posteriors(seed: int,
  368. disable_bonus_tokens: bool,
  369. device: str):
  370. """
  371. Test the TypicalAcceptanceSampler with custom posterior thresholds and
  372. alpha values. This test verifies that by modifying the posterior
  373. thresholds and alpha values we can change the acceptance behavior of the
  374. sampler.
  375. """
  376. set_random_seed(seed)
  377. k = 5
  378. batch_size = 1
  379. vocab_size = 30_000
  380. torch.set_default_device(device)
  381. typical_acceptance_sampler = get_acceptance_sampler(
  382. strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
  383. typical_acceptance_sampler.init_gpu_tensors(device=device)
  384. # Simulate temperature 0 probability distribution for target
  385. # probabilities and create target probabilities such that only 1 token
  386. # id has probability 1.0 and others have a very low probability of
  387. # 0.00001. Populate draft_token_ids such that they exclude the token_ids
  388. # with probability = 1.0. Without any changes to the posterior thresholds
  389. # none of the draft tokens are accepted.
  390. target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
  391. batch_size, k, vocab_size))
  392. target_probs[target_probs == 0] = 0.00001
  393. draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
  394. zero_temperature_token_ids)
  395. bonus_token_ids = torch.randint(low=0,
  396. high=vocab_size,
  397. size=(batch_size, 1),
  398. dtype=torch.int64)
  399. output_token_ids = typical_acceptance_sampler(
  400. target_probs,
  401. bonus_token_ids,
  402. draft_probs=None,
  403. draft_token_ids=draft_token_ids)
  404. assert output_token_ids.shape[0] == batch_size
  405. assert output_token_ids.shape[1] == (k + 1)
  406. assert torch.all(output_token_ids[:, 1:-1] == -1)
  407. # Change the posterior threshold values to 0.0 so that we will
  408. # now accept even draft tokens with very low probability in the
  409. # target distribution. Simulate and verify the same.
  410. typical_acceptance_sampler = TypicalAcceptanceSampler(
  411. strict_mode=True,
  412. disable_bonus_tokens=disable_bonus_tokens,
  413. posterior_threshold=0.0,
  414. posterior_alpha=0.0)
  415. typical_acceptance_sampler.init_gpu_tensors(device=device)
  416. output_token_ids = typical_acceptance_sampler(
  417. target_probs,
  418. bonus_token_ids,
  419. draft_probs=None,
  420. draft_token_ids=draft_token_ids)
  421. assert output_token_ids.shape[0] == batch_size
  422. assert output_token_ids.shape[1] == (k + 1)
  423. assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
  424. if disable_bonus_tokens:
  425. assert torch.all(output_token_ids[:, -1] == -1)
  426. else:
  427. assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
  428. @pytest.mark.parametrize("seed", list(range(10)))
  429. @pytest.mark.parametrize("disable_bonus_tokens", [True, False])
  430. @pytest.mark.parametrize("device", CUDA_DEVICES)
  431. @torch.inference_mode()
  432. def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
  433. device: str):
  434. """
  435. Test the TypicalAcceptanceSampler's method for generating
  436. replacement token IDs.
  437. This test verifies that the `_replacement_token_ids` method of the
  438. TypicalAcceptanceSampler correctly identifies the token IDs to be used
  439. as replacements based on the target probability distribution.
  440. Specifically, it ensures that the method correctly identifies the
  441. tokens with the highest probability for each sequence in the batch.
  442. """
  443. set_random_seed(seed)
  444. k = 10
  445. batch_size = 5
  446. vocab_size = 30_000
  447. torch.set_default_device(device)
  448. typical_acceptance_sampler = get_acceptance_sampler(
  449. strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
  450. typical_acceptance_sampler.init_gpu_tensors(device=device)
  451. target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
  452. expected_replacement_tokens = -torch.ones(
  453. (batch_size, k), dtype=torch.long)
  454. expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :],
  455. dim=1)
  456. actual_replacement_tokens = (
  457. typical_acceptance_sampler._replacement_token_ids(target_probs))
  458. assert torch.all(expected_replacement_tokens == actual_replacement_tokens)