소스 검색

chore: remove dead code from triton sampling kernels (#1049)

AlpinDale 2 달 전
부모
커밋
4593a3b306

+ 0 - 9
aphrodite/common/utils.py

@@ -842,15 +842,6 @@ def async_tensor_h2d(
     return t.to(device=target_device, non_blocking=True)
 
 
-def maybe_expand_dim(tensor: torch.Tensor,
-                     target_dims: int,
-                     size: int = 1) -> torch.Tensor:
-    """Expand the tensor to the target_dims."""
-    if tensor.ndim < target_dims:
-        tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
-    return tensor
-
-
 def get_dtype_size(dtype: torch.dtype) -> int:
     """Get the size of the data type in bytes."""
     return torch.tensor([], dtype=dtype).element_size()

+ 0 - 0
aphrodite/modeling/layers/ops/__init__.py


+ 0 - 155
aphrodite/modeling/layers/ops/rand.py

@@ -1,155 +0,0 @@
-from typing import Optional, Union
-
-import torch
-import triton
-import triton.language as tl
-
-
-def seeded_uniform(
-    *size,
-    seeds: torch.Tensor,
-    out: Optional[torch.Tensor] = None,
-    dtype: Optional[torch.dtype] = None,
-    device: Optional[Union[torch.device, str]] = None,
-    pin_memory: Optional[bool] = False,
-) -> torch.Tensor:
-    """Similar to torch.rand, but allows for seeds to be set per row.
-    seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
-    If it is 3d, the additional seeds needed will be derived automatically
-    in a deterministic fashion:
-    [
-        row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
-    ]
-    """
-    n_dims = len(size)
-
-    if n_dims > 3:
-        raise ValueError("seeded_uniform only supports up to 3D tensors")
-
-    if out is None:
-        out = torch.empty(*size,
-                          dtype=dtype,
-                          device=device,
-                          pin_memory=pin_memory)
-    elif out.shape != size:
-        raise ValueError("shape of out and size must be the same")
-
-    if n_dims == 3:
-        n_rows, n_3d, n_cols = out.shape
-        stride_row = out.stride(0)
-        stride_3d = out.stride(1)
-    elif n_dims == 2:
-        n_rows, n_cols = out.shape
-        n_3d = 1
-        stride_row = out.stride(0)
-        stride_3d = 1
-    else:
-        n_cols = out.shape[0]
-        n_rows = 1
-        n_3d = 1
-        stride_row = 1
-        stride_3d = 1
-
-    if seeds.ndim != 1:
-        raise ValueError("seeds must be a 1D tensor")
-
-    if seeds.numel() != n_rows:
-        raise ValueError(
-            "seeds must have the same number of elements as out has rows")
-
-    # The philox PRNG Triton uses generates 4 random numbers at once.
-    # Therefore, the most efficient use of it is to divide the
-    # block size by 4, and then save the generated random numbers to
-    # each of the 4 slices of the tensor.
-    full_block_size = triton.next_power_of_2(n_cols)
-    philox_block_size = max(full_block_size // 4, 1)
-    n_slices = full_block_size // philox_block_size
-    num_warps = 4
-    # Manual tuning. This seems to give best performance on A100 for
-    # simple kernels like this.
-    if philox_block_size >= 8192:
-        num_warps = 32
-    elif philox_block_size >= 4096:
-        num_warps = 16
-    elif philox_block_size >= 2048:
-        num_warps = 8
-
-    _seeded_uniform_triton[(n_rows, n_3d)](
-        out,
-        seeds,
-        stride_row,
-        stride_3d,
-        seeds.stride(0),
-        n_rows,
-        n_3d,
-        n_cols,
-        n_slices=n_slices,
-        num_warps=num_warps,
-        block_size=philox_block_size,
-    )
-    return out
-
-
-@triton.jit
-def _seeded_uniform_triton(
-    out_ptr: torch.Tensor,
-    seed_ptr: torch.Tensor,
-    out_row_stride: int,
-    out_3d_stride: int,
-    seed_row_stride: int,
-    n_rows: int,
-    n_3d: int,
-    n_cols: int,
-    n_slices: tl.constexpr,
-    block_size: tl.constexpr,
-):
-    """
-    Generate a random float32 number in [0, 1) for each element in the output
-    tensor. The random numbers in a row generated using the seed for that row.
-    Args:
-        out_ptr: The output tensor.
-        seed_ptr: The per-row seeds to use for random number generation.
-        out_row_stride: The stride between rows of the output tensor.
-        out_3d_stride: The stride between 3D slices of the output tensor.
-        seed_row_stride: The stride between rows of the seed tensor.
-        n_rows: The number of rows in the output tensor.
-        n_3d: The size of second dimension of the output tensor,
-            if output tensor is 3D.
-        n_cols: The number of columns in the output tensor.
-        n_slices: The number of philox outputs to use.
-    """
-    tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
-
-    # Get the row index.
-    row_idx = tl.program_id(axis=0)
-    three_d_idx = tl.program_id(axis=1)
-
-    philox_offsets = tl.arange(0, block_size)
-    # Get the seed for the current element.
-    seed = tl.load(seed_ptr + row_idx * seed_row_stride)
-    if three_d_idx > 0:
-        seed ^= three_d_idx
-    # Generate random numbers in [0, 1).
-    out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
-
-    output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
-                            three_d_idx * out_3d_stride)
-    out1_offsets = philox_offsets
-    tl.store(output_row_start_ptr + out1_offsets,
-             out1,
-             mask=out1_offsets < n_cols)
-    if n_slices > 1:
-        out2_offsets = tl.arange(block_size, block_size * 2)
-        tl.store(output_row_start_ptr + out2_offsets,
-                 out2,
-                 mask=out2_offsets < n_cols)
-    if n_slices > 2:
-        out3_offsets = tl.arange(block_size * 2, block_size * 3)
-        tl.store(output_row_start_ptr + out3_offsets,
-                 out3,
-                 mask=out3_offsets < n_cols)
-    if n_slices > 3:
-        out4_offsets = tl.arange(block_size * 3, block_size * 4)
-        tl.store(output_row_start_ptr + out4_offsets,
-                 out4,
-                 mask=out4_offsets < n_cols)

+ 0 - 394
aphrodite/modeling/layers/ops/sample.py

@@ -1,394 +0,0 @@
-from typing import Optional, Tuple
-
-import torch
-import triton
-import triton.language as tl
-
-from aphrodite.modeling.layers.ops.rand import seeded_uniform
-from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
-
-_EPS = 1e-6
-
-
-def _multi_split_sample(
-    probs: torch.Tensor,
-    seeds: torch.Tensor,
-    n_splits: int,
-    sampled_tokens_size: Tuple[int, int],
-    sampled_logprobs_size: Tuple[int, int],
-    sample_indices: torch.Tensor,
-    logprobs: torch.Tensor,
-    *,
-    modify_greedy_probs: bool = False,
-    save_logprobs: bool = False,
-):
-    """Sample tokens where vocab size is split into multiple parts
-    (too large for Triton otherwise)."""
-    assert seeds.ndim == 2 and seeds.shape[0] == n_splits
-    split_probs = probs.tensor_split(n_splits, 1)
-    split_logprobs = logprobs.tensor_split(n_splits, 1)
-    sampled_tokens_tmp = [
-        torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
-        for _ in range(n_splits)
-    ]
-    sampled_logprobs_tmp = [
-        torch.empty(sampled_logprobs_size,
-                    dtype=probs.dtype,
-                    device=probs.device) for _ in range(n_splits)
-    ]
-    # We are purposefuly using sampled_tokens_size as we need to always
-    # save modified probs in this case.
-    sampled_modified_probs_tmp = [
-        torch.empty(sampled_tokens_size,
-                    dtype=probs.dtype,
-                    device=probs.device) for _ in range(n_splits)
-    ]
-    for i in range(n_splits):
-        n_samples = sample_indices.shape[0]
-        n_cols = split_probs[i].shape[1]
-        n_best = sampled_tokens_tmp[i].shape[1]
-        uniform_noise = seeded_uniform(n_samples,
-                                       n_best,
-                                       n_cols,
-                                       seeds=seeds[i].flatten(),
-                                       device=split_probs[i].device,
-                                       dtype=split_probs[i].dtype)
-        # TODO: See if we can remove the contiguous() calls.
-        # Will need kernel support.
-        _sample(
-            split_probs[i].contiguous(),
-            split_logprobs[i].contiguous(),
-            sample_indices,
-            sampled_tokens_tmp[i],
-            sampled_logprobs_tmp[i],
-            sampled_modified_probs_tmp[i],
-            seeds[i],
-            uniform_noise,
-            modify_greedy_probs=False,
-            save_logprobs=save_logprobs,
-            save_modified_probs=True,
-        )
-        if i > 0:
-            # Add offset to sampled tokens
-            sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
-    sampled_tokens = torch.stack(sampled_tokens_tmp)
-    sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
-    # Reduce the results from the splits.
-    sampled_modified_probs, indices = torch.max(sampled_modified_probs,
-                                                dim=0,
-                                                keepdim=True)
-    sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
-    if save_logprobs:
-        sampled_logprobs = torch.stack(sampled_logprobs_tmp)
-        sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
-    else:
-        sampled_logprobs = None
-    sampled_modified_probs = sampled_modified_probs.squeeze(0)
-
-    if modify_greedy_probs:
-        # We need to modify the greedy probs for the sampled tokens.
-        # We can't do this in the kernel as we need to know the
-        # sampled tokens.
-        probs.fill_(0.0)
-        probs.scatter_(1, sampled_tokens, 1.0)
-
-    return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
-
-
-def sample(
-    probs: torch.Tensor,
-    seeds: torch.Tensor,
-    *,
-    max_best_of: int = 1,
-    sample_indices: Optional[torch.Tensor] = None,
-    logprobs: Optional[torch.Tensor] = None,
-    modify_greedy_probs: bool = False,
-    save_logprobs: bool = False,
-    _save_modified_probs: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
-    """Sample tokens from probs. with per-sequence seeds.
-
-    Can sample from a subset of sequences through sample_indices.
-
-    Args:
-        probs: Probabilities to sample from.
-            shape = [batch_size, vocab_size]
-        seeds: Per-sequence seed values.
-            shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
-        max_best_of: Number of samples to generate per sequence.
-            Sequence seed will be incremented by 1 each time.
-        sample_indices: Indices of sequences to sample from.
-            If not provided, will sample from all sequences.
-            shape = [n]
-        logprobs: Log-probabilities of the sampled tokens.
-            Only used for saving the logprobs if save_logprobs is True.
-            shape = [batch_size, vocab_size]
-        modify_greedy_probs: Whether to modify the greedy probabilities
-            for speculative sampling (sampled token = 1.0,
-            everything else = 0.0).
-        save_logprobs: Whether to save the log-probabilities of the
-            sampled tokens to a tensor.
-        _save_modified_probs: Whether to save the modified probabilities
-            (including gumbel noise) of the sampled tokens to a tensor.
-            DOES NOT include the modification done by modify_greedy_probs
-            (because we want to use the unmodified probs to pick the best
-            split in case of multi-split sampling).
-            This is exposed only for testing.
-
-    Returns:
-        sampled_tokens: shape = [n, max_best_of]
-        sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
-        sampled_modified_probs: shape = [n, max_best_of]
-            if save_modified_probs else None
-    """
-    if sample_indices is None:
-        sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
-
-    sampled_tokens_size = (sample_indices.size(0), max_best_of)
-    if save_logprobs:
-        if logprobs is None:
-            raise ValueError(
-                "logprobs tensor must be provided if save_logprobs is True")
-        sampled_logprobs_size = sampled_tokens_size
-    else:
-        # Empty tensors to invoke the kernel
-        sampled_logprobs_size = (0, 0)
-        logprobs = probs
-
-    assert logprobs is not None
-    if _save_modified_probs:
-        sampled_modified_probs_size = sampled_tokens_size
-    else:
-        # Empty tensors to invoke the kernel
-        sampled_modified_probs_size = (0, 0)
-
-    # If the number of columns in probs is too large for Triton to handle,
-    # we split the tensor and sample from each split separately, and then
-    # do an argmax+gather to combine the results.
-    n_splits = get_num_triton_sampler_splits(probs.shape[1])
-    if n_splits > 1:
-        (sampled_tokens, sampled_logprobs,
-         sampled_modified_probs) = _multi_split_sample(
-             probs,
-             seeds,
-             n_splits,
-             sampled_tokens_size,
-             sampled_logprobs_size,
-             sample_indices,
-             logprobs=logprobs,
-             modify_greedy_probs=modify_greedy_probs,
-             save_logprobs=save_logprobs)
-    else:
-        sampled_tokens = torch.empty(sampled_tokens_size,
-                                     dtype=torch.long,
-                                     device=probs.device)
-        sampled_logprobs = torch.empty(sampled_logprobs_size,
-                                       dtype=probs.dtype,
-                                       device=probs.device)
-        sampled_modified_probs = torch.empty(sampled_modified_probs_size,
-                                             dtype=probs.dtype,
-                                             device=probs.device)
-        n_samples = sample_indices.shape[0]
-        n_cols = probs.shape[1]
-        uniform_noise = seeded_uniform(n_samples,
-                                       max_best_of,
-                                       n_cols,
-                                       seeds=seeds.flatten(),
-                                       device=probs.device,
-                                       dtype=probs.dtype)
-
-        _sample(
-            probs,
-            logprobs,
-            sample_indices,
-            sampled_tokens,
-            sampled_logprobs,
-            sampled_modified_probs,
-            seeds,
-            uniform_noise,
-            modify_greedy_probs=modify_greedy_probs,
-            save_logprobs=save_logprobs,
-            save_modified_probs=_save_modified_probs,
-        )
-    return (sampled_tokens, sampled_logprobs if save_logprobs else None,
-            sampled_modified_probs if _save_modified_probs else None)
-
-
-def _sample(probs: torch.Tensor,
-            logprobs: torch.Tensor,
-            sample_indices: torch.Tensor,
-            output_samples: torch.Tensor,
-            output_logprobs: torch.Tensor,
-            output_modified_probs: torch.Tensor,
-            seeds: torch.Tensor,
-            uniform_noise: torch.Tensor,
-            *,
-            modify_greedy_probs: bool = False,
-            save_logprobs: bool = True,
-            save_modified_probs: bool = False) -> torch.Tensor:
-    """Sample tokens from probs.
-
-    Args:
-        probs [batch_size, vocab_size]: probs to sample from.
-        logprobs [batch_size, vocab_size]: logprobs (used when
-            save_logprobsis True).
-        sample_indices [n]: Indices of the samples to use for each row of probs.
-        output_samples [n, n_best]: Output tensor to store samples in.
-        output_logprobs [n, n_best]: Output tensor to store logprobs in.
-        output_modified_probs [n, n_best]: Output tensor to store
-            probs of chosen tokens in (modified with noise).
-        seeds [n]: Seeds to use for sampling. If the seed is 0, we use
-            greedy sampling. Note this is ONLY used for determining
-            whether to use random sampling or not. The actual random
-            noise should be passed as uniform_noise.
-        uniform_noise [batch_size, n_best, vocab_size]: Uniform
-            noise to use for random sampling (will be converted
-            to exponential gumbel noise by the kernel).
-        modify_greedy_probs: If True, we modify the probs tensor in-place
-            to encode the sampling method used for each row. This is used
-            in speculative decoding. Only applies in greedy decoding.
-        save_logprobs: If True, we save the logprobs of the sampled tokens
-            in the output_logprobs tensor.
-        save_modified_probs: If True, we save the modified probs (with noise)
-            of the sampled tokens in the output_modified_probs tensor.
-            DOES NOT include the modification done by modify_greedy_probs
-            (because we want to use the unmodified probs to pick the best
-            split in case of multi-split sampling).
-    """
-    n_samples = sample_indices.shape[0]
-    n_cols = probs.shape[1]
-    n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
-
-    # The block size is the smallest power of two greater than the number of
-    # columns in probs
-    block_size = triton.next_power_of_2(n_cols)
-    num_warps = 4
-    # Manual tuning. This seems to give best performance on A100 for
-    # simple kernels like this.
-    if block_size >= 8192:
-        num_warps = 32
-    elif block_size >= 4096:
-        num_warps = 16
-    elif block_size >= 2048:
-        num_warps = 8
-
-    # Enqueue kernel. The 1D launch grid is simple: we have one kernel
-    # instance per row of the probs matrix
-    _sample_triton[(n_samples, n_best)](
-        sample_indices,
-        output_samples,
-        output_logprobs,
-        output_modified_probs,
-        probs,
-        logprobs,
-        seeds,
-        uniform_noise,
-        output_samples.stride(0),
-        probs.stride(0),
-        uniform_noise.stride(0),
-        uniform_noise.stride(1) if n_best > 1 else 1,
-        n_samples,
-        n_cols,
-        n_best,
-        num_warps=num_warps,
-        block_size=block_size,
-        modify_greedy_probs=modify_greedy_probs,
-        save_logprobs=save_logprobs,
-        save_modified_probs=save_modified_probs,
-    )
-    return output_samples, output_logprobs, output_modified_probs
-
-
-@triton.jit
-def _uniform_to_exponential(uniform_noise):
-    """Convert uniform samples to exponential samples."""
-    # tl.rand returns values in [0, 1), so we clamp lower bound
-    # to _EPS to avoid log(0) and thus division by 0 later
-    lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
-    uniform_noise = tl.maximum(uniform_noise, lb)
-    # Use the inversion method to turn uniform samples
-    # into exponential samples
-    exponential_noise = -tl.log(uniform_noise)
-    return exponential_noise
-
-
-@triton.jit
-def _sample_triton(
-        sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
-        output_logprobs_ptr: torch.Tensor,
-        output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
-        logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
-        uniform_noise_ptr: torch.Tensor, output_row_stride: int,
-        probs_row_stride: int, uniform_noise_row_stride: int,
-        uniform_noise_best_stride: int, n_samples: int, n_cols: int,
-        n_best: int, block_size: tl.constexpr,
-        modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
-        save_modified_probs: tl.constexpr):
-    # The rows are independent, so we parallelize across those
-    sample_idx = tl.program_id(0)
-    best_idx = tl.program_id(1)
-
-    # Load the row index from DRAM
-    row_idx = tl.load(sample_indices_ptr + sample_idx)
-    seed = tl.load(seeds_ptr + sample_idx)
-    uses_random_sampling = seed != 0
-
-    # The stride represents how much we need to increase the
-    # pointer to advance 1 row
-    row_start_ptr = probs_ptr + row_idx * probs_row_stride
-
-    # The block size is the next power of two greater than n_cols,
-    # so we can fit each row in a single block
-    col_offsets = tl.arange(0, block_size)
-
-    # Load the row into SRAM, using a mask since block_size may be > than n_cols
-    row = tl.load(row_start_ptr + col_offsets,
-                  mask=col_offsets < n_cols,
-                  other=float("-inf"))
-
-    if uses_random_sampling:
-        uniform_noise_start_ptr = (uniform_noise_ptr +
-                                   sample_idx * uniform_noise_row_stride +
-                                   best_idx * uniform_noise_best_stride)
-        uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
-                                mask=col_offsets < n_cols,
-                                other=0.5)
-        exponential_noise = _uniform_to_exponential(uniform_noise)
-        row /= exponential_noise
-
-    sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
-    # clamp sampled token to n_cols - 1
-    # this should not be necessary, but we do it
-    # just in case
-    if sampled_token >= n_cols:
-        sampled_token = n_cols - 1
-    # Write back output to DRAM
-    output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
-                            best_idx)
-    tl.store(output_row_start_ptr, sampled_token)
-
-    if modify_greedy_probs:  # noqa
-        if not uses_random_sampling:
-            # Set the probability of the sampled token to 1, all other
-            # tokens to zero. This is used in speculative decoding where
-            # the sampling method must be encoded within the sampled
-            # probability distributions.
-            row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
-            tl.store(row_start_ptr + col_offsets,
-                     row,
-                     mask=col_offsets < n_cols)
-
-    if save_modified_probs:
-        output_row_start_ptr = (output_modified_probs_ptr +
-                                sample_idx * output_row_stride + best_idx)
-        tl.store(output_row_start_ptr, sampled_value)
-
-    if save_logprobs:
-        # Load the row into SRAM, using a mask since block_size
-        # may be > than n_cols
-        sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
-                                  sampled_token)
-        # Write back output to DRAM
-        output_row_start_ptr = (output_logprobs_ptr +
-                                sample_idx * output_row_stride + best_idx)
-        tl.store(output_row_start_ptr, sampled_logprob)

+ 2 - 84
aphrodite/modeling/layers/sampler.py

@@ -18,15 +18,10 @@ from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        PromptLogprobs, SampleLogprobs,
                                        SequenceOutput)
 from aphrodite.common.utils import is_cpu
-from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
-from aphrodite.triton_utils import HAS_TRITON
-
-if HAS_TRITON:
-    from aphrodite.modeling.layers.ops.sample import sample as sample_triton
-
 from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
                                                   SamplingTensors,
                                                   SequenceGroupToSample)
+from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 # (num_token_ids, num_parent_ids) per sequence group.
 SampleResultType = List[Tuple[List[int], List[int]]]
@@ -1431,7 +1426,7 @@ def _sample_with_torch(
     # Counterintuitively, having two loops here is actually faster.
     # The first loop can run without waiting on GPU<->CPU sync.
     for sampling_type in SamplingType:
-        sample_indices = categorized_sample_indices[sampling_type][:, 0]
+        sample_indices = categorized_sample_indices[sampling_type]
         num_tokens = len(sample_indices)
         if num_tokens == 0:
             continue
@@ -1515,80 +1510,6 @@ def _sample_with_torch(
         )
 
 
-def _sample_with_triton_kernel(
-    probs: torch.Tensor,
-    logprobs: torch.Tensor,
-    sampling_metadata: SamplingMetadata,
-    sampling_tensors: SamplingTensors,
-) -> List[Tuple[List[int], List[int]]]:
-    categorized_seq_group_ids = {t: [] for t in SamplingType}
-    categorized_sample_indices = sampling_metadata.categorized_sample_indices
-    for i, seq_group in enumerate(sampling_metadata.seq_groups):
-        sampling_params = seq_group.sampling_params
-        sampling_type = sampling_params.sampling_type
-        categorized_seq_group_ids[sampling_type].append(i)
-
-    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
-    sample_metadata = {}
-    max_best_of_in_batch = 1
-    # Counterintuitively, having two loops here is actually faster.
-    # The first loop can run without waiting on GPU<->CPU sync.
-    for sampling_type in SamplingType:
-        sample_indices = categorized_sample_indices[sampling_type][:, 0]
-        sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
-        num_tokens = len(sample_indices)
-        if num_tokens == 0:
-            continue
-        seq_group_id = categorized_seq_group_ids[sampling_type]
-        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
-        sample_metadata[sampling_type] = (seq_group_id, seq_groups,
-                                          sample_indices,
-                                          sampled_token_indices)
-        if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
-                             SamplingType.RANDOM_SEED):
-            for seq_group in seq_groups:
-                if seq_group.is_prompt:
-                    sampling_params = seq_group.sampling_params
-                    max_best_of_in_batch = max(max_best_of_in_batch,
-                                               sampling_params.best_of)
-        elif sampling_type == SamplingType.BEAM:
-            beam_search_logprobs = logprobs[sample_indices]
-        else:
-            raise ValueError(f"Unsupported sampling type: {sampling_type}")
-    sampled_tokens, _, _ = sample_triton(
-        probs=probs,
-        seeds=sampling_tensors.sampling_seeds,
-        max_best_of=max_best_of_in_batch,
-        sample_indices=sampling_tensors.sample_indices,
-        logprobs=logprobs,
-        # don't save logprobs because we have logic for that below
-        # TODO: use this instead of the CPU-based logic below
-        save_logprobs=False,
-    )
-    # GPU<->CPU sync happens in the loop below.
-    for sampling_type in SamplingType:
-        if sampling_type not in sample_metadata:
-            continue
-        (seq_group_id, seq_groups, sample_indices,
-         sampled_token_indices) = sample_metadata[sampling_type]
-        if sampling_type == SamplingType.GREEDY:
-            sample_results = _greedy_sample(
-                seq_groups, sampled_tokens[sampled_token_indices][:, 0])
-        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
-            sample_results = _random_sample(
-                seq_groups, sampled_tokens[sampled_token_indices])
-        elif sampling_type == SamplingType.BEAM:
-            sample_results = _beam_search_sample(seq_groups,
-                                                 beam_search_logprobs)
-        sample_results_dict.update(zip(seq_group_id, sample_results))
-
-    sample_results = [
-        sample_results_dict.get(i, ([], []))
-        for i in range(len(sampling_metadata.seq_groups))
-    ]
-    return sample_results
-
-
 def _sample(
     probs: torch.Tensor,
     logprobs: torch.Tensor,
@@ -1616,9 +1537,6 @@ def _sample(
         include_gpu_probs_tensor=include_gpu_probs_tensor,
         modify_greedy_probs=modify_greedy_probs,
     )
-    # TODO: Enable once Triton kernel & associated code is faster.
-    # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
-    #                                   sampling_tensors)
 
 
 def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:

+ 103 - 164
aphrodite/modeling/sampling_metadata.py

@@ -1,4 +1,3 @@
-import random
 from array import array
 from dataclasses import dataclass
 from typing import Dict, List, Optional, Tuple
@@ -9,14 +8,10 @@ from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
 from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
                                     is_pin_memory_available,
-                                    make_tensor_with_pad, maybe_expand_dim)
+                                    make_tensor_with_pad)
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
-from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
 
 _SAMPLING_EPS = 1e-5
-_SEED_0_REPLACEMENT = 3403598558
-# Some triton sampler related code is guarded before it is ready.
-_USE_TRITON_SAMPLER = False
 
 
 @dataclass
@@ -168,11 +163,12 @@ class SamplingMetadata:
                                                   target_device=device,
                                                   pin_memory=pin_memory)
         categorized_sample_indices = {
-            t: maybe_expand_dim(
-                async_tensor_h2d(seq_ids,
-                                 dtype=torch.int,
-                                 target_device=device,
-                                 pin_memory=pin_memory), 2, 2)
+            t: async_tensor_h2d(
+                seq_ids,
+                dtype=torch.int,
+                target_device=device,
+                pin_memory=pin_memory,
+            )
             for t, seq_ids in categorized_sample_indices.items()
         }
 
@@ -199,8 +195,8 @@ def _prepare_seq_groups(
     device: str,
     generators: Optional[Dict[str, torch.Generator]] = None,
     cache: Optional[SamplingMetadataCache] = None,
-) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
-        SamplingType, List[Tuple[int, int]]], int]:
+) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
+                                                        List[int]], int, ]:
     """Prepare sequence groups and indices for sampling.
 
     Args:
@@ -231,16 +227,13 @@ def _prepare_seq_groups(
     # Sampling type -> (
     # indices to sample/prompt logprob within pruned output logits,
     # indices to sample within pruned logits)
-    categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
+    categorized_sample_indices: Dict[SamplingType, List[int]] = {
         t: []
         for t in SamplingType
     }
     # Index of logits to compute logprob. Logits include both prompt logprob
     # and sample logprob indices.
     logit_idx = 0
-    # Index to sample from a sample tensor. It is used by triton sample kernel.
-    # See `_sample_with_triton_kernel` for more details.
-    sample_idx = 0
     # Total number of prompts from given sequence groups.
     num_prompts = 0
 
@@ -261,10 +254,10 @@ def _prepare_seq_groups(
         # If the current seq group is in decode stage, it is None.
         seq_len: Optional[int] = None
         query_len: Optional[int] = None
-        prompt_logprob_indices: List[int] = \
-            sample_obj.prompt_logprob_indices if cache is not None else []
-        sample_indices: List[int] = \
-            sample_obj.sample_indices if cache is not None else []
+        prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
+                                             if cache is not None else [])
+        sample_indices: List[int] = (sample_obj.sample_indices
+                                     if cache is not None else [])
         do_sample = seq_group_metadata.do_sample
 
         if seq_group_metadata.is_prompt:
@@ -330,11 +323,7 @@ def _prepare_seq_groups(
         if do_sample:
             sample_indices.extend(range(logit_idx, logit_idx + sample_len))
             categorized_sample_indices[sampling_params.sampling_type].extend(
-                list(
-                    zip(range(logit_idx, logit_idx + sample_len),
-                        range(sample_idx, sample_idx + sample_len))))
-            logit_idx += sample_len
-            sample_idx += sample_len
+                list(range(logit_idx, logit_idx + sample_len)))
 
         if cache is not None:
             sample_obj.sampling_params = sampling_params
@@ -353,7 +342,8 @@ def _prepare_seq_groups(
                 generator=generator,
                 is_prompt=is_prompt,
                 prompt_logprob_indices=list(prompt_logprob_indices),
-                sample_indices=list(sample_indices))
+                sample_indices=list(sample_indices),
+            )
 
         seq_groups.append(sample_obj)
 
@@ -395,9 +385,6 @@ class SamplingTensors:
     dry_sequence_breaker_ids: torch.Tensor
     dry_ranges: torch.Tensor
     skews: torch.Tensor
-    sampling_seeds: torch.Tensor
-    sample_indices: torch.Tensor
-    extra_seeds: Optional[torch.Tensor]
     prompt_tokens: torch.Tensor
     output_tokens: torch.Tensor
 
@@ -408,16 +395,8 @@ class SamplingTensors:
         vocab_size: int,
         device: torch.device,
         dtype: torch.dtype,
-        *,
-        extra_seeds_to_generate: int = 0,
-        extra_entropy: Optional[Tuple[int, ...]] = None
     ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
                bool, bool, bool, bool, bool, bool, bool, bool, bool]:
-        """
-        extra_seeds_to_generate: extra seeds to generate using the
-            user-defined seed for each sequence.
-        extra_entropy: extra entropy to use when generating seeds.
-        """
         prompt_tokens: List[array] = []
         output_tokens: List[array] = []
         top_ks: List[int] = []
@@ -442,8 +421,6 @@ class SamplingTensors:
         xtc_thresholds: List[float] = []
         xtc_probabilities: List[float] = []
         nsigmas: List[float] = []
-        sampling_seeds: List[List[int]] = []
-        sample_indices: List[int] = []
         dry_multipliers: List[float] = []
         dry_bases: List[float] = []
         dry_allowed_lengths: List[int] = []
@@ -468,13 +445,6 @@ class SamplingTensors:
         do_skews = False
         do_temp_last = False
 
-        if _USE_TRITON_SAMPLER:
-            prompt_best_of: List[int] = []
-
-            # We need one base seed per Triton slice.
-            seeds_to_generate = (extra_seeds_to_generate +
-                                 get_num_triton_sampler_splits(vocab_size))
-
         assert sampling_metadata.seq_groups is not None
         for seq_group in sampling_metadata.seq_groups:
             seq_ids = seq_group.seq_ids
@@ -515,7 +485,6 @@ class SamplingTensors:
 
             do_temp_last |= params.temperature_last
 
-            is_prompt = seq_group.is_prompt
             wants_prompt_logprobs = params.prompt_logprobs is not None
 
             n_seqs = 0
@@ -557,28 +526,6 @@ class SamplingTensors:
             dry_ranges += [params.dry_range] * n_seqs
             skews += [params.skew] * n_seqs
 
-            if _USE_TRITON_SAMPLER:
-                if is_prompt:
-                    prompt_best_of.append(params.best_of)
-                    query_len = seq_group.query_len
-                    assert query_len is not None
-
-                seed = params.seed
-                is_greedy = params.sampling_type == SamplingType.GREEDY
-
-                for seq_id in seq_ids:
-                    seq_data = seq_group.seq_data[seq_id]
-                    extra_entropy = extra_entropy or ()
-                    seq_seeds = cls._get_sequence_seeds(
-                        seed,
-                        seq_data.get_len(),
-                        *extra_entropy,
-                        seq_id,
-                        seeds_to_generate=seeds_to_generate,
-                        is_greedy=is_greedy)
-                    sampling_seeds.append(seq_seeds)
-                sample_indices.extend(seq_group.sample_indices)
-
         if do_penalties or do_dry or do_no_repeat_ngrams:
             for seq_group in sampling_metadata.seq_groups:
                 seq_ids = seq_group.seq_ids
@@ -598,43 +545,94 @@ class SamplingTensors:
                         output_tokens.append(seq_data.output_token_ids_array)
 
         sampling_tensors = SamplingTensors.from_lists(
-            temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
-            temperature_lasts, top_ps, top_ks, top_as, min_ps,
-            presence_penalties, frequency_penalties, repetition_penalties,
-            no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
-            typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
-            xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
-            dry_allowed_lengths, dry_sequence_breaker_ids, dry_ranges, skews,
-            sampling_seeds, sample_indices, prompt_tokens, output_tokens,
-            vocab_size, extra_seeds_to_generate, device, dtype)
-        return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
-                do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
-                do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
-                do_quadratic, do_xtc, do_nsigmas, do_dry, do_skews,
-                do_temp_last)
+            temperatures,
+            dynatemp_mins,
+            dynatemp_maxs,
+            dynatemp_exps,
+            temperature_lasts,
+            top_ps,
+            top_ks,
+            top_as,
+            min_ps,
+            presence_penalties,
+            frequency_penalties,
+            repetition_penalties,
+            no_repeat_ngram_sizes,
+            tfss,
+            eta_cutoffs,
+            epsilon_cutoffs,
+            typical_ps,
+            smoothing_factors,
+            smoothing_curves,
+            xtc_thresholds,
+            xtc_probabilities,
+            nsigmas,
+            dry_multipliers,
+            dry_bases,
+            dry_allowed_lengths,
+            dry_sequence_breaker_ids,
+            dry_ranges,
+            skews,
+            prompt_tokens,
+            output_tokens,
+            vocab_size,
+            device,
+            dtype)
+        return (
+            sampling_tensors,
+            do_penalties,
+            do_no_repeat_ngrams,
+            do_temperatures,
+            do_top_p_top_k,
+            do_top_as,
+            do_min_p,
+            do_tfss,
+            do_eta_cutoffs,
+            do_epsilon_cutoffs,
+            do_typical_ps,
+            do_quadratic,
+            do_xtc,
+            do_nsigmas,
+            do_dry,
+            do_skews,
+            do_temp_last)
 
     @classmethod
-    def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
-                   dynatemp_maxs: List[float], dynatemp_exps: List[float],
-                   temperature_lasts: List[bool], top_ps: List[float],
-                   top_ks: List[int], top_as: List[float],
-                   min_ps: List[float], presence_penalties: List[float],
-                   frequency_penalties: List[float],
-                   repetition_penalties: List[float],
-                   no_repeat_ngram_sizes: List[int], tfss: List[float],
-                   eta_cutoffs: List[float], epsilon_cutoffs: List[float],
-                   typical_ps: List[float], smoothing_factors: List[float],
-                   smoothing_curves: List[float], xtc_thresholds: List[float],
-                   xtc_probabilities: List[float], nsigmas: List[float],
-                   dry_multipliers: List[float], dry_bases: List[float],
-                   dry_allowed_lengths: List[int],
-                   dry_sequence_breaker_ids: List[List[int]],
-                   dry_ranges: List[int], skews: List[float],
-                   sampling_seeds: List[List[int]],
-                   sample_indices: List[int], prompt_tokens: List[array],
-                   output_tokens: List[array], vocab_size: int,
-                   extra_seeds_to_generate: int, device: torch.device,
-                   dtype: torch.dtype) -> "SamplingTensors":
+    def from_lists(
+        cls,
+        temperatures: List[float],
+        dynatemp_mins: List[float],
+        dynatemp_maxs: List[float],
+        dynatemp_exps: List[float],
+        temperature_lasts: List[bool],
+        top_ps: List[float],
+        top_ks: List[int],
+        top_as: List[float],
+        min_ps: List[float],
+        presence_penalties: List[float],
+        frequency_penalties: List[float],
+        repetition_penalties: List[float],
+        no_repeat_ngram_sizes: List[int],
+        tfss: List[float],
+        eta_cutoffs: List[float],
+        epsilon_cutoffs: List[float],
+        typical_ps: List[float],
+        smoothing_factors: List[float],
+        smoothing_curves: List[float],
+        xtc_thresholds: List[float],
+        xtc_probabilities: List[float],
+        nsigmas: List[float],
+        dry_multipliers: List[float],
+        dry_bases: List[float],
+        dry_allowed_lengths: List[int],
+        dry_sequence_breaker_ids: List[List[int]],
+        dry_ranges: List[int],
+        skews: List[float],
+        prompt_tokens: List[array],
+        output_tokens: List[array],
+        vocab_size: int,
+        device: torch.device,
+        dtype: torch.dtype) -> "SamplingTensors":
         # Note that the performance will be very bad without
         # pinned memory.
         pin_memory = is_pin_memory_available()
@@ -811,34 +809,9 @@ class SamplingTensors:
             pin_memory=pin_memory,
         )
 
-        sample_indices_t = torch.tensor(
-            sample_indices,
-            device="cpu",
-            dtype=torch.long,
-            pin_memory=pin_memory,
-        )
-        # need to transpose and make contiguous to
-        # copy the tensor correctly.
-        # [batch_size, n_seeds] -> [n_seeds, batch_size]
-        sampling_seeds_t = torch.tensor(
-            sampling_seeds,
-            device="cpu",
-            dtype=torch.long,
-            pin_memory=pin_memory,
-        ).t().contiguous()
-
         # Because the memory is pinned, we can do non-blocking
         # transfer to device.
 
-        # How many seeds the sample operation itself will need.
-        num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
-        sampling_seeds_gpu = sampling_seeds_t.to(device=device,
-                                                 non_blocking=True)
-        extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
-        if not extra_seeds_gpu.numel():
-            extra_seeds_gpu = None
-        sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
-
         return cls(
             temperatures=temperatures_t.to(device=device, non_blocking=True),
             dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
@@ -882,38 +855,4 @@ class SamplingTensors:
             typical_ps=typical_ps_t.to(device=device, non_blocking=True),
             prompt_tokens=prompt_t.to(device=device, non_blocking=True),
             output_tokens=output_t.to(device=device, non_blocking=True),
-            sampling_seeds=sampling_seeds_gpu,
-            sample_indices=sample_indices_t.to(device=device,
-                                               non_blocking=True),
-            extra_seeds=extra_seeds_gpu,
         )
-
-    @staticmethod
-    def _get_sequence_seeds(
-        seed: int|None,
-        *extra_entropy: int,
-        seeds_to_generate: int,
-        is_greedy: bool,
-    ):
-        """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
-        if not is_greedy:
-            if seed is None:
-                randint_fn = random.randint
-            else:
-                generator = random.Random(str((seed, ) + extra_entropy))
-                randint_fn = generator.randint
-            lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
-            # If the user/random sets seed = 0 but request should
-            # have sampling, we need to change it to something
-            # else. We use a constant in that case.
-            # This way we don't need to create and load a bool
-            # matrix in the sampling kernel, which reduces CPU
-            # overhead and latency.
-            seq_seeds = [
-                randint_fn(lo, hi) or _SEED_0_REPLACEMENT
-                for _ in range(seeds_to_generate)
-            ]
-        else:
-            # For the kernel, seed == 0 means greedy decoding.
-            seq_seeds = [0] * seeds_to_generate
-        return seq_seeds

+ 0 - 12
aphrodite/triton_utils/sample.py

@@ -1,12 +0,0 @@
-import math
-
-# This is a hardcoded limit in Triton (max block size).
-MAX_TRITON_N_COLS = 131072
-
-
-def get_num_triton_sampler_splits(n_cols: int) -> int:
-    """Get the number of splits to use for Triton sampling.
-    Triton has a limit on the number of columns it can handle, so we need to
-    split the tensor and call the kernel multiple times if it's too large.
-    """
-    return math.ceil(n_cols / MAX_TRITON_N_COLS)

+ 0 - 52
tests/kernels/test_rand.py

@@ -1,52 +0,0 @@
-import random
-
-import pytest
-import torch
-
-from aphrodite.modeling.layers.ops.rand import seeded_uniform
-from aphrodite.modeling.utils import set_random_seed
-
-
-@pytest.mark.parametrize("dtype",
-                         [torch.float32, torch.float16, torch.bfloat16])
-@pytest.mark.parametrize("use_3d", [True, False])
-def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
-    device = "cuda"
-    for seed in range(512):
-        set_random_seed(seed)
-        rows = random.randint(1, 512)
-        cols = random.randint(1, 64000)
-        if use_3d:
-            third_dim = random.randint(2, 10)
-            dims = [rows, third_dim, cols]
-        else:
-            dims = [rows, cols]
-        seeds = torch.randint(torch.iinfo(torch.long).min,
-                              torch.iinfo(torch.long).max, (rows, ),
-                              device=device)
-
-        # Test that the same seed produces the same output
-        out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
-        out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
-        torch.testing.assert_close(out, out2)
-        # del to save memory
-        del out2
-
-        out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
-        torch.testing.assert_close(out, out3)
-        # del to save memory
-        del out3
-
-        # Initialize out tensor with garbage to ensure that it is overwritten
-        out_with_tensor = seeded_uniform(
-            *dims,
-            out=torch.full(
-                (*dims, ),
-                -1,
-                dtype=dtype,
-                device=device,
-            ),
-            seeds=seeds,
-            dtype=dtype,
-        )
-        torch.testing.assert_close(out, out_with_tensor)

+ 0 - 209
tests/kernels/test_sampler.py

@@ -1,209 +0,0 @@
-import gc
-from unittest.mock import patch
-
-import pytest
-import torch
-import triton
-import triton.language as tl
-
-from aphrodite.modeling.layers.ops.sample import (_sample_triton,
-                                                  _uniform_to_exponential,
-                                                  sample)
-from aphrodite.modeling.sampling_metadata import SamplingTensors
-from aphrodite.modeling.utils import set_random_seed
-from aphrodite.triton_utils.libentry import LibEntry
-from aphrodite.triton_utils.sample import (MAX_TRITON_N_COLS,
-                                           get_num_triton_sampler_splits)
-
-SINGLE_SPLIT_VOCAB_SIZE = 32000  # llama/mistral/mixtral vocab size
-MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
-
-
-@pytest.fixture(autouse=True)
-def _cleanup():
-    yield
-    gc.collect()
-    torch.cuda.empty_cache()
-
-
-@triton.jit
-def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
-    idx = tl.arange(0, n)
-    x = tl.load(input + idx)
-    y = _uniform_to_exponential(x)
-    tl.store(output + idx, y)
-
-
-def test_uniform_to_exponential():
-    """Test that we can convert uniform to exponential without div by 0."""
-    input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
-                         dtype=torch.float32,
-                         device="cuda")
-    output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
-    _uniform_to_exponential_kernel[(1, )](input, output, 2)
-    assert torch.all(torch.isfinite(output))
-    assert torch.all(output > 0)
-    assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
-
-
-@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
-@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
-@pytest.mark.parametrize("modify_greedy_probs", [True, False])
-@pytest.mark.parametrize("seed", [1337])
-@pytest.mark.parametrize("vocab_size",
-                         [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
-@pytest.mark.parametrize("save_logprobs", [True, False])
-def test_sample_decoding_only(random_sampling, max_best_of,
-                              modify_greedy_probs, seed, vocab_size,
-                              save_logprobs):
-    set_random_seed(seed)
-    bs = 8
-    probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
-    for i in range(bs):
-        probs[i, i * (vocab_size // bs)] = 1.0
-    logprobs = torch.rand_like(probs)
-    sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
-    n_splits = get_num_triton_sampler_splits(probs.shape[1])
-    if random_sampling == "mixed":
-        random_sampling_mask = (torch.rand(
-            (1, bs), device="cuda") < 0.5).expand(n_splits, bs)
-    elif random_sampling:
-        random_sampling_mask = torch.ones((n_splits, bs),
-                                          dtype=torch.bool,
-                                          device="cuda")
-    else:
-        random_sampling_mask = torch.zeros((n_splits, bs),
-                                           dtype=torch.bool,
-                                           device="cuda")
-
-    seeds = torch.randint(1,
-                          torch.iinfo(torch.long).max, (n_splits, bs),
-                          device="cuda").mul_(random_sampling_mask)
-    #The current _sample_triton does not utilize the
-    # libentry decoration. The purpose of adding this patch is to test
-    # the correctness of libentry.
-    with patch("aphrodite.model_executor.layers.ops.sample._sample_triton",
-               LibEntry(_sample_triton)):
-        sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
-            probs=probs,
-            logprobs=logprobs,
-            sample_indices=sample_indices,
-            seeds=seeds,
-            max_best_of=max_best_of,
-            modify_greedy_probs=modify_greedy_probs,
-            save_logprobs=save_logprobs,
-            _save_modified_probs=True)
-    assert sampled_tokens.shape == (bs, max_best_of)
-    for i in range(bs):
-        assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
-        request_uses_random_sampling = random_sampling_mask[0, i]
-        if modify_greedy_probs and not request_uses_random_sampling:
-            # If we are modifying greedy probs and the request is greedy,
-            # we want to make sure the probs tensor is modified in place
-            torch.testing.assert_close(
-                probs[i][sampled_tokens[i]],
-                torch.full_like(probs[i][sampled_tokens[i]], 1.0))
-            assert torch.sum(probs[i]) == 1.0
-            torch.testing.assert_close(
-                sampled_modified_probs[i][0],
-                torch.full_like(sampled_modified_probs[i][0], 1.0))
-        elif request_uses_random_sampling:
-            # If the request is random, we want to make sure
-            # sampled_modified_probs tensor has noise added
-            # (and thus is different from probs tensor)
-            assert not torch.allclose(sampled_modified_probs[i][0],
-                                      probs[i][sampled_tokens[i]])
-        elif not request_uses_random_sampling:
-            # If the request is greedy and we are not modifying greedy probs,
-            # we want to make sure sampled_modified_probs tensor is the same as
-            # the probs tensor.
-            torch.testing.assert_close(sampled_modified_probs[i],
-                                       probs[i][sampled_tokens[i]])
-
-    if save_logprobs:
-        assert sampled_logprobs.shape == (bs, max_best_of)
-        for i in range(bs):
-            for best_of in range(max_best_of):
-                assert torch.all(sampled_logprobs[i] == logprobs[i][
-                    sampled_tokens[i, best_of]])
-    else:
-        assert sampled_logprobs is None
-
-
-@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
-@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
-@pytest.mark.parametrize("modify_greedy_probs", [True, False])
-@pytest.mark.parametrize("seed", [1337])
-@pytest.mark.parametrize("vocab_size",
-                         [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
-def test_sample_prompt_logprobs(random_sampling, max_best_of,
-                                modify_greedy_probs, seed, vocab_size):
-
-    set_random_seed(seed)
-    prompt_sizes = [16, 32, 64, 128] * 2
-    samples = 8
-    bs = samples + sum(prompt_sizes)
-    probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
-    for i in range(bs):
-        probs[i, i * (vocab_size // bs)] = 1.0
-    logprobs = torch.rand_like(probs)
-    sample_indices = torch.tensor(prompt_sizes,
-                                  dtype=torch.long,
-                                  device="cuda").cumsum_(0)
-    n_splits = get_num_triton_sampler_splits(probs.shape[1])
-    if random_sampling == "mixed":
-        random_sampling_mask = torch.rand(
-            (n_splits, samples), device="cuda") < 0.5
-    elif random_sampling:
-        random_sampling_mask = torch.ones((n_splits, samples),
-                                          dtype=torch.bool,
-                                          device="cuda")
-    else:
-        random_sampling_mask = torch.zeros((n_splits, samples),
-                                           dtype=torch.bool,
-                                           device="cuda")
-
-    seeds = torch.randint(1,
-                          torch.iinfo(torch.long).max, (n_splits, samples),
-                          device="cuda").mul_(random_sampling_mask)
-    #ditto
-    with patch("aphrodite.model_executor.layers.ops.sample._sample_triton",
-               LibEntry(_sample_triton)):
-        sampled_tokens, sampled_logprobs, _ = sample(
-            probs=probs,
-            logprobs=logprobs,
-            sample_indices=sample_indices,
-            seeds=seeds,
-            max_best_of=max_best_of,
-            modify_greedy_probs=modify_greedy_probs,
-            save_logprobs=True)
-    assert sampled_tokens.shape == (samples, max_best_of)
-    assert sampled_logprobs.shape == (samples, max_best_of)
-    for i, t in enumerate(sample_indices):
-        assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
-        for best_of in range(max_best_of):
-            assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
-                             [sampled_tokens[i, best_of]])
-
-
-@pytest.mark.parametrize("seed", list(range(16)))
-def test_get_sequence_seeds(seed):
-    """Ensure that we get a different child seed from base 
-    seed + extra entropy"""
-    starting_seed = seed
-    seq_seed = None
-    extra_entropy = 1
-    for i in range(512):
-        new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
-                                                           i,
-                                                           seeds_to_generate=1,
-                                                           is_greedy=False)[0]
-        new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
-            starting_seed,
-            i,
-            extra_entropy,
-            seeds_to_generate=1,
-            is_greedy=False)[0]
-        assert new_seq_seed_extra_entropy != new_seq_seed
-        assert seq_seed != new_seq_seed
-        seq_seed = new_seq_seed