123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- from typing import Optional, Tuple
- import torch
- import triton
- import triton.language as tl
- from aphrodite.modeling.layers.ops.rand import seeded_uniform
- from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
- _EPS = 1e-6
- def _multi_split_sample(
- probs: torch.Tensor,
- seeds: torch.Tensor,
- n_splits: int,
- sampled_tokens_size: Tuple[int, int],
- sampled_logprobs_size: Tuple[int, int],
- sample_indices: torch.Tensor,
- logprobs: torch.Tensor,
- *,
- modify_greedy_probs: bool = False,
- save_logprobs: bool = False,
- ):
- """Sample tokens where vocab size is split into multiple parts
- (too large for Triton otherwise)."""
- assert seeds.ndim == 2 and seeds.shape[0] == n_splits
- split_probs = probs.tensor_split(n_splits, 1)
- split_logprobs = logprobs.tensor_split(n_splits, 1)
- sampled_tokens_tmp = [
- torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
- for _ in range(n_splits)
- ]
- sampled_logprobs_tmp = [
- torch.empty(sampled_logprobs_size,
- dtype=probs.dtype,
- device=probs.device) for _ in range(n_splits)
- ]
- # We are purposefuly using sampled_tokens_size as we need to always
- # save modified probs in this case.
- sampled_modified_probs_tmp = [
- torch.empty(sampled_tokens_size,
- dtype=probs.dtype,
- device=probs.device) for _ in range(n_splits)
- ]
- for i in range(n_splits):
- n_samples = sample_indices.shape[0]
- n_cols = split_probs[i].shape[1]
- n_best = sampled_tokens_tmp[i].shape[1]
- uniform_noise = seeded_uniform(n_samples,
- n_best,
- n_cols,
- seeds=seeds[i].flatten(),
- device=split_probs[i].device,
- dtype=split_probs[i].dtype)
- # TODO: See if we can remove the contiguous() calls.
- # Will need kernel support.
- _sample(
- split_probs[i].contiguous(),
- split_logprobs[i].contiguous(),
- sample_indices,
- sampled_tokens_tmp[i],
- sampled_logprobs_tmp[i],
- sampled_modified_probs_tmp[i],
- seeds[i],
- uniform_noise,
- modify_greedy_probs=False,
- save_logprobs=save_logprobs,
- save_modified_probs=True,
- )
- if i > 0:
- # Add offset to sampled tokens
- sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
- sampled_tokens = torch.stack(sampled_tokens_tmp)
- sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
- # Reduce the results from the splits.
- sampled_modified_probs, indices = torch.max(sampled_modified_probs,
- dim=0,
- keepdim=True)
- sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
- if save_logprobs:
- sampled_logprobs = torch.stack(sampled_logprobs_tmp)
- sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
- else:
- sampled_logprobs = None
- sampled_modified_probs = sampled_modified_probs.squeeze(0)
- if modify_greedy_probs:
- # We need to modify the greedy probs for the sampled tokens.
- # We can't do this in the kernel as we need to know the
- # sampled tokens.
- probs.fill_(0.0)
- probs.scatter_(1, sampled_tokens, 1.0)
- return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
- def sample(
- probs: torch.Tensor,
- seeds: torch.Tensor,
- *,
- max_best_of: int = 1,
- sample_indices: Optional[torch.Tensor] = None,
- logprobs: Optional[torch.Tensor] = None,
- modify_greedy_probs: bool = False,
- save_logprobs: bool = False,
- _save_modified_probs: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
- """Sample tokens from probs. with per-sequence seeds.
- Can sample from a subset of sequences through sample_indices.
- Args:
- probs: Probabilities to sample from.
- shape = [batch_size, vocab_size]
- seeds: Per-sequence seed values.
- shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
- max_best_of: Number of samples to generate per sequence.
- Sequence seed will be incremented by 1 each time.
- sample_indices: Indices of sequences to sample from.
- If not provided, will sample from all sequences.
- shape = [n]
- logprobs: Log-probabilities of the sampled tokens.
- Only used for saving the logprobs if save_logprobs is True.
- shape = [batch_size, vocab_size]
- modify_greedy_probs: Whether to modify the greedy probabilities
- for speculative sampling (sampled token = 1.0,
- everything else = 0.0).
- save_logprobs: Whether to save the log-probabilities of the
- sampled tokens to a tensor.
- _save_modified_probs: Whether to save the modified probabilities
- (including gumbel noise) of the sampled tokens to a tensor.
- DOES NOT include the modification done by modify_greedy_probs
- (because we want to use the unmodified probs to pick the best
- split in case of multi-split sampling).
- This is exposed only for testing.
- Returns:
- sampled_tokens: shape = [n, max_best_of]
- sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
- sampled_modified_probs: shape = [n, max_best_of]
- if save_modified_probs else None
- """
- if sample_indices is None:
- sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
- sampled_tokens_size = (sample_indices.size(0), max_best_of)
- if save_logprobs:
- if logprobs is None:
- raise ValueError(
- "logprobs tensor must be provided if save_logprobs is True")
- sampled_logprobs_size = sampled_tokens_size
- else:
- # Empty tensors to invoke the kernel
- sampled_logprobs_size = (0, 0)
- logprobs = probs
- assert logprobs is not None
- if _save_modified_probs:
- sampled_modified_probs_size = sampled_tokens_size
- else:
- # Empty tensors to invoke the kernel
- sampled_modified_probs_size = (0, 0)
- # If the number of columns in probs is too large for Triton to handle,
- # we split the tensor and sample from each split separately, and then
- # do an argmax+gather to combine the results.
- n_splits = get_num_triton_sampler_splits(probs.shape[1])
- if n_splits > 1:
- (sampled_tokens, sampled_logprobs,
- sampled_modified_probs) = _multi_split_sample(
- probs,
- seeds,
- n_splits,
- sampled_tokens_size,
- sampled_logprobs_size,
- sample_indices,
- logprobs=logprobs,
- modify_greedy_probs=modify_greedy_probs,
- save_logprobs=save_logprobs)
- else:
- sampled_tokens = torch.empty(sampled_tokens_size,
- dtype=torch.long,
- device=probs.device)
- sampled_logprobs = torch.empty(sampled_logprobs_size,
- dtype=probs.dtype,
- device=probs.device)
- sampled_modified_probs = torch.empty(sampled_modified_probs_size,
- dtype=probs.dtype,
- device=probs.device)
- n_samples = sample_indices.shape[0]
- n_cols = probs.shape[1]
- uniform_noise = seeded_uniform(n_samples,
- max_best_of,
- n_cols,
- seeds=seeds.flatten(),
- device=probs.device,
- dtype=probs.dtype)
- _sample(
- probs,
- logprobs,
- sample_indices,
- sampled_tokens,
- sampled_logprobs,
- sampled_modified_probs,
- seeds,
- uniform_noise,
- modify_greedy_probs=modify_greedy_probs,
- save_logprobs=save_logprobs,
- save_modified_probs=_save_modified_probs,
- )
- return (sampled_tokens, sampled_logprobs if save_logprobs else None,
- sampled_modified_probs if _save_modified_probs else None)
- def _sample(probs: torch.Tensor,
- logprobs: torch.Tensor,
- sample_indices: torch.Tensor,
- output_samples: torch.Tensor,
- output_logprobs: torch.Tensor,
- output_modified_probs: torch.Tensor,
- seeds: torch.Tensor,
- uniform_noise: torch.Tensor,
- *,
- modify_greedy_probs: bool = False,
- save_logprobs: bool = True,
- save_modified_probs: bool = False) -> torch.Tensor:
- """Sample tokens from probs.
- Args:
- probs [batch_size, vocab_size]: probs to sample from.
- logprobs [batch_size, vocab_size]: logprobs (used when
- save_logprobsis True).
- sample_indices [n]: Indices of the samples to use for each row of probs.
- output_samples [n, n_best]: Output tensor to store samples in.
- output_logprobs [n, n_best]: Output tensor to store logprobs in.
- output_modified_probs [n, n_best]: Output tensor to store
- probs of chosen tokens in (modified with noise).
- seeds [n]: Seeds to use for sampling. If the seed is 0, we use
- greedy sampling. Note this is ONLY used for determining
- whether to use random sampling or not. The actual random
- noise should be passed as uniform_noise.
- uniform_noise [batch_size, n_best, vocab_size]: Uniform
- noise to use for random sampling (will be converted
- to exponential gumbel noise by the kernel).
- modify_greedy_probs: If True, we modify the probs tensor in-place
- to encode the sampling method used for each row. This is used
- in speculative decoding. Only applies in greedy decoding.
- save_logprobs: If True, we save the logprobs of the sampled tokens
- in the output_logprobs tensor.
- save_modified_probs: If True, we save the modified probs (with noise)
- of the sampled tokens in the output_modified_probs tensor.
- DOES NOT include the modification done by modify_greedy_probs
- (because we want to use the unmodified probs to pick the best
- split in case of multi-split sampling).
- """
- n_samples = sample_indices.shape[0]
- n_cols = probs.shape[1]
- n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
- # The block size is the smallest power of two greater than the number of
- # columns in probs
- block_size = triton.next_power_of_2(n_cols)
- num_warps = 4
- # Manual tuning. This seems to give best performance on A100 for
- # simple kernels like this.
- if block_size >= 8192:
- num_warps = 32
- elif block_size >= 4096:
- num_warps = 16
- elif block_size >= 2048:
- num_warps = 8
- # Enqueue kernel. The 1D launch grid is simple: we have one kernel
- # instance per row of the probs matrix
- _sample_triton[(n_samples, n_best)](
- sample_indices,
- output_samples,
- output_logprobs,
- output_modified_probs,
- probs,
- logprobs,
- seeds,
- uniform_noise,
- output_samples.stride(0),
- probs.stride(0),
- uniform_noise.stride(0),
- uniform_noise.stride(1) if n_best > 1 else 1,
- n_samples,
- n_cols,
- n_best,
- num_warps=num_warps,
- block_size=block_size,
- modify_greedy_probs=modify_greedy_probs,
- save_logprobs=save_logprobs,
- save_modified_probs=save_modified_probs,
- )
- return output_samples, output_logprobs, output_modified_probs
- @triton.jit
- def _uniform_to_exponential(uniform_noise):
- """Convert uniform samples to exponential samples."""
- # tl.rand returns values in [0, 1), so we clamp lower bound
- # to _EPS to avoid log(0) and thus division by 0 later
- lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
- uniform_noise = tl.maximum(uniform_noise, lb)
- # Use the inversion method to turn uniform samples
- # into exponential samples
- exponential_noise = -tl.log(uniform_noise)
- return exponential_noise
- @triton.jit
- def _sample_triton(
- sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
- output_logprobs_ptr: torch.Tensor,
- output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
- logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
- uniform_noise_ptr: torch.Tensor, output_row_stride: int,
- probs_row_stride: int, uniform_noise_row_stride: int,
- uniform_noise_best_stride: int, n_samples: int, n_cols: int,
- n_best: int, block_size: tl.constexpr,
- modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
- save_modified_probs: tl.constexpr):
- # The rows are independent, so we parallelize across those
- sample_idx = tl.program_id(0)
- best_idx = tl.program_id(1)
- # Load the row index from DRAM
- row_idx = tl.load(sample_indices_ptr + sample_idx)
- seed = tl.load(seeds_ptr + sample_idx)
- uses_random_sampling = seed != 0
- # The stride represents how much we need to increase the
- # pointer to advance 1 row
- row_start_ptr = probs_ptr + row_idx * probs_row_stride
- # The block size is the next power of two greater than n_cols,
- # so we can fit each row in a single block
- col_offsets = tl.arange(0, block_size)
- # Load the row into SRAM, using a mask since block_size may be > than n_cols
- row = tl.load(row_start_ptr + col_offsets,
- mask=col_offsets < n_cols,
- other=float("-inf"))
- if uses_random_sampling:
- uniform_noise_start_ptr = (uniform_noise_ptr +
- sample_idx * uniform_noise_row_stride +
- best_idx * uniform_noise_best_stride)
- uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
- mask=col_offsets < n_cols,
- other=0.5)
- exponential_noise = _uniform_to_exponential(uniform_noise)
- row /= exponential_noise
- sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
- # clamp sampled token to n_cols - 1
- # this should not be necessary, but we do it
- # just in case
- if sampled_token >= n_cols:
- sampled_token = n_cols - 1
- # Write back output to DRAM
- output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
- best_idx)
- tl.store(output_row_start_ptr, sampled_token)
- if modify_greedy_probs: # noqa
- if not uses_random_sampling:
- # Set the probability of the sampled token to 1, all other
- # tokens to zero. This is used in speculative decoding where
- # the sampling method must be encoded within the sampled
- # probability distributions.
- row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
- tl.store(row_start_ptr + col_offsets,
- row,
- mask=col_offsets < n_cols)
- if save_modified_probs:
- output_row_start_ptr = (output_modified_probs_ptr +
- sample_idx * output_row_stride + best_idx)
- tl.store(output_row_start_ptr, sampled_value)
- if save_logprobs:
- # Load the row into SRAM, using a mask since block_size
- # may be > than n_cols
- sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
- sampled_token)
- # Write back output to DRAM
- output_row_start_ptr = (output_logprobs_ptr +
- sample_idx * output_row_stride + best_idx)
- tl.store(output_row_start_ptr, sampled_logprob)
|