1234567891011121314151617181920212223 |
- import triton
- MAX_FUSED_SIZE : int = 65536
- next_power_of_2 = triton.next_power_of_2
- # Calculate the optimal block size and number of warps for the layernorm kernel
- # borrowed from https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/utils.py#L49-L59
- def calculate_settings(n : int) -> tuple[int, int]:
- BLOCK_SIZE : int = next_power_of_2(n)
- if BLOCK_SIZE > MAX_FUSED_SIZE:
- raise RuntimeError(
- f"Cannot launch Triton kernel since n = {n} exceeds "
- f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
- num_warps : int = 4
- if BLOCK_SIZE >= 32768:
- num_warps = 32
- elif BLOCK_SIZE >= 8192:
- num_warps = 16
- elif BLOCK_SIZE >= 2048:
- num_warps = 8
- return BLOCK_SIZE, num_warps
- pass
|