sample.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from typing import Optional, Tuple
  2. import torch
  3. import triton
  4. import triton.language as tl
  5. from aphrodite.modeling.layers.ops.rand import seeded_uniform
  6. from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
  7. _EPS = 1e-6
  8. def _multi_split_sample(
  9. probs: torch.Tensor,
  10. seeds: torch.Tensor,
  11. n_splits: int,
  12. sampled_tokens_size: Tuple[int, int],
  13. sampled_logprobs_size: Tuple[int, int],
  14. sample_indices: torch.Tensor,
  15. logprobs: torch.Tensor,
  16. *,
  17. modify_greedy_probs: bool = False,
  18. save_logprobs: bool = False,
  19. ):
  20. """Sample tokens where vocab size is split into multiple parts
  21. (too large for Triton otherwise)."""
  22. assert seeds.ndim == 2 and seeds.shape[0] == n_splits
  23. split_probs = probs.tensor_split(n_splits, 1)
  24. split_logprobs = logprobs.tensor_split(n_splits, 1)
  25. sampled_tokens_tmp = [
  26. torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
  27. for _ in range(n_splits)
  28. ]
  29. sampled_logprobs_tmp = [
  30. torch.empty(sampled_logprobs_size,
  31. dtype=probs.dtype,
  32. device=probs.device) for _ in range(n_splits)
  33. ]
  34. # We are purposefuly using sampled_tokens_size as we need to always
  35. # save modified probs in this case.
  36. sampled_modified_probs_tmp = [
  37. torch.empty(sampled_tokens_size,
  38. dtype=probs.dtype,
  39. device=probs.device) for _ in range(n_splits)
  40. ]
  41. for i in range(n_splits):
  42. n_samples = sample_indices.shape[0]
  43. n_cols = split_probs[i].shape[1]
  44. n_best = sampled_tokens_tmp[i].shape[1]
  45. uniform_noise = seeded_uniform(n_samples,
  46. n_best,
  47. n_cols,
  48. seeds=seeds[i].flatten(),
  49. device=split_probs[i].device,
  50. dtype=split_probs[i].dtype)
  51. # TODO: See if we can remove the contiguous() calls.
  52. # Will need kernel support.
  53. _sample(
  54. split_probs[i].contiguous(),
  55. split_logprobs[i].contiguous(),
  56. sample_indices,
  57. sampled_tokens_tmp[i],
  58. sampled_logprobs_tmp[i],
  59. sampled_modified_probs_tmp[i],
  60. seeds[i],
  61. uniform_noise,
  62. modify_greedy_probs=False,
  63. save_logprobs=save_logprobs,
  64. save_modified_probs=True,
  65. )
  66. if i > 0:
  67. # Add offset to sampled tokens
  68. sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
  69. sampled_tokens = torch.stack(sampled_tokens_tmp)
  70. sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
  71. # Reduce the results from the splits.
  72. sampled_modified_probs, indices = torch.max(sampled_modified_probs,
  73. dim=0,
  74. keepdim=True)
  75. sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
  76. if save_logprobs:
  77. sampled_logprobs = torch.stack(sampled_logprobs_tmp)
  78. sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
  79. else:
  80. sampled_logprobs = None
  81. sampled_modified_probs = sampled_modified_probs.squeeze(0)
  82. if modify_greedy_probs:
  83. # We need to modify the greedy probs for the sampled tokens.
  84. # We can't do this in the kernel as we need to know the
  85. # sampled tokens.
  86. probs.fill_(0.0)
  87. probs.scatter_(1, sampled_tokens, 1.0)
  88. return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
  89. def sample(
  90. probs: torch.Tensor,
  91. seeds: torch.Tensor,
  92. *,
  93. max_best_of: int = 1,
  94. sample_indices: Optional[torch.Tensor] = None,
  95. logprobs: Optional[torch.Tensor] = None,
  96. modify_greedy_probs: bool = False,
  97. save_logprobs: bool = False,
  98. _save_modified_probs: bool = False,
  99. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
  100. """Sample tokens from probs. with per-sequence seeds.
  101. Can sample from a subset of sequences through sample_indices.
  102. Args:
  103. probs: Probabilities to sample from.
  104. shape = [batch_size, vocab_size]
  105. seeds: Per-sequence seed values.
  106. shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
  107. max_best_of: Number of samples to generate per sequence.
  108. Sequence seed will be incremented by 1 each time.
  109. sample_indices: Indices of sequences to sample from.
  110. If not provided, will sample from all sequences.
  111. shape = [n]
  112. logprobs: Log-probabilities of the sampled tokens.
  113. Only used for saving the logprobs if save_logprobs is True.
  114. shape = [batch_size, vocab_size]
  115. modify_greedy_probs: Whether to modify the greedy probabilities
  116. for speculative sampling (sampled token = 1.0,
  117. everything else = 0.0).
  118. save_logprobs: Whether to save the log-probabilities of the
  119. sampled tokens to a tensor.
  120. _save_modified_probs: Whether to save the modified probabilities
  121. (including gumbel noise) of the sampled tokens to a tensor.
  122. DOES NOT include the modification done by modify_greedy_probs
  123. (because we want to use the unmodified probs to pick the best
  124. split in case of multi-split sampling).
  125. This is exposed only for testing.
  126. Returns:
  127. sampled_tokens: shape = [n, max_best_of]
  128. sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
  129. sampled_modified_probs: shape = [n, max_best_of]
  130. if save_modified_probs else None
  131. """
  132. if sample_indices is None:
  133. sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
  134. sampled_tokens_size = (sample_indices.size(0), max_best_of)
  135. if save_logprobs:
  136. if logprobs is None:
  137. raise ValueError(
  138. "logprobs tensor must be provided if save_logprobs is True")
  139. sampled_logprobs_size = sampled_tokens_size
  140. else:
  141. # Empty tensors to invoke the kernel
  142. sampled_logprobs_size = (0, 0)
  143. logprobs = probs
  144. assert logprobs is not None
  145. if _save_modified_probs:
  146. sampled_modified_probs_size = sampled_tokens_size
  147. else:
  148. # Empty tensors to invoke the kernel
  149. sampled_modified_probs_size = (0, 0)
  150. # If the number of columns in probs is too large for Triton to handle,
  151. # we split the tensor and sample from each split separately, and then
  152. # do an argmax+gather to combine the results.
  153. n_splits = get_num_triton_sampler_splits(probs.shape[1])
  154. if n_splits > 1:
  155. (sampled_tokens, sampled_logprobs,
  156. sampled_modified_probs) = _multi_split_sample(
  157. probs,
  158. seeds,
  159. n_splits,
  160. sampled_tokens_size,
  161. sampled_logprobs_size,
  162. sample_indices,
  163. logprobs=logprobs,
  164. modify_greedy_probs=modify_greedy_probs,
  165. save_logprobs=save_logprobs)
  166. else:
  167. sampled_tokens = torch.empty(sampled_tokens_size,
  168. dtype=torch.long,
  169. device=probs.device)
  170. sampled_logprobs = torch.empty(sampled_logprobs_size,
  171. dtype=probs.dtype,
  172. device=probs.device)
  173. sampled_modified_probs = torch.empty(sampled_modified_probs_size,
  174. dtype=probs.dtype,
  175. device=probs.device)
  176. n_samples = sample_indices.shape[0]
  177. n_cols = probs.shape[1]
  178. uniform_noise = seeded_uniform(n_samples,
  179. max_best_of,
  180. n_cols,
  181. seeds=seeds.flatten(),
  182. device=probs.device,
  183. dtype=probs.dtype)
  184. _sample(
  185. probs,
  186. logprobs,
  187. sample_indices,
  188. sampled_tokens,
  189. sampled_logprobs,
  190. sampled_modified_probs,
  191. seeds,
  192. uniform_noise,
  193. modify_greedy_probs=modify_greedy_probs,
  194. save_logprobs=save_logprobs,
  195. save_modified_probs=_save_modified_probs,
  196. )
  197. return (sampled_tokens, sampled_logprobs if save_logprobs else None,
  198. sampled_modified_probs if _save_modified_probs else None)
  199. def _sample(probs: torch.Tensor,
  200. logprobs: torch.Tensor,
  201. sample_indices: torch.Tensor,
  202. output_samples: torch.Tensor,
  203. output_logprobs: torch.Tensor,
  204. output_modified_probs: torch.Tensor,
  205. seeds: torch.Tensor,
  206. uniform_noise: torch.Tensor,
  207. *,
  208. modify_greedy_probs: bool = False,
  209. save_logprobs: bool = True,
  210. save_modified_probs: bool = False) -> torch.Tensor:
  211. """Sample tokens from probs.
  212. Args:
  213. probs [batch_size, vocab_size]: probs to sample from.
  214. logprobs [batch_size, vocab_size]: logprobs (used when
  215. save_logprobsis True).
  216. sample_indices [n]: Indices of the samples to use for each row of probs.
  217. output_samples [n, n_best]: Output tensor to store samples in.
  218. output_logprobs [n, n_best]: Output tensor to store logprobs in.
  219. output_modified_probs [n, n_best]: Output tensor to store
  220. probs of chosen tokens in (modified with noise).
  221. seeds [n]: Seeds to use for sampling. If the seed is 0, we use
  222. greedy sampling. Note this is ONLY used for determining
  223. whether to use random sampling or not. The actual random
  224. noise should be passed as uniform_noise.
  225. uniform_noise [batch_size, n_best, vocab_size]: Uniform
  226. noise to use for random sampling (will be converted
  227. to exponential gumbel noise by the kernel).
  228. modify_greedy_probs: If True, we modify the probs tensor in-place
  229. to encode the sampling method used for each row. This is used
  230. in speculative decoding. Only applies in greedy decoding.
  231. save_logprobs: If True, we save the logprobs of the sampled tokens
  232. in the output_logprobs tensor.
  233. save_modified_probs: If True, we save the modified probs (with noise)
  234. of the sampled tokens in the output_modified_probs tensor.
  235. DOES NOT include the modification done by modify_greedy_probs
  236. (because we want to use the unmodified probs to pick the best
  237. split in case of multi-split sampling).
  238. """
  239. n_samples = sample_indices.shape[0]
  240. n_cols = probs.shape[1]
  241. n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
  242. # The block size is the smallest power of two greater than the number of
  243. # columns in probs
  244. block_size = triton.next_power_of_2(n_cols)
  245. num_warps = 4
  246. # Manual tuning. This seems to give best performance on A100 for
  247. # simple kernels like this.
  248. if block_size >= 8192:
  249. num_warps = 32
  250. elif block_size >= 4096:
  251. num_warps = 16
  252. elif block_size >= 2048:
  253. num_warps = 8
  254. # Enqueue kernel. The 1D launch grid is simple: we have one kernel
  255. # instance per row of the probs matrix
  256. _sample_triton[(n_samples, n_best)](
  257. sample_indices,
  258. output_samples,
  259. output_logprobs,
  260. output_modified_probs,
  261. probs,
  262. logprobs,
  263. seeds,
  264. uniform_noise,
  265. output_samples.stride(0),
  266. probs.stride(0),
  267. uniform_noise.stride(0),
  268. uniform_noise.stride(1) if n_best > 1 else 1,
  269. n_samples,
  270. n_cols,
  271. n_best,
  272. num_warps=num_warps,
  273. block_size=block_size,
  274. modify_greedy_probs=modify_greedy_probs,
  275. save_logprobs=save_logprobs,
  276. save_modified_probs=save_modified_probs,
  277. )
  278. return output_samples, output_logprobs, output_modified_probs
  279. @triton.jit
  280. def _uniform_to_exponential(uniform_noise):
  281. """Convert uniform samples to exponential samples."""
  282. # tl.rand returns values in [0, 1), so we clamp lower bound
  283. # to _EPS to avoid log(0) and thus division by 0 later
  284. lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
  285. uniform_noise = tl.maximum(uniform_noise, lb)
  286. # Use the inversion method to turn uniform samples
  287. # into exponential samples
  288. exponential_noise = -tl.log(uniform_noise)
  289. return exponential_noise
  290. @triton.jit
  291. def _sample_triton(
  292. sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
  293. output_logprobs_ptr: torch.Tensor,
  294. output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
  295. logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
  296. uniform_noise_ptr: torch.Tensor, output_row_stride: int,
  297. probs_row_stride: int, uniform_noise_row_stride: int,
  298. uniform_noise_best_stride: int, n_samples: int, n_cols: int,
  299. n_best: int, block_size: tl.constexpr,
  300. modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
  301. save_modified_probs: tl.constexpr):
  302. # The rows are independent, so we parallelize across those
  303. sample_idx = tl.program_id(0)
  304. best_idx = tl.program_id(1)
  305. # Load the row index from DRAM
  306. row_idx = tl.load(sample_indices_ptr + sample_idx)
  307. seed = tl.load(seeds_ptr + sample_idx)
  308. uses_random_sampling = seed != 0
  309. # The stride represents how much we need to increase the
  310. # pointer to advance 1 row
  311. row_start_ptr = probs_ptr + row_idx * probs_row_stride
  312. # The block size is the next power of two greater than n_cols,
  313. # so we can fit each row in a single block
  314. col_offsets = tl.arange(0, block_size)
  315. # Load the row into SRAM, using a mask since block_size may be > than n_cols
  316. row = tl.load(row_start_ptr + col_offsets,
  317. mask=col_offsets < n_cols,
  318. other=float("-inf"))
  319. if uses_random_sampling:
  320. uniform_noise_start_ptr = (uniform_noise_ptr +
  321. sample_idx * uniform_noise_row_stride +
  322. best_idx * uniform_noise_best_stride)
  323. uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
  324. mask=col_offsets < n_cols,
  325. other=0.5)
  326. exponential_noise = _uniform_to_exponential(uniform_noise)
  327. row /= exponential_noise
  328. sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
  329. # clamp sampled token to n_cols - 1
  330. # this should not be necessary, but we do it
  331. # just in case
  332. if sampled_token >= n_cols:
  333. sampled_token = n_cols - 1
  334. # Write back output to DRAM
  335. output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
  336. best_idx)
  337. tl.store(output_row_start_ptr, sampled_token)
  338. if modify_greedy_probs: # noqa
  339. if not uses_random_sampling:
  340. # Set the probability of the sampled token to 1, all other
  341. # tokens to zero. This is used in speculative decoding where
  342. # the sampling method must be encoded within the sampled
  343. # probability distributions.
  344. row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
  345. tl.store(row_start_ptr + col_offsets,
  346. row,
  347. mask=col_offsets < n_cols)
  348. if save_modified_probs:
  349. output_row_start_ptr = (output_modified_probs_ptr +
  350. sample_idx * output_row_stride + best_idx)
  351. tl.store(output_row_start_ptr, sampled_value)
  352. if save_logprobs:
  353. # Load the row into SRAM, using a mask since block_size
  354. # may be > than n_cols
  355. sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
  356. sampled_token)
  357. # Write back output to DRAM
  358. output_row_start_ptr = (output_logprobs_ptr +
  359. sample_idx * output_row_stride + best_idx)
  360. tl.store(output_row_start_ptr, sampled_logprob)