sampler.py 85 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import warnings
  4. from dataclasses import dataclass
  5. from enum import IntEnum
  6. from math import inf
  7. from typing import Dict, List, Optional, Tuple, Union
  8. import msgspec
  9. import torch
  10. import torch.nn as nn
  11. from loguru import logger
  12. import aphrodite._custom_ops as ops
  13. import aphrodite.common.envs as envs
  14. from aphrodite.common.sampling_params import SamplingType
  15. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  16. PromptLogprobs, SampleLogprobs,
  17. SequenceOutput)
  18. from aphrodite.common.utils import is_cpu
  19. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  20. SamplingTensors,
  21. SequenceGroupToSample)
  22. from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
  23. # (num_token_ids, num_parent_ids) per sequence group.
  24. SampleResultType = List[Tuple[List[int], List[int]]]
  25. # Types of temporary data structures used for
  26. # computing sample_result
  27. SampleMetadataType = Dict[SamplingType, Tuple[List[int],
  28. List[SequenceGroupToSample]]]
  29. MultinomialSamplesType = Dict[SamplingType, torch.Tensor]
  30. SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]]
  31. # Encapsulates temporary data structures for computing
  32. # sample_result.
  33. #
  34. # * For multi-step scheduling: must be returned
  35. # by `Sampler.forward()` and used later to compute the pythonized
  36. # sample_result
  37. #
  38. # * For single-step scheduling: consumed immediately
  39. # inside `Sampler.forward()` to compute pythonized sample_result.
  40. @dataclass
  41. class SampleResultArgsType:
  42. sample_metadata: SampleMetadataType
  43. multinomial_samples: MultinomialSamplesType
  44. sample_results_dict: SampleResultsDictType
  45. sampling_metadata: SamplingMetadata
  46. greedy_samples: Optional[torch.Tensor]
  47. beam_search_logprobs: Optional[torch.Tensor]
  48. # Union of non-deferred (single-step scheduling)
  49. # vs deferred (multi-step scheduling)
  50. # sample result types
  51. MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
  52. # Abbreviation of the _sample() return type
  53. SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
  54. class SamplerOutput(
  55. msgspec.Struct,
  56. omit_defaults=True, # type: ignore[call-arg]
  57. array_like=True): # type: ignore[call-arg]
  58. """For each sequence group, we generate a list of SequenceOutput object,
  59. each of which contains one possible candidate for the next token.
  60. This data structure implements methods, so it can be used like a list, but
  61. also has optional fields for device tensors.
  62. """
  63. outputs: List[CompletionSequenceGroupOutput]
  64. # On-device tensor containing probabilities of each token.
  65. sampled_token_probs: Optional[torch.Tensor] = None
  66. # On-device tensor containing the logprobs of each token.
  67. logprobs: Optional["torch.Tensor"] = None
  68. # Holds either (1) the pythonized sampler result (single-step scheduling)
  69. # or (2) what will be arguments for later deferred pythonization of the
  70. # sampler result (muliti-step scheduling)
  71. deferred_sample_results_args: Optional[SampleResultArgsType] = None
  72. # On-device tensor containing the sampled token ids.
  73. sampled_token_ids: Optional[torch.Tensor] = None
  74. # CPU tensor containing the sampled token ids. Used during multi-step to
  75. # return the sampled token ids from last rank to AsyncLLMEngine to be
  76. # 'broadcasted' to all other PP ranks for next step.
  77. sampled_token_ids_cpu: Optional[torch.Tensor] = None
  78. # Spec decode metrics populated by workers.
  79. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
  80. # Optional last hidden states from the model.
  81. hidden_states: Optional[torch.Tensor] = None
  82. # Optional prefill hidden states from the model
  83. # (used for models like EAGLE).
  84. prefill_hidden_states: Optional[torch.Tensor] = None
  85. # Time taken in the forward pass for this across all workers
  86. model_forward_time: Optional[float] = None
  87. # Time taken in the model execute function. This will include model forward,
  88. # block/sync across workers, cpu-gpu sync time and sampling time.
  89. model_execute_time: Optional[float] = None
  90. def __getitem__(self, idx: int):
  91. return self.outputs[idx]
  92. def __setitem__(self, idx: int, value):
  93. self.outputs[idx] = value
  94. def __len__(self):
  95. return len(self.outputs)
  96. def __eq__(self, other: object):
  97. return isinstance(other,
  98. self.__class__) and self.outputs == other.outputs
  99. def __repr__(self) -> str:
  100. """Show the shape of a tensor instead of its values to reduce noise.
  101. """
  102. sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
  103. else self.sampled_token_probs.shape)
  104. sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
  105. self.sampled_token_ids.shape)
  106. return (
  107. f"SamplerOutput(outputs={self.outputs}, "
  108. f"sampled_token_probs={sampled_token_probs_repr}, "
  109. f"sampled_token_ids={sampled_token_ids_repr}, "
  110. f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
  111. # There isn't a "safe" temperature range for fp16 logits.
  112. # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
  113. # that this temperature well-uses the fp16 space after the logits are offset.
  114. _TEMPERATURE_MINIMUM = 2e-5
  115. # If enabled, we switch to a more performant implementation
  116. # of top-k and top-p
  117. APHRODITE_USE_SAMPLING_KERNELS = envs.APHRODITE_USE_SAMPLING_KERNELS
  118. class SamplerID(IntEnum):
  119. # Mirror these in aphrodite/common/sampling_params.py
  120. # Values out of order to keep backwards compatibility
  121. # with Koboldcpp values
  122. DRY = 7
  123. PENALTIES = 6
  124. NO_REPEAT_NGRAM = 8
  125. TEMPERATURE = 5
  126. TOP_NSIGMA = 9
  127. TOP_P_TOP_K = 0
  128. TOP_A = 1
  129. MIN_P = 2
  130. TFS = 3
  131. ETA_CUTOFF = 10
  132. EPSILON_CUTOFF = 11
  133. TYPICAL_P = 4
  134. QUADRATIC = 12
  135. XTC = 13
  136. class Sampler(nn.Module):
  137. """Samples the next tokens from the model's outputs.
  138. This layer does the following:
  139. 1. Discard the hidden states that are not used for sampling (i.e., all
  140. tokens except the final one in each prompt).
  141. 2. Compute the logits for the next tokens.
  142. 3. Apply presence, frequency and repetition penalties.
  143. 4. Apply temperature scaling.
  144. 5. Apply top-p and top-k truncation.
  145. 6. Sample the next tokens.
  146. Here, each sequence group within the batch can have different sampling
  147. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  148. The structure of the logits tensor is coupled with the seq_groups in
  149. sampling_metadata. Typically, each sequence in each seq_group has one row in
  150. logits for the next token to be sampled; however, for a seq_group with a
  151. prompt request with the prompt_logprobs sampling parameter, there are rows
  152. in logits for each token in the input prompt.
  153. """
  154. def __init__(self):
  155. super().__init__()
  156. # Whether or not the SamplerOutput should have on-device tensors
  157. # containing the sampled token ids and probabilities. This is used by
  158. # speculative decoding.
  159. self.include_gpu_probs_tensor = False
  160. self.should_modify_greedy_probs_inplace = False
  161. def _init_sampling_tensors(
  162. self,
  163. logits: torch.Tensor,
  164. sampling_metadata: SamplingMetadata,
  165. ):
  166. """The goal here is to reuse sampling tensors between similar decode
  167. runs. This is possible because sampling logic does not change between
  168. decodes of the same sequences.
  169. """
  170. _, vocab_size = logits.shape
  171. # First free any existing stored sampling tensors.
  172. # This is necessary because some sampling tensors may
  173. # have pinned memory.
  174. self._sampling_tensors = None
  175. # Initialize new sampling tensors
  176. (sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
  177. do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
  178. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
  179. do_dry, do_skew, do_temp_last
  180. ) = SamplingTensors.from_sampling_metadata(
  181. sampling_metadata, vocab_size, logits.device, logits.dtype)
  182. self._sampling_tensors = sampling_tensors
  183. self._do_penalties = do_penalties
  184. self._do_no_repeat_ngrams = do_no_repeat_ngrams
  185. self._do_temperatures = do_temperatures
  186. self._do_top_p_top_k = do_top_p_top_k
  187. self._do_top_as = do_top_as
  188. self._do_min_p = do_min_p
  189. self._do_tfss = do_tfss
  190. self._do_eta_cutoffs = do_eta_cutoffs
  191. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  192. self._do_typical_ps = do_typical_ps
  193. self._do_quadratic = do_quadratic
  194. self._do_xtc = do_xtc
  195. self._do_nsgimas = do_nsigmas
  196. self._do_dry = do_dry
  197. self._do_skew = do_skew
  198. self._do_temp_last = do_temp_last
  199. def forward(
  200. self,
  201. logits: torch.Tensor,
  202. sampling_metadata: SamplingMetadata,
  203. ) -> Optional[SamplerOutput]:
  204. """
  205. Single-step scheduling:
  206. * Perform GPU-side sampling computation & compute
  207. GPU-side logprobs tensor
  208. * Pythonize sampling result & logprobs tensor
  209. Multi-step scheduling:
  210. * Perform GPU-side sampling computation & compute
  211. GPU-side logprobs tensor
  212. * Defer Pythonization of sampling result & logprobs
  213. tensor
  214. * Encapsulate arguments required for deferred Pythonization
  215. in the :class:`SamplerOutput` structure
  216. Args:
  217. logits: (num_tokens, vocab_size).
  218. sampling_metadata: Metadata for sampling.
  219. """
  220. assert logits is not None
  221. _, vocab_size = logits.shape
  222. # Prepare sampling tensors with pinned memory to avoid blocking.
  223. if not sampling_metadata.reuse_sampling_tensors:
  224. self._init_sampling_tensors(logits, sampling_metadata)
  225. elif self._do_penalties or self._do_dry:
  226. # In this case, the sampling tensors logic depends on
  227. # "output_tokens" of a sequence. As a result, we cannot
  228. # reuse sampling tensors, since "output_tokens" changes
  229. # between decode runs.
  230. self._init_sampling_tensors(logits, sampling_metadata)
  231. assert self._sampling_tensors is not None
  232. sampling_tensors = self._sampling_tensors
  233. do_penalties = self._do_penalties
  234. do_no_repeat_ngrams = self._do_no_repeat_ngrams
  235. do_temperatures = self._do_temperatures
  236. do_top_p_top_k = self._do_top_p_top_k
  237. do_top_as = self._do_top_as
  238. do_min_p = self._do_min_p
  239. do_tfss = self._do_tfss
  240. do_eta_cutoffs = self._do_eta_cutoffs
  241. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  242. do_typical_ps = self._do_typical_ps
  243. do_quadratic = self._do_quadratic
  244. do_xtc = self._do_xtc
  245. do_nsigmas = self._do_nsgimas
  246. do_dry = self._do_dry
  247. do_skew = self._do_skew
  248. do_temp_last = self._do_temp_last
  249. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  250. banned_tokens = _get_custom_token_bans(sampling_metadata)
  251. logits = _apply_token_bans(logits, banned_tokens)
  252. sampler_order = None
  253. if sampling_metadata.seq_groups:
  254. sampler_order = sampling_metadata.seq_groups[
  255. 0].sampling_params.sampler_priority
  256. # Warn if both custom order and temp_last are specified
  257. if (sampling_metadata.seq_groups and
  258. sampling_metadata.seq_groups[0].is_prompt and
  259. sampler_order is not None and
  260. do_temp_last):
  261. logger.warning(
  262. "Both sampler_priority and temperature_last=True "
  263. "were specified. Using custom sampler_priority order "
  264. "and ignoring temperature_last.")
  265. if sampler_order is None:
  266. default_order = [
  267. SamplerID.DRY,
  268. SamplerID.PENALTIES,
  269. SamplerID.NO_REPEAT_NGRAM,
  270. SamplerID.TEMPERATURE,
  271. SamplerID.TOP_NSIGMA,
  272. SamplerID.TOP_P_TOP_K,
  273. SamplerID.TOP_A,
  274. SamplerID.MIN_P,
  275. SamplerID.TFS,
  276. SamplerID.ETA_CUTOFF,
  277. SamplerID.EPSILON_CUTOFF,
  278. SamplerID.TYPICAL_P,
  279. SamplerID.QUADRATIC,
  280. SamplerID.XTC,
  281. ]
  282. sampler_order = []
  283. for sampler_id in default_order:
  284. if sampler_id == SamplerID.TEMPERATURE and do_temp_last:
  285. continue
  286. sampler_order.append(sampler_id)
  287. if sampler_id == SamplerID.XTC and do_temp_last:
  288. sampler_order.append(SamplerID.TEMPERATURE)
  289. if sampling_metadata.seq_groups and sampling_metadata.seq_groups[
  290. 0].is_prompt:
  291. logger.debug("Sampler execution order: ")
  292. for i, sampler_id in enumerate(sampler_order, 1):
  293. logger.debug(f"{i}. {SamplerID(sampler_id).name}")
  294. enabled_samplers = []
  295. # ruff: noqa: E701
  296. if do_penalties: enabled_samplers.append("PENALTIES")
  297. if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM")
  298. if do_temperatures: enabled_samplers.append("TEMPERATURE")
  299. if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K")
  300. if do_top_as: enabled_samplers.append("TOP_A")
  301. if do_min_p: enabled_samplers.append("MIN_P")
  302. if do_tfss: enabled_samplers.append("TFS")
  303. if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF")
  304. if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF")
  305. if do_typical_ps: enabled_samplers.append("TYPICAL_P")
  306. if do_quadratic: enabled_samplers.append("QUADRATIC")
  307. if do_xtc: enabled_samplers.append("XTC")
  308. if do_nsigmas: enabled_samplers.append("TOP_NSIGMA")
  309. if do_dry: enabled_samplers.append("DRY")
  310. if do_skew: enabled_samplers.append("SKEW")
  311. logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}")
  312. for sampler_id in sampler_order:
  313. if sampler_id == SamplerID.DRY and do_dry:
  314. if (sampling_metadata.seq_groups and
  315. sampling_metadata.seq_groups[0].is_prompt):
  316. logger.debug(
  317. f"Applying DRY with dry_multiplier: "
  318. f"{sampling_tensors.dry_multipliers}.")
  319. logits = _apply_dry(
  320. logits,
  321. sampling_tensors.prompt_tokens,
  322. sampling_tensors.output_tokens,
  323. sampling_tensors.dry_multipliers,
  324. sampling_tensors.dry_bases,
  325. sampling_tensors.dry_allowed_lengths,
  326. sampling_tensors.dry_sequence_breaker_ids,
  327. sampling_tensors.dry_ranges)
  328. elif sampler_id == SamplerID.PENALTIES and do_penalties:
  329. if (sampling_metadata.seq_groups and
  330. sampling_metadata.seq_groups[0].is_prompt):
  331. logger.debug(
  332. "Applying penalties with "
  333. f"pres_pen: {sampling_tensors.presence_penalties}, "
  334. f"freq_pen: {sampling_tensors.frequency_penalties}, "
  335. f"rep_pen: {sampling_tensors.repetition_penalties}.")
  336. logits = _apply_penalties(
  337. logits, sampling_tensors.prompt_tokens,
  338. sampling_tensors.output_tokens,
  339. sampling_tensors.presence_penalties,
  340. sampling_tensors.frequency_penalties,
  341. sampling_tensors.repetition_penalties)
  342. elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \
  343. do_no_repeat_ngrams:
  344. if (sampling_metadata.seq_groups and
  345. sampling_metadata.seq_groups[0].is_prompt):
  346. logger.debug(
  347. "Applying no_repeat_ngram with no_repeat_ngram_size: "
  348. f"{sampling_tensors.no_repeat_ngram_sizes}.")
  349. logits = _apply_no_repeat_ngram(
  350. logits,
  351. sampling_tensors.prompt_tokens,
  352. sampling_tensors.no_repeat_ngram_sizes)
  353. elif sampler_id == SamplerID.TEMPERATURE and do_temperatures:
  354. if (sampling_metadata.seq_groups and
  355. sampling_metadata.seq_groups[0].is_prompt):
  356. logger.debug(
  357. "Applying temperatures with temperature: "
  358. f"{sampling_tensors.temperatures}, "
  359. f"dynatemp_min: {sampling_tensors.dynatemp_mins}, "
  360. f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, "
  361. f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.")
  362. _apply_temperatures(
  363. logits, sampling_tensors.temperatures,
  364. sampling_tensors.dynatemp_mins,
  365. sampling_tensors.dynatemp_maxs,
  366. sampling_tensors.dynatemp_exps)
  367. elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas:
  368. if (sampling_metadata.seq_groups and
  369. sampling_metadata.seq_groups[0].is_prompt):
  370. logger.debug(
  371. "Applying Top-Nsigma with nsigma: "
  372. f"{sampling_tensors.nsigmas}")
  373. logits = _apply_top_nsigma(
  374. logits, sampling_tensors.nsigmas)
  375. elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \
  376. not APHRODITE_USE_SAMPLING_KERNELS:
  377. if (sampling_metadata.seq_groups and
  378. sampling_metadata.seq_groups[0].is_prompt):
  379. logger.debug(
  380. "Applying Top-p and Top-k with top-p: "
  381. f"{sampling_tensors.top_ps}, top_k: "
  382. f"{sampling_tensors.top_ks}.")
  383. logits = _apply_top_k_top_p(
  384. logits, sampling_tensors.top_ps,
  385. sampling_tensors.top_ks)
  386. elif sampler_id == SamplerID.TOP_A and do_top_as:
  387. if (sampling_metadata.seq_groups and
  388. sampling_metadata.seq_groups[0].is_prompt):
  389. logger.debug(
  390. "Applying Top-a with Top-a: "
  391. f"{sampling_tensors.top_as}.")
  392. logits = _apply_top_a(
  393. logits, sampling_tensors.top_as)
  394. elif sampler_id == SamplerID.MIN_P and do_min_p:
  395. if (sampling_metadata.seq_groups and
  396. sampling_metadata.seq_groups[0].is_prompt):
  397. logger.debug(
  398. "Applying Min-p with Min-p: "
  399. f"{sampling_tensors.min_ps}.")
  400. logits = _apply_min_p(
  401. logits, sampling_tensors.min_ps)
  402. elif sampler_id == SamplerID.TFS and do_tfss:
  403. if (sampling_metadata.seq_groups and
  404. sampling_metadata.seq_groups[0].is_prompt):
  405. logger.debug(
  406. "Applying Tail-Free Sampling with tfs: "
  407. f"{sampling_tensors.tfss}.")
  408. logits = _apply_tfs(
  409. logits, sampling_tensors.tfss)
  410. elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs:
  411. if (sampling_metadata.seq_groups and
  412. sampling_metadata.seq_groups[0].is_prompt):
  413. logger.debug(
  414. "Applying ETA Cutoff with eta_cutoff: "
  415. f"{sampling_tensors.eta_cutoffs}.")
  416. logits = _apply_eta_cutoff(
  417. logits, sampling_tensors.eta_cutoffs)
  418. elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs:
  419. if (sampling_metadata.seq_groups and
  420. sampling_metadata.seq_groups[0].is_prompt):
  421. logger.debug(
  422. "Applying Epsilon Cutoff with epsilon_cutoff: "
  423. f"{sampling_tensors.epsilon_cutoffs}.")
  424. logits = _apply_epsilon_cutoff(
  425. logits, sampling_tensors.epsilon_cutoffs)
  426. elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps:
  427. if (sampling_metadata.seq_groups and
  428. sampling_metadata.seq_groups[0].is_prompt):
  429. logger.debug(
  430. "Applying Locally Typical Sampling with typical_p: "
  431. f"{sampling_tensors.typical_ps}.")
  432. logits = _apply_typical_sampling(
  433. logits, sampling_tensors.typical_ps)
  434. elif sampler_id == SamplerID.QUADRATIC and do_quadratic:
  435. if (sampling_metadata.seq_groups and
  436. sampling_metadata.seq_groups[0].is_prompt):
  437. logger.debug(
  438. "Applying Quadratic and Cubic Sampling with "
  439. "smoothing_factors: "
  440. f"{sampling_tensors.smoothing_factors},"
  441. f" smoothing_curves: "
  442. f"{sampling_tensors.smoothing_curves}.")
  443. logits = _apply_quadratic_sampling(
  444. logits, sampling_tensors.smoothing_factors,
  445. sampling_tensors.smoothing_curves)
  446. elif sampler_id == SamplerID.XTC and do_xtc:
  447. if (sampling_metadata.seq_groups and
  448. sampling_metadata.seq_groups[0].is_prompt):
  449. logger.debug(
  450. "Applying Exclude Top Choices sampling with "
  451. f"xtc_threshold: {sampling_tensors.xtc_thresholds}, "
  452. "xtc_probability: "
  453. f"{sampling_tensors.xtc_probabilities}.")
  454. logits = _apply_xtc_sampling(
  455. logits, sampling_tensors.xtc_thresholds,
  456. sampling_tensors.xtc_probabilities)
  457. # We use float32 for probabilities and log probabilities.
  458. # Compute the probabilities.
  459. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  460. # skew needs to be applied post-softmax
  461. if do_skew:
  462. if (sampling_metadata.seq_groups and
  463. sampling_metadata.seq_groups[0].is_prompt):
  464. logger.debug(
  465. "Applying Skew sampling with skew: "
  466. f"{sampling_tensors.skews}.")
  467. # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
  468. cum_probs = torch.cumsum(probs, dim=-1)
  469. cum_probs = torch.pow(cum_probs, torch.exp(
  470. sampling_tensors.skews).unsqueeze(dim=1))
  471. probs = torch.diff(cum_probs, dim=-1,
  472. prepend=torch.zeros_like(cum_probs[..., :1]))
  473. logits = torch.log(probs)
  474. # Compute the log probabilities.
  475. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  476. # Sample the next tokens.
  477. maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
  478. probs,
  479. logprobs,
  480. sampling_metadata,
  481. sampling_tensors,
  482. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  483. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  484. )
  485. if self.include_gpu_probs_tensor:
  486. # Since we will defer sampler result Pythonization,
  487. # preserve GPU-side tensors in support of later
  488. # deferred pythonization of logprobs
  489. assert maybe_sampled_tokens_tensor is not None
  490. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  491. else:
  492. # Since Pythonization has already happened, don't preserve
  493. # GPU-side tensors.
  494. on_device_tensors = None
  495. # Get the logprobs query results.
  496. prompt_logprobs = None
  497. sample_logprobs = None
  498. if not sampling_metadata.skip_sampler_cpu_output:
  499. # Pythonize logprobs now (GPU -> CPU); do not defer.
  500. assert not isinstance(maybe_deferred_sample_results,
  501. SampleResultArgsType)
  502. prompt_logprobs, sample_logprobs = get_logprobs(
  503. logprobs, sampling_metadata, maybe_deferred_sample_results)
  504. return _build_sampler_output(
  505. maybe_deferred_sample_results,
  506. sampling_metadata,
  507. prompt_logprobs,
  508. sample_logprobs,
  509. on_device_tensors=on_device_tensors,
  510. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  511. @property
  512. def _should_modify_greedy_probs_inplace(self) -> bool:
  513. """Whether or not the sampler should modify the probability distribution
  514. of greedily-sampled tokens such that multinomial sampling would sample
  515. the greedily-sampled token.
  516. In other words, if True then we set the probability of the greedily-
  517. sampled token to 1.
  518. This is used by speculative decoding, which requires that the sampling
  519. method be encoded into the probability distribution.
  520. """
  521. return self.should_modify_greedy_probs_inplace
  522. def _get_bin_counts_and_mask(
  523. tokens: torch.Tensor,
  524. vocab_size: int,
  525. num_seqs: int,
  526. ) -> Tuple[torch.Tensor, torch.Tensor]:
  527. # Compute the bin counts for the tokens.
  528. # vocab_size + 1 for padding.
  529. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  530. dtype=torch.long,
  531. device=tokens.device)
  532. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  533. bin_counts = bin_counts[:, :vocab_size]
  534. mask = bin_counts > 0
  535. return bin_counts, mask
  536. def _get_custom_token_bans(
  537. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  538. assert sampling_metadata.seq_groups is not None
  539. banned_tokens: List[List[int]] = []
  540. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  541. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  542. seq_ids = seq_group.seq_ids
  543. custom_token_bans = sampling_params.custom_token_bans
  544. if (i < sampling_metadata.num_prompts
  545. and sampling_params.prompt_logprobs is not None):
  546. prompt_len = len(seq_group.prompt_logprob_indices)
  547. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  548. banned_tokens += [custom_token_bans] * len(seq_ids)
  549. return banned_tokens
  550. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  551. output_tokens_tensor: torch.Tensor,
  552. presence_penalties: torch.Tensor,
  553. frequency_penalties: torch.Tensor,
  554. repetition_penalties: torch.Tensor) -> torch.Tensor:
  555. num_seqs, vocab_size = logits.shape
  556. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  557. num_seqs)
  558. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  559. output_tokens_tensor, vocab_size, num_seqs)
  560. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  561. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  562. logits = torch.where(logits > 0, logits / repetition_penalties,
  563. logits * repetition_penalties)
  564. # We follow the definition in OpenAI API.
  565. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  566. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  567. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  568. return logits
  569. def _apply_temperatures(
  570. logits: torch.Tensor,
  571. temperatures: torch.Tensor,
  572. dynatemp_mins: torch.Tensor,
  573. dynatemp_maxs: torch.Tensor,
  574. dynatemp_exps: torch.Tensor,
  575. ) -> None:
  576. dynatemp_mask = (dynatemp_mins != 0) | (dynatemp_maxs != 0)
  577. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  578. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  579. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  580. dynatemp_logits = logits[dynatemp_mask]
  581. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  582. dynatemp_probs = dynatemp_shifted_logits.exp()
  583. dynatemp_entropies = -(dynatemp_probs *
  584. dynatemp_shifted_logits).nansum(dim=-1)
  585. dynatemp_max_entropies = torch.log_(
  586. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  587. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  588. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  589. normalized_entropies.pow_(dynatemp_exps))
  590. temperatures[dynatemp_mask] = dyn_temp
  591. temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
  592. temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
  593. # To prevent saturation of top logits, we shift the range to [-inf, 1]
  594. # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
  595. # Why mask? So we aren't potentially discarding data in milder temps.
  596. low_temps = temperatures < 0.1
  597. logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
  598. logits.div_(temperatures.unsqueeze(dim=1))
  599. def _apply_token_bans(logits: torch.Tensor,
  600. banned_tokens: List[List[int]]) -> torch.Tensor:
  601. for i, banned_token_ids in enumerate(banned_tokens):
  602. if i >= logits.size(0):
  603. break
  604. if not banned_token_ids:
  605. continue
  606. logits[i, banned_token_ids] = -float("inf")
  607. return logits
  608. def _apply_min_tokens_penalty(
  609. logits: torch.Tensor,
  610. sampling_metadata: SamplingMetadata,
  611. ) -> torch.Tensor:
  612. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  613. have not been generated yet
  614. """
  615. # list of indices in logits that will be set to -inf
  616. logits_to_penalize = []
  617. logits_applied = 0
  618. for seq_group in sampling_metadata.seq_groups:
  619. seq_ids = seq_group.seq_ids
  620. sampling_params = seq_group.sampling_params
  621. sample_indices = seq_group.sample_indices
  622. logits_applied += len(sample_indices) + len(
  623. seq_group.prompt_logprob_indices)
  624. if not seq_group.do_sample:
  625. continue
  626. start_idx = sample_indices[0]
  627. min_tokens = sampling_params.min_tokens
  628. token_ids_to_penalize = sampling_params.all_stop_token_ids
  629. if min_tokens > 0 and token_ids_to_penalize:
  630. seqs_to_penalize = []
  631. for j, seq_id in enumerate(seq_ids):
  632. seq_data = seq_group.seq_data[seq_id]
  633. if len(seq_data.output_token_ids_array) < min_tokens:
  634. seqs_to_penalize.append(j)
  635. if seqs_to_penalize:
  636. # convert to the index into logits
  637. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  638. # itertools.product pairs each seq index with every token id
  639. logits_to_penalize.extend(
  640. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  641. if logits_to_penalize:
  642. # use zip and * to group indices along each dimension
  643. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  644. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  645. # verifies that no rows in logits were missed unexpectedly
  646. assert logits_applied == logits.shape[0]
  647. return logits
  648. def _apply_dry(
  649. logits: torch.Tensor,
  650. input_token_ids: torch.Tensor,
  651. output_token_ids: torch.Tensor,
  652. multipliers: torch.Tensor,
  653. bases: torch.Tensor,
  654. allowed_lengths: torch.Tensor,
  655. sequence_breakers_ids: torch.Tensor,
  656. ranges: torch.Tensor,
  657. ) -> torch.Tensor:
  658. """
  659. Apply Don't Repeat Yourself (DRY) sampling to the logits.
  660. Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
  661. """
  662. VOCAB_SIZE = logits.size(-1)
  663. MAX_NGRAM = 100
  664. # Process each sequence that has a nonzero multiplier
  665. applies_to = multipliers.nonzero(as_tuple=True)[0]
  666. for irow in applies_to.tolist():
  667. # DRY applies to input AND output tokens, so concat w/o padding tokens
  668. prompt_len = (len(input_token_ids[irow]) -
  669. (input_token_ids[irow] == VOCAB_SIZE).sum().item())
  670. ouput_len = (len(output_token_ids[irow]) -
  671. (output_token_ids[irow] == VOCAB_SIZE).sum().item())
  672. token_seq = torch.cat((input_token_ids[irow][:prompt_len],
  673. output_token_ids[irow][:ouput_len]), dim=0)
  674. range_limit = ranges[irow].item()
  675. if range_limit: # could this be done in cat? yes. do i care? no.
  676. token_seq = token_seq[-range_limit:]
  677. last_token = token_seq[-1].item()
  678. if last_token in sequence_breakers_ids[irow]:
  679. continue # early out for everything up to the min_ngram check?
  680. # Build a mask of all the breaking tokens in the context
  681. break_mask = torch.zeros(len(token_seq), dtype=torch.bool,
  682. device=logits.device)
  683. for break_tok in sequence_breakers_ids[irow]:
  684. break_mask.logical_or_(token_seq == break_tok)
  685. # Find the most recent breaking token (sets ngram limit)
  686. max_ngram = 0
  687. for max_ngram in range(min(len(break_mask), MAX_NGRAM + 1)):
  688. if break_mask[-max_ngram - 1]:
  689. break
  690. min_ngram = allowed_lengths[irow].item()
  691. if max_ngram <= min_ngram: # Too close to a break to match anything
  692. continue
  693. # If [token] is picked, what's the longest ngram that would match?
  694. ngram_lens = torch.zeros(VOCAB_SIZE, dtype=torch.int32,
  695. device=logits.device)
  696. # Find all instances of the last token- potential ngrams!
  697. endpoint_indexes = torch.nonzero(token_seq == last_token,
  698. as_tuple=True)[0].tolist()
  699. # NOTE: This seems like the slow part. Haven't benchmarked.
  700. for idx in endpoint_indexes[:-1]: # Skip the last_token match
  701. unwind = 0
  702. # Check up to max_ngram tokens prior to idx (we know idx matches)
  703. for unwind in range(1, min(idx, max_ngram) + 1):
  704. if break_mask[idx - unwind]:
  705. break
  706. if token_seq[idx - unwind] != token_seq[-unwind - 1]:
  707. break
  708. next_tok = token_seq[idx+1]
  709. # The repeated tokens BEFORE next_tok (+1 to include [idx]).
  710. ngram_lens[next_tok] = max(ngram_lens[next_tok].item(), unwind + 1)
  711. # Convert ngram lengths to penalty exponents
  712. penalty_mask = ngram_lens > 0
  713. scales = bases[irow] ** (ngram_lens[penalty_mask] - min_ngram)
  714. # Calculate and apply penalties
  715. logits[irow][penalty_mask] -= multipliers[irow] * scales
  716. return logits
  717. def _apply_no_repeat_ngram(
  718. logits: torch.Tensor,
  719. input_ids: torch.Tensor,
  720. ngram_size: torch.Tensor,
  721. ) -> torch.Tensor:
  722. """Apply no-repeat-ngram penalty which sets logits to -inf for tokens that
  723. would create a repeated n-gram.
  724. """
  725. if torch.all(ngram_size == 0):
  726. return logits
  727. batch_size = logits.shape[0]
  728. for i in range(batch_size):
  729. size = int(ngram_size[i].item())
  730. if size == 0:
  731. continue
  732. cur_len = len(input_ids[i])
  733. if cur_len < size:
  734. continue
  735. banned_tokens = _calc_banned_ngram_tokens(
  736. ngram_size=size,
  737. prev_input_ids=input_ids[i],
  738. cur_len=cur_len-1
  739. )
  740. if banned_tokens:
  741. logits[i, banned_tokens] = -float("inf")
  742. return logits
  743. def _apply_top_k_top_p(
  744. logits: torch.Tensor,
  745. p: torch.Tensor,
  746. k: torch.Tensor,
  747. ) -> torch.Tensor:
  748. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  749. # Apply top-k.
  750. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  751. # Get all the top_k values.
  752. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  753. top_k_mask = logits_sort < top_k_mask
  754. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  755. # Apply top-p.
  756. probs_sort = logits_sort.softmax(dim=-1)
  757. probs_sum = probs_sort.cumsum(dim=-1)
  758. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  759. # at least one
  760. top_p_mask[:, -1] = False
  761. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  762. # Re-sort the probabilities.
  763. logits = torch.empty_like(logits_sort).scatter_(dim=-1,
  764. index=logits_idx,
  765. src=logits_sort)
  766. return logits
  767. def _apply_min_p(
  768. logits: torch.Tensor,
  769. min_p: torch.Tensor,
  770. ) -> torch.Tensor:
  771. """
  772. Adapted from
  773. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  774. """
  775. probs = torch.softmax(logits, dim=-1)
  776. top_probs, _ = probs.max(dim=-1, keepdim=True)
  777. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  778. tokens_to_remove = probs < scaled_min_p
  779. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  780. return logits
  781. def _apply_top_a(
  782. logits: torch.Tensor,
  783. top_a: torch.Tensor,
  784. ) -> torch.Tensor:
  785. probs = torch.softmax(logits, dim=-1)
  786. top_probs, _ = probs.max(dim=-1, keepdim=True)
  787. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  788. tokens_to_remove = probs < threshold
  789. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  790. return logits
  791. def _apply_tfs(
  792. logits: torch.Tensor,
  793. tfs: torch.Tensor,
  794. ) -> torch.Tensor:
  795. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  796. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  797. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  798. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  799. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  800. tfs_mask = torch.cat(
  801. (
  802. torch.zeros(
  803. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  804. tfs_mask,
  805. torch.ones(
  806. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  807. ),
  808. dim=-1,
  809. )
  810. logits_sort[tfs_mask] = -float("inf")
  811. logits = torch.gather(logits_sort,
  812. dim=-1,
  813. index=torch.argsort(logits_idx, dim=-1))
  814. return logits
  815. def _apply_eta_cutoff(
  816. logits: torch.Tensor,
  817. eta_cutoff: torch.Tensor,
  818. ) -> torch.Tensor:
  819. shifted_logits = torch.log_softmax(logits, dim=-1)
  820. probs = shifted_logits.exp()
  821. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  822. eps = torch.min(eta_cutoff,
  823. torch.sqrt(eta_cutoff) *
  824. torch.exp(neg_entropy)).unsqueeze(dim=1)
  825. eta_mask = probs < eps
  826. # guard against nulling out all the logits
  827. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  828. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  829. logits[eta_mask] = -float("inf")
  830. return logits
  831. def _apply_epsilon_cutoff(
  832. logits: torch.Tensor,
  833. epsilon_cutoff: torch.Tensor,
  834. ) -> torch.Tensor:
  835. probs = logits.softmax(dim=-1)
  836. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  837. # guard against nulling out all the logits
  838. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  839. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  840. logits[eps_mask] = -float("inf")
  841. return logits
  842. def _apply_typical_sampling(
  843. logits: torch.Tensor,
  844. typical_p: torch.Tensor,
  845. ) -> torch.Tensor:
  846. shifted_logits = torch.log_softmax(logits, dim=-1)
  847. probs = shifted_logits.exp()
  848. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  849. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  850. _, indices = torch.sort(surprisal_deviations)
  851. reordered_probs = probs.gather(-1, indices)
  852. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  853. dim=1)
  854. min_tokens_to_keep = 1
  855. # Keep at least min_tokens_to_keep
  856. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  857. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  858. logits[typ_mask] = -float("inf")
  859. return logits
  860. def _apply_quadratic_sampling(
  861. logits: torch.Tensor,
  862. smoothing_factor: torch.Tensor,
  863. smoothing_curve: torch.Tensor,
  864. ) -> torch.Tensor:
  865. """
  866. Applies a quadratic transformation to the logits based on the
  867. provided smoothing factors and curves. The transformation is
  868. centered around the maximum logit value in the batch.
  869. The transformation involves a quadratic and cubic term, with the
  870. cubic term controlled by the smoothing curve. The quadratic term is
  871. scaled by the smoothing factor, and the cubic term is scaled by the
  872. product of the smoothing factor and the smoothing curve.
  873. params:
  874. logits (torch.Tensor): The logits to be transformed.
  875. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  876. term in the transformation.
  877. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  878. in the transformation.
  879. returns:
  880. torch.Tensor: The transformed logits.
  881. Credits: @kalomaze
  882. """
  883. mask = smoothing_factor != 0
  884. smoothing_factor.unsqueeze_(dim=1)
  885. smoothing_curve.unsqueeze_(dim=1)
  886. k = smoothing_factor * (3 - smoothing_curve) / 2
  887. s = smoothing_factor * (smoothing_curve - 1) / 2
  888. quadlogits = logits[mask] # limit to logits using this sampler
  889. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  890. # Construct the delta from each logit to its new value
  891. diff = quadlogits - max_logits
  892. diff -= diff**2 * (s[mask] * diff - k[mask])
  893. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  894. logits[mask] -= diff
  895. return logits
  896. def _apply_xtc_sampling(
  897. logits: torch.Tensor,
  898. xtc_thresholds: torch.Tensor,
  899. xtc_probabilities: torch.Tensor,
  900. ) -> torch.Tensor:
  901. """Apply Exclude Top Choices (XTC) sampling to the logits.
  902. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  903. Args:
  904. logits: (num_tokens, vocab_size) The input logits.
  905. xtc_thresholds: (num_tokens,) The threshold for each token.
  906. xtc_probabilities: (num_tokens,) The probability of applying XTC
  907. for each token.
  908. Returns:
  909. torch.Tensor: The modified logits.
  910. """
  911. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  912. if not apply_xtc.any():
  913. return logits
  914. probs = torch.softmax(logits, dim=-1)
  915. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  916. # Find indices where the next probability is above the threshold
  917. # Skips the top choice, which later on becomes skipping the last choice.
  918. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  919. # Apply XTC only to rows where it should be applied
  920. for i in range(logits.shape[0]):
  921. if apply_xtc[i]:
  922. # Count logits above the threshold (skipping the first)
  923. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  924. if indices_to_remove > 0:
  925. # Implies the top logit and at least one other is >= threshold.
  926. # Mask out above_thresh logits except the last/lowest one.
  927. logits[i].scatter_(
  928. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  929. return logits
  930. def _apply_top_nsigma(
  931. logits: torch.Tensor,
  932. nsigma: torch.Tensor,
  933. ) -> torch.Tensor:
  934. """Apply top-nsigma truncation to the logits.
  935. Reference: https://arxiv.org/abs/2411.07641
  936. Args:
  937. logits: Logits of shape (num_tokens, vocab_size)
  938. nsigma: Number of standard deviations to use as threshold
  939. Returns:
  940. Modified logits with values below threshold set to -inf
  941. """
  942. std = logits.std(dim=-1, keepdim=True)
  943. threshold = (logits.max(dim=-1, keepdim=True).values -
  944. nsigma.unsqueeze(dim=1) * std)
  945. logits[logits < threshold] = float("-inf")
  946. return logits
  947. def _greedy_sample(
  948. selected_seq_groups: List[SequenceGroupToSample],
  949. samples: torch.Tensor,
  950. ) -> List[Tuple[List[int], List[int]]]:
  951. """Run greedy sampling on a given samples.
  952. Args:
  953. selected_seq_groups: A list of sequence groups batched.
  954. samples: (num_selected_samples,) A tensor of samples. The length of
  955. samples could be smaller than selected_seq_groups if
  956. seq_group.do_sample is False.
  957. Returns:
  958. Tuple of (next_token_ids, parent_ids). The length of returned list is
  959. same as the length of selected_seq_groups. If the corresponding
  960. seq_group has do_sample=False, tuple contains ([], [])
  961. """
  962. samples = samples.tolist()
  963. sample_idx = 0
  964. results = []
  965. for seq_group in selected_seq_groups:
  966. if not seq_group.do_sample:
  967. results.append(([], []))
  968. continue
  969. seq_ids = seq_group.seq_ids
  970. num_parent_seqs = len(seq_ids)
  971. assert num_parent_seqs == 1, (
  972. "Greedy sampling should have only one seq.")
  973. parent_ids = list(range(num_parent_seqs))
  974. next_token_ids = [samples[sample_idx]]
  975. results.append((next_token_ids, parent_ids))
  976. sample_idx += num_parent_seqs
  977. return results
  978. def _random_sample(
  979. selected_seq_groups: List[SequenceGroupToSample],
  980. random_samples: torch.Tensor,
  981. ) -> List[Tuple[List[int], List[int]]]:
  982. """Run random sampling on a given samples.
  983. Args:
  984. selected_seq_groups: A list of sequence groups batched.
  985. random_samples: (num_selected_samples,) A tensor of samples. The
  986. length of samples could be smaller than selected_seq_groups if
  987. seq_group.do_sample is False.
  988. Returns:
  989. Tuple of (next_token_ids, parent_ids). The length of returned list is
  990. same as the length of selected_seq_groups. If the corresponding
  991. seq_group has do_sample=False, tuple contains ([], [])
  992. """
  993. # Find the maximum best_of value of the prompt phase requests.
  994. random_samples = random_samples.cpu()
  995. sample_idx = 0
  996. results = []
  997. for seq_group in selected_seq_groups:
  998. if not seq_group.do_sample:
  999. results.append(([], []))
  1000. continue
  1001. seq_ids = seq_group.seq_ids
  1002. sampling_params = seq_group.sampling_params
  1003. is_prompt = seq_group.is_prompt
  1004. num_parent_seqs = len(seq_ids)
  1005. if is_prompt:
  1006. # Prompt phase.
  1007. parent_ids = [0] * sampling_params.best_of
  1008. next_token_ids = random_samples[
  1009. sample_idx, :sampling_params.best_of].tolist()
  1010. else:
  1011. # Generation phase.
  1012. parent_ids = list(range(num_parent_seqs))
  1013. next_token_ids = random_samples[sample_idx:sample_idx +
  1014. num_parent_seqs, 0].tolist()
  1015. results.append((next_token_ids, parent_ids))
  1016. sample_idx += num_parent_seqs
  1017. return results
  1018. def _beam_search_sample(
  1019. selected_seq_groups: List[SequenceGroupToSample],
  1020. logprobs: torch.Tensor,
  1021. ) -> List[Tuple[List[int], List[int]]]:
  1022. """Run beam sampling on a given samples.
  1023. Args:
  1024. selected_seq_groups: A list of sequence groups batched.
  1025. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  1026. on selected sample indices.
  1027. Returns:
  1028. Tuple of (next_token_ids, parent_ids). The length of returned list is
  1029. same as the length of selected_seq_groups. If the corresponding
  1030. seq_group has do_sample=False, tuple contains ([], [])
  1031. """
  1032. # We sample 2 * beam_width candidates to make sure that with high
  1033. # probability we can get `beam_width` candidates in addition to
  1034. # the finished sequences for the next iteration. See
  1035. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  1036. # for details. See also HF reference:
  1037. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  1038. #
  1039. # NOTE: Beam search is not vectorized, so its speed can be slower than
  1040. # other sampling methods.
  1041. sample_idx = 0
  1042. results = []
  1043. for seq_group in selected_seq_groups:
  1044. if not seq_group.do_sample:
  1045. results.append(([], []))
  1046. continue
  1047. is_prompt = seq_group.is_prompt
  1048. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  1049. num_parent_seqs = len(seq_ids)
  1050. beam_width = sampling_params.best_of
  1051. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  1052. if is_prompt:
  1053. # Prompt phase.
  1054. assert num_parent_seqs == 1, (
  1055. "Prompt input should have only one seq.")
  1056. parent_ids = [0] * (2 * beam_width)
  1057. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  1058. 2 * beam_width)
  1059. next_token_ids = next_token_ids.tolist()
  1060. else:
  1061. # Generation phase.
  1062. cumulative_logprobs = [
  1063. seq_group.seq_data[seq_id].cumulative_logprob
  1064. for seq_id in seq_ids
  1065. ]
  1066. cumulative_logprobs = torch.tensor(
  1067. cumulative_logprobs,
  1068. dtype=torch.float,
  1069. device=seq_group_logprobs.device)
  1070. seq_group_logprobs = (seq_group_logprobs +
  1071. cumulative_logprobs.unsqueeze(dim=1))
  1072. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  1073. 2 * beam_width)
  1074. topk_ids = topk_ids.tolist()
  1075. vocab_size = seq_group_logprobs.size(-1)
  1076. parent_ids = [i // vocab_size for i in topk_ids]
  1077. next_token_ids = [i % vocab_size for i in topk_ids]
  1078. results.append((next_token_ids, parent_ids))
  1079. sample_idx += num_parent_seqs
  1080. assert sample_idx == logprobs.size(0)
  1081. return results
  1082. # torch.multinomial forces a GPU<->CPU sync.
  1083. # Therefore, we use an optimized implementation instead.
  1084. # Note that we always sample with replacement.
  1085. # probs will be modified in place, but this is fine, as we pass
  1086. # in a copy already.
  1087. def _multinomial(
  1088. probs: torch.Tensor,
  1089. num_samples: int,
  1090. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  1091. ) -> torch.Tensor:
  1092. if num_samples > 1:
  1093. probs = probs.repeat_interleave(num_samples, dim=0)
  1094. q = torch.empty_like(probs)
  1095. if seq_groups is None:
  1096. q.exponential_()
  1097. else:
  1098. sample_idx = 0
  1099. for seq_group in seq_groups:
  1100. seq_ids = seq_group.seq_ids
  1101. stride = len(seq_ids) * num_samples
  1102. assert seq_group.generator is not None
  1103. q[sample_idx:sample_idx +
  1104. stride].exponential_(generator=seq_group.generator)
  1105. sample_idx += stride
  1106. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  1107. def _top_k_top_p_multinomial_with_kernels(
  1108. probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
  1109. num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
  1110. max_top_k_round = 32
  1111. if num_samples > 1:
  1112. probs = probs.repeat_interleave(num_samples, dim=0)
  1113. top_ks = top_ks.repeat_interleave(num_samples)
  1114. top_ps = top_ps.repeat_interleave(num_samples)
  1115. batch_size = probs.shape[0]
  1116. uniform_samples = torch.empty((max_top_k_round, batch_size),
  1117. device=probs.device)
  1118. if seq_groups is None:
  1119. uniform_samples.uniform_()
  1120. else:
  1121. sample_idx = 0
  1122. for seq_group in seq_groups:
  1123. seq_ids = seq_group.seq_ids
  1124. stride = len(seq_ids) * num_samples
  1125. assert seq_group.generator is not None
  1126. uniform_samples[:, sample_idx:sample_idx +
  1127. stride].uniform_(generator=seq_group.generator)
  1128. sample_idx += stride
  1129. batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
  1130. probs,
  1131. uniform_samples,
  1132. top_ks,
  1133. top_ps,
  1134. )
  1135. if not success.all():
  1136. warnings.warn("CUDA rejection sampling failed, fallback.",
  1137. stacklevel=1)
  1138. probs = ops.top_k_renorm_prob(probs, top_ks)
  1139. probs = ops.top_p_renorm_prob(probs, top_ps)
  1140. batch_next_token_ids = ops.sampling_from_probs(
  1141. probs, uniform_samples[0])
  1142. return batch_next_token_ids.view(-1, num_samples)
  1143. def get_pythonized_sample_results(
  1144. sample_result_args: SampleResultArgsType) -> SampleResultType:
  1145. """This function consumes GPU-side sampler results and computes
  1146. Pythonized CPU-side sampler results (GPU -> CPU sync.)
  1147. Single-step scheduling: this function is invoked at sampling-time
  1148. for immediate Pythonization.
  1149. Multi-step scheduling: Pythonization is deferred until after multiple
  1150. GPU-side steps have been completed.
  1151. Args:
  1152. sample_result_args: GPU-side inputs to the Pythonization process
  1153. Returns:
  1154. Pythonized sampler results
  1155. """
  1156. (
  1157. sample_metadata,
  1158. sampling_metadata,
  1159. greedy_samples,
  1160. multinomial_samples,
  1161. beam_search_logprobs,
  1162. sample_results_dict,
  1163. ) = (
  1164. sample_result_args.sample_metadata,
  1165. sample_result_args.sampling_metadata,
  1166. sample_result_args.greedy_samples,
  1167. sample_result_args.multinomial_samples,
  1168. sample_result_args.beam_search_logprobs,
  1169. sample_result_args.sample_results_dict,
  1170. )
  1171. for sampling_type in SamplingType:
  1172. if sampling_type not in sample_metadata:
  1173. continue
  1174. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  1175. if sampling_type == SamplingType.GREEDY:
  1176. sample_results = _greedy_sample(seq_groups, greedy_samples)
  1177. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1178. sample_results = _random_sample(seq_groups,
  1179. multinomial_samples[sampling_type])
  1180. elif sampling_type == SamplingType.BEAM:
  1181. sample_results = _beam_search_sample(seq_groups,
  1182. beam_search_logprobs)
  1183. sample_results_dict.update(zip(seq_group_id, sample_results))
  1184. return [
  1185. sample_results_dict.get(i, ([], []))
  1186. for i in range(len(sampling_metadata.seq_groups))
  1187. ]
  1188. def _sample_with_torch(
  1189. probs: torch.Tensor,
  1190. logprobs: torch.Tensor,
  1191. sampling_metadata: SamplingMetadata,
  1192. sampling_tensors: SamplingTensors,
  1193. include_gpu_probs_tensor: bool,
  1194. modify_greedy_probs: bool,
  1195. ) -> SampleReturnType:
  1196. """Torch-oriented _sample() implementation.
  1197. Single-step scheduling:
  1198. * Perform GPU-side sampling computation
  1199. * Immediately Pythonize sampling result
  1200. Multi-step scheduling:
  1201. * Perform GPU-side sampling computation
  1202. * Defer Pythonization & preserve GPU-side
  1203. tensors required for Pythonization
  1204. """
  1205. categorized_seq_group_ids = {t: [] for t in SamplingType}
  1206. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  1207. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  1208. sampling_params = seq_group.sampling_params
  1209. sampling_type = sampling_params.sampling_type
  1210. categorized_seq_group_ids[sampling_type].append(i)
  1211. sample_results_dict: SampleResultsDictType = {}
  1212. sample_metadata: SampleMetadataType = {}
  1213. multinomial_samples: MultinomialSamplesType = {}
  1214. greedy_samples: Optional[torch.Tensor] = None
  1215. beam_search_logprobs: Optional[torch.Tensor] = None
  1216. # Create output tensor for sampled token ids.
  1217. if include_gpu_probs_tensor:
  1218. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  1219. 1,
  1220. dtype=torch.long,
  1221. device=logprobs.device)
  1222. else:
  1223. sampled_token_ids_tensor = None
  1224. # Counterintuitively, having two loops here is actually faster.
  1225. # The first loop can run without waiting on GPU<->CPU sync.
  1226. for sampling_type in SamplingType:
  1227. sample_indices = categorized_sample_indices[sampling_type]
  1228. num_tokens = len(sample_indices)
  1229. if num_tokens == 0:
  1230. continue
  1231. seq_group_id = categorized_seq_group_ids[sampling_type]
  1232. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  1233. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  1234. long_sample_indices = sample_indices.long()
  1235. if sampling_type == SamplingType.GREEDY:
  1236. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  1237. dim=-1)
  1238. if include_gpu_probs_tensor:
  1239. # Store sampled tokens in output tensor.
  1240. sampled_token_ids_tensor[
  1241. long_sample_indices] = greedy_samples.unsqueeze(-1)
  1242. if modify_greedy_probs:
  1243. # If required, modify the probabilities such that sampling from
  1244. # the modified distribution would always sample the argmax
  1245. # token id.
  1246. _modify_greedy_probs_inplace(logprobs, probs,
  1247. long_sample_indices,
  1248. greedy_samples)
  1249. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1250. max_best_of_in_batch = 1
  1251. for seq_group in seq_groups:
  1252. if seq_group.is_prompt:
  1253. sampling_params = seq_group.sampling_params
  1254. max_best_of_in_batch = max(max_best_of_in_batch,
  1255. sampling_params.best_of)
  1256. seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
  1257. seq_groups)
  1258. if APHRODITE_USE_SAMPLING_KERNELS and not is_cpu():
  1259. multinomial_samples[
  1260. sampling_type] = _top_k_top_p_multinomial_with_kernels(
  1261. probs[long_sample_indices],
  1262. sampling_tensors.top_ks[long_sample_indices],
  1263. sampling_tensors.top_ps[long_sample_indices],
  1264. max_best_of_in_batch,
  1265. seq_groups_arg,
  1266. )
  1267. else:
  1268. multinomial_samples[sampling_type] = _multinomial(
  1269. probs[long_sample_indices],
  1270. max_best_of_in_batch,
  1271. seq_groups=seq_groups_arg)
  1272. if sampled_token_ids_tensor is not None:
  1273. # Store sampled tokens in output tensor.
  1274. sampled_token_ids_tensor[long_sample_indices] = \
  1275. multinomial_samples[sampling_type].to(torch.long)
  1276. elif sampling_type == SamplingType.BEAM:
  1277. beam_search_logprobs = logprobs[sample_indices]
  1278. else:
  1279. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  1280. # Encapsulate arguments for computing Pythonized sampler
  1281. # results, whether deferred or otherwise.
  1282. maybe_deferred_args = SampleResultArgsType(
  1283. sampling_metadata=sampling_metadata,
  1284. sample_metadata=sample_metadata,
  1285. multinomial_samples=multinomial_samples,
  1286. greedy_samples=greedy_samples,
  1287. beam_search_logprobs=beam_search_logprobs,
  1288. sample_results_dict=sample_results_dict)
  1289. if not sampling_metadata.skip_sampler_cpu_output:
  1290. # GPU<->CPU sync happens here.
  1291. # This also converts the sampler output to a Python object.
  1292. # Return Pythonized sampler result & sampled token ids
  1293. return get_pythonized_sample_results(
  1294. maybe_deferred_args), sampled_token_ids_tensor
  1295. else:
  1296. # Defer sampler result Pythonization; return deferred
  1297. # Pythonization args & sampled token ids
  1298. return (
  1299. maybe_deferred_args,
  1300. sampled_token_ids_tensor,
  1301. )
  1302. def _sample(
  1303. probs: torch.Tensor,
  1304. logprobs: torch.Tensor,
  1305. sampling_metadata: SamplingMetadata,
  1306. sampling_tensors: SamplingTensors,
  1307. include_gpu_probs_tensor: bool,
  1308. modify_greedy_probs: bool,
  1309. ) -> SampleReturnType:
  1310. """
  1311. Args:
  1312. probs: (num_query_tokens_in_batch, num_vocab)
  1313. logprobs: (num_query_tokens_in_batch, num_vocab)
  1314. sampling_metadata: The metadata for a batch for sampling.
  1315. sampling_tensors: Tensors that include sampling related metadata.
  1316. Returns:
  1317. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  1318. If sampling is skipped, it returns ([], [])
  1319. sampled_token_ids_tensor: A tensor of sampled token ids.
  1320. """
  1321. return _sample_with_torch(
  1322. probs,
  1323. logprobs,
  1324. sampling_metadata,
  1325. sampling_tensors,
  1326. include_gpu_probs_tensor=include_gpu_probs_tensor,
  1327. modify_greedy_probs=modify_greedy_probs,
  1328. )
  1329. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  1330. """
  1331. This function calculates the ranks of the chosen tokens in a logprob tensor.
  1332. Args:
  1333. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  1334. where N is the no. of tokens and M is the vocab dim.
  1335. indices (torch.Tensor): List of chosen token indices.
  1336. Returns:
  1337. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  1338. Each element in the returned tensor represents the rank
  1339. of the chosen token in the input logprob tensor.
  1340. """
  1341. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  1342. indices]
  1343. return (x > vals[:, None]).long().sum(1).add_(1)
  1344. def get_logprobs(
  1345. logprobs: torch.Tensor,
  1346. sampling_metadata: SamplingMetadata,
  1347. sample_results: List[Tuple[List[int], List[int]]],
  1348. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  1349. """Return sample lobprobs and prompt logprobs.
  1350. The logic consists of 3 parts.
  1351. - Select indices to compute logprob from, ranks of token ids, and
  1352. the top k token ids from logprobs.
  1353. - Compute prompt logprobs if required.
  1354. - Compute sample logprobs if required.
  1355. Args:
  1356. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  1357. logprob per vocab. Sequence groups' query tokens are batched in a
  1358. single flattened tensor. For example, assuming there are N
  1359. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  1360. prompt logprob is enabled), decode tokens for seq_group_1 (if
  1361. sampling is required), prefill tokens for seq_group_2, ...
  1362. sampling_metadata: The sampling metadata.
  1363. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  1364. parent_ids) for each sequence group. When beam search is enabled,
  1365. sample_results can contain different number of seq_ids from
  1366. sampling_metadata.seq_groups. It is because beam search creates
  1367. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  1368. BEAM_WIDTH number of seq_ids).
  1369. Returns:
  1370. A tuple of prompt and sample logprobs per sequence group in a batch.
  1371. """
  1372. # The index of query token to calculate logprobs. It includes both
  1373. # prompt and sample logprob indices.
  1374. query_indices: List[int] = []
  1375. # The next token ids to get the logprob value from.
  1376. next_token_ids: List[int] = []
  1377. # The largest requested number of logprobs. We find logprobs as many as the
  1378. # largest num logprobs in this API. If every logprobs is None, it will be
  1379. # set to -1.
  1380. largest_num_logprobs = -1
  1381. # If beam search is enabled.
  1382. use_beam_search = False
  1383. # Select indices to compute logprob from, ranks of token ids, and the top
  1384. # k token ids from logprobs.
  1385. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  1386. sample_results):
  1387. sampling_params = seq_group.sampling_params
  1388. # Update indices and tokens for prompt logprobs.
  1389. if (seq_group.is_prompt
  1390. and sampling_params.prompt_logprobs is not None):
  1391. largest_num_logprobs = max(largest_num_logprobs,
  1392. sampling_params.prompt_logprobs)
  1393. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1394. query_indices.extend(seq_group.prompt_logprob_indices)
  1395. next_token_ids.extend(next_prompt_tokens)
  1396. # Update indices and next tokenes for sample logprob.
  1397. if seq_group.do_sample:
  1398. token_ids, parent_seq_ids = sample_result
  1399. # NOTE: We cannot directly use sample_indices because
  1400. # sample_indices only contain parent seq_ids of a previous step.
  1401. # The current step may have different number of seq_ids, and
  1402. # we can obtain it from `sample_result[1]`.
  1403. query_idx = seq_group.sample_indices[0]
  1404. query_indices.extend(
  1405. [query_idx + parent_id for parent_id in parent_seq_ids])
  1406. next_token_ids.extend(token_ids)
  1407. if sampling_params.logprobs is not None:
  1408. largest_num_logprobs = max(largest_num_logprobs,
  1409. sampling_params.logprobs)
  1410. use_beam_search = use_beam_search or sampling_params.use_beam_search
  1411. assert len(next_token_ids) == len(query_indices)
  1412. if len(query_indices) == 0:
  1413. empty_sampled_logprob = []
  1414. empty_prompt_logprob = None
  1415. return [empty_prompt_logprob], [empty_sampled_logprob]
  1416. selected_logprobs, ranks = None, None
  1417. top_logprobs, top_token_ids = None, None
  1418. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  1419. # skip the whole logprob calculation.
  1420. if largest_num_logprobs >= 0 or use_beam_search:
  1421. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  1422. next_token_ids_gpu = torch.tensor(next_token_ids,
  1423. device=logprobs.device)
  1424. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  1425. # contain duplicates if beam search is enabled.
  1426. selected_logprobs = logprobs[[
  1427. query_indices_gpu,
  1428. next_token_ids_gpu,
  1429. ]]
  1430. ranks = _get_ranks(
  1431. logprobs[query_indices_gpu],
  1432. next_token_ids_gpu,
  1433. )
  1434. assert selected_logprobs.shape[0] == ranks.shape[0]
  1435. # We need to compute top k only if there exists logprobs > 0.
  1436. if largest_num_logprobs > 0:
  1437. # Logprobs of topk tokens for a batch of sequence groups.
  1438. # (num_query_tokens_across_batch).
  1439. top_logprobs, top_token_ids = torch.topk(logprobs,
  1440. largest_num_logprobs,
  1441. dim=-1)
  1442. top_logprobs = top_logprobs.to('cpu')
  1443. top_token_ids = top_token_ids.to('cpu')
  1444. selected_logprobs = selected_logprobs.to('cpu')
  1445. ranks = ranks.to('cpu')
  1446. # Find prompt/sample logprobs.
  1447. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  1448. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  1449. top_logprob_idx = 0
  1450. selected_logprobs_idx = 0
  1451. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1452. sample_results):
  1453. (prompt_logprobs, top_logprob_idx,
  1454. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1455. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1456. selected_logprobs_idx, top_logprob_idx)
  1457. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1458. (sampled_logprobs, top_logprob_idx,
  1459. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1460. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1461. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1462. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1463. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1464. def _get_prompt_logprob_if_needed(
  1465. seq_group: SequenceGroupToSample,
  1466. selected_logprobs: torch.Tensor,
  1467. ranks: torch.Tensor,
  1468. top_token_ids: torch.Tensor,
  1469. top_logprobs: torch.Tensor,
  1470. selected_logprobs_idx: int,
  1471. top_logprob_idx: int,
  1472. ):
  1473. """Compute the prompt logprob from a sequence group if needed."""
  1474. sampling_params = seq_group.sampling_params
  1475. is_prompt = seq_group.is_prompt
  1476. # Find prompt logprobs
  1477. prompt_logprobs: Optional[PromptLogprobs] = None
  1478. if is_prompt and sampling_params.prompt_logprobs is not None:
  1479. prompt_logprobs = []
  1480. num_logprobs = sampling_params.prompt_logprobs
  1481. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1482. # Pre-select indexes and create a list. It is faster than calling .item
  1483. # repetitively.
  1484. selected_logprob_items = selected_logprobs[
  1485. selected_logprobs_idx:selected_logprobs_idx +
  1486. len(next_prompt_tokens)].tolist()
  1487. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1488. len(next_prompt_tokens)].tolist()
  1489. for idx, token_id in enumerate(next_prompt_tokens):
  1490. # Calculate the prompt logprob of the real prompt tokens.
  1491. # {token_id: (logprob, rank_from_vocab)}
  1492. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1493. token_id: (selected_logprob_items[idx], rank_items[idx])
  1494. }
  1495. # Add top K prompt logprobs along with its rank.
  1496. if num_logprobs > 0:
  1497. top_ids = top_token_ids[
  1498. top_logprob_idx, :num_logprobs].tolist()
  1499. top_probs = top_logprobs[
  1500. top_logprob_idx, :num_logprobs].tolist()
  1501. # Top K is already sorted by rank, so we can use 1 ~
  1502. # num_logprobs + 1 for rank.
  1503. top_ranks = range(1, num_logprobs + 1)
  1504. prompt_logprobs_dict.update({
  1505. top_id: (top_prob, rank)
  1506. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1507. top_ranks)
  1508. })
  1509. prompt_logprobs.append({
  1510. token_id: Logprob(*logprob_and_rank)
  1511. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1512. })
  1513. # + 1 to go to the next prompt token.
  1514. top_logprob_idx += 1
  1515. # + len(next_prompt_tokens) to go to the next prompt.
  1516. selected_logprobs_idx += len(next_prompt_tokens)
  1517. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1518. def _get_sampled_logprob_if_needed(
  1519. seq_group: SequenceGroupToSample,
  1520. sample_result: Tuple[List[int], List[int]],
  1521. selected_logprobs: torch.Tensor,
  1522. ranks: torch.Tensor,
  1523. top_token_ids: torch.Tensor,
  1524. top_logprobs: torch.Tensor,
  1525. selected_logprobs_idx: int,
  1526. top_logprob_idx: int,
  1527. ):
  1528. """Compute the sample logprob if needed."""
  1529. seq_ids = seq_group.seq_ids
  1530. num_logprobs = seq_group.sampling_params.logprobs
  1531. use_beam_search = seq_group.sampling_params.use_beam_search
  1532. sampled_logprobs: SampleLogprobs = []
  1533. next_token_ids, parent_seq_ids = sample_result
  1534. if seq_group.do_sample:
  1535. assert len(next_token_ids) > 0
  1536. if num_logprobs is None and not use_beam_search:
  1537. for next_token_id in next_token_ids:
  1538. # Use a dummy logprob
  1539. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1540. else:
  1541. # Pre-select items from tensor. tolist() is faster than repetitive
  1542. # `.item()` calls.
  1543. selected_logprob_items = selected_logprobs[
  1544. selected_logprobs_idx:selected_logprobs_idx +
  1545. len(next_token_ids)].tolist()
  1546. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1547. len(next_token_ids)].tolist()
  1548. for idx, (next_token_id, parent_id) in enumerate(
  1549. zip(next_token_ids, parent_seq_ids)):
  1550. # Get the logprob of a sampled token.
  1551. sampled_logprobs_dict = {
  1552. next_token_id:
  1553. (selected_logprob_items[idx], rank_items[idx])
  1554. }
  1555. if num_logprobs is not None and num_logprobs > 0:
  1556. # Get top K logprobs.
  1557. top_ids = top_token_ids[top_logprob_idx +
  1558. parent_id, :num_logprobs].tolist()
  1559. top_probs = top_logprobs[
  1560. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1561. # Top K is already sorted by rank, so we can use 1 ~
  1562. # num_logprobs + 1 for rank.
  1563. top_ranks = range(1, num_logprobs + 1)
  1564. sampled_logprobs_dict.update({
  1565. top_id: (top_prob, rank)
  1566. for top_id, top_prob, rank in zip(
  1567. top_ids, top_probs, top_ranks)
  1568. })
  1569. sampled_logprobs.append({
  1570. token_id: Logprob(*logprob_and_rank)
  1571. for token_id, logprob_and_rank in
  1572. sampled_logprobs_dict.items()
  1573. })
  1574. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1575. # logprobs for the current step, which has len(next_token_ids) tokens
  1576. # per sequence group. `logprobs` includes logprobs from the previous
  1577. # steps, which has len(seq_ids) tokens per sequence group.
  1578. # Iterate to the next sequence group in a batch.
  1579. selected_logprobs_idx += len(next_token_ids)
  1580. # Iterate to the next sequence group in a batch.
  1581. top_logprob_idx += len(seq_ids)
  1582. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1583. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1584. sample_indices: torch.Tensor,
  1585. greedy_samples: torch.Tensor) -> None:
  1586. """Modify the probability distributions of the greedily-sampled tokens such
  1587. that each sampled token has a "probability" of 1.0. This is required by
  1588. speculative decoding, which depends on the sampling method being encoded
  1589. within the probability distribution for correctness.
  1590. # Why do we only need to do this for greedy sampling?
  1591. Aphrodite's sampler performs the following steps for greedy or multinomial
  1592. (random) sampling:
  1593. 1. Get logits from model.
  1594. 2. Modify logits according to per-sequence sampling parameters.
  1595. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1596. according to their frequency, etc.
  1597. 3. Sample a token.
  1598. - Random sampling simply samples from the modified probability
  1599. distribution.
  1600. - Greedy sampling performs `argmax` to obtain the token with the
  1601. highest likelihood.
  1602. Ignoring greedy sampling for a moment, we find that the computed probability
  1603. distribution has the following property: we can sample from it independently
  1604. and find that the token sampled by the Sampler has a frequency corresponding
  1605. to how often we see it in our sampling. In other words, for tokens sampled
  1606. with Aphrodite's random SamplingType, the computed probability distribution
  1607. encodes the sampling methodology completely.
  1608. Greedy sampling does not normally have this property. Aphrodite modifies
  1609. logits according to sampling params, then performs `argmax`, then returns
  1610. the sampled token and the computed probability distribution. If we sample
  1611. from the distribution, we'll find the likelihood of the greedily-sampled
  1612. token is not always 1.0.
  1613. Since lossless speculative decoding requires that the sampling methodology
  1614. be encoded within the probability distribution, we are motivated to modify
  1615. the probability distribution such that the sampled token has probability 1
  1616. when speculative decoding is used.
  1617. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1618. greedy sampling using multinomial computation and unite the codepaths. This
  1619. has implications on the overall design of the sampler, e.g. how to record
  1620. accurate logprobs for the user, so this improvement is deferred to later.
  1621. """
  1622. # NOTE: logprobs are not modified so they can be returned to the user.
  1623. probs[sample_indices, :] = 0
  1624. probs[sample_indices, greedy_samples] = 1.0
  1625. def _build_sampler_output(
  1626. maybe_deferred_sample_results: MaybeDeferredSampleResultType,
  1627. sampling_metadata: SamplingMetadata,
  1628. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1629. sample_logprobs: Optional[List[SampleLogprobs]],
  1630. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1631. torch.Tensor]],
  1632. skip_sampler_cpu_output: bool = False,
  1633. ) -> SamplerOutput:
  1634. """Construct Python objects with the output of sampling.
  1635. Args:
  1636. on_device_tensors: Tuple containing on-device tensors with the
  1637. probabilities used in sampling and the sampled token ids. This
  1638. allows post-processing without copies to CPU/serialization, e.g. in
  1639. speculative decoding rejection sampling.
  1640. """
  1641. sampler_output: List[CompletionSequenceGroupOutput] = []
  1642. if skip_sampler_cpu_output:
  1643. assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
  1644. deferred_sample_results_args = maybe_deferred_sample_results
  1645. else:
  1646. assert prompt_logprobs is not None
  1647. assert sample_logprobs is not None
  1648. assert not isinstance(maybe_deferred_sample_results,
  1649. SampleResultArgsType)
  1650. deferred_sample_results_args = None
  1651. for (seq_group, sample_result, group_prompt_logprobs,
  1652. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1653. maybe_deferred_sample_results,
  1654. prompt_logprobs, sample_logprobs):
  1655. seq_ids = seq_group.seq_ids
  1656. next_token_ids, parent_ids = sample_result
  1657. seq_outputs: List[SequenceOutput] = []
  1658. for parent_id, next_token_id, logprobs in zip(
  1659. parent_ids, next_token_ids, group_sample_logprobs):
  1660. seq_outputs.append(
  1661. SequenceOutput(seq_ids[parent_id], next_token_id,
  1662. logprobs))
  1663. sampler_output.append(
  1664. CompletionSequenceGroupOutput(seq_outputs,
  1665. group_prompt_logprobs))
  1666. # If not specified, store None values in SamplerOutput.
  1667. if on_device_tensors is not None:
  1668. (sampled_token_probs, logprobs_tensor,
  1669. sampled_token_ids) = on_device_tensors
  1670. else:
  1671. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1672. None)
  1673. return SamplerOutput(
  1674. outputs=sampler_output,
  1675. sampled_token_probs=sampled_token_probs,
  1676. sampled_token_ids=sampled_token_ids,
  1677. logprobs=logprobs_tensor,
  1678. deferred_sample_results_args=deferred_sample_results_args)
  1679. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1680. """Get a list of next prompt tokens to compute logprob from a
  1681. given sequence group.
  1682. It is used to compute prompt logprob. Imagine you have logprob for each
  1683. query token. Query token needs to know the next prompt token id to compute
  1684. prompt logprob. This is a helper to obtain next prompt token ids.
  1685. This API has to be used only when the caller knows seq_group is in prefill
  1686. stage.
  1687. Returns:
  1688. A list of next prompt tokens to compute logprob.
  1689. """
  1690. assert seq_group.is_prompt, (
  1691. "Caller should ensure the sequence group is in a prefill stage.")
  1692. seq_ids = seq_group.seq_ids
  1693. query_len = seq_group.query_len
  1694. assert query_len is not None
  1695. # prompt has only 1 seq id.
  1696. assert len(seq_ids) == 1
  1697. seq_data = seq_group.seq_data[seq_ids[0]]
  1698. computed_len = seq_data.get_num_computed_tokens()
  1699. prompt_tokens = seq_data.prompt_token_ids
  1700. # +1 because we are looking for a next prompt token.
  1701. next_token_index_start = computed_len + 1
  1702. next_token_index_end = min(computed_len + query_len + 1,
  1703. len(prompt_tokens))
  1704. next_prompt_tokens = prompt_tokens[
  1705. next_token_index_start:next_token_index_end]
  1706. return next_prompt_tokens
  1707. def _get_ngrams(
  1708. ngram_size: int,
  1709. prev_input_ids: torch.Tensor
  1710. ) -> Dict[Tuple[int, ...], List[int]]:
  1711. """Get dictionary of ngrams and the tokens that followed them.
  1712. Args:
  1713. ngram_size: Size of ngrams to track
  1714. prev_input_ids: 1D tensor of previous token ids
  1715. Returns:
  1716. Dictionary mapping ngram tuples to list of tokens that followed them
  1717. """
  1718. generated_ngrams = {}
  1719. gen_tokens = prev_input_ids.tolist()
  1720. for i in range(len(gen_tokens) - ngram_size + 1):
  1721. ngram = tuple(gen_tokens[i:i + ngram_size - 1])
  1722. next_token = gen_tokens[i + ngram_size - 1]
  1723. if ngram in generated_ngrams:
  1724. generated_ngrams[ngram].append(next_token)
  1725. else:
  1726. generated_ngrams[ngram] = [next_token]
  1727. return generated_ngrams
  1728. def _get_generated_ngrams(
  1729. banned_ngrams: Dict[Tuple[int, ...], List[int]],
  1730. prev_input_ids: torch.Tensor,
  1731. ngram_size: int,
  1732. cur_len: int
  1733. ) -> List[int]:
  1734. """Get list of tokens that would create a repeated ngram if generated next.
  1735. Args:
  1736. banned_ngrams: Dictionary of previously seen ngrams and their next
  1737. tokens
  1738. prev_input_ids: Previous token ids
  1739. ngram_size: Size of ngrams to check
  1740. cur_len: Current position in sequence
  1741. Returns:
  1742. List of token ids that would create a repeat ngram
  1743. """
  1744. start_idx = cur_len + 1 - ngram_size
  1745. current_ngram = tuple(prev_input_ids[start_idx:cur_len].tolist())
  1746. return banned_ngrams.get(current_ngram, [])
  1747. def _calc_banned_ngram_tokens(
  1748. ngram_size: int,
  1749. prev_input_ids: torch.Tensor,
  1750. cur_len: int
  1751. ) -> List[int]:
  1752. """Calculate tokens that would create repeated ngrams if generated next.
  1753. Args:
  1754. ngram_size: Size of ngrams to prevent repeating
  1755. prev_input_ids: Previous token ids in sequence
  1756. cur_len: Current position in sequence
  1757. Returns:
  1758. List of token ids that should be banned to prevent ngram repetition
  1759. """
  1760. if cur_len + 1 < ngram_size:
  1761. return []
  1762. generated_ngrams = _get_ngrams(ngram_size, prev_input_ids)
  1763. banned_tokens = _get_generated_ngrams(
  1764. generated_ngrams,
  1765. prev_input_ids,
  1766. ngram_size,
  1767. cur_len
  1768. )
  1769. return banned_tokens
  1770. # def _apply_mirostat_v2(logits: torch.Tensor,
  1771. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1772. # # Reduce our view to just the affected logits
  1773. # logit_view = logits[sampling_tensors.miro_indices]
  1774. # # Calculate surprise value per token
  1775. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1776. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1777. # # Mask out "too-surprising" tokens (surprisal > mu)
  1778. # mus = sampling_tensors.miro_mus
  1779. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1780. # # Unmask most-likely logit to guarantee a selection.
  1781. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1782. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1783. # # Apply logit mask (effectively a top-k filter).
  1784. # logit_view[miro_mask] = -float("inf")
  1785. # # Project logit changes made to the view onto the original.
  1786. # # I think this step might be redundant.
  1787. # logits[sampling_tensors.miro_indices] = logit_view
  1788. # return logits
  1789. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1790. # sample_results: List[Tuple[List[int], List[int]]],
  1791. # sampling_metadata: SamplingMetadata,
  1792. # output_metadata: OutputMetadata) -> None:
  1793. # """Based on whichever token was finally sampled, we calculate the
  1794. # final surprisal values to update the mus.
  1795. # Because a single sequence can have multiple samples, we must fork
  1796. # the mu accordingly."""
  1797. # assert sampling_metadata.seq_groups is not None
  1798. # seqid_to_tokens = {}
  1799. # seqid_to_indices = {}
  1800. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1801. # sample_results):
  1802. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1803. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1804. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1805. # seqids = args.miro_seqids
  1806. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1807. # device=logits.device,
  1808. # dtype=torch.long)
  1809. # # Clumsily, we recalculate token surprisals.
  1810. # logits_view = logits[args.miro_indices]
  1811. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1812. # dim=-1,
  1813. # index=picked_tokens) / -math.log(2)
  1814. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1815. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1816. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1817. # nu_mus = mus - (picked_surprise - taus) * etas
  1818. # # Record updated mu values for use in the next iteration
  1819. # # Note how each mu is split into multiple based on the number of samples.
  1820. # for seqid, seq_mus in zip(seqids, nu_mus):
  1821. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1822. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)