1
0

sampler.py 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from math import inf
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.common.sampling_params import SamplingType
  8. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  9. PromptLogprobs, SampleLogprobs,
  10. SamplerOutput, SequenceOutput)
  11. from aphrodite.triton_utils import HAS_TRITON
  12. if HAS_TRITON:
  13. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  14. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  15. SamplingTensors,
  16. SequenceGroupToSample)
  17. # (num_token_ids, num_parent_ids) per sequence group.
  18. SampleResultType = List[Tuple[List[int], List[int]]]
  19. # There isn't a "safe" temperature range for fp16 logits.
  20. # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
  21. # that this temperature well-uses the fp16 space after the logits are offset.
  22. _TEMPERATURE_MINIMUM = 2e-5
  23. class Sampler(nn.Module):
  24. """Samples the next tokens from the model's outputs.
  25. This layer does the following:
  26. 1. Discard the hidden states that are not used for sampling (i.e., all
  27. tokens except the final one in each prompt).
  28. 2. Compute the logits for the next tokens.
  29. 3. Apply presence, frequency and repetition penalties.
  30. 4. Apply temperature scaling.
  31. 5. Apply top-p and top-k truncation.
  32. 6. Sample the next tokens.
  33. Here, each sequence group within the batch can have different sampling
  34. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  35. The structure of the logits tensor is coupled with the seq_groups in
  36. sampling_metadata. Typically, each sequence in each seq_group has one row in
  37. logits for the next token to be sampled; however, for a seq_group with a
  38. prompt request with the prompt_logprobs sampling parameter, there are rows
  39. in logits for each token in the input prompt.
  40. """
  41. def __init__(self):
  42. super().__init__()
  43. # Whether or not the SamplerOutput should have on-device tensors
  44. # containing the sampled token ids and probabilities. This is used by
  45. # speculative decoding.
  46. self.include_gpu_probs_tensor = False
  47. self.should_modify_greedy_probs_inplace = False
  48. def _init_sampling_tensors(
  49. self,
  50. logits: torch.Tensor,
  51. sampling_metadata: SamplingMetadata,
  52. ):
  53. """The goal here is to reuse sampling tensors between similar decode
  54. runs. This is possible because sampling logic does not change between
  55. decodes of the same sequences.
  56. """
  57. _, vocab_size = logits.shape
  58. # First free any existing stored sampling tensors.
  59. # This is necessary because some sampling tensors may
  60. # have pinned memory.
  61. self._sampling_tensors = None
  62. # Initialize new sampling tensors
  63. (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
  64. do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  65. do_typical_ps, do_quadratic, do_xtc, do_temp_last
  66. ) = SamplingTensors.from_sampling_metadata(
  67. sampling_metadata, vocab_size, logits.device, logits.dtype)
  68. self._sampling_tensors = sampling_tensors
  69. self._do_penalties = do_penalties
  70. self._do_temperatures = do_temperatures
  71. self._do_top_p_top_k = do_top_p_top_k
  72. self._do_top_as = do_top_as
  73. self._do_min_p = do_min_p
  74. self._do_tfss = do_tfss
  75. self._do_eta_cutoffs = do_eta_cutoffs
  76. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  77. self._do_typical_ps = do_typical_ps
  78. self._do_quadratic = do_quadratic
  79. self._do_xtc = do_xtc
  80. self._do_temp_last = do_temp_last
  81. def forward(
  82. self,
  83. logits: torch.Tensor,
  84. sampling_metadata: SamplingMetadata,
  85. ) -> Optional[SamplerOutput]:
  86. """
  87. Args:
  88. logits: (num_tokens, vocab_size).
  89. sampling_metadata: Metadata for sampling.
  90. """
  91. assert logits is not None
  92. _, vocab_size = logits.shape
  93. # Prepare sampling tensors with pinned memory to avoid blocking.
  94. if not sampling_metadata.reuse_sampling_tensors:
  95. self._init_sampling_tensors(logits, sampling_metadata)
  96. elif self._do_penalties:
  97. # In this case, the sampling tensors logic depends on
  98. # "output_tokens" of a sequence. As a result, we cannot
  99. # reuse sampling tensors, since "output_tokens" changes
  100. # between decode runs.
  101. self._init_sampling_tensors(logits, sampling_metadata)
  102. assert self._sampling_tensors is not None
  103. sampling_tensors = self._sampling_tensors
  104. do_penalties = self._do_penalties
  105. do_temperatures = self._do_temperatures
  106. do_top_p_top_k = self._do_top_p_top_k
  107. do_top_as = self._do_top_as
  108. do_min_p = self._do_min_p
  109. do_tfss = self._do_tfss
  110. do_eta_cutoffs = self._do_eta_cutoffs
  111. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  112. do_typical_ps = self._do_typical_ps
  113. do_quadratic = self._do_quadratic
  114. do_xtc = self._do_xtc
  115. do_temp_last = self._do_temp_last
  116. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  117. # Apply presence and frequency penalties.
  118. if do_penalties:
  119. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  120. sampling_tensors.output_tokens,
  121. sampling_tensors.presence_penalties,
  122. sampling_tensors.frequency_penalties,
  123. sampling_tensors.repetition_penalties)
  124. # Apply temperature scaling if not doing temp_last.
  125. if do_temperatures and not do_temp_last:
  126. _apply_temperatures(logits, sampling_tensors.temperatures,
  127. sampling_tensors.dynatemp_mins,
  128. sampling_tensors.dynatemp_maxs,
  129. sampling_tensors.dynatemp_exps)
  130. if do_top_p_top_k:
  131. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  132. sampling_tensors.top_ks)
  133. if do_top_as:
  134. logits = _apply_top_a(logits, sampling_tensors.top_as)
  135. if do_min_p:
  136. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  137. if do_tfss:
  138. logits = _apply_tfs(logits, sampling_tensors.tfss)
  139. if do_eta_cutoffs:
  140. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  141. if do_epsilon_cutoffs:
  142. logits = _apply_epsilon_cutoff(logits,
  143. sampling_tensors.epsilon_cutoffs)
  144. if do_typical_ps:
  145. logits = _apply_typical_sampling(logits,
  146. sampling_tensors.typical_ps)
  147. if do_quadratic:
  148. logits = _apply_quadratic_sampling(
  149. logits, sampling_tensors.smoothing_factors,
  150. sampling_tensors.smoothing_curves)
  151. if do_xtc:
  152. logits = _apply_xtc_sampling(
  153. logits, sampling_tensors.xtc_thresholds,
  154. sampling_tensors.xtc_probabilities)
  155. if do_temperatures and do_temp_last:
  156. _apply_temperatures(logits, sampling_tensors.temperatures,
  157. sampling_tensors.dynatemp_mins,
  158. sampling_tensors.dynatemp_maxs,
  159. sampling_tensors.dynatemp_exps)
  160. banned_tokens = _get_custom_token_bans(sampling_metadata)
  161. logits = _apply_token_bans(logits, banned_tokens)
  162. # We use float32 for probabilities and log probabilities.
  163. # Compute the probabilities.
  164. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  165. # Compute the log probabilities.
  166. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  167. # Sample the next tokens.
  168. sample_results, maybe_sampled_tokens_tensor = _sample(
  169. probs,
  170. logprobs,
  171. sampling_metadata,
  172. sampling_tensors,
  173. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  174. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  175. )
  176. if self.include_gpu_probs_tensor:
  177. assert maybe_sampled_tokens_tensor is not None
  178. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  179. else:
  180. on_device_tensors = None
  181. # Get the logprobs query results.
  182. prompt_logprobs = None
  183. sample_logprobs = None
  184. if not sampling_metadata.skip_sampler_cpu_output:
  185. prompt_logprobs, sample_logprobs = _get_logprobs(
  186. logprobs, sampling_metadata, sample_results)
  187. return _build_sampler_output(
  188. sample_results,
  189. sampling_metadata,
  190. prompt_logprobs,
  191. sample_logprobs,
  192. on_device_tensors=on_device_tensors,
  193. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  194. @property
  195. def _should_modify_greedy_probs_inplace(self) -> bool:
  196. """Whether or not the sampler should modify the probability distribution
  197. of greedily-sampled tokens such that multinomial sampling would sample
  198. the greedily-sampled token.
  199. In other words, if True then we set the probability of the greedily-
  200. sampled token to 1.
  201. This is used by speculative decoding, which requires that the sampling
  202. method be encoded into the probability distribution.
  203. """
  204. return self.should_modify_greedy_probs_inplace
  205. def _get_bin_counts_and_mask(
  206. tokens: torch.Tensor,
  207. vocab_size: int,
  208. num_seqs: int,
  209. ) -> Tuple[torch.Tensor, torch.Tensor]:
  210. # Compute the bin counts for the tokens.
  211. # vocab_size + 1 for padding.
  212. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  213. dtype=torch.long,
  214. device=tokens.device)
  215. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  216. bin_counts = bin_counts[:, :vocab_size]
  217. mask = bin_counts > 0
  218. return bin_counts, mask
  219. def _get_custom_token_bans(
  220. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  221. assert sampling_metadata.seq_groups is not None
  222. banned_tokens: List[List[int]] = []
  223. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  224. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  225. seq_ids = seq_group.seq_ids
  226. custom_token_bans = sampling_params.custom_token_bans
  227. if (i < sampling_metadata.num_prompts
  228. and sampling_params.prompt_logprobs is not None):
  229. prompt_len = len(seq_group.prompt_logprob_indices)
  230. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  231. banned_tokens += [custom_token_bans] * len(seq_ids)
  232. return banned_tokens
  233. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  234. output_tokens_tensor: torch.Tensor,
  235. presence_penalties: torch.Tensor,
  236. frequency_penalties: torch.Tensor,
  237. repetition_penalties: torch.Tensor) -> torch.Tensor:
  238. num_seqs, vocab_size = logits.shape
  239. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  240. num_seqs)
  241. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  242. output_tokens_tensor, vocab_size, num_seqs)
  243. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  244. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  245. logits = torch.where(logits > 0, logits / repetition_penalties,
  246. logits * repetition_penalties)
  247. # We follow the definition in OpenAI API.
  248. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  249. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  250. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  251. return logits
  252. def _apply_temperatures(
  253. logits: torch.Tensor,
  254. temperatures: torch.Tensor,
  255. dynatemp_mins: torch.Tensor,
  256. dynatemp_maxs: torch.Tensor,
  257. dynatemp_exps: torch.Tensor,
  258. ) -> None:
  259. dynatemp_mask = (dynatemp_mins != 0) | (dynatemp_maxs != 0)
  260. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  261. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  262. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  263. dynatemp_logits = logits[dynatemp_mask]
  264. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  265. dynatemp_probs = dynatemp_shifted_logits.exp()
  266. dynatemp_entropies = -(dynatemp_probs *
  267. dynatemp_shifted_logits).nansum(dim=-1)
  268. dynatemp_max_entropies = torch.log_(
  269. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  270. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  271. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  272. normalized_entropies.pow_(dynatemp_exps))
  273. temperatures[dynatemp_mask] = dyn_temp
  274. temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
  275. temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
  276. # To prevent saturation of top logits, we shift the range to [-inf, 1]
  277. # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
  278. # Why mask? So we aren't potentially discarding data in milder temps.
  279. low_temps = temperatures < 0.1
  280. logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
  281. logits.div_(temperatures.unsqueeze(dim=1))
  282. def _apply_token_bans(logits: torch.Tensor,
  283. banned_tokens: List[List[int]]) -> torch.Tensor:
  284. for i, banned_token_ids in enumerate(banned_tokens):
  285. if i >= logits.size(0):
  286. break
  287. if not banned_token_ids:
  288. continue
  289. logits[i, banned_token_ids] = -float("inf")
  290. return logits
  291. def _apply_min_tokens_penalty(
  292. logits: torch.Tensor,
  293. sampling_metadata: SamplingMetadata,
  294. ) -> torch.Tensor:
  295. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  296. have not been generated yet
  297. """
  298. # list of indices in logits that will be set to -inf
  299. logits_to_penalize = []
  300. logits_applied = 0
  301. for seq_group in sampling_metadata.seq_groups:
  302. seq_ids = seq_group.seq_ids
  303. sampling_params = seq_group.sampling_params
  304. sample_indices = seq_group.sample_indices
  305. logits_applied += len(sample_indices) + len(
  306. seq_group.prompt_logprob_indices)
  307. if not seq_group.do_sample:
  308. continue
  309. start_idx = sample_indices[0]
  310. min_tokens = sampling_params.min_tokens
  311. token_ids_to_penalize = sampling_params.all_stop_token_ids
  312. if min_tokens > 0 and token_ids_to_penalize:
  313. seqs_to_penalize = []
  314. for j, seq_id in enumerate(seq_ids):
  315. seq_data = seq_group.seq_data[seq_id]
  316. if len(seq_data.output_token_ids_array) < min_tokens:
  317. seqs_to_penalize.append(j)
  318. if seqs_to_penalize:
  319. # convert to the index into logits
  320. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  321. # itertools.product pairs each seq index with every token id
  322. logits_to_penalize.extend(
  323. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  324. if logits_to_penalize:
  325. # use zip and * to group indices along each dimension
  326. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  327. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  328. # verifies that no rows in logits were missed unexpectedly
  329. assert logits_applied == logits.shape[0]
  330. return logits
  331. def _apply_top_k_top_p(
  332. logits: torch.Tensor,
  333. p: torch.Tensor,
  334. k: torch.Tensor,
  335. ) -> torch.Tensor:
  336. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  337. # Apply top-k.
  338. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  339. # Get all the top_k values.
  340. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  341. top_k_mask = logits_sort < top_k_mask
  342. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  343. # Apply top-p.
  344. probs_sort = logits_sort.softmax(dim=-1)
  345. probs_sum = probs_sort.cumsum(dim=-1)
  346. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  347. # at least one
  348. top_p_mask[:, -1] = False
  349. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  350. # Re-sort the probabilities.
  351. src = torch.arange(logits_idx.shape[-1],
  352. device=logits_idx.device).expand_as(logits_idx)
  353. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  354. index=logits_idx,
  355. src=src)
  356. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  357. return logits
  358. def _apply_min_p(
  359. logits: torch.Tensor,
  360. min_p: torch.Tensor,
  361. ) -> torch.Tensor:
  362. """
  363. Adapted from
  364. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  365. """
  366. probs = torch.softmax(logits, dim=-1)
  367. top_probs, _ = probs.max(dim=-1, keepdim=True)
  368. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  369. tokens_to_remove = probs < scaled_min_p
  370. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  371. return logits
  372. def _apply_top_a(
  373. logits: torch.Tensor,
  374. top_a: torch.Tensor,
  375. ) -> torch.Tensor:
  376. probs = torch.softmax(logits, dim=-1)
  377. top_probs, _ = probs.max(dim=-1, keepdim=True)
  378. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  379. tokens_to_remove = probs < threshold
  380. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  381. return logits
  382. def _apply_tfs(
  383. logits: torch.Tensor,
  384. tfs: torch.Tensor,
  385. ) -> torch.Tensor:
  386. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  387. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  388. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  389. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  390. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  391. tfs_mask = torch.cat(
  392. (
  393. torch.zeros(
  394. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  395. tfs_mask,
  396. torch.ones(
  397. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  398. ),
  399. dim=-1,
  400. )
  401. logits_sort[tfs_mask] = -float("inf")
  402. logits = torch.gather(logits_sort,
  403. dim=-1,
  404. index=torch.argsort(logits_idx, dim=-1))
  405. return logits
  406. def _apply_eta_cutoff(
  407. logits: torch.Tensor,
  408. eta_cutoff: torch.Tensor,
  409. ) -> torch.Tensor:
  410. shifted_logits = torch.log_softmax(logits, dim=-1)
  411. probs = shifted_logits.exp()
  412. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  413. eps = torch.min(eta_cutoff,
  414. torch.sqrt(eta_cutoff) *
  415. torch.exp(neg_entropy)).unsqueeze(dim=1)
  416. eta_mask = probs < eps
  417. # guard against nulling out all the logits
  418. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  419. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  420. logits[eta_mask] = -float("inf")
  421. return logits
  422. def _apply_epsilon_cutoff(
  423. logits: torch.Tensor,
  424. epsilon_cutoff: torch.Tensor,
  425. ) -> torch.Tensor:
  426. probs = logits.softmax(dim=-1)
  427. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  428. # guard against nulling out all the logits
  429. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  430. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  431. logits[eps_mask] = -float("inf")
  432. return logits
  433. def _apply_typical_sampling(
  434. logits: torch.Tensor,
  435. typical_p: torch.Tensor,
  436. ) -> torch.Tensor:
  437. shifted_logits = torch.log_softmax(logits, dim=-1)
  438. probs = shifted_logits.exp()
  439. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  440. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  441. _, indices = torch.sort(surprisal_deviations)
  442. reordered_probs = probs.gather(-1, indices)
  443. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  444. dim=1)
  445. min_tokens_to_keep = 1
  446. # Keep at least min_tokens_to_keep
  447. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  448. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  449. logits[typ_mask] = -float("inf")
  450. return logits
  451. def _apply_quadratic_sampling(
  452. logits: torch.Tensor,
  453. smoothing_factor: torch.Tensor,
  454. smoothing_curve: torch.Tensor,
  455. ) -> torch.Tensor:
  456. """
  457. Applies a quadratic transformation to the logits based on the
  458. provided smoothing factors and curves. The transformation is
  459. centered around the maximum logit value in the batch.
  460. The transformation involves a quadratic and cubic term, with the
  461. cubic term controlled by the smoothing curve. The quadratic term is
  462. scaled by the smoothing factor, and the cubic term is scaled by the
  463. product of the smoothing factor and the smoothing curve.
  464. params:
  465. logits (torch.Tensor): The logits to be transformed.
  466. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  467. term in the transformation.
  468. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  469. in the transformation.
  470. returns:
  471. torch.Tensor: The transformed logits.
  472. Credits: @kalomaze
  473. """
  474. mask = smoothing_factor != 0
  475. smoothing_factor.unsqueeze_(dim=1)
  476. smoothing_curve.unsqueeze_(dim=1)
  477. k = smoothing_factor * (3 - smoothing_curve) / 2
  478. s = smoothing_factor * (smoothing_curve - 1) / 2
  479. quadlogits = logits[mask] # limit to logits using this sampler
  480. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  481. # Construct the delta from each logit to its new value
  482. diff = quadlogits - max_logits
  483. diff -= diff**2 * (s[mask] * diff - k[mask])
  484. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  485. logits[mask] -= diff
  486. return logits
  487. def _apply_xtc_sampling(
  488. logits: torch.Tensor,
  489. xtc_thresholds: torch.Tensor,
  490. xtc_probabilities: torch.Tensor,
  491. ) -> torch.Tensor:
  492. """Apply Exclude Top Choices (XTC) sampling to the logits.
  493. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  494. Args:
  495. logits: (num_tokens, vocab_size) The input logits.
  496. xtc_thresholds: (num_tokens,) The threshold for each token.
  497. xtc_probabilities: (num_tokens,) The probability of applying XTC
  498. for each token.
  499. Returns:
  500. torch.Tensor: The modified logits.
  501. """
  502. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  503. if not apply_xtc.any():
  504. return logits
  505. probs = torch.softmax(logits, dim=-1)
  506. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  507. # Find indices where the next probability is above the threshold
  508. # Skips the top choice, which later on becomes skipping the last choice.
  509. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  510. # Apply XTC only to rows where it should be applied
  511. for i in range(logits.shape[0]):
  512. if apply_xtc[i]:
  513. # Count logits above the threshold (skipping the first)
  514. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  515. if indices_to_remove > 0:
  516. # Implies the top logit and at least one other is >= threshold.
  517. # Mask out above_thresh logits except the last/lowest one.
  518. logits[i].scatter_(
  519. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  520. return logits
  521. def _greedy_sample(
  522. selected_seq_groups: List[SequenceGroupToSample],
  523. samples: torch.Tensor,
  524. ) -> List[Tuple[List[int], List[int]]]:
  525. """Run greedy sampling on a given samples.
  526. Args:
  527. selected_seq_groups: A list of sequence groups batched.
  528. samples: (num_selected_samples,) A tensor of samples. The length of
  529. samples could be smaller than selected_seq_groups if
  530. seq_group.do_sample is False.
  531. Returns:
  532. Tuple of (next_token_ids, parent_ids). The length of returned list is
  533. same as the length of selected_seq_groups. If the corresponding
  534. seq_group has do_sample=False, tuple contains ([], [])
  535. """
  536. samples = samples.tolist()
  537. sample_idx = 0
  538. results = []
  539. for seq_group in selected_seq_groups:
  540. if not seq_group.do_sample:
  541. results.append(([], []))
  542. continue
  543. seq_ids = seq_group.seq_ids
  544. num_parent_seqs = len(seq_ids)
  545. assert num_parent_seqs == 1, (
  546. "Greedy sampling should have only one seq.")
  547. parent_ids = list(range(num_parent_seqs))
  548. next_token_ids = [samples[sample_idx]]
  549. results.append((next_token_ids, parent_ids))
  550. sample_idx += num_parent_seqs
  551. return results
  552. def _random_sample(
  553. selected_seq_groups: List[SequenceGroupToSample],
  554. random_samples: torch.Tensor,
  555. ) -> List[Tuple[List[int], List[int]]]:
  556. """Run random sampling on a given samples.
  557. Args:
  558. selected_seq_groups: A list of sequence groups batched.
  559. random_samples: (num_selected_samples,) A tensor of samples. The
  560. length of samples could be smaller than selected_seq_groups if
  561. seq_group.do_sample is False.
  562. Returns:
  563. Tuple of (next_token_ids, parent_ids). The length of returned list is
  564. same as the length of selected_seq_groups. If the corresponding
  565. seq_group has do_sample=False, tuple contains ([], [])
  566. """
  567. # Find the maximum best_of value of the prompt phase requests.
  568. random_samples = random_samples.cpu()
  569. sample_idx = 0
  570. results = []
  571. for seq_group in selected_seq_groups:
  572. if not seq_group.do_sample:
  573. results.append(([], []))
  574. continue
  575. seq_ids = seq_group.seq_ids
  576. sampling_params = seq_group.sampling_params
  577. is_prompt = seq_group.is_prompt
  578. num_parent_seqs = len(seq_ids)
  579. if is_prompt:
  580. # Prompt phase.
  581. parent_ids = [0] * sampling_params.best_of
  582. next_token_ids = random_samples[
  583. sample_idx, :sampling_params.best_of].tolist()
  584. else:
  585. # Generation phase.
  586. parent_ids = list(range(num_parent_seqs))
  587. next_token_ids = random_samples[sample_idx:sample_idx +
  588. num_parent_seqs, 0].tolist()
  589. results.append((next_token_ids, parent_ids))
  590. sample_idx += num_parent_seqs
  591. return results
  592. def _beam_search_sample(
  593. selected_seq_groups: List[SequenceGroupToSample],
  594. logprobs: torch.Tensor,
  595. ) -> List[Tuple[List[int], List[int]]]:
  596. """Run beam sampling on a given samples.
  597. Args:
  598. selected_seq_groups: A list of sequence groups batched.
  599. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  600. on selected sample indices.
  601. Returns:
  602. Tuple of (next_token_ids, parent_ids). The length of returned list is
  603. same as the length of selected_seq_groups. If the corresponding
  604. seq_group has do_sample=False, tuple contains ([], [])
  605. """
  606. # We sample 2 * beam_width candidates to make sure that with high
  607. # probability we can get `beam_width` candidates in addition to
  608. # the finished sequences for the next iteration. See
  609. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  610. # for details. See also HF reference:
  611. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  612. #
  613. # NOTE: Beam search is not vectorized, so its speed can be slower than
  614. # other sampling methods.
  615. sample_idx = 0
  616. results = []
  617. for seq_group in selected_seq_groups:
  618. if not seq_group.do_sample:
  619. results.append(([], []))
  620. continue
  621. is_prompt = seq_group.is_prompt
  622. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  623. num_parent_seqs = len(seq_ids)
  624. beam_width = sampling_params.best_of
  625. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  626. if is_prompt:
  627. # Prompt phase.
  628. assert num_parent_seqs == 1, (
  629. "Prompt input should have only one seq.")
  630. parent_ids = [0] * (2 * beam_width)
  631. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  632. 2 * beam_width)
  633. next_token_ids = next_token_ids.tolist()
  634. else:
  635. # Generation phase.
  636. cumulative_logprobs = [
  637. seq_group.seq_data[seq_id].cumulative_logprob
  638. for seq_id in seq_ids
  639. ]
  640. cumulative_logprobs = torch.tensor(
  641. cumulative_logprobs,
  642. dtype=torch.float,
  643. device=seq_group_logprobs.device)
  644. seq_group_logprobs = (seq_group_logprobs +
  645. cumulative_logprobs.unsqueeze(dim=1))
  646. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  647. 2 * beam_width)
  648. topk_ids = topk_ids.tolist()
  649. vocab_size = seq_group_logprobs.size(-1)
  650. parent_ids = [i // vocab_size for i in topk_ids]
  651. next_token_ids = [i % vocab_size for i in topk_ids]
  652. results.append((next_token_ids, parent_ids))
  653. sample_idx += num_parent_seqs
  654. assert sample_idx == logprobs.size(0)
  655. return results
  656. # torch.multinomial forces a GPU<->CPU sync.
  657. # Therefore, we use an optimized implementation instead.
  658. # Note that we always sample with replacement.
  659. # probs will be modified in place, but this is fine, as we pass
  660. # in a copy already.
  661. def _multinomial(
  662. probs: torch.Tensor,
  663. num_samples: int,
  664. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  665. ) -> torch.Tensor:
  666. if num_samples > 1:
  667. # This is equivalent to torch.repeat_interleaved (which also
  668. # forces a GPU<->CPU sync).
  669. # This allows us to do sampling with replacement by creating
  670. # num_samples copies of each row in the tensor, and then
  671. # batch sampling the resulting tensor.
  672. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  673. probs.shape[1]).contiguous().view(
  674. -1, probs.shape[1])
  675. q = torch.empty_like(probs)
  676. if seq_groups is None:
  677. q.exponential_()
  678. else:
  679. sample_idx = 0
  680. for seq_group in seq_groups:
  681. seq_ids = seq_group.seq_ids
  682. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  683. q[sample_idx:next_sample_idx].exponential_(
  684. generator=seq_group.generator)
  685. sample_idx = next_sample_idx
  686. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  687. def _sample_with_torch(
  688. probs: torch.Tensor,
  689. logprobs: torch.Tensor,
  690. sampling_metadata: SamplingMetadata,
  691. include_gpu_probs_tensor: bool,
  692. modify_greedy_probs: bool,
  693. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  694. categorized_seq_group_ids = {t: [] for t in SamplingType}
  695. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  696. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  697. sampling_params = seq_group.sampling_params
  698. sampling_type = sampling_params.sampling_type
  699. categorized_seq_group_ids[sampling_type].append(i)
  700. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  701. sample_metadata = {}
  702. multinomial_samples = {}
  703. # Create output tensor for sampled token ids.
  704. if include_gpu_probs_tensor:
  705. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  706. 1,
  707. dtype=torch.long,
  708. device=logprobs.device)
  709. else:
  710. sampled_token_ids_tensor = None
  711. # Counterintuitively, having two loops here is actually faster.
  712. # The first loop can run without waiting on GPU<->CPU sync.
  713. for sampling_type in SamplingType:
  714. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  715. num_tokens = len(sample_indices)
  716. if num_tokens == 0:
  717. continue
  718. seq_group_id = categorized_seq_group_ids[sampling_type]
  719. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  720. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  721. long_sample_indices = sample_indices.long()
  722. if sampling_type == SamplingType.GREEDY:
  723. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  724. dim=-1)
  725. if include_gpu_probs_tensor:
  726. # Store sampled tokens in output tensor.
  727. sampled_token_ids_tensor[
  728. long_sample_indices] = greedy_samples.unsqueeze(-1)
  729. if modify_greedy_probs:
  730. # If required, modify the probabilities such that sampling from
  731. # the modified distribution would always sample the argmax
  732. # token id.
  733. _modify_greedy_probs_inplace(logprobs, probs,
  734. long_sample_indices,
  735. greedy_samples)
  736. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  737. max_best_of_in_batch = 1
  738. for seq_group in seq_groups:
  739. if seq_group.is_prompt:
  740. sampling_params = seq_group.sampling_params
  741. max_best_of_in_batch = max(max_best_of_in_batch,
  742. sampling_params.best_of)
  743. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  744. "seq_groups": seq_groups,
  745. }
  746. multinomial_samples[sampling_type] = _multinomial(
  747. probs[long_sample_indices], max_best_of_in_batch,
  748. **seeded_args)
  749. if include_gpu_probs_tensor:
  750. # Store sampled tokens in output tensor.
  751. sampled_token_ids_tensor[
  752. long_sample_indices] = multinomial_samples[sampling_type]
  753. elif sampling_type == SamplingType.BEAM:
  754. beam_search_logprobs = logprobs[sample_indices]
  755. else:
  756. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  757. # GPU<->CPU sync happens in the loop below.
  758. # This also converts the sample output to Python objects.
  759. if not sampling_metadata.skip_sampler_cpu_output:
  760. for sampling_type in SamplingType:
  761. if sampling_type not in sample_metadata:
  762. continue
  763. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  764. if sampling_type == SamplingType.GREEDY:
  765. sample_results = _greedy_sample(seq_groups, greedy_samples)
  766. elif sampling_type in (SamplingType.RANDOM,
  767. SamplingType.RANDOM_SEED):
  768. sample_results = _random_sample(
  769. seq_groups, multinomial_samples[sampling_type])
  770. elif sampling_type == SamplingType.BEAM:
  771. sample_results = _beam_search_sample(seq_groups,
  772. beam_search_logprobs)
  773. sample_results_dict.update(zip(seq_group_id, sample_results))
  774. sample_results = [
  775. sample_results_dict.get(i, ([], []))
  776. for i in range(len(sampling_metadata.seq_groups))
  777. ]
  778. else:
  779. sample_results = []
  780. return sample_results, sampled_token_ids_tensor
  781. def _sample_with_triton_kernel(
  782. probs: torch.Tensor,
  783. logprobs: torch.Tensor,
  784. sampling_metadata: SamplingMetadata,
  785. sampling_tensors: SamplingTensors,
  786. ) -> List[Tuple[List[int], List[int]]]:
  787. categorized_seq_group_ids = {t: [] for t in SamplingType}
  788. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  789. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  790. sampling_params = seq_group.sampling_params
  791. sampling_type = sampling_params.sampling_type
  792. categorized_seq_group_ids[sampling_type].append(i)
  793. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  794. sample_metadata = {}
  795. max_best_of_in_batch = 1
  796. # Counterintuitively, having two loops here is actually faster.
  797. # The first loop can run without waiting on GPU<->CPU sync.
  798. for sampling_type in SamplingType:
  799. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  800. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  801. num_tokens = len(sample_indices)
  802. if num_tokens == 0:
  803. continue
  804. seq_group_id = categorized_seq_group_ids[sampling_type]
  805. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  806. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  807. sample_indices,
  808. sampled_token_indices)
  809. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  810. SamplingType.RANDOM_SEED):
  811. for seq_group in seq_groups:
  812. if seq_group.is_prompt:
  813. sampling_params = seq_group.sampling_params
  814. max_best_of_in_batch = max(max_best_of_in_batch,
  815. sampling_params.best_of)
  816. elif sampling_type == SamplingType.BEAM:
  817. beam_search_logprobs = logprobs[sample_indices]
  818. else:
  819. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  820. sampled_tokens, _, _ = sample_triton(
  821. probs=probs,
  822. seeds=sampling_tensors.sampling_seeds,
  823. max_best_of=max_best_of_in_batch,
  824. sample_indices=sampling_tensors.sample_indices,
  825. logprobs=logprobs,
  826. # don't save logprobs because we have logic for that below
  827. # TODO: use this instead of the CPU-based logic below
  828. save_logprobs=False,
  829. )
  830. # GPU<->CPU sync happens in the loop below.
  831. for sampling_type in SamplingType:
  832. if sampling_type not in sample_metadata:
  833. continue
  834. (seq_group_id, seq_groups, sample_indices,
  835. sampled_token_indices) = sample_metadata[sampling_type]
  836. if sampling_type == SamplingType.GREEDY:
  837. sample_results = _greedy_sample(
  838. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  839. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  840. sample_results = _random_sample(
  841. seq_groups, sampled_tokens[sampled_token_indices])
  842. elif sampling_type == SamplingType.BEAM:
  843. sample_results = _beam_search_sample(seq_groups,
  844. beam_search_logprobs)
  845. sample_results_dict.update(zip(seq_group_id, sample_results))
  846. sample_results = [
  847. sample_results_dict.get(i, ([], []))
  848. for i in range(len(sampling_metadata.seq_groups))
  849. ]
  850. return sample_results
  851. def _sample(
  852. probs: torch.Tensor, logprobs: torch.Tensor,
  853. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  854. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  855. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  856. """
  857. Args:
  858. probs: (num_query_tokens_in_batch, num_vocab)
  859. logprobs: (num_query_tokens_in_batch, num_vocab)
  860. sampling_metadata: The metadata for a batch for sampling.
  861. sampling_tensors: Tensors that include sampling related metadata.
  862. Returns:
  863. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  864. If sampling is skipped, it returns ([], [])
  865. sampled_token_ids_tensor: A tensor of sampled token ids.
  866. """
  867. return _sample_with_torch(
  868. probs,
  869. logprobs,
  870. sampling_metadata,
  871. include_gpu_probs_tensor=include_gpu_probs_tensor,
  872. modify_greedy_probs=modify_greedy_probs,
  873. )
  874. # TODO: Enable once Triton kernel & associated code is faster.
  875. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  876. # sampling_tensors)
  877. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  878. """
  879. This function calculates the ranks of the chosen tokens in a logprob tensor.
  880. Args:
  881. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  882. where N is the no. of tokens and M is the vocab dim.
  883. indices (torch.Tensor): List of chosen token indices.
  884. Returns:
  885. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  886. Each element in the returned tensor represents the rank
  887. of the chosen token in the input logprob tensor.
  888. """
  889. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  890. indices]
  891. return (x > vals[:, None]).long().sum(1).add_(1)
  892. def _get_logprobs(
  893. logprobs: torch.Tensor,
  894. sampling_metadata: SamplingMetadata,
  895. sample_results: List[Tuple[List[int], List[int]]],
  896. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  897. """Return sample lobprobs and prompt logprobs.
  898. The logic consists of 3 parts.
  899. - Select indices to compute logprob from, ranks of token ids, and
  900. the top k token ids from logprobs.
  901. - Compute prompt logprobs if required.
  902. - Compute sample logprobs if required.
  903. Args:
  904. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  905. logprob per vocab. Sequence groups' query tokens are batched in a
  906. single flattened tensor. For example, assuming there are N
  907. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  908. prompt logprob is enabled), decode tokens for seq_group_1 (if
  909. sampling is required), prefill tokens for seq_group_2, ...
  910. sampling_metadata: The sampling metadata.
  911. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  912. parent_ids) for each sequence group. When beam search is enabled,
  913. sample_results can contain different number of seq_ids from
  914. sampling_metadata.seq_groups. It is because beam search creates
  915. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  916. BEAM_WIDTH number of seq_ids).
  917. Returns:
  918. A tuple of prompt and sample logprobs per sequence group in a batch.
  919. """
  920. # The index of query token to calculate logprobs. It includes both
  921. # prompt and sample logprob indices.
  922. query_indices: List[int] = []
  923. # The next token ids to get the logprob value from.
  924. next_token_ids: List[int] = []
  925. # The largest requested number of logprobs. We find logprobs as many as the
  926. # largest num logprobs in this API. If every logprobs is None, it will be
  927. # set to -1.
  928. largest_num_logprobs = -1
  929. # If beam search is enabled.
  930. use_beam_search = False
  931. # Select indices to compute logprob from, ranks of token ids, and the top
  932. # k token ids from logprobs.
  933. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  934. sample_results):
  935. sampling_params = seq_group.sampling_params
  936. # Update indices and tokens for prompt logprobs.
  937. if (seq_group.is_prompt
  938. and sampling_params.prompt_logprobs is not None):
  939. largest_num_logprobs = max(largest_num_logprobs,
  940. sampling_params.prompt_logprobs)
  941. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  942. query_indices.extend(seq_group.prompt_logprob_indices)
  943. next_token_ids.extend(next_prompt_tokens)
  944. # Update indices and next tokenes for sample logprob.
  945. if seq_group.do_sample:
  946. token_ids, parent_seq_ids = sample_result
  947. # NOTE: We cannot directly use sample_indices because
  948. # sample_indices only contain parent seq_ids of a previous step.
  949. # The current step may have different number of seq_ids, and
  950. # we can obtain it from `sample_result[1]`.
  951. query_idx = seq_group.sample_indices[0]
  952. query_indices.extend(
  953. [query_idx + parent_id for parent_id in parent_seq_ids])
  954. next_token_ids.extend(token_ids)
  955. if sampling_params.logprobs is not None:
  956. largest_num_logprobs = max(largest_num_logprobs,
  957. sampling_params.logprobs)
  958. use_beam_search = use_beam_search or sampling_params.use_beam_search
  959. assert len(next_token_ids) == len(query_indices)
  960. if len(query_indices) == 0:
  961. empty_sampled_logprob = []
  962. empty_prompt_logprob = None
  963. return [empty_prompt_logprob], [empty_sampled_logprob]
  964. selected_logprobs, ranks = None, None
  965. top_logprobs, top_token_ids = None, None
  966. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  967. # skip the whole logprob calculation.
  968. if largest_num_logprobs >= 0 or use_beam_search:
  969. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  970. next_token_ids_gpu = torch.tensor(next_token_ids,
  971. device=logprobs.device)
  972. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  973. # contain duplicates if beam search is enabled.
  974. selected_logprobs = logprobs[[
  975. query_indices_gpu,
  976. next_token_ids_gpu,
  977. ]]
  978. ranks = _get_ranks(
  979. logprobs[query_indices_gpu],
  980. next_token_ids_gpu,
  981. )
  982. assert selected_logprobs.shape[0] == ranks.shape[0]
  983. # We need to compute top k only if there exists logprobs > 0.
  984. if largest_num_logprobs > 0:
  985. # Logprobs of topk tokens for a batch of sequence groups.
  986. # (num_query_tokens_across_batch).
  987. top_logprobs, top_token_ids = torch.topk(logprobs,
  988. largest_num_logprobs,
  989. dim=-1)
  990. top_logprobs = top_logprobs.to('cpu')
  991. top_token_ids = top_token_ids.to('cpu')
  992. selected_logprobs = selected_logprobs.to('cpu')
  993. ranks = ranks.to('cpu')
  994. # Find prompt/sample logprobs.
  995. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  996. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  997. top_logprob_idx = 0
  998. selected_logprobs_idx = 0
  999. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1000. sample_results):
  1001. (prompt_logprobs, top_logprob_idx,
  1002. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1003. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1004. selected_logprobs_idx, top_logprob_idx)
  1005. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1006. (sampled_logprobs, top_logprob_idx,
  1007. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1008. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1009. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1010. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1011. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1012. def _get_prompt_logprob_if_needed(
  1013. seq_group: SequenceGroupToSample,
  1014. selected_logprobs: torch.Tensor,
  1015. ranks: torch.Tensor,
  1016. top_token_ids: torch.Tensor,
  1017. top_logprobs: torch.Tensor,
  1018. selected_logprobs_idx: int,
  1019. top_logprob_idx: int,
  1020. ):
  1021. """Compute the prompt logprob from a sequence group if needed."""
  1022. sampling_params = seq_group.sampling_params
  1023. is_prompt = seq_group.is_prompt
  1024. # Find prompt logprobs
  1025. prompt_logprobs: Optional[PromptLogprobs] = None
  1026. if is_prompt and sampling_params.prompt_logprobs is not None:
  1027. prompt_logprobs = []
  1028. num_logprobs = sampling_params.prompt_logprobs
  1029. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1030. # Pre-select indexes and create a list. It is faster than calling .item
  1031. # repetitively.
  1032. selected_logprob_items = selected_logprobs[
  1033. selected_logprobs_idx:selected_logprobs_idx +
  1034. len(next_prompt_tokens)].tolist()
  1035. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1036. len(next_prompt_tokens)].tolist()
  1037. for idx, token_id in enumerate(next_prompt_tokens):
  1038. # Calculate the prompt logprob of the real prompt tokens.
  1039. # {token_id: (logprob, rank_from_vocab)}
  1040. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1041. token_id: (selected_logprob_items[idx], rank_items[idx])
  1042. }
  1043. # Add top K prompt logprobs along with its rank.
  1044. if num_logprobs > 0:
  1045. top_ids = top_token_ids[
  1046. top_logprob_idx, :num_logprobs].tolist()
  1047. top_probs = top_logprobs[
  1048. top_logprob_idx, :num_logprobs].tolist()
  1049. # Top K is already sorted by rank, so we can use 1 ~
  1050. # num_logprobs + 1 for rank.
  1051. top_ranks = range(1, num_logprobs + 1)
  1052. prompt_logprobs_dict.update({
  1053. top_id: (top_prob, rank)
  1054. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1055. top_ranks)
  1056. })
  1057. prompt_logprobs.append({
  1058. token_id: Logprob(*logprob_and_rank)
  1059. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1060. })
  1061. # + 1 to go to the next prompt token.
  1062. top_logprob_idx += 1
  1063. # + len(next_prompt_tokens) to go to the next prompt.
  1064. selected_logprobs_idx += len(next_prompt_tokens)
  1065. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1066. def _get_sampled_logprob_if_needed(
  1067. seq_group: SequenceGroupToSample,
  1068. sample_result: Tuple[List[int], List[int]],
  1069. selected_logprobs: torch.Tensor,
  1070. ranks: torch.Tensor,
  1071. top_token_ids: torch.Tensor,
  1072. top_logprobs: torch.Tensor,
  1073. selected_logprobs_idx: int,
  1074. top_logprob_idx: int,
  1075. ):
  1076. """Compute the sample logprob if needed."""
  1077. seq_ids = seq_group.seq_ids
  1078. num_logprobs = seq_group.sampling_params.logprobs
  1079. use_beam_search = seq_group.sampling_params.use_beam_search
  1080. sampled_logprobs: SampleLogprobs = []
  1081. next_token_ids, parent_seq_ids = sample_result
  1082. if seq_group.do_sample:
  1083. assert len(next_token_ids) > 0
  1084. if num_logprobs is None and not use_beam_search:
  1085. for next_token_id in next_token_ids:
  1086. # Use a dummy logprob
  1087. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1088. else:
  1089. # Pre-select items from tensor. tolist() is faster than repetitive
  1090. # `.item()` calls.
  1091. selected_logprob_items = selected_logprobs[
  1092. selected_logprobs_idx:selected_logprobs_idx +
  1093. len(next_token_ids)].tolist()
  1094. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1095. len(next_token_ids)].tolist()
  1096. for idx, (next_token_id, parent_id) in enumerate(
  1097. zip(next_token_ids, parent_seq_ids)):
  1098. # Get the logprob of a sampled token.
  1099. sampled_logprobs_dict = {
  1100. next_token_id:
  1101. (selected_logprob_items[idx], rank_items[idx])
  1102. }
  1103. if num_logprobs is not None and num_logprobs > 0:
  1104. # Get top K logprobs.
  1105. top_ids = top_token_ids[top_logprob_idx +
  1106. parent_id, :num_logprobs].tolist()
  1107. top_probs = top_logprobs[
  1108. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1109. # Top K is already sorted by rank, so we can use 1 ~
  1110. # num_logprobs + 1 for rank.
  1111. top_ranks = range(1, num_logprobs + 1)
  1112. sampled_logprobs_dict.update({
  1113. top_id: (top_prob, rank)
  1114. for top_id, top_prob, rank in zip(
  1115. top_ids, top_probs, top_ranks)
  1116. })
  1117. sampled_logprobs.append({
  1118. token_id: Logprob(*logprob_and_rank)
  1119. for token_id, logprob_and_rank in
  1120. sampled_logprobs_dict.items()
  1121. })
  1122. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1123. # logprobs for the current step, which has len(next_token_ids) tokens
  1124. # per sequence group. `logprobs` includes logprobs from the previous
  1125. # steps, which has len(seq_ids) tokens per sequence group.
  1126. # Iterate to the next sequence group in a batch.
  1127. selected_logprobs_idx += len(next_token_ids)
  1128. # Iterate to the next sequence group in a batch.
  1129. top_logprob_idx += len(seq_ids)
  1130. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1131. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1132. sample_indices: torch.Tensor,
  1133. greedy_samples: torch.Tensor) -> None:
  1134. """Modify the probability distributions of the greedily-sampled tokens such
  1135. that each sampled token has a "probability" of 1.0. This is required by
  1136. speculative decoding, which depends on the sampling method being encoded
  1137. within the probability distribution for correctness.
  1138. # Why do we only need to do this for greedy sampling?
  1139. Aphrodite's sampler performs the following steps for greedy or multinomial
  1140. (random) sampling:
  1141. 1. Get logits from model.
  1142. 2. Modify logits according to per-sequence sampling parameters.
  1143. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1144. according to their frequency, etc.
  1145. 3. Sample a token.
  1146. - Random sampling simply samples from the modified probability
  1147. distribution.
  1148. - Greedy sampling performs `argmax` to obtain the token with the
  1149. highest likelihood.
  1150. Ignoring greedy sampling for a moment, we find that the computed probability
  1151. distribution has the following property: we can sample from it independently
  1152. and find that the token sampled by the Sampler has a frequency corresponding
  1153. to how often we see it in our sampling. In other words, for tokens sampled
  1154. with Aphrodite's random SamplingType, the computed probability distribution
  1155. encodes the sampling methodology completely.
  1156. Greedy sampling does not normally have this property. Aphrodite modifies
  1157. logits according to sampling params, then performs `argmax`, then returns
  1158. the sampled token and the computed probability distribution. If we sample
  1159. from the distribution, we'll find the likelihood of the greedily-sampled
  1160. token is not always 1.0.
  1161. Since lossless speculative decoding requires that the sampling methodology
  1162. be encoded within the probability distribution, we are motivated to modify
  1163. the probability distribution such that the sampled token has probability 1
  1164. when speculative decoding is used.
  1165. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1166. greedy sampling using multinomial computation and unite the codepaths. This
  1167. has implications on the overall design of the sampler, e.g. how to record
  1168. accurate logprobs for the user, so this improvement is deferred to later.
  1169. """
  1170. # NOTE: logprobs are not modified so they can be returned to the user.
  1171. probs[sample_indices, :] = 0
  1172. probs[sample_indices, greedy_samples] = 1.0
  1173. def _build_sampler_output(
  1174. sample_results: SampleResultType,
  1175. sampling_metadata: SamplingMetadata,
  1176. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1177. sample_logprobs: Optional[List[SampleLogprobs]],
  1178. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1179. torch.Tensor]],
  1180. skip_sampler_cpu_output: bool = False,
  1181. ) -> SamplerOutput:
  1182. """Construct Python objects with the output of sampling.
  1183. Args:
  1184. on_device_tensors: Tuple containing on-device tensors with the
  1185. probabilities used in sampling and the sampled token ids. This
  1186. allows post-processing without copies to CPU/serialization, e.g. in
  1187. speculative decoding rejection sampling.
  1188. """
  1189. sampler_output: List[CompletionSequenceGroupOutput] = []
  1190. if not skip_sampler_cpu_output:
  1191. assert prompt_logprobs is not None
  1192. assert sample_logprobs is not None
  1193. for (seq_group, sample_result, group_prompt_logprobs,
  1194. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1195. sample_results, prompt_logprobs,
  1196. sample_logprobs):
  1197. seq_ids = seq_group.seq_ids
  1198. next_token_ids, parent_ids = sample_result
  1199. seq_outputs: List[SequenceOutput] = []
  1200. for parent_id, next_token_id, logprobs in zip(
  1201. parent_ids, next_token_ids, group_sample_logprobs):
  1202. seq_outputs.append(
  1203. SequenceOutput(seq_ids[parent_id], next_token_id,
  1204. logprobs))
  1205. sampler_output.append(
  1206. CompletionSequenceGroupOutput(seq_outputs,
  1207. group_prompt_logprobs))
  1208. # If not specified, store None values in SamplerOutput.
  1209. if on_device_tensors is not None:
  1210. (sampled_token_probs, logprobs_tensor,
  1211. sampled_token_ids) = on_device_tensors
  1212. else:
  1213. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1214. None)
  1215. return SamplerOutput(
  1216. outputs=sampler_output,
  1217. sampled_token_probs=sampled_token_probs,
  1218. sampled_token_ids=sampled_token_ids,
  1219. logprobs=logprobs_tensor,
  1220. )
  1221. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1222. """Get a list of next prompt tokens to compute logprob from a
  1223. given sequence group.
  1224. It is used to compute prompt logprob. Imagine you have logprob for each
  1225. query token. Query token needs to know the next prompt token id to compute
  1226. prompt logprob. This is a helper to obtain next prompt token ids.
  1227. This API has to be used only when the caller knows seq_group is in prefill
  1228. stage.
  1229. Returns:
  1230. A list of next prompt tokens to compute logprob.
  1231. """
  1232. assert seq_group.is_prompt, (
  1233. "Caller should ensure the sequence group is in a prefill stage.")
  1234. seq_ids = seq_group.seq_ids
  1235. query_len = seq_group.query_len
  1236. assert query_len is not None
  1237. # prompt has only 1 seq id.
  1238. assert len(seq_ids) == 1
  1239. seq_data = seq_group.seq_data[seq_ids[0]]
  1240. computed_len = seq_data.get_num_computed_tokens()
  1241. prompt_tokens = seq_data.prompt_token_ids
  1242. # +1 because we are looking for a next prompt token.
  1243. next_token_index_start = computed_len + 1
  1244. next_token_index_end = min(computed_len + query_len + 1,
  1245. len(prompt_tokens))
  1246. next_prompt_tokens = prompt_tokens[
  1247. next_token_index_start:next_token_index_end]
  1248. return next_prompt_tokens
  1249. # def _apply_mirostat_v2(logits: torch.Tensor,
  1250. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1251. # # Reduce our view to just the affected logits
  1252. # logit_view = logits[sampling_tensors.miro_indices]
  1253. # # Calculate surprise value per token
  1254. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1255. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1256. # # Mask out "too-surprising" tokens (surprisal > mu)
  1257. # mus = sampling_tensors.miro_mus
  1258. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1259. # # Unmask most-likely logit to guarantee a selection.
  1260. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1261. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1262. # # Apply logit mask (effectively a top-k filter).
  1263. # logit_view[miro_mask] = -float("inf")
  1264. # # Project logit changes made to the view onto the original.
  1265. # # I think this step might be redundant.
  1266. # logits[sampling_tensors.miro_indices] = logit_view
  1267. # return logits
  1268. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1269. # sample_results: List[Tuple[List[int], List[int]]],
  1270. # sampling_metadata: SamplingMetadata,
  1271. # output_metadata: OutputMetadata) -> None:
  1272. # """Based on whichever token was finally sampled, we calculate the
  1273. # final surprisal values to update the mus.
  1274. # Because a single sequence can have multiple samples, we must fork
  1275. # the mu accordingly."""
  1276. # assert sampling_metadata.seq_groups is not None
  1277. # seqid_to_tokens = {}
  1278. # seqid_to_indices = {}
  1279. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1280. # sample_results):
  1281. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1282. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1283. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1284. # seqids = args.miro_seqids
  1285. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1286. # device=logits.device,
  1287. # dtype=torch.long)
  1288. # # Clumsily, we recalculate token surprisals.
  1289. # logits_view = logits[args.miro_indices]
  1290. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1291. # dim=-1,
  1292. # index=picked_tokens) / -math.log(2)
  1293. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1294. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1295. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1296. # nu_mus = mus - (picked_surprise - taus) * etas
  1297. # # Record updated mu values for use in the next iteration
  1298. # # Note how each mu is split into multiple based on the number of samples.
  1299. # for seqid, seq_mus in zip(seqids, nu_mus):
  1300. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1301. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)