import gc from unittest.mock import patch import pytest import torch import triton import triton.language as tl from aphrodite.modeling.layers.ops.sample import (_sample_triton, _uniform_to_exponential, sample) from aphrodite.modeling.sampling_metadata import SamplingTensors from aphrodite.modeling.utils import set_random_seed from aphrodite.triton_utils.libentry import LibEntry from aphrodite.triton_utils.sample import (MAX_TRITON_N_COLS, get_num_triton_sampler_splits) SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 @pytest.fixture(autouse=True) def _cleanup(): yield gc.collect() torch.cuda.empty_cache() @triton.jit def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): idx = tl.arange(0, n) x = tl.load(input + idx) y = _uniform_to_exponential(x) tl.store(output + idx, y) def test_uniform_to_exponential(): """Test that we can convert uniform to exponential without div by 0.""" input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], dtype=torch.float32, device="cuda") output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") _uniform_to_exponential_kernel[(1, )](input, output, 2) assert torch.all(torch.isfinite(output)) assert torch.all(output > 0) assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) @pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) @pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("modify_greedy_probs", [True, False]) @pytest.mark.parametrize("seed", [1337]) @pytest.mark.parametrize("vocab_size", [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) @pytest.mark.parametrize("save_logprobs", [True, False]) def test_sample_decoding_only(random_sampling, max_best_of, modify_greedy_probs, seed, vocab_size, save_logprobs): set_random_seed(seed) bs = 8 probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") for i in range(bs): probs[i, i * (vocab_size // bs)] = 1.0 logprobs = torch.rand_like(probs) sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") n_splits = get_num_triton_sampler_splits(probs.shape[1]) if random_sampling == "mixed": random_sampling_mask = (torch.rand( (1, bs), device="cuda") < 0.5).expand(n_splits, bs) elif random_sampling: random_sampling_mask = torch.ones((n_splits, bs), dtype=torch.bool, device="cuda") else: random_sampling_mask = torch.zeros((n_splits, bs), dtype=torch.bool, device="cuda") seeds = torch.randint(1, torch.iinfo(torch.long).max, (n_splits, bs), device="cuda").mul_(random_sampling_mask) #The current _sample_triton does not utilize the # libentry decoration. The purpose of adding this patch is to test # the correctness of libentry. with patch("aphrodite.model_executor.layers.ops.sample._sample_triton", LibEntry(_sample_triton)): sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( probs=probs, logprobs=logprobs, sample_indices=sample_indices, seeds=seeds, max_best_of=max_best_of, modify_greedy_probs=modify_greedy_probs, save_logprobs=save_logprobs, _save_modified_probs=True) assert sampled_tokens.shape == (bs, max_best_of) for i in range(bs): assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) request_uses_random_sampling = random_sampling_mask[0, i] if modify_greedy_probs and not request_uses_random_sampling: # If we are modifying greedy probs and the request is greedy, # we want to make sure the probs tensor is modified in place torch.testing.assert_close( probs[i][sampled_tokens[i]], torch.full_like(probs[i][sampled_tokens[i]], 1.0)) assert torch.sum(probs[i]) == 1.0 torch.testing.assert_close( sampled_modified_probs[i][0], torch.full_like(sampled_modified_probs[i][0], 1.0)) elif request_uses_random_sampling: # If the request is random, we want to make sure # sampled_modified_probs tensor has noise added # (and thus is different from probs tensor) assert not torch.allclose(sampled_modified_probs[i][0], probs[i][sampled_tokens[i]]) elif not request_uses_random_sampling: # If the request is greedy and we are not modifying greedy probs, # we want to make sure sampled_modified_probs tensor is the same as # the probs tensor. torch.testing.assert_close(sampled_modified_probs[i], probs[i][sampled_tokens[i]]) if save_logprobs: assert sampled_logprobs.shape == (bs, max_best_of) for i in range(bs): for best_of in range(max_best_of): assert torch.all(sampled_logprobs[i] == logprobs[i][ sampled_tokens[i, best_of]]) else: assert sampled_logprobs is None @pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) @pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) @pytest.mark.parametrize("modify_greedy_probs", [True, False]) @pytest.mark.parametrize("seed", [1337]) @pytest.mark.parametrize("vocab_size", [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) def test_sample_prompt_logprobs(random_sampling, max_best_of, modify_greedy_probs, seed, vocab_size): set_random_seed(seed) prompt_sizes = [16, 32, 64, 128] * 2 samples = 8 bs = samples + sum(prompt_sizes) probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") for i in range(bs): probs[i, i * (vocab_size // bs)] = 1.0 logprobs = torch.rand_like(probs) sample_indices = torch.tensor(prompt_sizes, dtype=torch.long, device="cuda").cumsum_(0) n_splits = get_num_triton_sampler_splits(probs.shape[1]) if random_sampling == "mixed": random_sampling_mask = torch.rand( (n_splits, samples), device="cuda") < 0.5 elif random_sampling: random_sampling_mask = torch.ones((n_splits, samples), dtype=torch.bool, device="cuda") else: random_sampling_mask = torch.zeros((n_splits, samples), dtype=torch.bool, device="cuda") seeds = torch.randint(1, torch.iinfo(torch.long).max, (n_splits, samples), device="cuda").mul_(random_sampling_mask) #ditto with patch("aphrodite.model_executor.layers.ops.sample._sample_triton", LibEntry(_sample_triton)): sampled_tokens, sampled_logprobs, _ = sample( probs=probs, logprobs=logprobs, sample_indices=sample_indices, seeds=seeds, max_best_of=max_best_of, modify_greedy_probs=modify_greedy_probs, save_logprobs=True) assert sampled_tokens.shape == (samples, max_best_of) assert sampled_logprobs.shape == (samples, max_best_of) for i, t in enumerate(sample_indices): assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) for best_of in range(max_best_of): assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] [sampled_tokens[i, best_of]]) @pytest.mark.parametrize("seed", list(range(16))) def test_get_sequence_seeds(seed): """Ensure that we get a different child seed from base seed + extra entropy""" starting_seed = seed seq_seed = None extra_entropy = 1 for i in range(512): new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, i, seeds_to_generate=1, is_greedy=False)[0] new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( starting_seed, i, extra_entropy, seeds_to_generate=1, is_greedy=False)[0] assert new_seq_seed_extra_entropy != new_seq_seed assert seq_seed != new_seq_seed seq_seed = new_seq_seed