sampler.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import os
  4. import warnings
  5. from math import inf
  6. from typing import Dict, List, Optional, Tuple
  7. import torch
  8. import torch.nn as nn
  9. import aphrodite._custom_ops as ops
  10. from aphrodite.common.sampling_params import SamplingType
  11. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  12. PromptLogprobs, SampleLogprobs,
  13. SamplerOutput, SequenceOutput)
  14. from aphrodite.triton_utils import HAS_TRITON
  15. if HAS_TRITON:
  16. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  17. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  18. SamplingTensors,
  19. SequenceGroupToSample)
  20. # (num_token_ids, num_parent_ids) per sequence group.
  21. SampleResultType = List[Tuple[List[int], List[int]]]
  22. # There isn't a "safe" temperature range for fp16 logits.
  23. # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
  24. # that this temperature well-uses the fp16 space after the logits are offset.
  25. _TEMPERATURE_MINIMUM = 2e-5
  26. # If enabled, we switch to a more performant implementation
  27. # of top-k and top-p
  28. APHRODITE_USE_SAMPLING_KERNELS = bool(int(
  29. os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
  30. class Sampler(nn.Module):
  31. """Samples the next tokens from the model's outputs.
  32. This layer does the following:
  33. 1. Discard the hidden states that are not used for sampling (i.e., all
  34. tokens except the final one in each prompt).
  35. 2. Compute the logits for the next tokens.
  36. 3. Apply presence, frequency and repetition penalties.
  37. 4. Apply temperature scaling.
  38. 5. Apply top-p and top-k truncation.
  39. 6. Sample the next tokens.
  40. Here, each sequence group within the batch can have different sampling
  41. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  42. The structure of the logits tensor is coupled with the seq_groups in
  43. sampling_metadata. Typically, each sequence in each seq_group has one row in
  44. logits for the next token to be sampled; however, for a seq_group with a
  45. prompt request with the prompt_logprobs sampling parameter, there are rows
  46. in logits for each token in the input prompt.
  47. """
  48. def __init__(self):
  49. super().__init__()
  50. # Whether or not the SamplerOutput should have on-device tensors
  51. # containing the sampled token ids and probabilities. This is used by
  52. # speculative decoding.
  53. self.include_gpu_probs_tensor = False
  54. self.should_modify_greedy_probs_inplace = False
  55. def _init_sampling_tensors(
  56. self,
  57. logits: torch.Tensor,
  58. sampling_metadata: SamplingMetadata,
  59. ):
  60. """The goal here is to reuse sampling tensors between similar decode
  61. runs. This is possible because sampling logic does not change between
  62. decodes of the same sequences.
  63. """
  64. _, vocab_size = logits.shape
  65. # First free any existing stored sampling tensors.
  66. # This is necessary because some sampling tensors may
  67. # have pinned memory.
  68. self._sampling_tensors = None
  69. # Initialize new sampling tensors
  70. (sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
  71. do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
  72. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
  73. do_dry, do_skew, do_temp_last
  74. ) = SamplingTensors.from_sampling_metadata(
  75. sampling_metadata, vocab_size, logits.device, logits.dtype)
  76. self._sampling_tensors = sampling_tensors
  77. self._do_penalties = do_penalties
  78. self._do_no_repeat_ngrams = do_no_repeat_ngrams
  79. self._do_temperatures = do_temperatures
  80. self._do_top_p_top_k = do_top_p_top_k
  81. self._do_top_as = do_top_as
  82. self._do_min_p = do_min_p
  83. self._do_tfss = do_tfss
  84. self._do_eta_cutoffs = do_eta_cutoffs
  85. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  86. self._do_typical_ps = do_typical_ps
  87. self._do_quadratic = do_quadratic
  88. self._do_xtc = do_xtc
  89. self._do_nsgimas = do_nsigmas
  90. self._do_dry = do_dry
  91. self._do_skew = do_skew
  92. self._do_temp_last = do_temp_last
  93. def forward(
  94. self,
  95. logits: torch.Tensor,
  96. sampling_metadata: SamplingMetadata,
  97. ) -> Optional[SamplerOutput]:
  98. """
  99. Args:
  100. logits: (num_tokens, vocab_size).
  101. sampling_metadata: Metadata for sampling.
  102. """
  103. assert logits is not None
  104. _, vocab_size = logits.shape
  105. # Prepare sampling tensors with pinned memory to avoid blocking.
  106. if not sampling_metadata.reuse_sampling_tensors:
  107. self._init_sampling_tensors(logits, sampling_metadata)
  108. elif self._do_penalties:
  109. # In this case, the sampling tensors logic depends on
  110. # "output_tokens" of a sequence. As a result, we cannot
  111. # reuse sampling tensors, since "output_tokens" changes
  112. # between decode runs.
  113. self._init_sampling_tensors(logits, sampling_metadata)
  114. assert self._sampling_tensors is not None
  115. sampling_tensors = self._sampling_tensors
  116. do_penalties = self._do_penalties
  117. do_no_repeat_ngrams = self._do_no_repeat_ngrams
  118. do_temperatures = self._do_temperatures
  119. do_top_p_top_k = self._do_top_p_top_k
  120. do_top_as = self._do_top_as
  121. do_min_p = self._do_min_p
  122. do_tfss = self._do_tfss
  123. do_eta_cutoffs = self._do_eta_cutoffs
  124. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  125. do_typical_ps = self._do_typical_ps
  126. do_quadratic = self._do_quadratic
  127. do_xtc = self._do_xtc
  128. do_nsigmas = self._do_nsgimas
  129. do_dry = self._do_dry
  130. do_skew = self._do_skew
  131. do_temp_last = self._do_temp_last
  132. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  133. if do_dry:
  134. logits = _apply_dry(
  135. logits,
  136. sampling_tensors.prompt_tokens,
  137. sampling_tensors.dry_multipliers,
  138. sampling_tensors.dry_bases,
  139. sampling_tensors.dry_allowed_lengths,
  140. sampling_tensors.dry_sequence_breaker_ids
  141. )
  142. # Apply presence and frequency penalties.
  143. if do_penalties:
  144. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  145. sampling_tensors.output_tokens,
  146. sampling_tensors.presence_penalties,
  147. sampling_tensors.frequency_penalties,
  148. sampling_tensors.repetition_penalties)
  149. if do_no_repeat_ngrams:
  150. logits = _apply_no_repeat_ngram(
  151. logits,
  152. sampling_tensors.prompt_tokens,
  153. sampling_tensors.no_repeat_ngram_sizes)
  154. # Apply temperature scaling if not doing temp_last.
  155. if do_temperatures and not do_temp_last:
  156. _apply_temperatures(logits, sampling_tensors.temperatures,
  157. sampling_tensors.dynatemp_mins,
  158. sampling_tensors.dynatemp_maxs,
  159. sampling_tensors.dynatemp_exps)
  160. if do_nsigmas:
  161. logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
  162. if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
  163. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  164. sampling_tensors.top_ks)
  165. if do_top_as:
  166. logits = _apply_top_a(logits, sampling_tensors.top_as)
  167. if do_min_p:
  168. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  169. if do_tfss:
  170. logits = _apply_tfs(logits, sampling_tensors.tfss)
  171. if do_eta_cutoffs:
  172. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  173. if do_epsilon_cutoffs:
  174. logits = _apply_epsilon_cutoff(logits,
  175. sampling_tensors.epsilon_cutoffs)
  176. if do_typical_ps:
  177. logits = _apply_typical_sampling(logits,
  178. sampling_tensors.typical_ps)
  179. if do_quadratic:
  180. logits = _apply_quadratic_sampling(
  181. logits, sampling_tensors.smoothing_factors,
  182. sampling_tensors.smoothing_curves)
  183. if do_xtc:
  184. logits = _apply_xtc_sampling(
  185. logits, sampling_tensors.xtc_thresholds,
  186. sampling_tensors.xtc_probabilities)
  187. if do_temperatures and do_temp_last:
  188. _apply_temperatures(logits, sampling_tensors.temperatures,
  189. sampling_tensors.dynatemp_mins,
  190. sampling_tensors.dynatemp_maxs,
  191. sampling_tensors.dynatemp_exps)
  192. banned_tokens = _get_custom_token_bans(sampling_metadata)
  193. logits = _apply_token_bans(logits, banned_tokens)
  194. # We use float32 for probabilities and log probabilities.
  195. # Compute the probabilities.
  196. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  197. # skew needs to be applied post-softmax
  198. if do_skew:
  199. # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
  200. cum_probs = torch.cumsum(probs, dim=-1)
  201. cum_probs = torch.pow(cum_probs, torch.exp(
  202. sampling_tensors.skews).unsqueeze(dim=1))
  203. probs = torch.diff(cum_probs, dim=-1,
  204. prepend=torch.zeros_like(cum_probs[..., :1]))
  205. logits = torch.log(probs)
  206. # Compute the log probabilities.
  207. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  208. # Sample the next tokens.
  209. sample_results, maybe_sampled_tokens_tensor = _sample(
  210. probs,
  211. logprobs,
  212. sampling_metadata,
  213. sampling_tensors,
  214. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  215. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  216. )
  217. if self.include_gpu_probs_tensor:
  218. assert maybe_sampled_tokens_tensor is not None
  219. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  220. else:
  221. on_device_tensors = None
  222. # Get the logprobs query results.
  223. prompt_logprobs = None
  224. sample_logprobs = None
  225. if not sampling_metadata.skip_sampler_cpu_output:
  226. prompt_logprobs, sample_logprobs = _get_logprobs(
  227. logprobs, sampling_metadata, sample_results)
  228. return _build_sampler_output(
  229. sample_results,
  230. sampling_metadata,
  231. prompt_logprobs,
  232. sample_logprobs,
  233. on_device_tensors=on_device_tensors,
  234. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  235. @property
  236. def _should_modify_greedy_probs_inplace(self) -> bool:
  237. """Whether or not the sampler should modify the probability distribution
  238. of greedily-sampled tokens such that multinomial sampling would sample
  239. the greedily-sampled token.
  240. In other words, if True then we set the probability of the greedily-
  241. sampled token to 1.
  242. This is used by speculative decoding, which requires that the sampling
  243. method be encoded into the probability distribution.
  244. """
  245. return self.should_modify_greedy_probs_inplace
  246. def _get_bin_counts_and_mask(
  247. tokens: torch.Tensor,
  248. vocab_size: int,
  249. num_seqs: int,
  250. ) -> Tuple[torch.Tensor, torch.Tensor]:
  251. # Compute the bin counts for the tokens.
  252. # vocab_size + 1 for padding.
  253. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  254. dtype=torch.long,
  255. device=tokens.device)
  256. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  257. bin_counts = bin_counts[:, :vocab_size]
  258. mask = bin_counts > 0
  259. return bin_counts, mask
  260. def _get_custom_token_bans(
  261. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  262. assert sampling_metadata.seq_groups is not None
  263. banned_tokens: List[List[int]] = []
  264. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  265. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  266. seq_ids = seq_group.seq_ids
  267. custom_token_bans = sampling_params.custom_token_bans
  268. if (i < sampling_metadata.num_prompts
  269. and sampling_params.prompt_logprobs is not None):
  270. prompt_len = len(seq_group.prompt_logprob_indices)
  271. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  272. banned_tokens += [custom_token_bans] * len(seq_ids)
  273. return banned_tokens
  274. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  275. output_tokens_tensor: torch.Tensor,
  276. presence_penalties: torch.Tensor,
  277. frequency_penalties: torch.Tensor,
  278. repetition_penalties: torch.Tensor) -> torch.Tensor:
  279. num_seqs, vocab_size = logits.shape
  280. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  281. num_seqs)
  282. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  283. output_tokens_tensor, vocab_size, num_seqs)
  284. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  285. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  286. logits = torch.where(logits > 0, logits / repetition_penalties,
  287. logits * repetition_penalties)
  288. # We follow the definition in OpenAI API.
  289. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  290. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  291. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  292. return logits
  293. def _apply_temperatures(
  294. logits: torch.Tensor,
  295. temperatures: torch.Tensor,
  296. dynatemp_mins: torch.Tensor,
  297. dynatemp_maxs: torch.Tensor,
  298. dynatemp_exps: torch.Tensor,
  299. ) -> None:
  300. dynatemp_mask = (dynatemp_mins != 0) | (dynatemp_maxs != 0)
  301. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  302. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  303. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  304. dynatemp_logits = logits[dynatemp_mask]
  305. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  306. dynatemp_probs = dynatemp_shifted_logits.exp()
  307. dynatemp_entropies = -(dynatemp_probs *
  308. dynatemp_shifted_logits).nansum(dim=-1)
  309. dynatemp_max_entropies = torch.log_(
  310. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  311. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  312. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  313. normalized_entropies.pow_(dynatemp_exps))
  314. temperatures[dynatemp_mask] = dyn_temp
  315. temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
  316. temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
  317. # To prevent saturation of top logits, we shift the range to [-inf, 1]
  318. # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
  319. # Why mask? So we aren't potentially discarding data in milder temps.
  320. low_temps = temperatures < 0.1
  321. logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
  322. logits.div_(temperatures.unsqueeze(dim=1))
  323. def _apply_token_bans(logits: torch.Tensor,
  324. banned_tokens: List[List[int]]) -> torch.Tensor:
  325. for i, banned_token_ids in enumerate(banned_tokens):
  326. if i >= logits.size(0):
  327. break
  328. if not banned_token_ids:
  329. continue
  330. logits[i, banned_token_ids] = -float("inf")
  331. return logits
  332. def _apply_min_tokens_penalty(
  333. logits: torch.Tensor,
  334. sampling_metadata: SamplingMetadata,
  335. ) -> torch.Tensor:
  336. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  337. have not been generated yet
  338. """
  339. # list of indices in logits that will be set to -inf
  340. logits_to_penalize = []
  341. logits_applied = 0
  342. for seq_group in sampling_metadata.seq_groups:
  343. seq_ids = seq_group.seq_ids
  344. sampling_params = seq_group.sampling_params
  345. sample_indices = seq_group.sample_indices
  346. logits_applied += len(sample_indices) + len(
  347. seq_group.prompt_logprob_indices)
  348. if not seq_group.do_sample:
  349. continue
  350. start_idx = sample_indices[0]
  351. min_tokens = sampling_params.min_tokens
  352. token_ids_to_penalize = sampling_params.all_stop_token_ids
  353. if min_tokens > 0 and token_ids_to_penalize:
  354. seqs_to_penalize = []
  355. for j, seq_id in enumerate(seq_ids):
  356. seq_data = seq_group.seq_data[seq_id]
  357. if len(seq_data.output_token_ids_array) < min_tokens:
  358. seqs_to_penalize.append(j)
  359. if seqs_to_penalize:
  360. # convert to the index into logits
  361. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  362. # itertools.product pairs each seq index with every token id
  363. logits_to_penalize.extend(
  364. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  365. if logits_to_penalize:
  366. # use zip and * to group indices along each dimension
  367. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  368. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  369. # verifies that no rows in logits were missed unexpectedly
  370. assert logits_applied == logits.shape[0]
  371. return logits
  372. def _apply_dry(
  373. logits: torch.Tensor,
  374. input_ids: torch.Tensor,
  375. multipliers: torch.Tensor,
  376. bases: torch.Tensor,
  377. allowed_lengths: torch.Tensor,
  378. sequence_breakers_ids: torch.Tensor
  379. ) -> torch.Tensor:
  380. """
  381. Apply Exclude Don't Repeat Yourself (DRY) sampling to the logits.
  382. Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
  383. """
  384. # Don't apply dry penalties if multiplier is 0
  385. if torch.all(multipliers == 0):
  386. return logits
  387. # Process each sequence in the batch
  388. for i, (input_ids_row, logits_row) in enumerate(zip(input_ids, logits)):
  389. multiplier = multipliers[i].item()
  390. if multiplier == 0:
  391. continue # Skip processing for this sequence
  392. # Get the last token
  393. last_token = input_ids_row[-1].item()
  394. # Skip if last token is a sequence breaker
  395. if last_token in sequence_breakers_ids:
  396. continue
  397. # Find matches of the last token, excluding the last position
  398. match_indices = (input_ids_row[:-1] == last_token).nonzero()
  399. # Track max matching sequence length for each potential next token
  400. match_lengths = {}
  401. # Process each match
  402. for idx in match_indices:
  403. # Convert to scalar
  404. idx = idx.item()
  405. # Get the token that followed this match in the input
  406. next_token = input_ids_row[idx + 1].item()
  407. # Skip if next token is a sequence breaker
  408. if next_token in sequence_breakers_ids:
  409. continue
  410. # We found last_token matches at this index, so match length starts
  411. # at 1
  412. match_length = 1
  413. # Try to extend match backwards
  414. while match_length < 50:
  415. j = idx - match_length
  416. k = len(input_ids_row) - match_length - 1
  417. if j < 0 or k < 0:
  418. # Reached start of input
  419. break
  420. if input_ids_row[j].item() != input_ids_row[k].item():
  421. # No more matches
  422. break
  423. if input_ids_row[k].item() in sequence_breakers_ids:
  424. # Hit a sequence breaker
  425. break
  426. match_length += 1
  427. # Update max match length for this next token
  428. if next_token in match_lengths:
  429. match_lengths[next_token] = max(
  430. match_length, match_lengths[next_token])
  431. else:
  432. match_lengths[next_token] = match_length
  433. # Apply penalties based on match lengths
  434. allowed_length = allowed_lengths[i]
  435. multiplier = multipliers[i]
  436. base = bases[i]
  437. for token, match_length in match_lengths.items():
  438. if match_length >= allowed_length:
  439. penalty = multiplier * (base ** (match_length - allowed_length))
  440. logits_row[token] -= penalty
  441. return logits
  442. def _apply_no_repeat_ngram(
  443. logits: torch.Tensor,
  444. input_ids: torch.Tensor,
  445. ngram_size: torch.Tensor,
  446. ) -> torch.Tensor:
  447. """Apply no-repeat-ngram penalty which sets logits to -inf for tokens that
  448. would create a repeated n-gram.
  449. """
  450. if torch.all(ngram_size == 0):
  451. return logits
  452. batch_size = logits.shape[0]
  453. for i in range(batch_size):
  454. size = int(ngram_size[i].item())
  455. if size == 0:
  456. continue
  457. cur_len = len(input_ids[i])
  458. if cur_len < size:
  459. continue
  460. banned_tokens = _calc_banned_ngram_tokens(
  461. ngram_size=size,
  462. prev_input_ids=input_ids[i],
  463. cur_len=cur_len-1
  464. )
  465. if banned_tokens:
  466. logits[i, banned_tokens] = -float("inf")
  467. return logits
  468. def _apply_top_k_top_p(
  469. logits: torch.Tensor,
  470. p: torch.Tensor,
  471. k: torch.Tensor,
  472. ) -> torch.Tensor:
  473. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  474. # Apply top-k.
  475. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  476. # Get all the top_k values.
  477. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  478. top_k_mask = logits_sort < top_k_mask
  479. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  480. # Apply top-p.
  481. probs_sort = logits_sort.softmax(dim=-1)
  482. probs_sum = probs_sort.cumsum(dim=-1)
  483. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  484. # at least one
  485. top_p_mask[:, -1] = False
  486. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  487. # Re-sort the probabilities.
  488. src = torch.arange(logits_idx.shape[-1],
  489. device=logits_idx.device).expand_as(logits_idx)
  490. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  491. index=logits_idx,
  492. src=src)
  493. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  494. return logits
  495. def _apply_min_p(
  496. logits: torch.Tensor,
  497. min_p: torch.Tensor,
  498. ) -> torch.Tensor:
  499. """
  500. Adapted from
  501. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  502. """
  503. probs = torch.softmax(logits, dim=-1)
  504. top_probs, _ = probs.max(dim=-1, keepdim=True)
  505. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  506. tokens_to_remove = probs < scaled_min_p
  507. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  508. return logits
  509. def _apply_top_a(
  510. logits: torch.Tensor,
  511. top_a: torch.Tensor,
  512. ) -> torch.Tensor:
  513. probs = torch.softmax(logits, dim=-1)
  514. top_probs, _ = probs.max(dim=-1, keepdim=True)
  515. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  516. tokens_to_remove = probs < threshold
  517. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  518. return logits
  519. def _apply_tfs(
  520. logits: torch.Tensor,
  521. tfs: torch.Tensor,
  522. ) -> torch.Tensor:
  523. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  524. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  525. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  526. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  527. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  528. tfs_mask = torch.cat(
  529. (
  530. torch.zeros(
  531. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  532. tfs_mask,
  533. torch.ones(
  534. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  535. ),
  536. dim=-1,
  537. )
  538. logits_sort[tfs_mask] = -float("inf")
  539. logits = torch.gather(logits_sort,
  540. dim=-1,
  541. index=torch.argsort(logits_idx, dim=-1))
  542. return logits
  543. def _apply_eta_cutoff(
  544. logits: torch.Tensor,
  545. eta_cutoff: torch.Tensor,
  546. ) -> torch.Tensor:
  547. shifted_logits = torch.log_softmax(logits, dim=-1)
  548. probs = shifted_logits.exp()
  549. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  550. eps = torch.min(eta_cutoff,
  551. torch.sqrt(eta_cutoff) *
  552. torch.exp(neg_entropy)).unsqueeze(dim=1)
  553. eta_mask = probs < eps
  554. # guard against nulling out all the logits
  555. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  556. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  557. logits[eta_mask] = -float("inf")
  558. return logits
  559. def _apply_epsilon_cutoff(
  560. logits: torch.Tensor,
  561. epsilon_cutoff: torch.Tensor,
  562. ) -> torch.Tensor:
  563. probs = logits.softmax(dim=-1)
  564. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  565. # guard against nulling out all the logits
  566. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  567. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  568. logits[eps_mask] = -float("inf")
  569. return logits
  570. def _apply_typical_sampling(
  571. logits: torch.Tensor,
  572. typical_p: torch.Tensor,
  573. ) -> torch.Tensor:
  574. shifted_logits = torch.log_softmax(logits, dim=-1)
  575. probs = shifted_logits.exp()
  576. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  577. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  578. _, indices = torch.sort(surprisal_deviations)
  579. reordered_probs = probs.gather(-1, indices)
  580. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  581. dim=1)
  582. min_tokens_to_keep = 1
  583. # Keep at least min_tokens_to_keep
  584. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  585. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  586. logits[typ_mask] = -float("inf")
  587. return logits
  588. def _apply_quadratic_sampling(
  589. logits: torch.Tensor,
  590. smoothing_factor: torch.Tensor,
  591. smoothing_curve: torch.Tensor,
  592. ) -> torch.Tensor:
  593. """
  594. Applies a quadratic transformation to the logits based on the
  595. provided smoothing factors and curves. The transformation is
  596. centered around the maximum logit value in the batch.
  597. The transformation involves a quadratic and cubic term, with the
  598. cubic term controlled by the smoothing curve. The quadratic term is
  599. scaled by the smoothing factor, and the cubic term is scaled by the
  600. product of the smoothing factor and the smoothing curve.
  601. params:
  602. logits (torch.Tensor): The logits to be transformed.
  603. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  604. term in the transformation.
  605. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  606. in the transformation.
  607. returns:
  608. torch.Tensor: The transformed logits.
  609. Credits: @kalomaze
  610. """
  611. mask = smoothing_factor != 0
  612. smoothing_factor.unsqueeze_(dim=1)
  613. smoothing_curve.unsqueeze_(dim=1)
  614. k = smoothing_factor * (3 - smoothing_curve) / 2
  615. s = smoothing_factor * (smoothing_curve - 1) / 2
  616. quadlogits = logits[mask] # limit to logits using this sampler
  617. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  618. # Construct the delta from each logit to its new value
  619. diff = quadlogits - max_logits
  620. diff -= diff**2 * (s[mask] * diff - k[mask])
  621. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  622. logits[mask] -= diff
  623. return logits
  624. def _apply_xtc_sampling(
  625. logits: torch.Tensor,
  626. xtc_thresholds: torch.Tensor,
  627. xtc_probabilities: torch.Tensor,
  628. ) -> torch.Tensor:
  629. """Apply Exclude Top Choices (XTC) sampling to the logits.
  630. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  631. Args:
  632. logits: (num_tokens, vocab_size) The input logits.
  633. xtc_thresholds: (num_tokens,) The threshold for each token.
  634. xtc_probabilities: (num_tokens,) The probability of applying XTC
  635. for each token.
  636. Returns:
  637. torch.Tensor: The modified logits.
  638. """
  639. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  640. if not apply_xtc.any():
  641. return logits
  642. probs = torch.softmax(logits, dim=-1)
  643. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  644. # Find indices where the next probability is above the threshold
  645. # Skips the top choice, which later on becomes skipping the last choice.
  646. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  647. # Apply XTC only to rows where it should be applied
  648. for i in range(logits.shape[0]):
  649. if apply_xtc[i]:
  650. # Count logits above the threshold (skipping the first)
  651. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  652. if indices_to_remove > 0:
  653. # Implies the top logit and at least one other is >= threshold.
  654. # Mask out above_thresh logits except the last/lowest one.
  655. logits[i].scatter_(
  656. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  657. return logits
  658. def _apply_top_nsigma(
  659. logits: torch.Tensor,
  660. nsigma: torch.Tensor,
  661. ) -> torch.Tensor:
  662. """Apply top-nsigma truncation to the logits.
  663. Reference: https://arxiv.org/abs/2411.07641
  664. Args:
  665. logits: Logits of shape (num_tokens, vocab_size)
  666. nsigma: Number of standard deviations to use as threshold
  667. Returns:
  668. Modified logits with values below threshold set to -inf
  669. """
  670. std = logits.std(dim=-1, keepdim=True)
  671. threshold = (logits.max(dim=-1, keepdim=True).values -
  672. nsigma.unsqueeze(dim=1) * std)
  673. logits[logits < threshold] = float("-inf")
  674. return logits
  675. def _greedy_sample(
  676. selected_seq_groups: List[SequenceGroupToSample],
  677. samples: torch.Tensor,
  678. ) -> List[Tuple[List[int], List[int]]]:
  679. """Run greedy sampling on a given samples.
  680. Args:
  681. selected_seq_groups: A list of sequence groups batched.
  682. samples: (num_selected_samples,) A tensor of samples. The length of
  683. samples could be smaller than selected_seq_groups if
  684. seq_group.do_sample is False.
  685. Returns:
  686. Tuple of (next_token_ids, parent_ids). The length of returned list is
  687. same as the length of selected_seq_groups. If the corresponding
  688. seq_group has do_sample=False, tuple contains ([], [])
  689. """
  690. samples = samples.tolist()
  691. sample_idx = 0
  692. results = []
  693. for seq_group in selected_seq_groups:
  694. if not seq_group.do_sample:
  695. results.append(([], []))
  696. continue
  697. seq_ids = seq_group.seq_ids
  698. num_parent_seqs = len(seq_ids)
  699. assert num_parent_seqs == 1, (
  700. "Greedy sampling should have only one seq.")
  701. parent_ids = list(range(num_parent_seqs))
  702. next_token_ids = [samples[sample_idx]]
  703. results.append((next_token_ids, parent_ids))
  704. sample_idx += num_parent_seqs
  705. return results
  706. def _random_sample(
  707. selected_seq_groups: List[SequenceGroupToSample],
  708. random_samples: torch.Tensor,
  709. ) -> List[Tuple[List[int], List[int]]]:
  710. """Run random sampling on a given samples.
  711. Args:
  712. selected_seq_groups: A list of sequence groups batched.
  713. random_samples: (num_selected_samples,) A tensor of samples. The
  714. length of samples could be smaller than selected_seq_groups if
  715. seq_group.do_sample is False.
  716. Returns:
  717. Tuple of (next_token_ids, parent_ids). The length of returned list is
  718. same as the length of selected_seq_groups. If the corresponding
  719. seq_group has do_sample=False, tuple contains ([], [])
  720. """
  721. # Find the maximum best_of value of the prompt phase requests.
  722. random_samples = random_samples.cpu()
  723. sample_idx = 0
  724. results = []
  725. for seq_group in selected_seq_groups:
  726. if not seq_group.do_sample:
  727. results.append(([], []))
  728. continue
  729. seq_ids = seq_group.seq_ids
  730. sampling_params = seq_group.sampling_params
  731. is_prompt = seq_group.is_prompt
  732. num_parent_seqs = len(seq_ids)
  733. if is_prompt:
  734. # Prompt phase.
  735. parent_ids = [0] * sampling_params.best_of
  736. next_token_ids = random_samples[
  737. sample_idx, :sampling_params.best_of].tolist()
  738. else:
  739. # Generation phase.
  740. parent_ids = list(range(num_parent_seqs))
  741. next_token_ids = random_samples[sample_idx:sample_idx +
  742. num_parent_seqs, 0].tolist()
  743. results.append((next_token_ids, parent_ids))
  744. sample_idx += num_parent_seqs
  745. return results
  746. def _beam_search_sample(
  747. selected_seq_groups: List[SequenceGroupToSample],
  748. logprobs: torch.Tensor,
  749. ) -> List[Tuple[List[int], List[int]]]:
  750. """Run beam sampling on a given samples.
  751. Args:
  752. selected_seq_groups: A list of sequence groups batched.
  753. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  754. on selected sample indices.
  755. Returns:
  756. Tuple of (next_token_ids, parent_ids). The length of returned list is
  757. same as the length of selected_seq_groups. If the corresponding
  758. seq_group has do_sample=False, tuple contains ([], [])
  759. """
  760. # We sample 2 * beam_width candidates to make sure that with high
  761. # probability we can get `beam_width` candidates in addition to
  762. # the finished sequences for the next iteration. See
  763. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  764. # for details. See also HF reference:
  765. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  766. #
  767. # NOTE: Beam search is not vectorized, so its speed can be slower than
  768. # other sampling methods.
  769. sample_idx = 0
  770. results = []
  771. for seq_group in selected_seq_groups:
  772. if not seq_group.do_sample:
  773. results.append(([], []))
  774. continue
  775. is_prompt = seq_group.is_prompt
  776. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  777. num_parent_seqs = len(seq_ids)
  778. beam_width = sampling_params.best_of
  779. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  780. if is_prompt:
  781. # Prompt phase.
  782. assert num_parent_seqs == 1, (
  783. "Prompt input should have only one seq.")
  784. parent_ids = [0] * (2 * beam_width)
  785. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  786. 2 * beam_width)
  787. next_token_ids = next_token_ids.tolist()
  788. else:
  789. # Generation phase.
  790. cumulative_logprobs = [
  791. seq_group.seq_data[seq_id].cumulative_logprob
  792. for seq_id in seq_ids
  793. ]
  794. cumulative_logprobs = torch.tensor(
  795. cumulative_logprobs,
  796. dtype=torch.float,
  797. device=seq_group_logprobs.device)
  798. seq_group_logprobs = (seq_group_logprobs +
  799. cumulative_logprobs.unsqueeze(dim=1))
  800. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  801. 2 * beam_width)
  802. topk_ids = topk_ids.tolist()
  803. vocab_size = seq_group_logprobs.size(-1)
  804. parent_ids = [i // vocab_size for i in topk_ids]
  805. next_token_ids = [i % vocab_size for i in topk_ids]
  806. results.append((next_token_ids, parent_ids))
  807. sample_idx += num_parent_seqs
  808. assert sample_idx == logprobs.size(0)
  809. return results
  810. # torch.multinomial forces a GPU<->CPU sync.
  811. # Therefore, we use an optimized implementation instead.
  812. # Note that we always sample with replacement.
  813. # probs will be modified in place, but this is fine, as we pass
  814. # in a copy already.
  815. def _multinomial(
  816. probs: torch.Tensor,
  817. num_samples: int,
  818. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  819. ) -> torch.Tensor:
  820. if num_samples > 1:
  821. probs = probs.repeat_interleave(num_samples, dim=0)
  822. q = torch.empty_like(probs)
  823. if seq_groups is None:
  824. q.exponential_()
  825. else:
  826. sample_idx = 0
  827. for seq_group in seq_groups:
  828. seq_ids = seq_group.seq_ids
  829. stride = len(seq_ids) * num_samples
  830. assert seq_group.generator is not None
  831. q[sample_idx:sample_idx +
  832. stride].exponential_(generator=seq_group.generator)
  833. sample_idx += stride
  834. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  835. def _top_k_top_p_multinomial_with_kernels(
  836. probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
  837. num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
  838. max_top_k_round = 32
  839. if num_samples > 1:
  840. probs = probs.repeat_interleave(num_samples, dim=0)
  841. top_ks = top_ks.repeat_interleave(num_samples)
  842. top_ps = top_ps.repeat_interleave(num_samples)
  843. batch_size = probs.shape[0]
  844. uniform_samples = torch.empty((max_top_k_round, batch_size),
  845. device=probs.device)
  846. if seq_groups is None:
  847. uniform_samples.uniform_()
  848. else:
  849. sample_idx = 0
  850. for seq_group in seq_groups:
  851. seq_ids = seq_group.seq_ids
  852. stride = len(seq_ids) * num_samples
  853. assert seq_group.generator is not None
  854. uniform_samples[:, sample_idx:sample_idx +
  855. stride].uniform_(generator=seq_group.generator)
  856. sample_idx += stride
  857. batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
  858. probs,
  859. uniform_samples,
  860. top_ks,
  861. top_ps,
  862. )
  863. if not success.all():
  864. warnings.warn("CUDA rejection sampling failed, fallback.",
  865. stacklevel=1)
  866. probs = ops.top_k_renorm_prob(probs, top_ks)
  867. probs = ops.top_p_renorm_prob(probs, top_ps)
  868. batch_next_token_ids = ops.sampling_from_probs(
  869. probs, uniform_samples[0])
  870. return batch_next_token_ids.view(-1, num_samples)
  871. def _sample_with_torch(
  872. probs: torch.Tensor,
  873. logprobs: torch.Tensor,
  874. sampling_metadata: SamplingMetadata,
  875. sampling_tensors: SamplingTensors,
  876. include_gpu_probs_tensor: bool,
  877. modify_greedy_probs: bool,
  878. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  879. categorized_seq_group_ids = {t: [] for t in SamplingType}
  880. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  881. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  882. sampling_params = seq_group.sampling_params
  883. sampling_type = sampling_params.sampling_type
  884. categorized_seq_group_ids[sampling_type].append(i)
  885. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  886. sample_metadata = {}
  887. multinomial_samples = {}
  888. # Create output tensor for sampled token ids.
  889. if include_gpu_probs_tensor:
  890. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  891. 1,
  892. dtype=torch.long,
  893. device=logprobs.device)
  894. else:
  895. sampled_token_ids_tensor = None
  896. # Counterintuitively, having two loops here is actually faster.
  897. # The first loop can run without waiting on GPU<->CPU sync.
  898. for sampling_type in SamplingType:
  899. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  900. num_tokens = len(sample_indices)
  901. if num_tokens == 0:
  902. continue
  903. seq_group_id = categorized_seq_group_ids[sampling_type]
  904. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  905. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  906. long_sample_indices = sample_indices.long()
  907. if sampling_type == SamplingType.GREEDY:
  908. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  909. dim=-1)
  910. if include_gpu_probs_tensor:
  911. # Store sampled tokens in output tensor.
  912. sampled_token_ids_tensor[
  913. long_sample_indices] = greedy_samples.unsqueeze(-1)
  914. if modify_greedy_probs:
  915. # If required, modify the probabilities such that sampling from
  916. # the modified distribution would always sample the argmax
  917. # token id.
  918. _modify_greedy_probs_inplace(logprobs, probs,
  919. long_sample_indices,
  920. greedy_samples)
  921. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  922. max_best_of_in_batch = 1
  923. for seq_group in seq_groups:
  924. if seq_group.is_prompt:
  925. sampling_params = seq_group.sampling_params
  926. max_best_of_in_batch = max(max_best_of_in_batch,
  927. sampling_params.best_of)
  928. seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
  929. seq_groups)
  930. if APHRODITE_USE_SAMPLING_KERNELS is not None:
  931. multinomial_samples[
  932. sampling_type] = _top_k_top_p_multinomial_with_kernels(
  933. probs[long_sample_indices],
  934. sampling_tensors.top_ks[long_sample_indices],
  935. sampling_tensors.top_ps[long_sample_indices],
  936. max_best_of_in_batch,
  937. seq_groups_arg,
  938. )
  939. else:
  940. multinomial_samples[sampling_type] = _multinomial(
  941. probs[long_sample_indices],
  942. max_best_of_in_batch,
  943. seq_groups=seq_groups_arg)
  944. if sampled_token_ids_tensor is not None:
  945. # Store sampled tokens in output tensor.
  946. sampled_token_ids_tensor[long_sample_indices] = \
  947. multinomial_samples[sampling_type].to(torch.long)
  948. elif sampling_type == SamplingType.BEAM:
  949. beam_search_logprobs = logprobs[sample_indices]
  950. else:
  951. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  952. # GPU<->CPU sync happens in the loop below.
  953. # This also converts the sample output to Python objects.
  954. if not sampling_metadata.skip_sampler_cpu_output:
  955. for sampling_type in SamplingType:
  956. if sampling_type not in sample_metadata:
  957. continue
  958. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  959. if sampling_type == SamplingType.GREEDY:
  960. sample_results = _greedy_sample(seq_groups, greedy_samples)
  961. elif sampling_type in (SamplingType.RANDOM,
  962. SamplingType.RANDOM_SEED):
  963. sample_results = _random_sample(
  964. seq_groups, multinomial_samples[sampling_type])
  965. elif sampling_type == SamplingType.BEAM:
  966. sample_results = _beam_search_sample(seq_groups,
  967. beam_search_logprobs)
  968. sample_results_dict.update(zip(seq_group_id, sample_results))
  969. sample_results = [
  970. sample_results_dict.get(i, ([], []))
  971. for i in range(len(sampling_metadata.seq_groups))
  972. ]
  973. else:
  974. sample_results = []
  975. return sample_results, sampled_token_ids_tensor
  976. def _sample_with_triton_kernel(
  977. probs: torch.Tensor,
  978. logprobs: torch.Tensor,
  979. sampling_metadata: SamplingMetadata,
  980. sampling_tensors: SamplingTensors,
  981. ) -> List[Tuple[List[int], List[int]]]:
  982. categorized_seq_group_ids = {t: [] for t in SamplingType}
  983. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  984. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  985. sampling_params = seq_group.sampling_params
  986. sampling_type = sampling_params.sampling_type
  987. categorized_seq_group_ids[sampling_type].append(i)
  988. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  989. sample_metadata = {}
  990. max_best_of_in_batch = 1
  991. # Counterintuitively, having two loops here is actually faster.
  992. # The first loop can run without waiting on GPU<->CPU sync.
  993. for sampling_type in SamplingType:
  994. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  995. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  996. num_tokens = len(sample_indices)
  997. if num_tokens == 0:
  998. continue
  999. seq_group_id = categorized_seq_group_ids[sampling_type]
  1000. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  1001. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  1002. sample_indices,
  1003. sampled_token_indices)
  1004. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  1005. SamplingType.RANDOM_SEED):
  1006. for seq_group in seq_groups:
  1007. if seq_group.is_prompt:
  1008. sampling_params = seq_group.sampling_params
  1009. max_best_of_in_batch = max(max_best_of_in_batch,
  1010. sampling_params.best_of)
  1011. elif sampling_type == SamplingType.BEAM:
  1012. beam_search_logprobs = logprobs[sample_indices]
  1013. else:
  1014. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  1015. sampled_tokens, _, _ = sample_triton(
  1016. probs=probs,
  1017. seeds=sampling_tensors.sampling_seeds,
  1018. max_best_of=max_best_of_in_batch,
  1019. sample_indices=sampling_tensors.sample_indices,
  1020. logprobs=logprobs,
  1021. # don't save logprobs because we have logic for that below
  1022. # TODO: use this instead of the CPU-based logic below
  1023. save_logprobs=False,
  1024. )
  1025. # GPU<->CPU sync happens in the loop below.
  1026. for sampling_type in SamplingType:
  1027. if sampling_type not in sample_metadata:
  1028. continue
  1029. (seq_group_id, seq_groups, sample_indices,
  1030. sampled_token_indices) = sample_metadata[sampling_type]
  1031. if sampling_type == SamplingType.GREEDY:
  1032. sample_results = _greedy_sample(
  1033. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  1034. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1035. sample_results = _random_sample(
  1036. seq_groups, sampled_tokens[sampled_token_indices])
  1037. elif sampling_type == SamplingType.BEAM:
  1038. sample_results = _beam_search_sample(seq_groups,
  1039. beam_search_logprobs)
  1040. sample_results_dict.update(zip(seq_group_id, sample_results))
  1041. sample_results = [
  1042. sample_results_dict.get(i, ([], []))
  1043. for i in range(len(sampling_metadata.seq_groups))
  1044. ]
  1045. return sample_results
  1046. def _sample(
  1047. probs: torch.Tensor, logprobs: torch.Tensor,
  1048. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  1049. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  1050. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  1051. """
  1052. Args:
  1053. probs: (num_query_tokens_in_batch, num_vocab)
  1054. logprobs: (num_query_tokens_in_batch, num_vocab)
  1055. sampling_metadata: The metadata for a batch for sampling.
  1056. sampling_tensors: Tensors that include sampling related metadata.
  1057. Returns:
  1058. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  1059. If sampling is skipped, it returns ([], [])
  1060. sampled_token_ids_tensor: A tensor of sampled token ids.
  1061. """
  1062. return _sample_with_torch(
  1063. probs,
  1064. logprobs,
  1065. sampling_metadata,
  1066. sampling_tensors,
  1067. include_gpu_probs_tensor=include_gpu_probs_tensor,
  1068. modify_greedy_probs=modify_greedy_probs,
  1069. )
  1070. # TODO: Enable once Triton kernel & associated code is faster.
  1071. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  1072. # sampling_tensors)
  1073. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  1074. """
  1075. This function calculates the ranks of the chosen tokens in a logprob tensor.
  1076. Args:
  1077. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  1078. where N is the no. of tokens and M is the vocab dim.
  1079. indices (torch.Tensor): List of chosen token indices.
  1080. Returns:
  1081. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  1082. Each element in the returned tensor represents the rank
  1083. of the chosen token in the input logprob tensor.
  1084. """
  1085. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  1086. indices]
  1087. return (x > vals[:, None]).long().sum(1).add_(1)
  1088. def _get_logprobs(
  1089. logprobs: torch.Tensor,
  1090. sampling_metadata: SamplingMetadata,
  1091. sample_results: List[Tuple[List[int], List[int]]],
  1092. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  1093. """Return sample lobprobs and prompt logprobs.
  1094. The logic consists of 3 parts.
  1095. - Select indices to compute logprob from, ranks of token ids, and
  1096. the top k token ids from logprobs.
  1097. - Compute prompt logprobs if required.
  1098. - Compute sample logprobs if required.
  1099. Args:
  1100. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  1101. logprob per vocab. Sequence groups' query tokens are batched in a
  1102. single flattened tensor. For example, assuming there are N
  1103. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  1104. prompt logprob is enabled), decode tokens for seq_group_1 (if
  1105. sampling is required), prefill tokens for seq_group_2, ...
  1106. sampling_metadata: The sampling metadata.
  1107. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  1108. parent_ids) for each sequence group. When beam search is enabled,
  1109. sample_results can contain different number of seq_ids from
  1110. sampling_metadata.seq_groups. It is because beam search creates
  1111. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  1112. BEAM_WIDTH number of seq_ids).
  1113. Returns:
  1114. A tuple of prompt and sample logprobs per sequence group in a batch.
  1115. """
  1116. # The index of query token to calculate logprobs. It includes both
  1117. # prompt and sample logprob indices.
  1118. query_indices: List[int] = []
  1119. # The next token ids to get the logprob value from.
  1120. next_token_ids: List[int] = []
  1121. # The largest requested number of logprobs. We find logprobs as many as the
  1122. # largest num logprobs in this API. If every logprobs is None, it will be
  1123. # set to -1.
  1124. largest_num_logprobs = -1
  1125. # If beam search is enabled.
  1126. use_beam_search = False
  1127. # Select indices to compute logprob from, ranks of token ids, and the top
  1128. # k token ids from logprobs.
  1129. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  1130. sample_results):
  1131. sampling_params = seq_group.sampling_params
  1132. # Update indices and tokens for prompt logprobs.
  1133. if (seq_group.is_prompt
  1134. and sampling_params.prompt_logprobs is not None):
  1135. largest_num_logprobs = max(largest_num_logprobs,
  1136. sampling_params.prompt_logprobs)
  1137. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1138. query_indices.extend(seq_group.prompt_logprob_indices)
  1139. next_token_ids.extend(next_prompt_tokens)
  1140. # Update indices and next tokenes for sample logprob.
  1141. if seq_group.do_sample:
  1142. token_ids, parent_seq_ids = sample_result
  1143. # NOTE: We cannot directly use sample_indices because
  1144. # sample_indices only contain parent seq_ids of a previous step.
  1145. # The current step may have different number of seq_ids, and
  1146. # we can obtain it from `sample_result[1]`.
  1147. query_idx = seq_group.sample_indices[0]
  1148. query_indices.extend(
  1149. [query_idx + parent_id for parent_id in parent_seq_ids])
  1150. next_token_ids.extend(token_ids)
  1151. if sampling_params.logprobs is not None:
  1152. largest_num_logprobs = max(largest_num_logprobs,
  1153. sampling_params.logprobs)
  1154. use_beam_search = use_beam_search or sampling_params.use_beam_search
  1155. assert len(next_token_ids) == len(query_indices)
  1156. if len(query_indices) == 0:
  1157. empty_sampled_logprob = []
  1158. empty_prompt_logprob = None
  1159. return [empty_prompt_logprob], [empty_sampled_logprob]
  1160. selected_logprobs, ranks = None, None
  1161. top_logprobs, top_token_ids = None, None
  1162. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  1163. # skip the whole logprob calculation.
  1164. if largest_num_logprobs >= 0 or use_beam_search:
  1165. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  1166. next_token_ids_gpu = torch.tensor(next_token_ids,
  1167. device=logprobs.device)
  1168. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  1169. # contain duplicates if beam search is enabled.
  1170. selected_logprobs = logprobs[[
  1171. query_indices_gpu,
  1172. next_token_ids_gpu,
  1173. ]]
  1174. ranks = _get_ranks(
  1175. logprobs[query_indices_gpu],
  1176. next_token_ids_gpu,
  1177. )
  1178. assert selected_logprobs.shape[0] == ranks.shape[0]
  1179. # We need to compute top k only if there exists logprobs > 0.
  1180. if largest_num_logprobs > 0:
  1181. # Logprobs of topk tokens for a batch of sequence groups.
  1182. # (num_query_tokens_across_batch).
  1183. top_logprobs, top_token_ids = torch.topk(logprobs,
  1184. largest_num_logprobs,
  1185. dim=-1)
  1186. top_logprobs = top_logprobs.to('cpu')
  1187. top_token_ids = top_token_ids.to('cpu')
  1188. selected_logprobs = selected_logprobs.to('cpu')
  1189. ranks = ranks.to('cpu')
  1190. # Find prompt/sample logprobs.
  1191. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  1192. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  1193. top_logprob_idx = 0
  1194. selected_logprobs_idx = 0
  1195. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1196. sample_results):
  1197. (prompt_logprobs, top_logprob_idx,
  1198. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1199. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1200. selected_logprobs_idx, top_logprob_idx)
  1201. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1202. (sampled_logprobs, top_logprob_idx,
  1203. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1204. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1205. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1206. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1207. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1208. def _get_prompt_logprob_if_needed(
  1209. seq_group: SequenceGroupToSample,
  1210. selected_logprobs: torch.Tensor,
  1211. ranks: torch.Tensor,
  1212. top_token_ids: torch.Tensor,
  1213. top_logprobs: torch.Tensor,
  1214. selected_logprobs_idx: int,
  1215. top_logprob_idx: int,
  1216. ):
  1217. """Compute the prompt logprob from a sequence group if needed."""
  1218. sampling_params = seq_group.sampling_params
  1219. is_prompt = seq_group.is_prompt
  1220. # Find prompt logprobs
  1221. prompt_logprobs: Optional[PromptLogprobs] = None
  1222. if is_prompt and sampling_params.prompt_logprobs is not None:
  1223. prompt_logprobs = []
  1224. num_logprobs = sampling_params.prompt_logprobs
  1225. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1226. # Pre-select indexes and create a list. It is faster than calling .item
  1227. # repetitively.
  1228. selected_logprob_items = selected_logprobs[
  1229. selected_logprobs_idx:selected_logprobs_idx +
  1230. len(next_prompt_tokens)].tolist()
  1231. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1232. len(next_prompt_tokens)].tolist()
  1233. for idx, token_id in enumerate(next_prompt_tokens):
  1234. # Calculate the prompt logprob of the real prompt tokens.
  1235. # {token_id: (logprob, rank_from_vocab)}
  1236. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1237. token_id: (selected_logprob_items[idx], rank_items[idx])
  1238. }
  1239. # Add top K prompt logprobs along with its rank.
  1240. if num_logprobs > 0:
  1241. top_ids = top_token_ids[
  1242. top_logprob_idx, :num_logprobs].tolist()
  1243. top_probs = top_logprobs[
  1244. top_logprob_idx, :num_logprobs].tolist()
  1245. # Top K is already sorted by rank, so we can use 1 ~
  1246. # num_logprobs + 1 for rank.
  1247. top_ranks = range(1, num_logprobs + 1)
  1248. prompt_logprobs_dict.update({
  1249. top_id: (top_prob, rank)
  1250. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1251. top_ranks)
  1252. })
  1253. prompt_logprobs.append({
  1254. token_id: Logprob(*logprob_and_rank)
  1255. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1256. })
  1257. # + 1 to go to the next prompt token.
  1258. top_logprob_idx += 1
  1259. # + len(next_prompt_tokens) to go to the next prompt.
  1260. selected_logprobs_idx += len(next_prompt_tokens)
  1261. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1262. def _get_sampled_logprob_if_needed(
  1263. seq_group: SequenceGroupToSample,
  1264. sample_result: Tuple[List[int], List[int]],
  1265. selected_logprobs: torch.Tensor,
  1266. ranks: torch.Tensor,
  1267. top_token_ids: torch.Tensor,
  1268. top_logprobs: torch.Tensor,
  1269. selected_logprobs_idx: int,
  1270. top_logprob_idx: int,
  1271. ):
  1272. """Compute the sample logprob if needed."""
  1273. seq_ids = seq_group.seq_ids
  1274. num_logprobs = seq_group.sampling_params.logprobs
  1275. use_beam_search = seq_group.sampling_params.use_beam_search
  1276. sampled_logprobs: SampleLogprobs = []
  1277. next_token_ids, parent_seq_ids = sample_result
  1278. if seq_group.do_sample:
  1279. assert len(next_token_ids) > 0
  1280. if num_logprobs is None and not use_beam_search:
  1281. for next_token_id in next_token_ids:
  1282. # Use a dummy logprob
  1283. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1284. else:
  1285. # Pre-select items from tensor. tolist() is faster than repetitive
  1286. # `.item()` calls.
  1287. selected_logprob_items = selected_logprobs[
  1288. selected_logprobs_idx:selected_logprobs_idx +
  1289. len(next_token_ids)].tolist()
  1290. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1291. len(next_token_ids)].tolist()
  1292. for idx, (next_token_id, parent_id) in enumerate(
  1293. zip(next_token_ids, parent_seq_ids)):
  1294. # Get the logprob of a sampled token.
  1295. sampled_logprobs_dict = {
  1296. next_token_id:
  1297. (selected_logprob_items[idx], rank_items[idx])
  1298. }
  1299. if num_logprobs is not None and num_logprobs > 0:
  1300. # Get top K logprobs.
  1301. top_ids = top_token_ids[top_logprob_idx +
  1302. parent_id, :num_logprobs].tolist()
  1303. top_probs = top_logprobs[
  1304. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1305. # Top K is already sorted by rank, so we can use 1 ~
  1306. # num_logprobs + 1 for rank.
  1307. top_ranks = range(1, num_logprobs + 1)
  1308. sampled_logprobs_dict.update({
  1309. top_id: (top_prob, rank)
  1310. for top_id, top_prob, rank in zip(
  1311. top_ids, top_probs, top_ranks)
  1312. })
  1313. sampled_logprobs.append({
  1314. token_id: Logprob(*logprob_and_rank)
  1315. for token_id, logprob_and_rank in
  1316. sampled_logprobs_dict.items()
  1317. })
  1318. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1319. # logprobs for the current step, which has len(next_token_ids) tokens
  1320. # per sequence group. `logprobs` includes logprobs from the previous
  1321. # steps, which has len(seq_ids) tokens per sequence group.
  1322. # Iterate to the next sequence group in a batch.
  1323. selected_logprobs_idx += len(next_token_ids)
  1324. # Iterate to the next sequence group in a batch.
  1325. top_logprob_idx += len(seq_ids)
  1326. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1327. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1328. sample_indices: torch.Tensor,
  1329. greedy_samples: torch.Tensor) -> None:
  1330. """Modify the probability distributions of the greedily-sampled tokens such
  1331. that each sampled token has a "probability" of 1.0. This is required by
  1332. speculative decoding, which depends on the sampling method being encoded
  1333. within the probability distribution for correctness.
  1334. # Why do we only need to do this for greedy sampling?
  1335. Aphrodite's sampler performs the following steps for greedy or multinomial
  1336. (random) sampling:
  1337. 1. Get logits from model.
  1338. 2. Modify logits according to per-sequence sampling parameters.
  1339. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1340. according to their frequency, etc.
  1341. 3. Sample a token.
  1342. - Random sampling simply samples from the modified probability
  1343. distribution.
  1344. - Greedy sampling performs `argmax` to obtain the token with the
  1345. highest likelihood.
  1346. Ignoring greedy sampling for a moment, we find that the computed probability
  1347. distribution has the following property: we can sample from it independently
  1348. and find that the token sampled by the Sampler has a frequency corresponding
  1349. to how often we see it in our sampling. In other words, for tokens sampled
  1350. with Aphrodite's random SamplingType, the computed probability distribution
  1351. encodes the sampling methodology completely.
  1352. Greedy sampling does not normally have this property. Aphrodite modifies
  1353. logits according to sampling params, then performs `argmax`, then returns
  1354. the sampled token and the computed probability distribution. If we sample
  1355. from the distribution, we'll find the likelihood of the greedily-sampled
  1356. token is not always 1.0.
  1357. Since lossless speculative decoding requires that the sampling methodology
  1358. be encoded within the probability distribution, we are motivated to modify
  1359. the probability distribution such that the sampled token has probability 1
  1360. when speculative decoding is used.
  1361. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1362. greedy sampling using multinomial computation and unite the codepaths. This
  1363. has implications on the overall design of the sampler, e.g. how to record
  1364. accurate logprobs for the user, so this improvement is deferred to later.
  1365. """
  1366. # NOTE: logprobs are not modified so they can be returned to the user.
  1367. probs[sample_indices, :] = 0
  1368. probs[sample_indices, greedy_samples] = 1.0
  1369. def _build_sampler_output(
  1370. sample_results: SampleResultType,
  1371. sampling_metadata: SamplingMetadata,
  1372. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1373. sample_logprobs: Optional[List[SampleLogprobs]],
  1374. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1375. torch.Tensor]],
  1376. skip_sampler_cpu_output: bool = False,
  1377. ) -> SamplerOutput:
  1378. """Construct Python objects with the output of sampling.
  1379. Args:
  1380. on_device_tensors: Tuple containing on-device tensors with the
  1381. probabilities used in sampling and the sampled token ids. This
  1382. allows post-processing without copies to CPU/serialization, e.g. in
  1383. speculative decoding rejection sampling.
  1384. """
  1385. sampler_output: List[CompletionSequenceGroupOutput] = []
  1386. if not skip_sampler_cpu_output:
  1387. assert prompt_logprobs is not None
  1388. assert sample_logprobs is not None
  1389. for (seq_group, sample_result, group_prompt_logprobs,
  1390. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1391. sample_results, prompt_logprobs,
  1392. sample_logprobs):
  1393. seq_ids = seq_group.seq_ids
  1394. next_token_ids, parent_ids = sample_result
  1395. seq_outputs: List[SequenceOutput] = []
  1396. for parent_id, next_token_id, logprobs in zip(
  1397. parent_ids, next_token_ids, group_sample_logprobs):
  1398. seq_outputs.append(
  1399. SequenceOutput(seq_ids[parent_id], next_token_id,
  1400. logprobs))
  1401. sampler_output.append(
  1402. CompletionSequenceGroupOutput(seq_outputs,
  1403. group_prompt_logprobs))
  1404. # If not specified, store None values in SamplerOutput.
  1405. if on_device_tensors is not None:
  1406. (sampled_token_probs, logprobs_tensor,
  1407. sampled_token_ids) = on_device_tensors
  1408. else:
  1409. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1410. None)
  1411. return SamplerOutput(
  1412. outputs=sampler_output,
  1413. sampled_token_probs=sampled_token_probs,
  1414. sampled_token_ids=sampled_token_ids,
  1415. logprobs=logprobs_tensor,
  1416. )
  1417. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1418. """Get a list of next prompt tokens to compute logprob from a
  1419. given sequence group.
  1420. It is used to compute prompt logprob. Imagine you have logprob for each
  1421. query token. Query token needs to know the next prompt token id to compute
  1422. prompt logprob. This is a helper to obtain next prompt token ids.
  1423. This API has to be used only when the caller knows seq_group is in prefill
  1424. stage.
  1425. Returns:
  1426. A list of next prompt tokens to compute logprob.
  1427. """
  1428. assert seq_group.is_prompt, (
  1429. "Caller should ensure the sequence group is in a prefill stage.")
  1430. seq_ids = seq_group.seq_ids
  1431. query_len = seq_group.query_len
  1432. assert query_len is not None
  1433. # prompt has only 1 seq id.
  1434. assert len(seq_ids) == 1
  1435. seq_data = seq_group.seq_data[seq_ids[0]]
  1436. computed_len = seq_data.get_num_computed_tokens()
  1437. prompt_tokens = seq_data.prompt_token_ids
  1438. # +1 because we are looking for a next prompt token.
  1439. next_token_index_start = computed_len + 1
  1440. next_token_index_end = min(computed_len + query_len + 1,
  1441. len(prompt_tokens))
  1442. next_prompt_tokens = prompt_tokens[
  1443. next_token_index_start:next_token_index_end]
  1444. return next_prompt_tokens
  1445. def _get_ngrams(
  1446. ngram_size: int,
  1447. prev_input_ids: torch.Tensor
  1448. ) -> Dict[Tuple[int, ...], List[int]]:
  1449. """Get dictionary of ngrams and the tokens that followed them.
  1450. Args:
  1451. ngram_size: Size of ngrams to track
  1452. prev_input_ids: 1D tensor of previous token ids
  1453. Returns:
  1454. Dictionary mapping ngram tuples to list of tokens that followed them
  1455. """
  1456. generated_ngrams = {}
  1457. gen_tokens = prev_input_ids.tolist()
  1458. for i in range(len(gen_tokens) - ngram_size + 1):
  1459. ngram = tuple(gen_tokens[i:i + ngram_size - 1])
  1460. next_token = gen_tokens[i + ngram_size - 1]
  1461. if ngram in generated_ngrams:
  1462. generated_ngrams[ngram].append(next_token)
  1463. else:
  1464. generated_ngrams[ngram] = [next_token]
  1465. return generated_ngrams
  1466. def _get_generated_ngrams(
  1467. banned_ngrams: Dict[Tuple[int, ...], List[int]],
  1468. prev_input_ids: torch.Tensor,
  1469. ngram_size: int,
  1470. cur_len: int
  1471. ) -> List[int]:
  1472. """Get list of tokens that would create a repeated ngram if generated next.
  1473. Args:
  1474. banned_ngrams: Dictionary of previously seen ngrams and their next
  1475. tokens
  1476. prev_input_ids: Previous token ids
  1477. ngram_size: Size of ngrams to check
  1478. cur_len: Current position in sequence
  1479. Returns:
  1480. List of token ids that would create a repeat ngram
  1481. """
  1482. start_idx = cur_len + 1 - ngram_size
  1483. current_ngram = tuple(prev_input_ids[start_idx:cur_len].tolist())
  1484. return banned_ngrams.get(current_ngram, [])
  1485. def _calc_banned_ngram_tokens(
  1486. ngram_size: int,
  1487. prev_input_ids: torch.Tensor,
  1488. cur_len: int
  1489. ) -> List[int]:
  1490. """Calculate tokens that would create repeated ngrams if generated next.
  1491. Args:
  1492. ngram_size: Size of ngrams to prevent repeating
  1493. prev_input_ids: Previous token ids in sequence
  1494. cur_len: Current position in sequence
  1495. Returns:
  1496. List of token ids that should be banned to prevent ngram repetition
  1497. """
  1498. if cur_len + 1 < ngram_size:
  1499. return []
  1500. generated_ngrams = _get_ngrams(ngram_size, prev_input_ids)
  1501. banned_tokens = _get_generated_ngrams(
  1502. generated_ngrams,
  1503. prev_input_ids,
  1504. ngram_size,
  1505. cur_len
  1506. )
  1507. return banned_tokens
  1508. # def _apply_mirostat_v2(logits: torch.Tensor,
  1509. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1510. # # Reduce our view to just the affected logits
  1511. # logit_view = logits[sampling_tensors.miro_indices]
  1512. # # Calculate surprise value per token
  1513. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1514. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1515. # # Mask out "too-surprising" tokens (surprisal > mu)
  1516. # mus = sampling_tensors.miro_mus
  1517. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1518. # # Unmask most-likely logit to guarantee a selection.
  1519. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1520. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1521. # # Apply logit mask (effectively a top-k filter).
  1522. # logit_view[miro_mask] = -float("inf")
  1523. # # Project logit changes made to the view onto the original.
  1524. # # I think this step might be redundant.
  1525. # logits[sampling_tensors.miro_indices] = logit_view
  1526. # return logits
  1527. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1528. # sample_results: List[Tuple[List[int], List[int]]],
  1529. # sampling_metadata: SamplingMetadata,
  1530. # output_metadata: OutputMetadata) -> None:
  1531. # """Based on whichever token was finally sampled, we calculate the
  1532. # final surprisal values to update the mus.
  1533. # Because a single sequence can have multiple samples, we must fork
  1534. # the mu accordingly."""
  1535. # assert sampling_metadata.seq_groups is not None
  1536. # seqid_to_tokens = {}
  1537. # seqid_to_indices = {}
  1538. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1539. # sample_results):
  1540. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1541. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1542. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1543. # seqids = args.miro_seqids
  1544. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1545. # device=logits.device,
  1546. # dtype=torch.long)
  1547. # # Clumsily, we recalculate token surprisals.
  1548. # logits_view = logits[args.miro_indices]
  1549. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1550. # dim=-1,
  1551. # index=picked_tokens) / -math.log(2)
  1552. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1553. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1554. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1555. # nu_mus = mus - (picked_surprise - taus) * etas
  1556. # # Record updated mu values for use in the next iteration
  1557. # # Note how each mu is split into multiple based on the number of samples.
  1558. # for seqid, seq_mus in zip(seqids, nu_mus):
  1559. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1560. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)