1
0

rand.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. from typing import Optional, Union
  5. def seeded_uniform(
  6. *size,
  7. seeds: torch.Tensor,
  8. out: Optional[torch.Tensor] = None,
  9. dtype: Optional[torch.dtype] = None,
  10. device: Optional[Union[torch.device, str]] = None,
  11. pin_memory: Optional[bool] = False,
  12. ) -> torch.Tensor:
  13. """Similar to torch.rand, but allows for seeds to be set per row.
  14. seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
  15. If it is 3d, the additional seeds needed will be derived automatically
  16. in a deterministic fashion:
  17. [
  18. row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
  19. ]
  20. """
  21. n_dims = len(size)
  22. if n_dims > 3:
  23. raise ValueError("seeded_uniform only supports up to 3D tensors")
  24. if out is None:
  25. out = torch.empty(*size,
  26. dtype=dtype,
  27. device=device,
  28. pin_memory=pin_memory)
  29. elif out.shape != size:
  30. raise ValueError("shape of out and size must be the same")
  31. if n_dims == 3:
  32. n_rows, n_3d, n_cols = out.shape
  33. stride_row = out.stride(0)
  34. stride_3d = out.stride(1)
  35. elif n_dims == 2:
  36. n_rows, n_cols = out.shape
  37. n_3d = 1
  38. stride_row = out.stride(0)
  39. stride_3d = 1
  40. else:
  41. n_cols = out.shape[0]
  42. n_rows = 1
  43. n_3d = 1
  44. stride_row = 1
  45. stride_3d = 1
  46. if seeds.ndim != 1:
  47. raise ValueError("seeds must be a 1D tensor")
  48. if seeds.numel() != n_rows:
  49. raise ValueError(
  50. "seeds must have the same number of elements as out has rows")
  51. # The philox PRNG Triton uses generates 4 random numbers at once.
  52. # Therefore, the most efficient use of it is to divide the
  53. # block size by 4, and then save the generated random numbers to
  54. # each of the 4 slices of the tensor.
  55. full_block_size = triton.next_power_of_2(n_cols)
  56. philox_block_size = max(full_block_size // 4, 1)
  57. n_slices = full_block_size // philox_block_size
  58. num_warps = 4
  59. # Manual tuning. This seems to give best performance on A100 for
  60. # simple kernels like this.
  61. if philox_block_size >= 8192:
  62. num_warps = 32
  63. elif philox_block_size >= 4096:
  64. num_warps = 16
  65. elif philox_block_size >= 2048:
  66. num_warps = 8
  67. _seeded_uniform_triton[(n_rows, n_3d)](
  68. out,
  69. seeds,
  70. stride_row,
  71. stride_3d,
  72. seeds.stride(0),
  73. n_rows,
  74. n_3d,
  75. n_cols,
  76. n_slices=n_slices,
  77. num_warps=num_warps,
  78. block_size=philox_block_size,
  79. )
  80. return out
  81. @triton.jit
  82. def _seeded_uniform_triton(
  83. out_ptr: torch.Tensor,
  84. seed_ptr: torch.Tensor,
  85. out_row_stride: int,
  86. out_3d_stride: int,
  87. seed_row_stride: int,
  88. n_rows: int,
  89. n_3d: int,
  90. n_cols: int,
  91. n_slices: tl.constexpr,
  92. block_size: tl.constexpr,
  93. ):
  94. """
  95. Generate a random float32 number in [0, 1) for each element in the output
  96. tensor. The random numbers in a row generated using the seed for that row.
  97. Args:
  98. out_ptr: The output tensor.
  99. seed_ptr: The per-row seeds to use for random number generation.
  100. out_row_stride: The stride between rows of the output tensor.
  101. out_3d_stride: The stride between 3D slices of the output tensor.
  102. seed_row_stride: The stride between rows of the seed tensor.
  103. n_rows: The number of rows in the output tensor.
  104. n_3d: The size of second dimension of the output tensor,
  105. if output tensor is 3D.
  106. n_cols: The number of columns in the output tensor.
  107. n_slices: The number of philox outputs to use.
  108. """
  109. tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
  110. # Get the row index.
  111. row_idx = tl.program_id(axis=0)
  112. three_d_idx = tl.program_id(axis=1)
  113. philox_offsets = tl.arange(0, block_size)
  114. # Get the seed for the current element.
  115. seed = tl.load(seed_ptr + row_idx * seed_row_stride)
  116. if three_d_idx > 0:
  117. seed ^= three_d_idx
  118. # Generate random numbers in [0, 1).
  119. out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
  120. output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
  121. three_d_idx * out_3d_stride)
  122. out1_offsets = philox_offsets
  123. tl.store(output_row_start_ptr + out1_offsets,
  124. out1,
  125. mask=out1_offsets < n_cols)
  126. if n_slices > 1:
  127. out2_offsets = tl.arange(block_size, block_size * 2)
  128. tl.store(output_row_start_ptr + out2_offsets,
  129. out2,
  130. mask=out2_offsets < n_cols)
  131. if n_slices > 2:
  132. out3_offsets = tl.arange(block_size * 2, block_size * 3)
  133. tl.store(output_row_start_ptr + out3_offsets,
  134. out3,
  135. mask=out3_offsets < n_cols)
  136. if n_slices > 3:
  137. out4_offsets = tl.arange(block_size * 3, block_size * 4)
  138. tl.store(output_row_start_ptr + out4_offsets,
  139. out4,
  140. mask=out4_offsets < n_cols)