sampler.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import os
  4. import warnings
  5. from typing import Dict, List, Optional, Tuple
  6. import torch
  7. import torch.nn as nn
  8. from aphrodite import _custom_ops as ops
  9. from aphrodite.common.sampling_params import SamplingType
  10. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  11. PromptLogprobs, SampleLogprobs,
  12. SamplerOutput, SequenceOutput)
  13. from aphrodite.common.utils import print_warning_once
  14. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  15. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  16. SamplingTensors,
  17. SequenceGroupToSample)
  18. APHRODITE_USE_CUDA_KERNELS_FOR_SAMPLING = bool(
  19. int(os.getenv("APHRODITE_USE_CUDA_KERNELS_FOR_SAMPLING", 0)))
  20. if APHRODITE_USE_CUDA_KERNELS_FOR_SAMPLING:
  21. print_warning_once("Using CUDA kernels for sampling.")
  22. aphrodite_use_cuda_kernels_for_sampling = True
  23. else:
  24. aphrodite_use_cuda_kernels_for_sampling = False
  25. # (num_token_ids, num_parent_ids) per sequence group.
  26. SampleResultType = List[Tuple[List[int], List[int]]]
  27. class Sampler(nn.Module):
  28. """Samples the next tokens from the model's outputs.
  29. This layer does the following:
  30. 1. Discard the hidden states that are not used for sampling (i.e., all
  31. tokens except the final one in each prompt).
  32. 2. Compute the logits for the next tokens.
  33. 3. Apply presence, frequency and repetition penalties.
  34. 4. Apply temperature scaling.
  35. 5. Apply top-p and top-k truncation.
  36. 6. Sample the next tokens.
  37. Here, each sequence group within the batch can have different sampling
  38. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  39. The structure of the logits tensor is coupled with the seq_groups in
  40. sampling_metadata. Typically, each sequence in each seq_group has one row in
  41. logits for the next token to be sampled; however, for a seq_group with a
  42. prompt request with the prompt_logprobs sampling parameter, there are rows
  43. in logits for each token in the input prompt.
  44. """
  45. def __init__(self):
  46. super().__init__()
  47. # Whether or not the SamplerOutput should have on-device tensors
  48. # containing the sampled token ids and probabilities. This is used by
  49. # speculative decoding.
  50. self.include_gpu_probs_tensor = False
  51. def _init_sampling_tensors(
  52. self,
  53. logits: torch.Tensor,
  54. sampling_metadata: SamplingMetadata,
  55. ):
  56. """The goal here is to reuse sampling tensors between similar decode
  57. runs. This is possible because sampling logic does not change between
  58. decodes of the same sequences.
  59. """
  60. _, vocab_size = logits.shape
  61. # First free any existing stored sampling tensors.
  62. # This is necessary because some sampling tensors may
  63. # have pinned memory.
  64. self._sampling_tensors = None
  65. # Initialize new sampling tensors
  66. (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
  67. do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
  68. do_quadratic) = SamplingTensors.from_sampling_metadata(
  69. sampling_metadata, vocab_size, logits.device, logits.dtype)
  70. self._sampling_tensors = sampling_tensors
  71. self._do_penalties = do_penalties
  72. self._do_top_p_top_k = do_top_p_top_k
  73. self._do_top_as = do_top_as
  74. self._do_min_p = do_min_p
  75. self._do_tfss = do_tfss
  76. self._do_eta_cutoffs = do_eta_cutoffs
  77. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  78. self._do_typical_ps = do_typical_ps
  79. self._do_quadratic = do_quadratic
  80. def forward(
  81. self,
  82. logits: torch.Tensor,
  83. sampling_metadata: SamplingMetadata,
  84. ) -> Optional[SamplerOutput]:
  85. """
  86. Args:
  87. logits: (num_tokens, vocab_size).
  88. sampling_metadata: Metadata for sampling.
  89. """
  90. assert logits is not None
  91. _, vocab_size = logits.shape
  92. # Prepare sampling tensors with pinned memory to avoid blocking.
  93. if not sampling_metadata.reuse_sampling_tensors:
  94. self._init_sampling_tensors(logits, sampling_metadata)
  95. elif self._do_penalties:
  96. # In this case, the sampling tensors logic depends on
  97. # "output_tokens" of a sequence. As a result, we cannot
  98. # reuse sampling tensors, since "output_tokens" changes
  99. # between decode runs.
  100. self._init_sampling_tensors(logits, sampling_metadata)
  101. assert self._sampling_tensors is not None
  102. sampling_tensors = self._sampling_tensors
  103. do_penalties = self._do_penalties
  104. do_top_p_top_k = self._do_top_p_top_k
  105. do_top_as = self._do_top_as
  106. do_min_p = self._do_min_p
  107. do_tfss = self._do_tfss
  108. do_eta_cutoffs = self._do_eta_cutoffs
  109. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  110. do_typical_ps = self._do_typical_ps
  111. do_quadratic = self._do_quadratic
  112. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  113. # Apply presence and frequency penalties.
  114. if do_penalties:
  115. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  116. sampling_tensors.output_tokens,
  117. sampling_tensors.presence_penalties,
  118. sampling_tensors.frequency_penalties,
  119. sampling_tensors.repetition_penalties)
  120. # Apply temperature scaling.
  121. # Use in-place division to avoid creating a new tensor.
  122. logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
  123. if do_top_p_top_k and not aphrodite_use_cuda_kernels_for_sampling:
  124. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  125. sampling_tensors.top_ks)
  126. if do_top_as:
  127. logits = _apply_top_a(logits, sampling_tensors.top_as)
  128. if do_min_p:
  129. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  130. if do_tfss:
  131. logits = _apply_tfs(logits, sampling_tensors.tfss)
  132. if do_eta_cutoffs:
  133. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  134. if do_epsilon_cutoffs:
  135. logits = _apply_epsilon_cutoff(logits,
  136. sampling_tensors.epsilon_cutoffs)
  137. if do_typical_ps:
  138. logits = _apply_typical_sampling(logits,
  139. sampling_tensors.typical_ps)
  140. if do_quadratic:
  141. logits = _apply_quadratic_sampling(
  142. logits, sampling_tensors.smoothing_factors,
  143. sampling_tensors.smoothing_curves)
  144. # banned_tokens = _get_custom_token_bans(sampling_metadata)
  145. # logits = _apply_token_bans(logits, banned_tokens)
  146. # We use float32 for probabilities and log probabilities.
  147. # Compute the probabilities.
  148. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  149. # Compute the log probabilities.
  150. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  151. # Sample the next tokens.
  152. sample_results, maybe_sampled_tokens_tensor = _sample(
  153. probs,
  154. logprobs,
  155. sampling_metadata,
  156. sampling_tensors,
  157. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  158. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  159. )
  160. if self.include_gpu_probs_tensor:
  161. assert maybe_sampled_tokens_tensor is not None
  162. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  163. else:
  164. on_device_tensors = None
  165. # Get the logprobs query results.
  166. prompt_logprobs = None
  167. sample_logprobs = None
  168. if not sampling_metadata.skip_sampler_cpu_output:
  169. prompt_logprobs, sample_logprobs = _get_logprobs(
  170. logprobs, sampling_metadata, sample_results)
  171. return _build_sampler_output(
  172. sample_results,
  173. sampling_metadata,
  174. prompt_logprobs,
  175. sample_logprobs,
  176. on_device_tensors=on_device_tensors,
  177. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  178. @property
  179. def _should_modify_greedy_probs_inplace(self) -> bool:
  180. """Whether or not the sampler should modify the probability distribution
  181. of greedily-sampled tokens such that multinomial sampling would sample
  182. the greedily-sampled token.
  183. In other words, if True then we set the probability of the greedily-
  184. sampled token to 1.
  185. This is used by speculative decoding, which requires that the sampling
  186. method be encoded into the probability distribution.
  187. """
  188. # Modify greedy probs if include_gpu_probs_tensor is set.
  189. return self.include_gpu_probs_tensor
  190. def _get_bin_counts_and_mask(
  191. tokens: torch.Tensor,
  192. vocab_size: int,
  193. num_seqs: int,
  194. ) -> Tuple[torch.Tensor, torch.Tensor]:
  195. # Compute the bin counts for the tokens.
  196. # vocab_size + 1 for padding.
  197. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  198. dtype=torch.long,
  199. device=tokens.device)
  200. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  201. bin_counts = bin_counts[:, :vocab_size]
  202. mask = bin_counts > 0
  203. return bin_counts, mask
  204. def _get_custom_token_bans(
  205. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  206. assert sampling_metadata.seq_groups is not None
  207. banned_tokens: List[List[int]] = []
  208. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  209. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  210. seq_ids = seq_group.seq_ids
  211. custom_token_bans = sampling_params.custom_token_bans
  212. if (i < sampling_metadata.num_prompts
  213. and sampling_params.prompt_logprobs is not None):
  214. prompt_len = len(seq_group.prompt_logprob_indices)
  215. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  216. banned_tokens += [custom_token_bans] * len(seq_ids)
  217. return banned_tokens
  218. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  219. output_tokens_tensor: torch.Tensor,
  220. presence_penalties: torch.Tensor,
  221. frequency_penalties: torch.Tensor,
  222. repetition_penalties: torch.Tensor) -> torch.Tensor:
  223. num_seqs, vocab_size = logits.shape
  224. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  225. num_seqs)
  226. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  227. output_tokens_tensor, vocab_size, num_seqs)
  228. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  229. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  230. logits = torch.where(logits > 0, logits / repetition_penalties,
  231. logits * repetition_penalties)
  232. # We follow the definition in OpenAI API.
  233. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  234. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  235. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  236. return logits
  237. def _apply_token_bans(logits: torch.Tensor,
  238. banned_tokens: List[List[int]]) -> torch.Tensor:
  239. for i, banned_token_ids in enumerate(banned_tokens):
  240. if not banned_token_ids:
  241. continue
  242. logits[i, banned_token_ids] = -float("inf")
  243. return logits
  244. def _apply_min_tokens_penalty(
  245. logits: torch.Tensor,
  246. sampling_metadata: SamplingMetadata,
  247. ) -> torch.Tensor:
  248. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  249. have not been generated yet
  250. """
  251. # list of indices in logits that will be set to -inf
  252. logits_to_penalize = []
  253. logits_applied = 0
  254. for seq_group in sampling_metadata.seq_groups:
  255. seq_ids = seq_group.seq_ids
  256. sampling_params = seq_group.sampling_params
  257. sample_indices = seq_group.sample_indices
  258. logits_applied += len(sample_indices) + len(
  259. seq_group.prompt_logprob_indices)
  260. if not seq_group.do_sample:
  261. continue
  262. start_idx = sample_indices[0]
  263. min_tokens = sampling_params.min_tokens
  264. token_ids_to_penalize = sampling_params.all_stop_token_ids
  265. if min_tokens > 0 and token_ids_to_penalize:
  266. seqs_to_penalize = []
  267. for j, seq_id in enumerate(seq_ids):
  268. seq_data = seq_group.seq_data[seq_id]
  269. if len(seq_data.output_token_ids) < min_tokens:
  270. seqs_to_penalize.append(j)
  271. if seqs_to_penalize:
  272. # convert to the index into logits
  273. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  274. # itertools.product pairs each seq index with every token id
  275. logits_to_penalize.extend(
  276. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  277. if logits_to_penalize:
  278. # use zip and * to group indices along each dimension
  279. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  280. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  281. # verifies that no rows in logits were missed unexpectedly
  282. assert logits_applied == logits.shape[0]
  283. return logits
  284. def _apply_top_k_top_p(
  285. logits: torch.Tensor,
  286. p: torch.Tensor,
  287. k: torch.Tensor,
  288. ) -> torch.Tensor:
  289. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  290. # Apply top-k.
  291. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  292. # Get all the top_k values.
  293. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  294. top_k_mask = logits_sort < top_k_mask
  295. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  296. # Apply top-p.
  297. probs_sort = logits_sort.softmax(dim=-1)
  298. probs_sum = probs_sort.cumsum(dim=-1)
  299. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  300. # at least one
  301. top_p_mask[:, -1] = False
  302. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  303. # Re-sort the probabilities.
  304. src = torch.arange(logits_idx.shape[-1],
  305. device=logits_idx.device).expand_as(logits_idx)
  306. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  307. index=logits_idx,
  308. src=src)
  309. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  310. return logits
  311. def _apply_min_p(
  312. logits: torch.Tensor,
  313. min_p: torch.Tensor,
  314. ) -> torch.Tensor:
  315. """
  316. Adapted from
  317. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  318. """
  319. probs = torch.softmax(logits, dim=-1)
  320. top_probs, _ = probs.max(dim=-1, keepdim=True)
  321. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  322. tokens_to_remove = probs < scaled_min_p
  323. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  324. return logits
  325. def _apply_top_a(
  326. logits: torch.Tensor,
  327. top_a: torch.Tensor,
  328. ) -> torch.Tensor:
  329. probs = torch.softmax(logits, dim=-1)
  330. top_probs, _ = probs.max(dim=-1, keepdim=True)
  331. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  332. tokens_to_remove = probs < threshold
  333. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  334. return logits
  335. def _apply_tfs(
  336. logits: torch.Tensor,
  337. tfs: torch.Tensor,
  338. ) -> torch.Tensor:
  339. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  340. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  341. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  342. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  343. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  344. tfs_mask = torch.cat(
  345. (
  346. torch.zeros(
  347. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  348. tfs_mask,
  349. torch.ones(
  350. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  351. ),
  352. dim=-1,
  353. )
  354. logits_sort[tfs_mask] = -float("inf")
  355. logits = torch.gather(logits_sort,
  356. dim=-1,
  357. index=torch.argsort(logits_idx, dim=-1))
  358. return logits
  359. def _apply_eta_cutoff(
  360. logits: torch.Tensor,
  361. eta_cutoff: torch.Tensor,
  362. ) -> torch.Tensor:
  363. shifted_logits = torch.log_softmax(logits, dim=-1)
  364. probs = shifted_logits.exp()
  365. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  366. eps = torch.min(eta_cutoff,
  367. torch.sqrt(eta_cutoff) *
  368. torch.exp(neg_entropy)).unsqueeze(dim=1)
  369. eta_mask = probs < eps
  370. # guard against nulling out all the logits
  371. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  372. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  373. logits[eta_mask] = -float("inf")
  374. return logits
  375. def _apply_epsilon_cutoff(
  376. logits: torch.Tensor,
  377. epsilon_cutoff: torch.Tensor,
  378. ) -> torch.Tensor:
  379. probs = logits.softmax(dim=-1)
  380. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  381. # guard against nulling out all the logits
  382. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  383. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  384. logits[eps_mask] = -float("inf")
  385. return logits
  386. def _apply_typical_sampling(
  387. logits: torch.Tensor,
  388. typical_p: torch.Tensor,
  389. ) -> torch.Tensor:
  390. shifted_logits = torch.log_softmax(logits, dim=-1)
  391. probs = shifted_logits.exp()
  392. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  393. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  394. _, indices = torch.sort(surprisal_deviations)
  395. reordered_probs = probs.gather(-1, indices)
  396. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  397. dim=1)
  398. min_tokens_to_keep = 1
  399. # Keep at least min_tokens_to_keep
  400. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  401. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  402. logits[typ_mask] = -float("inf")
  403. return logits
  404. def _apply_quadratic_sampling(
  405. logits: torch.Tensor,
  406. smoothing_factor: torch.Tensor,
  407. smoothing_curve: torch.Tensor,
  408. ) -> torch.Tensor:
  409. """
  410. Applies a quadratic transformation to the logits based on the
  411. provided smoothing factors and curves. The transformation is
  412. centered around the maximum logit value in the batch.
  413. The transformation involves a quadratic and cubic term, with the
  414. cubic term controlled by the smoothing curve. The quadratic term is
  415. scaled by the smoothing factor, and the cubic term is scaled by the
  416. product of the smoothing factor and the smoothing curve.
  417. params:
  418. logits (torch.Tensor): The logits to be transformed.
  419. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  420. term in the transformation.
  421. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  422. in the transformation.
  423. returns:
  424. torch.Tensor: The transformed logits.
  425. Credits: @kalomaze
  426. """
  427. max_logits = logits.max(dim=-1, keepdim=True).values
  428. diff = logits - max_logits
  429. smoothing_factor.unsqueeze_(dim=1)
  430. smoothing_curve.unsqueeze_(dim=1)
  431. k = (3 - smoothing_curve) / 2
  432. s = (smoothing_curve - 1) / 2
  433. mask = smoothing_factor > 0
  434. mask = mask.flatten()
  435. transformed_logits = torch.where(
  436. logits != float('-inf'), -(k * smoothing_factor * diff**2) +
  437. (s * smoothing_factor * diff**3) + max_logits, logits)
  438. logits[mask, :] = transformed_logits[mask, :]
  439. return logits
  440. def _greedy_sample(
  441. selected_seq_groups: List[SequenceGroupToSample],
  442. samples: torch.Tensor,
  443. ) -> List[Tuple[List[int], List[int]]]:
  444. """Run greedy sampling on a given samples.
  445. Args:
  446. selected_seq_groups: A list of sequence groups batched.
  447. samples: (num_selected_samples,) A tensor of samples. The length of
  448. samples could be smaller than selected_seq_groups if
  449. seq_group.do_sample is False.
  450. Returns:
  451. Tuple of (next_token_ids, parent_ids). The length of returned list is
  452. same as the length of selected_seq_groups. If the corresponding
  453. seq_group has do_sample=False, tuple contains ([], [])
  454. """
  455. samples = samples.tolist()
  456. sample_idx = 0
  457. results = []
  458. for seq_group in selected_seq_groups:
  459. if not seq_group.do_sample:
  460. results.append(([], []))
  461. continue
  462. seq_ids = seq_group.seq_ids
  463. num_parent_seqs = len(seq_ids)
  464. assert num_parent_seqs == 1, (
  465. "Greedy sampling should have only one seq.")
  466. parent_ids = list(range(num_parent_seqs))
  467. next_token_ids = [samples[sample_idx]]
  468. results.append((next_token_ids, parent_ids))
  469. sample_idx += num_parent_seqs
  470. return results
  471. def _random_sample(
  472. selected_seq_groups: List[SequenceGroupToSample],
  473. random_samples: torch.Tensor,
  474. ) -> List[Tuple[List[int], List[int]]]:
  475. """Run random sampling on a given samples.
  476. Args:
  477. selected_seq_groups: A list of sequence groups batched.
  478. random_samples: (num_selected_samples,) A tensor of samples. The
  479. length of samples could be smaller than selected_seq_groups if
  480. seq_group.do_sample is False.
  481. Returns:
  482. Tuple of (next_token_ids, parent_ids). The length of returned list is
  483. same as the length of selected_seq_groups. If the corresponding
  484. seq_group has do_sample=False, tuple contains ([], [])
  485. """
  486. # Find the maximum best_of value of the prompt phase requests.
  487. random_samples = random_samples.cpu()
  488. sample_idx = 0
  489. results = []
  490. for seq_group in selected_seq_groups:
  491. if not seq_group.do_sample:
  492. results.append(([], []))
  493. continue
  494. seq_ids = seq_group.seq_ids
  495. sampling_params = seq_group.sampling_params
  496. is_prompt = seq_group.is_prompt
  497. num_parent_seqs = len(seq_ids)
  498. if is_prompt:
  499. # Prompt phase.
  500. parent_ids = [0] * sampling_params.best_of
  501. next_token_ids = random_samples[
  502. sample_idx, :sampling_params.best_of].tolist()
  503. else:
  504. # Generation phase.
  505. parent_ids = list(range(num_parent_seqs))
  506. next_token_ids = random_samples[sample_idx:sample_idx +
  507. num_parent_seqs, 0].tolist()
  508. results.append((next_token_ids, parent_ids))
  509. sample_idx += num_parent_seqs
  510. return results
  511. def _beam_search_sample(
  512. selected_seq_groups: List[SequenceGroupToSample],
  513. logprobs: torch.Tensor,
  514. ) -> List[Tuple[List[int], List[int]]]:
  515. """Run beam sampling on a given samples.
  516. Args:
  517. selected_seq_groups: A list of sequence groups batched.
  518. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  519. on selected sample indices.
  520. Returns:
  521. Tuple of (next_token_ids, parent_ids). The length of returned list is
  522. same as the length of selected_seq_groups. If the corresponding
  523. seq_group has do_sample=False, tuple contains ([], [])
  524. """
  525. # We sample 2 * beam_width candidates to make sure that with high
  526. # probability we can get `beam_width` candidates in addition to
  527. # the finished sequences for the next iteration. See
  528. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  529. # for details. See also HF reference:
  530. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  531. #
  532. # NOTE: Beam search is not vectorized, so its speed can be slower than
  533. # other sampling methods.
  534. sample_idx = 0
  535. results = []
  536. for seq_group in selected_seq_groups:
  537. if not seq_group.do_sample:
  538. results.append(([], []))
  539. continue
  540. is_prompt = seq_group.is_prompt
  541. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  542. num_parent_seqs = len(seq_ids)
  543. beam_width = sampling_params.best_of
  544. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  545. if is_prompt:
  546. # Prompt phase.
  547. assert num_parent_seqs == 1, (
  548. "Prompt input should have only one seq.")
  549. parent_ids = [0] * (2 * beam_width)
  550. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  551. 2 * beam_width)
  552. next_token_ids = next_token_ids.tolist()
  553. else:
  554. # Generation phase.
  555. cumulative_logprobs = [
  556. seq_group.seq_data[seq_id].cumulative_logprob
  557. for seq_id in seq_ids
  558. ]
  559. cumulative_logprobs = torch.tensor(
  560. cumulative_logprobs,
  561. dtype=torch.float,
  562. device=seq_group_logprobs.device)
  563. seq_group_logprobs = (seq_group_logprobs +
  564. cumulative_logprobs.unsqueeze(dim=1))
  565. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  566. 2 * beam_width)
  567. topk_ids = topk_ids.tolist()
  568. vocab_size = seq_group_logprobs.size(-1)
  569. parent_ids = [i // vocab_size for i in topk_ids]
  570. next_token_ids = [i % vocab_size for i in topk_ids]
  571. results.append((next_token_ids, parent_ids))
  572. sample_idx += num_parent_seqs
  573. assert sample_idx == logprobs.size(0)
  574. return results
  575. # torch.multinomial forces a GPU<->CPU sync.
  576. # Therefore, we use an optimized implementation instead.
  577. # Note that we always sample with replacement.
  578. # probs will be modified in place, but this is fine, as we pass
  579. # in a copy already.
  580. def _multinomial(
  581. probs: torch.Tensor,
  582. num_samples: int,
  583. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  584. ) -> torch.Tensor:
  585. if num_samples > 1:
  586. probs = probs.repeat_interleave(num_samples, dim=0)
  587. q = torch.empty_like(probs)
  588. if seq_groups is None:
  589. q.exponential_()
  590. else:
  591. sample_idx = 0
  592. for seq_group in seq_groups:
  593. seq_ids = seq_group.seq_ids
  594. stride = len(seq_ids) * num_samples
  595. assert seq_group.generator is not None
  596. q[sample_idx:sample_idx +
  597. stride].exponential_(generator=seq_group.generator)
  598. sample_idx += stride
  599. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  600. def _top_k_top_p_multinomial_with_cuda_kernels(
  601. probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
  602. num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
  603. max_top_k_round = 32
  604. if num_samples > 1:
  605. probs = probs.repeat_interleave(num_samples, dim=0)
  606. top_ks = top_ks.repeat_interleave(num_samples)
  607. top_ps = top_ps.repeat_interleave(num_samples)
  608. batch_size = probs.shape[0]
  609. uniform_samples = torch.empty((max_top_k_round, batch_size),
  610. device=probs.device)
  611. if seq_groups is None:
  612. uniform_samples.uniform_()
  613. else:
  614. sample_idx = 0
  615. for seq_group in seq_groups:
  616. seq_ids = seq_group.seq_ids
  617. stride = len(seq_ids) * num_samples
  618. assert seq_group.generator is not None
  619. uniform_samples[:, sample_idx:sample_idx +
  620. stride].uniform_(generator=seq_group.generator)
  621. sample_idx += stride
  622. batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
  623. probs,
  624. uniform_samples,
  625. top_ks,
  626. top_ps,
  627. )
  628. if not success.all():
  629. warnings.warn("CUDA kernel rejection sampling failed, fallback.",
  630. stacklevel=1)
  631. probs = ops.top_k_renorm_prob(probs, top_ks)
  632. probs = ops.top_p_renorm_prob(probs, top_ps)
  633. batch_next_token_ids = ops.sampling_from_probs(probs,
  634. uniform_samples[0])
  635. return batch_next_token_ids.view(-1, num_samples)
  636. def _sample_with_torch(
  637. probs: torch.Tensor,
  638. logprobs: torch.Tensor,
  639. sampling_metadata: SamplingMetadata,
  640. sampling_tensors: SamplingTensors,
  641. include_gpu_probs_tensor: bool,
  642. modify_greedy_probs: bool,
  643. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  644. categorized_seq_group_ids = {t: [] for t in SamplingType}
  645. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  646. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  647. sampling_params = seq_group.sampling_params
  648. sampling_type = sampling_params.sampling_type
  649. categorized_seq_group_ids[sampling_type].append(i)
  650. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  651. sample_metadata = {}
  652. multinomial_samples = {}
  653. # Create output tensor for sampled token ids.
  654. if include_gpu_probs_tensor:
  655. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  656. 1,
  657. dtype=torch.long,
  658. device=logprobs.device)
  659. else:
  660. sampled_token_ids_tensor = None
  661. # Counterintuitively, having two loops here is actually faster.
  662. # The first loop can run without waiting on GPU<->CPU sync.
  663. for sampling_type in SamplingType:
  664. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  665. num_tokens = len(sample_indices)
  666. if num_tokens == 0:
  667. continue
  668. seq_group_id = categorized_seq_group_ids[sampling_type]
  669. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  670. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  671. long_sample_indices = sample_indices.long()
  672. if sampling_type == SamplingType.GREEDY:
  673. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  674. dim=-1)
  675. if include_gpu_probs_tensor:
  676. # Store sampled tokens in output tensor.
  677. sampled_token_ids_tensor[
  678. long_sample_indices] = greedy_samples.unsqueeze(-1)
  679. if modify_greedy_probs:
  680. # If required, modify the probabilities such that sampling from
  681. # the modified distribution would always sample the argmax
  682. # token id.
  683. _modify_greedy_probs_inplace(logprobs, probs,
  684. long_sample_indices,
  685. greedy_samples)
  686. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  687. max_best_of_in_batch = 1
  688. for seq_group in seq_groups:
  689. if seq_group.is_prompt:
  690. sampling_params = seq_group.sampling_params
  691. max_best_of_in_batch = max(max_best_of_in_batch,
  692. sampling_params.best_of)
  693. seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
  694. seq_groups)
  695. if aphrodite_use_cuda_kernels_for_sampling is not None:
  696. multinomial_samples[
  697. sampling_type] = _top_k_top_p_multinomial_with_cuda_kernels(
  698. probs[long_sample_indices],
  699. sampling_tensors.top_ks[long_sample_indices],
  700. sampling_tensors.top_ps[long_sample_indices],
  701. max_best_of_in_batch,
  702. seq_groups_arg,
  703. )
  704. else:
  705. multinomial_samples[sampling_type] = _multinomial(
  706. probs[long_sample_indices],
  707. max_best_of_in_batch,
  708. seq_groups=seq_groups_arg)
  709. if include_gpu_probs_tensor:
  710. # Store sampled tokens in output tensor.
  711. sampled_token_ids_tensor[long_sample_indices] = \
  712. multinomial_samples[sampling_type].to(torch.long)
  713. elif sampling_type == SamplingType.BEAM:
  714. beam_search_logprobs = logprobs[sample_indices]
  715. else:
  716. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  717. # GPU<->CPU sync happens in the loop below.
  718. # This also converts the sample output to Python objects.
  719. if not sampling_metadata.skip_sampler_cpu_output:
  720. for sampling_type in SamplingType:
  721. if sampling_type not in sample_metadata:
  722. continue
  723. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  724. if sampling_type == SamplingType.GREEDY:
  725. sample_results = _greedy_sample(seq_groups, greedy_samples)
  726. elif sampling_type in (SamplingType.RANDOM,
  727. SamplingType.RANDOM_SEED):
  728. sample_results = _random_sample(
  729. seq_groups, multinomial_samples[sampling_type])
  730. elif sampling_type == SamplingType.BEAM:
  731. sample_results = _beam_search_sample(seq_groups,
  732. beam_search_logprobs)
  733. sample_results_dict.update(zip(seq_group_id, sample_results))
  734. sample_results = [
  735. sample_results_dict.get(i, ([], []))
  736. for i in range(len(sampling_metadata.seq_groups))
  737. ]
  738. else:
  739. sample_results = []
  740. return sample_results, sampled_token_ids_tensor
  741. def _sample_with_triton_kernel(
  742. probs: torch.Tensor,
  743. logprobs: torch.Tensor,
  744. sampling_metadata: SamplingMetadata,
  745. sampling_tensors: SamplingTensors,
  746. ) -> List[Tuple[List[int], List[int]]]:
  747. categorized_seq_group_ids = {t: [] for t in SamplingType}
  748. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  749. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  750. sampling_params = seq_group.sampling_params
  751. sampling_type = sampling_params.sampling_type
  752. categorized_seq_group_ids[sampling_type].append(i)
  753. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  754. sample_metadata = {}
  755. max_best_of_in_batch = 1
  756. # Counterintuitively, having two loops here is actually faster.
  757. # The first loop can run without waiting on GPU<->CPU sync.
  758. for sampling_type in SamplingType:
  759. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  760. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  761. num_tokens = len(sample_indices)
  762. if num_tokens == 0:
  763. continue
  764. seq_group_id = categorized_seq_group_ids[sampling_type]
  765. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  766. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  767. sample_indices,
  768. sampled_token_indices)
  769. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  770. SamplingType.RANDOM_SEED):
  771. for seq_group in seq_groups:
  772. if seq_group.is_prompt:
  773. sampling_params = seq_group.sampling_params
  774. max_best_of_in_batch = max(max_best_of_in_batch,
  775. sampling_params.best_of)
  776. elif sampling_type == SamplingType.BEAM:
  777. beam_search_logprobs = logprobs[sample_indices]
  778. else:
  779. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  780. sampled_tokens, _, _ = sample_triton(
  781. probs=probs,
  782. seeds=sampling_tensors.sampling_seeds,
  783. max_best_of=max_best_of_in_batch,
  784. sample_indices=sampling_tensors.sample_indices,
  785. logprobs=logprobs,
  786. # don't save logprobs because we have logic for that below
  787. # TODO: use this instead of the CPU-based logic below
  788. save_logprobs=False,
  789. )
  790. # GPU<->CPU sync happens in the loop below.
  791. for sampling_type in SamplingType:
  792. if sampling_type not in sample_metadata:
  793. continue
  794. (seq_group_id, seq_groups, sample_indices,
  795. sampled_token_indices) = sample_metadata[sampling_type]
  796. if sampling_type == SamplingType.GREEDY:
  797. sample_results = _greedy_sample(
  798. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  799. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  800. sample_results = _random_sample(
  801. seq_groups, sampled_tokens[sampled_token_indices])
  802. elif sampling_type == SamplingType.BEAM:
  803. sample_results = _beam_search_sample(seq_groups,
  804. beam_search_logprobs)
  805. sample_results_dict.update(zip(seq_group_id, sample_results))
  806. sample_results = [
  807. sample_results_dict.get(i, ([], []))
  808. for i in range(len(sampling_metadata.seq_groups))
  809. ]
  810. return sample_results
  811. def _sample(
  812. probs: torch.Tensor,
  813. logprobs: torch.Tensor,
  814. sampling_metadata: SamplingMetadata,
  815. sampling_tensors: SamplingTensors,
  816. include_gpu_probs_tensor: bool,
  817. modify_greedy_probs: bool,
  818. ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
  819. """
  820. Args:
  821. probs: (num_query_tokens_in_batch, num_vocab)
  822. logprobs: (num_query_tokens_in_batch, num_vocab)
  823. sampling_metadata: The metadata for a batch for sampling.
  824. sampling_tensors: Tensors that include sampling related metadata.
  825. Returns:
  826. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  827. If sampling is skipped, it returns ([], [])
  828. sampled_token_ids_tensor: A tensor of sampled token ids.
  829. """
  830. return _sample_with_torch(
  831. probs,
  832. logprobs,
  833. sampling_metadata,
  834. sampling_tensors,
  835. include_gpu_probs_tensor=include_gpu_probs_tensor,
  836. modify_greedy_probs=modify_greedy_probs,
  837. )
  838. # TODO: Enable once Triton kernel & associated code is faster.
  839. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  840. # sampling_tensors)
  841. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  842. """
  843. This function calculates the ranks of the chosen tokens in a logprob tensor.
  844. Args:
  845. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  846. where N is the no. of tokens and M is the vocab dim.
  847. indices (torch.Tensor): List of chosen token indices.
  848. Returns:
  849. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  850. Each element in the returned tensor represents the rank
  851. of the chosen token in the input logprob tensor.
  852. """
  853. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  854. indices]
  855. return (x > vals[:, None]).long().sum(1).add_(1)
  856. def _get_logprobs(
  857. logprobs: torch.Tensor,
  858. sampling_metadata: SamplingMetadata,
  859. sample_results: List[Tuple[List[int], List[int]]],
  860. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  861. """Return sample lobprobs and prompt logprobs.
  862. The logic consists of 3 parts.
  863. - Select indices to compute logprob from, ranks of token ids, and
  864. the top k token ids from logprobs.
  865. - Compute prompt logprobs if required.
  866. - Compute sample logprobs if required.
  867. Args:
  868. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  869. logprob per vocab. Sequence groups' query tokens are batched in a
  870. single flattened tensor. For example, assuming there are N
  871. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  872. prompt logprob is enabled), decode tokens for seq_group_1 (if
  873. sampling is required), prefill tokens for seq_group_2, ...
  874. sampling_metadata: The sampling metadata.
  875. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  876. parent_ids) for each sequence group. When beam search is enabled,
  877. sample_results can contain different number of seq_ids from
  878. sampling_metadata.seq_groups. It is because beam search creates
  879. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  880. BEAM_WIDTH number of seq_ids).
  881. Returns:
  882. A tuple of prompt and sample logprobs per sequence group in a batch.
  883. """
  884. # The index of query token to calculate logprobs. It includes both
  885. # prompt and sample logprob indices.
  886. query_indices: List[int] = []
  887. # The next token ids to get the logprob value from.
  888. next_token_ids: List[int] = []
  889. # The largest requested number of logprobs. We find logprobs as many as the
  890. # largest num logprobs in this API.
  891. largest_num_logprobs = 1
  892. # Select indices to compute logprob from, ranks of token ids, and the top
  893. # k token ids from logprobs.
  894. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  895. sample_results):
  896. sampling_params = seq_group.sampling_params
  897. # Update indices and tokens for prompt logprobs.
  898. if (seq_group.is_prompt
  899. and sampling_params.prompt_logprobs is not None):
  900. largest_num_logprobs = max(largest_num_logprobs,
  901. sampling_params.prompt_logprobs)
  902. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  903. query_indices.extend(seq_group.prompt_logprob_indices)
  904. next_token_ids.extend(next_prompt_tokens)
  905. # Update indices and next tokenes for sample logprob.
  906. if seq_group.do_sample:
  907. token_ids, parent_seq_ids = sample_result
  908. # NOTE: We cannot directly use sample_indices because
  909. # sample_indices only contain parent seq_ids of a previous step.
  910. # The current step may have different number of seq_ids, and
  911. # we can obtain it from `sample_result[1]`.
  912. query_idx = seq_group.sample_indices[0]
  913. query_indices.extend(
  914. [query_idx + parent_id for parent_id in parent_seq_ids])
  915. next_token_ids.extend(token_ids)
  916. if sampling_params.logprobs is not None:
  917. largest_num_logprobs = max(largest_num_logprobs,
  918. sampling_params.logprobs)
  919. assert len(next_token_ids) == len(query_indices)
  920. if len(query_indices) == 0:
  921. empty_sampled_logprob = []
  922. empty_prompt_logprob = None
  923. return [empty_prompt_logprob], [empty_sampled_logprob]
  924. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  925. next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
  926. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  927. # contain duplicates if beam search is enabled.
  928. selected_logprobs = logprobs[[
  929. query_indices_gpu,
  930. next_token_ids_gpu,
  931. ]]
  932. ranks = _get_ranks(
  933. logprobs[query_indices_gpu],
  934. next_token_ids_gpu,
  935. )
  936. assert selected_logprobs.shape[0] == ranks.shape[0]
  937. # Logprobs of topk tokens for a batch of sequence groups.
  938. # (num_query_tokens_across_batch).
  939. if largest_num_logprobs > 0:
  940. top_logprobs, top_token_ids = torch.topk(logprobs,
  941. largest_num_logprobs,
  942. dim=-1)
  943. else:
  944. top_logprobs, top_token_ids = None, None
  945. selected_logprobs = selected_logprobs.to('cpu')
  946. ranks = ranks.to('cpu')
  947. if top_logprobs is not None and top_token_ids is not None:
  948. top_logprobs = top_logprobs.to('cpu')
  949. top_token_ids = top_token_ids.to('cpu')
  950. # Find prompt/sample logprobs.
  951. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  952. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  953. top_logprob_idx = 0
  954. selected_logprobs_idx = 0
  955. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  956. sample_results):
  957. (prompt_logprobs, top_logprob_idx,
  958. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  959. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  960. selected_logprobs_idx, top_logprob_idx)
  961. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  962. (sampled_logprobs, top_logprob_idx,
  963. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  964. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  965. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  966. sample_logprobs_per_seq_group.append(sampled_logprobs)
  967. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  968. def _get_prompt_logprob_if_needed(
  969. seq_group: SequenceGroupToSample,
  970. selected_logprobs: torch.Tensor,
  971. ranks: torch.Tensor,
  972. top_token_ids: torch.Tensor,
  973. top_logprobs: torch.Tensor,
  974. selected_logprobs_idx: int,
  975. top_logprob_idx: int,
  976. ):
  977. """Compute the prompt logprob from a sequence group if needed."""
  978. sampling_params = seq_group.sampling_params
  979. is_prompt = seq_group.is_prompt
  980. # Find prompt logprobs
  981. prompt_logprobs: Optional[PromptLogprobs] = None
  982. if is_prompt and sampling_params.prompt_logprobs is not None:
  983. prompt_logprobs = []
  984. num_logprobs = sampling_params.prompt_logprobs
  985. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  986. # Pre-select indexes and create a list. It is faster than calling .item
  987. # repetitively.
  988. selected_logprob_items = selected_logprobs[
  989. selected_logprobs_idx:selected_logprobs_idx +
  990. len(next_prompt_tokens)].tolist()
  991. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  992. len(next_prompt_tokens)].tolist()
  993. for idx, token_id in enumerate(next_prompt_tokens):
  994. # Calculate the prompt logprob of the real prompt tokens.
  995. # {token_id: (logprob, rank_from_vocab)}
  996. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  997. token_id: (selected_logprob_items[idx], rank_items[idx])
  998. }
  999. # Add top K prompt logprobs along with its rank.
  1000. if num_logprobs > 0:
  1001. top_ids = top_token_ids[
  1002. top_logprob_idx, :num_logprobs].tolist()
  1003. top_probs = top_logprobs[
  1004. top_logprob_idx, :num_logprobs].tolist()
  1005. # Top K is already sorted by rank, so we can use 1 ~
  1006. # num_logprobs + 1 for rank.
  1007. top_ranks = range(1, num_logprobs + 1)
  1008. prompt_logprobs_dict.update({
  1009. top_id: (top_prob, rank)
  1010. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1011. top_ranks)
  1012. })
  1013. prompt_logprobs.append({
  1014. token_id: Logprob(*logprob_and_rank)
  1015. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1016. })
  1017. # + 1 to go to the next prompt token.
  1018. top_logprob_idx += 1
  1019. # + len(next_prompt_tokens) to go to the next prompt.
  1020. selected_logprobs_idx += len(next_prompt_tokens)
  1021. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1022. def _get_sampled_logprob_if_needed(
  1023. seq_group: SequenceGroupToSample,
  1024. sample_result: Tuple[List[int], List[int]],
  1025. selected_logprobs: torch.Tensor,
  1026. ranks: torch.Tensor,
  1027. top_token_ids: torch.Tensor,
  1028. top_logprobs: torch.Tensor,
  1029. selected_logprobs_idx: int,
  1030. top_logprob_idx: int,
  1031. ):
  1032. """Compute the sample logprob if needed."""
  1033. seq_ids = seq_group.seq_ids
  1034. num_logprobs = seq_group.sampling_params.logprobs or 0
  1035. sampled_logprobs: SampleLogprobs = []
  1036. next_token_ids, parent_seq_ids = sample_result
  1037. if seq_group.do_sample:
  1038. assert len(next_token_ids) > 0
  1039. # Pre-select items from tensor. tolist() is faster than repetitive
  1040. # `.item()` calls.
  1041. selected_logprob_items = selected_logprobs[
  1042. selected_logprobs_idx:selected_logprobs_idx +
  1043. len(next_token_ids)].tolist()
  1044. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1045. len(next_token_ids)].tolist()
  1046. for idx, (next_token_id,
  1047. parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
  1048. # Get the logprob of a sampled token.
  1049. sampled_logprobs_dict = {
  1050. next_token_id: (selected_logprob_items[idx], rank_items[idx])
  1051. }
  1052. # Get top K logprobs.
  1053. if num_logprobs > 0:
  1054. top_ids = top_token_ids[top_logprob_idx +
  1055. parent_id, :num_logprobs].tolist()
  1056. top_probs = top_logprobs[top_logprob_idx +
  1057. parent_id, :num_logprobs].tolist()
  1058. # Top K is already sorted by rank, so we can use 1 ~
  1059. # num_logprobs + 1 for rank.
  1060. top_ranks = range(1, num_logprobs + 1)
  1061. sampled_logprobs_dict.update({
  1062. top_id: (top_prob, rank)
  1063. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1064. top_ranks)
  1065. })
  1066. sampled_logprobs.append({
  1067. token_id: Logprob(*logprob_and_rank)
  1068. for token_id, logprob_and_rank in
  1069. sampled_logprobs_dict.items()
  1070. })
  1071. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1072. # logprobs for the current step, which has len(next_token_ids) tokens
  1073. # per sequence group. `logprobs` includes logprobs from the previous
  1074. # steps, which has len(seq_ids) tokens per sequence group.
  1075. # Iterate to the next sequence group in a batch.
  1076. selected_logprobs_idx += len(next_token_ids)
  1077. # Iterate to the next sequence group in a batch.
  1078. top_logprob_idx += len(seq_ids)
  1079. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1080. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1081. sample_indices: torch.Tensor,
  1082. greedy_samples: torch.Tensor) -> None:
  1083. """Modify the probability distributions of the greedily-sampled tokens such
  1084. that each sampled token has a "probability" of 1.0. This is required by
  1085. speculative decoding, which depends on the sampling method being encoded
  1086. within the probability distribution for correctness.
  1087. # Why do we only need to do this for greedy sampling?
  1088. Aphrodite's sampler performs the following steps for greedy or multinomial
  1089. (random) sampling:
  1090. 1. Get logits from model.
  1091. 2. Modify logits according to per-sequence sampling parameters.
  1092. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1093. according to their frequency, etc.
  1094. 3. Sample a token.
  1095. - Random sampling simply samples from the modified probability
  1096. distribution.
  1097. - Greedy sampling performs `argmax` to obtain the token with the
  1098. highest likelihood.
  1099. Ignoring greedy sampling for a moment, we find that the computed probability
  1100. distribution has the following property: we can sample from it independently
  1101. and find that the token sampled by the Sampler has a frequency corresponding
  1102. to how often we see it in our sampling. In other words, for tokens sampled
  1103. with Aphrodite's random SamplingType, the computed probability distribution
  1104. encodes the sampling methodology completely.
  1105. Greedy sampling does not normally have this property. Aphrodite modifies
  1106. logits according to sampling params, then performs `argmax`, then returns
  1107. the sampled token and the computed probability distribution. If we sample
  1108. from the distribution, we'll find the likelihood of the greedily-sampled
  1109. token is not always 1.0.
  1110. Since lossless speculative decoding requires that the sampling methodology
  1111. be encoded within the probability distribution, we are motivated to modify
  1112. the probability distribution such that the sampled token has probability 1
  1113. when speculative decoding is used.
  1114. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1115. greedy sampling using multinomial computation and unite the codepaths. This
  1116. has implications on the overall design of the sampler, e.g. how to record
  1117. accurate logprobs for the user, so this improvement is deferred to later.
  1118. """
  1119. # NOTE: logprobs are not modified so they can be returned to the user.
  1120. probs[sample_indices, :] = 0
  1121. probs[sample_indices, greedy_samples] = 1.0
  1122. def _build_sampler_output(
  1123. sample_results: SampleResultType,
  1124. sampling_metadata: SamplingMetadata,
  1125. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1126. sample_logprobs: Optional[List[SampleLogprobs]],
  1127. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1128. torch.Tensor]],
  1129. skip_sampler_cpu_output: bool = False,
  1130. ) -> SamplerOutput:
  1131. """Construct Python objects with the output of sampling.
  1132. Args:
  1133. on_device_tensors: Tuple containing on-device tensors with the
  1134. probabilities used in sampling and the sampled token ids. This
  1135. allows post-processing without copies to CPU/serialization, e.g. in
  1136. speculative decoding rejection sampling.
  1137. """
  1138. sampler_output: List[CompletionSequenceGroupOutput] = []
  1139. if not skip_sampler_cpu_output:
  1140. assert prompt_logprobs is not None
  1141. assert sample_logprobs is not None
  1142. for (seq_group, sample_result, group_prompt_logprobs,
  1143. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1144. sample_results, prompt_logprobs,
  1145. sample_logprobs):
  1146. seq_ids = seq_group.seq_ids
  1147. next_token_ids, parent_ids = sample_result
  1148. seq_outputs: List[SequenceOutput] = []
  1149. for parent_id, next_token_id, logprobs in zip(
  1150. parent_ids, next_token_ids, group_sample_logprobs):
  1151. seq_outputs.append(
  1152. SequenceOutput(seq_ids[parent_id], next_token_id,
  1153. logprobs))
  1154. sampler_output.append(
  1155. CompletionSequenceGroupOutput(seq_outputs,
  1156. group_prompt_logprobs))
  1157. # If not specified, store None values in SamplerOutput.
  1158. if on_device_tensors is not None:
  1159. (sampled_token_probs, logprobs_tensor,
  1160. sampled_token_ids) = on_device_tensors
  1161. else:
  1162. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1163. None)
  1164. return SamplerOutput(
  1165. outputs=sampler_output,
  1166. sampled_token_probs=sampled_token_probs,
  1167. sampled_token_ids=sampled_token_ids,
  1168. logprobs=logprobs_tensor,
  1169. )
  1170. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1171. """Get a list of next prompt tokens to compute logprob from a
  1172. given sequence group.
  1173. It is used to compute prompt logprob. Imagine you have logprob for each
  1174. query token. Query token needs to know the next prompt token id to compute
  1175. prompt logprob. This is a helper to obtain next prompt token ids.
  1176. This API has to be used only when the caller knows seq_group is in prefill
  1177. stage.
  1178. Returns:
  1179. A list of next prompt tokens to compute logprob.
  1180. """
  1181. assert seq_group.is_prompt, (
  1182. "Caller should ensure the sequence group is in a prefill stage.")
  1183. seq_ids = seq_group.seq_ids
  1184. query_len = seq_group.query_len
  1185. assert query_len is not None
  1186. # prompt has only 1 seq id.
  1187. assert len(seq_ids) == 1
  1188. seq_data = seq_group.seq_data[seq_ids[0]]
  1189. computed_len = seq_data.get_num_computed_tokens()
  1190. prompt_tokens = seq_data.prompt_token_ids
  1191. # +1 because we are looking for a next prompt token.
  1192. next_token_index_start = computed_len + 1
  1193. next_token_index_end = min(computed_len + query_len + 1,
  1194. len(prompt_tokens))
  1195. next_prompt_tokens = prompt_tokens[
  1196. next_token_index_start:next_token_index_end]
  1197. return next_prompt_tokens
  1198. # def _apply_mirostat_v2(logits: torch.Tensor,
  1199. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1200. # # Reduce our view to just the affected logits
  1201. # logit_view = logits[sampling_tensors.miro_indices]
  1202. # # Calculate surprise value per token
  1203. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1204. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1205. # # Mask out "too-surprising" tokens (surprisal > mu)
  1206. # mus = sampling_tensors.miro_mus
  1207. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1208. # # Unmask most-likely logit to guarantee a selection.
  1209. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1210. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1211. # # Apply logit mask (effectively a top-k filter).
  1212. # logit_view[miro_mask] = -float("inf")
  1213. # # Project logit changes made to the view onto the original.
  1214. # # I think this step might be redundant.
  1215. # logits[sampling_tensors.miro_indices] = logit_view
  1216. # return logits
  1217. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1218. # sample_results: List[Tuple[List[int], List[int]]],
  1219. # sampling_metadata: SamplingMetadata,
  1220. # output_metadata: OutputMetadata) -> None:
  1221. # """Based on whichever token was finally sampled, we calculate the
  1222. # final surprisal values to update the mus.
  1223. # Because a single sequence can have multiple samples, we must fork
  1224. # the mu accordingly."""
  1225. # assert sampling_metadata.seq_groups is not None
  1226. # seqid_to_tokens = {}
  1227. # seqid_to_indices = {}
  1228. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1229. # sample_results):
  1230. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1231. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1232. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1233. # seqids = args.miro_seqids
  1234. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1235. # device=logits.device,
  1236. # dtype=torch.long)
  1237. # # Clumsily, we recalculate token surprisals.
  1238. # logits_view = logits[args.miro_indices]
  1239. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1240. # dim=-1,
  1241. # index=picked_tokens) / -math.log(2)
  1242. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1243. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1244. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1245. # nu_mus = mus - (picked_surprise - taus) * etas
  1246. # # Record updated mu values for use in the next iteration
  1247. # # Note how each mu is split into multiple based on the number of samples.
  1248. # for seqid, seq_mus in zip(seqids, nu_mus):
  1249. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1250. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)