sampler.py 62 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from math import inf
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.common.sampling_params import SamplingType
  8. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  9. PromptLogprobs, SampleLogprobs,
  10. SamplerOutput, SequenceOutput)
  11. from aphrodite.triton_utils import HAS_TRITON
  12. if HAS_TRITON:
  13. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  14. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  15. SamplingTensors,
  16. SequenceGroupToSample)
  17. # (num_token_ids, num_parent_ids) per sequence group.
  18. SampleResultType = List[Tuple[List[int], List[int]]]
  19. class Sampler(nn.Module):
  20. """Samples the next tokens from the model's outputs.
  21. This layer does the following:
  22. 1. Discard the hidden states that are not used for sampling (i.e., all
  23. tokens except the final one in each prompt).
  24. 2. Compute the logits for the next tokens.
  25. 3. Apply presence, frequency and repetition penalties.
  26. 4. Apply temperature scaling.
  27. 5. Apply top-p and top-k truncation.
  28. 6. Sample the next tokens.
  29. Here, each sequence group within the batch can have different sampling
  30. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  31. The structure of the logits tensor is coupled with the seq_groups in
  32. sampling_metadata. Typically, each sequence in each seq_group has one row in
  33. logits for the next token to be sampled; however, for a seq_group with a
  34. prompt request with the prompt_logprobs sampling parameter, there are rows
  35. in logits for each token in the input prompt.
  36. """
  37. def __init__(self):
  38. super().__init__()
  39. # Whether or not the SamplerOutput should have on-device tensors
  40. # containing the sampled token ids and probabilities. This is used by
  41. # speculative decoding.
  42. self.include_gpu_probs_tensor = False
  43. self.should_modify_greedy_probs_inplace = False
  44. def _init_sampling_tensors(
  45. self,
  46. logits: torch.Tensor,
  47. sampling_metadata: SamplingMetadata,
  48. ):
  49. """The goal here is to reuse sampling tensors between similar decode
  50. runs. This is possible because sampling logic does not change between
  51. decodes of the same sequences.
  52. """
  53. _, vocab_size = logits.shape
  54. # First free any existing stored sampling tensors.
  55. # This is necessary because some sampling tensors may
  56. # have pinned memory.
  57. self._sampling_tensors = None
  58. # Initialize new sampling tensors
  59. (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
  60. do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  61. do_typical_ps, do_quadratic, do_xtc, do_temp_last
  62. ) = SamplingTensors.from_sampling_metadata(
  63. sampling_metadata, vocab_size, logits.device, logits.dtype)
  64. self._sampling_tensors = sampling_tensors
  65. self._do_penalties = do_penalties
  66. self._do_temperatures = do_temperatures
  67. self._do_top_p_top_k = do_top_p_top_k
  68. self._do_top_as = do_top_as
  69. self._do_min_p = do_min_p
  70. self._do_tfss = do_tfss
  71. self._do_eta_cutoffs = do_eta_cutoffs
  72. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  73. self._do_typical_ps = do_typical_ps
  74. self._do_quadratic = do_quadratic
  75. self._do_xtc = do_xtc
  76. self._do_temp_last = do_temp_last
  77. def forward(
  78. self,
  79. logits: torch.Tensor,
  80. sampling_metadata: SamplingMetadata,
  81. ) -> Optional[SamplerOutput]:
  82. """
  83. Args:
  84. logits: (num_tokens, vocab_size).
  85. sampling_metadata: Metadata for sampling.
  86. """
  87. assert logits is not None
  88. _, vocab_size = logits.shape
  89. # Prepare sampling tensors with pinned memory to avoid blocking.
  90. if not sampling_metadata.reuse_sampling_tensors:
  91. self._init_sampling_tensors(logits, sampling_metadata)
  92. elif self._do_penalties:
  93. # In this case, the sampling tensors logic depends on
  94. # "output_tokens" of a sequence. As a result, we cannot
  95. # reuse sampling tensors, since "output_tokens" changes
  96. # between decode runs.
  97. self._init_sampling_tensors(logits, sampling_metadata)
  98. assert self._sampling_tensors is not None
  99. sampling_tensors = self._sampling_tensors
  100. do_penalties = self._do_penalties
  101. do_temperatures = self._do_temperatures
  102. do_top_p_top_k = self._do_top_p_top_k
  103. do_top_as = self._do_top_as
  104. do_min_p = self._do_min_p
  105. do_tfss = self._do_tfss
  106. do_eta_cutoffs = self._do_eta_cutoffs
  107. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  108. do_typical_ps = self._do_typical_ps
  109. do_quadratic = self._do_quadratic
  110. do_xtc = self._do_xtc
  111. do_temp_last = self._do_temp_last
  112. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  113. # Apply presence and frequency penalties.
  114. if do_penalties:
  115. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  116. sampling_tensors.output_tokens,
  117. sampling_tensors.presence_penalties,
  118. sampling_tensors.frequency_penalties,
  119. sampling_tensors.repetition_penalties)
  120. # Apply temperature scaling if not doing temp_last.
  121. if do_temperatures and not do_temp_last:
  122. _apply_temperatures(logits, sampling_tensors.temperatures,
  123. sampling_tensors.dynatemp_mins,
  124. sampling_tensors.dynatemp_maxs,
  125. sampling_tensors.dynatemp_exps)
  126. if do_top_p_top_k:
  127. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  128. sampling_tensors.top_ks)
  129. if do_top_as:
  130. logits = _apply_top_a(logits, sampling_tensors.top_as)
  131. if do_min_p:
  132. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  133. if do_tfss:
  134. logits = _apply_tfs(logits, sampling_tensors.tfss)
  135. if do_eta_cutoffs:
  136. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  137. if do_epsilon_cutoffs:
  138. logits = _apply_epsilon_cutoff(logits,
  139. sampling_tensors.epsilon_cutoffs)
  140. if do_typical_ps:
  141. logits = _apply_typical_sampling(logits,
  142. sampling_tensors.typical_ps)
  143. if do_quadratic:
  144. logits = _apply_quadratic_sampling(
  145. logits, sampling_tensors.smoothing_factors,
  146. sampling_tensors.smoothing_curves)
  147. if do_xtc:
  148. logits = _apply_xtc_sampling(
  149. logits, sampling_tensors.xtc_thresholds,
  150. sampling_tensors.xtc_probabilities)
  151. if do_temperatures and do_temp_last:
  152. _apply_temperatures(logits, sampling_tensors.temperatures,
  153. sampling_tensors.dynatemp_mins,
  154. sampling_tensors.dynatemp_maxs,
  155. sampling_tensors.dynatemp_exps)
  156. banned_tokens = _get_custom_token_bans(sampling_metadata)
  157. logits = _apply_token_bans(logits, banned_tokens)
  158. # We use float32 for probabilities and log probabilities.
  159. # Compute the probabilities.
  160. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  161. # Compute the log probabilities.
  162. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  163. # Sample the next tokens.
  164. sample_results, maybe_sampled_tokens_tensor = _sample(
  165. probs,
  166. logprobs,
  167. sampling_metadata,
  168. sampling_tensors,
  169. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  170. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  171. )
  172. if self.include_gpu_probs_tensor:
  173. assert maybe_sampled_tokens_tensor is not None
  174. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  175. else:
  176. on_device_tensors = None
  177. # Get the logprobs query results.
  178. prompt_logprobs = None
  179. sample_logprobs = None
  180. if not sampling_metadata.skip_sampler_cpu_output:
  181. prompt_logprobs, sample_logprobs = _get_logprobs(
  182. logprobs, sampling_metadata, sample_results)
  183. return _build_sampler_output(
  184. sample_results,
  185. sampling_metadata,
  186. prompt_logprobs,
  187. sample_logprobs,
  188. on_device_tensors=on_device_tensors,
  189. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  190. @property
  191. def _should_modify_greedy_probs_inplace(self) -> bool:
  192. """Whether or not the sampler should modify the probability distribution
  193. of greedily-sampled tokens such that multinomial sampling would sample
  194. the greedily-sampled token.
  195. In other words, if True then we set the probability of the greedily-
  196. sampled token to 1.
  197. This is used by speculative decoding, which requires that the sampling
  198. method be encoded into the probability distribution.
  199. """
  200. return self.should_modify_greedy_probs_inplace
  201. def _get_bin_counts_and_mask(
  202. tokens: torch.Tensor,
  203. vocab_size: int,
  204. num_seqs: int,
  205. ) -> Tuple[torch.Tensor, torch.Tensor]:
  206. # Compute the bin counts for the tokens.
  207. # vocab_size + 1 for padding.
  208. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  209. dtype=torch.long,
  210. device=tokens.device)
  211. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  212. bin_counts = bin_counts[:, :vocab_size]
  213. mask = bin_counts > 0
  214. return bin_counts, mask
  215. def _get_custom_token_bans(
  216. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  217. assert sampling_metadata.seq_groups is not None
  218. banned_tokens: List[List[int]] = []
  219. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  220. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  221. seq_ids = seq_group.seq_ids
  222. custom_token_bans = sampling_params.custom_token_bans
  223. if (i < sampling_metadata.num_prompts
  224. and sampling_params.prompt_logprobs is not None):
  225. prompt_len = len(seq_group.prompt_logprob_indices)
  226. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  227. banned_tokens += [custom_token_bans] * len(seq_ids)
  228. return banned_tokens
  229. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  230. output_tokens_tensor: torch.Tensor,
  231. presence_penalties: torch.Tensor,
  232. frequency_penalties: torch.Tensor,
  233. repetition_penalties: torch.Tensor) -> torch.Tensor:
  234. num_seqs, vocab_size = logits.shape
  235. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  236. num_seqs)
  237. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  238. output_tokens_tensor, vocab_size, num_seqs)
  239. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  240. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  241. logits = torch.where(logits > 0, logits / repetition_penalties,
  242. logits * repetition_penalties)
  243. # We follow the definition in OpenAI API.
  244. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  245. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  246. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  247. return logits
  248. def _apply_temperatures(
  249. logits: torch.Tensor,
  250. temperatures: torch.Tensor,
  251. dynatemp_mins: torch.Tensor,
  252. dynatemp_maxs: torch.Tensor,
  253. dynatemp_exps: torch.Tensor,
  254. ) -> None:
  255. dynatemp_mask = dynatemp_exps != 0
  256. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  257. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  258. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  259. dynatemp_logits = logits[dynatemp_mask]
  260. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  261. dynatemp_probs = dynatemp_shifted_logits.exp()
  262. dynatemp_entropies = -(dynatemp_probs *
  263. dynatemp_shifted_logits).nansum(dim=-1)
  264. dynatemp_max_entropies = torch.log_(
  265. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  266. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  267. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  268. normalized_entropies.pow_(dynatemp_exps))
  269. temperatures[dynatemp_mask] = dyn_temp
  270. temperatures[temperatures <= 0.0] = 1.0
  271. # Use float32 to apply temp.
  272. # Use in-place division to avoid creating a new tensor.
  273. logits = logits.to(torch.float)
  274. logits.div_(temperatures.unsqueeze(dim=1))
  275. def _apply_token_bans(logits: torch.Tensor,
  276. banned_tokens: List[List[int]]) -> torch.Tensor:
  277. for i, banned_token_ids in enumerate(banned_tokens):
  278. if not banned_token_ids:
  279. continue
  280. logits[i, banned_token_ids] = -float("inf")
  281. return logits
  282. def _apply_min_tokens_penalty(
  283. logits: torch.Tensor,
  284. sampling_metadata: SamplingMetadata,
  285. ) -> torch.Tensor:
  286. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  287. have not been generated yet
  288. """
  289. # list of indices in logits that will be set to -inf
  290. logits_to_penalize = []
  291. logits_applied = 0
  292. for seq_group in sampling_metadata.seq_groups:
  293. seq_ids = seq_group.seq_ids
  294. sampling_params = seq_group.sampling_params
  295. sample_indices = seq_group.sample_indices
  296. logits_applied += len(sample_indices) + len(
  297. seq_group.prompt_logprob_indices)
  298. if not seq_group.do_sample:
  299. continue
  300. start_idx = sample_indices[0]
  301. min_tokens = sampling_params.min_tokens
  302. token_ids_to_penalize = sampling_params.all_stop_token_ids
  303. if min_tokens > 0 and token_ids_to_penalize:
  304. seqs_to_penalize = []
  305. for j, seq_id in enumerate(seq_ids):
  306. seq_data = seq_group.seq_data[seq_id]
  307. if len(seq_data.output_token_ids_array) < min_tokens:
  308. seqs_to_penalize.append(j)
  309. if seqs_to_penalize:
  310. # convert to the index into logits
  311. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  312. # itertools.product pairs each seq index with every token id
  313. logits_to_penalize.extend(
  314. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  315. if logits_to_penalize:
  316. # use zip and * to group indices along each dimension
  317. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  318. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  319. # verifies that no rows in logits were missed unexpectedly
  320. assert logits_applied == logits.shape[0]
  321. return logits
  322. def _apply_top_k_top_p(
  323. logits: torch.Tensor,
  324. p: torch.Tensor,
  325. k: torch.Tensor,
  326. ) -> torch.Tensor:
  327. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  328. # Apply top-k.
  329. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  330. # Get all the top_k values.
  331. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  332. top_k_mask = logits_sort < top_k_mask
  333. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  334. # Apply top-p.
  335. probs_sort = logits_sort.softmax(dim=-1)
  336. probs_sum = probs_sort.cumsum(dim=-1)
  337. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  338. # at least one
  339. top_p_mask[:, -1] = False
  340. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  341. # Re-sort the probabilities.
  342. src = torch.arange(logits_idx.shape[-1],
  343. device=logits_idx.device).expand_as(logits_idx)
  344. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  345. index=logits_idx,
  346. src=src)
  347. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  348. return logits
  349. def _apply_min_p(
  350. logits: torch.Tensor,
  351. min_p: torch.Tensor,
  352. ) -> torch.Tensor:
  353. """
  354. Adapted from
  355. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  356. """
  357. probs = torch.softmax(logits, dim=-1)
  358. top_probs, _ = probs.max(dim=-1, keepdim=True)
  359. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  360. tokens_to_remove = probs < scaled_min_p
  361. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  362. return logits
  363. def _apply_top_a(
  364. logits: torch.Tensor,
  365. top_a: torch.Tensor,
  366. ) -> torch.Tensor:
  367. probs = torch.softmax(logits, dim=-1)
  368. top_probs, _ = probs.max(dim=-1, keepdim=True)
  369. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  370. tokens_to_remove = probs < threshold
  371. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  372. return logits
  373. def _apply_tfs(
  374. logits: torch.Tensor,
  375. tfs: torch.Tensor,
  376. ) -> torch.Tensor:
  377. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  378. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  379. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  380. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  381. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  382. tfs_mask = torch.cat(
  383. (
  384. torch.zeros(
  385. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  386. tfs_mask,
  387. torch.ones(
  388. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  389. ),
  390. dim=-1,
  391. )
  392. logits_sort[tfs_mask] = -float("inf")
  393. logits = torch.gather(logits_sort,
  394. dim=-1,
  395. index=torch.argsort(logits_idx, dim=-1))
  396. return logits
  397. def _apply_eta_cutoff(
  398. logits: torch.Tensor,
  399. eta_cutoff: torch.Tensor,
  400. ) -> torch.Tensor:
  401. shifted_logits = torch.log_softmax(logits, dim=-1)
  402. probs = shifted_logits.exp()
  403. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  404. eps = torch.min(eta_cutoff,
  405. torch.sqrt(eta_cutoff) *
  406. torch.exp(neg_entropy)).unsqueeze(dim=1)
  407. eta_mask = probs < eps
  408. # guard against nulling out all the logits
  409. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  410. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  411. logits[eta_mask] = -float("inf")
  412. return logits
  413. def _apply_epsilon_cutoff(
  414. logits: torch.Tensor,
  415. epsilon_cutoff: torch.Tensor,
  416. ) -> torch.Tensor:
  417. probs = logits.softmax(dim=-1)
  418. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  419. # guard against nulling out all the logits
  420. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  421. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  422. logits[eps_mask] = -float("inf")
  423. return logits
  424. def _apply_typical_sampling(
  425. logits: torch.Tensor,
  426. typical_p: torch.Tensor,
  427. ) -> torch.Tensor:
  428. shifted_logits = torch.log_softmax(logits, dim=-1)
  429. probs = shifted_logits.exp()
  430. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  431. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  432. _, indices = torch.sort(surprisal_deviations)
  433. reordered_probs = probs.gather(-1, indices)
  434. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  435. dim=1)
  436. min_tokens_to_keep = 1
  437. # Keep at least min_tokens_to_keep
  438. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  439. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  440. logits[typ_mask] = -float("inf")
  441. return logits
  442. def _apply_quadratic_sampling(
  443. logits: torch.Tensor,
  444. smoothing_factor: torch.Tensor,
  445. smoothing_curve: torch.Tensor,
  446. ) -> torch.Tensor:
  447. """
  448. Applies a quadratic transformation to the logits based on the
  449. provided smoothing factors and curves. The transformation is
  450. centered around the maximum logit value in the batch.
  451. The transformation involves a quadratic and cubic term, with the
  452. cubic term controlled by the smoothing curve. The quadratic term is
  453. scaled by the smoothing factor, and the cubic term is scaled by the
  454. product of the smoothing factor and the smoothing curve.
  455. params:
  456. logits (torch.Tensor): The logits to be transformed.
  457. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  458. term in the transformation.
  459. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  460. in the transformation.
  461. returns:
  462. torch.Tensor: The transformed logits.
  463. Credits: @kalomaze
  464. """
  465. mask = smoothing_factor != 0
  466. smoothing_factor.unsqueeze_(dim=1)
  467. smoothing_curve.unsqueeze_(dim=1)
  468. k = smoothing_factor * (3 - smoothing_curve) / 2
  469. s = smoothing_factor * (smoothing_curve - 1) / 2
  470. quadlogits = logits[mask] # limit to logits using this sampler
  471. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  472. # Construct the delta from each logit to its new value
  473. diff = quadlogits - max_logits
  474. diff -= diff**2 * (s[mask] * diff - k[mask])
  475. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  476. logits[mask] -= diff
  477. return logits
  478. def _apply_xtc_sampling(
  479. logits: torch.Tensor,
  480. xtc_thresholds: torch.Tensor,
  481. xtc_probabilities: torch.Tensor,
  482. ) -> torch.Tensor:
  483. """Apply Exclude Top Choices (XTC) sampling to the logits.
  484. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  485. Args:
  486. logits: (num_tokens, vocab_size) The input logits.
  487. xtc_thresholds: (num_tokens,) The threshold for each token.
  488. xtc_probabilities: (num_tokens,) The probability of applying XTC
  489. for each token.
  490. Returns:
  491. torch.Tensor: The modified logits.
  492. """
  493. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  494. if not apply_xtc.any():
  495. return logits
  496. probs = torch.softmax(logits, dim=-1)
  497. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  498. # Find indices where the next probability is above the threshold
  499. # Skips the top choice, which later on becomes skipping the last choice.
  500. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  501. # Apply XTC only to rows where it should be applied
  502. for i in range(logits.shape[0]):
  503. if apply_xtc[i]:
  504. # Count logits above the threshold (skipping the first)
  505. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  506. if indices_to_remove > 0:
  507. # Implies the top logit and at least one other is >= threshold.
  508. # Mask out above_thresh logits except the last/lowest one.
  509. logits[i].scatter_(
  510. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  511. return logits
  512. def _greedy_sample(
  513. selected_seq_groups: List[SequenceGroupToSample],
  514. samples: torch.Tensor,
  515. ) -> List[Tuple[List[int], List[int]]]:
  516. """Run greedy sampling on a given samples.
  517. Args:
  518. selected_seq_groups: A list of sequence groups batched.
  519. samples: (num_selected_samples,) A tensor of samples. The length of
  520. samples could be smaller than selected_seq_groups if
  521. seq_group.do_sample is False.
  522. Returns:
  523. Tuple of (next_token_ids, parent_ids). The length of returned list is
  524. same as the length of selected_seq_groups. If the corresponding
  525. seq_group has do_sample=False, tuple contains ([], [])
  526. """
  527. samples = samples.tolist()
  528. sample_idx = 0
  529. results = []
  530. for seq_group in selected_seq_groups:
  531. if not seq_group.do_sample:
  532. results.append(([], []))
  533. continue
  534. seq_ids = seq_group.seq_ids
  535. num_parent_seqs = len(seq_ids)
  536. assert num_parent_seqs == 1, (
  537. "Greedy sampling should have only one seq.")
  538. parent_ids = list(range(num_parent_seqs))
  539. next_token_ids = [samples[sample_idx]]
  540. results.append((next_token_ids, parent_ids))
  541. sample_idx += num_parent_seqs
  542. return results
  543. def _random_sample(
  544. selected_seq_groups: List[SequenceGroupToSample],
  545. random_samples: torch.Tensor,
  546. ) -> List[Tuple[List[int], List[int]]]:
  547. """Run random sampling on a given samples.
  548. Args:
  549. selected_seq_groups: A list of sequence groups batched.
  550. random_samples: (num_selected_samples,) A tensor of samples. The
  551. length of samples could be smaller than selected_seq_groups if
  552. seq_group.do_sample is False.
  553. Returns:
  554. Tuple of (next_token_ids, parent_ids). The length of returned list is
  555. same as the length of selected_seq_groups. If the corresponding
  556. seq_group has do_sample=False, tuple contains ([], [])
  557. """
  558. # Find the maximum best_of value of the prompt phase requests.
  559. random_samples = random_samples.cpu()
  560. sample_idx = 0
  561. results = []
  562. for seq_group in selected_seq_groups:
  563. if not seq_group.do_sample:
  564. results.append(([], []))
  565. continue
  566. seq_ids = seq_group.seq_ids
  567. sampling_params = seq_group.sampling_params
  568. is_prompt = seq_group.is_prompt
  569. num_parent_seqs = len(seq_ids)
  570. if is_prompt:
  571. # Prompt phase.
  572. parent_ids = [0] * sampling_params.best_of
  573. next_token_ids = random_samples[
  574. sample_idx, :sampling_params.best_of].tolist()
  575. else:
  576. # Generation phase.
  577. parent_ids = list(range(num_parent_seqs))
  578. next_token_ids = random_samples[sample_idx:sample_idx +
  579. num_parent_seqs, 0].tolist()
  580. results.append((next_token_ids, parent_ids))
  581. sample_idx += num_parent_seqs
  582. return results
  583. def _beam_search_sample(
  584. selected_seq_groups: List[SequenceGroupToSample],
  585. logprobs: torch.Tensor,
  586. ) -> List[Tuple[List[int], List[int]]]:
  587. """Run beam sampling on a given samples.
  588. Args:
  589. selected_seq_groups: A list of sequence groups batched.
  590. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  591. on selected sample indices.
  592. Returns:
  593. Tuple of (next_token_ids, parent_ids). The length of returned list is
  594. same as the length of selected_seq_groups. If the corresponding
  595. seq_group has do_sample=False, tuple contains ([], [])
  596. """
  597. # We sample 2 * beam_width candidates to make sure that with high
  598. # probability we can get `beam_width` candidates in addition to
  599. # the finished sequences for the next iteration. See
  600. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  601. # for details. See also HF reference:
  602. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  603. #
  604. # NOTE: Beam search is not vectorized, so its speed can be slower than
  605. # other sampling methods.
  606. sample_idx = 0
  607. results = []
  608. for seq_group in selected_seq_groups:
  609. if not seq_group.do_sample:
  610. results.append(([], []))
  611. continue
  612. is_prompt = seq_group.is_prompt
  613. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  614. num_parent_seqs = len(seq_ids)
  615. beam_width = sampling_params.best_of
  616. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  617. if is_prompt:
  618. # Prompt phase.
  619. assert num_parent_seqs == 1, (
  620. "Prompt input should have only one seq.")
  621. parent_ids = [0] * (2 * beam_width)
  622. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  623. 2 * beam_width)
  624. next_token_ids = next_token_ids.tolist()
  625. else:
  626. # Generation phase.
  627. cumulative_logprobs = [
  628. seq_group.seq_data[seq_id].cumulative_logprob
  629. for seq_id in seq_ids
  630. ]
  631. cumulative_logprobs = torch.tensor(
  632. cumulative_logprobs,
  633. dtype=torch.float,
  634. device=seq_group_logprobs.device)
  635. seq_group_logprobs = (seq_group_logprobs +
  636. cumulative_logprobs.unsqueeze(dim=1))
  637. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  638. 2 * beam_width)
  639. topk_ids = topk_ids.tolist()
  640. vocab_size = seq_group_logprobs.size(-1)
  641. parent_ids = [i // vocab_size for i in topk_ids]
  642. next_token_ids = [i % vocab_size for i in topk_ids]
  643. results.append((next_token_ids, parent_ids))
  644. sample_idx += num_parent_seqs
  645. assert sample_idx == logprobs.size(0)
  646. return results
  647. # torch.multinomial forces a GPU<->CPU sync.
  648. # Therefore, we use an optimized implementation instead.
  649. # Note that we always sample with replacement.
  650. # probs will be modified in place, but this is fine, as we pass
  651. # in a copy already.
  652. def _multinomial(
  653. probs: torch.Tensor,
  654. num_samples: int,
  655. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  656. ) -> torch.Tensor:
  657. if num_samples > 1:
  658. # This is equivalent to torch.repeat_interleaved (which also
  659. # forces a GPU<->CPU sync).
  660. # This allows us to do sampling with replacement by creating
  661. # num_samples copies of each row in the tensor, and then
  662. # batch sampling the resulting tensor.
  663. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  664. probs.shape[1]).contiguous().view(
  665. -1, probs.shape[1])
  666. q = torch.empty_like(probs)
  667. if seq_groups is None:
  668. q.exponential_()
  669. else:
  670. sample_idx = 0
  671. for seq_group in seq_groups:
  672. seq_ids = seq_group.seq_ids
  673. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  674. q[sample_idx:next_sample_idx].exponential_(
  675. generator=seq_group.generator)
  676. sample_idx = next_sample_idx
  677. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  678. def _sample_with_torch(
  679. probs: torch.Tensor,
  680. logprobs: torch.Tensor,
  681. sampling_metadata: SamplingMetadata,
  682. include_gpu_probs_tensor: bool,
  683. modify_greedy_probs: bool,
  684. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  685. categorized_seq_group_ids = {t: [] for t in SamplingType}
  686. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  687. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  688. sampling_params = seq_group.sampling_params
  689. sampling_type = sampling_params.sampling_type
  690. categorized_seq_group_ids[sampling_type].append(i)
  691. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  692. sample_metadata = {}
  693. multinomial_samples = {}
  694. # Create output tensor for sampled token ids.
  695. if include_gpu_probs_tensor:
  696. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  697. 1,
  698. dtype=torch.long,
  699. device=logprobs.device)
  700. else:
  701. sampled_token_ids_tensor = None
  702. # Counterintuitively, having two loops here is actually faster.
  703. # The first loop can run without waiting on GPU<->CPU sync.
  704. for sampling_type in SamplingType:
  705. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  706. num_tokens = len(sample_indices)
  707. if num_tokens == 0:
  708. continue
  709. seq_group_id = categorized_seq_group_ids[sampling_type]
  710. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  711. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  712. long_sample_indices = sample_indices.long()
  713. if sampling_type == SamplingType.GREEDY:
  714. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  715. dim=-1)
  716. if include_gpu_probs_tensor:
  717. # Store sampled tokens in output tensor.
  718. sampled_token_ids_tensor[
  719. long_sample_indices] = greedy_samples.unsqueeze(-1)
  720. if modify_greedy_probs:
  721. # If required, modify the probabilities such that sampling from
  722. # the modified distribution would always sample the argmax
  723. # token id.
  724. _modify_greedy_probs_inplace(logprobs, probs,
  725. long_sample_indices,
  726. greedy_samples)
  727. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  728. max_best_of_in_batch = 1
  729. for seq_group in seq_groups:
  730. if seq_group.is_prompt:
  731. sampling_params = seq_group.sampling_params
  732. max_best_of_in_batch = max(max_best_of_in_batch,
  733. sampling_params.best_of)
  734. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  735. "seq_groups": seq_groups,
  736. }
  737. multinomial_samples[sampling_type] = _multinomial(
  738. probs[long_sample_indices], max_best_of_in_batch,
  739. **seeded_args)
  740. if include_gpu_probs_tensor:
  741. # Store sampled tokens in output tensor.
  742. sampled_token_ids_tensor[
  743. long_sample_indices] = multinomial_samples[sampling_type]
  744. elif sampling_type == SamplingType.BEAM:
  745. beam_search_logprobs = logprobs[sample_indices]
  746. else:
  747. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  748. # GPU<->CPU sync happens in the loop below.
  749. # This also converts the sample output to Python objects.
  750. if not sampling_metadata.skip_sampler_cpu_output:
  751. for sampling_type in SamplingType:
  752. if sampling_type not in sample_metadata:
  753. continue
  754. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  755. if sampling_type == SamplingType.GREEDY:
  756. sample_results = _greedy_sample(seq_groups, greedy_samples)
  757. elif sampling_type in (SamplingType.RANDOM,
  758. SamplingType.RANDOM_SEED):
  759. sample_results = _random_sample(
  760. seq_groups, multinomial_samples[sampling_type])
  761. elif sampling_type == SamplingType.BEAM:
  762. sample_results = _beam_search_sample(seq_groups,
  763. beam_search_logprobs)
  764. sample_results_dict.update(zip(seq_group_id, sample_results))
  765. sample_results = [
  766. sample_results_dict.get(i, ([], []))
  767. for i in range(len(sampling_metadata.seq_groups))
  768. ]
  769. else:
  770. sample_results = []
  771. return sample_results, sampled_token_ids_tensor
  772. def _sample_with_triton_kernel(
  773. probs: torch.Tensor,
  774. logprobs: torch.Tensor,
  775. sampling_metadata: SamplingMetadata,
  776. sampling_tensors: SamplingTensors,
  777. ) -> List[Tuple[List[int], List[int]]]:
  778. categorized_seq_group_ids = {t: [] for t in SamplingType}
  779. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  780. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  781. sampling_params = seq_group.sampling_params
  782. sampling_type = sampling_params.sampling_type
  783. categorized_seq_group_ids[sampling_type].append(i)
  784. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  785. sample_metadata = {}
  786. max_best_of_in_batch = 1
  787. # Counterintuitively, having two loops here is actually faster.
  788. # The first loop can run without waiting on GPU<->CPU sync.
  789. for sampling_type in SamplingType:
  790. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  791. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  792. num_tokens = len(sample_indices)
  793. if num_tokens == 0:
  794. continue
  795. seq_group_id = categorized_seq_group_ids[sampling_type]
  796. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  797. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  798. sample_indices,
  799. sampled_token_indices)
  800. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  801. SamplingType.RANDOM_SEED):
  802. for seq_group in seq_groups:
  803. if seq_group.is_prompt:
  804. sampling_params = seq_group.sampling_params
  805. max_best_of_in_batch = max(max_best_of_in_batch,
  806. sampling_params.best_of)
  807. elif sampling_type == SamplingType.BEAM:
  808. beam_search_logprobs = logprobs[sample_indices]
  809. else:
  810. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  811. sampled_tokens, _, _ = sample_triton(
  812. probs=probs,
  813. seeds=sampling_tensors.sampling_seeds,
  814. max_best_of=max_best_of_in_batch,
  815. sample_indices=sampling_tensors.sample_indices,
  816. logprobs=logprobs,
  817. # don't save logprobs because we have logic for that below
  818. # TODO: use this instead of the CPU-based logic below
  819. save_logprobs=False,
  820. )
  821. # GPU<->CPU sync happens in the loop below.
  822. for sampling_type in SamplingType:
  823. if sampling_type not in sample_metadata:
  824. continue
  825. (seq_group_id, seq_groups, sample_indices,
  826. sampled_token_indices) = sample_metadata[sampling_type]
  827. if sampling_type == SamplingType.GREEDY:
  828. sample_results = _greedy_sample(
  829. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  830. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  831. sample_results = _random_sample(
  832. seq_groups, sampled_tokens[sampled_token_indices])
  833. elif sampling_type == SamplingType.BEAM:
  834. sample_results = _beam_search_sample(seq_groups,
  835. beam_search_logprobs)
  836. sample_results_dict.update(zip(seq_group_id, sample_results))
  837. sample_results = [
  838. sample_results_dict.get(i, ([], []))
  839. for i in range(len(sampling_metadata.seq_groups))
  840. ]
  841. return sample_results
  842. def _sample(
  843. probs: torch.Tensor, logprobs: torch.Tensor,
  844. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  845. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  846. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  847. """
  848. Args:
  849. probs: (num_query_tokens_in_batch, num_vocab)
  850. logprobs: (num_query_tokens_in_batch, num_vocab)
  851. sampling_metadata: The metadata for a batch for sampling.
  852. sampling_tensors: Tensors that include sampling related metadata.
  853. Returns:
  854. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  855. If sampling is skipped, it returns ([], [])
  856. sampled_token_ids_tensor: A tensor of sampled token ids.
  857. """
  858. return _sample_with_torch(
  859. probs,
  860. logprobs,
  861. sampling_metadata,
  862. include_gpu_probs_tensor=include_gpu_probs_tensor,
  863. modify_greedy_probs=modify_greedy_probs,
  864. )
  865. # TODO: Enable once Triton kernel & associated code is faster.
  866. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  867. # sampling_tensors)
  868. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  869. """
  870. This function calculates the ranks of the chosen tokens in a logprob tensor.
  871. Args:
  872. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  873. where N is the no. of tokens and M is the vocab dim.
  874. indices (torch.Tensor): List of chosen token indices.
  875. Returns:
  876. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  877. Each element in the returned tensor represents the rank
  878. of the chosen token in the input logprob tensor.
  879. """
  880. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  881. indices]
  882. return (x > vals[:, None]).long().sum(1).add_(1)
  883. def _get_logprobs(
  884. logprobs: torch.Tensor,
  885. sampling_metadata: SamplingMetadata,
  886. sample_results: List[Tuple[List[int], List[int]]],
  887. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  888. """Return sample lobprobs and prompt logprobs.
  889. The logic consists of 3 parts.
  890. - Select indices to compute logprob from, ranks of token ids, and
  891. the top k token ids from logprobs.
  892. - Compute prompt logprobs if required.
  893. - Compute sample logprobs if required.
  894. Args:
  895. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  896. logprob per vocab. Sequence groups' query tokens are batched in a
  897. single flattened tensor. For example, assuming there are N
  898. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  899. prompt logprob is enabled), decode tokens for seq_group_1 (if
  900. sampling is required), prefill tokens for seq_group_2, ...
  901. sampling_metadata: The sampling metadata.
  902. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  903. parent_ids) for each sequence group. When beam search is enabled,
  904. sample_results can contain different number of seq_ids from
  905. sampling_metadata.seq_groups. It is because beam search creates
  906. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  907. BEAM_WIDTH number of seq_ids).
  908. Returns:
  909. A tuple of prompt and sample logprobs per sequence group in a batch.
  910. """
  911. # The index of query token to calculate logprobs. It includes both
  912. # prompt and sample logprob indices.
  913. query_indices: List[int] = []
  914. # The next token ids to get the logprob value from.
  915. next_token_ids: List[int] = []
  916. # The largest requested number of logprobs. We find logprobs as many as the
  917. # largest num logprobs in this API. If every logprobs is None, it will be
  918. # set to -1.
  919. largest_num_logprobs = -1
  920. # If beam search is enabled.
  921. use_beam_search = False
  922. # Select indices to compute logprob from, ranks of token ids, and the top
  923. # k token ids from logprobs.
  924. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  925. sample_results):
  926. sampling_params = seq_group.sampling_params
  927. # Update indices and tokens for prompt logprobs.
  928. if (seq_group.is_prompt
  929. and sampling_params.prompt_logprobs is not None):
  930. largest_num_logprobs = max(largest_num_logprobs,
  931. sampling_params.prompt_logprobs)
  932. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  933. query_indices.extend(seq_group.prompt_logprob_indices)
  934. next_token_ids.extend(next_prompt_tokens)
  935. # Update indices and next tokenes for sample logprob.
  936. if seq_group.do_sample:
  937. token_ids, parent_seq_ids = sample_result
  938. # NOTE: We cannot directly use sample_indices because
  939. # sample_indices only contain parent seq_ids of a previous step.
  940. # The current step may have different number of seq_ids, and
  941. # we can obtain it from `sample_result[1]`.
  942. query_idx = seq_group.sample_indices[0]
  943. query_indices.extend(
  944. [query_idx + parent_id for parent_id in parent_seq_ids])
  945. next_token_ids.extend(token_ids)
  946. if sampling_params.logprobs is not None:
  947. largest_num_logprobs = max(largest_num_logprobs,
  948. sampling_params.logprobs)
  949. use_beam_search = use_beam_search or sampling_params.use_beam_search
  950. assert len(next_token_ids) == len(query_indices)
  951. if len(query_indices) == 0:
  952. empty_sampled_logprob = []
  953. empty_prompt_logprob = None
  954. return [empty_prompt_logprob], [empty_sampled_logprob]
  955. selected_logprobs, ranks = None, None
  956. top_logprobs, top_token_ids = None, None
  957. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  958. # skip the whole logprob calculation.
  959. if largest_num_logprobs >= 0 or use_beam_search:
  960. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  961. next_token_ids_gpu = torch.tensor(next_token_ids,
  962. device=logprobs.device)
  963. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  964. # contain duplicates if beam search is enabled.
  965. selected_logprobs = logprobs[[
  966. query_indices_gpu,
  967. next_token_ids_gpu,
  968. ]]
  969. ranks = _get_ranks(
  970. logprobs[query_indices_gpu],
  971. next_token_ids_gpu,
  972. )
  973. assert selected_logprobs.shape[0] == ranks.shape[0]
  974. # We need to compute top k only if there exists logprobs > 0.
  975. if largest_num_logprobs > 0:
  976. # Logprobs of topk tokens for a batch of sequence groups.
  977. # (num_query_tokens_across_batch).
  978. top_logprobs, top_token_ids = torch.topk(logprobs,
  979. largest_num_logprobs,
  980. dim=-1)
  981. top_logprobs = top_logprobs.to('cpu')
  982. top_token_ids = top_token_ids.to('cpu')
  983. selected_logprobs = selected_logprobs.to('cpu')
  984. ranks = ranks.to('cpu')
  985. # Find prompt/sample logprobs.
  986. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  987. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  988. top_logprob_idx = 0
  989. selected_logprobs_idx = 0
  990. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  991. sample_results):
  992. (prompt_logprobs, top_logprob_idx,
  993. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  994. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  995. selected_logprobs_idx, top_logprob_idx)
  996. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  997. (sampled_logprobs, top_logprob_idx,
  998. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  999. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1000. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1001. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1002. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1003. def _get_prompt_logprob_if_needed(
  1004. seq_group: SequenceGroupToSample,
  1005. selected_logprobs: torch.Tensor,
  1006. ranks: torch.Tensor,
  1007. top_token_ids: torch.Tensor,
  1008. top_logprobs: torch.Tensor,
  1009. selected_logprobs_idx: int,
  1010. top_logprob_idx: int,
  1011. ):
  1012. """Compute the prompt logprob from a sequence group if needed."""
  1013. sampling_params = seq_group.sampling_params
  1014. is_prompt = seq_group.is_prompt
  1015. # Find prompt logprobs
  1016. prompt_logprobs: Optional[PromptLogprobs] = None
  1017. if is_prompt and sampling_params.prompt_logprobs is not None:
  1018. prompt_logprobs = []
  1019. num_logprobs = sampling_params.prompt_logprobs
  1020. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1021. # Pre-select indexes and create a list. It is faster than calling .item
  1022. # repetitively.
  1023. selected_logprob_items = selected_logprobs[
  1024. selected_logprobs_idx:selected_logprobs_idx +
  1025. len(next_prompt_tokens)].tolist()
  1026. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1027. len(next_prompt_tokens)].tolist()
  1028. for idx, token_id in enumerate(next_prompt_tokens):
  1029. # Calculate the prompt logprob of the real prompt tokens.
  1030. # {token_id: (logprob, rank_from_vocab)}
  1031. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1032. token_id: (selected_logprob_items[idx], rank_items[idx])
  1033. }
  1034. # Add top K prompt logprobs along with its rank.
  1035. if num_logprobs > 0:
  1036. top_ids = top_token_ids[
  1037. top_logprob_idx, :num_logprobs].tolist()
  1038. top_probs = top_logprobs[
  1039. top_logprob_idx, :num_logprobs].tolist()
  1040. # Top K is already sorted by rank, so we can use 1 ~
  1041. # num_logprobs + 1 for rank.
  1042. top_ranks = range(1, num_logprobs + 1)
  1043. prompt_logprobs_dict.update({
  1044. top_id: (top_prob, rank)
  1045. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1046. top_ranks)
  1047. })
  1048. prompt_logprobs.append({
  1049. token_id: Logprob(*logprob_and_rank)
  1050. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1051. })
  1052. # + 1 to go to the next prompt token.
  1053. top_logprob_idx += 1
  1054. # + len(next_prompt_tokens) to go to the next prompt.
  1055. selected_logprobs_idx += len(next_prompt_tokens)
  1056. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1057. def _get_sampled_logprob_if_needed(
  1058. seq_group: SequenceGroupToSample,
  1059. sample_result: Tuple[List[int], List[int]],
  1060. selected_logprobs: torch.Tensor,
  1061. ranks: torch.Tensor,
  1062. top_token_ids: torch.Tensor,
  1063. top_logprobs: torch.Tensor,
  1064. selected_logprobs_idx: int,
  1065. top_logprob_idx: int,
  1066. ):
  1067. """Compute the sample logprob if needed."""
  1068. seq_ids = seq_group.seq_ids
  1069. num_logprobs = seq_group.sampling_params.logprobs
  1070. use_beam_search = seq_group.sampling_params.use_beam_search
  1071. sampled_logprobs: SampleLogprobs = []
  1072. next_token_ids, parent_seq_ids = sample_result
  1073. if seq_group.do_sample:
  1074. assert len(next_token_ids) > 0
  1075. if num_logprobs is None and not use_beam_search:
  1076. for next_token_id in next_token_ids:
  1077. # Use a dummy logprob
  1078. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1079. else:
  1080. # Pre-select items from tensor. tolist() is faster than repetitive
  1081. # `.item()` calls.
  1082. selected_logprob_items = selected_logprobs[
  1083. selected_logprobs_idx:selected_logprobs_idx +
  1084. len(next_token_ids)].tolist()
  1085. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1086. len(next_token_ids)].tolist()
  1087. for idx, (next_token_id, parent_id) in enumerate(
  1088. zip(next_token_ids, parent_seq_ids)):
  1089. # Get the logprob of a sampled token.
  1090. sampled_logprobs_dict = {
  1091. next_token_id:
  1092. (selected_logprob_items[idx], rank_items[idx])
  1093. }
  1094. if num_logprobs is not None and num_logprobs > 0:
  1095. # Get top K logprobs.
  1096. top_ids = top_token_ids[top_logprob_idx +
  1097. parent_id, :num_logprobs].tolist()
  1098. top_probs = top_logprobs[
  1099. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1100. # Top K is already sorted by rank, so we can use 1 ~
  1101. # num_logprobs + 1 for rank.
  1102. top_ranks = range(1, num_logprobs + 1)
  1103. sampled_logprobs_dict.update({
  1104. top_id: (top_prob, rank)
  1105. for top_id, top_prob, rank in zip(
  1106. top_ids, top_probs, top_ranks)
  1107. })
  1108. sampled_logprobs.append({
  1109. token_id: Logprob(*logprob_and_rank)
  1110. for token_id, logprob_and_rank in
  1111. sampled_logprobs_dict.items()
  1112. })
  1113. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1114. # logprobs for the current step, which has len(next_token_ids) tokens
  1115. # per sequence group. `logprobs` includes logprobs from the previous
  1116. # steps, which has len(seq_ids) tokens per sequence group.
  1117. # Iterate to the next sequence group in a batch.
  1118. selected_logprobs_idx += len(next_token_ids)
  1119. # Iterate to the next sequence group in a batch.
  1120. top_logprob_idx += len(seq_ids)
  1121. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1122. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1123. sample_indices: torch.Tensor,
  1124. greedy_samples: torch.Tensor) -> None:
  1125. """Modify the probability distributions of the greedily-sampled tokens such
  1126. that each sampled token has a "probability" of 1.0. This is required by
  1127. speculative decoding, which depends on the sampling method being encoded
  1128. within the probability distribution for correctness.
  1129. # Why do we only need to do this for greedy sampling?
  1130. Aphrodite's sampler performs the following steps for greedy or multinomial
  1131. (random) sampling:
  1132. 1. Get logits from model.
  1133. 2. Modify logits according to per-sequence sampling parameters.
  1134. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1135. according to their frequency, etc.
  1136. 3. Sample a token.
  1137. - Random sampling simply samples from the modified probability
  1138. distribution.
  1139. - Greedy sampling performs `argmax` to obtain the token with the
  1140. highest likelihood.
  1141. Ignoring greedy sampling for a moment, we find that the computed probability
  1142. distribution has the following property: we can sample from it independently
  1143. and find that the token sampled by the Sampler has a frequency corresponding
  1144. to how often we see it in our sampling. In other words, for tokens sampled
  1145. with Aphrodite's random SamplingType, the computed probability distribution
  1146. encodes the sampling methodology completely.
  1147. Greedy sampling does not normally have this property. Aphrodite modifies
  1148. logits according to sampling params, then performs `argmax`, then returns
  1149. the sampled token and the computed probability distribution. If we sample
  1150. from the distribution, we'll find the likelihood of the greedily-sampled
  1151. token is not always 1.0.
  1152. Since lossless speculative decoding requires that the sampling methodology
  1153. be encoded within the probability distribution, we are motivated to modify
  1154. the probability distribution such that the sampled token has probability 1
  1155. when speculative decoding is used.
  1156. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1157. greedy sampling using multinomial computation and unite the codepaths. This
  1158. has implications on the overall design of the sampler, e.g. how to record
  1159. accurate logprobs for the user, so this improvement is deferred to later.
  1160. """
  1161. # NOTE: logprobs are not modified so they can be returned to the user.
  1162. probs[sample_indices, :] = 0
  1163. probs[sample_indices, greedy_samples] = 1.0
  1164. def _build_sampler_output(
  1165. sample_results: SampleResultType,
  1166. sampling_metadata: SamplingMetadata,
  1167. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1168. sample_logprobs: Optional[List[SampleLogprobs]],
  1169. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1170. torch.Tensor]],
  1171. skip_sampler_cpu_output: bool = False,
  1172. ) -> SamplerOutput:
  1173. """Construct Python objects with the output of sampling.
  1174. Args:
  1175. on_device_tensors: Tuple containing on-device tensors with the
  1176. probabilities used in sampling and the sampled token ids. This
  1177. allows post-processing without copies to CPU/serialization, e.g. in
  1178. speculative decoding rejection sampling.
  1179. """
  1180. sampler_output: List[CompletionSequenceGroupOutput] = []
  1181. if not skip_sampler_cpu_output:
  1182. assert prompt_logprobs is not None
  1183. assert sample_logprobs is not None
  1184. for (seq_group, sample_result, group_prompt_logprobs,
  1185. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1186. sample_results, prompt_logprobs,
  1187. sample_logprobs):
  1188. seq_ids = seq_group.seq_ids
  1189. next_token_ids, parent_ids = sample_result
  1190. seq_outputs: List[SequenceOutput] = []
  1191. for parent_id, next_token_id, logprobs in zip(
  1192. parent_ids, next_token_ids, group_sample_logprobs):
  1193. seq_outputs.append(
  1194. SequenceOutput(seq_ids[parent_id], next_token_id,
  1195. logprobs))
  1196. sampler_output.append(
  1197. CompletionSequenceGroupOutput(seq_outputs,
  1198. group_prompt_logprobs))
  1199. # If not specified, store None values in SamplerOutput.
  1200. if on_device_tensors is not None:
  1201. (sampled_token_probs, logprobs_tensor,
  1202. sampled_token_ids) = on_device_tensors
  1203. else:
  1204. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1205. None)
  1206. return SamplerOutput(
  1207. outputs=sampler_output,
  1208. sampled_token_probs=sampled_token_probs,
  1209. sampled_token_ids=sampled_token_ids,
  1210. logprobs=logprobs_tensor,
  1211. )
  1212. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1213. """Get a list of next prompt tokens to compute logprob from a
  1214. given sequence group.
  1215. It is used to compute prompt logprob. Imagine you have logprob for each
  1216. query token. Query token needs to know the next prompt token id to compute
  1217. prompt logprob. This is a helper to obtain next prompt token ids.
  1218. This API has to be used only when the caller knows seq_group is in prefill
  1219. stage.
  1220. Returns:
  1221. A list of next prompt tokens to compute logprob.
  1222. """
  1223. assert seq_group.is_prompt, (
  1224. "Caller should ensure the sequence group is in a prefill stage.")
  1225. seq_ids = seq_group.seq_ids
  1226. query_len = seq_group.query_len
  1227. assert query_len is not None
  1228. # prompt has only 1 seq id.
  1229. assert len(seq_ids) == 1
  1230. seq_data = seq_group.seq_data[seq_ids[0]]
  1231. computed_len = seq_data.get_num_computed_tokens()
  1232. prompt_tokens = seq_data.prompt_token_ids
  1233. # +1 because we are looking for a next prompt token.
  1234. next_token_index_start = computed_len + 1
  1235. next_token_index_end = min(computed_len + query_len + 1,
  1236. len(prompt_tokens))
  1237. next_prompt_tokens = prompt_tokens[
  1238. next_token_index_start:next_token_index_end]
  1239. return next_prompt_tokens
  1240. # def _apply_mirostat_v2(logits: torch.Tensor,
  1241. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1242. # # Reduce our view to just the affected logits
  1243. # logit_view = logits[sampling_tensors.miro_indices]
  1244. # # Calculate surprise value per token
  1245. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1246. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1247. # # Mask out "too-surprising" tokens (surprisal > mu)
  1248. # mus = sampling_tensors.miro_mus
  1249. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1250. # # Unmask most-likely logit to guarantee a selection.
  1251. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1252. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1253. # # Apply logit mask (effectively a top-k filter).
  1254. # logit_view[miro_mask] = -float("inf")
  1255. # # Project logit changes made to the view onto the original.
  1256. # # I think this step might be redundant.
  1257. # logits[sampling_tensors.miro_indices] = logit_view
  1258. # return logits
  1259. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1260. # sample_results: List[Tuple[List[int], List[int]]],
  1261. # sampling_metadata: SamplingMetadata,
  1262. # output_metadata: OutputMetadata) -> None:
  1263. # """Based on whichever token was finally sampled, we calculate the
  1264. # final surprisal values to update the mus.
  1265. # Because a single sequence can have multiple samples, we must fork
  1266. # the mu accordingly."""
  1267. # assert sampling_metadata.seq_groups is not None
  1268. # seqid_to_tokens = {}
  1269. # seqid_to_indices = {}
  1270. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1271. # sample_results):
  1272. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1273. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1274. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1275. # seqids = args.miro_seqids
  1276. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1277. # device=logits.device,
  1278. # dtype=torch.long)
  1279. # # Clumsily, we recalculate token surprisals.
  1280. # logits_view = logits[args.miro_indices]
  1281. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1282. # dim=-1,
  1283. # index=picked_tokens) / -math.log(2)
  1284. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1285. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1286. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1287. # nu_mus = mus - (picked_surprise - taus) * etas
  1288. # # Record updated mu values for use in the next iteration
  1289. # # Note how each mu is split into multiple based on the number of samples.
  1290. # for seqid, seq_mus in zip(seqids, nu_mus):
  1291. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1292. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)