sampler.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import warnings
  4. from enum import IntEnum
  5. from math import inf
  6. from typing import Dict, List, Optional, Tuple
  7. import torch
  8. import torch.nn as nn
  9. from loguru import logger
  10. import aphrodite._custom_ops as ops
  11. import aphrodite.common.envs as envs
  12. from aphrodite.common.sampling_params import SamplingType
  13. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  14. PromptLogprobs, SampleLogprobs,
  15. SamplerOutput, SequenceOutput)
  16. from aphrodite.triton_utils import HAS_TRITON
  17. if HAS_TRITON:
  18. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  19. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  20. SamplingTensors,
  21. SequenceGroupToSample)
  22. # (num_token_ids, num_parent_ids) per sequence group.
  23. SampleResultType = List[Tuple[List[int], List[int]]]
  24. # There isn't a "safe" temperature range for fp16 logits.
  25. # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
  26. # that this temperature well-uses the fp16 space after the logits are offset.
  27. _TEMPERATURE_MINIMUM = 2e-5
  28. # If enabled, we switch to a more performant implementation
  29. # of top-k and top-p
  30. APHRODITE_USE_SAMPLING_KERNELS = envs.APHRODITE_USE_SAMPLING_KERNELS
  31. class SamplerID(IntEnum):
  32. # Mirror these in aphrodite/common/sampling_params.py
  33. # Values out of order to keep backwards compatibility
  34. # with Koboldcpp values
  35. DRY = 7
  36. PENALTIES = 6
  37. NO_REPEAT_NGRAM = 8
  38. TEMPERATURE = 5
  39. TOP_NSIGMA = 9
  40. TOP_P_TOP_K = 0
  41. TOP_A = 1
  42. MIN_P = 2
  43. TFS = 3
  44. ETA_CUTOFF = 10
  45. EPSILON_CUTOFF = 11
  46. TYPICAL_P = 4
  47. QUADRATIC = 12
  48. XTC = 13
  49. class Sampler(nn.Module):
  50. """Samples the next tokens from the model's outputs.
  51. This layer does the following:
  52. 1. Discard the hidden states that are not used for sampling (i.e., all
  53. tokens except the final one in each prompt).
  54. 2. Compute the logits for the next tokens.
  55. 3. Apply presence, frequency and repetition penalties.
  56. 4. Apply temperature scaling.
  57. 5. Apply top-p and top-k truncation.
  58. 6. Sample the next tokens.
  59. Here, each sequence group within the batch can have different sampling
  60. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  61. The structure of the logits tensor is coupled with the seq_groups in
  62. sampling_metadata. Typically, each sequence in each seq_group has one row in
  63. logits for the next token to be sampled; however, for a seq_group with a
  64. prompt request with the prompt_logprobs sampling parameter, there are rows
  65. in logits for each token in the input prompt.
  66. """
  67. def __init__(self):
  68. super().__init__()
  69. # Whether or not the SamplerOutput should have on-device tensors
  70. # containing the sampled token ids and probabilities. This is used by
  71. # speculative decoding.
  72. self.include_gpu_probs_tensor = False
  73. self.should_modify_greedy_probs_inplace = False
  74. def _init_sampling_tensors(
  75. self,
  76. logits: torch.Tensor,
  77. sampling_metadata: SamplingMetadata,
  78. ):
  79. """The goal here is to reuse sampling tensors between similar decode
  80. runs. This is possible because sampling logic does not change between
  81. decodes of the same sequences.
  82. """
  83. _, vocab_size = logits.shape
  84. # First free any existing stored sampling tensors.
  85. # This is necessary because some sampling tensors may
  86. # have pinned memory.
  87. self._sampling_tensors = None
  88. # Initialize new sampling tensors
  89. (sampling_tensors, do_penalties, do_no_repeat_ngrams, do_temperatures,
  90. do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
  91. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc, do_nsigmas,
  92. do_dry, do_skew, do_temp_last
  93. ) = SamplingTensors.from_sampling_metadata(
  94. sampling_metadata, vocab_size, logits.device, logits.dtype)
  95. self._sampling_tensors = sampling_tensors
  96. self._do_penalties = do_penalties
  97. self._do_no_repeat_ngrams = do_no_repeat_ngrams
  98. self._do_temperatures = do_temperatures
  99. self._do_top_p_top_k = do_top_p_top_k
  100. self._do_top_as = do_top_as
  101. self._do_min_p = do_min_p
  102. self._do_tfss = do_tfss
  103. self._do_eta_cutoffs = do_eta_cutoffs
  104. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  105. self._do_typical_ps = do_typical_ps
  106. self._do_quadratic = do_quadratic
  107. self._do_xtc = do_xtc
  108. self._do_nsgimas = do_nsigmas
  109. self._do_dry = do_dry
  110. self._do_skew = do_skew
  111. self._do_temp_last = do_temp_last
  112. def forward(
  113. self,
  114. logits: torch.Tensor,
  115. sampling_metadata: SamplingMetadata,
  116. ) -> Optional[SamplerOutput]:
  117. """
  118. Args:
  119. logits: (num_tokens, vocab_size).
  120. sampling_metadata: Metadata for sampling.
  121. """
  122. assert logits is not None
  123. _, vocab_size = logits.shape
  124. # Prepare sampling tensors with pinned memory to avoid blocking.
  125. if not sampling_metadata.reuse_sampling_tensors:
  126. self._init_sampling_tensors(logits, sampling_metadata)
  127. elif self._do_penalties or self._do_dry:
  128. # In this case, the sampling tensors logic depends on
  129. # "output_tokens" of a sequence. As a result, we cannot
  130. # reuse sampling tensors, since "output_tokens" changes
  131. # between decode runs.
  132. self._init_sampling_tensors(logits, sampling_metadata)
  133. assert self._sampling_tensors is not None
  134. sampling_tensors = self._sampling_tensors
  135. do_penalties = self._do_penalties
  136. do_no_repeat_ngrams = self._do_no_repeat_ngrams
  137. do_temperatures = self._do_temperatures
  138. do_top_p_top_k = self._do_top_p_top_k
  139. do_top_as = self._do_top_as
  140. do_min_p = self._do_min_p
  141. do_tfss = self._do_tfss
  142. do_eta_cutoffs = self._do_eta_cutoffs
  143. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  144. do_typical_ps = self._do_typical_ps
  145. do_quadratic = self._do_quadratic
  146. do_xtc = self._do_xtc
  147. do_nsigmas = self._do_nsgimas
  148. do_dry = self._do_dry
  149. do_skew = self._do_skew
  150. do_temp_last = self._do_temp_last
  151. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  152. banned_tokens = _get_custom_token_bans(sampling_metadata)
  153. logits = _apply_token_bans(logits, banned_tokens)
  154. sampler_order = None
  155. if sampling_metadata.seq_groups:
  156. sampler_order = sampling_metadata.seq_groups[
  157. 0].sampling_params.sampler_priority
  158. # Warn if both custom order and temp_last are specified
  159. if (sampling_metadata.seq_groups and
  160. sampling_metadata.seq_groups[0].is_prompt and
  161. sampler_order is not None and
  162. do_temp_last):
  163. logger.warning(
  164. "Both sampler_priority and temperature_last=True "
  165. "were specified. Using custom sampler_priority order "
  166. "and ignoring temperature_last.")
  167. if sampler_order is None:
  168. default_order = [
  169. SamplerID.DRY,
  170. SamplerID.PENALTIES,
  171. SamplerID.NO_REPEAT_NGRAM,
  172. SamplerID.TEMPERATURE,
  173. SamplerID.TOP_NSIGMA,
  174. SamplerID.TOP_P_TOP_K,
  175. SamplerID.TOP_A,
  176. SamplerID.MIN_P,
  177. SamplerID.TFS,
  178. SamplerID.ETA_CUTOFF,
  179. SamplerID.EPSILON_CUTOFF,
  180. SamplerID.TYPICAL_P,
  181. SamplerID.QUADRATIC,
  182. SamplerID.XTC,
  183. ]
  184. sampler_order = []
  185. for sampler_id in default_order:
  186. if sampler_id == SamplerID.TEMPERATURE and do_temp_last:
  187. continue
  188. sampler_order.append(sampler_id)
  189. if sampler_id == SamplerID.XTC and do_temp_last:
  190. sampler_order.append(SamplerID.TEMPERATURE)
  191. if sampling_metadata.seq_groups and sampling_metadata.seq_groups[
  192. 0].is_prompt:
  193. logger.debug("Sampler execution order: ")
  194. for i, sampler_id in enumerate(sampler_order, 1):
  195. logger.debug(f"{i}. {SamplerID(sampler_id).name}")
  196. enabled_samplers = []
  197. # ruff: noqa: E701
  198. if do_penalties: enabled_samplers.append("PENALTIES")
  199. if do_no_repeat_ngrams: enabled_samplers.append("NO_REPEAT_NGRAM")
  200. if do_temperatures: enabled_samplers.append("TEMPERATURE")
  201. if do_top_p_top_k: enabled_samplers.append("TOP_P_TOP_K")
  202. if do_top_as: enabled_samplers.append("TOP_A")
  203. if do_min_p: enabled_samplers.append("MIN_P")
  204. if do_tfss: enabled_samplers.append("TFS")
  205. if do_eta_cutoffs: enabled_samplers.append("ETA_CUTOFF")
  206. if do_epsilon_cutoffs: enabled_samplers.append("EPSILON_CUTOFF")
  207. if do_typical_ps: enabled_samplers.append("TYPICAL_P")
  208. if do_quadratic: enabled_samplers.append("QUADRATIC")
  209. if do_xtc: enabled_samplers.append("XTC")
  210. if do_nsigmas: enabled_samplers.append("TOP_NSIGMA")
  211. if do_dry: enabled_samplers.append("DRY")
  212. if do_skew: enabled_samplers.append("SKEW")
  213. logger.debug(f"Enabled samplers: {', '.join(enabled_samplers)}")
  214. for sampler_id in sampler_order:
  215. if sampler_id == SamplerID.DRY and do_dry:
  216. if (sampling_metadata.seq_groups and
  217. sampling_metadata.seq_groups[0].is_prompt):
  218. logger.debug(
  219. f"Applying DRY with dry_multiplier: "
  220. f"{sampling_tensors.dry_multipliers}.")
  221. logits = _apply_dry(
  222. logits,
  223. sampling_tensors.prompt_tokens,
  224. sampling_tensors.output_tokens,
  225. sampling_tensors.dry_multipliers,
  226. sampling_tensors.dry_bases,
  227. sampling_tensors.dry_allowed_lengths,
  228. sampling_tensors.dry_sequence_breaker_ids,
  229. sampling_tensors.dry_ranges)
  230. elif sampler_id == SamplerID.PENALTIES and do_penalties:
  231. if (sampling_metadata.seq_groups and
  232. sampling_metadata.seq_groups[0].is_prompt):
  233. logger.debug(
  234. "Applying penalties with "
  235. f"pres_pen: {sampling_tensors.presence_penalties}, "
  236. f"freq_pen: {sampling_tensors.frequency_penalties}, "
  237. f"rep_pen: {sampling_tensors.repetition_penalties}.")
  238. logits = _apply_penalties(
  239. logits, sampling_tensors.prompt_tokens,
  240. sampling_tensors.output_tokens,
  241. sampling_tensors.presence_penalties,
  242. sampling_tensors.frequency_penalties,
  243. sampling_tensors.repetition_penalties)
  244. elif sampler_id == SamplerID.NO_REPEAT_NGRAM and \
  245. do_no_repeat_ngrams:
  246. if (sampling_metadata.seq_groups and
  247. sampling_metadata.seq_groups[0].is_prompt):
  248. logger.debug(
  249. "Applying no_repeat_ngram with no_repeat_ngram_size: "
  250. f"{sampling_tensors.no_repeat_ngram_sizes}.")
  251. logits = _apply_no_repeat_ngram(
  252. logits,
  253. sampling_tensors.prompt_tokens,
  254. sampling_tensors.no_repeat_ngram_sizes)
  255. elif sampler_id == SamplerID.TEMPERATURE and do_temperatures:
  256. if (sampling_metadata.seq_groups and
  257. sampling_metadata.seq_groups[0].is_prompt):
  258. logger.debug(
  259. "Applying temperatures with temperature: "
  260. f"{sampling_tensors.temperatures}, "
  261. f"dynatemp_min: {sampling_tensors.dynatemp_mins}, "
  262. f"dynatemp_max: {sampling_tensors.dynatemp_maxs}, "
  263. f"dynamtep_exp: {sampling_tensors.dynatemp_exps}.")
  264. _apply_temperatures(
  265. logits, sampling_tensors.temperatures,
  266. sampling_tensors.dynatemp_mins,
  267. sampling_tensors.dynatemp_maxs,
  268. sampling_tensors.dynatemp_exps)
  269. elif sampler_id == SamplerID.TOP_NSIGMA and do_nsigmas:
  270. if (sampling_metadata.seq_groups and
  271. sampling_metadata.seq_groups[0].is_prompt):
  272. logger.debug(
  273. "Applying Top-Nsigma with nsigma: "
  274. f"{sampling_tensors.nsigmas}")
  275. logits = _apply_top_nsigma(
  276. logits, sampling_tensors.nsigmas)
  277. elif sampler_id == SamplerID.TOP_P_TOP_K and do_top_p_top_k and \
  278. not APHRODITE_USE_SAMPLING_KERNELS:
  279. if (sampling_metadata.seq_groups and
  280. sampling_metadata.seq_groups[0].is_prompt):
  281. logger.debug(
  282. "Applying Top-p and Top-k with top-p: "
  283. f"{sampling_tensors.top_ps}, top_k: "
  284. f"{sampling_tensors.top_ks}.")
  285. logits = _apply_top_k_top_p(
  286. logits, sampling_tensors.top_ps,
  287. sampling_tensors.top_ks)
  288. elif sampler_id == SamplerID.TOP_A and do_top_as:
  289. if (sampling_metadata.seq_groups and
  290. sampling_metadata.seq_groups[0].is_prompt):
  291. logger.debug(
  292. "Applying Top-a with Top-a: "
  293. f"{sampling_tensors.top_as}.")
  294. logits = _apply_top_a(
  295. logits, sampling_tensors.top_as)
  296. elif sampler_id == SamplerID.MIN_P and do_min_p:
  297. if (sampling_metadata.seq_groups and
  298. sampling_metadata.seq_groups[0].is_prompt):
  299. logger.debug(
  300. "Applying Min-p with Min-p: "
  301. f"{sampling_tensors.min_ps}.")
  302. logits = _apply_min_p(
  303. logits, sampling_tensors.min_ps)
  304. elif sampler_id == SamplerID.TFS and do_tfss:
  305. if (sampling_metadata.seq_groups and
  306. sampling_metadata.seq_groups[0].is_prompt):
  307. logger.debug(
  308. "Applying Tail-Free Sampling with tfs: "
  309. f"{sampling_tensors.tfss}.")
  310. logits = _apply_tfs(
  311. logits, sampling_tensors.tfss)
  312. elif sampler_id == SamplerID.ETA_CUTOFF and do_eta_cutoffs:
  313. if (sampling_metadata.seq_groups and
  314. sampling_metadata.seq_groups[0].is_prompt):
  315. logger.debug(
  316. "Applying ETA Cutoff with eta_cutoff: "
  317. f"{sampling_tensors.eta_cutoffs}.")
  318. logits = _apply_eta_cutoff(
  319. logits, sampling_tensors.eta_cutoffs)
  320. elif sampler_id == SamplerID.EPSILON_CUTOFF and do_epsilon_cutoffs:
  321. if (sampling_metadata.seq_groups and
  322. sampling_metadata.seq_groups[0].is_prompt):
  323. logger.debug(
  324. "Applying Epsilon Cutoff with epsilon_cutoff: "
  325. f"{sampling_tensors.epsilon_cutoffs}.")
  326. logits = _apply_epsilon_cutoff(
  327. logits, sampling_tensors.epsilon_cutoffs)
  328. elif sampler_id == SamplerID.TYPICAL_P and do_typical_ps:
  329. if (sampling_metadata.seq_groups and
  330. sampling_metadata.seq_groups[0].is_prompt):
  331. logger.debug(
  332. "Applying Locally Typical Sampling with typical_p: "
  333. f"{sampling_tensors.typical_ps}.")
  334. logits = _apply_typical_sampling(
  335. logits, sampling_tensors.typical_ps)
  336. elif sampler_id == SamplerID.QUADRATIC and do_quadratic:
  337. if (sampling_metadata.seq_groups and
  338. sampling_metadata.seq_groups[0].is_prompt):
  339. logger.debug(
  340. "Applying Quadratic and Cubic Sampling with "
  341. "smoothing_factors: "
  342. f"{sampling_tensors.smoothing_factors},"
  343. f" smoothing_curves: "
  344. f"{sampling_tensors.smoothing_curves}.")
  345. logits = _apply_quadratic_sampling(
  346. logits, sampling_tensors.smoothing_factors,
  347. sampling_tensors.smoothing_curves)
  348. elif sampler_id == SamplerID.XTC and do_xtc:
  349. if (sampling_metadata.seq_groups and
  350. sampling_metadata.seq_groups[0].is_prompt):
  351. logger.debug(
  352. "Applying Exclude Top Choices sampling with "
  353. f"xtc_threshold: {sampling_tensors.xtc_thresholds}, "
  354. "xtc_probability: "
  355. f"{sampling_tensors.xtc_probabilities}.")
  356. logits = _apply_xtc_sampling(
  357. logits, sampling_tensors.xtc_thresholds,
  358. sampling_tensors.xtc_probabilities)
  359. # We use float32 for probabilities and log probabilities.
  360. # Compute the probabilities.
  361. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  362. # skew needs to be applied post-softmax
  363. if do_skew:
  364. if (sampling_metadata.seq_groups and
  365. sampling_metadata.seq_groups[0].is_prompt):
  366. logger.debug(
  367. "Applying Skew sampling with skew: "
  368. f"{sampling_tensors.skews}.")
  369. # reference: https://github.com/turboderp/exllamav2/commit/1de4cdd70b09208e7b4f17ee322c190e16f60efd
  370. cum_probs = torch.cumsum(probs, dim=-1)
  371. cum_probs = torch.pow(cum_probs, torch.exp(
  372. sampling_tensors.skews).unsqueeze(dim=1))
  373. probs = torch.diff(cum_probs, dim=-1,
  374. prepend=torch.zeros_like(cum_probs[..., :1]))
  375. logits = torch.log(probs)
  376. # Compute the log probabilities.
  377. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  378. # Sample the next tokens.
  379. sample_results, maybe_sampled_tokens_tensor = _sample(
  380. probs,
  381. logprobs,
  382. sampling_metadata,
  383. sampling_tensors,
  384. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  385. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  386. )
  387. if self.include_gpu_probs_tensor:
  388. assert maybe_sampled_tokens_tensor is not None
  389. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  390. else:
  391. on_device_tensors = None
  392. # Get the logprobs query results.
  393. prompt_logprobs = None
  394. sample_logprobs = None
  395. if not sampling_metadata.skip_sampler_cpu_output:
  396. prompt_logprobs, sample_logprobs = _get_logprobs(
  397. logprobs, sampling_metadata, sample_results)
  398. return _build_sampler_output(
  399. sample_results,
  400. sampling_metadata,
  401. prompt_logprobs,
  402. sample_logprobs,
  403. on_device_tensors=on_device_tensors,
  404. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  405. @property
  406. def _should_modify_greedy_probs_inplace(self) -> bool:
  407. """Whether or not the sampler should modify the probability distribution
  408. of greedily-sampled tokens such that multinomial sampling would sample
  409. the greedily-sampled token.
  410. In other words, if True then we set the probability of the greedily-
  411. sampled token to 1.
  412. This is used by speculative decoding, which requires that the sampling
  413. method be encoded into the probability distribution.
  414. """
  415. return self.should_modify_greedy_probs_inplace
  416. def _get_bin_counts_and_mask(
  417. tokens: torch.Tensor,
  418. vocab_size: int,
  419. num_seqs: int,
  420. ) -> Tuple[torch.Tensor, torch.Tensor]:
  421. # Compute the bin counts for the tokens.
  422. # vocab_size + 1 for padding.
  423. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  424. dtype=torch.long,
  425. device=tokens.device)
  426. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  427. bin_counts = bin_counts[:, :vocab_size]
  428. mask = bin_counts > 0
  429. return bin_counts, mask
  430. def _get_custom_token_bans(
  431. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  432. assert sampling_metadata.seq_groups is not None
  433. banned_tokens: List[List[int]] = []
  434. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  435. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  436. seq_ids = seq_group.seq_ids
  437. custom_token_bans = sampling_params.custom_token_bans
  438. if (i < sampling_metadata.num_prompts
  439. and sampling_params.prompt_logprobs is not None):
  440. prompt_len = len(seq_group.prompt_logprob_indices)
  441. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  442. banned_tokens += [custom_token_bans] * len(seq_ids)
  443. return banned_tokens
  444. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  445. output_tokens_tensor: torch.Tensor,
  446. presence_penalties: torch.Tensor,
  447. frequency_penalties: torch.Tensor,
  448. repetition_penalties: torch.Tensor) -> torch.Tensor:
  449. num_seqs, vocab_size = logits.shape
  450. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  451. num_seqs)
  452. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  453. output_tokens_tensor, vocab_size, num_seqs)
  454. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  455. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  456. logits = torch.where(logits > 0, logits / repetition_penalties,
  457. logits * repetition_penalties)
  458. # We follow the definition in OpenAI API.
  459. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  460. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  461. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  462. return logits
  463. def _apply_temperatures(
  464. logits: torch.Tensor,
  465. temperatures: torch.Tensor,
  466. dynatemp_mins: torch.Tensor,
  467. dynatemp_maxs: torch.Tensor,
  468. dynatemp_exps: torch.Tensor,
  469. ) -> None:
  470. dynatemp_mask = (dynatemp_mins != 0) | (dynatemp_maxs != 0)
  471. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  472. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  473. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  474. dynatemp_logits = logits[dynatemp_mask]
  475. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  476. dynatemp_probs = dynatemp_shifted_logits.exp()
  477. dynatemp_entropies = -(dynatemp_probs *
  478. dynatemp_shifted_logits).nansum(dim=-1)
  479. dynatemp_max_entropies = torch.log_(
  480. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  481. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  482. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  483. normalized_entropies.pow_(dynatemp_exps))
  484. temperatures[dynatemp_mask] = dyn_temp
  485. temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
  486. temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
  487. # To prevent saturation of top logits, we shift the range to [-inf, 1]
  488. # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
  489. # Why mask? So we aren't potentially discarding data in milder temps.
  490. low_temps = temperatures < 0.1
  491. logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
  492. logits.div_(temperatures.unsqueeze(dim=1))
  493. def _apply_token_bans(logits: torch.Tensor,
  494. banned_tokens: List[List[int]]) -> torch.Tensor:
  495. for i, banned_token_ids in enumerate(banned_tokens):
  496. if i >= logits.size(0):
  497. break
  498. if not banned_token_ids:
  499. continue
  500. logits[i, banned_token_ids] = -float("inf")
  501. return logits
  502. def _apply_min_tokens_penalty(
  503. logits: torch.Tensor,
  504. sampling_metadata: SamplingMetadata,
  505. ) -> torch.Tensor:
  506. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  507. have not been generated yet
  508. """
  509. # list of indices in logits that will be set to -inf
  510. logits_to_penalize = []
  511. logits_applied = 0
  512. for seq_group in sampling_metadata.seq_groups:
  513. seq_ids = seq_group.seq_ids
  514. sampling_params = seq_group.sampling_params
  515. sample_indices = seq_group.sample_indices
  516. logits_applied += len(sample_indices) + len(
  517. seq_group.prompt_logprob_indices)
  518. if not seq_group.do_sample:
  519. continue
  520. start_idx = sample_indices[0]
  521. min_tokens = sampling_params.min_tokens
  522. token_ids_to_penalize = sampling_params.all_stop_token_ids
  523. if min_tokens > 0 and token_ids_to_penalize:
  524. seqs_to_penalize = []
  525. for j, seq_id in enumerate(seq_ids):
  526. seq_data = seq_group.seq_data[seq_id]
  527. if len(seq_data.output_token_ids_array) < min_tokens:
  528. seqs_to_penalize.append(j)
  529. if seqs_to_penalize:
  530. # convert to the index into logits
  531. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  532. # itertools.product pairs each seq index with every token id
  533. logits_to_penalize.extend(
  534. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  535. if logits_to_penalize:
  536. # use zip and * to group indices along each dimension
  537. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  538. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  539. # verifies that no rows in logits were missed unexpectedly
  540. assert logits_applied == logits.shape[0]
  541. return logits
  542. def _apply_dry(
  543. logits: torch.Tensor,
  544. input_token_ids: torch.Tensor,
  545. output_token_ids: torch.Tensor,
  546. multipliers: torch.Tensor,
  547. bases: torch.Tensor,
  548. allowed_lengths: torch.Tensor,
  549. sequence_breakers_ids: torch.Tensor,
  550. ranges: torch.Tensor,
  551. ) -> torch.Tensor:
  552. """
  553. Apply Don't Repeat Yourself (DRY) sampling to the logits.
  554. Reference: https://github.com/oobabooga/text-generation-webui/pull/5677
  555. """
  556. VOCAB_SIZE = logits.size(-1)
  557. MAX_NGRAM = 100
  558. # Process each sequence that has a nonzero multiplier
  559. applies_to = multipliers.nonzero(as_tuple=True)[0]
  560. for irow in applies_to.tolist():
  561. # DRY applies to input AND output tokens, so concat w/o padding tokens
  562. prompt_len = (len(input_token_ids[irow]) -
  563. (input_token_ids[irow] == VOCAB_SIZE).sum().item())
  564. ouput_len = (len(output_token_ids[irow]) -
  565. (output_token_ids[irow] == VOCAB_SIZE).sum().item())
  566. token_seq = torch.cat((input_token_ids[irow][:prompt_len],
  567. output_token_ids[irow][:ouput_len]), dim=0)
  568. range_limit = ranges[irow].item()
  569. if range_limit: # could this be done in cat? yes. do i care? no.
  570. token_seq = token_seq[-range_limit:]
  571. last_token = token_seq[-1].item()
  572. if last_token in sequence_breakers_ids[irow]:
  573. continue # early out for everything up to the min_ngram check?
  574. # Build a mask of all the breaking tokens in the context
  575. break_mask = torch.zeros(len(token_seq), dtype=torch.bool,
  576. device=logits.device)
  577. for break_tok in sequence_breakers_ids[irow]:
  578. break_mask.logical_or_(token_seq == break_tok)
  579. # Find the most recent breaking token (sets ngram limit)
  580. max_ngram = 0
  581. for max_ngram in range(min(len(break_mask), MAX_NGRAM + 1)):
  582. if break_mask[-max_ngram - 1]:
  583. break
  584. min_ngram = allowed_lengths[irow].item()
  585. if max_ngram <= min_ngram: # Too close to a break to match anything
  586. continue
  587. # If [token] is picked, what's the longest ngram that would match?
  588. ngram_lens = torch.zeros(VOCAB_SIZE, dtype=torch.int32,
  589. device=logits.device)
  590. # Find all instances of the last token- potential ngrams!
  591. endpoint_indexes = torch.nonzero(token_seq == last_token,
  592. as_tuple=True)[0].tolist()
  593. # NOTE: This seems like the slow part. Haven't benchmarked.
  594. for idx in endpoint_indexes[:-1]: # Skip the last_token match
  595. unwind = 0
  596. # Check up to max_ngram tokens prior to idx (we know idx matches)
  597. for unwind in range(1, min(idx, max_ngram) + 1):
  598. if break_mask[idx - unwind]:
  599. break
  600. if token_seq[idx - unwind] != token_seq[-unwind - 1]:
  601. break
  602. next_tok = token_seq[idx+1]
  603. # The repeated tokens BEFORE next_tok (+1 to include [idx]).
  604. ngram_lens[next_tok] = max(ngram_lens[next_tok].item(), unwind + 1)
  605. # Convert ngram lengths to penalty exponents
  606. penalty_mask = ngram_lens > 0
  607. scales = bases[irow] ** (ngram_lens[penalty_mask] - min_ngram)
  608. # Calculate and apply penalties
  609. logits[irow][penalty_mask] -= multipliers[irow] * scales
  610. return logits
  611. def _apply_no_repeat_ngram(
  612. logits: torch.Tensor,
  613. input_ids: torch.Tensor,
  614. ngram_size: torch.Tensor,
  615. ) -> torch.Tensor:
  616. """Apply no-repeat-ngram penalty which sets logits to -inf for tokens that
  617. would create a repeated n-gram.
  618. """
  619. if torch.all(ngram_size == 0):
  620. return logits
  621. batch_size = logits.shape[0]
  622. for i in range(batch_size):
  623. size = int(ngram_size[i].item())
  624. if size == 0:
  625. continue
  626. cur_len = len(input_ids[i])
  627. if cur_len < size:
  628. continue
  629. banned_tokens = _calc_banned_ngram_tokens(
  630. ngram_size=size,
  631. prev_input_ids=input_ids[i],
  632. cur_len=cur_len-1
  633. )
  634. if banned_tokens:
  635. logits[i, banned_tokens] = -float("inf")
  636. return logits
  637. def _apply_top_k_top_p(
  638. logits: torch.Tensor,
  639. p: torch.Tensor,
  640. k: torch.Tensor,
  641. ) -> torch.Tensor:
  642. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  643. # Apply top-k.
  644. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  645. # Get all the top_k values.
  646. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  647. top_k_mask = logits_sort < top_k_mask
  648. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  649. # Apply top-p.
  650. probs_sort = logits_sort.softmax(dim=-1)
  651. probs_sum = probs_sort.cumsum(dim=-1)
  652. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  653. # at least one
  654. top_p_mask[:, -1] = False
  655. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  656. # Re-sort the probabilities.
  657. src = torch.arange(logits_idx.shape[-1],
  658. device=logits_idx.device).expand_as(logits_idx)
  659. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  660. index=logits_idx,
  661. src=src)
  662. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  663. return logits
  664. def _apply_min_p(
  665. logits: torch.Tensor,
  666. min_p: torch.Tensor,
  667. ) -> torch.Tensor:
  668. """
  669. Adapted from
  670. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  671. """
  672. probs = torch.softmax(logits, dim=-1)
  673. top_probs, _ = probs.max(dim=-1, keepdim=True)
  674. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  675. tokens_to_remove = probs < scaled_min_p
  676. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  677. return logits
  678. def _apply_top_a(
  679. logits: torch.Tensor,
  680. top_a: torch.Tensor,
  681. ) -> torch.Tensor:
  682. probs = torch.softmax(logits, dim=-1)
  683. top_probs, _ = probs.max(dim=-1, keepdim=True)
  684. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  685. tokens_to_remove = probs < threshold
  686. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  687. return logits
  688. def _apply_tfs(
  689. logits: torch.Tensor,
  690. tfs: torch.Tensor,
  691. ) -> torch.Tensor:
  692. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  693. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  694. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  695. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  696. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  697. tfs_mask = torch.cat(
  698. (
  699. torch.zeros(
  700. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  701. tfs_mask,
  702. torch.ones(
  703. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  704. ),
  705. dim=-1,
  706. )
  707. logits_sort[tfs_mask] = -float("inf")
  708. logits = torch.gather(logits_sort,
  709. dim=-1,
  710. index=torch.argsort(logits_idx, dim=-1))
  711. return logits
  712. def _apply_eta_cutoff(
  713. logits: torch.Tensor,
  714. eta_cutoff: torch.Tensor,
  715. ) -> torch.Tensor:
  716. shifted_logits = torch.log_softmax(logits, dim=-1)
  717. probs = shifted_logits.exp()
  718. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  719. eps = torch.min(eta_cutoff,
  720. torch.sqrt(eta_cutoff) *
  721. torch.exp(neg_entropy)).unsqueeze(dim=1)
  722. eta_mask = probs < eps
  723. # guard against nulling out all the logits
  724. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  725. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  726. logits[eta_mask] = -float("inf")
  727. return logits
  728. def _apply_epsilon_cutoff(
  729. logits: torch.Tensor,
  730. epsilon_cutoff: torch.Tensor,
  731. ) -> torch.Tensor:
  732. probs = logits.softmax(dim=-1)
  733. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  734. # guard against nulling out all the logits
  735. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  736. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  737. logits[eps_mask] = -float("inf")
  738. return logits
  739. def _apply_typical_sampling(
  740. logits: torch.Tensor,
  741. typical_p: torch.Tensor,
  742. ) -> torch.Tensor:
  743. shifted_logits = torch.log_softmax(logits, dim=-1)
  744. probs = shifted_logits.exp()
  745. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  746. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  747. _, indices = torch.sort(surprisal_deviations)
  748. reordered_probs = probs.gather(-1, indices)
  749. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  750. dim=1)
  751. min_tokens_to_keep = 1
  752. # Keep at least min_tokens_to_keep
  753. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  754. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  755. logits[typ_mask] = -float("inf")
  756. return logits
  757. def _apply_quadratic_sampling(
  758. logits: torch.Tensor,
  759. smoothing_factor: torch.Tensor,
  760. smoothing_curve: torch.Tensor,
  761. ) -> torch.Tensor:
  762. """
  763. Applies a quadratic transformation to the logits based on the
  764. provided smoothing factors and curves. The transformation is
  765. centered around the maximum logit value in the batch.
  766. The transformation involves a quadratic and cubic term, with the
  767. cubic term controlled by the smoothing curve. The quadratic term is
  768. scaled by the smoothing factor, and the cubic term is scaled by the
  769. product of the smoothing factor and the smoothing curve.
  770. params:
  771. logits (torch.Tensor): The logits to be transformed.
  772. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  773. term in the transformation.
  774. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  775. in the transformation.
  776. returns:
  777. torch.Tensor: The transformed logits.
  778. Credits: @kalomaze
  779. """
  780. mask = smoothing_factor != 0
  781. smoothing_factor.unsqueeze_(dim=1)
  782. smoothing_curve.unsqueeze_(dim=1)
  783. k = smoothing_factor * (3 - smoothing_curve) / 2
  784. s = smoothing_factor * (smoothing_curve - 1) / 2
  785. quadlogits = logits[mask] # limit to logits using this sampler
  786. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  787. # Construct the delta from each logit to its new value
  788. diff = quadlogits - max_logits
  789. diff -= diff**2 * (s[mask] * diff - k[mask])
  790. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  791. logits[mask] -= diff
  792. return logits
  793. def _apply_xtc_sampling(
  794. logits: torch.Tensor,
  795. xtc_thresholds: torch.Tensor,
  796. xtc_probabilities: torch.Tensor,
  797. ) -> torch.Tensor:
  798. """Apply Exclude Top Choices (XTC) sampling to the logits.
  799. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  800. Args:
  801. logits: (num_tokens, vocab_size) The input logits.
  802. xtc_thresholds: (num_tokens,) The threshold for each token.
  803. xtc_probabilities: (num_tokens,) The probability of applying XTC
  804. for each token.
  805. Returns:
  806. torch.Tensor: The modified logits.
  807. """
  808. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  809. if not apply_xtc.any():
  810. return logits
  811. probs = torch.softmax(logits, dim=-1)
  812. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  813. # Find indices where the next probability is above the threshold
  814. # Skips the top choice, which later on becomes skipping the last choice.
  815. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  816. # Apply XTC only to rows where it should be applied
  817. for i in range(logits.shape[0]):
  818. if apply_xtc[i]:
  819. # Count logits above the threshold (skipping the first)
  820. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  821. if indices_to_remove > 0:
  822. # Implies the top logit and at least one other is >= threshold.
  823. # Mask out above_thresh logits except the last/lowest one.
  824. logits[i].scatter_(
  825. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  826. return logits
  827. def _apply_top_nsigma(
  828. logits: torch.Tensor,
  829. nsigma: torch.Tensor,
  830. ) -> torch.Tensor:
  831. """Apply top-nsigma truncation to the logits.
  832. Reference: https://arxiv.org/abs/2411.07641
  833. Args:
  834. logits: Logits of shape (num_tokens, vocab_size)
  835. nsigma: Number of standard deviations to use as threshold
  836. Returns:
  837. Modified logits with values below threshold set to -inf
  838. """
  839. std = logits.std(dim=-1, keepdim=True)
  840. threshold = (logits.max(dim=-1, keepdim=True).values -
  841. nsigma.unsqueeze(dim=1) * std)
  842. logits[logits < threshold] = float("-inf")
  843. return logits
  844. def _greedy_sample(
  845. selected_seq_groups: List[SequenceGroupToSample],
  846. samples: torch.Tensor,
  847. ) -> List[Tuple[List[int], List[int]]]:
  848. """Run greedy sampling on a given samples.
  849. Args:
  850. selected_seq_groups: A list of sequence groups batched.
  851. samples: (num_selected_samples,) A tensor of samples. The length of
  852. samples could be smaller than selected_seq_groups if
  853. seq_group.do_sample is False.
  854. Returns:
  855. Tuple of (next_token_ids, parent_ids). The length of returned list is
  856. same as the length of selected_seq_groups. If the corresponding
  857. seq_group has do_sample=False, tuple contains ([], [])
  858. """
  859. samples = samples.tolist()
  860. sample_idx = 0
  861. results = []
  862. for seq_group in selected_seq_groups:
  863. if not seq_group.do_sample:
  864. results.append(([], []))
  865. continue
  866. seq_ids = seq_group.seq_ids
  867. num_parent_seqs = len(seq_ids)
  868. assert num_parent_seqs == 1, (
  869. "Greedy sampling should have only one seq.")
  870. parent_ids = list(range(num_parent_seqs))
  871. next_token_ids = [samples[sample_idx]]
  872. results.append((next_token_ids, parent_ids))
  873. sample_idx += num_parent_seqs
  874. return results
  875. def _random_sample(
  876. selected_seq_groups: List[SequenceGroupToSample],
  877. random_samples: torch.Tensor,
  878. ) -> List[Tuple[List[int], List[int]]]:
  879. """Run random sampling on a given samples.
  880. Args:
  881. selected_seq_groups: A list of sequence groups batched.
  882. random_samples: (num_selected_samples,) A tensor of samples. The
  883. length of samples could be smaller than selected_seq_groups if
  884. seq_group.do_sample is False.
  885. Returns:
  886. Tuple of (next_token_ids, parent_ids). The length of returned list is
  887. same as the length of selected_seq_groups. If the corresponding
  888. seq_group has do_sample=False, tuple contains ([], [])
  889. """
  890. # Find the maximum best_of value of the prompt phase requests.
  891. random_samples = random_samples.cpu()
  892. sample_idx = 0
  893. results = []
  894. for seq_group in selected_seq_groups:
  895. if not seq_group.do_sample:
  896. results.append(([], []))
  897. continue
  898. seq_ids = seq_group.seq_ids
  899. sampling_params = seq_group.sampling_params
  900. is_prompt = seq_group.is_prompt
  901. num_parent_seqs = len(seq_ids)
  902. if is_prompt:
  903. # Prompt phase.
  904. parent_ids = [0] * sampling_params.best_of
  905. next_token_ids = random_samples[
  906. sample_idx, :sampling_params.best_of].tolist()
  907. else:
  908. # Generation phase.
  909. parent_ids = list(range(num_parent_seqs))
  910. next_token_ids = random_samples[sample_idx:sample_idx +
  911. num_parent_seqs, 0].tolist()
  912. results.append((next_token_ids, parent_ids))
  913. sample_idx += num_parent_seqs
  914. return results
  915. def _beam_search_sample(
  916. selected_seq_groups: List[SequenceGroupToSample],
  917. logprobs: torch.Tensor,
  918. ) -> List[Tuple[List[int], List[int]]]:
  919. """Run beam sampling on a given samples.
  920. Args:
  921. selected_seq_groups: A list of sequence groups batched.
  922. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  923. on selected sample indices.
  924. Returns:
  925. Tuple of (next_token_ids, parent_ids). The length of returned list is
  926. same as the length of selected_seq_groups. If the corresponding
  927. seq_group has do_sample=False, tuple contains ([], [])
  928. """
  929. # We sample 2 * beam_width candidates to make sure that with high
  930. # probability we can get `beam_width` candidates in addition to
  931. # the finished sequences for the next iteration. See
  932. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  933. # for details. See also HF reference:
  934. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  935. #
  936. # NOTE: Beam search is not vectorized, so its speed can be slower than
  937. # other sampling methods.
  938. sample_idx = 0
  939. results = []
  940. for seq_group in selected_seq_groups:
  941. if not seq_group.do_sample:
  942. results.append(([], []))
  943. continue
  944. is_prompt = seq_group.is_prompt
  945. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  946. num_parent_seqs = len(seq_ids)
  947. beam_width = sampling_params.best_of
  948. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  949. if is_prompt:
  950. # Prompt phase.
  951. assert num_parent_seqs == 1, (
  952. "Prompt input should have only one seq.")
  953. parent_ids = [0] * (2 * beam_width)
  954. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  955. 2 * beam_width)
  956. next_token_ids = next_token_ids.tolist()
  957. else:
  958. # Generation phase.
  959. cumulative_logprobs = [
  960. seq_group.seq_data[seq_id].cumulative_logprob
  961. for seq_id in seq_ids
  962. ]
  963. cumulative_logprobs = torch.tensor(
  964. cumulative_logprobs,
  965. dtype=torch.float,
  966. device=seq_group_logprobs.device)
  967. seq_group_logprobs = (seq_group_logprobs +
  968. cumulative_logprobs.unsqueeze(dim=1))
  969. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  970. 2 * beam_width)
  971. topk_ids = topk_ids.tolist()
  972. vocab_size = seq_group_logprobs.size(-1)
  973. parent_ids = [i // vocab_size for i in topk_ids]
  974. next_token_ids = [i % vocab_size for i in topk_ids]
  975. results.append((next_token_ids, parent_ids))
  976. sample_idx += num_parent_seqs
  977. assert sample_idx == logprobs.size(0)
  978. return results
  979. # torch.multinomial forces a GPU<->CPU sync.
  980. # Therefore, we use an optimized implementation instead.
  981. # Note that we always sample with replacement.
  982. # probs will be modified in place, but this is fine, as we pass
  983. # in a copy already.
  984. def _multinomial(
  985. probs: torch.Tensor,
  986. num_samples: int,
  987. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  988. ) -> torch.Tensor:
  989. if num_samples > 1:
  990. probs = probs.repeat_interleave(num_samples, dim=0)
  991. q = torch.empty_like(probs)
  992. if seq_groups is None:
  993. q.exponential_()
  994. else:
  995. sample_idx = 0
  996. for seq_group in seq_groups:
  997. seq_ids = seq_group.seq_ids
  998. stride = len(seq_ids) * num_samples
  999. assert seq_group.generator is not None
  1000. q[sample_idx:sample_idx +
  1001. stride].exponential_(generator=seq_group.generator)
  1002. sample_idx += stride
  1003. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  1004. def _top_k_top_p_multinomial_with_kernels(
  1005. probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
  1006. num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
  1007. max_top_k_round = 32
  1008. if num_samples > 1:
  1009. probs = probs.repeat_interleave(num_samples, dim=0)
  1010. top_ks = top_ks.repeat_interleave(num_samples)
  1011. top_ps = top_ps.repeat_interleave(num_samples)
  1012. batch_size = probs.shape[0]
  1013. uniform_samples = torch.empty((max_top_k_round, batch_size),
  1014. device=probs.device)
  1015. if seq_groups is None:
  1016. uniform_samples.uniform_()
  1017. else:
  1018. sample_idx = 0
  1019. for seq_group in seq_groups:
  1020. seq_ids = seq_group.seq_ids
  1021. stride = len(seq_ids) * num_samples
  1022. assert seq_group.generator is not None
  1023. uniform_samples[:, sample_idx:sample_idx +
  1024. stride].uniform_(generator=seq_group.generator)
  1025. sample_idx += stride
  1026. batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
  1027. probs,
  1028. uniform_samples,
  1029. top_ks,
  1030. top_ps,
  1031. )
  1032. if not success.all():
  1033. warnings.warn("CUDA rejection sampling failed, fallback.",
  1034. stacklevel=1)
  1035. probs = ops.top_k_renorm_prob(probs, top_ks)
  1036. probs = ops.top_p_renorm_prob(probs, top_ps)
  1037. batch_next_token_ids = ops.sampling_from_probs(
  1038. probs, uniform_samples[0])
  1039. return batch_next_token_ids.view(-1, num_samples)
  1040. def _sample_with_torch(
  1041. probs: torch.Tensor,
  1042. logprobs: torch.Tensor,
  1043. sampling_metadata: SamplingMetadata,
  1044. sampling_tensors: SamplingTensors,
  1045. include_gpu_probs_tensor: bool,
  1046. modify_greedy_probs: bool,
  1047. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  1048. categorized_seq_group_ids = {t: [] for t in SamplingType}
  1049. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  1050. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  1051. sampling_params = seq_group.sampling_params
  1052. sampling_type = sampling_params.sampling_type
  1053. categorized_seq_group_ids[sampling_type].append(i)
  1054. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  1055. sample_metadata = {}
  1056. multinomial_samples = {}
  1057. # Create output tensor for sampled token ids.
  1058. if include_gpu_probs_tensor:
  1059. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  1060. 1,
  1061. dtype=torch.long,
  1062. device=logprobs.device)
  1063. else:
  1064. sampled_token_ids_tensor = None
  1065. # Counterintuitively, having two loops here is actually faster.
  1066. # The first loop can run without waiting on GPU<->CPU sync.
  1067. for sampling_type in SamplingType:
  1068. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  1069. num_tokens = len(sample_indices)
  1070. if num_tokens == 0:
  1071. continue
  1072. seq_group_id = categorized_seq_group_ids[sampling_type]
  1073. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  1074. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  1075. long_sample_indices = sample_indices.long()
  1076. if sampling_type == SamplingType.GREEDY:
  1077. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  1078. dim=-1)
  1079. if include_gpu_probs_tensor:
  1080. # Store sampled tokens in output tensor.
  1081. sampled_token_ids_tensor[
  1082. long_sample_indices] = greedy_samples.unsqueeze(-1)
  1083. if modify_greedy_probs:
  1084. # If required, modify the probabilities such that sampling from
  1085. # the modified distribution would always sample the argmax
  1086. # token id.
  1087. _modify_greedy_probs_inplace(logprobs, probs,
  1088. long_sample_indices,
  1089. greedy_samples)
  1090. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1091. max_best_of_in_batch = 1
  1092. for seq_group in seq_groups:
  1093. if seq_group.is_prompt:
  1094. sampling_params = seq_group.sampling_params
  1095. max_best_of_in_batch = max(max_best_of_in_batch,
  1096. sampling_params.best_of)
  1097. seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
  1098. seq_groups)
  1099. if APHRODITE_USE_SAMPLING_KERNELS is not None:
  1100. multinomial_samples[
  1101. sampling_type] = _top_k_top_p_multinomial_with_kernels(
  1102. probs[long_sample_indices],
  1103. sampling_tensors.top_ks[long_sample_indices],
  1104. sampling_tensors.top_ps[long_sample_indices],
  1105. max_best_of_in_batch,
  1106. seq_groups_arg,
  1107. )
  1108. else:
  1109. multinomial_samples[sampling_type] = _multinomial(
  1110. probs[long_sample_indices],
  1111. max_best_of_in_batch,
  1112. seq_groups=seq_groups_arg)
  1113. if sampled_token_ids_tensor is not None:
  1114. # Store sampled tokens in output tensor.
  1115. sampled_token_ids_tensor[long_sample_indices] = \
  1116. multinomial_samples[sampling_type].to(torch.long)
  1117. elif sampling_type == SamplingType.BEAM:
  1118. beam_search_logprobs = logprobs[sample_indices]
  1119. else:
  1120. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  1121. # GPU<->CPU sync happens in the loop below.
  1122. # This also converts the sample output to Python objects.
  1123. if not sampling_metadata.skip_sampler_cpu_output:
  1124. for sampling_type in SamplingType:
  1125. if sampling_type not in sample_metadata:
  1126. continue
  1127. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  1128. if sampling_type == SamplingType.GREEDY:
  1129. sample_results = _greedy_sample(seq_groups, greedy_samples)
  1130. elif sampling_type in (SamplingType.RANDOM,
  1131. SamplingType.RANDOM_SEED):
  1132. sample_results = _random_sample(
  1133. seq_groups, multinomial_samples[sampling_type])
  1134. elif sampling_type == SamplingType.BEAM:
  1135. sample_results = _beam_search_sample(seq_groups,
  1136. beam_search_logprobs)
  1137. sample_results_dict.update(zip(seq_group_id, sample_results))
  1138. sample_results = [
  1139. sample_results_dict.get(i, ([], []))
  1140. for i in range(len(sampling_metadata.seq_groups))
  1141. ]
  1142. else:
  1143. sample_results = []
  1144. return sample_results, sampled_token_ids_tensor
  1145. def _sample_with_triton_kernel(
  1146. probs: torch.Tensor,
  1147. logprobs: torch.Tensor,
  1148. sampling_metadata: SamplingMetadata,
  1149. sampling_tensors: SamplingTensors,
  1150. ) -> List[Tuple[List[int], List[int]]]:
  1151. categorized_seq_group_ids = {t: [] for t in SamplingType}
  1152. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  1153. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  1154. sampling_params = seq_group.sampling_params
  1155. sampling_type = sampling_params.sampling_type
  1156. categorized_seq_group_ids[sampling_type].append(i)
  1157. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  1158. sample_metadata = {}
  1159. max_best_of_in_batch = 1
  1160. # Counterintuitively, having two loops here is actually faster.
  1161. # The first loop can run without waiting on GPU<->CPU sync.
  1162. for sampling_type in SamplingType:
  1163. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  1164. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  1165. num_tokens = len(sample_indices)
  1166. if num_tokens == 0:
  1167. continue
  1168. seq_group_id = categorized_seq_group_ids[sampling_type]
  1169. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  1170. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  1171. sample_indices,
  1172. sampled_token_indices)
  1173. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  1174. SamplingType.RANDOM_SEED):
  1175. for seq_group in seq_groups:
  1176. if seq_group.is_prompt:
  1177. sampling_params = seq_group.sampling_params
  1178. max_best_of_in_batch = max(max_best_of_in_batch,
  1179. sampling_params.best_of)
  1180. elif sampling_type == SamplingType.BEAM:
  1181. beam_search_logprobs = logprobs[sample_indices]
  1182. else:
  1183. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  1184. sampled_tokens, _, _ = sample_triton(
  1185. probs=probs,
  1186. seeds=sampling_tensors.sampling_seeds,
  1187. max_best_of=max_best_of_in_batch,
  1188. sample_indices=sampling_tensors.sample_indices,
  1189. logprobs=logprobs,
  1190. # don't save logprobs because we have logic for that below
  1191. # TODO: use this instead of the CPU-based logic below
  1192. save_logprobs=False,
  1193. )
  1194. # GPU<->CPU sync happens in the loop below.
  1195. for sampling_type in SamplingType:
  1196. if sampling_type not in sample_metadata:
  1197. continue
  1198. (seq_group_id, seq_groups, sample_indices,
  1199. sampled_token_indices) = sample_metadata[sampling_type]
  1200. if sampling_type == SamplingType.GREEDY:
  1201. sample_results = _greedy_sample(
  1202. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  1203. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  1204. sample_results = _random_sample(
  1205. seq_groups, sampled_tokens[sampled_token_indices])
  1206. elif sampling_type == SamplingType.BEAM:
  1207. sample_results = _beam_search_sample(seq_groups,
  1208. beam_search_logprobs)
  1209. sample_results_dict.update(zip(seq_group_id, sample_results))
  1210. sample_results = [
  1211. sample_results_dict.get(i, ([], []))
  1212. for i in range(len(sampling_metadata.seq_groups))
  1213. ]
  1214. return sample_results
  1215. def _sample(
  1216. probs: torch.Tensor, logprobs: torch.Tensor,
  1217. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  1218. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  1219. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  1220. """
  1221. Args:
  1222. probs: (num_query_tokens_in_batch, num_vocab)
  1223. logprobs: (num_query_tokens_in_batch, num_vocab)
  1224. sampling_metadata: The metadata for a batch for sampling.
  1225. sampling_tensors: Tensors that include sampling related metadata.
  1226. Returns:
  1227. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  1228. If sampling is skipped, it returns ([], [])
  1229. sampled_token_ids_tensor: A tensor of sampled token ids.
  1230. """
  1231. return _sample_with_torch(
  1232. probs,
  1233. logprobs,
  1234. sampling_metadata,
  1235. sampling_tensors,
  1236. include_gpu_probs_tensor=include_gpu_probs_tensor,
  1237. modify_greedy_probs=modify_greedy_probs,
  1238. )
  1239. # TODO: Enable once Triton kernel & associated code is faster.
  1240. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  1241. # sampling_tensors)
  1242. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  1243. """
  1244. This function calculates the ranks of the chosen tokens in a logprob tensor.
  1245. Args:
  1246. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  1247. where N is the no. of tokens and M is the vocab dim.
  1248. indices (torch.Tensor): List of chosen token indices.
  1249. Returns:
  1250. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  1251. Each element in the returned tensor represents the rank
  1252. of the chosen token in the input logprob tensor.
  1253. """
  1254. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  1255. indices]
  1256. return (x > vals[:, None]).long().sum(1).add_(1)
  1257. def _get_logprobs(
  1258. logprobs: torch.Tensor,
  1259. sampling_metadata: SamplingMetadata,
  1260. sample_results: List[Tuple[List[int], List[int]]],
  1261. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  1262. """Return sample lobprobs and prompt logprobs.
  1263. The logic consists of 3 parts.
  1264. - Select indices to compute logprob from, ranks of token ids, and
  1265. the top k token ids from logprobs.
  1266. - Compute prompt logprobs if required.
  1267. - Compute sample logprobs if required.
  1268. Args:
  1269. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  1270. logprob per vocab. Sequence groups' query tokens are batched in a
  1271. single flattened tensor. For example, assuming there are N
  1272. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  1273. prompt logprob is enabled), decode tokens for seq_group_1 (if
  1274. sampling is required), prefill tokens for seq_group_2, ...
  1275. sampling_metadata: The sampling metadata.
  1276. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  1277. parent_ids) for each sequence group. When beam search is enabled,
  1278. sample_results can contain different number of seq_ids from
  1279. sampling_metadata.seq_groups. It is because beam search creates
  1280. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  1281. BEAM_WIDTH number of seq_ids).
  1282. Returns:
  1283. A tuple of prompt and sample logprobs per sequence group in a batch.
  1284. """
  1285. # The index of query token to calculate logprobs. It includes both
  1286. # prompt and sample logprob indices.
  1287. query_indices: List[int] = []
  1288. # The next token ids to get the logprob value from.
  1289. next_token_ids: List[int] = []
  1290. # The largest requested number of logprobs. We find logprobs as many as the
  1291. # largest num logprobs in this API. If every logprobs is None, it will be
  1292. # set to -1.
  1293. largest_num_logprobs = -1
  1294. # If beam search is enabled.
  1295. use_beam_search = False
  1296. # Select indices to compute logprob from, ranks of token ids, and the top
  1297. # k token ids from logprobs.
  1298. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  1299. sample_results):
  1300. sampling_params = seq_group.sampling_params
  1301. # Update indices and tokens for prompt logprobs.
  1302. if (seq_group.is_prompt
  1303. and sampling_params.prompt_logprobs is not None):
  1304. largest_num_logprobs = max(largest_num_logprobs,
  1305. sampling_params.prompt_logprobs)
  1306. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1307. query_indices.extend(seq_group.prompt_logprob_indices)
  1308. next_token_ids.extend(next_prompt_tokens)
  1309. # Update indices and next tokenes for sample logprob.
  1310. if seq_group.do_sample:
  1311. token_ids, parent_seq_ids = sample_result
  1312. # NOTE: We cannot directly use sample_indices because
  1313. # sample_indices only contain parent seq_ids of a previous step.
  1314. # The current step may have different number of seq_ids, and
  1315. # we can obtain it from `sample_result[1]`.
  1316. query_idx = seq_group.sample_indices[0]
  1317. query_indices.extend(
  1318. [query_idx + parent_id for parent_id in parent_seq_ids])
  1319. next_token_ids.extend(token_ids)
  1320. if sampling_params.logprobs is not None:
  1321. largest_num_logprobs = max(largest_num_logprobs,
  1322. sampling_params.logprobs)
  1323. use_beam_search = use_beam_search or sampling_params.use_beam_search
  1324. assert len(next_token_ids) == len(query_indices)
  1325. if len(query_indices) == 0:
  1326. empty_sampled_logprob = []
  1327. empty_prompt_logprob = None
  1328. return [empty_prompt_logprob], [empty_sampled_logprob]
  1329. selected_logprobs, ranks = None, None
  1330. top_logprobs, top_token_ids = None, None
  1331. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  1332. # skip the whole logprob calculation.
  1333. if largest_num_logprobs >= 0 or use_beam_search:
  1334. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  1335. next_token_ids_gpu = torch.tensor(next_token_ids,
  1336. device=logprobs.device)
  1337. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  1338. # contain duplicates if beam search is enabled.
  1339. selected_logprobs = logprobs[[
  1340. query_indices_gpu,
  1341. next_token_ids_gpu,
  1342. ]]
  1343. ranks = _get_ranks(
  1344. logprobs[query_indices_gpu],
  1345. next_token_ids_gpu,
  1346. )
  1347. assert selected_logprobs.shape[0] == ranks.shape[0]
  1348. # We need to compute top k only if there exists logprobs > 0.
  1349. if largest_num_logprobs > 0:
  1350. # Logprobs of topk tokens for a batch of sequence groups.
  1351. # (num_query_tokens_across_batch).
  1352. top_logprobs, top_token_ids = torch.topk(logprobs,
  1353. largest_num_logprobs,
  1354. dim=-1)
  1355. top_logprobs = top_logprobs.to('cpu')
  1356. top_token_ids = top_token_ids.to('cpu')
  1357. selected_logprobs = selected_logprobs.to('cpu')
  1358. ranks = ranks.to('cpu')
  1359. # Find prompt/sample logprobs.
  1360. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  1361. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  1362. top_logprob_idx = 0
  1363. selected_logprobs_idx = 0
  1364. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1365. sample_results):
  1366. (prompt_logprobs, top_logprob_idx,
  1367. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1368. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1369. selected_logprobs_idx, top_logprob_idx)
  1370. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1371. (sampled_logprobs, top_logprob_idx,
  1372. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1373. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1374. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1375. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1376. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1377. def _get_prompt_logprob_if_needed(
  1378. seq_group: SequenceGroupToSample,
  1379. selected_logprobs: torch.Tensor,
  1380. ranks: torch.Tensor,
  1381. top_token_ids: torch.Tensor,
  1382. top_logprobs: torch.Tensor,
  1383. selected_logprobs_idx: int,
  1384. top_logprob_idx: int,
  1385. ):
  1386. """Compute the prompt logprob from a sequence group if needed."""
  1387. sampling_params = seq_group.sampling_params
  1388. is_prompt = seq_group.is_prompt
  1389. # Find prompt logprobs
  1390. prompt_logprobs: Optional[PromptLogprobs] = None
  1391. if is_prompt and sampling_params.prompt_logprobs is not None:
  1392. prompt_logprobs = []
  1393. num_logprobs = sampling_params.prompt_logprobs
  1394. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1395. # Pre-select indexes and create a list. It is faster than calling .item
  1396. # repetitively.
  1397. selected_logprob_items = selected_logprobs[
  1398. selected_logprobs_idx:selected_logprobs_idx +
  1399. len(next_prompt_tokens)].tolist()
  1400. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1401. len(next_prompt_tokens)].tolist()
  1402. for idx, token_id in enumerate(next_prompt_tokens):
  1403. # Calculate the prompt logprob of the real prompt tokens.
  1404. # {token_id: (logprob, rank_from_vocab)}
  1405. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1406. token_id: (selected_logprob_items[idx], rank_items[idx])
  1407. }
  1408. # Add top K prompt logprobs along with its rank.
  1409. if num_logprobs > 0:
  1410. top_ids = top_token_ids[
  1411. top_logprob_idx, :num_logprobs].tolist()
  1412. top_probs = top_logprobs[
  1413. top_logprob_idx, :num_logprobs].tolist()
  1414. # Top K is already sorted by rank, so we can use 1 ~
  1415. # num_logprobs + 1 for rank.
  1416. top_ranks = range(1, num_logprobs + 1)
  1417. prompt_logprobs_dict.update({
  1418. top_id: (top_prob, rank)
  1419. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1420. top_ranks)
  1421. })
  1422. prompt_logprobs.append({
  1423. token_id: Logprob(*logprob_and_rank)
  1424. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1425. })
  1426. # + 1 to go to the next prompt token.
  1427. top_logprob_idx += 1
  1428. # + len(next_prompt_tokens) to go to the next prompt.
  1429. selected_logprobs_idx += len(next_prompt_tokens)
  1430. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1431. def _get_sampled_logprob_if_needed(
  1432. seq_group: SequenceGroupToSample,
  1433. sample_result: Tuple[List[int], List[int]],
  1434. selected_logprobs: torch.Tensor,
  1435. ranks: torch.Tensor,
  1436. top_token_ids: torch.Tensor,
  1437. top_logprobs: torch.Tensor,
  1438. selected_logprobs_idx: int,
  1439. top_logprob_idx: int,
  1440. ):
  1441. """Compute the sample logprob if needed."""
  1442. seq_ids = seq_group.seq_ids
  1443. num_logprobs = seq_group.sampling_params.logprobs
  1444. use_beam_search = seq_group.sampling_params.use_beam_search
  1445. sampled_logprobs: SampleLogprobs = []
  1446. next_token_ids, parent_seq_ids = sample_result
  1447. if seq_group.do_sample:
  1448. assert len(next_token_ids) > 0
  1449. if num_logprobs is None and not use_beam_search:
  1450. for next_token_id in next_token_ids:
  1451. # Use a dummy logprob
  1452. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1453. else:
  1454. # Pre-select items from tensor. tolist() is faster than repetitive
  1455. # `.item()` calls.
  1456. selected_logprob_items = selected_logprobs[
  1457. selected_logprobs_idx:selected_logprobs_idx +
  1458. len(next_token_ids)].tolist()
  1459. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1460. len(next_token_ids)].tolist()
  1461. for idx, (next_token_id, parent_id) in enumerate(
  1462. zip(next_token_ids, parent_seq_ids)):
  1463. # Get the logprob of a sampled token.
  1464. sampled_logprobs_dict = {
  1465. next_token_id:
  1466. (selected_logprob_items[idx], rank_items[idx])
  1467. }
  1468. if num_logprobs is not None and num_logprobs > 0:
  1469. # Get top K logprobs.
  1470. top_ids = top_token_ids[top_logprob_idx +
  1471. parent_id, :num_logprobs].tolist()
  1472. top_probs = top_logprobs[
  1473. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1474. # Top K is already sorted by rank, so we can use 1 ~
  1475. # num_logprobs + 1 for rank.
  1476. top_ranks = range(1, num_logprobs + 1)
  1477. sampled_logprobs_dict.update({
  1478. top_id: (top_prob, rank)
  1479. for top_id, top_prob, rank in zip(
  1480. top_ids, top_probs, top_ranks)
  1481. })
  1482. sampled_logprobs.append({
  1483. token_id: Logprob(*logprob_and_rank)
  1484. for token_id, logprob_and_rank in
  1485. sampled_logprobs_dict.items()
  1486. })
  1487. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1488. # logprobs for the current step, which has len(next_token_ids) tokens
  1489. # per sequence group. `logprobs` includes logprobs from the previous
  1490. # steps, which has len(seq_ids) tokens per sequence group.
  1491. # Iterate to the next sequence group in a batch.
  1492. selected_logprobs_idx += len(next_token_ids)
  1493. # Iterate to the next sequence group in a batch.
  1494. top_logprob_idx += len(seq_ids)
  1495. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1496. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1497. sample_indices: torch.Tensor,
  1498. greedy_samples: torch.Tensor) -> None:
  1499. """Modify the probability distributions of the greedily-sampled tokens such
  1500. that each sampled token has a "probability" of 1.0. This is required by
  1501. speculative decoding, which depends on the sampling method being encoded
  1502. within the probability distribution for correctness.
  1503. # Why do we only need to do this for greedy sampling?
  1504. Aphrodite's sampler performs the following steps for greedy or multinomial
  1505. (random) sampling:
  1506. 1. Get logits from model.
  1507. 2. Modify logits according to per-sequence sampling parameters.
  1508. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1509. according to their frequency, etc.
  1510. 3. Sample a token.
  1511. - Random sampling simply samples from the modified probability
  1512. distribution.
  1513. - Greedy sampling performs `argmax` to obtain the token with the
  1514. highest likelihood.
  1515. Ignoring greedy sampling for a moment, we find that the computed probability
  1516. distribution has the following property: we can sample from it independently
  1517. and find that the token sampled by the Sampler has a frequency corresponding
  1518. to how often we see it in our sampling. In other words, for tokens sampled
  1519. with Aphrodite's random SamplingType, the computed probability distribution
  1520. encodes the sampling methodology completely.
  1521. Greedy sampling does not normally have this property. Aphrodite modifies
  1522. logits according to sampling params, then performs `argmax`, then returns
  1523. the sampled token and the computed probability distribution. If we sample
  1524. from the distribution, we'll find the likelihood of the greedily-sampled
  1525. token is not always 1.0.
  1526. Since lossless speculative decoding requires that the sampling methodology
  1527. be encoded within the probability distribution, we are motivated to modify
  1528. the probability distribution such that the sampled token has probability 1
  1529. when speculative decoding is used.
  1530. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1531. greedy sampling using multinomial computation and unite the codepaths. This
  1532. has implications on the overall design of the sampler, e.g. how to record
  1533. accurate logprobs for the user, so this improvement is deferred to later.
  1534. """
  1535. # NOTE: logprobs are not modified so they can be returned to the user.
  1536. probs[sample_indices, :] = 0
  1537. probs[sample_indices, greedy_samples] = 1.0
  1538. def _build_sampler_output(
  1539. sample_results: SampleResultType,
  1540. sampling_metadata: SamplingMetadata,
  1541. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1542. sample_logprobs: Optional[List[SampleLogprobs]],
  1543. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1544. torch.Tensor]],
  1545. skip_sampler_cpu_output: bool = False,
  1546. ) -> SamplerOutput:
  1547. """Construct Python objects with the output of sampling.
  1548. Args:
  1549. on_device_tensors: Tuple containing on-device tensors with the
  1550. probabilities used in sampling and the sampled token ids. This
  1551. allows post-processing without copies to CPU/serialization, e.g. in
  1552. speculative decoding rejection sampling.
  1553. """
  1554. sampler_output: List[CompletionSequenceGroupOutput] = []
  1555. if not skip_sampler_cpu_output:
  1556. assert prompt_logprobs is not None
  1557. assert sample_logprobs is not None
  1558. for (seq_group, sample_result, group_prompt_logprobs,
  1559. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1560. sample_results, prompt_logprobs,
  1561. sample_logprobs):
  1562. seq_ids = seq_group.seq_ids
  1563. next_token_ids, parent_ids = sample_result
  1564. seq_outputs: List[SequenceOutput] = []
  1565. for parent_id, next_token_id, logprobs in zip(
  1566. parent_ids, next_token_ids, group_sample_logprobs):
  1567. seq_outputs.append(
  1568. SequenceOutput(seq_ids[parent_id], next_token_id,
  1569. logprobs))
  1570. sampler_output.append(
  1571. CompletionSequenceGroupOutput(seq_outputs,
  1572. group_prompt_logprobs))
  1573. # If not specified, store None values in SamplerOutput.
  1574. if on_device_tensors is not None:
  1575. (sampled_token_probs, logprobs_tensor,
  1576. sampled_token_ids) = on_device_tensors
  1577. else:
  1578. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1579. None)
  1580. return SamplerOutput(
  1581. outputs=sampler_output,
  1582. sampled_token_probs=sampled_token_probs,
  1583. sampled_token_ids=sampled_token_ids,
  1584. logprobs=logprobs_tensor,
  1585. )
  1586. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1587. """Get a list of next prompt tokens to compute logprob from a
  1588. given sequence group.
  1589. It is used to compute prompt logprob. Imagine you have logprob for each
  1590. query token. Query token needs to know the next prompt token id to compute
  1591. prompt logprob. This is a helper to obtain next prompt token ids.
  1592. This API has to be used only when the caller knows seq_group is in prefill
  1593. stage.
  1594. Returns:
  1595. A list of next prompt tokens to compute logprob.
  1596. """
  1597. assert seq_group.is_prompt, (
  1598. "Caller should ensure the sequence group is in a prefill stage.")
  1599. seq_ids = seq_group.seq_ids
  1600. query_len = seq_group.query_len
  1601. assert query_len is not None
  1602. # prompt has only 1 seq id.
  1603. assert len(seq_ids) == 1
  1604. seq_data = seq_group.seq_data[seq_ids[0]]
  1605. computed_len = seq_data.get_num_computed_tokens()
  1606. prompt_tokens = seq_data.prompt_token_ids
  1607. # +1 because we are looking for a next prompt token.
  1608. next_token_index_start = computed_len + 1
  1609. next_token_index_end = min(computed_len + query_len + 1,
  1610. len(prompt_tokens))
  1611. next_prompt_tokens = prompt_tokens[
  1612. next_token_index_start:next_token_index_end]
  1613. return next_prompt_tokens
  1614. def _get_ngrams(
  1615. ngram_size: int,
  1616. prev_input_ids: torch.Tensor
  1617. ) -> Dict[Tuple[int, ...], List[int]]:
  1618. """Get dictionary of ngrams and the tokens that followed them.
  1619. Args:
  1620. ngram_size: Size of ngrams to track
  1621. prev_input_ids: 1D tensor of previous token ids
  1622. Returns:
  1623. Dictionary mapping ngram tuples to list of tokens that followed them
  1624. """
  1625. generated_ngrams = {}
  1626. gen_tokens = prev_input_ids.tolist()
  1627. for i in range(len(gen_tokens) - ngram_size + 1):
  1628. ngram = tuple(gen_tokens[i:i + ngram_size - 1])
  1629. next_token = gen_tokens[i + ngram_size - 1]
  1630. if ngram in generated_ngrams:
  1631. generated_ngrams[ngram].append(next_token)
  1632. else:
  1633. generated_ngrams[ngram] = [next_token]
  1634. return generated_ngrams
  1635. def _get_generated_ngrams(
  1636. banned_ngrams: Dict[Tuple[int, ...], List[int]],
  1637. prev_input_ids: torch.Tensor,
  1638. ngram_size: int,
  1639. cur_len: int
  1640. ) -> List[int]:
  1641. """Get list of tokens that would create a repeated ngram if generated next.
  1642. Args:
  1643. banned_ngrams: Dictionary of previously seen ngrams and their next
  1644. tokens
  1645. prev_input_ids: Previous token ids
  1646. ngram_size: Size of ngrams to check
  1647. cur_len: Current position in sequence
  1648. Returns:
  1649. List of token ids that would create a repeat ngram
  1650. """
  1651. start_idx = cur_len + 1 - ngram_size
  1652. current_ngram = tuple(prev_input_ids[start_idx:cur_len].tolist())
  1653. return banned_ngrams.get(current_ngram, [])
  1654. def _calc_banned_ngram_tokens(
  1655. ngram_size: int,
  1656. prev_input_ids: torch.Tensor,
  1657. cur_len: int
  1658. ) -> List[int]:
  1659. """Calculate tokens that would create repeated ngrams if generated next.
  1660. Args:
  1661. ngram_size: Size of ngrams to prevent repeating
  1662. prev_input_ids: Previous token ids in sequence
  1663. cur_len: Current position in sequence
  1664. Returns:
  1665. List of token ids that should be banned to prevent ngram repetition
  1666. """
  1667. if cur_len + 1 < ngram_size:
  1668. return []
  1669. generated_ngrams = _get_ngrams(ngram_size, prev_input_ids)
  1670. banned_tokens = _get_generated_ngrams(
  1671. generated_ngrams,
  1672. prev_input_ids,
  1673. ngram_size,
  1674. cur_len
  1675. )
  1676. return banned_tokens
  1677. # def _apply_mirostat_v2(logits: torch.Tensor,
  1678. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1679. # # Reduce our view to just the affected logits
  1680. # logit_view = logits[sampling_tensors.miro_indices]
  1681. # # Calculate surprise value per token
  1682. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1683. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1684. # # Mask out "too-surprising" tokens (surprisal > mu)
  1685. # mus = sampling_tensors.miro_mus
  1686. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1687. # # Unmask most-likely logit to guarantee a selection.
  1688. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1689. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1690. # # Apply logit mask (effectively a top-k filter).
  1691. # logit_view[miro_mask] = -float("inf")
  1692. # # Project logit changes made to the view onto the original.
  1693. # # I think this step might be redundant.
  1694. # logits[sampling_tensors.miro_indices] = logit_view
  1695. # return logits
  1696. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1697. # sample_results: List[Tuple[List[int], List[int]]],
  1698. # sampling_metadata: SamplingMetadata,
  1699. # output_metadata: OutputMetadata) -> None:
  1700. # """Based on whichever token was finally sampled, we calculate the
  1701. # final surprisal values to update the mus.
  1702. # Because a single sequence can have multiple samples, we must fork
  1703. # the mu accordingly."""
  1704. # assert sampling_metadata.seq_groups is not None
  1705. # seqid_to_tokens = {}
  1706. # seqid_to_indices = {}
  1707. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1708. # sample_results):
  1709. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1710. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1711. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1712. # seqids = args.miro_seqids
  1713. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1714. # device=logits.device,
  1715. # dtype=torch.long)
  1716. # # Clumsily, we recalculate token surprisals.
  1717. # logits_view = logits[args.miro_indices]
  1718. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1719. # dim=-1,
  1720. # index=picked_tokens) / -math.log(2)
  1721. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1722. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1723. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1724. # nu_mus = mus - (picked_surprise - taus) * etas
  1725. # # Record updated mu values for use in the next iteration
  1726. # # Note how each mu is split into multiple based on the number of samples.
  1727. # for seqid, seq_mus in zip(seqids, nu_mus):
  1728. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1729. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)