sampler.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from math import inf
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.common.sampling_params import SamplingType
  8. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  9. PromptLogprobs, SampleLogprobs,
  10. SamplerOutput, SequenceOutput)
  11. from aphrodite.triton_utils import HAS_TRITON
  12. if HAS_TRITON:
  13. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  14. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  15. SamplingTensors,
  16. SequenceGroupToSample)
  17. # (num_token_ids, num_parent_ids) per sequence group.
  18. SampleResultType = List[Tuple[List[int], List[int]]]
  19. class Sampler(nn.Module):
  20. """Samples the next tokens from the model's outputs.
  21. This layer does the following:
  22. 1. Discard the hidden states that are not used for sampling (i.e., all
  23. tokens except the final one in each prompt).
  24. 2. Compute the logits for the next tokens.
  25. 3. Apply presence, frequency and repetition penalties.
  26. 4. Apply temperature scaling.
  27. 5. Apply top-p and top-k truncation.
  28. 6. Sample the next tokens.
  29. Here, each sequence group within the batch can have different sampling
  30. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  31. The structure of the logits tensor is coupled with the seq_groups in
  32. sampling_metadata. Typically, each sequence in each seq_group has one row in
  33. logits for the next token to be sampled; however, for a seq_group with a
  34. prompt request with the prompt_logprobs sampling parameter, there are rows
  35. in logits for each token in the input prompt.
  36. """
  37. def __init__(self):
  38. super().__init__()
  39. # Whether or not the SamplerOutput should have on-device tensors
  40. # containing the sampled token ids and probabilities. This is used by
  41. # speculative decoding.
  42. self.include_gpu_probs_tensor = False
  43. self.should_modify_greedy_probs_inplace = False
  44. def _init_sampling_tensors(
  45. self,
  46. logits: torch.Tensor,
  47. sampling_metadata: SamplingMetadata,
  48. ):
  49. """The goal here is to reuse sampling tensors between similar decode
  50. runs. This is possible because sampling logic does not change between
  51. decodes of the same sequences.
  52. """
  53. _, vocab_size = logits.shape
  54. # First free any existing stored sampling tensors.
  55. # This is necessary because some sampling tensors may
  56. # have pinned memory.
  57. self._sampling_tensors = None
  58. # Initialize new sampling tensors
  59. (sampling_tensors, do_penalties, do_dries, do_top_p_top_k, do_top_as,
  60. do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
  61. do_quadratic, do_temp_last) = SamplingTensors.from_sampling_metadata(
  62. sampling_metadata, vocab_size, logits.device, logits.dtype)
  63. self._sampling_tensors = sampling_tensors
  64. self._do_penalties = do_penalties
  65. self._do_dries = do_dries
  66. self._do_top_p_top_k = do_top_p_top_k
  67. self._do_top_as = do_top_as
  68. self._do_min_p = do_min_p
  69. self._do_tfss = do_tfss
  70. self._do_eta_cutoffs = do_eta_cutoffs
  71. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  72. self._do_typical_ps = do_typical_ps
  73. self._do_quadratic = do_quadratic
  74. self._do_temp_last = do_temp_last
  75. def forward(
  76. self,
  77. logits: torch.Tensor,
  78. sampling_metadata: SamplingMetadata,
  79. ) -> Optional[SamplerOutput]:
  80. """
  81. Args:
  82. logits: (num_tokens, vocab_size).
  83. sampling_metadata: Metadata for sampling.
  84. """
  85. assert logits is not None
  86. _, vocab_size = logits.shape
  87. # Prepare sampling tensors with pinned memory to avoid blocking.
  88. if not sampling_metadata.reuse_sampling_tensors:
  89. self._init_sampling_tensors(logits, sampling_metadata)
  90. elif self._do_penalties:
  91. # In this case, the sampling tensors logic depends on
  92. # "output_tokens" of a sequence. As a result, we cannot
  93. # reuse sampling tensors, since "output_tokens" changes
  94. # between decode runs.
  95. self._init_sampling_tensors(logits, sampling_metadata)
  96. assert self._sampling_tensors is not None
  97. sampling_tensors = self._sampling_tensors
  98. do_penalties = self._do_penalties
  99. do_dries = self._do_dries
  100. do_top_p_top_k = self._do_top_p_top_k
  101. do_top_as = self._do_top_as
  102. do_min_p = self._do_min_p
  103. do_tfss = self._do_tfss
  104. do_eta_cutoffs = self._do_eta_cutoffs
  105. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  106. do_typical_ps = self._do_typical_ps
  107. do_quadratic = self._do_quadratic
  108. do_temp_last = self._do_temp_last
  109. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  110. # Apply presence and frequency penalties.
  111. if do_penalties:
  112. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  113. sampling_tensors.output_tokens,
  114. sampling_tensors.presence_penalties,
  115. sampling_tensors.frequency_penalties,
  116. sampling_tensors.repetition_penalties)
  117. if do_dries:
  118. logits = _apply_dry_penalty(logits, sampling_tensors.prompt_tokens,
  119. sampling_tensors.output_tokens,
  120. sampling_tensors.dry_multipliers,
  121. sampling_tensors.dry_bases,
  122. sampling_tensors.dry_allowed_lengths,
  123. sampling_tensors.dry_sequence_breakerss)
  124. # Apply temperature scaling if not doing temp_last.
  125. if not do_temp_last:
  126. # Use float32 to apply temp.
  127. # Use in-place division to avoid creating a new tensor.
  128. logits = logits.to(torch.float)
  129. logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
  130. if do_top_p_top_k:
  131. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  132. sampling_tensors.top_ks)
  133. if do_top_as:
  134. logits = _apply_top_a(logits, sampling_tensors.top_as)
  135. if do_min_p:
  136. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  137. if do_tfss:
  138. logits = _apply_tfs(logits, sampling_tensors.tfss)
  139. if do_eta_cutoffs:
  140. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  141. if do_epsilon_cutoffs:
  142. logits = _apply_epsilon_cutoff(logits,
  143. sampling_tensors.epsilon_cutoffs)
  144. if do_typical_ps:
  145. logits = _apply_typical_sampling(logits,
  146. sampling_tensors.typical_ps)
  147. if do_quadratic:
  148. logits = _apply_quadratic_sampling(
  149. logits, sampling_tensors.smoothing_factors,
  150. sampling_tensors.smoothing_curves)
  151. if do_temp_last:
  152. # Use float32 to apply temp.
  153. # Use in-place division to avoid creating a new tensor.
  154. logits = logits.to(torch.float)
  155. logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
  156. # banned_tokens = _get_custom_token_bans(sampling_metadata)
  157. # logits = _apply_token_bans(logits, banned_tokens)
  158. # We use float32 for probabilities and log probabilities.
  159. # Compute the probabilities.
  160. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  161. # Compute the log probabilities.
  162. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  163. # Sample the next tokens.
  164. sample_results, maybe_sampled_tokens_tensor = _sample(
  165. probs,
  166. logprobs,
  167. sampling_metadata,
  168. sampling_tensors,
  169. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  170. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  171. )
  172. if self.include_gpu_probs_tensor:
  173. assert maybe_sampled_tokens_tensor is not None
  174. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  175. else:
  176. on_device_tensors = None
  177. # Get the logprobs query results.
  178. prompt_logprobs = None
  179. sample_logprobs = None
  180. if not sampling_metadata.skip_sampler_cpu_output:
  181. prompt_logprobs, sample_logprobs = _get_logprobs(
  182. logprobs, sampling_metadata, sample_results)
  183. return _build_sampler_output(
  184. sample_results,
  185. sampling_metadata,
  186. prompt_logprobs,
  187. sample_logprobs,
  188. on_device_tensors=on_device_tensors,
  189. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  190. @property
  191. def _should_modify_greedy_probs_inplace(self) -> bool:
  192. """Whether or not the sampler should modify the probability distribution
  193. of greedily-sampled tokens such that multinomial sampling would sample
  194. the greedily-sampled token.
  195. In other words, if True then we set the probability of the greedily-
  196. sampled token to 1.
  197. This is used by speculative decoding, which requires that the sampling
  198. method be encoded into the probability distribution.
  199. """
  200. return self.should_modify_greedy_probs_inplace
  201. def _get_bin_counts_and_mask(
  202. tokens: torch.Tensor,
  203. vocab_size: int,
  204. num_seqs: int,
  205. ) -> Tuple[torch.Tensor, torch.Tensor]:
  206. # Compute the bin counts for the tokens.
  207. # vocab_size + 1 for padding.
  208. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  209. dtype=torch.long,
  210. device=tokens.device)
  211. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  212. bin_counts = bin_counts[:, :vocab_size]
  213. mask = bin_counts > 0
  214. return bin_counts, mask
  215. def _get_custom_token_bans(
  216. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  217. assert sampling_metadata.seq_groups is not None
  218. banned_tokens: List[List[int]] = []
  219. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  220. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  221. seq_ids = seq_group.seq_ids
  222. custom_token_bans = sampling_params.custom_token_bans
  223. if (i < sampling_metadata.num_prompts
  224. and sampling_params.prompt_logprobs is not None):
  225. prompt_len = len(seq_group.prompt_logprob_indices)
  226. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  227. banned_tokens += [custom_token_bans] * len(seq_ids)
  228. return banned_tokens
  229. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  230. output_tokens_tensor: torch.Tensor,
  231. presence_penalties: torch.Tensor,
  232. frequency_penalties: torch.Tensor,
  233. repetition_penalties: torch.Tensor) -> torch.Tensor:
  234. num_seqs, vocab_size = logits.shape
  235. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  236. num_seqs)
  237. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  238. output_tokens_tensor, vocab_size, num_seqs)
  239. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  240. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  241. logits = torch.where(logits > 0, logits / repetition_penalties,
  242. logits * repetition_penalties)
  243. # We follow the definition in OpenAI API.
  244. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  245. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  246. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  247. return logits
  248. def _apply_dry_penalty(
  249. logits: torch.Tensor,
  250. prompt_tokens: torch.Tensor,
  251. output_tokens: torch.Tensor,
  252. dry_multipliers: torch.Tensor,
  253. dry_bases: torch.Tensor,
  254. dry_allowed_lengths: torch.Tensor,
  255. dry_sequence_breakerss: torch.Tensor,
  256. ) -> torch.Tensor:
  257. num_seqs, vocab_size = logits.shape
  258. device = logits.device
  259. for i in range(num_seqs):
  260. input_ids = torch.cat([prompt_tokens[i], output_tokens[i]])
  261. if input_ids.numel() == 0:
  262. continue
  263. last_token = input_ids[-1].item()
  264. if any(last_token in breaker for breaker in dry_sequence_breakerss[i]):
  265. continue
  266. match_indices = (input_ids[:-1] == last_token).nonzero().squeeze()
  267. match_lengths = {}
  268. for idx in match_indices:
  269. next_token = input_ids[idx + 1].item()
  270. if any(next_token in breaker for breaker in dry_sequence_breakerss[i]):
  271. continue
  272. match_length = 1
  273. while idx - match_length >= 0:
  274. previous_token = input_ids[-(match_length + 1)].item()
  275. if (input_ids[idx - match_length] != previous_token or
  276. any(previous_token in breaker for breaker in dry_sequence_breakerss[i])):
  277. break
  278. match_length += 1
  279. match_lengths[next_token] = max(match_lengths.get(next_token, 0), match_length)
  280. for token, match_length in match_lengths.items():
  281. if match_length >= dry_allowed_lengths[i]:
  282. penalty = dry_multipliers[i] * (dry_bases[i] ** (match_length - dry_allowed_lengths[i]))
  283. logits[i, token] -= penalty
  284. return logits
  285. def _apply_token_bans(logits: torch.Tensor,
  286. banned_tokens: List[List[int]]) -> torch.Tensor:
  287. for i, banned_token_ids in enumerate(banned_tokens):
  288. if not banned_token_ids:
  289. continue
  290. logits[i, banned_token_ids] = -float("inf")
  291. return logits
  292. def _apply_min_tokens_penalty(
  293. logits: torch.Tensor,
  294. sampling_metadata: SamplingMetadata,
  295. ) -> torch.Tensor:
  296. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  297. have not been generated yet
  298. """
  299. # list of indices in logits that will be set to -inf
  300. logits_to_penalize = []
  301. logits_applied = 0
  302. for seq_group in sampling_metadata.seq_groups:
  303. seq_ids = seq_group.seq_ids
  304. sampling_params = seq_group.sampling_params
  305. sample_indices = seq_group.sample_indices
  306. logits_applied += len(sample_indices) + len(
  307. seq_group.prompt_logprob_indices)
  308. if not seq_group.do_sample:
  309. continue
  310. start_idx = sample_indices[0]
  311. min_tokens = sampling_params.min_tokens
  312. token_ids_to_penalize = sampling_params.all_stop_token_ids
  313. if min_tokens > 0 and token_ids_to_penalize:
  314. seqs_to_penalize = []
  315. for j, seq_id in enumerate(seq_ids):
  316. seq_data = seq_group.seq_data[seq_id]
  317. if len(seq_data.output_token_ids_array) < min_tokens:
  318. seqs_to_penalize.append(j)
  319. if seqs_to_penalize:
  320. # convert to the index into logits
  321. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  322. # itertools.product pairs each seq index with every token id
  323. logits_to_penalize.extend(
  324. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  325. if logits_to_penalize:
  326. # use zip and * to group indices along each dimension
  327. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  328. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  329. # verifies that no rows in logits were missed unexpectedly
  330. assert logits_applied == logits.shape[0]
  331. return logits
  332. def _apply_top_k_top_p(
  333. logits: torch.Tensor,
  334. p: torch.Tensor,
  335. k: torch.Tensor,
  336. ) -> torch.Tensor:
  337. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  338. # Apply top-k.
  339. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  340. # Get all the top_k values.
  341. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  342. top_k_mask = logits_sort < top_k_mask
  343. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  344. # Apply top-p.
  345. probs_sort = logits_sort.softmax(dim=-1)
  346. probs_sum = probs_sort.cumsum(dim=-1)
  347. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  348. # at least one
  349. top_p_mask[:, -1] = False
  350. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  351. # Re-sort the probabilities.
  352. src = torch.arange(logits_idx.shape[-1],
  353. device=logits_idx.device).expand_as(logits_idx)
  354. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  355. index=logits_idx,
  356. src=src)
  357. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  358. return logits
  359. def _apply_min_p(
  360. logits: torch.Tensor,
  361. min_p: torch.Tensor,
  362. ) -> torch.Tensor:
  363. """
  364. Adapted from
  365. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  366. """
  367. probs = torch.softmax(logits, dim=-1)
  368. top_probs, _ = probs.max(dim=-1, keepdim=True)
  369. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  370. tokens_to_remove = probs < scaled_min_p
  371. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  372. return logits
  373. def _apply_top_a(
  374. logits: torch.Tensor,
  375. top_a: torch.Tensor,
  376. ) -> torch.Tensor:
  377. probs = torch.softmax(logits, dim=-1)
  378. top_probs, _ = probs.max(dim=-1, keepdim=True)
  379. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  380. tokens_to_remove = probs < threshold
  381. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  382. return logits
  383. def _apply_tfs(
  384. logits: torch.Tensor,
  385. tfs: torch.Tensor,
  386. ) -> torch.Tensor:
  387. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  388. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  389. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  390. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  391. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  392. tfs_mask = torch.cat(
  393. (
  394. torch.zeros(
  395. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  396. tfs_mask,
  397. torch.ones(
  398. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  399. ),
  400. dim=-1,
  401. )
  402. logits_sort[tfs_mask] = -float("inf")
  403. logits = torch.gather(logits_sort,
  404. dim=-1,
  405. index=torch.argsort(logits_idx, dim=-1))
  406. return logits
  407. def _apply_eta_cutoff(
  408. logits: torch.Tensor,
  409. eta_cutoff: torch.Tensor,
  410. ) -> torch.Tensor:
  411. shifted_logits = torch.log_softmax(logits, dim=-1)
  412. probs = shifted_logits.exp()
  413. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  414. eps = torch.min(eta_cutoff,
  415. torch.sqrt(eta_cutoff) *
  416. torch.exp(neg_entropy)).unsqueeze(dim=1)
  417. eta_mask = probs < eps
  418. # guard against nulling out all the logits
  419. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  420. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  421. logits[eta_mask] = -float("inf")
  422. return logits
  423. def _apply_epsilon_cutoff(
  424. logits: torch.Tensor,
  425. epsilon_cutoff: torch.Tensor,
  426. ) -> torch.Tensor:
  427. probs = logits.softmax(dim=-1)
  428. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  429. # guard against nulling out all the logits
  430. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  431. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  432. logits[eps_mask] = -float("inf")
  433. return logits
  434. def _apply_typical_sampling(
  435. logits: torch.Tensor,
  436. typical_p: torch.Tensor,
  437. ) -> torch.Tensor:
  438. shifted_logits = torch.log_softmax(logits, dim=-1)
  439. probs = shifted_logits.exp()
  440. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  441. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  442. _, indices = torch.sort(surprisal_deviations)
  443. reordered_probs = probs.gather(-1, indices)
  444. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  445. dim=1)
  446. min_tokens_to_keep = 1
  447. # Keep at least min_tokens_to_keep
  448. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  449. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  450. logits[typ_mask] = -float("inf")
  451. return logits
  452. def _apply_quadratic_sampling(
  453. logits: torch.Tensor,
  454. smoothing_factor: torch.Tensor,
  455. smoothing_curve: torch.Tensor,
  456. ) -> torch.Tensor:
  457. """
  458. Applies a quadratic transformation to the logits based on the
  459. provided smoothing factors and curves. The transformation is
  460. centered around the maximum logit value in the batch.
  461. The transformation involves a quadratic and cubic term, with the
  462. cubic term controlled by the smoothing curve. The quadratic term is
  463. scaled by the smoothing factor, and the cubic term is scaled by the
  464. product of the smoothing factor and the smoothing curve.
  465. params:
  466. logits (torch.Tensor): The logits to be transformed.
  467. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  468. term in the transformation.
  469. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  470. in the transformation.
  471. returns:
  472. torch.Tensor: The transformed logits.
  473. Credits: @kalomaze
  474. """
  475. mask = smoothing_factor != 0
  476. smoothing_factor.unsqueeze_(dim=1)
  477. smoothing_curve.unsqueeze_(dim=1)
  478. k = smoothing_factor * (3 - smoothing_curve) / 2
  479. s = smoothing_factor * (smoothing_curve - 1) / 2
  480. quadlogits = logits[mask] # limit to logits using this sampler
  481. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  482. # Construct the delta from each logit to its new value
  483. diff = quadlogits - max_logits
  484. diff -= diff**2 * (s[mask] * diff - k[mask])
  485. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  486. logits[mask] -= diff
  487. return logits
  488. def _greedy_sample(
  489. selected_seq_groups: List[SequenceGroupToSample],
  490. samples: torch.Tensor,
  491. ) -> List[Tuple[List[int], List[int]]]:
  492. """Run greedy sampling on a given samples.
  493. Args:
  494. selected_seq_groups: A list of sequence groups batched.
  495. samples: (num_selected_samples,) A tensor of samples. The length of
  496. samples could be smaller than selected_seq_groups if
  497. seq_group.do_sample is False.
  498. Returns:
  499. Tuple of (next_token_ids, parent_ids). The length of returned list is
  500. same as the length of selected_seq_groups. If the corresponding
  501. seq_group has do_sample=False, tuple contains ([], [])
  502. """
  503. samples = samples.tolist()
  504. sample_idx = 0
  505. results = []
  506. for seq_group in selected_seq_groups:
  507. if not seq_group.do_sample:
  508. results.append(([], []))
  509. continue
  510. seq_ids = seq_group.seq_ids
  511. num_parent_seqs = len(seq_ids)
  512. assert num_parent_seqs == 1, (
  513. "Greedy sampling should have only one seq.")
  514. parent_ids = list(range(num_parent_seqs))
  515. next_token_ids = [samples[sample_idx]]
  516. results.append((next_token_ids, parent_ids))
  517. sample_idx += num_parent_seqs
  518. return results
  519. def _random_sample(
  520. selected_seq_groups: List[SequenceGroupToSample],
  521. random_samples: torch.Tensor,
  522. ) -> List[Tuple[List[int], List[int]]]:
  523. """Run random sampling on a given samples.
  524. Args:
  525. selected_seq_groups: A list of sequence groups batched.
  526. random_samples: (num_selected_samples,) A tensor of samples. The
  527. length of samples could be smaller than selected_seq_groups if
  528. seq_group.do_sample is False.
  529. Returns:
  530. Tuple of (next_token_ids, parent_ids). The length of returned list is
  531. same as the length of selected_seq_groups. If the corresponding
  532. seq_group has do_sample=False, tuple contains ([], [])
  533. """
  534. # Find the maximum best_of value of the prompt phase requests.
  535. random_samples = random_samples.cpu()
  536. sample_idx = 0
  537. results = []
  538. for seq_group in selected_seq_groups:
  539. if not seq_group.do_sample:
  540. results.append(([], []))
  541. continue
  542. seq_ids = seq_group.seq_ids
  543. sampling_params = seq_group.sampling_params
  544. is_prompt = seq_group.is_prompt
  545. num_parent_seqs = len(seq_ids)
  546. if is_prompt:
  547. # Prompt phase.
  548. parent_ids = [0] * sampling_params.best_of
  549. next_token_ids = random_samples[
  550. sample_idx, :sampling_params.best_of].tolist()
  551. else:
  552. # Generation phase.
  553. parent_ids = list(range(num_parent_seqs))
  554. next_token_ids = random_samples[sample_idx:sample_idx +
  555. num_parent_seqs, 0].tolist()
  556. results.append((next_token_ids, parent_ids))
  557. sample_idx += num_parent_seqs
  558. return results
  559. def _beam_search_sample(
  560. selected_seq_groups: List[SequenceGroupToSample],
  561. logprobs: torch.Tensor,
  562. ) -> List[Tuple[List[int], List[int]]]:
  563. """Run beam sampling on a given samples.
  564. Args:
  565. selected_seq_groups: A list of sequence groups batched.
  566. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  567. on selected sample indices.
  568. Returns:
  569. Tuple of (next_token_ids, parent_ids). The length of returned list is
  570. same as the length of selected_seq_groups. If the corresponding
  571. seq_group has do_sample=False, tuple contains ([], [])
  572. """
  573. # We sample 2 * beam_width candidates to make sure that with high
  574. # probability we can get `beam_width` candidates in addition to
  575. # the finished sequences for the next iteration. See
  576. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  577. # for details. See also HF reference:
  578. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  579. #
  580. # NOTE: Beam search is not vectorized, so its speed can be slower than
  581. # other sampling methods.
  582. sample_idx = 0
  583. results = []
  584. for seq_group in selected_seq_groups:
  585. if not seq_group.do_sample:
  586. results.append(([], []))
  587. continue
  588. is_prompt = seq_group.is_prompt
  589. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  590. num_parent_seqs = len(seq_ids)
  591. beam_width = sampling_params.best_of
  592. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  593. if is_prompt:
  594. # Prompt phase.
  595. assert num_parent_seqs == 1, (
  596. "Prompt input should have only one seq.")
  597. parent_ids = [0] * (2 * beam_width)
  598. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  599. 2 * beam_width)
  600. next_token_ids = next_token_ids.tolist()
  601. else:
  602. # Generation phase.
  603. cumulative_logprobs = [
  604. seq_group.seq_data[seq_id].cumulative_logprob
  605. for seq_id in seq_ids
  606. ]
  607. cumulative_logprobs = torch.tensor(
  608. cumulative_logprobs,
  609. dtype=torch.float,
  610. device=seq_group_logprobs.device)
  611. seq_group_logprobs = (seq_group_logprobs +
  612. cumulative_logprobs.unsqueeze(dim=1))
  613. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  614. 2 * beam_width)
  615. topk_ids = topk_ids.tolist()
  616. vocab_size = seq_group_logprobs.size(-1)
  617. parent_ids = [i // vocab_size for i in topk_ids]
  618. next_token_ids = [i % vocab_size for i in topk_ids]
  619. results.append((next_token_ids, parent_ids))
  620. sample_idx += num_parent_seqs
  621. assert sample_idx == logprobs.size(0)
  622. return results
  623. # torch.multinomial forces a GPU<->CPU sync.
  624. # Therefore, we use an optimized implementation instead.
  625. # Note that we always sample with replacement.
  626. # probs will be modified in place, but this is fine, as we pass
  627. # in a copy already.
  628. def _multinomial(
  629. probs: torch.Tensor,
  630. num_samples: int,
  631. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  632. ) -> torch.Tensor:
  633. if num_samples > 1:
  634. # This is equivalent to torch.repeat_interleaved (which also
  635. # forces a GPU<->CPU sync).
  636. # This allows us to do sampling with replacement by creating
  637. # num_samples copies of each row in the tensor, and then
  638. # batch sampling the resulting tensor.
  639. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  640. probs.shape[1]).contiguous().view(
  641. -1, probs.shape[1])
  642. q = torch.empty_like(probs)
  643. if seq_groups is None:
  644. q.exponential_()
  645. else:
  646. sample_idx = 0
  647. for seq_group in seq_groups:
  648. seq_ids = seq_group.seq_ids
  649. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  650. q[sample_idx:next_sample_idx].exponential_(
  651. generator=seq_group.generator)
  652. sample_idx = next_sample_idx
  653. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  654. def _sample_with_torch(
  655. probs: torch.Tensor,
  656. logprobs: torch.Tensor,
  657. sampling_metadata: SamplingMetadata,
  658. include_gpu_probs_tensor: bool,
  659. modify_greedy_probs: bool,
  660. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  661. categorized_seq_group_ids = {t: [] for t in SamplingType}
  662. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  663. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  664. sampling_params = seq_group.sampling_params
  665. sampling_type = sampling_params.sampling_type
  666. categorized_seq_group_ids[sampling_type].append(i)
  667. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  668. sample_metadata = {}
  669. multinomial_samples = {}
  670. # Create output tensor for sampled token ids.
  671. if include_gpu_probs_tensor:
  672. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  673. 1,
  674. dtype=torch.long,
  675. device=logprobs.device)
  676. else:
  677. sampled_token_ids_tensor = None
  678. # Counterintuitively, having two loops here is actually faster.
  679. # The first loop can run without waiting on GPU<->CPU sync.
  680. for sampling_type in SamplingType:
  681. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  682. num_tokens = len(sample_indices)
  683. if num_tokens == 0:
  684. continue
  685. seq_group_id = categorized_seq_group_ids[sampling_type]
  686. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  687. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  688. long_sample_indices = sample_indices.long()
  689. if sampling_type == SamplingType.GREEDY:
  690. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  691. dim=-1)
  692. if include_gpu_probs_tensor:
  693. # Store sampled tokens in output tensor.
  694. sampled_token_ids_tensor[
  695. long_sample_indices] = greedy_samples.unsqueeze(-1)
  696. if modify_greedy_probs:
  697. # If required, modify the probabilities such that sampling from
  698. # the modified distribution would always sample the argmax
  699. # token id.
  700. _modify_greedy_probs_inplace(logprobs, probs,
  701. long_sample_indices,
  702. greedy_samples)
  703. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  704. max_best_of_in_batch = 1
  705. for seq_group in seq_groups:
  706. if seq_group.is_prompt:
  707. sampling_params = seq_group.sampling_params
  708. max_best_of_in_batch = max(max_best_of_in_batch,
  709. sampling_params.best_of)
  710. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  711. "seq_groups": seq_groups,
  712. }
  713. multinomial_samples[sampling_type] = _multinomial(
  714. probs[long_sample_indices], max_best_of_in_batch,
  715. **seeded_args)
  716. if include_gpu_probs_tensor:
  717. # Store sampled tokens in output tensor.
  718. sampled_token_ids_tensor[
  719. long_sample_indices] = multinomial_samples[sampling_type]
  720. elif sampling_type == SamplingType.BEAM:
  721. beam_search_logprobs = logprobs[sample_indices]
  722. else:
  723. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  724. # GPU<->CPU sync happens in the loop below.
  725. # This also converts the sample output to Python objects.
  726. if not sampling_metadata.skip_sampler_cpu_output:
  727. for sampling_type in SamplingType:
  728. if sampling_type not in sample_metadata:
  729. continue
  730. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  731. if sampling_type == SamplingType.GREEDY:
  732. sample_results = _greedy_sample(seq_groups, greedy_samples)
  733. elif sampling_type in (SamplingType.RANDOM,
  734. SamplingType.RANDOM_SEED):
  735. sample_results = _random_sample(
  736. seq_groups, multinomial_samples[sampling_type])
  737. elif sampling_type == SamplingType.BEAM:
  738. sample_results = _beam_search_sample(seq_groups,
  739. beam_search_logprobs)
  740. sample_results_dict.update(zip(seq_group_id, sample_results))
  741. sample_results = [
  742. sample_results_dict.get(i, ([], []))
  743. for i in range(len(sampling_metadata.seq_groups))
  744. ]
  745. else:
  746. sample_results = []
  747. return sample_results, sampled_token_ids_tensor
  748. def _sample_with_triton_kernel(
  749. probs: torch.Tensor,
  750. logprobs: torch.Tensor,
  751. sampling_metadata: SamplingMetadata,
  752. sampling_tensors: SamplingTensors,
  753. ) -> List[Tuple[List[int], List[int]]]:
  754. categorized_seq_group_ids = {t: [] for t in SamplingType}
  755. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  756. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  757. sampling_params = seq_group.sampling_params
  758. sampling_type = sampling_params.sampling_type
  759. categorized_seq_group_ids[sampling_type].append(i)
  760. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  761. sample_metadata = {}
  762. max_best_of_in_batch = 1
  763. # Counterintuitively, having two loops here is actually faster.
  764. # The first loop can run without waiting on GPU<->CPU sync.
  765. for sampling_type in SamplingType:
  766. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  767. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  768. num_tokens = len(sample_indices)
  769. if num_tokens == 0:
  770. continue
  771. seq_group_id = categorized_seq_group_ids[sampling_type]
  772. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  773. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  774. sample_indices,
  775. sampled_token_indices)
  776. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  777. SamplingType.RANDOM_SEED):
  778. for seq_group in seq_groups:
  779. if seq_group.is_prompt:
  780. sampling_params = seq_group.sampling_params
  781. max_best_of_in_batch = max(max_best_of_in_batch,
  782. sampling_params.best_of)
  783. elif sampling_type == SamplingType.BEAM:
  784. beam_search_logprobs = logprobs[sample_indices]
  785. else:
  786. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  787. sampled_tokens, _, _ = sample_triton(
  788. probs=probs,
  789. seeds=sampling_tensors.sampling_seeds,
  790. max_best_of=max_best_of_in_batch,
  791. sample_indices=sampling_tensors.sample_indices,
  792. logprobs=logprobs,
  793. # don't save logprobs because we have logic for that below
  794. # TODO: use this instead of the CPU-based logic below
  795. save_logprobs=False,
  796. )
  797. # GPU<->CPU sync happens in the loop below.
  798. for sampling_type in SamplingType:
  799. if sampling_type not in sample_metadata:
  800. continue
  801. (seq_group_id, seq_groups, sample_indices,
  802. sampled_token_indices) = sample_metadata[sampling_type]
  803. if sampling_type == SamplingType.GREEDY:
  804. sample_results = _greedy_sample(
  805. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  806. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  807. sample_results = _random_sample(
  808. seq_groups, sampled_tokens[sampled_token_indices])
  809. elif sampling_type == SamplingType.BEAM:
  810. sample_results = _beam_search_sample(seq_groups,
  811. beam_search_logprobs)
  812. sample_results_dict.update(zip(seq_group_id, sample_results))
  813. sample_results = [
  814. sample_results_dict.get(i, ([], []))
  815. for i in range(len(sampling_metadata.seq_groups))
  816. ]
  817. return sample_results
  818. def _sample(
  819. probs: torch.Tensor, logprobs: torch.Tensor,
  820. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  821. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  822. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  823. """
  824. Args:
  825. probs: (num_query_tokens_in_batch, num_vocab)
  826. logprobs: (num_query_tokens_in_batch, num_vocab)
  827. sampling_metadata: The metadata for a batch for sampling.
  828. sampling_tensors: Tensors that include sampling related metadata.
  829. Returns:
  830. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  831. If sampling is skipped, it returns ([], [])
  832. sampled_token_ids_tensor: A tensor of sampled token ids.
  833. """
  834. return _sample_with_torch(
  835. probs,
  836. logprobs,
  837. sampling_metadata,
  838. include_gpu_probs_tensor=include_gpu_probs_tensor,
  839. modify_greedy_probs=modify_greedy_probs,
  840. )
  841. # TODO: Enable once Triton kernel & associated code is faster.
  842. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  843. # sampling_tensors)
  844. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  845. """
  846. This function calculates the ranks of the chosen tokens in a logprob tensor.
  847. Args:
  848. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  849. where N is the no. of tokens and M is the vocab dim.
  850. indices (torch.Tensor): List of chosen token indices.
  851. Returns:
  852. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  853. Each element in the returned tensor represents the rank
  854. of the chosen token in the input logprob tensor.
  855. """
  856. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  857. indices]
  858. return (x > vals[:, None]).long().sum(1).add_(1)
  859. def _get_logprobs(
  860. logprobs: torch.Tensor,
  861. sampling_metadata: SamplingMetadata,
  862. sample_results: List[Tuple[List[int], List[int]]],
  863. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  864. """Return sample lobprobs and prompt logprobs.
  865. The logic consists of 3 parts.
  866. - Select indices to compute logprob from, ranks of token ids, and
  867. the top k token ids from logprobs.
  868. - Compute prompt logprobs if required.
  869. - Compute sample logprobs if required.
  870. Args:
  871. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  872. logprob per vocab. Sequence groups' query tokens are batched in a
  873. single flattened tensor. For example, assuming there are N
  874. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  875. prompt logprob is enabled), decode tokens for seq_group_1 (if
  876. sampling is required), prefill tokens for seq_group_2, ...
  877. sampling_metadata: The sampling metadata.
  878. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  879. parent_ids) for each sequence group. When beam search is enabled,
  880. sample_results can contain different number of seq_ids from
  881. sampling_metadata.seq_groups. It is because beam search creates
  882. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  883. BEAM_WIDTH number of seq_ids).
  884. Returns:
  885. A tuple of prompt and sample logprobs per sequence group in a batch.
  886. """
  887. # The index of query token to calculate logprobs. It includes both
  888. # prompt and sample logprob indices.
  889. query_indices: List[int] = []
  890. # The next token ids to get the logprob value from.
  891. next_token_ids: List[int] = []
  892. # The largest requested number of logprobs. We find logprobs as many as the
  893. # largest num logprobs in this API. If every logprobs is None, it will be
  894. # set to -1.
  895. largest_num_logprobs = -1
  896. # If beam search is enabled.
  897. use_beam_search = False
  898. # Select indices to compute logprob from, ranks of token ids, and the top
  899. # k token ids from logprobs.
  900. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  901. sample_results):
  902. sampling_params = seq_group.sampling_params
  903. # Update indices and tokens for prompt logprobs.
  904. if (seq_group.is_prompt
  905. and sampling_params.prompt_logprobs is not None):
  906. largest_num_logprobs = max(largest_num_logprobs,
  907. sampling_params.prompt_logprobs)
  908. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  909. query_indices.extend(seq_group.prompt_logprob_indices)
  910. next_token_ids.extend(next_prompt_tokens)
  911. # Update indices and next tokenes for sample logprob.
  912. if seq_group.do_sample:
  913. token_ids, parent_seq_ids = sample_result
  914. # NOTE: We cannot directly use sample_indices because
  915. # sample_indices only contain parent seq_ids of a previous step.
  916. # The current step may have different number of seq_ids, and
  917. # we can obtain it from `sample_result[1]`.
  918. query_idx = seq_group.sample_indices[0]
  919. query_indices.extend(
  920. [query_idx + parent_id for parent_id in parent_seq_ids])
  921. next_token_ids.extend(token_ids)
  922. if sampling_params.logprobs is not None:
  923. largest_num_logprobs = max(largest_num_logprobs,
  924. sampling_params.logprobs)
  925. use_beam_search = use_beam_search or sampling_params.use_beam_search
  926. assert len(next_token_ids) == len(query_indices)
  927. if len(query_indices) == 0:
  928. empty_sampled_logprob = []
  929. empty_prompt_logprob = None
  930. return [empty_prompt_logprob], [empty_sampled_logprob]
  931. selected_logprobs, ranks = None, None
  932. top_logprobs, top_token_ids = None, None
  933. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  934. # skip the whole logprob calculation.
  935. if largest_num_logprobs >= 0 or use_beam_search:
  936. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  937. next_token_ids_gpu = torch.tensor(next_token_ids,
  938. device=logprobs.device)
  939. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  940. # contain duplicates if beam search is enabled.
  941. selected_logprobs = logprobs[[
  942. query_indices_gpu,
  943. next_token_ids_gpu,
  944. ]]
  945. ranks = _get_ranks(
  946. logprobs[query_indices_gpu],
  947. next_token_ids_gpu,
  948. )
  949. assert selected_logprobs.shape[0] == ranks.shape[0]
  950. # We need to compute top k only if there exists logprobs > 0.
  951. if largest_num_logprobs > 0:
  952. # Logprobs of topk tokens for a batch of sequence groups.
  953. # (num_query_tokens_across_batch).
  954. top_logprobs, top_token_ids = torch.topk(logprobs,
  955. largest_num_logprobs,
  956. dim=-1)
  957. top_logprobs = top_logprobs.to('cpu')
  958. top_token_ids = top_token_ids.to('cpu')
  959. selected_logprobs = selected_logprobs.to('cpu')
  960. ranks = ranks.to('cpu')
  961. # Find prompt/sample logprobs.
  962. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  963. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  964. top_logprob_idx = 0
  965. selected_logprobs_idx = 0
  966. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  967. sample_results):
  968. (prompt_logprobs, top_logprob_idx,
  969. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  970. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  971. selected_logprobs_idx, top_logprob_idx)
  972. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  973. (sampled_logprobs, top_logprob_idx,
  974. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  975. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  976. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  977. sample_logprobs_per_seq_group.append(sampled_logprobs)
  978. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  979. def _get_prompt_logprob_if_needed(
  980. seq_group: SequenceGroupToSample,
  981. selected_logprobs: torch.Tensor,
  982. ranks: torch.Tensor,
  983. top_token_ids: torch.Tensor,
  984. top_logprobs: torch.Tensor,
  985. selected_logprobs_idx: int,
  986. top_logprob_idx: int,
  987. ):
  988. """Compute the prompt logprob from a sequence group if needed."""
  989. sampling_params = seq_group.sampling_params
  990. is_prompt = seq_group.is_prompt
  991. # Find prompt logprobs
  992. prompt_logprobs: Optional[PromptLogprobs] = None
  993. if is_prompt and sampling_params.prompt_logprobs is not None:
  994. prompt_logprobs = []
  995. num_logprobs = sampling_params.prompt_logprobs
  996. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  997. # Pre-select indexes and create a list. It is faster than calling .item
  998. # repetitively.
  999. selected_logprob_items = selected_logprobs[
  1000. selected_logprobs_idx:selected_logprobs_idx +
  1001. len(next_prompt_tokens)].tolist()
  1002. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1003. len(next_prompt_tokens)].tolist()
  1004. for idx, token_id in enumerate(next_prompt_tokens):
  1005. # Calculate the prompt logprob of the real prompt tokens.
  1006. # {token_id: (logprob, rank_from_vocab)}
  1007. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  1008. token_id: (selected_logprob_items[idx], rank_items[idx])
  1009. }
  1010. # Add top K prompt logprobs along with its rank.
  1011. if num_logprobs > 0:
  1012. top_ids = top_token_ids[
  1013. top_logprob_idx, :num_logprobs].tolist()
  1014. top_probs = top_logprobs[
  1015. top_logprob_idx, :num_logprobs].tolist()
  1016. # Top K is already sorted by rank, so we can use 1 ~
  1017. # num_logprobs + 1 for rank.
  1018. top_ranks = range(1, num_logprobs + 1)
  1019. prompt_logprobs_dict.update({
  1020. top_id: (top_prob, rank)
  1021. for top_id, top_prob, rank in zip(top_ids, top_probs,
  1022. top_ranks)
  1023. })
  1024. prompt_logprobs.append({
  1025. token_id: Logprob(*logprob_and_rank)
  1026. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  1027. })
  1028. # + 1 to go to the next prompt token.
  1029. top_logprob_idx += 1
  1030. # + len(next_prompt_tokens) to go to the next prompt.
  1031. selected_logprobs_idx += len(next_prompt_tokens)
  1032. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  1033. def _get_sampled_logprob_if_needed(
  1034. seq_group: SequenceGroupToSample,
  1035. sample_result: Tuple[List[int], List[int]],
  1036. selected_logprobs: torch.Tensor,
  1037. ranks: torch.Tensor,
  1038. top_token_ids: torch.Tensor,
  1039. top_logprobs: torch.Tensor,
  1040. selected_logprobs_idx: int,
  1041. top_logprob_idx: int,
  1042. ):
  1043. """Compute the sample logprob if needed."""
  1044. seq_ids = seq_group.seq_ids
  1045. num_logprobs = seq_group.sampling_params.logprobs
  1046. use_beam_search = seq_group.sampling_params.use_beam_search
  1047. sampled_logprobs: SampleLogprobs = []
  1048. next_token_ids, parent_seq_ids = sample_result
  1049. if seq_group.do_sample:
  1050. assert len(next_token_ids) > 0
  1051. if num_logprobs is None and not use_beam_search:
  1052. for next_token_id in next_token_ids:
  1053. # Use a dummy logprob
  1054. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1055. else:
  1056. # Pre-select items from tensor. tolist() is faster than repetitive
  1057. # `.item()` calls.
  1058. selected_logprob_items = selected_logprobs[
  1059. selected_logprobs_idx:selected_logprobs_idx +
  1060. len(next_token_ids)].tolist()
  1061. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1062. len(next_token_ids)].tolist()
  1063. for idx, (next_token_id, parent_id) in enumerate(
  1064. zip(next_token_ids, parent_seq_ids)):
  1065. # Get the logprob of a sampled token.
  1066. sampled_logprobs_dict = {
  1067. next_token_id:
  1068. (selected_logprob_items[idx], rank_items[idx])
  1069. }
  1070. if num_logprobs is not None and num_logprobs > 0:
  1071. # Get top K logprobs.
  1072. top_ids = top_token_ids[top_logprob_idx +
  1073. parent_id, :num_logprobs].tolist()
  1074. top_probs = top_logprobs[
  1075. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1076. # Top K is already sorted by rank, so we can use 1 ~
  1077. # num_logprobs + 1 for rank.
  1078. top_ranks = range(1, num_logprobs + 1)
  1079. sampled_logprobs_dict.update({
  1080. top_id: (top_prob, rank)
  1081. for top_id, top_prob, rank in zip(
  1082. top_ids, top_probs, top_ranks)
  1083. })
  1084. sampled_logprobs.append({
  1085. token_id: Logprob(*logprob_and_rank)
  1086. for token_id, logprob_and_rank in
  1087. sampled_logprobs_dict.items()
  1088. })
  1089. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1090. # logprobs for the current step, which has len(next_token_ids) tokens
  1091. # per sequence group. `logprobs` includes logprobs from the previous
  1092. # steps, which has len(seq_ids) tokens per sequence group.
  1093. # Iterate to the next sequence group in a batch.
  1094. selected_logprobs_idx += len(next_token_ids)
  1095. # Iterate to the next sequence group in a batch.
  1096. top_logprob_idx += len(seq_ids)
  1097. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1098. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1099. sample_indices: torch.Tensor,
  1100. greedy_samples: torch.Tensor) -> None:
  1101. """Modify the probability distributions of the greedily-sampled tokens such
  1102. that each sampled token has a "probability" of 1.0. This is required by
  1103. speculative decoding, which depends on the sampling method being encoded
  1104. within the probability distribution for correctness.
  1105. # Why do we only need to do this for greedy sampling?
  1106. Aphrodite's sampler performs the following steps for greedy or multinomial
  1107. (random) sampling:
  1108. 1. Get logits from model.
  1109. 2. Modify logits according to per-sequence sampling parameters.
  1110. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1111. according to their frequency, etc.
  1112. 3. Sample a token.
  1113. - Random sampling simply samples from the modified probability
  1114. distribution.
  1115. - Greedy sampling performs `argmax` to obtain the token with the
  1116. highest likelihood.
  1117. Ignoring greedy sampling for a moment, we find that the computed probability
  1118. distribution has the following property: we can sample from it independently
  1119. and find that the token sampled by the Sampler has a frequency corresponding
  1120. to how often we see it in our sampling. In other words, for tokens sampled
  1121. with Aphrodite's random SamplingType, the computed probability distribution
  1122. encodes the sampling methodology completely.
  1123. Greedy sampling does not normally have this property. Aphrodite modifies
  1124. logits according to sampling params, then performs `argmax`, then returns
  1125. the sampled token and the computed probability distribution. If we sample
  1126. from the distribution, we'll find the likelihood of the greedily-sampled
  1127. token is not always 1.0.
  1128. Since lossless speculative decoding requires that the sampling methodology
  1129. be encoded within the probability distribution, we are motivated to modify
  1130. the probability distribution such that the sampled token has probability 1
  1131. when speculative decoding is used.
  1132. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1133. greedy sampling using multinomial computation and unite the codepaths. This
  1134. has implications on the overall design of the sampler, e.g. how to record
  1135. accurate logprobs for the user, so this improvement is deferred to later.
  1136. """
  1137. # NOTE: logprobs are not modified so they can be returned to the user.
  1138. probs[sample_indices, :] = 0
  1139. probs[sample_indices, greedy_samples] = 1.0
  1140. def _build_sampler_output(
  1141. sample_results: SampleResultType,
  1142. sampling_metadata: SamplingMetadata,
  1143. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1144. sample_logprobs: Optional[List[SampleLogprobs]],
  1145. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1146. torch.Tensor]],
  1147. skip_sampler_cpu_output: bool = False,
  1148. ) -> SamplerOutput:
  1149. """Construct Python objects with the output of sampling.
  1150. Args:
  1151. on_device_tensors: Tuple containing on-device tensors with the
  1152. probabilities used in sampling and the sampled token ids. This
  1153. allows post-processing without copies to CPU/serialization, e.g. in
  1154. speculative decoding rejection sampling.
  1155. """
  1156. sampler_output: List[CompletionSequenceGroupOutput] = []
  1157. if not skip_sampler_cpu_output:
  1158. assert prompt_logprobs is not None
  1159. assert sample_logprobs is not None
  1160. for (seq_group, sample_result, group_prompt_logprobs,
  1161. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1162. sample_results, prompt_logprobs,
  1163. sample_logprobs):
  1164. seq_ids = seq_group.seq_ids
  1165. next_token_ids, parent_ids = sample_result
  1166. seq_outputs: List[SequenceOutput] = []
  1167. for parent_id, next_token_id, logprobs in zip(
  1168. parent_ids, next_token_ids, group_sample_logprobs):
  1169. seq_outputs.append(
  1170. SequenceOutput(seq_ids[parent_id], next_token_id,
  1171. logprobs))
  1172. sampler_output.append(
  1173. CompletionSequenceGroupOutput(seq_outputs,
  1174. group_prompt_logprobs))
  1175. # If not specified, store None values in SamplerOutput.
  1176. if on_device_tensors is not None:
  1177. (sampled_token_probs, logprobs_tensor,
  1178. sampled_token_ids) = on_device_tensors
  1179. else:
  1180. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1181. None)
  1182. return SamplerOutput(
  1183. outputs=sampler_output,
  1184. sampled_token_probs=sampled_token_probs,
  1185. sampled_token_ids=sampled_token_ids,
  1186. logprobs=logprobs_tensor,
  1187. )
  1188. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1189. """Get a list of next prompt tokens to compute logprob from a
  1190. given sequence group.
  1191. It is used to compute prompt logprob. Imagine you have logprob for each
  1192. query token. Query token needs to know the next prompt token id to compute
  1193. prompt logprob. This is a helper to obtain next prompt token ids.
  1194. This API has to be used only when the caller knows seq_group is in prefill
  1195. stage.
  1196. Returns:
  1197. A list of next prompt tokens to compute logprob.
  1198. """
  1199. assert seq_group.is_prompt, (
  1200. "Caller should ensure the sequence group is in a prefill stage.")
  1201. seq_ids = seq_group.seq_ids
  1202. query_len = seq_group.query_len
  1203. assert query_len is not None
  1204. # prompt has only 1 seq id.
  1205. assert len(seq_ids) == 1
  1206. seq_data = seq_group.seq_data[seq_ids[0]]
  1207. computed_len = seq_data.get_num_computed_tokens()
  1208. prompt_tokens = seq_data.prompt_token_ids
  1209. # +1 because we are looking for a next prompt token.
  1210. next_token_index_start = computed_len + 1
  1211. next_token_index_end = min(computed_len + query_len + 1,
  1212. len(prompt_tokens))
  1213. next_prompt_tokens = prompt_tokens[
  1214. next_token_index_start:next_token_index_end]
  1215. return next_prompt_tokens
  1216. # def _apply_mirostat_v2(logits: torch.Tensor,
  1217. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1218. # # Reduce our view to just the affected logits
  1219. # logit_view = logits[sampling_tensors.miro_indices]
  1220. # # Calculate surprise value per token
  1221. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1222. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1223. # # Mask out "too-surprising" tokens (surprisal > mu)
  1224. # mus = sampling_tensors.miro_mus
  1225. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1226. # # Unmask most-likely logit to guarantee a selection.
  1227. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1228. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1229. # # Apply logit mask (effectively a top-k filter).
  1230. # logit_view[miro_mask] = -float("inf")
  1231. # # Project logit changes made to the view onto the original.
  1232. # # I think this step might be redundant.
  1233. # logits[sampling_tensors.miro_indices] = logit_view
  1234. # return logits
  1235. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1236. # sample_results: List[Tuple[List[int], List[int]]],
  1237. # sampling_metadata: SamplingMetadata,
  1238. # output_metadata: OutputMetadata) -> None:
  1239. # """Based on whichever token was finally sampled, we calculate the
  1240. # final surprisal values to update the mus.
  1241. # Because a single sequence can have multiple samples, we must fork
  1242. # the mu accordingly."""
  1243. # assert sampling_metadata.seq_groups is not None
  1244. # seqid_to_tokens = {}
  1245. # seqid_to_indices = {}
  1246. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1247. # sample_results):
  1248. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1249. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1250. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1251. # seqids = args.miro_seqids
  1252. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1253. # device=logits.device,
  1254. # dtype=torch.long)
  1255. # # Clumsily, we recalculate token surprisals.
  1256. # logits_view = logits[args.miro_indices]
  1257. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1258. # dim=-1,
  1259. # index=picked_tokens) / -math.log(2)
  1260. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1261. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1262. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1263. # nu_mus = mus - (picked_surprise - taus) * etas
  1264. # # Record updated mu values for use in the next iteration
  1265. # # Note how each mu is split into multiple based on the number of samples.
  1266. # for seqid, seq_mus in zip(seqids, nu_mus):
  1267. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1268. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)