utils.py 805 B

1234567891011121314151617181920212223
  1. import triton
  2. MAX_FUSED_SIZE : int = 65536
  3. next_power_of_2 = triton.next_power_of_2
  4. # Calculate the optimal block size and number of warps for the layernorm kernel
  5. # borrowed from https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/utils.py#L49-L59
  6. def calculate_settings(n : int) -> tuple[int, int]:
  7. BLOCK_SIZE : int = next_power_of_2(n)
  8. if BLOCK_SIZE > MAX_FUSED_SIZE:
  9. raise RuntimeError(
  10. f"Cannot launch Triton kernel since n = {n} exceeds "
  11. f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
  12. num_warps : int = 4
  13. if BLOCK_SIZE >= 32768:
  14. num_warps = 32
  15. elif BLOCK_SIZE >= 8192:
  16. num_warps = 16
  17. elif BLOCK_SIZE >= 2048:
  18. num_warps = 8
  19. return BLOCK_SIZE, num_warps
  20. pass