123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- from typing import Optional, Union
- import torch
- import triton
- import triton.language as tl
- def seeded_uniform(
- *size,
- seeds: torch.Tensor,
- out: Optional[torch.Tensor] = None,
- dtype: Optional[torch.dtype] = None,
- device: Optional[Union[torch.device, str]] = None,
- pin_memory: Optional[bool] = False,
- ) -> torch.Tensor:
- """Similar to torch.rand, but allows for seeds to be set per row.
- seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
- If it is 3d, the additional seeds needed will be derived automatically
- in a deterministic fashion:
- [
- row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
- ]
- """
- n_dims = len(size)
- if n_dims > 3:
- raise ValueError("seeded_uniform only supports up to 3D tensors")
- if out is None:
- out = torch.empty(*size,
- dtype=dtype,
- device=device,
- pin_memory=pin_memory)
- elif out.shape != size:
- raise ValueError("shape of out and size must be the same")
- if n_dims == 3:
- n_rows, n_3d, n_cols = out.shape
- stride_row = out.stride(0)
- stride_3d = out.stride(1)
- elif n_dims == 2:
- n_rows, n_cols = out.shape
- n_3d = 1
- stride_row = out.stride(0)
- stride_3d = 1
- else:
- n_cols = out.shape[0]
- n_rows = 1
- n_3d = 1
- stride_row = 1
- stride_3d = 1
- if seeds.ndim != 1:
- raise ValueError("seeds must be a 1D tensor")
- if seeds.numel() != n_rows:
- raise ValueError(
- "seeds must have the same number of elements as out has rows")
- # The philox PRNG Triton uses generates 4 random numbers at once.
- # Therefore, the most efficient use of it is to divide the
- # block size by 4, and then save the generated random numbers to
- # each of the 4 slices of the tensor.
- full_block_size = triton.next_power_of_2(n_cols)
- philox_block_size = max(full_block_size // 4, 1)
- n_slices = full_block_size // philox_block_size
- num_warps = 4
- # Manual tuning. This seems to give best performance on A100 for
- # simple kernels like this.
- if philox_block_size >= 8192:
- num_warps = 32
- elif philox_block_size >= 4096:
- num_warps = 16
- elif philox_block_size >= 2048:
- num_warps = 8
- _seeded_uniform_triton[(n_rows, n_3d)](
- out,
- seeds,
- stride_row,
- stride_3d,
- seeds.stride(0),
- n_rows,
- n_3d,
- n_cols,
- n_slices=n_slices,
- num_warps=num_warps,
- block_size=philox_block_size,
- )
- return out
- @triton.jit
- def _seeded_uniform_triton(
- out_ptr: torch.Tensor,
- seed_ptr: torch.Tensor,
- out_row_stride: int,
- out_3d_stride: int,
- seed_row_stride: int,
- n_rows: int,
- n_3d: int,
- n_cols: int,
- n_slices: tl.constexpr,
- block_size: tl.constexpr,
- ):
- """
- Generate a random float32 number in [0, 1) for each element in the output
- tensor. The random numbers in a row generated using the seed for that row.
- Args:
- out_ptr: The output tensor.
- seed_ptr: The per-row seeds to use for random number generation.
- out_row_stride: The stride between rows of the output tensor.
- out_3d_stride: The stride between 3D slices of the output tensor.
- seed_row_stride: The stride between rows of the seed tensor.
- n_rows: The number of rows in the output tensor.
- n_3d: The size of second dimension of the output tensor,
- if output tensor is 3D.
- n_cols: The number of columns in the output tensor.
- n_slices: The number of philox outputs to use.
- """
- tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
- # Get the row index.
- row_idx = tl.program_id(axis=0)
- three_d_idx = tl.program_id(axis=1)
- philox_offsets = tl.arange(0, block_size)
- # Get the seed for the current element.
- seed = tl.load(seed_ptr + row_idx * seed_row_stride)
- if three_d_idx > 0:
- seed ^= three_d_idx
- # Generate random numbers in [0, 1).
- out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
- output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
- three_d_idx * out_3d_stride)
- out1_offsets = philox_offsets
- tl.store(output_row_start_ptr + out1_offsets,
- out1,
- mask=out1_offsets < n_cols)
- if n_slices > 1:
- out2_offsets = tl.arange(block_size, block_size * 2)
- tl.store(output_row_start_ptr + out2_offsets,
- out2,
- mask=out2_offsets < n_cols)
- if n_slices > 2:
- out3_offsets = tl.arange(block_size * 2, block_size * 3)
- tl.store(output_row_start_ptr + out3_offsets,
- out3,
- mask=out3_offsets < n_cols)
- if n_slices > 3:
- out4_offsets = tl.arange(block_size * 3, block_size * 4)
- tl.store(output_row_start_ptr + out4_offsets,
- out4,
- mask=out4_offsets < n_cols)
|