sample.py 422 B

123456789101112
  1. import math
  2. # This is a hardcoded limit in Triton (max block size).
  3. MAX_TRITON_N_COLS = 131072
  4. def get_num_triton_sampler_splits(n_cols: int) -> int:
  5. """Get the number of splits to use for Triton sampling.
  6. Triton has a limit on the number of columns it can handle, so we need to
  7. split the tensor and call the kernel multiple times if it's too large.
  8. """
  9. return math.ceil(n_cols / MAX_TRITON_N_COLS)