1
0

sampler.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import os
  4. import warnings
  5. from enum import IntEnum
  6. from math import inf
  7. from typing import Dict, List, Optional, Tuple
  8. import torch
  9. import torch.nn as nn
  10. from loguru import logger
  11. import aphrodite._custom_ops as ops
  12. from aphrodite.common.sampling_params import SamplingType
  13. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  14. PromptLogprobs, SampleLogprobs,
  15. SamplerOutput, SequenceOutput)
  16. from aphrodite.triton_utils import HAS_TRITON
  17. if HAS_TRITON:
  18. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  19. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  20. SamplingTensors,
  21. SequenceGroupToSample)
  22. # (num_token_ids, num_parent_ids) per sequence group.
  23. SampleResultType = List[Tuple[List[int], List[int]]]
  24. # There isn't a "safe" temperature range for fp16 logits.
  25. # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
  26. # that this temperature well-uses the fp16 space after the logits are offset.
  27. _TEMPERATURE_MINIMUM = 2e-5
  28. # If enabled, we switch to a more performant implementation
  29. # of top-k and top-p
  30. APHRODITE_USE_SAMPLING_KERNELS = bool(int(
  31. os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
  32. class SamplerID(IntEnum):
  33. # Mirror these in aphrodite/common/sampling_params.py
  34. # Values out of order to keep backwards compatibility
  35. # with Koboldcpp values
  36. DRY = 7
  37. PENALTIES = 6
  38. NO_REPEAT_NGRAM = 8
  39. TEMPERATURE = 5
  40. TOP_NSIGMA = 9
  41. TOP_P_TOP_K = 0
  42. TOP_A = 1
  43. MIN_P = 2
  44. TFS = 3
  45. ETA_CUTOFF = 10
  46. EPSILON_CUTOFF = 11
  47. TYPICAL_P = 4
  48. QUADRATIC = 12
  49. XTC = 13
  50. class Sampler(nn.Module):
  51. """Samples the next tokens from the model's outputs.
  52. This layer does the following:
  53. 1. Discard the hidden states that are not used for sampling (i.e., all
  54. tokens except the final one in each prompt).
  55. 2. Compute the logits for the next tokens.
  56. 3. Apply presence, frequency and repetition penalties.
  57. 4. Apply temperature scaling.
  58. 5. Apply top-p and top-k truncation.
  59. 6. Sample the next tokens.
  60. Here, each sequence group within the batch can have different sampling
  61. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  62. The structure of the logits tensor is coupled with the seq_groups in
  63. sampling_metadata. Typically, each sequence in each seq_group has one row in
  64. logits for the next token to be sampled; however, for a seq_group with a
  65. prompt request with the prompt_logprobs sampling parameter, there are rows
  66. in logits for each token in the input prompt.
  67. """
  68. def __init__(self):
  69. super().__init__()
  70. # Whether or not the SamplerOutput should have on-device tensors
  71. # containing the sampled token ids and probabilities. This is used by
  72. # speculative decoding.
  73. self.include_gpu_probs_tensor = False
  74. self.should_modify_greedy_probs_inplace = False
  75. def _init_sampling_tensors(
  76. self,
  77. logits: torch.Tensor,
  78. sampling_metadata: SamplingMetadata,
  79. ):
  80. """The goal here is to reuse sampling tensors between similar decode
  81. runs. This is possible because sampling logic does not change between
  82. decodes of the same sequences.
  83. """
  84. _, vocab_size = logits.shape
  85. # First free any existing stored sampling tensors.
  86. # This is necessary because some sampling tensors may
  87. # have pinned memory.
  88. self._sampling_tensors = None
  89. # Initialize new sampling tensors
  90. (sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
  91. do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
  92. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
  93. do_dry, do_skew, do_temp_last
  94. ) = SamplingTensors.from_sampling_metadata(
  95. sampling_metadata, vocab_size, logits.device, logits.dtype)
  96. self._sampling_tensors = sampling_tensors
  97. self._do_penalties = do_penalties
  98. self._do_no_repeat_ngrams = do_no_repeat_ngrams
  99. self._do_temperatures = do_temperatures
  100. self._do_top_p_top_k = do_top_p_top_k
  101. self._do_top_as = do_top_as
  102. self._do_min_p = do_min_p
  103. self._do_tfss = do_tfss
  104. self._do_eta_cutoffs = do_eta_cutoffs
  105. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  106. self._do_typical_ps = do_typical_ps
  107. self._do_quadratic = do_quadratic
  108. self._do_xtc = do_xtc
  109. self._do_nsgimas = do_nsigmas
  110. self._do_dry = do_dry
  111. self._do_skew = do_skew
  112. self._do_temp_last = do_temp_last
  113. def forward(
  114. self,
  115. logits: torch.Tensor,
  116. sampling_metadata: SamplingMetadata,
  117. ) -> Optional[SamplerOutput]:
  118. """
  119. Args:
  120. logits: (num_tokens, vocab_size).
  121. sampling_metadata: Metadata for sampling.
  122. """
  123. assert logits is not None
  124. _, vocab_size = logits.shape
  125. # Prepare sampling tensors with pinned memory to avoid blocking.
  126. if not sampling_metadata.reuse_sampling_tensors:
  127. self._init_sampling_tensors(logits, sampling_metadata)
  128. elif self._do_penalties or self._do_dry:
  129. # In this case, the sampling tensors logic depends on
  130. # "output_tokens" of a sequence. As a result, we cannot
  131. # reuse sampling tensors, since "output_tokens" changes
  132. # between decode runs.
  133. self._init_sampling_tensors(logits, sampling_metadata)
  134. assert self._sampling_tensors is not None
  135. sampling_tensors = self._sampling_tensors
  136. do_penalties = self._do_penalties
  137. do_no_repeat_ngrams = self._do_no_repeat_ngrams
  138. do_temperatures = self._do_temperatures
  139. do_top_p_top_k = self._do_top_p_top_k
  140. do_top_as = self._do_top_as
  141. do_min_p = self._do_min_p
  142. do_tfss = self._do_tfss
  143. do_eta_cutoffs = self._do_eta_cutoffs
  144. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  145. do_typical_ps = self._do_typical_ps
  146. do_quadratic = self._do_quadratic
  147. do_xtc = self._do_xtc
  148. do_nsigmas = self._do_nsgimas
  149. do_dry = self._do_dry
  150. do_skew = self._do_skew
  151. do_temp_last = self._do_temp_last
  152. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  153. banned_tokens = _get_custom_token_bans(sampling_metadata)
  154. logits = _apply_token_bans(logits, banned_tokens)
  155. sampler_order = None
  156. if sampling_metadata.seq_groups:
  157. sampler_order = sampling_metadata.seq_groups[
  158. 0].sampling_params.sampler_priority
  159. # Warn if both custom order and temp_last are specified
  160. if sampler_order is not None and do_temp_last:
  161. logger.warning(
  162. "Both sampler_priority and temperature_last=True "
  163. "were specified. Using custom sampler_priority order "
  164. "and ignoring temperature_last.")
  165. if sampler_order is None:
  166. default_order = [
  167. SamplerID.DRY,
  168. SamplerID.PENALTIES,
  169. SamplerID.NO_REPEAT_NGRAM,
  170. SamplerID.TEMPERATURE,
  171. SamplerID.TOP_NSIGMA,
  172. SamplerID.TOP_P_TOP_K,
  173. SamplerID.TOP_A,
  174. SamplerID.MIN_P,
  175. SamplerID.TFS,
  176. SamplerID.ETA_CUTOFF,
  177. SamplerID.EPSILON_CUTOFF,
  178. SamplerID.TYPICAL_P,
  179. SamplerID.QUADRATIC,
  180. SamplerID.XTC,
  181. ]
  182. sampler_order = []
  183. for sampler_id in default_order:
  184. if sampler_id == SamplerID.TEMPERATURE and do_temp_last:
  185. continue
  186. sampler_order.append(sampler_id)
  187. if sampler_id == SamplerID.XTC and do_temp_last:
  188. sampler_order.append(SamplerID.TEMPERATURE)
  189. if sampling_metadata.seq_groups and sampling_metadata.seq_groups[
  190. 0].is_prompt:
  191. logger.debug("Sampler execution order: ")
  192. for i, sampler_id in enumerate(sampler_order, 1):
  193. logger.debug(f"{i}. {SamplerID(sampler_id).name}")
  194. enabled_samplers = []
  195. # ruff: noqa: E701
  196. if do_penalties: enabled_samplers.append("PENALTIES")
  197. if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM")
  198. if do_temperatures: enabled_samplers.append("TEMPERATURE")
  199. if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K")
  200. if do_top_as: enabled_samplers.append("TOP_A")
  201. if do_min_p: enabled_samplers.append("MIN_P")
  202. if do_tfss: enabled_samplers.append("TFS")
  203. if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF")
  204. if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF")
  205. if do_typical_ps: enabled_samplers.append("TYPICAL_P")
  206. if do_quadratic: enabled_samplers.append("QUADRATIC")
  207. if do_xtc: enabled_samplers.append("XTC")
  208. if do_nsigmas: enabled_samplers.append("TOP_NSIGMA")
  209. if do_dry: enabled_samplers.append("DRY")
  210. if do_skew: enabled_samplers.append("SKEW")
  211. logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}")
  212. for sampler_id in sampler_order:
  213. if sampler_id == SamplerID.DRY and do_dry:
  214. if (sampling_metadata.seq_groups and
  215. sampling_metadata.seq_groups[0].is_prompt):
  216. logger.debug(
  217. f"Applying DRY with dry_multiplier: "
  218. f"{sampling_tensors.dry_multipliers}.")
  219. logits = _apply_dry(
  220. logits,
  221. sampling_tensors.prompt_tokens,
  222. sampling_tensors.output_tokens,
  223. sampling_tensors.dry_multipliers,
  224. sampling_tensors.dry_bases,
  225. sampling_tensors.dry_allowed_lengths,
  226. sampling_tensors.dry_sequence_breaker_ids,
  227. sampling_tensors.dry_ranges)
  228. elif sampler_id == SamplerID.PENALTIES and do_penalties:
  229. if (sampling_metadata.seq_groups and
  230. sampling_metadata.seq_groups[0].is_prompt):
  231. logger.debug(
  232. "Applying penalties with "
  233. f"pres_pen: {sampling_tensors.presence_penalties}, "
  234. f"freq_pen: {sampling_tensors.frequency_penalties}, "
  235. f"rep_pen: {sampling_tensors.repetition_penalties}.")
  236. logits = _apply_penalties(
  237. logits, sampling_tensors.prompt_tokens,
  238. sampling_tensors.output_tokens,
  239. sampling_tensors.presence_penalties,
  240. sampling_tensors.frequency_penalties,
  241. sampling_tensors.repetition_penalties)
  242. elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \
  243. do_no_repeat_ngrams:
  244. if (sampling_metadata.seq_groups and
  245. sampling_metadata.seq_groups[0].is_prompt):
  246. logger.debug(
  247. "Applying no_repeat_ngram with no_repeat_ngram_size: "
  248. f"{sampling_tensors.no_repeat_ngram_sizes}.")
  249. logits = _apply_no_repeat_ngram(
  250. logits,
  251. sampling_tensors.prompt_tokens,
  252. sampling_tensors.no_repeat_ngram_sizes)
  253. elif sampler_id == SamplerID.TEMPERATURE and do_temperatures:
  254. if (sampling_metadata.seq_groups and
  255. sampling_metadata.seq_groups[0].is_prompt):
  256. logger.debug(
  257. "Applying temperatures with temperature: "
  258. f"{sampling_tensors.temperatures}, "
  259. f"dynatemp_min: {sampling_tensors.dynatemp_mins}, "
  260. f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, "
  261. f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.")
  262. _apply_temperatures(
  263. logits, sampling_tensors.temperatures,
  264. sampling_tensors.dynatemp_mins,
  265. sampling_tensors.dynatemp_maxs,
  266. sampling_tensors.dynatemp_exps)
  267. elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas:
  268. if (sampling_metadata.seq_groups and
  269. sampling_metadata.seq_groups[0].is_prompt):
  270. logger.debug(
  271. "Applying Top-Nsigma with nsigma: "
  272. f"{sampling_tensors.nsigmas}")
  273. logits = _apply_top_nsigma(
  274. logits, sampling_tensors.nsigmas)
  275. elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \
  276. not APHRODITE_USE_SAMPLING_KERNELS:
  277. if (sampling_metadata.seq_groups and
  278. sampling_metadata.seq_groups[0].is_prompt):
  279. logger.debug(
  280. "Applying Top-p and Top-k with top-p: "
  281. f"{sampling_tensors.top_ps}, top_k: "
  282. f"{sampling_tensors.top_ks}.")
  283. logits = _apply_top_k_top_p(
  284. logits, sampling_tensors.top_ps,
  285. sampling_tensors.top_ks)
  286. elif sampler_id == SamplerID.TOP_A and do_top_as:
  287. if (sampling_metadata.seq_groups and
  288. sampling_metadata.seq_groups[0].is_prompt):
  289. logger.debug(
  290. "Applying Top-a with Top-a: "
  291. f"{sampling_tensors.top_as}.")
  292. logits = _apply_top_a(
  293. logits, sampling_tensors.top_as)
  294. elif sampler_id == SamplerID.MIN_P and do_min_p:
  295. if (sampling_metadata.seq_groups and
  296. sampling_metadata.seq_groups[0].is_prompt):
  297. logger.debug(
  298. "Applying Min-p with Min-p: "
  299. f"{sampling_tensors.min_ps}.")
  300. logits = _apply_min_p(
  301. logits, sampling_tensors.min_ps)
  302. elif sampler_id == SamplerID.TFS and do_tfss:
  303. if (sampling_metadata.seq_groups and
  304. sampling_metadata.seq_groups[0].is_prompt):
  305. logger.debug(
  306. "Applying Tail-Free Sampling with tfs: "
  307. f"{sampling_tensors.tfss}.")
  308. logits = _apply_tfs(
  309. logits, sampling_tensors.tfss)
  310. elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs:
  311. if (sampling_metadata.seq_groups and
  312. sampling_metadata.seq_groups[0].is_prompt):
  313. logger.debug(
  314. "Applying ETA Cutoff with eta_cutoff: "
  315. f"{sampling_tensors.eta_cutoffs}.")
  316. logits = _apply_eta_cutoff(
  317. logits, sampling_tensors.eta_cutoffs)
  318. elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs:
  319. if (sampling_metadata.seq_groups and
  320. sampling_metadata.seq_groups[0].is_prompt):
  321. logger.debug(
  322. "Applying Epsilon Cutoff with epsilon_cutoff: "
  323. f"{sampling_tensors.epsilon_cutoffs}.")
  324. logits = _apply_epsilon_cutoff(
  325. logits, sampling_tensors.epsilon_cutoffs)
  326. elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps:
  327. if (sampling_metadata.seq_groups and
  328. sampling_metadata.seq_groups[0].is_prompt):
  329. logger.debug(
  330. "Applying Locally Typical Sampling with typical_p: "
  331. f"{sampling_tensors.typical_ps}.")
  332. logits = _apply_typical_sampling(
  333. logits, sampling_tensors.typical_ps)
  334. elif sampler_id == SamplerID.QUADRATIC and do_quadratic:
  335. if (sampling_metadata.seq_groups and
  336. sampling_metadata.seq_groups[0].is_prompt):
  337. logger.debug(
  338. "Applying Quadratic and Cubic Sampling with "
  339. "smoothing_factors: "
  340. f"{sampling_tensors.smoothing_factors},"
  341. f" smoothing_curves: "
  342. f"{sampling_tensors.smoothing_curves}.")
  343. logits = _apply_quadratic_sampling(
  344. logits, sampling_tensors.smoothing_factors,
  345. sampling_tensors.smoothing_curves)
  346. elif sampler_id == SamplerID.XTC and do_xtc:
  347. if (sampling_metadata.seq_groups and
  348. sampling_metadata.seq_groups[0].is_prompt):
  349. logger.debug(
  350. "Applying Exclude Top Choices sampling with "
  351. f"xtc_threshold: {sampling_tensors.xtc_thresholds}, "
  352. "xtc_probability: "
  353. f"{sampling_tensors.xtc_probabilities}.")
  354. logits = _apply_xtc_sampling(
  355. logits, sampling_tensors.xtc_thresholds,
  356. sampling_tensors.xtc_probabilities)
  357. # We use float32 for probabilities and log probabilities.
  358. # Compute the probabilities.
  359. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  360. # skew needs to be applied post-softmax
  361. if do_skew:
  362. if (sampling_metadata.seq_groups and
  363. sampling_metadata.seq_groups[0].is_prompt):
  364. logger.debug(
  365. "Applying Skew sampling with skew: "
  366. f"{sampling_tensors.skews}.")
  367. # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
  368. cum_probs = torch.cumsum(probs, dim=-1)
  369. cum_probs = torch.pow(cum_probs, torch.exp(
  370. sampling_tensors.skews).unsqueeze(dim=1))
  371. probs = torch.diff(cum_probs, dim=-1,
  372. prepend=torch.zeros_like(cum_probs[..., :1]))
  373. logits = torch.log(probs)
  374. # Compute the log probabilities.
  375. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  376. # Sample the next tokens.
  377. sample_results, maybe_sampled_tokens_tensor = _sample(
  378. probs,
  379. logprobs,
  380. sampling_metadata,
  381. sampling_tensors,
  382. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  383. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  384. )
  385. if self.include_gpu_probs_tensor:
  386. assert maybe_sampled_tokens_tensor is not None
  387. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  388. else:
  389. on_device_tensors = None
  390. # Get the logprobs query results.
  391. prompt_logprobs = None
  392. sample_logprobs = None
  393. if not sampling_metadata.skip_sampler_cpu_output:
  394. prompt_logprobs, sample_logprobs = _get_logprobs(
  395. logprobs, sampling_metadata, sample_results)
  396. return _build_sampler_output(
  397. sample_results,
  398. sampling_metadata,
  399. prompt_logprobs,
  400. sample_logprobs,
  401. on_device_tensors=on_device_tensors,
  402. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  403. @property
  404. def _should_modify_greedy_probs_inplace(self) -> bool:
  405. """Whether or not the sampler should modify the probability distribution
  406. of greedily-sampled tokens such that multinomial sampling would sample
  407. the greedily-sampled token.
  408. In other words, if True then we set the probability of the greedily-
  409. sampled token to 1.
  410. This is used by speculative decoding, which requires that the sampling
  411. method be encoded into the probability distribution.
  412. """
  413. return self.should_modify_greedy_probs_inplace
  414. def _get_bin_counts_and_mask(
  415. tokens: torch.Tensor,
  416. vocab_size: int,
  417. num_seqs: int,
  418. ) -> Tuple[torch.Tensor, torch.Tensor]:
  419. # Compute the bin counts for the tokens.
  420. # vocab_size + 1 for padding.
  421. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  422. dtype=torch.long,
  423. device=tokens.device)
  424. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  425. bin_counts = bin_counts[:, :vocab_size]
  426. mask = bin_counts > 0
  427. return bin_counts, mask
  428. def _get_custom_token_bans(
  429. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  430. assert sampling_metadata.seq_groups is not None
  431. banned_tokens: List[List[int]] = []
  432. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  433. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  434. seq_ids = seq_group.seq_ids
  435. custom_token_bans = sampling_params.custom_token_bans
  436. if (i < sampling_metadata.num_prompts
  437. and sampling_params.prompt_logprobs is not None):
  438. prompt_len = len(seq_group.prompt_logprob_indices)
  439. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  440. banned_tokens += [custom_token_bans] * len(seq_ids)
  441. return banned_tokens
  442. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  443. output_tokens_tensor: torch.Tensor,
  444. presence_penalties: torch.Tensor,
  445. frequency_penalties: torch.Tensor,
  446. repetition_penalties: torch.Tensor) -> torch.Tensor:
  447. num_seqs, vocab_size = logits.shape
  448. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  449. num_seqs)
  450. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  451. output_tokens_tensor, vocab_size, num_seqs)
  452. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  453. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  454. logits = torch.where(logits > 0, logits / repetition_penalties,
  455. logits * repetition_penalties)
  456. # We follow the definition in OpenAI API.
  457. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  458. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  459. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  460. return logits
  461. def _apply_temperatures(
  462. logits: torch.Tensor,
  463. temperatures: torch.Tensor,
  464. dynatemp_mins: torch.Tensor,
  465. dynatemp_maxs: torch.Tensor,
  466. dynatemp_exps: torch.Tensor,
  467. ) -> None:
  468. dynatemp_mask = (dynatemp_mins != 0) | (dynatemp_maxs != 0)
  469. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  470. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  471. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  472. dynatemp_logits = logits[dynatemp_mask]
  473. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  474. dynatemp_probs = dynatemp_shifted_logits.exp()
  475. dynatemp_entropies = -(dynatemp_probs *
  476. dynatemp_shifted_logits).nansum(dim=-1)
  477. dynatemp_max_entropies = torch.log_(
  478. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  479. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  480. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  481. normalized_entropies.pow_(dynatemp_exps))
  482. temperatures[dynatemp_mask] = dyn_temp
  483. temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
  484. temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
  485. # To prevent saturation of top logits, we shift the range to [-inf, 1]
  486. # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
  487. # Why mask? So we aren't potentially discarding data in milder temps.
  488. low_temps = temperatures < 0.1
  489. logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
  490. logits.div_(temperatures.unsqueeze(dim=1))
  491. def _apply_token_bans(logits: torch.Tensor,
  492. banned_tokens: List[List[int]]) -> torch.Tensor:
  493. for i, banned_token_ids in enumerate(banned_tokens):
  494. if i >= logits.size(0):
  495. break
  496. if not banned_token_ids:
  497. continue
  498. logits[i, banned_token_ids] = -float("inf")
  499. return logits
  500. def _apply_min_tokens_penalty(
  501. logits: torch.Tensor,
  502. sampling_metadata: SamplingMetadata,
  503. ) -> torch.Tensor:
  504. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  505. have not been generated yet
  506. """
  507. # list of indices in logits that will be set to -inf
  508. logits_to_penalize = []
  509. logits_applied = 0
  510. for seq_group in sampling_metadata.seq_groups:
  511. seq_ids = seq_group.seq_ids
  512. sampling_params = seq_group.sampling_params
  513. sample_indices = seq_group.sample_indices
  514. logits_applied += len(sample_indices) + len(
  515. seq_group.prompt_logprob_indices)
  516. if not seq_group.do_sample:
  517. continue
  518. start_idx = sample_indices[0]
  519. min_tokens = sampling_params.min_tokens
  520. token_ids_to_penalize = sampling_params.all_stop_token_ids
  521. if min_tokens > 0 and token_ids_to_penalize:
  522. seqs_to_penalize = []
  523. for j, seq_id in enumerate(seq_ids):
  524. seq_data = seq_group.seq_data[seq_id]
  525. if len(seq_data.output_token_ids_array) < min_tokens:
  526. seqs_to_penalize.append(j)
  527. if seqs_to_penalize:
  528. # convert to the index into logits
  529. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  530. # itertools.product pairs each seq index with every token id
  531. logits_to_penalize.extend(
  532. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  533. if logits_to_penalize:
  534. # use zip and * to group indices along each dimension
  535. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  536. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  537. # verifies that no rows in logits were missed unexpectedly
  538. assert logits_applied == logits.shape[0]
  539. return logits
  540. def _apply_dry(
  541. logits: torch.Tensor,
  542. input_token_ids: torch.Tensor,
  543. output_token_ids: torch.Tensor,
  544. multipliers: torch.Tensor,
  545. bases: torch.Tensor,
  546. allowed_lengths: torch.Tensor,
  547. sequence_breakers_ids: torch.Tensor,
  548. ranges: torch.Tensor,
  549. ) -> torch.Tensor:
  550. """
  551. Apply Don't Repeat Yourself (DRY) sampling to the logits.
  552. Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
  553. """
  554. dry_indices = torch.nonzero(multipliers).squeeze(-1)
  555. if len(dry_indices) == 0:
  556. return logits
  557. # DRY needs to be applied to both input AND output tokens
  558. input_ids = torch.cat((input_token_ids, output_token_ids), dim=1)
  559. vocab_size = logits.size(-1)
  560. def compute_z_array_gpu(
  561. s: torch.Tensor, end: int,
  562. search_start: int) -> torch.Tensor:
  563. """
  564. GPU-optimized version of Z array computation using two-pointer technique
  565. """
  566. n = len(s)
  567. z = torch.zeros(n, dtype=torch.int64, device=s.device)
  568. if end < search_start:
  569. return z
  570. right = end - 1
  571. left = end - 1
  572. while right >= search_start:
  573. while left == right and left >= search_start:
  574. if s[right] == s[end]:
  575. break
  576. right -= 1
  577. left -= 1
  578. while left >= search_start and s[left] == s[end - (right - left)]:
  579. z[right] += 1
  580. left -= 1
  581. helper = right
  582. while right > left:
  583. right -= 1
  584. if left == right:
  585. break
  586. z[right] = min(z[end - (helper - right)].item(), right - left)
  587. if left >= search_start and right - z[right] <= left:
  588. break
  589. return z
  590. # Process only sequences that have DRY enabled
  591. for idx in dry_indices:
  592. input_ids_row = input_ids[idx]
  593. logits_row = logits[idx]
  594. seq_breakers = set(sequence_breakers_ids[idx].tolist())
  595. last_token = input_ids_row[-1].item()
  596. if last_token in seq_breakers:
  597. continue
  598. range_limit = ranges[idx].item()
  599. if range_limit == 0:
  600. search_start = 0
  601. else:
  602. search_start = max(0, len(input_ids_row) - range_limit)
  603. # Find max match length based on sequence breakers
  604. max_match_length = 0
  605. MAX_LENGTH = min(len(input_ids_row), 1000) # Prevent overflow
  606. while (max_match_length < MAX_LENGTH and
  607. input_ids_row[len(input_ids_row) - max_match_length - 1].item()
  608. not in seq_breakers):
  609. max_match_length += 1
  610. z_array = compute_z_array_gpu(
  611. input_ids_row, len(input_ids_row) - 1, search_start)
  612. z_array = torch.minimum(
  613. z_array, torch.tensor(max_match_length, device=z_array.device))
  614. penalties = {}
  615. allowed_length = allowed_lengths[idx]
  616. base = bases[idx]
  617. multiplier = multipliers[idx].item()
  618. for idx2, match_length in enumerate(z_array[:-1]):
  619. match_length = match_length.item() # Convert tensor to int
  620. if match_length >= allowed_length:
  621. next_token = input_ids_row[idx2 + 1].item()
  622. if (next_token >= vocab_size or next_token in
  623. seq_breakers):
  624. continue
  625. penalty = multiplier * (base ** (match_length - allowed_length))
  626. penalties[next_token] = max(
  627. penalty, penalties.get(next_token, 0))
  628. for token, penalty in penalties.items():
  629. logits_row[token] -= penalty
  630. return logits
  631. def _apply_no_repeat_ngram(
  632. logits: torch.Tensor,
  633. input_ids: torch.Tensor,
  634. ngram_size: torch.Tensor,
  635. ) -> torch.Tensor:
  636. """Apply no-repeat-ngram penalty which sets logits to -inf for tokens that
  637. would create a repeated n-gram.
  638. """
  639. if torch.all(ngram_size == 0):
  640. return logits
  641. batch_size = logits.shape[0]
  642. for i in range(batch_size):
  643. size = int(ngram_size[i].item())
  644. if size == 0:
  645. continue
  646. cur_len = len(input_ids[i])
  647. if cur_len < size:
  648. continue
  649. banned_tokens = _calc_banned_ngram_tokens(
  650. ngram_size=size,
  651. prev_input_ids=input_ids[i],
  652. cur_len=cur_len-1
  653. )
  654. if banned_tokens:
  655. logits[i, banned_tokens] = -float("inf")
  656. return logits
  657. def _apply_top_k_top_p(
  658. logits: torch.Tensor,
  659. p: torch.Tensor,
  660. k: torch.Tensor,
  661. ) -> torch.Tensor:
  662. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  663. # Apply top-k.
  664. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  665. # Get all the top_k values.
  666. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  667. top_k_mask = logits_sort < top_k_mask
  668. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  669. # Apply top-p.
  670. probs_sort = logits_sort.softmax(dim=-1)
  671. probs_sum = probs_sort.cumsum(dim=-1)
  672. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  673. # at least one
  674. top_p_mask[:, -1] = False
  675. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  676. # Re-sort the probabilities.
  677. src = torch.arange(logits_idx.shape[-1],
  678. device=logits_idx.device).expand_as(logits_idx)
  679. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  680. index=logits_idx,
  681. src=src)
  682. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  683. return logits
  684. def _apply_min_p(
  685. logits: torch.Tensor,
  686. min_p: torch.Tensor,
  687. ) -> torch.Tensor:
  688. """
  689. Adapted from
  690. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  691. """
  692. probs = torch.softmax(logits, dim=-1)
  693. top_probs, _ = probs.max(dim=-1, keepdim=True)
  694. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  695. tokens_to_remove = probs < scaled_min_p
  696. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  697. return logits
  698. def _apply_top_a(
  699. logits: torch.Tensor,
  700. top_a: torch.Tensor,
  701. ) -> torch.Tensor:
  702. probs = torch.softmax(logits, dim=-1)
  703. top_probs, _ = probs.max(dim=-1, keepdim=True)
  704. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  705. tokens_to_remove = probs < threshold
  706. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  707. return logits
  708. def _apply_tfs(
  709. logits: torch.Tensor,
  710. tfs: torch.Tensor,
  711. ) -> torch.Tensor:
  712. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  713. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  714. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  715. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  716. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  717. tfs_mask = torch.cat(
  718. (
  719. torch.zeros(
  720. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  721. tfs_mask,
  722. torch.ones(
  723. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  724. ),
  725. dim=-1,
  726. )
  727. logits_sort[tfs_mask] = -float("inf")
  728. logits = torch.gather(logits_sort,
  729. dim=-1,
  730. index=torch.argsort(logits_idx, dim=-1))
  731. return logits
  732. def _apply_eta_cutoff(
  733. logits: torch.Tensor,
  734. eta_cutoff: torch.Tensor,
  735. ) -> torch.Tensor:
  736. shifted_logits = torch.log_softmax(logits, dim=-1)
  737. probs = shifted_logits.exp()
  738. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  739. eps = torch.min(eta_cutoff,
  740. torch.sqrt(eta_cutoff) *
  741. torch.exp(neg_entropy)).unsqueeze(dim=1)
  742. eta_mask = probs < eps
  743. # guard against nulling out all the logits
  744. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  745. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  746. logits[eta_mask] = -float("inf")
  747. return logits
  748. def _apply_epsilon_cutoff(
  749. logits: torch.Tensor,
  750. epsilon_cutoff: torch.Tensor,
  751. ) -> torch.Tensor:
  752. probs = logits.softmax(dim=-1)
  753. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  754. # guard against nulling out all the logits
  755. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  756. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  757. logits[eps_mask] = -float("inf")
  758. return logits
  759. def _apply_typical_sampling(
  760. logits: torch.Tensor,
  761. typical_p: torch.Tensor,
  762. ) -> torch.Tensor:
  763. shifted_logits = torch.log_softmax(logits, dim=-1)
  764. probs = shifted_logits.exp()
  765. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  766. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  767. _, indices = torch.sort(surprisal_deviations)
  768. reordered_probs = probs.gather(-1, indices)
  769. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  770. dim=1)
  771. min_tokens_to_keep = 1
  772. # Keep at least min_tokens_to_keep
  773. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  774. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  775. logits[typ_mask] = -float("inf")
  776. return logits
  777. def _apply_quadratic_sampling(
  778. logits: torch.Tensor,
  779. smoothing_factor: torch.Tensor,
  780. smoothing_curve: torch.Tensor,
  781. ) -> torch.Tensor:
  782. """
  783. Applies a quadratic transformation to the logits based on the
  784. provided smoothing factors and curves. The transformation is
  785. centered around the maximum logit value in the batch.
  786. The transformation involves a quadratic and cubic term, with the
  787. cubic term controlled by the smoothing curve. The quadratic term is
  788. scaled by the smoothing factor, and the cubic term is scaled by the
  789. product of the smoothing factor and the smoothing curve.
  790. params:
  791. logits (torch.Tensor): The logits to be transformed.
  792. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  793. term in the transformation.
  794. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  795. in the transformation.
  796. returns:
  797. torch.Tensor: The transformed logits.
  798. Credits: @kalomaze
  799. """
  800. mask = smoothing_factor != 0
  801. smoothing_factor.unsqueeze_(dim=1)
  802. smoothing_curve.unsqueeze_(dim=1)
  803. k = smoothing_factor * (3 - smoothing_curve) / 2
  804. s = smoothing_factor * (smoothing_curve - 1) / 2
  805. quadlogits = logits[mask] # limit to logits using this sampler
  806. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  807. # Construct the delta from each logit to its new value
  808. diff = quadlogits - max_logits
  809. diff -= diff**2 * (s[mask] * diff - k[mask])
  810. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  811. logits[mask] -= diff
  812. return logits
  813. def _apply_xtc_sampling(
  814. logits: torch.Tensor,
  815. xtc_thresholds: torch.Tensor,
  816. xtc_probabilities: torch.Tensor,
  817. ) -> torch.Tensor:
  818. """Apply Exclude Top Choices (XTC) sampling to the logits.
  819. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  820. Args:
  821. logits: (num_tokens, vocab_size) The input logits.
  822. xtc_thresholds: (num_tokens,) The threshold for each token.
  823. xtc_probabilities: (num_tokens,) The probability of applying XTC
  824. for each token.
  825. Returns:
  826. torch.Tensor: The modified logits.
  827. """
  828. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  829. if not apply_xtc.any():
  830. return logits
  831. probs = torch.softmax(logits, dim=-1)
  832. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  833. # Find indices where the next probability is above the threshold
  834. # Skips the top choice, which later on becomes skipping the last choice.
  835. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  836. # Apply XTC only to rows where it should be applied
  837. for i in range(logits.shape[0]):
  838. if apply_xtc[i]:
  839. # Count logits above the threshold (skipping the first)
  840. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  841. if indices_to_remove > 0:
  842. # Implies the top logit and at least one other is >= threshold.
  843. # Mask out above_thresh logits except the last/lowest one.
  844. logits[i].scatter_(
  845. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  846. return logits
  847. def _apply_top_nsigma(
  848. logits: torch.Tensor,
  849. nsigma: torch.Tensor,
  850. ) -> torch.Tensor:
  851. """Apply top-nsigma truncation to the logits.
  852. Reference: https://arxiv.org/abs/2411.07641
  853. Args:
  854. logits: Logits of shape (num_tokens, vocab_size)
  855. nsigma: Number of standard deviations to use as threshold
  856. Returns:
  857. Modified logits with values below threshold set to -inf
  858. """
  859. std = logits.std(dim=-1, keepdim=True)
  860. threshold = (logits.max(dim=-1, keepdim=True).values -
  861. nsigma.unsqueeze(dim=1) * std)
  862. logits[logits < threshold] = float("-inf")
  863. return logits
  864. def _greedy_sample(
  865. selected_seq_groups: List[SequenceGroupToSample],
  866. samples: torch.Tensor,
  867. ) -> List[Tuple[List[int], List[int]]]:
  868. """Run greedy sampling on a given samples.
  869. Args:
  870. selected_seq_groups: A list of sequence groups batched.
  871. samples: (num_selected_samples,) A tensor of samples. The length of
  872. samples could be smaller than selected_seq_groups if
  873. seq_group.do_sample is False.
  874. Returns:
  875. Tuple of (next_token_ids, parent_ids). The length of returned list is
  876. same as the length of selected_seq_groups. If the corresponding
  877. seq_group has do_sample=False, tuple contains ([], [])
  878. """
  879. samples = samples.tolist()
  880. sample_idx = 0
  881. results = []
  882. for seq_group in selected_seq_groups:
  883. if not seq_group.do_sample:
  884. results.append(([], []))
  885. continue
  886. seq_ids = seq_group.seq_ids
  887. num_parent_seqs = len(seq_ids)
  888. assert num_parent_seqs == 1, (
  889. "Greedy sampling should have only one seq.")
  890. parent_ids = list(range(num_parent_seqs))
  891. next_token_ids = [samples[sample_idx]]
  892. results.append((next_token_ids, parent_ids))
  893. sample_idx += num_parent_seqs
  894. return results
  895. def _random_sample(
  896. selected_seq_groups: List[SequenceGroupToSample],
  897. random_samples: torch.Tensor,
  898. ) -> List[Tuple[List[int], List[int]]]:
  899. """Run random sampling on a given samples.
  900. Args:
  901. selected_seq_groups: A list of sequence groups batched.
  902. random_samples: (num_selected_samples,) A tensor of samples. The
  903. length of samples could be smaller than selected_seq_groups if
  904. seq_group.do_sample is False.
  905. Returns:
  906. Tuple of (next_token_ids, parent_ids). The length of returned list is
  907. same as the length of selected_seq_groups. If the corresponding
  908. seq_group has do_sample=False, tuple contains ([], [])
  909. """
  910. # Find the maximum best_of value of the prompt phase requests.
  911. random_samples = random_samples.cpu()
  912. sample_idx = 0
  913. results = []
  914. for seq_group in selected_seq_groups:
  915. if not seq_group.do_sample:
  916. results.append(([], []))
  917. continue
  918. seq_ids = seq_group.seq_ids
  919. sampling_params = seq_group.sampling_params
  920. is_prompt = seq_group.is_prompt
  921. num_parent_seqs = len(seq_ids)
  922. if is_prompt:
  923. # Prompt phase.
  924. parent_ids = [0] * sampling_params.best_of
  925. next_token_ids = random_samples[
  926. sample_idx, :sampling_params.best_of].tolist()
  927. else:
  928. # Generation phase.
  929. parent_ids = list(range(num_parent_seqs))
  930. next_token_ids = random_samples[sample_idx:sample_idx +
  931. num_parent_seqs, 0].tolist()
  932. results.append((next_token_ids, parent_ids))
  933. sample_idx += num_parent_seqs
  934. return results
  935. def _beam_search_sample(
  936. selected_seq_groups: List[SequenceGroupToSample],
  937. logprobs: torch.Tensor,
  938. ) -> List[Tuple[List[int], List[int]]]:
  939. """Run beam sampling on a given samples.
  940. Args:
  941. selected_seq_groups: A list of sequence groups batched.
  942. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  943. on selected sample indices.
  944. Returns:
  945. Tuple of (next_token_ids, parent_ids). The length of returned list is
  946. same as the length of selected_seq_groups. If the corresponding
  947. seq_group has do_sample=False, tuple contains ([], [])
  948. """
  949. # We sample 2 * beam_width candidates to make sure that with high
  950. # probability we can get `beam_width` candidates in addition to
  951. # the finished sequences for the next iteration. See
  952. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  953. # for details. See also HF reference:
  954. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  955. #
  956. # NOTE: Beam search is not vectorized, so its speed can be slower than
  957. # other sampling methods.
  958. sample_idx = 0
  959. results = []
  960. for seq_group in selected_seq_groups:
  961. if not seq_group.do_sample:
  962. results.append(([], []))
  963. continue
  964. is_prompt = seq_group.is_prompt
  965. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  966. num_parent_seqs = len(seq_ids)
  967. beam_width = sampling_params.best_of
  968. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  969. if is_prompt:
  970. # Prompt phase.
  971. assert num_parent_seqs == 1, (
  972. "Prompt input should have only one seq.")
  973. parent_ids = [0] * (2 * beam_width)
  974. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  975. 2 * beam_width)
  976. next_token_ids = next_token_ids.tolist()
  977. else:
  978. # Generation phase.
  979. cumulative_logprobs = [
  980. seq_group.seq_data[seq_id].cumulative_logprob
  981. for seq_id in seq_ids
  982. ]
  983. cumulative_logprobs = torch.tensor(
  984. cumulative_logprobs,
  985. dtype=torch.float,
  986. device=seq_group_logprobs.device)
  987. seq_group_logprobs = (seq_group_logprobs +
  988. cumulative_logprobs.unsqueeze(dim=1))
  989. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  990. 2 * beam_width)
  991. topk_ids = topk_ids.tolist()
  992. vocab_size = seq_group_logprobs.size(-1)
  993. parent_ids = [i // vocab_size for i in topk_ids]
  994. next_token_ids = [i % vocab_size for i in topk_ids]
  995. results.append((next_token_ids, parent_ids))
  996. sample_idx += num_parent_seqs
  997. assert sample_idx == logprobs.size(0)
  998. return results
  999. # torch.multinomial forces a GPU<->CPU sync.
  1000. # Therefore, we use an optimized implementation instead.
  1001. # Note that we always sample with replacement.
  1002. # probs will be modified in place, but this is fine, as we pass
  1003. # in a copy already.
  1004. def _multinomial(
  1005. probs: torch.Tensor,
  1006. num_samples: int,
  1007. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  1008. ) -> torch.Tensor:
  1009. if num_samples > 1:
  1010. probs = probs.repeat_interleave(num_samples, dim=0)
  1011. q = torch.empty_like(probs)
  1012. if seq_groups is None:
  1013. q.exponential_()
  1014. else:
  1015. sample_idx = 0
  1016. for seq_group in seq_groups:
  1017. seq_ids = seq_group.seq_ids
  1018. stride = len(seq_ids) * num_samples
  1019. assert seq_group.generator is not None
  1020. q[sample_idx:sample_idx +
  1021. stride].exponential_(generator=seq_group.generator)
  1022. sample_idx += stride
  1023. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  1024. def _top_k_top_p_multinomial_with_kernels(
  1025. probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
  1026. num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
  1027. max_top_k_round = 32
  1028. if num_samples > 1:
  1029. probs = probs.repeat_interleave(num_samples, dim=0)
  1030. top_ks = top_ks.repeat_interleave(num_samples)
  1031. top_ps = top_ps.repeat_interleave(num_samples)
  1032. batch_size = probs.shape[0]
  1033. uniform_samples = torch.empty((max_top_k_round, batch_size),
  1034. device=probs.device)
  1035. if seq_groups is None:
  1036. uniform_samples.uniform_()
  1037. else:
  1038. sample_idx = 0
  1039. for seq_group in seq_groups:
  1040. seq_ids = seq_group.seq_ids
  1041. stride = len(seq_ids) * num_samples
  1042. assert seq_group.generator is not None
  1043. uniform_samples[:, sample_idx:sample_idx +
  1044. stride].uniform_(generator=seq_group.generator)
  1045. sample_idx += stride
  1046. batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
  1047. probs,
  1048. uniform_samples,
  1049. top_ks,
  1050. top_ps,
  1051. )
  1052. if not success.all():
  1053. warnings.warn("CUDA rejection sampling failed, fallback.",
  1054. stacklevel=1)
  1055. probs = ops.top_k_renorm_prob(probs, top_ks)
  1056. probs = ops.top_p_renorm_prob(probs, top_ps)
  1057. batch_next_token_ids = ops.sampling_from_probs(
  1058. probs, uniform_samples[0])
  1059. return batch_next_token_ids.view(-1, num_samples)
  1060. def _sample_with_torch(
  1061. probs: torch.Tensor,
  1062. logprobs: torch.Tensor,
  1063. sampling_metadata: SamplingMetadata,
  1064. sampling_tensors: SamplingTensors,
  1065. include_gpu_probs_tensor: bool,
  1066. modify_greedy_probs: bool,
  1067. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  1068. categorized_seq_group_ids = {t: [] for t in SamplingType}
  1069. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  1070. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  1071. sampling_params = seq_group.sampling_params
  1072. sampling_type = sampling_params.sampling_type
  1073. categorized_seq_group_ids[sampling_type].append(i)
  1074. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  1075. sample_metadata = {}
  1076. multinomial_samples = {}
  1077. # Create output tensor for sampled token ids.
  1078. if include_gpu_probs_tensor:
  1079. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  1080. 1,
  1081. dtype=torch.long,
  1082. device=logprobs.device)
  1083. else:
  1084. sampled_token_ids_tensor = None
  1085. # Counterintuitively, having two loops here is actually faster.
  1086. # The first loop can run without waiting on GPU<->CPU sync.
  1087. for sampling_type in SamplingType:
  1088. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  1089. num_tokens = len(sample_indices)
  1090. if num_tokens == 0:
  1091. continue
  1092. seq_group_id = categorized_seq_group_ids[sampling_type]
  1093. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  1094. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  1095. long_sample_indices = sample_indices.long()
  1096. if sampling_type == SamplingType.GREEDY:
  1097. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  1098. dim=-1)
  1099. if include_gpu_probs_tensor:
  1100. # Store sampled tokens in output tensor.
  1101. sampled_token_ids_tensor[
  1102. long_sample_indices] = greedy_samples.unsqueeze(-1)
  1103. if modify_greedy_probs:
  1104. # If required, modify the probabilities such that sampling from
  1105. # the modified distribution would always sample the argmax
  1106. # token id.
  1107. _modify_greedy_probs_inplace(logprobs, probs,
  1108. long_sample_indices,
  1109. greedy_samples)
  1110. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1111. max_best_of_in_batch = 1
  1112. for seq_group in seq_groups:
  1113. if seq_group.is_prompt:
  1114. sampling_params = seq_group.sampling_params
  1115. max_best_of_in_batch = max(max_best_of_in_batch,
  1116. sampling_params.best_of)
  1117. seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
  1118. seq_groups)
  1119. if APHRODITE_USE_SAMPLING_KERNELS is not None:
  1120. multinomial_samples[
  1121. sampling_type] = _top_k_top_p_multinomial_with_kernels(
  1122. probs[long_sample_indices],
  1123. sampling_tensors.top_ks[long_sample_indices],
  1124. sampling_tensors.top_ps[long_sample_indices],
  1125. max_best_of_in_batch,
  1126. seq_groups_arg,
  1127. )
  1128. else:
  1129. multinomial_samples[sampling_type] = _multinomial(
  1130. probs[long_sample_indices],
  1131. max_best_of_in_batch,
  1132. seq_groups=seq_groups_arg)
  1133. if sampled_token_ids_tensor is not None:
  1134. # Store sampled tokens in output tensor.
  1135. sampled_token_ids_tensor[long_sample_indices] = \
  1136. multinomial_samples[sampling_type].to(torch.long)
  1137. elif sampling_type == SamplingType.BEAM:
  1138. beam_search_logprobs = logprobs[sample_indices]
  1139. else:
  1140. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  1141. # GPU<->CPU sync happens in the loop below.
  1142. # This also converts the sample output to Python objects.
  1143. if not sampling_metadata.skip_sampler_cpu_output:
  1144. for sampling_type in SamplingType:
  1145. if sampling_type not in sample_metadata:
  1146. continue
  1147. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  1148. if sampling_type == SamplingType.GREEDY:
  1149. sample_results = _greedy_sample(seq_groups, greedy_samples)
  1150. elif sampling_type in (SamplingType.RANDOM,
  1151. SamplingType.RANDOM_SEED):
  1152. sample_results = _random_sample(
  1153. seq_groups, multinomial_samples[sampling_type])
  1154. elif sampling_type == SamplingType.BEAM:
  1155. sample_results = _beam_search_sample(seq_groups,
  1156. beam_search_logprobs)
  1157. sample_results_dict.update(zip(seq_group_id, sample_results))
  1158. sample_results = [
  1159. sample_results_dict.get(i, ([], []))
  1160. for i in range(len(sampling_metadata.seq_groups))
  1161. ]
  1162. else:
  1163. sample_results = []
  1164. return sample_results, sampled_token_ids_tensor
  1165. def _sample_with_triton_kernel(
  1166. probs: torch.Tensor,
  1167. logprobs: torch.Tensor,
  1168. sampling_metadata: SamplingMetadata,
  1169. sampling_tensors: SamplingTensors,
  1170. ) -> List[Tuple[List[int], List[int]]]:
  1171. categorized_seq_group_ids = {t: [] for t in SamplingType}
  1172. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  1173. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  1174. sampling_params = seq_group.sampling_params
  1175. sampling_type = sampling_params.sampling_type
  1176. categorized_seq_group_ids[sampling_type].append(i)
  1177. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  1178. sample_metadata = {}
  1179. max_best_of_in_batch = 1
  1180. # Counterintuitively, having two loops here is actually faster.
  1181. # The first loop can run without waiting on GPU<->CPU sync.
  1182. for sampling_type in SamplingType:
  1183. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  1184. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  1185. num_tokens = len(sample_indices)
  1186. if num_tokens == 0:
  1187. continue
  1188. seq_group_id = categorized_seq_group_ids[sampling_type]
  1189. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  1190. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  1191. sample_indices,
  1192. sampled_token_indices)
  1193. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  1194. SamplingType.RANDOM_SEED):
  1195. for seq_group in seq_groups:
  1196. if seq_group.is_prompt:
  1197. sampling_params = seq_group.sampling_params
  1198. max_best_of_in_batch = max(max_best_of_in_batch,
  1199. sampling_params.best_of)
  1200. elif sampling_type == SamplingType.BEAM:
  1201. beam_search_logprobs = logprobs[sample_indices]
  1202. else:
  1203. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  1204. sampled_tokens, _, _ = sample_triton(
  1205. probs=probs,
  1206. seeds=sampling_tensors.sampling_seeds,
  1207. max_best_of=max_best_of_in_batch,
  1208. sample_indices=sampling_tensors.sample_indices,
  1209. logprobs=logprobs,
  1210. # don't save logprobs because we have logic for that below
  1211. # TODO: use this instead of the CPU-based logic below
  1212. save_logprobs=False,
  1213. )
  1214. # GPU<->CPU sync happens in the loop below.
  1215. for sampling_type in SamplingType:
  1216. if sampling_type not in sample_metadata:
  1217. continue
  1218. (seq_group_id, seq_groups, sample_indices,
  1219. sampled_token_indices) = sample_metadata[sampling_type]
  1220. if sampling_type == SamplingType.GREEDY:
  1221. sample_results = _greedy_sample(
  1222. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  1223. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1224. sample_results = _random_sample(
  1225. seq_groups, sampled_tokens[sampled_token_indices])
  1226. elif sampling_type == SamplingType.BEAM:
  1227. sample_results = _beam_search_sample(seq_groups,
  1228. beam_search_logprobs)
  1229. sample_results_dict.update(zip(seq_group_id, sample_results))
  1230. sample_results = [
  1231. sample_results_dict.get(i, ([], []))
  1232. for i in range(len(sampling_metadata.seq_groups))
  1233. ]
  1234. return sample_results
  1235. def _sample(
  1236. probs: torch.Tensor, logprobs: torch.Tensor,
  1237. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  1238. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  1239. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  1240. """
  1241. Args:
  1242. probs: (num_query_tokens_in_batch, num_vocab)
  1243. logprobs: (num_query_tokens_in_batch, num_vocab)
  1244. sampling_metadata: The metadata for a batch for sampling.
  1245. sampling_tensors: Tensors that include sampling related metadata.
  1246. Returns:
  1247. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  1248. If sampling is skipped, it returns ([], [])
  1249. sampled_token_ids_tensor: A tensor of sampled token ids.
  1250. """
  1251. return _sample_with_torch(
  1252. probs,
  1253. logprobs,
  1254. sampling_metadata,
  1255. sampling_tensors,
  1256. include_gpu_probs_tensor=include_gpu_probs_tensor,
  1257. modify_greedy_probs=modify_greedy_probs,
  1258. )
  1259. # TODO: Enable once Triton kernel & associated code is faster.
  1260. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  1261. # sampling_tensors)
  1262. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  1263. """
  1264. This function calculates the ranks of the chosen tokens in a logprob tensor.
  1265. Args:
  1266. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  1267. where N is the no. of tokens and M is the vocab dim.
  1268. indices (torch.Tensor): List of chosen token indices.
  1269. Returns:
  1270. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  1271. Each element in the returned tensor represents the rank
  1272. of the chosen token in the input logprob tensor.
  1273. """
  1274. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  1275. indices]
  1276. return (x > vals[:, None]).long().sum(1).add_(1)
  1277. def _get_logprobs(
  1278. logprobs: torch.Tensor,
  1279. sampling_metadata: SamplingMetadata,
  1280. sample_results: List[Tuple[List[int], List[int]]],
  1281. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  1282. """Return sample lobprobs and prompt logprobs.
  1283. The logic consists of 3 parts.
  1284. - Select indices to compute logprob from, ranks of token ids, and
  1285. the top k token ids from logprobs.
  1286. - Compute prompt logprobs if required.
  1287. - Compute sample logprobs if required.
  1288. Args:
  1289. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  1290. logprob per vocab. Sequence groups' query tokens are batched in a
  1291. single flattened tensor. For example, assuming there are N
  1292. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  1293. prompt logprob is enabled), decode tokens for seq_group_1 (if
  1294. sampling is required), prefill tokens for seq_group_2, ...
  1295. sampling_metadata: The sampling metadata.
  1296. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  1297. parent_ids) for each sequence group. When beam search is enabled,
  1298. sample_results can contain different number of seq_ids from
  1299. sampling_metadata.seq_groups. It is because beam search creates
  1300. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  1301. BEAM_WIDTH number of seq_ids).
  1302. Returns:
  1303. A tuple of prompt and sample logprobs per sequence group in a batch.
  1304. """
  1305. # The index of query token to calculate logprobs. It includes both
  1306. # prompt and sample logprob indices.
  1307. query_indices: List[int] = []
  1308. # The next token ids to get the logprob value from.
  1309. next_token_ids: List[int] = []
  1310. # The largest requested number of logprobs. We find logprobs as many as the
  1311. # largest num logprobs in this API. If every logprobs is None, it will be
  1312. # set to -1.
  1313. largest_num_logprobs = -1
  1314. # If beam search is enabled.
  1315. use_beam_search = False
  1316. # Select indices to compute logprob from, ranks of token ids, and the top
  1317. # k token ids from logprobs.
  1318. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  1319. sample_results):
  1320. sampling_params = seq_group.sampling_params
  1321. # Update indices and tokens for prompt logprobs.
  1322. if (seq_group.is_prompt
  1323. and sampling_params.prompt_logprobs is not None):
  1324. largest_num_logprobs = max(largest_num_logprobs,
  1325. sampling_params.prompt_logprobs)
  1326. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1327. query_indices.extend(seq_group.prompt_logprob_indices)
  1328. next_token_ids.extend(next_prompt_tokens)
  1329. # Update indices and next tokenes for sample logprob.
  1330. if seq_group.do_sample:
  1331. token_ids, parent_seq_ids = sample_result
  1332. # NOTE: We cannot directly use sample_indices because
  1333. # sample_indices only contain parent seq_ids of a previous step.
  1334. # The current step may have different number of seq_ids, and
  1335. # we can obtain it from `sample_result[1]`.
  1336. query_idx = seq_group.sample_indices[0]
  1337. query_indices.extend(
  1338. [query_idx + parent_id for parent_id in parent_seq_ids])
  1339. next_token_ids.extend(token_ids)
  1340. if sampling_params.logprobs is not None:
  1341. largest_num_logprobs = max(largest_num_logprobs,
  1342. sampling_params.logprobs)
  1343. use_beam_search = use_beam_search or sampling_params.use_beam_search
  1344. assert len(next_token_ids) == len(query_indices)
  1345. if len(query_indices) == 0:
  1346. empty_sampled_logprob = []
  1347. empty_prompt_logprob = None
  1348. return [empty_prompt_logprob], [empty_sampled_logprob]
  1349. selected_logprobs, ranks = None, None
  1350. top_logprobs, top_token_ids = None, None
  1351. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  1352. # skip the whole logprob calculation.
  1353. if largest_num_logprobs >= 0 or use_beam_search:
  1354. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  1355. next_token_ids_gpu = torch.tensor(next_token_ids,
  1356. device=logprobs.device)
  1357. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  1358. # contain duplicates if beam search is enabled.
  1359. selected_logprobs = logprobs[[
  1360. query_indices_gpu,
  1361. next_token_ids_gpu,
  1362. ]]
  1363. ranks = _get_ranks(
  1364. logprobs[query_indices_gpu],
  1365. next_token_ids_gpu,
  1366. )
  1367. assert selected_logprobs.shape[0] == ranks.shape[0]
  1368. # We need to compute top k only if there exists logprobs > 0.
  1369. if largest_num_logprobs > 0:
  1370. # Logprobs of topk tokens for a batch of sequence groups.
  1371. # (num_query_tokens_across_batch).
  1372. top_logprobs, top_token_ids = torch.topk(logprobs,
  1373. largest_num_logprobs,
  1374. dim=-1)
  1375. top_logprobs = top_logprobs.to('cpu')
  1376. top_token_ids = top_token_ids.to('cpu')
  1377. selected_logprobs = selected_logprobs.to('cpu')
  1378. ranks = ranks.to('cpu')
  1379. # Find prompt/sample logprobs.
  1380. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  1381. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  1382. top_logprob_idx = 0
  1383. selected_logprobs_idx = 0
  1384. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1385. sample_results):
  1386. (prompt_logprobs, top_logprob_idx,
  1387. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1388. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1389. selected_logprobs_idx, top_logprob_idx)
  1390. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1391. (sampled_logprobs, top_logprob_idx,
  1392. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1393. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1394. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1395. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1396. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1397. def _get_prompt_logprob_if_needed(
  1398. seq_group: SequenceGroupToSample,
  1399. selected_logprobs: torch.Tensor,
  1400. ranks: torch.Tensor,
  1401. top_token_ids: torch.Tensor,
  1402. top_logprobs: torch.Tensor,
  1403. selected_logprobs_idx: int,
  1404. top_logprob_idx: int,
  1405. ):
  1406. """Compute the prompt logprob from a sequence group if needed."""
  1407. sampling_params = seq_group.sampling_params
  1408. is_prompt = seq_group.is_prompt
  1409. # Find prompt logprobs
  1410. prompt_logprobs: Optional[PromptLogprobs] = None
  1411. if is_prompt and sampling_params.prompt_logprobs is not None:
  1412. prompt_logprobs = []
  1413. num_logprobs = sampling_params.prompt_logprobs
  1414. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1415. # Pre-select indexes and create a list. It is faster than calling .item
  1416. # repetitively.
  1417. selected_logprob_items = selected_logprobs[
  1418. selected_logprobs_idx:selected_logprobs_idx +
  1419. len(next_prompt_tokens)].tolist()
  1420. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1421. len(next_prompt_tokens)].tolist()
  1422. for idx, token_id in enumerate(next_prompt_tokens):
  1423. # Calculate the prompt logprob of the real prompt tokens.
  1424. # {token_id: (logprob, rank_from_vocab)}
  1425. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1426. token_id: (selected_logprob_items[idx], rank_items[idx])
  1427. }
  1428. # Add top K prompt logprobs along with its rank.
  1429. if num_logprobs > 0:
  1430. top_ids = top_token_ids[
  1431. top_logprob_idx, :num_logprobs].tolist()
  1432. top_probs = top_logprobs[
  1433. top_logprob_idx, :num_logprobs].tolist()
  1434. # Top K is already sorted by rank, so we can use 1 ~
  1435. # num_logprobs + 1 for rank.
  1436. top_ranks = range(1, num_logprobs + 1)
  1437. prompt_logprobs_dict.update({
  1438. top_id: (top_prob, rank)
  1439. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1440. top_ranks)
  1441. })
  1442. prompt_logprobs.append({
  1443. token_id: Logprob(*logprob_and_rank)
  1444. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1445. })
  1446. # + 1 to go to the next prompt token.
  1447. top_logprob_idx += 1
  1448. # + len(next_prompt_tokens) to go to the next prompt.
  1449. selected_logprobs_idx += len(next_prompt_tokens)
  1450. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1451. def _get_sampled_logprob_if_needed(
  1452. seq_group: SequenceGroupToSample,
  1453. sample_result: Tuple[List[int], List[int]],
  1454. selected_logprobs: torch.Tensor,
  1455. ranks: torch.Tensor,
  1456. top_token_ids: torch.Tensor,
  1457. top_logprobs: torch.Tensor,
  1458. selected_logprobs_idx: int,
  1459. top_logprob_idx: int,
  1460. ):
  1461. """Compute the sample logprob if needed."""
  1462. seq_ids = seq_group.seq_ids
  1463. num_logprobs = seq_group.sampling_params.logprobs
  1464. use_beam_search = seq_group.sampling_params.use_beam_search
  1465. sampled_logprobs: SampleLogprobs = []
  1466. next_token_ids, parent_seq_ids = sample_result
  1467. if seq_group.do_sample:
  1468. assert len(next_token_ids) > 0
  1469. if num_logprobs is None and not use_beam_search:
  1470. for next_token_id in next_token_ids:
  1471. # Use a dummy logprob
  1472. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1473. else:
  1474. # Pre-select items from tensor. tolist() is faster than repetitive
  1475. # `.item()` calls.
  1476. selected_logprob_items = selected_logprobs[
  1477. selected_logprobs_idx:selected_logprobs_idx +
  1478. len(next_token_ids)].tolist()
  1479. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1480. len(next_token_ids)].tolist()
  1481. for idx, (next_token_id, parent_id) in enumerate(
  1482. zip(next_token_ids, parent_seq_ids)):
  1483. # Get the logprob of a sampled token.
  1484. sampled_logprobs_dict = {
  1485. next_token_id:
  1486. (selected_logprob_items[idx], rank_items[idx])
  1487. }
  1488. if num_logprobs is not None and num_logprobs > 0:
  1489. # Get top K logprobs.
  1490. top_ids = top_token_ids[top_logprob_idx +
  1491. parent_id, :num_logprobs].tolist()
  1492. top_probs = top_logprobs[
  1493. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1494. # Top K is already sorted by rank, so we can use 1 ~
  1495. # num_logprobs + 1 for rank.
  1496. top_ranks = range(1, num_logprobs + 1)
  1497. sampled_logprobs_dict.update({
  1498. top_id: (top_prob, rank)
  1499. for top_id, top_prob, rank in zip(
  1500. top_ids, top_probs, top_ranks)
  1501. })
  1502. sampled_logprobs.append({
  1503. token_id: Logprob(*logprob_and_rank)
  1504. for token_id, logprob_and_rank in
  1505. sampled_logprobs_dict.items()
  1506. })
  1507. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1508. # logprobs for the current step, which has len(next_token_ids) tokens
  1509. # per sequence group. `logprobs` includes logprobs from the previous
  1510. # steps, which has len(seq_ids) tokens per sequence group.
  1511. # Iterate to the next sequence group in a batch.
  1512. selected_logprobs_idx += len(next_token_ids)
  1513. # Iterate to the next sequence group in a batch.
  1514. top_logprob_idx += len(seq_ids)
  1515. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1516. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1517. sample_indices: torch.Tensor,
  1518. greedy_samples: torch.Tensor) -> None:
  1519. """Modify the probability distributions of the greedily-sampled tokens such
  1520. that each sampled token has a "probability" of 1.0. This is required by
  1521. speculative decoding, which depends on the sampling method being encoded
  1522. within the probability distribution for correctness.
  1523. # Why do we only need to do this for greedy sampling?
  1524. Aphrodite's sampler performs the following steps for greedy or multinomial
  1525. (random) sampling:
  1526. 1. Get logits from model.
  1527. 2. Modify logits according to per-sequence sampling parameters.
  1528. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1529. according to their frequency, etc.
  1530. 3. Sample a token.
  1531. - Random sampling simply samples from the modified probability
  1532. distribution.
  1533. - Greedy sampling performs `argmax` to obtain the token with the
  1534. highest likelihood.
  1535. Ignoring greedy sampling for a moment, we find that the computed probability
  1536. distribution has the following property: we can sample from it independently
  1537. and find that the token sampled by the Sampler has a frequency corresponding
  1538. to how often we see it in our sampling. In other words, for tokens sampled
  1539. with Aphrodite's random SamplingType, the computed probability distribution
  1540. encodes the sampling methodology completely.
  1541. Greedy sampling does not normally have this property. Aphrodite modifies
  1542. logits according to sampling params, then performs `argmax`, then returns
  1543. the sampled token and the computed probability distribution. If we sample
  1544. from the distribution, we'll find the likelihood of the greedily-sampled
  1545. token is not always 1.0.
  1546. Since lossless speculative decoding requires that the sampling methodology
  1547. be encoded within the probability distribution, we are motivated to modify
  1548. the probability distribution such that the sampled token has probability 1
  1549. when speculative decoding is used.
  1550. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1551. greedy sampling using multinomial computation and unite the codepaths. This
  1552. has implications on the overall design of the sampler, e.g. how to record
  1553. accurate logprobs for the user, so this improvement is deferred to later.
  1554. """
  1555. # NOTE: logprobs are not modified so they can be returned to the user.
  1556. probs[sample_indices, :] = 0
  1557. probs[sample_indices, greedy_samples] = 1.0
  1558. def _build_sampler_output(
  1559. sample_results: SampleResultType,
  1560. sampling_metadata: SamplingMetadata,
  1561. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1562. sample_logprobs: Optional[List[SampleLogprobs]],
  1563. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1564. torch.Tensor]],
  1565. skip_sampler_cpu_output: bool = False,
  1566. ) -> SamplerOutput:
  1567. """Construct Python objects with the output of sampling.
  1568. Args:
  1569. on_device_tensors: Tuple containing on-device tensors with the
  1570. probabilities used in sampling and the sampled token ids. This
  1571. allows post-processing without copies to CPU/serialization, e.g. in
  1572. speculative decoding rejection sampling.
  1573. """
  1574. sampler_output: List[CompletionSequenceGroupOutput] = []
  1575. if not skip_sampler_cpu_output:
  1576. assert prompt_logprobs is not None
  1577. assert sample_logprobs is not None
  1578. for (seq_group, sample_result, group_prompt_logprobs,
  1579. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1580. sample_results, prompt_logprobs,
  1581. sample_logprobs):
  1582. seq_ids = seq_group.seq_ids
  1583. next_token_ids, parent_ids = sample_result
  1584. seq_outputs: List[SequenceOutput] = []
  1585. for parent_id, next_token_id, logprobs in zip(
  1586. parent_ids, next_token_ids, group_sample_logprobs):
  1587. seq_outputs.append(
  1588. SequenceOutput(seq_ids[parent_id], next_token_id,
  1589. logprobs))
  1590. sampler_output.append(
  1591. CompletionSequenceGroupOutput(seq_outputs,
  1592. group_prompt_logprobs))
  1593. # If not specified, store None values in SamplerOutput.
  1594. if on_device_tensors is not None:
  1595. (sampled_token_probs, logprobs_tensor,
  1596. sampled_token_ids) = on_device_tensors
  1597. else:
  1598. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1599. None)
  1600. return SamplerOutput(
  1601. outputs=sampler_output,
  1602. sampled_token_probs=sampled_token_probs,
  1603. sampled_token_ids=sampled_token_ids,
  1604. logprobs=logprobs_tensor,
  1605. )
  1606. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1607. """Get a list of next prompt tokens to compute logprob from a
  1608. given sequence group.
  1609. It is used to compute prompt logprob. Imagine you have logprob for each
  1610. query token. Query token needs to know the next prompt token id to compute
  1611. prompt logprob. This is a helper to obtain next prompt token ids.
  1612. This API has to be used only when the caller knows seq_group is in prefill
  1613. stage.
  1614. Returns:
  1615. A list of next prompt tokens to compute logprob.
  1616. """
  1617. assert seq_group.is_prompt, (
  1618. "Caller should ensure the sequence group is in a prefill stage.")
  1619. seq_ids = seq_group.seq_ids
  1620. query_len = seq_group.query_len
  1621. assert query_len is not None
  1622. # prompt has only 1 seq id.
  1623. assert len(seq_ids) == 1
  1624. seq_data = seq_group.seq_data[seq_ids[0]]
  1625. computed_len = seq_data.get_num_computed_tokens()
  1626. prompt_tokens = seq_data.prompt_token_ids
  1627. # +1 because we are looking for a next prompt token.
  1628. next_token_index_start = computed_len + 1
  1629. next_token_index_end = min(computed_len + query_len + 1,
  1630. len(prompt_tokens))
  1631. next_prompt_tokens = prompt_tokens[
  1632. next_token_index_start:next_token_index_end]
  1633. return next_prompt_tokens
  1634. def _get_ngrams(
  1635. ngram_size: int,
  1636. prev_input_ids: torch.Tensor
  1637. ) -> Dict[Tuple[int, ...], List[int]]:
  1638. """Get dictionary of ngrams and the tokens that followed them.
  1639. Args:
  1640. ngram_size: Size of ngrams to track
  1641. prev_input_ids: 1D tensor of previous token ids
  1642. Returns:
  1643. Dictionary mapping ngram tuples to list of tokens that followed them
  1644. """
  1645. generated_ngrams = {}
  1646. gen_tokens = prev_input_ids.tolist()
  1647. for i in range(len(gen_tokens) - ngram_size + 1):
  1648. ngram = tuple(gen_tokens[i:i + ngram_size - 1])
  1649. next_token = gen_tokens[i + ngram_size - 1]
  1650. if ngram in generated_ngrams:
  1651. generated_ngrams[ngram].append(next_token)
  1652. else:
  1653. generated_ngrams[ngram] = [next_token]
  1654. return generated_ngrams
  1655. def _get_generated_ngrams(
  1656. banned_ngrams: Dict[Tuple[int, ...], List[int]],
  1657. prev_input_ids: torch.Tensor,
  1658. ngram_size: int,
  1659. cur_len: int
  1660. ) -> List[int]:
  1661. """Get list of tokens that would create a repeated ngram if generated next.
  1662. Args:
  1663. banned_ngrams: Dictionary of previously seen ngrams and their next
  1664. tokens
  1665. prev_input_ids: Previous token ids
  1666. ngram_size: Size of ngrams to check
  1667. cur_len: Current position in sequence
  1668. Returns:
  1669. List of token ids that would create a repeat ngram
  1670. """
  1671. start_idx = cur_len + 1 - ngram_size
  1672. current_ngram = tuple(prev_input_ids[start_idx:cur_len].tolist())
  1673. return banned_ngrams.get(current_ngram, [])
  1674. def _calc_banned_ngram_tokens(
  1675. ngram_size: int,
  1676. prev_input_ids: torch.Tensor,
  1677. cur_len: int
  1678. ) -> List[int]:
  1679. """Calculate tokens that would create repeated ngrams if generated next.
  1680. Args:
  1681. ngram_size: Size of ngrams to prevent repeating
  1682. prev_input_ids: Previous token ids in sequence
  1683. cur_len: Current position in sequence
  1684. Returns:
  1685. List of token ids that should be banned to prevent ngram repetition
  1686. """
  1687. if cur_len + 1 < ngram_size:
  1688. return []
  1689. generated_ngrams = _get_ngrams(ngram_size, prev_input_ids)
  1690. banned_tokens = _get_generated_ngrams(
  1691. generated_ngrams,
  1692. prev_input_ids,
  1693. ngram_size,
  1694. cur_len
  1695. )
  1696. return banned_tokens
  1697. # def _apply_mirostat_v2(logits: torch.Tensor,
  1698. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1699. # # Reduce our view to just the affected logits
  1700. # logit_view = logits[sampling_tensors.miro_indices]
  1701. # # Calculate surprise value per token
  1702. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1703. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1704. # # Mask out "too-surprising" tokens (surprisal > mu)
  1705. # mus = sampling_tensors.miro_mus
  1706. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1707. # # Unmask most-likely logit to guarantee a selection.
  1708. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1709. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1710. # # Apply logit mask (effectively a top-k filter).
  1711. # logit_view[miro_mask] = -float("inf")
  1712. # # Project logit changes made to the view onto the original.
  1713. # # I think this step might be redundant.
  1714. # logits[sampling_tensors.miro_indices] = logit_view
  1715. # return logits
  1716. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1717. # sample_results: List[Tuple[List[int], List[int]]],
  1718. # sampling_metadata: SamplingMetadata,
  1719. # output_metadata: OutputMetadata) -> None:
  1720. # """Based on whichever token was finally sampled, we calculate the
  1721. # final surprisal values to update the mus.
  1722. # Because a single sequence can have multiple samples, we must fork
  1723. # the mu accordingly."""
  1724. # assert sampling_metadata.seq_groups is not None
  1725. # seqid_to_tokens = {}
  1726. # seqid_to_indices = {}
  1727. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1728. # sample_results):
  1729. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1730. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1731. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1732. # seqids = args.miro_seqids
  1733. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1734. # device=logits.device,
  1735. # dtype=torch.long)
  1736. # # Clumsily, we recalculate token surprisals.
  1737. # logits_view = logits[args.miro_indices]
  1738. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1739. # dim=-1,
  1740. # index=picked_tokens) / -math.log(2)
  1741. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1742. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1743. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1744. # nu_mus = mus - (picked_surprise - taus) * etas
  1745. # # Record updated mu values for use in the next iteration
  1746. # # Note how each mu is split into multiple based on the number of samples.
  1747. # for seqid, seq_mus in zip(seqids, nu_mus):
  1748. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1749. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)