123456789101112 |
- import math
- # This is a hardcoded limit in Triton (max block size).
- MAX_TRITON_N_COLS = 131072
- def get_num_triton_sampler_splits(n_cols: int) -> int:
- """Get the number of splits to use for Triton sampling.
- Triton has a limit on the number of columns it can handle, so we need to
- split the tensor and call the kernel multiple times if it's too large.
- """
- return math.ceil(n_cols / MAX_TRITON_N_COLS)
|