test_rejection_sampling.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. """Tests for rejection sampling."""
  2. import pytest
  3. from typing import List, Tuple
  4. import torch
  5. import torch.nn.functional as F
  6. from aphrodite.modeling.utils import set_random_seed
  7. from aphrodite.modeling.layers.rejection import RejectionSampler
  8. def mock_causal_accepted_tensor(
  9. k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
  10. """Generate an "accepted" tensor which should yield causally-accepted tokens
  11. up to last accepted indices.
  12. Tokens after last_accepted_indices+1 may also be accepted, although they
  13. will not be causally accepted.
  14. """
  15. batch_size = last_accepted_indices.shape[0]
  16. accepted = (torch.arange(k).expand(batch_size, k) <=
  17. last_accepted_indices.unsqueeze(-1).broadcast_to(
  18. batch_size, k)).to(device="cuda")
  19. # Sprinkle accepted values after the contiguous initial accepted values.
  20. # This replicates the behavior of rejection sampling, which may "accept"
  21. # a token that cannot be accepted because of causality.
  22. sprinkle_candidates = (
  23. torch.arange(k).expand(batch_size, k) >
  24. last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
  25. sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
  26. accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
  27. return accepted
  28. @pytest.mark.parametrize("seed", list(range(10)))
  29. @pytest.mark.parametrize(
  30. "which_tokens_accepted",
  31. ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
  32. @torch.inference_mode()
  33. def test_correct_output_format(which_tokens_accepted: str, seed: int):
  34. """Verify the output has correct format given predetermined accepted matrix.
  35. """
  36. set_random_seed(seed)
  37. batch_size = 10
  38. k = 5
  39. vocab_size = 3000
  40. if which_tokens_accepted == "all_tokens_accepted":
  41. accepted = mock_causal_accepted_tensor(
  42. k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
  43. elif which_tokens_accepted == "no_tokens_accepted":
  44. accepted = mock_causal_accepted_tensor(
  45. k, -torch.ones((batch_size, ), dtype=torch.long))
  46. elif which_tokens_accepted == "some_tokens_accepted":
  47. last_accepted_indices = torch.randint(low=-1,
  48. high=k,
  49. size=(batch_size, ))
  50. accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
  51. else:
  52. raise AssertionError()
  53. recovered_token_ids = torch.randint(low=0,
  54. high=vocab_size,
  55. size=(batch_size, k),
  56. dtype=torch.int64,
  57. device="cuda")
  58. draft_token_ids = torch.randint(low=0,
  59. high=vocab_size,
  60. size=(batch_size, k),
  61. dtype=torch.int64,
  62. device="cuda")
  63. bonus_token_ids = torch.randint(low=0,
  64. high=vocab_size,
  65. size=(batch_size, 1),
  66. dtype=torch.int64,
  67. device="cuda")
  68. rejection_sampler = RejectionSampler()
  69. rejection_sampler.init_gpu_tensors(rank=0)
  70. output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
  71. accepted,
  72. recovered_token_ids,
  73. draft_token_ids,
  74. bonus_token_ids,
  75. )
  76. if which_tokens_accepted == "all_tokens_accepted":
  77. # Expect all tokens to be equal to draft tokens.
  78. assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
  79. # Expect all bonus tokens to be included.
  80. assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
  81. elif which_tokens_accepted == "no_tokens_accepted":
  82. # Expect first token to be equal to recovered tokens.
  83. assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
  84. # Expect everything else to be -1.
  85. assert torch.equal(output_token_ids[:, 1:],
  86. torch.ones_like(output_token_ids[:, 1:]) * -1)
  87. elif which_tokens_accepted == "some_tokens_accepted":
  88. recovered_plus_bonus = torch.cat(
  89. (recovered_token_ids, bonus_token_ids), dim=-1)
  90. # Assert first rejected token is a recovered token or bonus token.
  91. assert torch.equal(
  92. recovered_plus_bonus[torch.arange(0, batch_size),
  93. last_accepted_indices + 1],
  94. output_token_ids[torch.arange(0, batch_size),
  95. last_accepted_indices + 1])
  96. # Assert every subsequent token is -1.
  97. subsequent_mask = torch.arange(0, k + 1).expand(
  98. batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
  99. assert torch.all(output_token_ids[subsequent_mask] == -1)
  100. @pytest.mark.parametrize("k", list(range(1, 6)))
  101. @pytest.mark.parametrize("vocab_size", [30_000, 50_000])
  102. @pytest.mark.parametrize("batch_size", list(range(1, 32)))
  103. @torch.inference_mode()
  104. def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
  105. rejection_sampler = RejectionSampler()
  106. rejection_sampler.init_gpu_tensors(rank=0)
  107. draft_probs = torch.rand(batch_size,
  108. k,
  109. vocab_size,
  110. dtype=torch.float32,
  111. device="cuda")
  112. target_probs = torch.rand(batch_size,
  113. k,
  114. vocab_size,
  115. dtype=torch.float32,
  116. device="cuda")
  117. bonus_token_ids = torch.randint(low=0,
  118. high=vocab_size,
  119. size=(batch_size, 1),
  120. dtype=torch.int64,
  121. device="cuda")
  122. draft_token_ids = torch.randint(low=0,
  123. high=vocab_size,
  124. size=(batch_size, k),
  125. dtype=torch.int64,
  126. device="cuda")
  127. rejection_sampler(target_probs, bonus_token_ids, draft_probs,
  128. draft_token_ids)
  129. @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
  130. @pytest.mark.parametrize("which_token_ids",
  131. ["bonus_token_ids", "draft_token_ids"])
  132. @torch.inference_mode()
  133. def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
  134. which_token_ids: str):
  135. k = 3
  136. batch_size = 5
  137. vocab_size = 30_000
  138. rejection_sampler = RejectionSampler(strict_mode=True)
  139. rejection_sampler.init_gpu_tensors(rank=0)
  140. draft_probs = torch.rand(batch_size,
  141. k,
  142. vocab_size,
  143. dtype=torch.float32,
  144. device="cuda")
  145. target_probs = torch.rand(batch_size,
  146. k,
  147. vocab_size,
  148. dtype=torch.float32,
  149. device="cuda")
  150. bonus_token_ids = torch.randint(low=0,
  151. high=vocab_size,
  152. size=(batch_size, 1),
  153. dtype=torch.int64,
  154. device="cuda")
  155. draft_token_ids = torch.randint(low=0,
  156. high=vocab_size,
  157. size=(batch_size, k),
  158. dtype=torch.int64,
  159. device="cuda")
  160. oob_token_ids = None
  161. if which_token_ids == "bonus_token_ids":
  162. oob_token_ids = bonus_token_ids
  163. elif which_token_ids == "draft_token_ids":
  164. oob_token_ids = draft_token_ids
  165. else:
  166. raise AssertionError()
  167. if above_or_below_vocab_range == "above":
  168. rogue_token_id = vocab_size + 1
  169. elif above_or_below_vocab_range == "below":
  170. rogue_token_id = -1
  171. else:
  172. raise AssertionError()
  173. oob_token_ids[0][0] = rogue_token_id
  174. with pytest.raises(AssertionError):
  175. rejection_sampler(target_probs, bonus_token_ids, draft_probs,
  176. draft_token_ids)
  177. @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
  178. @pytest.mark.parametrize("seed", list(range(5)))
  179. @torch.inference_mode()
  180. def test_rejection_sampling_approximates_target_distribution(
  181. seed: int, draft_and_target_probs_equal: bool):
  182. """Verify rejection sampling approximates target distribution,
  183. despite sampling from a potentially distinct draft distribution.
  184. This is done by first creating a random target probability
  185. distribution and a random draft probability distribution. We then
  186. sample token ids from the rejection sampler using these draft
  187. and target distributions. The samples are used to estimate
  188. the output probability distribution, which we expect to approximate
  189. the target distribution.
  190. A basic distance metric is used to determine similarity between
  191. distributions.
  192. We expect that as we increase the number of samples,
  193. the distance between the observed distribution and the target
  194. distribution decreases. To measure this, we compare the distance
  195. of the observed distribution against both the target distribution
  196. and a uniform random distribution. We expect the distance between
  197. the observed distribution and the target distribution to improve
  198. much more than the distance improvement between the observed
  199. distribution and the random distribution.
  200. When draft_and_target_probs_equal=True, the draft and target
  201. probabilities are exactly equal. Rejection sampling should
  202. still work without any NaNs or exceptions.
  203. """
  204. set_random_seed(seed)
  205. helper = _CorrectnessTestHelper(
  206. vocab_size=10,
  207. rejection_sampler=RejectionSampler(),
  208. )
  209. draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
  210. draft_and_target_probs_equal)
  211. sample_sizes = [10, 100, 1_000, 10_000, 100_000]
  212. distance_wrt_reference = []
  213. distance_wrt_target = []
  214. for num_samples in sample_sizes:
  215. (reference_vs_rejsample_dist,
  216. target_vs_rejsample_dist) = helper.run_and_compare_distributions(
  217. draft_probs,
  218. target_probs,
  219. reference_probs,
  220. num_samples,
  221. )
  222. distance_wrt_reference.append(reference_vs_rejsample_dist)
  223. distance_wrt_target.append(target_vs_rejsample_dist)
  224. relative_change_in_distance_wrt_target = get_ratio_first_to_last(
  225. distance_wrt_target)
  226. relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
  227. distance_wrt_reference)
  228. print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
  229. f"{reference_vs_rejsample_dist=:.05f}")
  230. print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
  231. f"{relative_change_in_distance_wrt_reference=:.02f}")
  232. relative_change_in_distance_wrt_target = get_ratio_first_to_last(
  233. distance_wrt_target)
  234. relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
  235. distance_wrt_reference)
  236. expected_improvement_multiplier = 20
  237. assert (relative_change_in_distance_wrt_target >
  238. relative_change_in_distance_wrt_reference *
  239. expected_improvement_multiplier)
  240. def get_ratio_first_to_last(elements: List[float]) -> float:
  241. return elements[0] / elements[-1]
  242. class _CorrectnessTestHelper:
  243. """Class that packages together logic required for the unit-level
  244. rejection sampling correctness test.
  245. """
  246. def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
  247. self.rejection_sampler = rejection_sampler
  248. self.vocab_size = vocab_size
  249. self.vocab_range = (0, vocab_size)
  250. self.rejection_sampler.init_gpu_tensors(rank=0)
  251. # Keep test simple, use k=1
  252. self.k = 1
  253. # Bonus tokens not used, but rejection sampler requires
  254. # correct shape.
  255. self.num_bonus_tokens = 1
  256. def generate_probs_for_test(
  257. self, draft_and_target_probs_equal: bool
  258. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  259. draft_probs, target_probs = [
  260. F.softmax(
  261. torch.rand(self.vocab_size, dtype=torch.float32),
  262. dim=-1,
  263. ) for _ in range(2)
  264. ]
  265. num_reference_probs = 100
  266. reference_probs = F.softmax(
  267. torch.rand(num_reference_probs,
  268. self.vocab_size,
  269. dtype=torch.float32),
  270. dim=-1,
  271. )
  272. if draft_and_target_probs_equal:
  273. target_probs = draft_probs.clone()
  274. return draft_probs, target_probs, reference_probs
  275. def run_and_compare_distributions(self, draft_probs: torch.Tensor,
  276. target_probs: torch.Tensor,
  277. reference_probs: torch.Tensor,
  278. num_samples: int) -> Tuple[float, float]:
  279. # Sample using rejection sampling.
  280. rej_sample_probs = self._estimate_rejection_sampling_pdf(
  281. draft_probs, target_probs, num_samples)
  282. # Average distance from reference probs.
  283. reference_vs_rejsample_dist = torch.dist(
  284. reference_probs,
  285. rej_sample_probs).item() / reference_probs.shape[0]
  286. target_vs_rejsample_dist = torch.dist(target_probs,
  287. rej_sample_probs).item()
  288. return reference_vs_rejsample_dist, target_vs_rejsample_dist
  289. def _estimate_rejection_sampling_pdf(
  290. self,
  291. draft_probs: torch.Tensor,
  292. target_probs: torch.Tensor,
  293. num_samples: int,
  294. ) -> torch.Tensor:
  295. # Repeat draft probs num_samples times.
  296. draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
  297. num_samples, 1, 1)
  298. # Repeat target probs num_samples * k times.
  299. # Rejection sampler requires bonus token probs, but they aren't used.
  300. target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
  301. num_samples, self.k, 1)
  302. # Randomly sample draft token ids from draft probs.
  303. draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
  304. num_samples=1,
  305. replacement=True).reshape(
  306. num_samples, self.k)
  307. # Bonus tokens not used but required.
  308. bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
  309. dtype=torch.int64,
  310. device="cuda").repeat(num_samples, 1)
  311. # Get output tokens via rejection sampling.
  312. output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
  313. bonus_token_ids.to("cuda"),
  314. draft_probs.to("cuda"),
  315. draft_token_ids.to("cuda"))
  316. # Remove bonus tokens
  317. output_token_ids = output_token_ids[:, :-1].flatten()
  318. # Estimate probability density function
  319. hist = torch.histogram(output_token_ids.to(dtype=torch.float,
  320. device="cpu"),
  321. bins=self.vocab_size,
  322. range=self.vocab_range,
  323. density=True)
  324. return hist.hist