sample.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. import math
  2. from typing import Optional, Tuple
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. from aphrodite.modeling.layers.ops.rand import seeded_uniform
  7. _EPS = 1e-6
  8. # This is a hardcoded limit in Triton (max block size).
  9. MAX_TRITON_N_COLS = 131072
  10. def get_num_triton_sampler_splits(n_cols: int) -> int:
  11. """Get the number of splits to use for Triton sampling.
  12. Triton has a limit on the number of columns it can handle, so we need to
  13. split the tensor and call the kernel multiple times if it's too large.
  14. """
  15. return math.ceil(n_cols / MAX_TRITON_N_COLS)
  16. def _multi_split_sample(
  17. probs: torch.Tensor,
  18. seeds: torch.Tensor,
  19. n_splits: int,
  20. sampled_tokens_size: Tuple[int, int],
  21. sampled_logprobs_size: Tuple[int, int],
  22. sample_indices: torch.Tensor,
  23. logprobs: torch.Tensor,
  24. *,
  25. modify_greedy_probs: bool = False,
  26. save_logprobs: bool = False,
  27. ):
  28. """Sample tokens where vocab size is split into multiple parts
  29. (too large for Triton otherwise)."""
  30. assert seeds.ndim == 2 and seeds.shape[0] == n_splits
  31. split_probs = probs.tensor_split(n_splits, 1)
  32. split_logprobs = logprobs.tensor_split(n_splits, 1)
  33. sampled_tokens_tmp = [
  34. torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
  35. for _ in range(n_splits)
  36. ]
  37. sampled_logprobs_tmp = [
  38. torch.empty(sampled_logprobs_size,
  39. dtype=probs.dtype,
  40. device=probs.device) for _ in range(n_splits)
  41. ]
  42. # We are purposefuly using sampled_tokens_size as we need to always
  43. # save modified probs in this case.
  44. sampled_modified_probs_tmp = [
  45. torch.empty(sampled_tokens_size,
  46. dtype=probs.dtype,
  47. device=probs.device) for _ in range(n_splits)
  48. ]
  49. for i in range(n_splits):
  50. n_samples = sample_indices.shape[0]
  51. n_cols = split_probs[i].shape[1]
  52. n_best = sampled_tokens_tmp[i].shape[1]
  53. uniform_noise = seeded_uniform(n_samples,
  54. n_best,
  55. n_cols,
  56. seeds=seeds[i].flatten(),
  57. device=split_probs[i].device,
  58. dtype=split_probs[i].dtype)
  59. # TODO: See if we can remove the contiguous() calls.
  60. # Will need kernel support.
  61. _sample(
  62. split_probs[i].contiguous(),
  63. split_logprobs[i].contiguous(),
  64. sample_indices,
  65. sampled_tokens_tmp[i],
  66. sampled_logprobs_tmp[i],
  67. sampled_modified_probs_tmp[i],
  68. seeds[i],
  69. uniform_noise,
  70. modify_greedy_probs=False,
  71. save_logprobs=save_logprobs,
  72. save_modified_probs=True,
  73. )
  74. if i > 0:
  75. # Add offset to sampled tokens
  76. sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
  77. sampled_tokens = torch.stack(sampled_tokens_tmp)
  78. sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
  79. # Reduce the results from the splits.
  80. sampled_modified_probs, indices = torch.max(sampled_modified_probs,
  81. dim=0,
  82. keepdim=True)
  83. sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
  84. if save_logprobs:
  85. sampled_logprobs = torch.stack(sampled_logprobs_tmp)
  86. sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
  87. else:
  88. sampled_logprobs = None
  89. sampled_modified_probs = sampled_modified_probs.squeeze(0)
  90. if modify_greedy_probs:
  91. # We need to modify the greedy probs for the sampled tokens.
  92. # We can't do this in the kernel as we need to know the
  93. # sampled tokens.
  94. probs.fill_(0.0)
  95. probs.scatter_(1, sampled_tokens, 1.0)
  96. return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
  97. def sample(
  98. probs: torch.Tensor,
  99. seeds: torch.Tensor,
  100. *,
  101. max_best_of: int = 1,
  102. sample_indices: Optional[torch.Tensor] = None,
  103. logprobs: Optional[torch.Tensor] = None,
  104. modify_greedy_probs: bool = False,
  105. save_logprobs: bool = False,
  106. _save_modified_probs: bool = False,
  107. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
  108. """Sample tokens from probs. with per-sequence seeds.
  109. Can sample from a subset of sequences through sample_indices.
  110. Args:
  111. probs: Probabilities to sample from.
  112. shape = [batch_size, vocab_size]
  113. seeds: Per-sequence seed values.
  114. shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
  115. max_best_of: Number of samples to generate per sequence.
  116. Sequence seed will be incremented by 1 each time.
  117. sample_indices: Indices of sequences to sample from.
  118. If not provided, will sample from all sequences.
  119. shape = [n]
  120. logprobs: Log-probabilities of the sampled tokens.
  121. Only used for saving the logprobs if save_logprobs is True.
  122. shape = [batch_size, vocab_size]
  123. modify_greedy_probs: Whether to modify the greedy probabilities
  124. for speculative sampling (sampled token = 1.0,
  125. everything else = 0.0).
  126. save_logprobs: Whether to save the log-probabilities of the
  127. sampled tokens to a tensor.
  128. _save_modified_probs: Whether to save the modified probabilities
  129. (including gumbel noise) of the sampled tokens to a tensor.
  130. DOES NOT include the modification done by modify_greedy_probs
  131. (because we want to use the unmodified probs to pick the best
  132. split in case of multi-split sampling).
  133. This is exposed only for testing.
  134. Returns:
  135. sampled_tokens: shape = [n, max_best_of]
  136. sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
  137. sampled_modified_probs: shape = [n, max_best_of]
  138. if save_modified_probs else None
  139. """
  140. if sample_indices is None:
  141. sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
  142. sampled_tokens_size = (sample_indices.size(0), max_best_of)
  143. if save_logprobs:
  144. if logprobs is None:
  145. raise ValueError(
  146. "logprobs tensor must be provided if save_logprobs is True")
  147. sampled_logprobs_size = sampled_tokens_size
  148. else:
  149. # Empty tensors to invoke the kernel
  150. sampled_logprobs_size = (0, 0)
  151. logprobs = probs
  152. assert logprobs is not None
  153. if _save_modified_probs:
  154. sampled_modified_probs_size = sampled_tokens_size
  155. else:
  156. # Empty tensors to invoke the kernel
  157. sampled_modified_probs_size = (0, 0)
  158. # If the number of columns in probs is too large for Triton to handle,
  159. # we split the tensor and sample from each split separately, and then
  160. # do an argmax+gather to combine the results.
  161. n_splits = get_num_triton_sampler_splits(probs.shape[1])
  162. if n_splits > 1:
  163. (sampled_tokens, sampled_logprobs,
  164. sampled_modified_probs) = _multi_split_sample(
  165. probs,
  166. seeds,
  167. n_splits,
  168. sampled_tokens_size,
  169. sampled_logprobs_size,
  170. sample_indices,
  171. logprobs=logprobs,
  172. modify_greedy_probs=modify_greedy_probs,
  173. save_logprobs=save_logprobs)
  174. else:
  175. sampled_tokens = torch.empty(sampled_tokens_size,
  176. dtype=torch.long,
  177. device=probs.device)
  178. sampled_logprobs = torch.empty(sampled_logprobs_size,
  179. dtype=probs.dtype,
  180. device=probs.device)
  181. sampled_modified_probs = torch.empty(sampled_modified_probs_size,
  182. dtype=probs.dtype,
  183. device=probs.device)
  184. n_samples = sample_indices.shape[0]
  185. n_cols = probs.shape[1]
  186. uniform_noise = seeded_uniform(n_samples,
  187. max_best_of,
  188. n_cols,
  189. seeds=seeds.flatten(),
  190. device=probs.device,
  191. dtype=probs.dtype)
  192. _sample(
  193. probs,
  194. logprobs,
  195. sample_indices,
  196. sampled_tokens,
  197. sampled_logprobs,
  198. sampled_modified_probs,
  199. seeds,
  200. uniform_noise,
  201. modify_greedy_probs=modify_greedy_probs,
  202. save_logprobs=save_logprobs,
  203. save_modified_probs=_save_modified_probs,
  204. )
  205. return (sampled_tokens, sampled_logprobs if save_logprobs else None,
  206. sampled_modified_probs if _save_modified_probs else None)
  207. def _sample(probs: torch.Tensor,
  208. logprobs: torch.Tensor,
  209. sample_indices: torch.Tensor,
  210. output_samples: torch.Tensor,
  211. output_logprobs: torch.Tensor,
  212. output_modified_probs: torch.Tensor,
  213. seeds: torch.Tensor,
  214. uniform_noise: torch.Tensor,
  215. *,
  216. modify_greedy_probs: bool = False,
  217. save_logprobs: bool = True,
  218. save_modified_probs: bool = False) -> torch.Tensor:
  219. """Sample tokens from probs.
  220. Args:
  221. probs [batch_size, vocab_size]: probs to sample from.
  222. logprobs [batch_size, vocab_size]: logprobs (used when
  223. save_logprobsis True).
  224. sample_indices [n]: Indices of the samples to use for each row of probs.
  225. output_samples [n, n_best]: Output tensor to store samples in.
  226. output_logprobs [n, n_best]: Output tensor to store logprobs in.
  227. output_modified_probs [n, n_best]: Output tensor to store
  228. probs of chosen tokens in (modified with noise).
  229. seeds [n]: Seeds to use for sampling. If the seed is 0, we use
  230. greedy sampling. Note this is ONLY used for determining
  231. whether to use random sampling or not. The actual random
  232. noise should be passed as uniform_noise.
  233. uniform_noise [batch_size, n_best, vocab_size]: Uniform
  234. noise to use for random sampling (will be converted
  235. to exponential gumbel noise by the kernel).
  236. modify_greedy_probs: If True, we modify the probs tensor in-place
  237. to encode the sampling method used for each row. This is used
  238. in speculative decoding. Only applies in greedy decoding.
  239. save_logprobs: If True, we save the logprobs of the sampled tokens
  240. in the output_logprobs tensor.
  241. save_modified_probs: If True, we save the modified probs (with noise)
  242. of the sampled tokens in the output_modified_probs tensor.
  243. DOES NOT include the modification done by modify_greedy_probs
  244. (because we want to use the unmodified probs to pick the best
  245. split in case of multi-split sampling).
  246. """
  247. n_samples = sample_indices.shape[0]
  248. n_cols = probs.shape[1]
  249. n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
  250. # The block size is the smallest power of two greater than the number of
  251. # columns in probs
  252. block_size = triton.next_power_of_2(n_cols)
  253. num_warps = 4
  254. # Manual tuning. This seems to give best performance on A100 for
  255. # simple kernels like this.
  256. if block_size >= 8192:
  257. num_warps = 32
  258. elif block_size >= 4096:
  259. num_warps = 16
  260. elif block_size >= 2048:
  261. num_warps = 8
  262. # Enqueue kernel. The 1D launch grid is simple: we have one kernel
  263. # instance per row of the probs matrix
  264. _sample_triton[(n_samples, n_best)](
  265. sample_indices,
  266. output_samples,
  267. output_logprobs,
  268. output_modified_probs,
  269. probs,
  270. logprobs,
  271. seeds,
  272. uniform_noise,
  273. output_samples.stride(0),
  274. probs.stride(0),
  275. uniform_noise.stride(0),
  276. uniform_noise.stride(1) if n_best > 1 else 1,
  277. n_samples,
  278. n_cols,
  279. n_best,
  280. num_warps=num_warps,
  281. block_size=block_size,
  282. modify_greedy_probs=modify_greedy_probs,
  283. save_logprobs=save_logprobs,
  284. save_modified_probs=save_modified_probs,
  285. )
  286. return output_samples, output_logprobs, output_modified_probs
  287. @triton.jit
  288. def _uniform_to_exponential(uniform_noise):
  289. """Convert uniform samples to exponential samples."""
  290. # tl.rand returns values in [0, 1), so we clamp lower bound
  291. # to _EPS to avoid log(0) and thus division by 0 later
  292. lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
  293. uniform_noise = tl.maximum(uniform_noise, lb)
  294. # Use the inversion method to turn uniform samples
  295. # into exponential samples
  296. exponential_noise = -tl.log(uniform_noise)
  297. return exponential_noise
  298. @triton.jit
  299. def _sample_triton(
  300. sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
  301. output_logprobs_ptr: torch.Tensor,
  302. output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
  303. logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
  304. uniform_noise_ptr: torch.Tensor, output_row_stride: int,
  305. probs_row_stride: int, uniform_noise_row_stride: int,
  306. uniform_noise_best_stride: int, n_samples: int, n_cols: int,
  307. n_best: int, block_size: tl.constexpr,
  308. modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
  309. save_modified_probs: tl.constexpr):
  310. # The rows are independent, so we parallelize across those
  311. sample_idx = tl.program_id(0)
  312. best_idx = tl.program_id(1)
  313. # Load the row index from DRAM
  314. row_idx = tl.load(sample_indices_ptr + sample_idx)
  315. seed = tl.load(seeds_ptr + sample_idx)
  316. uses_random_sampling = seed != 0
  317. # The stride represents how much we need to increase the
  318. # pointer to advance 1 row
  319. row_start_ptr = probs_ptr + row_idx * probs_row_stride
  320. # The block size is the next power of two greater than n_cols,
  321. # so we can fit each row in a single block
  322. col_offsets = tl.arange(0, block_size)
  323. # Load the row into SRAM, using a mask since block_size may be > than n_cols
  324. row = tl.load(row_start_ptr + col_offsets,
  325. mask=col_offsets < n_cols,
  326. other=float("-inf"))
  327. if uses_random_sampling:
  328. uniform_noise_start_ptr = (uniform_noise_ptr +
  329. sample_idx * uniform_noise_row_stride +
  330. best_idx * uniform_noise_best_stride)
  331. uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
  332. mask=col_offsets < n_cols,
  333. other=0.5)
  334. exponential_noise = _uniform_to_exponential(uniform_noise)
  335. row /= exponential_noise
  336. sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
  337. # clamp sampled token to n_cols - 1
  338. # this should not be necessary, but we do it
  339. # just in case
  340. if sampled_token >= n_cols:
  341. sampled_token = n_cols - 1
  342. # Write back output to DRAM
  343. output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
  344. best_idx)
  345. tl.store(output_row_start_ptr, sampled_token)
  346. if modify_greedy_probs: # noqa
  347. if not uses_random_sampling:
  348. # Set the probability of the sampled token to 1, all other
  349. # tokens to zero. This is used in speculative decoding where
  350. # the sampling method must be encoded within the sampled
  351. # probability distributions.
  352. row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
  353. tl.store(row_start_ptr + col_offsets,
  354. row,
  355. mask=col_offsets < n_cols)
  356. if save_modified_probs:
  357. output_row_start_ptr = (output_modified_probs_ptr +
  358. sample_idx * output_row_stride + best_idx)
  359. tl.store(output_row_start_ptr, sampled_value)
  360. if save_logprobs:
  361. # Load the row into SRAM, using a mask since block_size
  362. # may be > than n_cols
  363. sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
  364. sampled_token)
  365. # Write back output to DRAM
  366. output_row_start_ptr = (output_logprobs_ptr +
  367. sample_idx * output_row_stride + best_idx)
  368. tl.store(output_row_start_ptr, sampled_logprob)