sampler.py 67 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from math import inf
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.common.sampling_params import SamplingType
  8. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  9. PromptLogprobs, SampleLogprobs,
  10. SamplerOutput, SequenceOutput)
  11. from aphrodite.triton_utils import HAS_TRITON
  12. if HAS_TRITON:
  13. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  14. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  15. SamplingTensors,
  16. SequenceGroupToSample)
  17. # (num_token_ids, num_parent_ids) per sequence group.
  18. SampleResultType = List[Tuple[List[int], List[int]]]
  19. class Sampler(nn.Module):
  20. """Samples the next tokens from the model's outputs.
  21. This layer does the following:
  22. 1. Discard the hidden states that are not used for sampling (i.e., all
  23. tokens except the final one in each prompt).
  24. 2. Compute the logits for the next tokens.
  25. 3. Apply presence, frequency and repetition penalties.
  26. 4. Apply temperature scaling.
  27. 5. Apply top-p and top-k truncation.
  28. 6. Sample the next tokens.
  29. Here, each sequence group within the batch can have different sampling
  30. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  31. The structure of the logits tensor is coupled with the seq_groups in
  32. sampling_metadata. Typically, each sequence in each seq_group has one row in
  33. logits for the next token to be sampled; however, for a seq_group with a
  34. prompt request with the prompt_logprobs sampling parameter, there are rows
  35. in logits for each token in the input prompt.
  36. """
  37. def __init__(self):
  38. super().__init__()
  39. # Whether or not the SamplerOutput should have on-device tensors
  40. # containing the sampled token ids and probabilities. This is used by
  41. # speculative decoding.
  42. self.include_gpu_probs_tensor = False
  43. self.should_modify_greedy_probs_inplace = False
  44. def _init_sampling_tensors(
  45. self,
  46. logits: torch.Tensor,
  47. sampling_metadata: SamplingMetadata,
  48. ):
  49. """The goal here is to reuse sampling tensors between similar decode
  50. runs. This is possible because sampling logic does not change between
  51. decodes of the same sequences.
  52. """
  53. _, vocab_size = logits.shape
  54. # First free any existing stored sampling tensors.
  55. # This is necessary because some sampling tensors may
  56. # have pinned memory.
  57. self._sampling_tensors = None
  58. # Initialize new sampling tensors
  59. (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
  60. do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  61. do_typical_ps, do_quadratic, do_xtc, do_kl_threshold, do_jsd_threshold,
  62. do_dynatypical_p, do_temp_last
  63. ) = SamplingTensors.from_sampling_metadata(
  64. sampling_metadata, vocab_size, logits.device, logits.dtype)
  65. self._sampling_tensors = sampling_tensors
  66. self._do_penalties = do_penalties
  67. self._do_temperatures = do_temperatures
  68. self._do_top_p_top_k = do_top_p_top_k
  69. self._do_top_as = do_top_as
  70. self._do_min_p = do_min_p
  71. self._do_tfss = do_tfss
  72. self._do_eta_cutoffs = do_eta_cutoffs
  73. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  74. self._do_typical_ps = do_typical_ps
  75. self._do_quadratic = do_quadratic
  76. self._do_xtc = do_xtc
  77. self._do_kl_threshold = do_kl_threshold
  78. self._do_jsd_threshold = do_jsd_threshold
  79. self._do_dynatypical_p = do_dynatypical_p
  80. self._do_temp_last = do_temp_last
  81. def forward(
  82. self,
  83. logits: torch.Tensor,
  84. sampling_metadata: SamplingMetadata,
  85. ) -> Optional[SamplerOutput]:
  86. """
  87. Args:
  88. logits: (num_tokens, vocab_size).
  89. sampling_metadata: Metadata for sampling.
  90. """
  91. assert logits is not None
  92. _, vocab_size = logits.shape
  93. # Prepare sampling tensors with pinned memory to avoid blocking.
  94. if not sampling_metadata.reuse_sampling_tensors:
  95. self._init_sampling_tensors(logits, sampling_metadata)
  96. elif self._do_penalties:
  97. # In this case, the sampling tensors logic depends on
  98. # "output_tokens" of a sequence. As a result, we cannot
  99. # reuse sampling tensors, since "output_tokens" changes
  100. # between decode runs.
  101. self._init_sampling_tensors(logits, sampling_metadata)
  102. assert self._sampling_tensors is not None
  103. sampling_tensors = self._sampling_tensors
  104. do_penalties = self._do_penalties
  105. do_temperatures = self._do_temperatures
  106. do_top_p_top_k = self._do_top_p_top_k
  107. do_top_as = self._do_top_as
  108. do_min_p = self._do_min_p
  109. do_tfss = self._do_tfss
  110. do_eta_cutoffs = self._do_eta_cutoffs
  111. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  112. do_typical_ps = self._do_typical_ps
  113. do_quadratic = self._do_quadratic
  114. do_xtc = self._do_xtc
  115. do_kl_threshold = self._do_kl_threshold
  116. do_jsd_threshold = self._do_jsd_threshold
  117. do_dynatypical_p = self._do_dynatypical_p
  118. do_temp_last = self._do_temp_last
  119. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  120. # Apply presence and frequency penalties.
  121. if do_penalties:
  122. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  123. sampling_tensors.output_tokens,
  124. sampling_tensors.presence_penalties,
  125. sampling_tensors.frequency_penalties,
  126. sampling_tensors.repetition_penalties)
  127. # Apply temperature scaling if not doing temp_last.
  128. if do_temperatures and not do_temp_last:
  129. _apply_temperatures(logits, sampling_tensors.temperatures,
  130. sampling_tensors.dynatemp_mins,
  131. sampling_tensors.dynatemp_maxs,
  132. sampling_tensors.dynatemp_exps)
  133. if do_top_p_top_k:
  134. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  135. sampling_tensors.top_ks)
  136. if do_top_as:
  137. logits = _apply_top_a(logits, sampling_tensors.top_as)
  138. if do_min_p:
  139. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  140. if do_tfss:
  141. logits = _apply_tfs(logits, sampling_tensors.tfss)
  142. if do_eta_cutoffs:
  143. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  144. if do_epsilon_cutoffs:
  145. logits = _apply_epsilon_cutoff(logits,
  146. sampling_tensors.epsilon_cutoffs)
  147. if do_typical_ps:
  148. logits = _apply_typical_sampling(logits,
  149. sampling_tensors.typical_ps)
  150. if do_quadratic:
  151. logits = _apply_quadratic_sampling(
  152. logits, sampling_tensors.smoothing_factors,
  153. sampling_tensors.smoothing_curves)
  154. if do_xtc:
  155. logits = _apply_xtc_sampling(
  156. logits, sampling_tensors.xtc_thresholds,
  157. sampling_tensors.xtc_probabilities)
  158. if do_kl_threshold:
  159. logits = _apply_kl_divergence_sampling(
  160. logits, sampling_tensors.kl_thresholds)
  161. if do_jsd_threshold:
  162. logits = _apply_jsd_sampling(logits, sampling_tensors.jsd_thresholds)
  163. if do_dynatypical_p:
  164. logits = _apply_dynamic_typical_sampling(
  165. logits, sampling_tensors.min_typical_ps,
  166. sampling_tensors.max_typical_ps)
  167. if do_temperatures and do_temp_last:
  168. _apply_temperatures(logits, sampling_tensors.temperatures,
  169. sampling_tensors.dynatemp_mins,
  170. sampling_tensors.dynatemp_maxs,
  171. sampling_tensors.dynatemp_exps)
  172. banned_tokens = _get_custom_token_bans(sampling_metadata)
  173. logits = _apply_token_bans(logits, banned_tokens)
  174. # We use float32 for probabilities and log probabilities.
  175. # Compute the probabilities.
  176. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  177. # Compute the log probabilities.
  178. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  179. # Sample the next tokens.
  180. sample_results, maybe_sampled_tokens_tensor = _sample(
  181. probs,
  182. logprobs,
  183. sampling_metadata,
  184. sampling_tensors,
  185. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  186. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  187. )
  188. if self.include_gpu_probs_tensor:
  189. assert maybe_sampled_tokens_tensor is not None
  190. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  191. else:
  192. on_device_tensors = None
  193. # Get the logprobs query results.
  194. prompt_logprobs = None
  195. sample_logprobs = None
  196. if not sampling_metadata.skip_sampler_cpu_output:
  197. prompt_logprobs, sample_logprobs = _get_logprobs(
  198. logprobs, sampling_metadata, sample_results)
  199. return _build_sampler_output(
  200. sample_results,
  201. sampling_metadata,
  202. prompt_logprobs,
  203. sample_logprobs,
  204. on_device_tensors=on_device_tensors,
  205. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  206. @property
  207. def _should_modify_greedy_probs_inplace(self) -> bool:
  208. """Whether or not the sampler should modify the probability distribution
  209. of greedily-sampled tokens such that multinomial sampling would sample
  210. the greedily-sampled token.
  211. In other words, if True then we set the probability of the greedily-
  212. sampled token to 1.
  213. This is used by speculative decoding, which requires that the sampling
  214. method be encoded into the probability distribution.
  215. """
  216. return self.should_modify_greedy_probs_inplace
  217. def _get_bin_counts_and_mask(
  218. tokens: torch.Tensor,
  219. vocab_size: int,
  220. num_seqs: int,
  221. ) -> Tuple[torch.Tensor, torch.Tensor]:
  222. # Compute the bin counts for the tokens.
  223. # vocab_size + 1 for padding.
  224. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  225. dtype=torch.long,
  226. device=tokens.device)
  227. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  228. bin_counts = bin_counts[:, :vocab_size]
  229. mask = bin_counts > 0
  230. return bin_counts, mask
  231. def _get_custom_token_bans(
  232. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  233. assert sampling_metadata.seq_groups is not None
  234. banned_tokens: List[List[int]] = []
  235. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  236. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  237. seq_ids = seq_group.seq_ids
  238. custom_token_bans = sampling_params.custom_token_bans
  239. if (i < sampling_metadata.num_prompts
  240. and sampling_params.prompt_logprobs is not None):
  241. prompt_len = len(seq_group.prompt_logprob_indices)
  242. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  243. banned_tokens += [custom_token_bans] * len(seq_ids)
  244. return banned_tokens
  245. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  246. output_tokens_tensor: torch.Tensor,
  247. presence_penalties: torch.Tensor,
  248. frequency_penalties: torch.Tensor,
  249. repetition_penalties: torch.Tensor) -> torch.Tensor:
  250. num_seqs, vocab_size = logits.shape
  251. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  252. num_seqs)
  253. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  254. output_tokens_tensor, vocab_size, num_seqs)
  255. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  256. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  257. logits = torch.where(logits > 0, logits / repetition_penalties,
  258. logits * repetition_penalties)
  259. # We follow the definition in OpenAI API.
  260. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  261. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  262. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  263. return logits
  264. def _apply_temperatures(
  265. logits: torch.Tensor,
  266. temperatures: torch.Tensor,
  267. dynatemp_mins: torch.Tensor,
  268. dynatemp_maxs: torch.Tensor,
  269. dynatemp_exps: torch.Tensor,
  270. ) -> None:
  271. dynatemp_mask = dynatemp_exps != 0
  272. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  273. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  274. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  275. dynatemp_logits = logits[dynatemp_mask]
  276. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  277. dynatemp_probs = dynatemp_shifted_logits.exp()
  278. dynatemp_entropies = -(dynatemp_probs *
  279. dynatemp_shifted_logits).nansum(dim=-1)
  280. dynatemp_max_entropies = torch.log_(
  281. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  282. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  283. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  284. normalized_entropies.pow_(dynatemp_exps))
  285. temperatures[dynatemp_mask] = dyn_temp
  286. temperatures[temperatures <= 0.0] = 1.0
  287. # Use float32 to apply temp.
  288. # Use in-place division to avoid creating a new tensor.
  289. logits = logits.to(torch.float)
  290. logits.div_(temperatures.unsqueeze(dim=1))
  291. def _apply_token_bans(logits: torch.Tensor,
  292. banned_tokens: List[List[int]]) -> torch.Tensor:
  293. for i, banned_token_ids in enumerate(banned_tokens):
  294. if i >= logits.size(0):
  295. break
  296. if not banned_token_ids:
  297. continue
  298. logits[i, banned_token_ids] = -float("inf")
  299. return logits
  300. def _apply_min_tokens_penalty(
  301. logits: torch.Tensor,
  302. sampling_metadata: SamplingMetadata,
  303. ) -> torch.Tensor:
  304. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  305. have not been generated yet
  306. """
  307. # list of indices in logits that will be set to -inf
  308. logits_to_penalize = []
  309. logits_applied = 0
  310. for seq_group in sampling_metadata.seq_groups:
  311. seq_ids = seq_group.seq_ids
  312. sampling_params = seq_group.sampling_params
  313. sample_indices = seq_group.sample_indices
  314. logits_applied += len(sample_indices) + len(
  315. seq_group.prompt_logprob_indices)
  316. if not seq_group.do_sample:
  317. continue
  318. start_idx = sample_indices[0]
  319. min_tokens = sampling_params.min_tokens
  320. token_ids_to_penalize = sampling_params.all_stop_token_ids
  321. if min_tokens > 0 and token_ids_to_penalize:
  322. seqs_to_penalize = []
  323. for j, seq_id in enumerate(seq_ids):
  324. seq_data = seq_group.seq_data[seq_id]
  325. if len(seq_data.output_token_ids_array) < min_tokens:
  326. seqs_to_penalize.append(j)
  327. if seqs_to_penalize:
  328. # convert to the index into logits
  329. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  330. # itertools.product pairs each seq index with every token id
  331. logits_to_penalize.extend(
  332. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  333. if logits_to_penalize:
  334. # use zip and * to group indices along each dimension
  335. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  336. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  337. # verifies that no rows in logits were missed unexpectedly
  338. assert logits_applied == logits.shape[0]
  339. return logits
  340. def _apply_top_k_top_p(
  341. logits: torch.Tensor,
  342. p: torch.Tensor,
  343. k: torch.Tensor,
  344. ) -> torch.Tensor:
  345. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  346. # Apply top-k.
  347. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  348. # Get all the top_k values.
  349. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  350. top_k_mask = logits_sort < top_k_mask
  351. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  352. # Apply top-p.
  353. probs_sort = logits_sort.softmax(dim=-1)
  354. probs_sum = probs_sort.cumsum(dim=-1)
  355. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  356. # at least one
  357. top_p_mask[:, -1] = False
  358. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  359. # Re-sort the probabilities.
  360. src = torch.arange(logits_idx.shape[-1],
  361. device=logits_idx.device).expand_as(logits_idx)
  362. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  363. index=logits_idx,
  364. src=src)
  365. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  366. return logits
  367. def _apply_min_p(
  368. logits: torch.Tensor,
  369. min_p: torch.Tensor,
  370. ) -> torch.Tensor:
  371. """
  372. Adapted from
  373. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  374. """
  375. probs = torch.softmax(logits, dim=-1)
  376. top_probs, _ = probs.max(dim=-1, keepdim=True)
  377. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  378. tokens_to_remove = probs < scaled_min_p
  379. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  380. return logits
  381. def _apply_top_a(
  382. logits: torch.Tensor,
  383. top_a: torch.Tensor,
  384. ) -> torch.Tensor:
  385. probs = torch.softmax(logits, dim=-1)
  386. top_probs, _ = probs.max(dim=-1, keepdim=True)
  387. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  388. tokens_to_remove = probs < threshold
  389. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  390. return logits
  391. def _apply_tfs(
  392. logits: torch.Tensor,
  393. tfs: torch.Tensor,
  394. ) -> torch.Tensor:
  395. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  396. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  397. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  398. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  399. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  400. tfs_mask = torch.cat(
  401. (
  402. torch.zeros(
  403. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  404. tfs_mask,
  405. torch.ones(
  406. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  407. ),
  408. dim=-1,
  409. )
  410. logits_sort[tfs_mask] = -float("inf")
  411. logits = torch.gather(logits_sort,
  412. dim=-1,
  413. index=torch.argsort(logits_idx, dim=-1))
  414. return logits
  415. def _apply_eta_cutoff(
  416. logits: torch.Tensor,
  417. eta_cutoff: torch.Tensor,
  418. ) -> torch.Tensor:
  419. shifted_logits = torch.log_softmax(logits, dim=-1)
  420. probs = shifted_logits.exp()
  421. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  422. eps = torch.min(eta_cutoff,
  423. torch.sqrt(eta_cutoff) *
  424. torch.exp(neg_entropy)).unsqueeze(dim=1)
  425. eta_mask = probs < eps
  426. # guard against nulling out all the logits
  427. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  428. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  429. logits[eta_mask] = -float("inf")
  430. return logits
  431. def _apply_epsilon_cutoff(
  432. logits: torch.Tensor,
  433. epsilon_cutoff: torch.Tensor,
  434. ) -> torch.Tensor:
  435. probs = logits.softmax(dim=-1)
  436. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  437. # guard against nulling out all the logits
  438. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  439. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  440. logits[eps_mask] = -float("inf")
  441. return logits
  442. def _apply_typical_sampling(
  443. logits: torch.Tensor,
  444. typical_p: torch.Tensor,
  445. ) -> torch.Tensor:
  446. shifted_logits = torch.log_softmax(logits, dim=-1)
  447. probs = shifted_logits.exp()
  448. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  449. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  450. _, indices = torch.sort(surprisal_deviations)
  451. reordered_probs = probs.gather(-1, indices)
  452. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  453. dim=1)
  454. min_tokens_to_keep = 1
  455. # Keep at least min_tokens_to_keep
  456. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  457. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  458. logits[typ_mask] = -float("inf")
  459. return logits
  460. def _apply_dynamic_typical_sampling(
  461. logits: torch.Tensor,
  462. min_typical_p: torch.Tensor,
  463. max_typical_p: torch.Tensor,
  464. ) -> torch.Tensor:
  465. """Applies typical sampling with a dynamic typical_p threshold based on entropy.
  466. Args:
  467. logits: Tensor of shape (batch_size, vocab_size) containing the logits.
  468. min_typical_p: Minimum threshold for typical sampling.
  469. max_typical_p: Maximum threshold for typical sampling.
  470. Returns:
  471. Modified logits tensor with atypical tokens masked out.
  472. """
  473. shifted_logits = torch.log_softmax(logits, dim=-1)
  474. probs = shifted_logits.exp()
  475. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  476. entropy = -neg_entropy # Entropy is the negative of the sum
  477. # Normalize entropy to range [0, 1]
  478. max_entropy = torch.log(torch.tensor(logits.size(-1), dtype=logits.dtype, device=logits.device))
  479. normalized_entropy = entropy / max_entropy # (batch_size, 1)
  480. # Compute dynamic typical_p
  481. dynamic_typical_p = min_typical_p + (max_typical_p - min_typical_p) * normalized_entropy
  482. dynamic_typical_p = dynamic_typical_p.squeeze(dim=-1) # (batch_size,)
  483. # Calculate surprisal deviation for each token
  484. surprisal_deviations = (entropy - shifted_logits).abs()
  485. # Sort tokens by surprisal deviation
  486. sorted_deviations, indices = torch.sort(surprisal_deviations, dim=-1)
  487. sorted_probs = probs.gather(-1, indices)
  488. # Compute cumulative probabilities of sorted tokens
  489. cum_probs = sorted_probs.cumsum(dim=-1)
  490. # Create a mask for tokens exceeding the dynamic typical_p threshold
  491. typ_mask = cum_probs > dynamic_typical_p.unsqueeze(dim=1)
  492. # Ensure at least one token is kept
  493. typ_mask[..., 0] = False
  494. # Scatter the mask back to the original logits order
  495. typ_mask = typ_mask.scatter(1, indices, typ_mask)
  496. # Mask out the filtered tokens
  497. logits = logits.masked_fill(typ_mask, -float("inf"))
  498. return logits
  499. def _apply_kl_divergence_sampling(
  500. logits: torch.Tensor,
  501. kl_thresholds: torch.Tensor,
  502. ) -> torch.Tensor:
  503. """Applies KL Divergence-based sampling to filter tokens.
  504. Args:
  505. logits: Tensor of shape (batch_size, vocab_size) containing the logits.
  506. kl_thresholds: Tensor of shape (batch_size,) containing the KL divergence thresholds.
  507. Returns:
  508. Modified logits tensor with tokens below the KL divergence threshold masked out.
  509. """
  510. # Compute the probability distribution from logits
  511. probs = torch.softmax(logits, dim=-1)
  512. # Create a uniform distribution
  513. uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1))
  514. # Calculate the KL divergence between the predicted and uniform distributions for each token
  515. kl_div_tokens = -torch.log(probs + 1e-8) - torch.log(uniform_probs + 1e-8)
  516. # Create a mask for tokens where KL divergence is less than the threshold
  517. kl_mask = kl_div_tokens < kl_thresholds.unsqueeze(-1)
  518. # Mask out tokens where the KL divergence is below the threshold
  519. logits = logits.masked_fill(kl_mask, -float("inf"))
  520. return logits
  521. def _apply_jsd_sampling(
  522. logits: torch.Tensor,
  523. jsd_thresholds: torch.Tensor,
  524. ) -> torch.Tensor:
  525. """Applies Jensen-Shannon Divergence-based sampling to filter tokens.
  526. Args:
  527. logits: Tensor of shape (batch_size, vocab_size) containing the logits.
  528. jsd_thresholds: Tensor of shape (batch_size,) with the JSD threshold for each sequence.
  529. Returns:
  530. Modified logits tensor with tokens exceeding the JSD threshold masked out.
  531. """
  532. # Compute the probability distribution from logits
  533. probs = torch.softmax(logits, dim=-1) # Shape: (batch_size, vocab_size)
  534. # Create a uniform distribution
  535. uniform_probs = torch.full_like(probs, 1.0 / probs.size(-1)) # Shape: (batch_size, vocab_size)
  536. # Compute the average distribution
  537. average_probs = 0.5 * (probs + uniform_probs) # Shape: (batch_size, vocab_size)
  538. # Compute per-token KL divergences
  539. kl_probs = probs * (torch.log(probs + 1e-8) - torch.log(average_probs + 1e-8)) # Shape: (batch_size, vocab_size)
  540. kl_uniform = uniform_probs * (torch.log(uniform_probs + 1e-8) - torch.log(average_probs + 1e-8)) # Shape: (batch_size, vocab_size)
  541. # Compute per-token JSD
  542. jsd = 0.5 * (kl_probs + kl_uniform) # Shape: (batch_size, vocab_size)
  543. # Compare per-token JSD with threshold
  544. jsd_thresholds_expanded = jsd_thresholds.unsqueeze(-1) # Shape: (batch_size, 1)
  545. jsd_mask = jsd >= jsd_thresholds_expanded # Shape: (batch_size, vocab_size)
  546. # Mask out tokens where the JSD exceeds the threshold
  547. logits = logits.masked_fill(jsd_mask, -float("inf"))
  548. return logits
  549. def _apply_quadratic_sampling(
  550. logits: torch.Tensor,
  551. smoothing_factor: torch.Tensor,
  552. smoothing_curve: torch.Tensor,
  553. ) -> torch.Tensor:
  554. """
  555. Applies a quadratic transformation to the logits based on the
  556. provided smoothing factors and curves. The transformation is
  557. centered around the maximum logit value in the batch.
  558. The transformation involves a quadratic and cubic term, with the
  559. cubic term controlled by the smoothing curve. The quadratic term is
  560. scaled by the smoothing factor, and the cubic term is scaled by the
  561. product of the smoothing factor and the smoothing curve.
  562. params:
  563. logits (torch.Tensor): The logits to be transformed.
  564. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  565. term in the transformation.
  566. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  567. in the transformation.
  568. returns:
  569. torch.Tensor: The transformed logits.
  570. Credits: @kalomaze
  571. """
  572. mask = smoothing_factor != 0
  573. smoothing_factor.unsqueeze_(dim=1)
  574. smoothing_curve.unsqueeze_(dim=1)
  575. k = smoothing_factor * (3 - smoothing_curve) / 2
  576. s = smoothing_factor * (smoothing_curve - 1) / 2
  577. quadlogits = logits[mask] # limit to logits using this sampler
  578. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  579. # Construct the delta from each logit to its new value
  580. diff = quadlogits - max_logits
  581. diff -= diff**2 * (s[mask] * diff - k[mask])
  582. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  583. logits[mask] -= diff
  584. return logits
  585. def _apply_xtc_sampling(
  586. logits: torch.Tensor,
  587. xtc_thresholds: torch.Tensor,
  588. xtc_probabilities: torch.Tensor,
  589. ) -> torch.Tensor:
  590. """Apply Exclude Top Choices (XTC) sampling to the logits.
  591. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  592. Args:
  593. logits: (num_tokens, vocab_size) The input logits.
  594. xtc_thresholds: (num_tokens,) The threshold for each token.
  595. xtc_probabilities: (num_tokens,) The probability of applying XTC
  596. for each token.
  597. Returns:
  598. torch.Tensor: The modified logits.
  599. """
  600. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  601. if not apply_xtc.any():
  602. return logits
  603. probs = torch.softmax(logits, dim=-1)
  604. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  605. # Find indices where the next probability is above the threshold
  606. # Skips the top choice, which later on becomes skipping the last choice.
  607. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  608. # Apply XTC only to rows where it should be applied
  609. for i in range(logits.shape[0]):
  610. if apply_xtc[i]:
  611. # Count logits above the threshold (skipping the first)
  612. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  613. if indices_to_remove > 0:
  614. # Implies the top logit and at least one other is >= threshold.
  615. # Mask out above_thresh logits except the last/lowest one.
  616. logits[i].scatter_(
  617. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  618. return logits
  619. def _greedy_sample(
  620. selected_seq_groups: List[SequenceGroupToSample],
  621. samples: torch.Tensor,
  622. ) -> List[Tuple[List[int], List[int]]]:
  623. """Run greedy sampling on a given samples.
  624. Args:
  625. selected_seq_groups: A list of sequence groups batched.
  626. samples: (num_selected_samples,) A tensor of samples. The length of
  627. samples could be smaller than selected_seq_groups if
  628. seq_group.do_sample is False.
  629. Returns:
  630. Tuple of (next_token_ids, parent_ids). The length of returned list is
  631. same as the length of selected_seq_groups. If the corresponding
  632. seq_group has do_sample=False, tuple contains ([], [])
  633. """
  634. samples = samples.tolist()
  635. sample_idx = 0
  636. results = []
  637. for seq_group in selected_seq_groups:
  638. if not seq_group.do_sample:
  639. results.append(([], []))
  640. continue
  641. seq_ids = seq_group.seq_ids
  642. num_parent_seqs = len(seq_ids)
  643. assert num_parent_seqs == 1, (
  644. "Greedy sampling should have only one seq.")
  645. parent_ids = list(range(num_parent_seqs))
  646. next_token_ids = [samples[sample_idx]]
  647. results.append((next_token_ids, parent_ids))
  648. sample_idx += num_parent_seqs
  649. return results
  650. def _random_sample(
  651. selected_seq_groups: List[SequenceGroupToSample],
  652. random_samples: torch.Tensor,
  653. ) -> List[Tuple[List[int], List[int]]]:
  654. """Run random sampling on a given samples.
  655. Args:
  656. selected_seq_groups: A list of sequence groups batched.
  657. random_samples: (num_selected_samples,) A tensor of samples. The
  658. length of samples could be smaller than selected_seq_groups if
  659. seq_group.do_sample is False.
  660. Returns:
  661. Tuple of (next_token_ids, parent_ids). The length of returned list is
  662. same as the length of selected_seq_groups. If the corresponding
  663. seq_group has do_sample=False, tuple contains ([], [])
  664. """
  665. # Find the maximum best_of value of the prompt phase requests.
  666. random_samples = random_samples.cpu()
  667. sample_idx = 0
  668. results = []
  669. for seq_group in selected_seq_groups:
  670. if not seq_group.do_sample:
  671. results.append(([], []))
  672. continue
  673. seq_ids = seq_group.seq_ids
  674. sampling_params = seq_group.sampling_params
  675. is_prompt = seq_group.is_prompt
  676. num_parent_seqs = len(seq_ids)
  677. if is_prompt:
  678. # Prompt phase.
  679. parent_ids = [0] * sampling_params.best_of
  680. next_token_ids = random_samples[
  681. sample_idx, :sampling_params.best_of].tolist()
  682. else:
  683. # Generation phase.
  684. parent_ids = list(range(num_parent_seqs))
  685. next_token_ids = random_samples[sample_idx:sample_idx +
  686. num_parent_seqs, 0].tolist()
  687. results.append((next_token_ids, parent_ids))
  688. sample_idx += num_parent_seqs
  689. return results
  690. def _beam_search_sample(
  691. selected_seq_groups: List[SequenceGroupToSample],
  692. logprobs: torch.Tensor,
  693. ) -> List[Tuple[List[int], List[int]]]:
  694. """Run beam sampling on a given samples.
  695. Args:
  696. selected_seq_groups: A list of sequence groups batched.
  697. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  698. on selected sample indices.
  699. Returns:
  700. Tuple of (next_token_ids, parent_ids). The length of returned list is
  701. same as the length of selected_seq_groups. If the corresponding
  702. seq_group has do_sample=False, tuple contains ([], [])
  703. """
  704. # We sample 2 * beam_width candidates to make sure that with high
  705. # probability we can get `beam_width` candidates in addition to
  706. # the finished sequences for the next iteration. See
  707. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  708. # for details. See also HF reference:
  709. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  710. #
  711. # NOTE: Beam search is not vectorized, so its speed can be slower than
  712. # other sampling methods.
  713. sample_idx = 0
  714. results = []
  715. for seq_group in selected_seq_groups:
  716. if not seq_group.do_sample:
  717. results.append(([], []))
  718. continue
  719. is_prompt = seq_group.is_prompt
  720. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  721. num_parent_seqs = len(seq_ids)
  722. beam_width = sampling_params.best_of
  723. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  724. if is_prompt:
  725. # Prompt phase.
  726. assert num_parent_seqs == 1, (
  727. "Prompt input should have only one seq.")
  728. parent_ids = [0] * (2 * beam_width)
  729. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  730. 2 * beam_width)
  731. next_token_ids = next_token_ids.tolist()
  732. else:
  733. # Generation phase.
  734. cumulative_logprobs = [
  735. seq_group.seq_data[seq_id].cumulative_logprob
  736. for seq_id in seq_ids
  737. ]
  738. cumulative_logprobs = torch.tensor(
  739. cumulative_logprobs,
  740. dtype=torch.float,
  741. device=seq_group_logprobs.device)
  742. seq_group_logprobs = (seq_group_logprobs +
  743. cumulative_logprobs.unsqueeze(dim=1))
  744. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  745. 2 * beam_width)
  746. topk_ids = topk_ids.tolist()
  747. vocab_size = seq_group_logprobs.size(-1)
  748. parent_ids = [i // vocab_size for i in topk_ids]
  749. next_token_ids = [i % vocab_size for i in topk_ids]
  750. results.append((next_token_ids, parent_ids))
  751. sample_idx += num_parent_seqs
  752. assert sample_idx == logprobs.size(0)
  753. return results
  754. # torch.multinomial forces a GPU<->CPU sync.
  755. # Therefore, we use an optimized implementation instead.
  756. # Note that we always sample with replacement.
  757. # probs will be modified in place, but this is fine, as we pass
  758. # in a copy already.
  759. def _multinomial(
  760. probs: torch.Tensor,
  761. num_samples: int,
  762. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  763. ) -> torch.Tensor:
  764. if num_samples > 1:
  765. # This is equivalent to torch.repeat_interleaved (which also
  766. # forces a GPU<->CPU sync).
  767. # This allows us to do sampling with replacement by creating
  768. # num_samples copies of each row in the tensor, and then
  769. # batch sampling the resulting tensor.
  770. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  771. probs.shape[1]).contiguous().view(
  772. -1, probs.shape[1])
  773. q = torch.empty_like(probs)
  774. if seq_groups is None:
  775. q.exponential_()
  776. else:
  777. sample_idx = 0
  778. for seq_group in seq_groups:
  779. seq_ids = seq_group.seq_ids
  780. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  781. q[sample_idx:next_sample_idx].exponential_(
  782. generator=seq_group.generator)
  783. sample_idx = next_sample_idx
  784. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  785. def _sample_with_torch(
  786. probs: torch.Tensor,
  787. logprobs: torch.Tensor,
  788. sampling_metadata: SamplingMetadata,
  789. include_gpu_probs_tensor: bool,
  790. modify_greedy_probs: bool,
  791. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  792. categorized_seq_group_ids = {t: [] for t in SamplingType}
  793. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  794. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  795. sampling_params = seq_group.sampling_params
  796. sampling_type = sampling_params.sampling_type
  797. categorized_seq_group_ids[sampling_type].append(i)
  798. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  799. sample_metadata = {}
  800. multinomial_samples = {}
  801. # Create output tensor for sampled token ids.
  802. if include_gpu_probs_tensor:
  803. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  804. 1,
  805. dtype=torch.long,
  806. device=logprobs.device)
  807. else:
  808. sampled_token_ids_tensor = None
  809. # Counterintuitively, having two loops here is actually faster.
  810. # The first loop can run without waiting on GPU<->CPU sync.
  811. for sampling_type in SamplingType:
  812. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  813. num_tokens = len(sample_indices)
  814. if num_tokens == 0:
  815. continue
  816. seq_group_id = categorized_seq_group_ids[sampling_type]
  817. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  818. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  819. long_sample_indices = sample_indices.long()
  820. if sampling_type == SamplingType.GREEDY:
  821. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  822. dim=-1)
  823. if include_gpu_probs_tensor:
  824. # Store sampled tokens in output tensor.
  825. sampled_token_ids_tensor[
  826. long_sample_indices] = greedy_samples.unsqueeze(-1)
  827. if modify_greedy_probs:
  828. # If required, modify the probabilities such that sampling from
  829. # the modified distribution would always sample the argmax
  830. # token id.
  831. _modify_greedy_probs_inplace(logprobs, probs,
  832. long_sample_indices,
  833. greedy_samples)
  834. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  835. max_best_of_in_batch = 1
  836. for seq_group in seq_groups:
  837. if seq_group.is_prompt:
  838. sampling_params = seq_group.sampling_params
  839. max_best_of_in_batch = max(max_best_of_in_batch,
  840. sampling_params.best_of)
  841. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  842. "seq_groups": seq_groups,
  843. }
  844. multinomial_samples[sampling_type] = _multinomial(
  845. probs[long_sample_indices], max_best_of_in_batch,
  846. **seeded_args)
  847. if include_gpu_probs_tensor:
  848. # Store sampled tokens in output tensor.
  849. sampled_token_ids_tensor[
  850. long_sample_indices] = multinomial_samples[sampling_type]
  851. elif sampling_type == SamplingType.BEAM:
  852. beam_search_logprobs = logprobs[sample_indices]
  853. else:
  854. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  855. # GPU<->CPU sync happens in the loop below.
  856. # This also converts the sample output to Python objects.
  857. if not sampling_metadata.skip_sampler_cpu_output:
  858. for sampling_type in SamplingType:
  859. if sampling_type not in sample_metadata:
  860. continue
  861. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  862. if sampling_type == SamplingType.GREEDY:
  863. sample_results = _greedy_sample(seq_groups, greedy_samples)
  864. elif sampling_type in (SamplingType.RANDOM,
  865. SamplingType.RANDOM_SEED):
  866. sample_results = _random_sample(
  867. seq_groups, multinomial_samples[sampling_type])
  868. elif sampling_type == SamplingType.BEAM:
  869. sample_results = _beam_search_sample(seq_groups,
  870. beam_search_logprobs)
  871. sample_results_dict.update(zip(seq_group_id, sample_results))
  872. sample_results = [
  873. sample_results_dict.get(i, ([], []))
  874. for i in range(len(sampling_metadata.seq_groups))
  875. ]
  876. else:
  877. sample_results = []
  878. return sample_results, sampled_token_ids_tensor
  879. def _sample_with_triton_kernel(
  880. probs: torch.Tensor,
  881. logprobs: torch.Tensor,
  882. sampling_metadata: SamplingMetadata,
  883. sampling_tensors: SamplingTensors,
  884. ) -> List[Tuple[List[int], List[int]]]:
  885. categorized_seq_group_ids = {t: [] for t in SamplingType}
  886. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  887. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  888. sampling_params = seq_group.sampling_params
  889. sampling_type = sampling_params.sampling_type
  890. categorized_seq_group_ids[sampling_type].append(i)
  891. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  892. sample_metadata = {}
  893. max_best_of_in_batch = 1
  894. # Counterintuitively, having two loops here is actually faster.
  895. # The first loop can run without waiting on GPU<->CPU sync.
  896. for sampling_type in SamplingType:
  897. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  898. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  899. num_tokens = len(sample_indices)
  900. if num_tokens == 0:
  901. continue
  902. seq_group_id = categorized_seq_group_ids[sampling_type]
  903. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  904. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  905. sample_indices,
  906. sampled_token_indices)
  907. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  908. SamplingType.RANDOM_SEED):
  909. for seq_group in seq_groups:
  910. if seq_group.is_prompt:
  911. sampling_params = seq_group.sampling_params
  912. max_best_of_in_batch = max(max_best_of_in_batch,
  913. sampling_params.best_of)
  914. elif sampling_type == SamplingType.BEAM:
  915. beam_search_logprobs = logprobs[sample_indices]
  916. else:
  917. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  918. sampled_tokens, _, _ = sample_triton(
  919. probs=probs,
  920. seeds=sampling_tensors.sampling_seeds,
  921. max_best_of=max_best_of_in_batch,
  922. sample_indices=sampling_tensors.sample_indices,
  923. logprobs=logprobs,
  924. # don't save logprobs because we have logic for that below
  925. # TODO: use this instead of the CPU-based logic below
  926. save_logprobs=False,
  927. )
  928. # GPU<->CPU sync happens in the loop below.
  929. for sampling_type in SamplingType:
  930. if sampling_type not in sample_metadata:
  931. continue
  932. (seq_group_id, seq_groups, sample_indices,
  933. sampled_token_indices) = sample_metadata[sampling_type]
  934. if sampling_type == SamplingType.GREEDY:
  935. sample_results = _greedy_sample(
  936. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  937. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  938. sample_results = _random_sample(
  939. seq_groups, sampled_tokens[sampled_token_indices])
  940. elif sampling_type == SamplingType.BEAM:
  941. sample_results = _beam_search_sample(seq_groups,
  942. beam_search_logprobs)
  943. sample_results_dict.update(zip(seq_group_id, sample_results))
  944. sample_results = [
  945. sample_results_dict.get(i, ([], []))
  946. for i in range(len(sampling_metadata.seq_groups))
  947. ]
  948. return sample_results
  949. def _sample(
  950. probs: torch.Tensor, logprobs: torch.Tensor,
  951. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  952. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  953. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  954. """
  955. Args:
  956. probs: (num_query_tokens_in_batch, num_vocab)
  957. logprobs: (num_query_tokens_in_batch, num_vocab)
  958. sampling_metadata: The metadata for a batch for sampling.
  959. sampling_tensors: Tensors that include sampling related metadata.
  960. Returns:
  961. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  962. If sampling is skipped, it returns ([], [])
  963. sampled_token_ids_tensor: A tensor of sampled token ids.
  964. """
  965. return _sample_with_torch(
  966. probs,
  967. logprobs,
  968. sampling_metadata,
  969. include_gpu_probs_tensor=include_gpu_probs_tensor,
  970. modify_greedy_probs=modify_greedy_probs,
  971. )
  972. # TODO: Enable once Triton kernel & associated code is faster.
  973. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  974. # sampling_tensors)
  975. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  976. """
  977. This function calculates the ranks of the chosen tokens in a logprob tensor.
  978. Args:
  979. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  980. where N is the no. of tokens and M is the vocab dim.
  981. indices (torch.Tensor): List of chosen token indices.
  982. Returns:
  983. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  984. Each element in the returned tensor represents the rank
  985. of the chosen token in the input logprob tensor.
  986. """
  987. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  988. indices]
  989. return (x > vals[:, None]).long().sum(1).add_(1)
  990. def _get_logprobs(
  991. logprobs: torch.Tensor,
  992. sampling_metadata: SamplingMetadata,
  993. sample_results: List[Tuple[List[int], List[int]]],
  994. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  995. """Return sample lobprobs and prompt logprobs.
  996. The logic consists of 3 parts.
  997. - Select indices to compute logprob from, ranks of token ids, and
  998. the top k token ids from logprobs.
  999. - Compute prompt logprobs if required.
  1000. - Compute sample logprobs if required.
  1001. Args:
  1002. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  1003. logprob per vocab. Sequence groups' query tokens are batched in a
  1004. single flattened tensor. For example, assuming there are N
  1005. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  1006. prompt logprob is enabled), decode tokens for seq_group_1 (if
  1007. sampling is required), prefill tokens for seq_group_2, ...
  1008. sampling_metadata: The sampling metadata.
  1009. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  1010. parent_ids) for each sequence group. When beam search is enabled,
  1011. sample_results can contain different number of seq_ids from
  1012. sampling_metadata.seq_groups. It is because beam search creates
  1013. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  1014. BEAM_WIDTH number of seq_ids).
  1015. Returns:
  1016. A tuple of prompt and sample logprobs per sequence group in a batch.
  1017. """
  1018. # The index of query token to calculate logprobs. It includes both
  1019. # prompt and sample logprob indices.
  1020. query_indices: List[int] = []
  1021. # The next token ids to get the logprob value from.
  1022. next_token_ids: List[int] = []
  1023. # The largest requested number of logprobs. We find logprobs as many as the
  1024. # largest num logprobs in this API. If every logprobs is None, it will be
  1025. # set to -1.
  1026. largest_num_logprobs = -1
  1027. # If beam search is enabled.
  1028. use_beam_search = False
  1029. # Select indices to compute logprob from, ranks of token ids, and the top
  1030. # k token ids from logprobs.
  1031. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  1032. sample_results):
  1033. sampling_params = seq_group.sampling_params
  1034. # Update indices and tokens for prompt logprobs.
  1035. if (seq_group.is_prompt
  1036. and sampling_params.prompt_logprobs is not None):
  1037. largest_num_logprobs = max(largest_num_logprobs,
  1038. sampling_params.prompt_logprobs)
  1039. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1040. query_indices.extend(seq_group.prompt_logprob_indices)
  1041. next_token_ids.extend(next_prompt_tokens)
  1042. # Update indices and next tokenes for sample logprob.
  1043. if seq_group.do_sample:
  1044. token_ids, parent_seq_ids = sample_result
  1045. # NOTE: We cannot directly use sample_indices because
  1046. # sample_indices only contain parent seq_ids of a previous step.
  1047. # The current step may have different number of seq_ids, and
  1048. # we can obtain it from `sample_result[1]`.
  1049. query_idx = seq_group.sample_indices[0]
  1050. query_indices.extend(
  1051. [query_idx + parent_id for parent_id in parent_seq_ids])
  1052. next_token_ids.extend(token_ids)
  1053. if sampling_params.logprobs is not None:
  1054. largest_num_logprobs = max(largest_num_logprobs,
  1055. sampling_params.logprobs)
  1056. use_beam_search = use_beam_search or sampling_params.use_beam_search
  1057. assert len(next_token_ids) == len(query_indices)
  1058. if len(query_indices) == 0:
  1059. empty_sampled_logprob = []
  1060. empty_prompt_logprob = None
  1061. return [empty_prompt_logprob], [empty_sampled_logprob]
  1062. selected_logprobs, ranks = None, None
  1063. top_logprobs, top_token_ids = None, None
  1064. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  1065. # skip the whole logprob calculation.
  1066. if largest_num_logprobs >= 0 or use_beam_search:
  1067. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  1068. next_token_ids_gpu = torch.tensor(next_token_ids,
  1069. device=logprobs.device)
  1070. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  1071. # contain duplicates if beam search is enabled.
  1072. selected_logprobs = logprobs[[
  1073. query_indices_gpu,
  1074. next_token_ids_gpu,
  1075. ]]
  1076. ranks = _get_ranks(
  1077. logprobs[query_indices_gpu],
  1078. next_token_ids_gpu,
  1079. )
  1080. assert selected_logprobs.shape[0] == ranks.shape[0]
  1081. # We need to compute top k only if there exists logprobs > 0.
  1082. if largest_num_logprobs > 0:
  1083. # Logprobs of topk tokens for a batch of sequence groups.
  1084. # (num_query_tokens_across_batch).
  1085. top_logprobs, top_token_ids = torch.topk(logprobs,
  1086. largest_num_logprobs,
  1087. dim=-1)
  1088. top_logprobs = top_logprobs.to('cpu')
  1089. top_token_ids = top_token_ids.to('cpu')
  1090. selected_logprobs = selected_logprobs.to('cpu')
  1091. ranks = ranks.to('cpu')
  1092. # Find prompt/sample logprobs.
  1093. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  1094. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  1095. top_logprob_idx = 0
  1096. selected_logprobs_idx = 0
  1097. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1098. sample_results):
  1099. (prompt_logprobs, top_logprob_idx,
  1100. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1101. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1102. selected_logprobs_idx, top_logprob_idx)
  1103. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1104. (sampled_logprobs, top_logprob_idx,
  1105. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1106. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1107. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1108. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1109. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1110. def _get_prompt_logprob_if_needed(
  1111. seq_group: SequenceGroupToSample,
  1112. selected_logprobs: torch.Tensor,
  1113. ranks: torch.Tensor,
  1114. top_token_ids: torch.Tensor,
  1115. top_logprobs: torch.Tensor,
  1116. selected_logprobs_idx: int,
  1117. top_logprob_idx: int,
  1118. ):
  1119. """Compute the prompt logprob from a sequence group if needed."""
  1120. sampling_params = seq_group.sampling_params
  1121. is_prompt = seq_group.is_prompt
  1122. # Find prompt logprobs
  1123. prompt_logprobs: Optional[PromptLogprobs] = None
  1124. if is_prompt and sampling_params.prompt_logprobs is not None:
  1125. prompt_logprobs = []
  1126. num_logprobs = sampling_params.prompt_logprobs
  1127. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1128. # Pre-select indexes and create a list. It is faster than calling .item
  1129. # repetitively.
  1130. selected_logprob_items = selected_logprobs[
  1131. selected_logprobs_idx:selected_logprobs_idx +
  1132. len(next_prompt_tokens)].tolist()
  1133. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1134. len(next_prompt_tokens)].tolist()
  1135. for idx, token_id in enumerate(next_prompt_tokens):
  1136. # Calculate the prompt logprob of the real prompt tokens.
  1137. # {token_id: (logprob, rank_from_vocab)}
  1138. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1139. token_id: (selected_logprob_items[idx], rank_items[idx])
  1140. }
  1141. # Add top K prompt logprobs along with its rank.
  1142. if num_logprobs > 0:
  1143. top_ids = top_token_ids[
  1144. top_logprob_idx, :num_logprobs].tolist()
  1145. top_probs = top_logprobs[
  1146. top_logprob_idx, :num_logprobs].tolist()
  1147. # Top K is already sorted by rank, so we can use 1 ~
  1148. # num_logprobs + 1 for rank.
  1149. top_ranks = range(1, num_logprobs + 1)
  1150. prompt_logprobs_dict.update({
  1151. top_id: (top_prob, rank)
  1152. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1153. top_ranks)
  1154. })
  1155. prompt_logprobs.append({
  1156. token_id: Logprob(*logprob_and_rank)
  1157. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1158. })
  1159. # + 1 to go to the next prompt token.
  1160. top_logprob_idx += 1
  1161. # + len(next_prompt_tokens) to go to the next prompt.
  1162. selected_logprobs_idx += len(next_prompt_tokens)
  1163. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1164. def _get_sampled_logprob_if_needed(
  1165. seq_group: SequenceGroupToSample,
  1166. sample_result: Tuple[List[int], List[int]],
  1167. selected_logprobs: torch.Tensor,
  1168. ranks: torch.Tensor,
  1169. top_token_ids: torch.Tensor,
  1170. top_logprobs: torch.Tensor,
  1171. selected_logprobs_idx: int,
  1172. top_logprob_idx: int,
  1173. ):
  1174. """Compute the sample logprob if needed."""
  1175. seq_ids = seq_group.seq_ids
  1176. num_logprobs = seq_group.sampling_params.logprobs
  1177. use_beam_search = seq_group.sampling_params.use_beam_search
  1178. sampled_logprobs: SampleLogprobs = []
  1179. next_token_ids, parent_seq_ids = sample_result
  1180. if seq_group.do_sample:
  1181. assert len(next_token_ids) > 0
  1182. if num_logprobs is None and not use_beam_search:
  1183. for next_token_id in next_token_ids:
  1184. # Use a dummy logprob
  1185. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1186. else:
  1187. # Pre-select items from tensor. tolist() is faster than repetitive
  1188. # `.item()` calls.
  1189. selected_logprob_items = selected_logprobs[
  1190. selected_logprobs_idx:selected_logprobs_idx +
  1191. len(next_token_ids)].tolist()
  1192. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1193. len(next_token_ids)].tolist()
  1194. for idx, (next_token_id, parent_id) in enumerate(
  1195. zip(next_token_ids, parent_seq_ids)):
  1196. # Get the logprob of a sampled token.
  1197. sampled_logprobs_dict = {
  1198. next_token_id:
  1199. (selected_logprob_items[idx], rank_items[idx])
  1200. }
  1201. if num_logprobs is not None and num_logprobs > 0:
  1202. # Get top K logprobs.
  1203. top_ids = top_token_ids[top_logprob_idx +
  1204. parent_id, :num_logprobs].tolist()
  1205. top_probs = top_logprobs[
  1206. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1207. # Top K is already sorted by rank, so we can use 1 ~
  1208. # num_logprobs + 1 for rank.
  1209. top_ranks = range(1, num_logprobs + 1)
  1210. sampled_logprobs_dict.update({
  1211. top_id: (top_prob, rank)
  1212. for top_id, top_prob, rank in zip(
  1213. top_ids, top_probs, top_ranks)
  1214. })
  1215. sampled_logprobs.append({
  1216. token_id: Logprob(*logprob_and_rank)
  1217. for token_id, logprob_and_rank in
  1218. sampled_logprobs_dict.items()
  1219. })
  1220. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1221. # logprobs for the current step, which has len(next_token_ids) tokens
  1222. # per sequence group. `logprobs` includes logprobs from the previous
  1223. # steps, which has len(seq_ids) tokens per sequence group.
  1224. # Iterate to the next sequence group in a batch.
  1225. selected_logprobs_idx += len(next_token_ids)
  1226. # Iterate to the next sequence group in a batch.
  1227. top_logprob_idx += len(seq_ids)
  1228. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1229. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1230. sample_indices: torch.Tensor,
  1231. greedy_samples: torch.Tensor) -> None:
  1232. """Modify the probability distributions of the greedily-sampled tokens such
  1233. that each sampled token has a "probability" of 1.0. This is required by
  1234. speculative decoding, which depends on the sampling method being encoded
  1235. within the probability distribution for correctness.
  1236. # Why do we only need to do this for greedy sampling?
  1237. Aphrodite's sampler performs the following steps for greedy or multinomial
  1238. (random) sampling:
  1239. 1. Get logits from model.
  1240. 2. Modify logits according to per-sequence sampling parameters.
  1241. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1242. according to their frequency, etc.
  1243. 3. Sample a token.
  1244. - Random sampling simply samples from the modified probability
  1245. distribution.
  1246. - Greedy sampling performs `argmax` to obtain the token with the
  1247. highest likelihood.
  1248. Ignoring greedy sampling for a moment, we find that the computed probability
  1249. distribution has the following property: we can sample from it independently
  1250. and find that the token sampled by the Sampler has a frequency corresponding
  1251. to how often we see it in our sampling. In other words, for tokens sampled
  1252. with Aphrodite's random SamplingType, the computed probability distribution
  1253. encodes the sampling methodology completely.
  1254. Greedy sampling does not normally have this property. Aphrodite modifies
  1255. logits according to sampling params, then performs `argmax`, then returns
  1256. the sampled token and the computed probability distribution. If we sample
  1257. from the distribution, we'll find the likelihood of the greedily-sampled
  1258. token is not always 1.0.
  1259. Since lossless speculative decoding requires that the sampling methodology
  1260. be encoded within the probability distribution, we are motivated to modify
  1261. the probability distribution such that the sampled token has probability 1
  1262. when speculative decoding is used.
  1263. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1264. greedy sampling using multinomial computation and unite the codepaths. This
  1265. has implications on the overall design of the sampler, e.g. how to record
  1266. accurate logprobs for the user, so this improvement is deferred to later.
  1267. """
  1268. # NOTE: logprobs are not modified so they can be returned to the user.
  1269. probs[sample_indices, :] = 0
  1270. probs[sample_indices, greedy_samples] = 1.0
  1271. def _build_sampler_output(
  1272. sample_results: SampleResultType,
  1273. sampling_metadata: SamplingMetadata,
  1274. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1275. sample_logprobs: Optional[List[SampleLogprobs]],
  1276. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1277. torch.Tensor]],
  1278. skip_sampler_cpu_output: bool = False,
  1279. ) -> SamplerOutput:
  1280. """Construct Python objects with the output of sampling.
  1281. Args:
  1282. on_device_tensors: Tuple containing on-device tensors with the
  1283. probabilities used in sampling and the sampled token ids. This
  1284. allows post-processing without copies to CPU/serialization, e.g. in
  1285. speculative decoding rejection sampling.
  1286. """
  1287. sampler_output: List[CompletionSequenceGroupOutput] = []
  1288. if not skip_sampler_cpu_output:
  1289. assert prompt_logprobs is not None
  1290. assert sample_logprobs is not None
  1291. for (seq_group, sample_result, group_prompt_logprobs,
  1292. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1293. sample_results, prompt_logprobs,
  1294. sample_logprobs):
  1295. seq_ids = seq_group.seq_ids
  1296. next_token_ids, parent_ids = sample_result
  1297. seq_outputs: List[SequenceOutput] = []
  1298. for parent_id, next_token_id, logprobs in zip(
  1299. parent_ids, next_token_ids, group_sample_logprobs):
  1300. seq_outputs.append(
  1301. SequenceOutput(seq_ids[parent_id], next_token_id,
  1302. logprobs))
  1303. sampler_output.append(
  1304. CompletionSequenceGroupOutput(seq_outputs,
  1305. group_prompt_logprobs))
  1306. # If not specified, store None values in SamplerOutput.
  1307. if on_device_tensors is not None:
  1308. (sampled_token_probs, logprobs_tensor,
  1309. sampled_token_ids) = on_device_tensors
  1310. else:
  1311. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1312. None)
  1313. return SamplerOutput(
  1314. outputs=sampler_output,
  1315. sampled_token_probs=sampled_token_probs,
  1316. sampled_token_ids=sampled_token_ids,
  1317. logprobs=logprobs_tensor,
  1318. )
  1319. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1320. """Get a list of next prompt tokens to compute logprob from a
  1321. given sequence group.
  1322. It is used to compute prompt logprob. Imagine you have logprob for each
  1323. query token. Query token needs to know the next prompt token id to compute
  1324. prompt logprob. This is a helper to obtain next prompt token ids.
  1325. This API has to be used only when the caller knows seq_group is in prefill
  1326. stage.
  1327. Returns:
  1328. A list of next prompt tokens to compute logprob.
  1329. """
  1330. assert seq_group.is_prompt, (
  1331. "Caller should ensure the sequence group is in a prefill stage.")
  1332. seq_ids = seq_group.seq_ids
  1333. query_len = seq_group.query_len
  1334. assert query_len is not None
  1335. # prompt has only 1 seq id.
  1336. assert len(seq_ids) == 1
  1337. seq_data = seq_group.seq_data[seq_ids[0]]
  1338. computed_len = seq_data.get_num_computed_tokens()
  1339. prompt_tokens = seq_data.prompt_token_ids
  1340. # +1 because we are looking for a next prompt token.
  1341. next_token_index_start = computed_len + 1
  1342. next_token_index_end = min(computed_len + query_len + 1,
  1343. len(prompt_tokens))
  1344. next_prompt_tokens = prompt_tokens[
  1345. next_token_index_start:next_token_index_end]
  1346. return next_prompt_tokens
  1347. # def _apply_mirostat_v2(logits: torch.Tensor,
  1348. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1349. # # Reduce our view to just the affected logits
  1350. # logit_view = logits[sampling_tensors.miro_indices]
  1351. # # Calculate surprise value per token
  1352. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1353. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1354. # # Mask out "too-surprising" tokens (surprisal > mu)
  1355. # mus = sampling_tensors.miro_mus
  1356. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1357. # # Unmask most-likely logit to guarantee a selection.
  1358. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1359. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1360. # # Apply logit mask (effectively a top-k filter).
  1361. # logit_view[miro_mask] = -float("inf")
  1362. # # Project logit changes made to the view onto the original.
  1363. # # I think this step might be redundant.
  1364. # logits[sampling_tensors.miro_indices] = logit_view
  1365. # return logits
  1366. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1367. # sample_results: List[Tuple[List[int], List[int]]],
  1368. # sampling_metadata: SamplingMetadata,
  1369. # output_metadata: OutputMetadata) -> None:
  1370. # """Based on whichever token was finally sampled, we calculate the
  1371. # final surprisal values to update the mus.
  1372. # Because a single sequence can have multiple samples, we must fork
  1373. # the mu accordingly."""
  1374. # assert sampling_metadata.seq_groups is not None
  1375. # seqid_to_tokens = {}
  1376. # seqid_to_indices = {}
  1377. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1378. # sample_results):
  1379. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1380. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1381. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1382. # seqids = args.miro_seqids
  1383. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1384. # device=logits.device,
  1385. # dtype=torch.long)
  1386. # # Clumsily, we recalculate token surprisals.
  1387. # logits_view = logits[args.miro_indices]
  1388. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1389. # dim=-1,
  1390. # index=picked_tokens) / -math.log(2)
  1391. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1392. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1393. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1394. # nu_mus = mus - (picked_surprise - taus) * etas
  1395. # # Record updated mu values for use in the next iteration
  1396. # # Note how each mu is split into multiple based on the number of samples.
  1397. # for seqid, seq_mus in zip(seqids, nu_mus):
  1398. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1399. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)