sampler.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import os
  4. import warnings
  5. from math import inf
  6. from typing import Dict, List, Optional, Tuple
  7. import torch
  8. import torch.nn as nn
  9. import aphrodite._custom_ops as ops
  10. from aphrodite.common.sampling_params import SamplingType
  11. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  12. PromptLogprobs, SampleLogprobs,
  13. SamplerOutput, SequenceOutput)
  14. from aphrodite.triton_utils import HAS_TRITON
  15. if HAS_TRITON:
  16. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  17. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  18. SamplingTensors,
  19. SequenceGroupToSample)
  20. # (num_token_ids, num_parent_ids) per sequence group.
  21. SampleResultType = List[Tuple[List[int], List[int]]]
  22. # There isn't a "safe" temperature range for fp16 logits.
  23. # This value was chosen because 1/2e-5 is just under the 65k fp16 max, meaning
  24. # that this temperature well-uses the fp16 space after the logits are offset.
  25. _TEMPERATURE_MINIMUM = 2e-5
  26. # If enabled, we switch to a more performant implementation
  27. # of top-k and top-p
  28. APHRODITE_USE_SAMPLING_KERNELS = bool(int(
  29. os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
  30. class Sampler(nn.Module):
  31. """Samples the next tokens from the model's outputs.
  32. This layer does the following:
  33. 1. Discard the hidden states that are not used for sampling (i.e., all
  34. tokens except the final one in each prompt).
  35. 2. Compute the logits for the next tokens.
  36. 3. Apply presence, frequency and repetition penalties.
  37. 4. Apply temperature scaling.
  38. 5. Apply top-p and top-k truncation.
  39. 6. Sample the next tokens.
  40. Here, each sequence group within the batch can have different sampling
  41. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  42. The structure of the logits tensor is coupled with the seq_groups in
  43. sampling_metadata. Typically, each sequence in each seq_group has one row in
  44. logits for the next token to be sampled; however, for a seq_group with a
  45. prompt request with the prompt_logprobs sampling parameter, there are rows
  46. in logits for each token in the input prompt.
  47. """
  48. def __init__(self):
  49. super().__init__()
  50. # Whether or not the SamplerOutput should have on-device tensors
  51. # containing the sampled token ids and probabilities. This is used by
  52. # speculative decoding.
  53. self.include_gpu_probs_tensor = False
  54. self.should_modify_greedy_probs_inplace = False
  55. def _init_sampling_tensors(
  56. self,
  57. logits: torch.Tensor,
  58. sampling_metadata: SamplingMetadata,
  59. ):
  60. """The goal here is to reuse sampling tensors between similar decode
  61. runs. This is possible because sampling logic does not change between
  62. decodes of the same sequences.
  63. """
  64. _, vocab_size = logits.shape
  65. # First free any existing stored sampling tensors.
  66. # This is necessary because some sampling tensors may
  67. # have pinned memory.
  68. self._sampling_tensors = None
  69. # Initialize new sampling tensors
  70. (sampling_tensors, do_penalties, do_temperatures, do_top_p_top_k,
  71. do_top_as, do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  72. do_typical_ps, do_quadratic, do_xtc, do_nsigmas, do_temp_last
  73. ) = SamplingTensors.from_sampling_metadata(
  74. sampling_metadata, vocab_size, logits.device, logits.dtype)
  75. self._sampling_tensors = sampling_tensors
  76. self._do_penalties = do_penalties
  77. self._do_temperatures = do_temperatures
  78. self._do_top_p_top_k = do_top_p_top_k
  79. self._do_top_as = do_top_as
  80. self._do_min_p = do_min_p
  81. self._do_tfss = do_tfss
  82. self._do_eta_cutoffs = do_eta_cutoffs
  83. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  84. self._do_typical_ps = do_typical_ps
  85. self._do_quadratic = do_quadratic
  86. self._do_xtc = do_xtc
  87. self._do_nsgimas = do_nsigmas
  88. self._do_temp_last = do_temp_last
  89. def forward(
  90. self,
  91. logits: torch.Tensor,
  92. sampling_metadata: SamplingMetadata,
  93. ) -> Optional[SamplerOutput]:
  94. """
  95. Args:
  96. logits: (num_tokens, vocab_size).
  97. sampling_metadata: Metadata for sampling.
  98. """
  99. assert logits is not None
  100. _, vocab_size = logits.shape
  101. # Prepare sampling tensors with pinned memory to avoid blocking.
  102. if not sampling_metadata.reuse_sampling_tensors:
  103. self._init_sampling_tensors(logits, sampling_metadata)
  104. elif self._do_penalties:
  105. # In this case, the sampling tensors logic depends on
  106. # "output_tokens" of a sequence. As a result, we cannot
  107. # reuse sampling tensors, since "output_tokens" changes
  108. # between decode runs.
  109. self._init_sampling_tensors(logits, sampling_metadata)
  110. assert self._sampling_tensors is not None
  111. sampling_tensors = self._sampling_tensors
  112. do_penalties = self._do_penalties
  113. do_temperatures = self._do_temperatures
  114. do_top_p_top_k = self._do_top_p_top_k
  115. do_top_as = self._do_top_as
  116. do_min_p = self._do_min_p
  117. do_tfss = self._do_tfss
  118. do_eta_cutoffs = self._do_eta_cutoffs
  119. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  120. do_typical_ps = self._do_typical_ps
  121. do_quadratic = self._do_quadratic
  122. do_xtc = self._do_xtc
  123. do_nsigmas = self._do_nsgimas
  124. do_temp_last = self._do_temp_last
  125. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  126. # Apply presence and frequency penalties.
  127. if do_penalties:
  128. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  129. sampling_tensors.output_tokens,
  130. sampling_tensors.presence_penalties,
  131. sampling_tensors.frequency_penalties,
  132. sampling_tensors.repetition_penalties)
  133. # Apply temperature scaling if not doing temp_last.
  134. if do_temperatures and not do_temp_last:
  135. _apply_temperatures(logits, sampling_tensors.temperatures,
  136. sampling_tensors.dynatemp_mins,
  137. sampling_tensors.dynatemp_maxs,
  138. sampling_tensors.dynatemp_exps)
  139. if do_nsigmas:
  140. logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
  141. if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
  142. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  143. sampling_tensors.top_ks)
  144. if do_top_as:
  145. logits = _apply_top_a(logits, sampling_tensors.top_as)
  146. if do_min_p:
  147. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  148. if do_tfss:
  149. logits = _apply_tfs(logits, sampling_tensors.tfss)
  150. if do_eta_cutoffs:
  151. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  152. if do_epsilon_cutoffs:
  153. logits = _apply_epsilon_cutoff(logits,
  154. sampling_tensors.epsilon_cutoffs)
  155. if do_typical_ps:
  156. logits = _apply_typical_sampling(logits,
  157. sampling_tensors.typical_ps)
  158. if do_quadratic:
  159. logits = _apply_quadratic_sampling(
  160. logits, sampling_tensors.smoothing_factors,
  161. sampling_tensors.smoothing_curves)
  162. if do_xtc:
  163. logits = _apply_xtc_sampling(
  164. logits, sampling_tensors.xtc_thresholds,
  165. sampling_tensors.xtc_probabilities)
  166. if do_temperatures and do_temp_last:
  167. _apply_temperatures(logits, sampling_tensors.temperatures,
  168. sampling_tensors.dynatemp_mins,
  169. sampling_tensors.dynatemp_maxs,
  170. sampling_tensors.dynatemp_exps)
  171. banned_tokens = _get_custom_token_bans(sampling_metadata)
  172. logits = _apply_token_bans(logits, banned_tokens)
  173. # We use float32 for probabilities and log probabilities.
  174. # Compute the probabilities.
  175. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  176. # Compute the log probabilities.
  177. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  178. # Sample the next tokens.
  179. sample_results, maybe_sampled_tokens_tensor = _sample(
  180. probs,
  181. logprobs,
  182. sampling_metadata,
  183. sampling_tensors,
  184. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  185. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  186. )
  187. if self.include_gpu_probs_tensor:
  188. assert maybe_sampled_tokens_tensor is not None
  189. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  190. else:
  191. on_device_tensors = None
  192. # Get the logprobs query results.
  193. prompt_logprobs = None
  194. sample_logprobs = None
  195. if not sampling_metadata.skip_sampler_cpu_output:
  196. prompt_logprobs, sample_logprobs = _get_logprobs(
  197. logprobs, sampling_metadata, sample_results)
  198. return _build_sampler_output(
  199. sample_results,
  200. sampling_metadata,
  201. prompt_logprobs,
  202. sample_logprobs,
  203. on_device_tensors=on_device_tensors,
  204. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  205. @property
  206. def _should_modify_greedy_probs_inplace(self) -> bool:
  207. """Whether or not the sampler should modify the probability distribution
  208. of greedily-sampled tokens such that multinomial sampling would sample
  209. the greedily-sampled token.
  210. In other words, if True then we set the probability of the greedily-
  211. sampled token to 1.
  212. This is used by speculative decoding, which requires that the sampling
  213. method be encoded into the probability distribution.
  214. """
  215. return self.should_modify_greedy_probs_inplace
  216. def _get_bin_counts_and_mask(
  217. tokens: torch.Tensor,
  218. vocab_size: int,
  219. num_seqs: int,
  220. ) -> Tuple[torch.Tensor, torch.Tensor]:
  221. # Compute the bin counts for the tokens.
  222. # vocab_size + 1 for padding.
  223. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  224. dtype=torch.long,
  225. device=tokens.device)
  226. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  227. bin_counts = bin_counts[:, :vocab_size]
  228. mask = bin_counts > 0
  229. return bin_counts, mask
  230. def _get_custom_token_bans(
  231. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  232. assert sampling_metadata.seq_groups is not None
  233. banned_tokens: List[List[int]] = []
  234. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  235. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  236. seq_ids = seq_group.seq_ids
  237. custom_token_bans = sampling_params.custom_token_bans
  238. if (i < sampling_metadata.num_prompts
  239. and sampling_params.prompt_logprobs is not None):
  240. prompt_len = len(seq_group.prompt_logprob_indices)
  241. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  242. banned_tokens += [custom_token_bans] * len(seq_ids)
  243. return banned_tokens
  244. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  245. output_tokens_tensor: torch.Tensor,
  246. presence_penalties: torch.Tensor,
  247. frequency_penalties: torch.Tensor,
  248. repetition_penalties: torch.Tensor) -> torch.Tensor:
  249. num_seqs, vocab_size = logits.shape
  250. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  251. num_seqs)
  252. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  253. output_tokens_tensor, vocab_size, num_seqs)
  254. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  255. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  256. logits = torch.where(logits > 0, logits / repetition_penalties,
  257. logits * repetition_penalties)
  258. # We follow the definition in OpenAI API.
  259. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  260. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  261. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  262. return logits
  263. def _apply_temperatures(
  264. logits: torch.Tensor,
  265. temperatures: torch.Tensor,
  266. dynatemp_mins: torch.Tensor,
  267. dynatemp_maxs: torch.Tensor,
  268. dynatemp_exps: torch.Tensor,
  269. ) -> None:
  270. dynatemp_mask = (dynatemp_mins != 0) | (dynatemp_maxs != 0)
  271. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  272. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  273. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  274. dynatemp_logits = logits[dynatemp_mask]
  275. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  276. dynatemp_probs = dynatemp_shifted_logits.exp()
  277. dynatemp_entropies = -(dynatemp_probs *
  278. dynatemp_shifted_logits).nansum(dim=-1)
  279. dynatemp_max_entropies = torch.log_(
  280. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  281. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  282. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  283. normalized_entropies.pow_(dynatemp_exps))
  284. temperatures[dynatemp_mask] = dyn_temp
  285. temperatures[temperatures.isnan()] = _TEMPERATURE_MINIMUM
  286. temperatures[temperatures <= _TEMPERATURE_MINIMUM] = _TEMPERATURE_MINIMUM
  287. # To prevent saturation of top logits, we shift the range to [-inf, 1]
  288. # Why align to 1, instead of 0? Because [0, 1] holds 25% of all floats.
  289. # Why mask? So we aren't potentially discarding data in milder temps.
  290. low_temps = temperatures < 0.1
  291. logits[low_temps] -= logits.max(dim=-1, keepdim=True).values[low_temps] - 1
  292. logits.div_(temperatures.unsqueeze(dim=1))
  293. def _apply_token_bans(logits: torch.Tensor,
  294. banned_tokens: List[List[int]]) -> torch.Tensor:
  295. for i, banned_token_ids in enumerate(banned_tokens):
  296. if i >= logits.size(0):
  297. break
  298. if not banned_token_ids:
  299. continue
  300. logits[i, banned_token_ids] = -float("inf")
  301. return logits
  302. def _apply_min_tokens_penalty(
  303. logits: torch.Tensor,
  304. sampling_metadata: SamplingMetadata,
  305. ) -> torch.Tensor:
  306. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  307. have not been generated yet
  308. """
  309. # list of indices in logits that will be set to -inf
  310. logits_to_penalize = []
  311. logits_applied = 0
  312. for seq_group in sampling_metadata.seq_groups:
  313. seq_ids = seq_group.seq_ids
  314. sampling_params = seq_group.sampling_params
  315. sample_indices = seq_group.sample_indices
  316. logits_applied += len(sample_indices) + len(
  317. seq_group.prompt_logprob_indices)
  318. if not seq_group.do_sample:
  319. continue
  320. start_idx = sample_indices[0]
  321. min_tokens = sampling_params.min_tokens
  322. token_ids_to_penalize = sampling_params.all_stop_token_ids
  323. if min_tokens > 0 and token_ids_to_penalize:
  324. seqs_to_penalize = []
  325. for j, seq_id in enumerate(seq_ids):
  326. seq_data = seq_group.seq_data[seq_id]
  327. if len(seq_data.output_token_ids_array) < min_tokens:
  328. seqs_to_penalize.append(j)
  329. if seqs_to_penalize:
  330. # convert to the index into logits
  331. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  332. # itertools.product pairs each seq index with every token id
  333. logits_to_penalize.extend(
  334. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  335. if logits_to_penalize:
  336. # use zip and * to group indices along each dimension
  337. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  338. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  339. # verifies that no rows in logits were missed unexpectedly
  340. assert logits_applied == logits.shape[0]
  341. return logits
  342. def _apply_top_k_top_p(
  343. logits: torch.Tensor,
  344. p: torch.Tensor,
  345. k: torch.Tensor,
  346. ) -> torch.Tensor:
  347. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  348. # Apply top-k.
  349. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  350. # Get all the top_k values.
  351. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  352. top_k_mask = logits_sort < top_k_mask
  353. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  354. # Apply top-p.
  355. probs_sort = logits_sort.softmax(dim=-1)
  356. probs_sum = probs_sort.cumsum(dim=-1)
  357. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  358. # at least one
  359. top_p_mask[:, -1] = False
  360. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  361. # Re-sort the probabilities.
  362. src = torch.arange(logits_idx.shape[-1],
  363. device=logits_idx.device).expand_as(logits_idx)
  364. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  365. index=logits_idx,
  366. src=src)
  367. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  368. return logits
  369. def _apply_min_p(
  370. logits: torch.Tensor,
  371. min_p: torch.Tensor,
  372. ) -> torch.Tensor:
  373. """
  374. Adapted from
  375. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  376. """
  377. probs = torch.softmax(logits, dim=-1)
  378. top_probs, _ = probs.max(dim=-1, keepdim=True)
  379. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  380. tokens_to_remove = probs < scaled_min_p
  381. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  382. return logits
  383. def _apply_top_a(
  384. logits: torch.Tensor,
  385. top_a: torch.Tensor,
  386. ) -> torch.Tensor:
  387. probs = torch.softmax(logits, dim=-1)
  388. top_probs, _ = probs.max(dim=-1, keepdim=True)
  389. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  390. tokens_to_remove = probs < threshold
  391. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  392. return logits
  393. def _apply_tfs(
  394. logits: torch.Tensor,
  395. tfs: torch.Tensor,
  396. ) -> torch.Tensor:
  397. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  398. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  399. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  400. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  401. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  402. tfs_mask = torch.cat(
  403. (
  404. torch.zeros(
  405. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  406. tfs_mask,
  407. torch.ones(
  408. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  409. ),
  410. dim=-1,
  411. )
  412. logits_sort[tfs_mask] = -float("inf")
  413. logits = torch.gather(logits_sort,
  414. dim=-1,
  415. index=torch.argsort(logits_idx, dim=-1))
  416. return logits
  417. def _apply_eta_cutoff(
  418. logits: torch.Tensor,
  419. eta_cutoff: torch.Tensor,
  420. ) -> torch.Tensor:
  421. shifted_logits = torch.log_softmax(logits, dim=-1)
  422. probs = shifted_logits.exp()
  423. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  424. eps = torch.min(eta_cutoff,
  425. torch.sqrt(eta_cutoff) *
  426. torch.exp(neg_entropy)).unsqueeze(dim=1)
  427. eta_mask = probs < eps
  428. # guard against nulling out all the logits
  429. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  430. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  431. logits[eta_mask] = -float("inf")
  432. return logits
  433. def _apply_epsilon_cutoff(
  434. logits: torch.Tensor,
  435. epsilon_cutoff: torch.Tensor,
  436. ) -> torch.Tensor:
  437. probs = logits.softmax(dim=-1)
  438. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  439. # guard against nulling out all the logits
  440. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  441. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  442. logits[eps_mask] = -float("inf")
  443. return logits
  444. def _apply_typical_sampling(
  445. logits: torch.Tensor,
  446. typical_p: torch.Tensor,
  447. ) -> torch.Tensor:
  448. shifted_logits = torch.log_softmax(logits, dim=-1)
  449. probs = shifted_logits.exp()
  450. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  451. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  452. _, indices = torch.sort(surprisal_deviations)
  453. reordered_probs = probs.gather(-1, indices)
  454. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  455. dim=1)
  456. min_tokens_to_keep = 1
  457. # Keep at least min_tokens_to_keep
  458. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  459. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  460. logits[typ_mask] = -float("inf")
  461. return logits
  462. def _apply_quadratic_sampling(
  463. logits: torch.Tensor,
  464. smoothing_factor: torch.Tensor,
  465. smoothing_curve: torch.Tensor,
  466. ) -> torch.Tensor:
  467. """
  468. Applies a quadratic transformation to the logits based on the
  469. provided smoothing factors and curves. The transformation is
  470. centered around the maximum logit value in the batch.
  471. The transformation involves a quadratic and cubic term, with the
  472. cubic term controlled by the smoothing curve. The quadratic term is
  473. scaled by the smoothing factor, and the cubic term is scaled by the
  474. product of the smoothing factor and the smoothing curve.
  475. params:
  476. logits (torch.Tensor): The logits to be transformed.
  477. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  478. term in the transformation.
  479. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  480. in the transformation.
  481. returns:
  482. torch.Tensor: The transformed logits.
  483. Credits: @kalomaze
  484. """
  485. mask = smoothing_factor != 0
  486. smoothing_factor.unsqueeze_(dim=1)
  487. smoothing_curve.unsqueeze_(dim=1)
  488. k = smoothing_factor * (3 - smoothing_curve) / 2
  489. s = smoothing_factor * (smoothing_curve - 1) / 2
  490. quadlogits = logits[mask] # limit to logits using this sampler
  491. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  492. # Construct the delta from each logit to its new value
  493. diff = quadlogits - max_logits
  494. diff -= diff**2 * (s[mask] * diff - k[mask])
  495. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  496. logits[mask] -= diff
  497. return logits
  498. def _apply_xtc_sampling(
  499. logits: torch.Tensor,
  500. xtc_thresholds: torch.Tensor,
  501. xtc_probabilities: torch.Tensor,
  502. ) -> torch.Tensor:
  503. """Apply Exclude Top Choices (XTC) sampling to the logits.
  504. Reference: https://github.com/oobabooga/text-generation-webui/pull/6335
  505. Args:
  506. logits: (num_tokens, vocab_size) The input logits.
  507. xtc_thresholds: (num_tokens,) The threshold for each token.
  508. xtc_probabilities: (num_tokens,) The probability of applying XTC
  509. for each token.
  510. Returns:
  511. torch.Tensor: The modified logits.
  512. """
  513. apply_xtc = torch.rand_like(xtc_probabilities) < xtc_probabilities
  514. if not apply_xtc.any():
  515. return logits
  516. probs = torch.softmax(logits, dim=-1)
  517. sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
  518. # Find indices where the next probability is above the threshold
  519. # Skips the top choice, which later on becomes skipping the last choice.
  520. above_threshold = sorted_probs[..., 1:] >= xtc_thresholds.unsqueeze(-1)
  521. # Apply XTC only to rows where it should be applied
  522. for i in range(logits.shape[0]):
  523. if apply_xtc[i]:
  524. # Count logits above the threshold (skipping the first)
  525. indices_to_remove = above_threshold[i].count_nonzero(dim=-1).item()
  526. if indices_to_remove > 0:
  527. # Implies the top logit and at least one other is >= threshold.
  528. # Mask out above_thresh logits except the last/lowest one.
  529. logits[i].scatter_(
  530. 0, sorted_indices[i, :indices_to_remove], -float('inf'))
  531. return logits
  532. def _apply_top_nsigma(
  533. logits: torch.Tensor,
  534. nsigma: torch.Tensor,
  535. ) -> torch.Tensor:
  536. """Apply top-nsigma truncation to the logits.
  537. Reference: https://arxiv.org/abs/2411.07641
  538. Args:
  539. logits: Logits of shape (num_tokens, vocab_size)
  540. nsigma: Number of standard deviations to use as threshold
  541. Returns:
  542. Modified logits with values below threshold set to -inf
  543. """
  544. std = logits.std(dim=-1, keepdim=True)
  545. threshold = (logits.max(dim=-1, keepdim=True).values -
  546. nsigma.unsqueeze(dim=1) * std)
  547. logits[logits < threshold] = float("-inf")
  548. return logits
  549. def _greedy_sample(
  550. selected_seq_groups: List[SequenceGroupToSample],
  551. samples: torch.Tensor,
  552. ) -> List[Tuple[List[int], List[int]]]:
  553. """Run greedy sampling on a given samples.
  554. Args:
  555. selected_seq_groups: A list of sequence groups batched.
  556. samples: (num_selected_samples,) A tensor of samples. The length of
  557. samples could be smaller than selected_seq_groups if
  558. seq_group.do_sample is False.
  559. Returns:
  560. Tuple of (next_token_ids, parent_ids). The length of returned list is
  561. same as the length of selected_seq_groups. If the corresponding
  562. seq_group has do_sample=False, tuple contains ([], [])
  563. """
  564. samples = samples.tolist()
  565. sample_idx = 0
  566. results = []
  567. for seq_group in selected_seq_groups:
  568. if not seq_group.do_sample:
  569. results.append(([], []))
  570. continue
  571. seq_ids = seq_group.seq_ids
  572. num_parent_seqs = len(seq_ids)
  573. assert num_parent_seqs == 1, (
  574. "Greedy sampling should have only one seq.")
  575. parent_ids = list(range(num_parent_seqs))
  576. next_token_ids = [samples[sample_idx]]
  577. results.append((next_token_ids, parent_ids))
  578. sample_idx += num_parent_seqs
  579. return results
  580. def _random_sample(
  581. selected_seq_groups: List[SequenceGroupToSample],
  582. random_samples: torch.Tensor,
  583. ) -> List[Tuple[List[int], List[int]]]:
  584. """Run random sampling on a given samples.
  585. Args:
  586. selected_seq_groups: A list of sequence groups batched.
  587. random_samples: (num_selected_samples,) A tensor of samples. The
  588. length of samples could be smaller than selected_seq_groups if
  589. seq_group.do_sample is False.
  590. Returns:
  591. Tuple of (next_token_ids, parent_ids). The length of returned list is
  592. same as the length of selected_seq_groups. If the corresponding
  593. seq_group has do_sample=False, tuple contains ([], [])
  594. """
  595. # Find the maximum best_of value of the prompt phase requests.
  596. random_samples = random_samples.cpu()
  597. sample_idx = 0
  598. results = []
  599. for seq_group in selected_seq_groups:
  600. if not seq_group.do_sample:
  601. results.append(([], []))
  602. continue
  603. seq_ids = seq_group.seq_ids
  604. sampling_params = seq_group.sampling_params
  605. is_prompt = seq_group.is_prompt
  606. num_parent_seqs = len(seq_ids)
  607. if is_prompt:
  608. # Prompt phase.
  609. parent_ids = [0] * sampling_params.best_of
  610. next_token_ids = random_samples[
  611. sample_idx, :sampling_params.best_of].tolist()
  612. else:
  613. # Generation phase.
  614. parent_ids = list(range(num_parent_seqs))
  615. next_token_ids = random_samples[sample_idx:sample_idx +
  616. num_parent_seqs, 0].tolist()
  617. results.append((next_token_ids, parent_ids))
  618. sample_idx += num_parent_seqs
  619. return results
  620. def _beam_search_sample(
  621. selected_seq_groups: List[SequenceGroupToSample],
  622. logprobs: torch.Tensor,
  623. ) -> List[Tuple[List[int], List[int]]]:
  624. """Run beam sampling on a given samples.
  625. Args:
  626. selected_seq_groups: A list of sequence groups batched.
  627. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  628. on selected sample indices.
  629. Returns:
  630. Tuple of (next_token_ids, parent_ids). The length of returned list is
  631. same as the length of selected_seq_groups. If the corresponding
  632. seq_group has do_sample=False, tuple contains ([], [])
  633. """
  634. # We sample 2 * beam_width candidates to make sure that with high
  635. # probability we can get `beam_width` candidates in addition to
  636. # the finished sequences for the next iteration. See
  637. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  638. # for details. See also HF reference:
  639. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  640. #
  641. # NOTE: Beam search is not vectorized, so its speed can be slower than
  642. # other sampling methods.
  643. sample_idx = 0
  644. results = []
  645. for seq_group in selected_seq_groups:
  646. if not seq_group.do_sample:
  647. results.append(([], []))
  648. continue
  649. is_prompt = seq_group.is_prompt
  650. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  651. num_parent_seqs = len(seq_ids)
  652. beam_width = sampling_params.best_of
  653. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  654. if is_prompt:
  655. # Prompt phase.
  656. assert num_parent_seqs == 1, (
  657. "Prompt input should have only one seq.")
  658. parent_ids = [0] * (2 * beam_width)
  659. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  660. 2 * beam_width)
  661. next_token_ids = next_token_ids.tolist()
  662. else:
  663. # Generation phase.
  664. cumulative_logprobs = [
  665. seq_group.seq_data[seq_id].cumulative_logprob
  666. for seq_id in seq_ids
  667. ]
  668. cumulative_logprobs = torch.tensor(
  669. cumulative_logprobs,
  670. dtype=torch.float,
  671. device=seq_group_logprobs.device)
  672. seq_group_logprobs = (seq_group_logprobs +
  673. cumulative_logprobs.unsqueeze(dim=1))
  674. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  675. 2 * beam_width)
  676. topk_ids = topk_ids.tolist()
  677. vocab_size = seq_group_logprobs.size(-1)
  678. parent_ids = [i // vocab_size for i in topk_ids]
  679. next_token_ids = [i % vocab_size for i in topk_ids]
  680. results.append((next_token_ids, parent_ids))
  681. sample_idx += num_parent_seqs
  682. assert sample_idx == logprobs.size(0)
  683. return results
  684. # torch.multinomial forces a GPU<->CPU sync.
  685. # Therefore, we use an optimized implementation instead.
  686. # Note that we always sample with replacement.
  687. # probs will be modified in place, but this is fine, as we pass
  688. # in a copy already.
  689. def _multinomial(
  690. probs: torch.Tensor,
  691. num_samples: int,
  692. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  693. ) -> torch.Tensor:
  694. if num_samples > 1:
  695. probs = probs.repeat_interleave(num_samples, dim=0)
  696. q = torch.empty_like(probs)
  697. if seq_groups is None:
  698. q.exponential_()
  699. else:
  700. sample_idx = 0
  701. for seq_group in seq_groups:
  702. seq_ids = seq_group.seq_ids
  703. stride = len(seq_ids) * num_samples
  704. assert seq_group.generator is not None
  705. q[sample_idx:sample_idx +
  706. stride].exponential_(generator=seq_group.generator)
  707. sample_idx += stride
  708. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  709. def _top_k_top_p_multinomial_with_kernels(
  710. probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
  711. num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
  712. max_top_k_round = 32
  713. if num_samples > 1:
  714. probs = probs.repeat_interleave(num_samples, dim=0)
  715. top_ks = top_ks.repeat_interleave(num_samples)
  716. top_ps = top_ps.repeat_interleave(num_samples)
  717. batch_size = probs.shape[0]
  718. uniform_samples = torch.empty((max_top_k_round, batch_size),
  719. device=probs.device)
  720. if seq_groups is None:
  721. uniform_samples.uniform_()
  722. else:
  723. sample_idx = 0
  724. for seq_group in seq_groups:
  725. seq_ids = seq_group.seq_ids
  726. stride = len(seq_ids) * num_samples
  727. assert seq_group.generator is not None
  728. uniform_samples[:, sample_idx:sample_idx +
  729. stride].uniform_(generator=seq_group.generator)
  730. sample_idx += stride
  731. batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
  732. probs,
  733. uniform_samples,
  734. top_ks,
  735. top_ps,
  736. )
  737. if not success.all():
  738. warnings.warn("CUDA rejection sampling failed, fallback.",
  739. stacklevel=1)
  740. probs = ops.top_k_renorm_prob(probs, top_ks)
  741. probs = ops.top_p_renorm_prob(probs, top_ps)
  742. batch_next_token_ids = ops.sampling_from_probs(
  743. probs, uniform_samples[0])
  744. return batch_next_token_ids.view(-1, num_samples)
  745. def _sample_with_torch(
  746. probs: torch.Tensor,
  747. logprobs: torch.Tensor,
  748. sampling_metadata: SamplingMetadata,
  749. sampling_tensors: SamplingTensors,
  750. include_gpu_probs_tensor: bool,
  751. modify_greedy_probs: bool,
  752. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  753. categorized_seq_group_ids = {t: [] for t in SamplingType}
  754. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  755. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  756. sampling_params = seq_group.sampling_params
  757. sampling_type = sampling_params.sampling_type
  758. categorized_seq_group_ids[sampling_type].append(i)
  759. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  760. sample_metadata = {}
  761. multinomial_samples = {}
  762. # Create output tensor for sampled token ids.
  763. if include_gpu_probs_tensor:
  764. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  765. 1,
  766. dtype=torch.long,
  767. device=logprobs.device)
  768. else:
  769. sampled_token_ids_tensor = None
  770. # Counterintuitively, having two loops here is actually faster.
  771. # The first loop can run without waiting on GPU<->CPU sync.
  772. for sampling_type in SamplingType:
  773. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  774. num_tokens = len(sample_indices)
  775. if num_tokens == 0:
  776. continue
  777. seq_group_id = categorized_seq_group_ids[sampling_type]
  778. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  779. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  780. long_sample_indices = sample_indices.long()
  781. if sampling_type == SamplingType.GREEDY:
  782. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  783. dim=-1)
  784. if include_gpu_probs_tensor:
  785. # Store sampled tokens in output tensor.
  786. sampled_token_ids_tensor[
  787. long_sample_indices] = greedy_samples.unsqueeze(-1)
  788. if modify_greedy_probs:
  789. # If required, modify the probabilities such that sampling from
  790. # the modified distribution would always sample the argmax
  791. # token id.
  792. _modify_greedy_probs_inplace(logprobs, probs,
  793. long_sample_indices,
  794. greedy_samples)
  795. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  796. max_best_of_in_batch = 1
  797. for seq_group in seq_groups:
  798. if seq_group.is_prompt:
  799. sampling_params = seq_group.sampling_params
  800. max_best_of_in_batch = max(max_best_of_in_batch,
  801. sampling_params.best_of)
  802. seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
  803. seq_groups)
  804. if APHRODITE_USE_SAMPLING_KERNELS is not None:
  805. multinomial_samples[
  806. sampling_type] = _top_k_top_p_multinomial_with_kernels(
  807. probs[long_sample_indices],
  808. sampling_tensors.top_ks[long_sample_indices],
  809. sampling_tensors.top_ps[long_sample_indices],
  810. max_best_of_in_batch,
  811. seq_groups_arg,
  812. )
  813. else:
  814. multinomial_samples[sampling_type] = _multinomial(
  815. probs[long_sample_indices],
  816. max_best_of_in_batch,
  817. seq_groups=seq_groups_arg)
  818. if sampled_token_ids_tensor is not None:
  819. # Store sampled tokens in output tensor.
  820. sampled_token_ids_tensor[long_sample_indices] = \
  821. multinomial_samples[sampling_type].to(torch.long)
  822. elif sampling_type == SamplingType.BEAM:
  823. beam_search_logprobs = logprobs[sample_indices]
  824. else:
  825. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  826. # GPU<->CPU sync happens in the loop below.
  827. # This also converts the sample output to Python objects.
  828. if not sampling_metadata.skip_sampler_cpu_output:
  829. for sampling_type in SamplingType:
  830. if sampling_type not in sample_metadata:
  831. continue
  832. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  833. if sampling_type == SamplingType.GREEDY:
  834. sample_results = _greedy_sample(seq_groups, greedy_samples)
  835. elif sampling_type in (SamplingType.RANDOM,
  836. SamplingType.RANDOM_SEED):
  837. sample_results = _random_sample(
  838. seq_groups, multinomial_samples[sampling_type])
  839. elif sampling_type == SamplingType.BEAM:
  840. sample_results = _beam_search_sample(seq_groups,
  841. beam_search_logprobs)
  842. sample_results_dict.update(zip(seq_group_id, sample_results))
  843. sample_results = [
  844. sample_results_dict.get(i, ([], []))
  845. for i in range(len(sampling_metadata.seq_groups))
  846. ]
  847. else:
  848. sample_results = []
  849. return sample_results, sampled_token_ids_tensor
  850. def _sample_with_triton_kernel(
  851. probs: torch.Tensor,
  852. logprobs: torch.Tensor,
  853. sampling_metadata: SamplingMetadata,
  854. sampling_tensors: SamplingTensors,
  855. ) -> List[Tuple[List[int], List[int]]]:
  856. categorized_seq_group_ids = {t: [] for t in SamplingType}
  857. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  858. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  859. sampling_params = seq_group.sampling_params
  860. sampling_type = sampling_params.sampling_type
  861. categorized_seq_group_ids[sampling_type].append(i)
  862. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  863. sample_metadata = {}
  864. max_best_of_in_batch = 1
  865. # Counterintuitively, having two loops here is actually faster.
  866. # The first loop can run without waiting on GPU<->CPU sync.
  867. for sampling_type in SamplingType:
  868. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  869. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  870. num_tokens = len(sample_indices)
  871. if num_tokens == 0:
  872. continue
  873. seq_group_id = categorized_seq_group_ids[sampling_type]
  874. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  875. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  876. sample_indices,
  877. sampled_token_indices)
  878. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  879. SamplingType.RANDOM_SEED):
  880. for seq_group in seq_groups:
  881. if seq_group.is_prompt:
  882. sampling_params = seq_group.sampling_params
  883. max_best_of_in_batch = max(max_best_of_in_batch,
  884. sampling_params.best_of)
  885. elif sampling_type == SamplingType.BEAM:
  886. beam_search_logprobs = logprobs[sample_indices]
  887. else:
  888. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  889. sampled_tokens, _, _ = sample_triton(
  890. probs=probs,
  891. seeds=sampling_tensors.sampling_seeds,
  892. max_best_of=max_best_of_in_batch,
  893. sample_indices=sampling_tensors.sample_indices,
  894. logprobs=logprobs,
  895. # don't save logprobs because we have logic for that below
  896. # TODO: use this instead of the CPU-based logic below
  897. save_logprobs=False,
  898. )
  899. # GPU<->CPU sync happens in the loop below.
  900. for sampling_type in SamplingType:
  901. if sampling_type not in sample_metadata:
  902. continue
  903. (seq_group_id, seq_groups, sample_indices,
  904. sampled_token_indices) = sample_metadata[sampling_type]
  905. if sampling_type == SamplingType.GREEDY:
  906. sample_results = _greedy_sample(
  907. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  908. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  909. sample_results = _random_sample(
  910. seq_groups, sampled_tokens[sampled_token_indices])
  911. elif sampling_type == SamplingType.BEAM:
  912. sample_results = _beam_search_sample(seq_groups,
  913. beam_search_logprobs)
  914. sample_results_dict.update(zip(seq_group_id, sample_results))
  915. sample_results = [
  916. sample_results_dict.get(i, ([], []))
  917. for i in range(len(sampling_metadata.seq_groups))
  918. ]
  919. return sample_results
  920. def _sample(
  921. probs: torch.Tensor, logprobs: torch.Tensor,
  922. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  923. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  924. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  925. """
  926. Args:
  927. probs: (num_query_tokens_in_batch, num_vocab)
  928. logprobs: (num_query_tokens_in_batch, num_vocab)
  929. sampling_metadata: The metadata for a batch for sampling.
  930. sampling_tensors: Tensors that include sampling related metadata.
  931. Returns:
  932. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  933. If sampling is skipped, it returns ([], [])
  934. sampled_token_ids_tensor: A tensor of sampled token ids.
  935. """
  936. return _sample_with_torch(
  937. probs,
  938. logprobs,
  939. sampling_metadata,
  940. sampling_tensors,
  941. include_gpu_probs_tensor=include_gpu_probs_tensor,
  942. modify_greedy_probs=modify_greedy_probs,
  943. )
  944. # TODO: Enable once Triton kernel & associated code is faster.
  945. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  946. # sampling_tensors)
  947. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  948. """
  949. This function calculates the ranks of the chosen tokens in a logprob tensor.
  950. Args:
  951. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  952. where N is the no. of tokens and M is the vocab dim.
  953. indices (torch.Tensor): List of chosen token indices.
  954. Returns:
  955. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  956. Each element in the returned tensor represents the rank
  957. of the chosen token in the input logprob tensor.
  958. """
  959. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  960. indices]
  961. return (x > vals[:, None]).long().sum(1).add_(1)
  962. def _get_logprobs(
  963. logprobs: torch.Tensor,
  964. sampling_metadata: SamplingMetadata,
  965. sample_results: List[Tuple[List[int], List[int]]],
  966. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  967. """Return sample lobprobs and prompt logprobs.
  968. The logic consists of 3 parts.
  969. - Select indices to compute logprob from, ranks of token ids, and
  970. the top k token ids from logprobs.
  971. - Compute prompt logprobs if required.
  972. - Compute sample logprobs if required.
  973. Args:
  974. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  975. logprob per vocab. Sequence groups' query tokens are batched in a
  976. single flattened tensor. For example, assuming there are N
  977. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  978. prompt logprob is enabled), decode tokens for seq_group_1 (if
  979. sampling is required), prefill tokens for seq_group_2, ...
  980. sampling_metadata: The sampling metadata.
  981. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  982. parent_ids) for each sequence group. When beam search is enabled,
  983. sample_results can contain different number of seq_ids from
  984. sampling_metadata.seq_groups. It is because beam search creates
  985. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  986. BEAM_WIDTH number of seq_ids).
  987. Returns:
  988. A tuple of prompt and sample logprobs per sequence group in a batch.
  989. """
  990. # The index of query token to calculate logprobs. It includes both
  991. # prompt and sample logprob indices.
  992. query_indices: List[int] = []
  993. # The next token ids to get the logprob value from.
  994. next_token_ids: List[int] = []
  995. # The largest requested number of logprobs. We find logprobs as many as the
  996. # largest num logprobs in this API. If every logprobs is None, it will be
  997. # set to -1.
  998. largest_num_logprobs = -1
  999. # If beam search is enabled.
  1000. use_beam_search = False
  1001. # Select indices to compute logprob from, ranks of token ids, and the top
  1002. # k token ids from logprobs.
  1003. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  1004. sample_results):
  1005. sampling_params = seq_group.sampling_params
  1006. # Update indices and tokens for prompt logprobs.
  1007. if (seq_group.is_prompt
  1008. and sampling_params.prompt_logprobs is not None):
  1009. largest_num_logprobs = max(largest_num_logprobs,
  1010. sampling_params.prompt_logprobs)
  1011. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1012. query_indices.extend(seq_group.prompt_logprob_indices)
  1013. next_token_ids.extend(next_prompt_tokens)
  1014. # Update indices and next tokenes for sample logprob.
  1015. if seq_group.do_sample:
  1016. token_ids, parent_seq_ids = sample_result
  1017. # NOTE: We cannot directly use sample_indices because
  1018. # sample_indices only contain parent seq_ids of a previous step.
  1019. # The current step may have different number of seq_ids, and
  1020. # we can obtain it from `sample_result[1]`.
  1021. query_idx = seq_group.sample_indices[0]
  1022. query_indices.extend(
  1023. [query_idx + parent_id for parent_id in parent_seq_ids])
  1024. next_token_ids.extend(token_ids)
  1025. if sampling_params.logprobs is not None:
  1026. largest_num_logprobs = max(largest_num_logprobs,
  1027. sampling_params.logprobs)
  1028. use_beam_search = use_beam_search or sampling_params.use_beam_search
  1029. assert len(next_token_ids) == len(query_indices)
  1030. if len(query_indices) == 0:
  1031. empty_sampled_logprob = []
  1032. empty_prompt_logprob = None
  1033. return [empty_prompt_logprob], [empty_sampled_logprob]
  1034. selected_logprobs, ranks = None, None
  1035. top_logprobs, top_token_ids = None, None
  1036. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  1037. # skip the whole logprob calculation.
  1038. if largest_num_logprobs >= 0 or use_beam_search:
  1039. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  1040. next_token_ids_gpu = torch.tensor(next_token_ids,
  1041. device=logprobs.device)
  1042. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  1043. # contain duplicates if beam search is enabled.
  1044. selected_logprobs = logprobs[[
  1045. query_indices_gpu,
  1046. next_token_ids_gpu,
  1047. ]]
  1048. ranks = _get_ranks(
  1049. logprobs[query_indices_gpu],
  1050. next_token_ids_gpu,
  1051. )
  1052. assert selected_logprobs.shape[0] == ranks.shape[0]
  1053. # We need to compute top k only if there exists logprobs > 0.
  1054. if largest_num_logprobs > 0:
  1055. # Logprobs of topk tokens for a batch of sequence groups.
  1056. # (num_query_tokens_across_batch).
  1057. top_logprobs, top_token_ids = torch.topk(logprobs,
  1058. largest_num_logprobs,
  1059. dim=-1)
  1060. top_logprobs = top_logprobs.to('cpu')
  1061. top_token_ids = top_token_ids.to('cpu')
  1062. selected_logprobs = selected_logprobs.to('cpu')
  1063. ranks = ranks.to('cpu')
  1064. # Find prompt/sample logprobs.
  1065. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  1066. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  1067. top_logprob_idx = 0
  1068. selected_logprobs_idx = 0
  1069. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  1070. sample_results):
  1071. (prompt_logprobs, top_logprob_idx,
  1072. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  1073. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  1074. selected_logprobs_idx, top_logprob_idx)
  1075. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  1076. (sampled_logprobs, top_logprob_idx,
  1077. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  1078. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  1079. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  1080. sample_logprobs_per_seq_group.append(sampled_logprobs)
  1081. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  1082. def _get_prompt_logprob_if_needed(
  1083. seq_group: SequenceGroupToSample,
  1084. selected_logprobs: torch.Tensor,
  1085. ranks: torch.Tensor,
  1086. top_token_ids: torch.Tensor,
  1087. top_logprobs: torch.Tensor,
  1088. selected_logprobs_idx: int,
  1089. top_logprob_idx: int,
  1090. ):
  1091. """Compute the prompt logprob from a sequence group if needed."""
  1092. sampling_params = seq_group.sampling_params
  1093. is_prompt = seq_group.is_prompt
  1094. # Find prompt logprobs
  1095. prompt_logprobs: Optional[PromptLogprobs] = None
  1096. if is_prompt and sampling_params.prompt_logprobs is not None:
  1097. prompt_logprobs = []
  1098. num_logprobs = sampling_params.prompt_logprobs
  1099. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  1100. # Pre-select indexes and create a list. It is faster than calling .item
  1101. # repetitively.
  1102. selected_logprob_items = selected_logprobs[
  1103. selected_logprobs_idx:selected_logprobs_idx +
  1104. len(next_prompt_tokens)].tolist()
  1105. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1106. len(next_prompt_tokens)].tolist()
  1107. for idx, token_id in enumerate(next_prompt_tokens):
  1108. # Calculate the prompt logprob of the real prompt tokens.
  1109. # {token_id: (logprob, rank_from_vocab)}
  1110. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1111. token_id: (selected_logprob_items[idx], rank_items[idx])
  1112. }
  1113. # Add top K prompt logprobs along with its rank.
  1114. if num_logprobs > 0:
  1115. top_ids = top_token_ids[
  1116. top_logprob_idx, :num_logprobs].tolist()
  1117. top_probs = top_logprobs[
  1118. top_logprob_idx, :num_logprobs].tolist()
  1119. # Top K is already sorted by rank, so we can use 1 ~
  1120. # num_logprobs + 1 for rank.
  1121. top_ranks = range(1, num_logprobs + 1)
  1122. prompt_logprobs_dict.update({
  1123. top_id: (top_prob, rank)
  1124. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1125. top_ranks)
  1126. })
  1127. prompt_logprobs.append({
  1128. token_id: Logprob(*logprob_and_rank)
  1129. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1130. })
  1131. # + 1 to go to the next prompt token.
  1132. top_logprob_idx += 1
  1133. # + len(next_prompt_tokens) to go to the next prompt.
  1134. selected_logprobs_idx += len(next_prompt_tokens)
  1135. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1136. def _get_sampled_logprob_if_needed(
  1137. seq_group: SequenceGroupToSample,
  1138. sample_result: Tuple[List[int], List[int]],
  1139. selected_logprobs: torch.Tensor,
  1140. ranks: torch.Tensor,
  1141. top_token_ids: torch.Tensor,
  1142. top_logprobs: torch.Tensor,
  1143. selected_logprobs_idx: int,
  1144. top_logprob_idx: int,
  1145. ):
  1146. """Compute the sample logprob if needed."""
  1147. seq_ids = seq_group.seq_ids
  1148. num_logprobs = seq_group.sampling_params.logprobs
  1149. use_beam_search = seq_group.sampling_params.use_beam_search
  1150. sampled_logprobs: SampleLogprobs = []
  1151. next_token_ids, parent_seq_ids = sample_result
  1152. if seq_group.do_sample:
  1153. assert len(next_token_ids) > 0
  1154. if num_logprobs is None and not use_beam_search:
  1155. for next_token_id in next_token_ids:
  1156. # Use a dummy logprob
  1157. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1158. else:
  1159. # Pre-select items from tensor. tolist() is faster than repetitive
  1160. # `.item()` calls.
  1161. selected_logprob_items = selected_logprobs[
  1162. selected_logprobs_idx:selected_logprobs_idx +
  1163. len(next_token_ids)].tolist()
  1164. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1165. len(next_token_ids)].tolist()
  1166. for idx, (next_token_id, parent_id) in enumerate(
  1167. zip(next_token_ids, parent_seq_ids)):
  1168. # Get the logprob of a sampled token.
  1169. sampled_logprobs_dict = {
  1170. next_token_id:
  1171. (selected_logprob_items[idx], rank_items[idx])
  1172. }
  1173. if num_logprobs is not None and num_logprobs > 0:
  1174. # Get top K logprobs.
  1175. top_ids = top_token_ids[top_logprob_idx +
  1176. parent_id, :num_logprobs].tolist()
  1177. top_probs = top_logprobs[
  1178. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1179. # Top K is already sorted by rank, so we can use 1 ~
  1180. # num_logprobs + 1 for rank.
  1181. top_ranks = range(1, num_logprobs + 1)
  1182. sampled_logprobs_dict.update({
  1183. top_id: (top_prob, rank)
  1184. for top_id, top_prob, rank in zip(
  1185. top_ids, top_probs, top_ranks)
  1186. })
  1187. sampled_logprobs.append({
  1188. token_id: Logprob(*logprob_and_rank)
  1189. for token_id, logprob_and_rank in
  1190. sampled_logprobs_dict.items()
  1191. })
  1192. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1193. # logprobs for the current step, which has len(next_token_ids) tokens
  1194. # per sequence group. `logprobs` includes logprobs from the previous
  1195. # steps, which has len(seq_ids) tokens per sequence group.
  1196. # Iterate to the next sequence group in a batch.
  1197. selected_logprobs_idx += len(next_token_ids)
  1198. # Iterate to the next sequence group in a batch.
  1199. top_logprob_idx += len(seq_ids)
  1200. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1201. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1202. sample_indices: torch.Tensor,
  1203. greedy_samples: torch.Tensor) -> None:
  1204. """Modify the probability distributions of the greedily-sampled tokens such
  1205. that each sampled token has a "probability" of 1.0. This is required by
  1206. speculative decoding, which depends on the sampling method being encoded
  1207. within the probability distribution for correctness.
  1208. # Why do we only need to do this for greedy sampling?
  1209. Aphrodite's sampler performs the following steps for greedy or multinomial
  1210. (random) sampling:
  1211. 1. Get logits from model.
  1212. 2. Modify logits according to per-sequence sampling parameters.
  1213. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1214. according to their frequency, etc.
  1215. 3. Sample a token.
  1216. - Random sampling simply samples from the modified probability
  1217. distribution.
  1218. - Greedy sampling performs `argmax` to obtain the token with the
  1219. highest likelihood.
  1220. Ignoring greedy sampling for a moment, we find that the computed probability
  1221. distribution has the following property: we can sample from it independently
  1222. and find that the token sampled by the Sampler has a frequency corresponding
  1223. to how often we see it in our sampling. In other words, for tokens sampled
  1224. with Aphrodite's random SamplingType, the computed probability distribution
  1225. encodes the sampling methodology completely.
  1226. Greedy sampling does not normally have this property. Aphrodite modifies
  1227. logits according to sampling params, then performs `argmax`, then returns
  1228. the sampled token and the computed probability distribution. If we sample
  1229. from the distribution, we'll find the likelihood of the greedily-sampled
  1230. token is not always 1.0.
  1231. Since lossless speculative decoding requires that the sampling methodology
  1232. be encoded within the probability distribution, we are motivated to modify
  1233. the probability distribution such that the sampled token has probability 1
  1234. when speculative decoding is used.
  1235. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1236. greedy sampling using multinomial computation and unite the codepaths. This
  1237. has implications on the overall design of the sampler, e.g. how to record
  1238. accurate logprobs for the user, so this improvement is deferred to later.
  1239. """
  1240. # NOTE: logprobs are not modified so they can be returned to the user.
  1241. probs[sample_indices, :] = 0
  1242. probs[sample_indices, greedy_samples] = 1.0
  1243. def _build_sampler_output(
  1244. sample_results: SampleResultType,
  1245. sampling_metadata: SamplingMetadata,
  1246. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1247. sample_logprobs: Optional[List[SampleLogprobs]],
  1248. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1249. torch.Tensor]],
  1250. skip_sampler_cpu_output: bool = False,
  1251. ) -> SamplerOutput:
  1252. """Construct Python objects with the output of sampling.
  1253. Args:
  1254. on_device_tensors: Tuple containing on-device tensors with the
  1255. probabilities used in sampling and the sampled token ids. This
  1256. allows post-processing without copies to CPU/serialization, e.g. in
  1257. speculative decoding rejection sampling.
  1258. """
  1259. sampler_output: List[CompletionSequenceGroupOutput] = []
  1260. if not skip_sampler_cpu_output:
  1261. assert prompt_logprobs is not None
  1262. assert sample_logprobs is not None
  1263. for (seq_group, sample_result, group_prompt_logprobs,
  1264. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1265. sample_results, prompt_logprobs,
  1266. sample_logprobs):
  1267. seq_ids = seq_group.seq_ids
  1268. next_token_ids, parent_ids = sample_result
  1269. seq_outputs: List[SequenceOutput] = []
  1270. for parent_id, next_token_id, logprobs in zip(
  1271. parent_ids, next_token_ids, group_sample_logprobs):
  1272. seq_outputs.append(
  1273. SequenceOutput(seq_ids[parent_id], next_token_id,
  1274. logprobs))
  1275. sampler_output.append(
  1276. CompletionSequenceGroupOutput(seq_outputs,
  1277. group_prompt_logprobs))
  1278. # If not specified, store None values in SamplerOutput.
  1279. if on_device_tensors is not None:
  1280. (sampled_token_probs, logprobs_tensor,
  1281. sampled_token_ids) = on_device_tensors
  1282. else:
  1283. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1284. None)
  1285. return SamplerOutput(
  1286. outputs=sampler_output,
  1287. sampled_token_probs=sampled_token_probs,
  1288. sampled_token_ids=sampled_token_ids,
  1289. logprobs=logprobs_tensor,
  1290. )
  1291. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1292. """Get a list of next prompt tokens to compute logprob from a
  1293. given sequence group.
  1294. It is used to compute prompt logprob. Imagine you have logprob for each
  1295. query token. Query token needs to know the next prompt token id to compute
  1296. prompt logprob. This is a helper to obtain next prompt token ids.
  1297. This API has to be used only when the caller knows seq_group is in prefill
  1298. stage.
  1299. Returns:
  1300. A list of next prompt tokens to compute logprob.
  1301. """
  1302. assert seq_group.is_prompt, (
  1303. "Caller should ensure the sequence group is in a prefill stage.")
  1304. seq_ids = seq_group.seq_ids
  1305. query_len = seq_group.query_len
  1306. assert query_len is not None
  1307. # prompt has only 1 seq id.
  1308. assert len(seq_ids) == 1
  1309. seq_data = seq_group.seq_data[seq_ids[0]]
  1310. computed_len = seq_data.get_num_computed_tokens()
  1311. prompt_tokens = seq_data.prompt_token_ids
  1312. # +1 because we are looking for a next prompt token.
  1313. next_token_index_start = computed_len + 1
  1314. next_token_index_end = min(computed_len + query_len + 1,
  1315. len(prompt_tokens))
  1316. next_prompt_tokens = prompt_tokens[
  1317. next_token_index_start:next_token_index_end]
  1318. return next_prompt_tokens
  1319. # def _apply_mirostat_v2(logits: torch.Tensor,
  1320. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1321. # # Reduce our view to just the affected logits
  1322. # logit_view = logits[sampling_tensors.miro_indices]
  1323. # # Calculate surprise value per token
  1324. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1325. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1326. # # Mask out "too-surprising" tokens (surprisal > mu)
  1327. # mus = sampling_tensors.miro_mus
  1328. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1329. # # Unmask most-likely logit to guarantee a selection.
  1330. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1331. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1332. # # Apply logit mask (effectively a top-k filter).
  1333. # logit_view[miro_mask] = -float("inf")
  1334. # # Project logit changes made to the view onto the original.
  1335. # # I think this step might be redundant.
  1336. # logits[sampling_tensors.miro_indices] = logit_view
  1337. # return logits
  1338. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1339. # sample_results: List[Tuple[List[int], List[int]]],
  1340. # sampling_metadata: SamplingMetadata,
  1341. # output_metadata: OutputMetadata) -> None:
  1342. # """Based on whichever token was finally sampled, we calculate the
  1343. # final surprisal values to update the mus.
  1344. # Because a single sequence can have multiple samples, we must fork
  1345. # the mu accordingly."""
  1346. # assert sampling_metadata.seq_groups is not None
  1347. # seqid_to_tokens = {}
  1348. # seqid_to_indices = {}
  1349. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1350. # sample_results):
  1351. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1352. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1353. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1354. # seqids = args.miro_seqids
  1355. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1356. # device=logits.device,
  1357. # dtype=torch.long)
  1358. # # Clumsily, we recalculate token surprisals.
  1359. # logits_view = logits[args.miro_indices]
  1360. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1361. # dim=-1,
  1362. # index=picked_tokens) / -math.log(2)
  1363. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1364. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1365. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1366. # nu_mus = mus - (picked_surprise - taus) * etas
  1367. # # Record updated mu values for use in the next iteration
  1368. # # Note how each mu is split into multiple based on the number of samples.
  1369. # for seqid, seq_mus in zip(seqids, nu_mus):
  1370. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1371. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)