sampler.py 58 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from math import inf
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.common.sampling_params import SamplingType
  8. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  9. PromptLogprobs, SampleLogprobs,
  10. SamplerOutput, SequenceOutput)
  11. from aphrodite.triton_utils import HAS_TRITON
  12. if HAS_TRITON:
  13. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  14. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  15. SamplingTensors,
  16. SequenceGroupToSample)
  17. # (num_token_ids, num_parent_ids) per sequence group.
  18. SampleResultType = List[Tuple[List[int], List[int]]]
  19. class Sampler(nn.Module):
  20. """Samples the next tokens from the model's outputs.
  21. This layer does the following:
  22. 1. Discard the hidden states that are not used for sampling (i.e., all
  23. tokens except the final one in each prompt).
  24. 2. Compute the logits for the next tokens.
  25. 3. Apply presence, frequency and repetition penalties.
  26. 4. Apply temperature scaling.
  27. 5. Apply top-p and top-k truncation.
  28. 6. Sample the next tokens.
  29. Here, each sequence group within the batch can have different sampling
  30. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  31. The structure of the logits tensor is coupled with the seq_groups in
  32. sampling_metadata. Typically, each sequence in each seq_group has one row in
  33. logits for the next token to be sampled; however, for a seq_group with a
  34. prompt request with the prompt_logprobs sampling parameter, there are rows
  35. in logits for each token in the input prompt.
  36. """
  37. def __init__(self):
  38. super().__init__()
  39. # Whether or not the SamplerOutput should have on-device tensors
  40. # containing the sampled token ids and probabilities. This is used by
  41. # speculative decoding.
  42. self.include_gpu_probs_tensor = False
  43. self.should_modify_greedy_probs_inplace = False
  44. def _init_sampling_tensors(
  45. self,
  46. logits: torch.Tensor,
  47. sampling_metadata: SamplingMetadata,
  48. ):
  49. """The goal here is to reuse sampling tensors between similar decode
  50. runs. This is possible because sampling logic does not change between
  51. decodes of the same sequences.
  52. """
  53. _, vocab_size = logits.shape
  54. # First free any existing stored sampling tensors.
  55. # This is necessary because some sampling tensors may
  56. # have pinned memory.
  57. self._sampling_tensors = None
  58. # Initialize new sampling tensors
  59. (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
  60. do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
  61. do_quadratic, do_temp_last) = SamplingTensors.from_sampling_metadata(
  62. sampling_metadata, vocab_size, logits.device, logits.dtype)
  63. self._sampling_tensors = sampling_tensors
  64. self._do_penalties = do_penalties
  65. self._do_top_p_top_k = do_top_p_top_k
  66. self._do_top_as = do_top_as
  67. self._do_min_p = do_min_p
  68. self._do_tfss = do_tfss
  69. self._do_eta_cutoffs = do_eta_cutoffs
  70. self._do_epsilon_cutoffs = do_epsilon_cutoffs
  71. self._do_typical_ps = do_typical_ps
  72. self._do_quadratic = do_quadratic
  73. self._do_temp_last = do_temp_last
  74. def forward(
  75. self,
  76. logits: torch.Tensor,
  77. sampling_metadata: SamplingMetadata,
  78. ) -> Optional[SamplerOutput]:
  79. """
  80. Args:
  81. logits: (num_tokens, vocab_size).
  82. sampling_metadata: Metadata for sampling.
  83. """
  84. assert logits is not None
  85. _, vocab_size = logits.shape
  86. # Prepare sampling tensors with pinned memory to avoid blocking.
  87. if not sampling_metadata.reuse_sampling_tensors:
  88. self._init_sampling_tensors(logits, sampling_metadata)
  89. elif self._do_penalties:
  90. # In this case, the sampling tensors logic depends on
  91. # "output_tokens" of a sequence. As a result, we cannot
  92. # reuse sampling tensors, since "output_tokens" changes
  93. # between decode runs.
  94. self._init_sampling_tensors(logits, sampling_metadata)
  95. assert self._sampling_tensors is not None
  96. sampling_tensors = self._sampling_tensors
  97. do_penalties = self._do_penalties
  98. do_top_p_top_k = self._do_top_p_top_k
  99. do_top_as = self._do_top_as
  100. do_min_p = self._do_min_p
  101. do_tfss = self._do_tfss
  102. do_eta_cutoffs = self._do_eta_cutoffs
  103. do_epsilon_cutoffs = self._do_epsilon_cutoffs
  104. do_typical_ps = self._do_typical_ps
  105. do_quadratic = self._do_quadratic
  106. do_temp_last = self._do_temp_last
  107. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  108. # Apply presence and frequency penalties.
  109. if do_penalties:
  110. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  111. sampling_tensors.output_tokens,
  112. sampling_tensors.presence_penalties,
  113. sampling_tensors.frequency_penalties,
  114. sampling_tensors.repetition_penalties)
  115. # Apply temperature scaling if not doing temp_last.
  116. if not do_temp_last:
  117. # Use float32 to apply temp.
  118. # Use in-place division to avoid creating a new tensor.
  119. logits = logits.to(torch.float)
  120. logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
  121. if do_top_p_top_k:
  122. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  123. sampling_tensors.top_ks)
  124. if do_top_as:
  125. logits = _apply_top_a(logits, sampling_tensors.top_as)
  126. if do_min_p:
  127. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  128. if do_tfss:
  129. logits = _apply_tfs(logits, sampling_tensors.tfss)
  130. if do_eta_cutoffs:
  131. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  132. if do_epsilon_cutoffs:
  133. logits = _apply_epsilon_cutoff(logits,
  134. sampling_tensors.epsilon_cutoffs)
  135. if do_typical_ps:
  136. logits = _apply_typical_sampling(logits,
  137. sampling_tensors.typical_ps)
  138. if do_quadratic:
  139. logits = _apply_quadratic_sampling(
  140. logits, sampling_tensors.smoothing_factors,
  141. sampling_tensors.smoothing_curves)
  142. if do_temp_last:
  143. # Use float32 to apply temp.
  144. # Use in-place division to avoid creating a new tensor.
  145. logits = logits.to(torch.float)
  146. logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
  147. # banned_tokens = _get_custom_token_bans(sampling_metadata)
  148. # logits = _apply_token_bans(logits, banned_tokens)
  149. # We use float32 for probabilities and log probabilities.
  150. # Compute the probabilities.
  151. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  152. # Compute the log probabilities.
  153. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  154. # Sample the next tokens.
  155. sample_results, maybe_sampled_tokens_tensor = _sample(
  156. probs,
  157. logprobs,
  158. sampling_metadata,
  159. sampling_tensors,
  160. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  161. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  162. )
  163. if self.include_gpu_probs_tensor:
  164. assert maybe_sampled_tokens_tensor is not None
  165. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  166. else:
  167. on_device_tensors = None
  168. # Get the logprobs query results.
  169. prompt_logprobs = None
  170. sample_logprobs = None
  171. if not sampling_metadata.skip_sampler_cpu_output:
  172. prompt_logprobs, sample_logprobs = _get_logprobs(
  173. logprobs, sampling_metadata, sample_results)
  174. return _build_sampler_output(
  175. sample_results,
  176. sampling_metadata,
  177. prompt_logprobs,
  178. sample_logprobs,
  179. on_device_tensors=on_device_tensors,
  180. skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
  181. @property
  182. def _should_modify_greedy_probs_inplace(self) -> bool:
  183. """Whether or not the sampler should modify the probability distribution
  184. of greedily-sampled tokens such that multinomial sampling would sample
  185. the greedily-sampled token.
  186. In other words, if True then we set the probability of the greedily-
  187. sampled token to 1.
  188. This is used by speculative decoding, which requires that the sampling
  189. method be encoded into the probability distribution.
  190. """
  191. return self.should_modify_greedy_probs_inplace
  192. def _get_bin_counts_and_mask(
  193. tokens: torch.Tensor,
  194. vocab_size: int,
  195. num_seqs: int,
  196. ) -> Tuple[torch.Tensor, torch.Tensor]:
  197. # Compute the bin counts for the tokens.
  198. # vocab_size + 1 for padding.
  199. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  200. dtype=torch.long,
  201. device=tokens.device)
  202. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  203. bin_counts = bin_counts[:, :vocab_size]
  204. mask = bin_counts > 0
  205. return bin_counts, mask
  206. def _get_custom_token_bans(
  207. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  208. assert sampling_metadata.seq_groups is not None
  209. banned_tokens: List[List[int]] = []
  210. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  211. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  212. seq_ids = seq_group.seq_ids
  213. custom_token_bans = sampling_params.custom_token_bans
  214. if (i < sampling_metadata.num_prompts
  215. and sampling_params.prompt_logprobs is not None):
  216. prompt_len = len(seq_group.prompt_logprob_indices)
  217. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  218. banned_tokens += [custom_token_bans] * len(seq_ids)
  219. return banned_tokens
  220. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  221. output_tokens_tensor: torch.Tensor,
  222. presence_penalties: torch.Tensor,
  223. frequency_penalties: torch.Tensor,
  224. repetition_penalties: torch.Tensor) -> torch.Tensor:
  225. num_seqs, vocab_size = logits.shape
  226. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  227. num_seqs)
  228. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  229. output_tokens_tensor, vocab_size, num_seqs)
  230. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  231. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  232. logits = torch.where(logits > 0, logits / repetition_penalties,
  233. logits * repetition_penalties)
  234. # We follow the definition in OpenAI API.
  235. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  236. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  237. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  238. return logits
  239. def _apply_token_bans(logits: torch.Tensor,
  240. banned_tokens: List[List[int]]) -> torch.Tensor:
  241. for i, banned_token_ids in enumerate(banned_tokens):
  242. if not banned_token_ids:
  243. continue
  244. logits[i, banned_token_ids] = -float("inf")
  245. return logits
  246. def _apply_min_tokens_penalty(
  247. logits: torch.Tensor,
  248. sampling_metadata: SamplingMetadata,
  249. ) -> torch.Tensor:
  250. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  251. have not been generated yet
  252. """
  253. # list of indices in logits that will be set to -inf
  254. logits_to_penalize = []
  255. logits_applied = 0
  256. for seq_group in sampling_metadata.seq_groups:
  257. seq_ids = seq_group.seq_ids
  258. sampling_params = seq_group.sampling_params
  259. sample_indices = seq_group.sample_indices
  260. logits_applied += len(sample_indices) + len(
  261. seq_group.prompt_logprob_indices)
  262. if not seq_group.do_sample:
  263. continue
  264. start_idx = sample_indices[0]
  265. min_tokens = sampling_params.min_tokens
  266. token_ids_to_penalize = sampling_params.all_stop_token_ids
  267. if min_tokens > 0 and token_ids_to_penalize:
  268. seqs_to_penalize = []
  269. for j, seq_id in enumerate(seq_ids):
  270. seq_data = seq_group.seq_data[seq_id]
  271. if len(seq_data.output_token_ids_array) < min_tokens:
  272. seqs_to_penalize.append(j)
  273. if seqs_to_penalize:
  274. # convert to the index into logits
  275. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  276. # itertools.product pairs each seq index with every token id
  277. logits_to_penalize.extend(
  278. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  279. if logits_to_penalize:
  280. # use zip and * to group indices along each dimension
  281. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  282. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  283. # verifies that no rows in logits were missed unexpectedly
  284. assert logits_applied == logits.shape[0]
  285. return logits
  286. def _apply_top_k_top_p(
  287. logits: torch.Tensor,
  288. p: torch.Tensor,
  289. k: torch.Tensor,
  290. ) -> torch.Tensor:
  291. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  292. # Apply top-k.
  293. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  294. # Get all the top_k values.
  295. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  296. top_k_mask = logits_sort < top_k_mask
  297. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  298. # Apply top-p.
  299. probs_sort = logits_sort.softmax(dim=-1)
  300. probs_sum = probs_sort.cumsum(dim=-1)
  301. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  302. # at least one
  303. top_p_mask[:, -1] = False
  304. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  305. # Re-sort the probabilities.
  306. src = torch.arange(logits_idx.shape[-1],
  307. device=logits_idx.device).expand_as(logits_idx)
  308. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  309. index=logits_idx,
  310. src=src)
  311. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  312. return logits
  313. def _apply_min_p(
  314. logits: torch.Tensor,
  315. min_p: torch.Tensor,
  316. ) -> torch.Tensor:
  317. """
  318. Adapted from
  319. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  320. """
  321. probs = torch.softmax(logits, dim=-1)
  322. top_probs, _ = probs.max(dim=-1, keepdim=True)
  323. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  324. tokens_to_remove = probs < scaled_min_p
  325. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  326. return logits
  327. def _apply_top_a(
  328. logits: torch.Tensor,
  329. top_a: torch.Tensor,
  330. ) -> torch.Tensor:
  331. probs = torch.softmax(logits, dim=-1)
  332. top_probs, _ = probs.max(dim=-1, keepdim=True)
  333. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  334. tokens_to_remove = probs < threshold
  335. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  336. return logits
  337. def _apply_tfs(
  338. logits: torch.Tensor,
  339. tfs: torch.Tensor,
  340. ) -> torch.Tensor:
  341. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  342. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  343. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  344. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  345. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  346. tfs_mask = torch.cat(
  347. (
  348. torch.zeros(
  349. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  350. tfs_mask,
  351. torch.ones(
  352. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  353. ),
  354. dim=-1,
  355. )
  356. logits_sort[tfs_mask] = -float("inf")
  357. logits = torch.gather(logits_sort,
  358. dim=-1,
  359. index=torch.argsort(logits_idx, dim=-1))
  360. return logits
  361. def _apply_eta_cutoff(
  362. logits: torch.Tensor,
  363. eta_cutoff: torch.Tensor,
  364. ) -> torch.Tensor:
  365. shifted_logits = torch.log_softmax(logits, dim=-1)
  366. probs = shifted_logits.exp()
  367. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  368. eps = torch.min(eta_cutoff,
  369. torch.sqrt(eta_cutoff) *
  370. torch.exp(neg_entropy)).unsqueeze(dim=1)
  371. eta_mask = probs < eps
  372. # guard against nulling out all the logits
  373. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  374. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  375. logits[eta_mask] = -float("inf")
  376. return logits
  377. def _apply_epsilon_cutoff(
  378. logits: torch.Tensor,
  379. epsilon_cutoff: torch.Tensor,
  380. ) -> torch.Tensor:
  381. probs = logits.softmax(dim=-1)
  382. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  383. # guard against nulling out all the logits
  384. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  385. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  386. logits[eps_mask] = -float("inf")
  387. return logits
  388. def _apply_typical_sampling(
  389. logits: torch.Tensor,
  390. typical_p: torch.Tensor,
  391. ) -> torch.Tensor:
  392. shifted_logits = torch.log_softmax(logits, dim=-1)
  393. probs = shifted_logits.exp()
  394. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  395. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  396. _, indices = torch.sort(surprisal_deviations)
  397. reordered_probs = probs.gather(-1, indices)
  398. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  399. dim=1)
  400. min_tokens_to_keep = 1
  401. # Keep at least min_tokens_to_keep
  402. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  403. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  404. logits[typ_mask] = -float("inf")
  405. return logits
  406. def _apply_quadratic_sampling(
  407. logits: torch.Tensor,
  408. smoothing_factor: torch.Tensor,
  409. smoothing_curve: torch.Tensor,
  410. ) -> torch.Tensor:
  411. """
  412. Applies a quadratic transformation to the logits based on the
  413. provided smoothing factors and curves. The transformation is
  414. centered around the maximum logit value in the batch.
  415. The transformation involves a quadratic and cubic term, with the
  416. cubic term controlled by the smoothing curve. The quadratic term is
  417. scaled by the smoothing factor, and the cubic term is scaled by the
  418. product of the smoothing factor and the smoothing curve.
  419. params:
  420. logits (torch.Tensor): The logits to be transformed.
  421. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  422. term in the transformation.
  423. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  424. in the transformation.
  425. returns:
  426. torch.Tensor: The transformed logits.
  427. Credits: @kalomaze
  428. """
  429. mask = smoothing_factor != 0
  430. smoothing_factor.unsqueeze_(dim=1)
  431. smoothing_curve.unsqueeze_(dim=1)
  432. k = smoothing_factor * (3 - smoothing_curve) / 2
  433. s = smoothing_factor * (smoothing_curve - 1) / 2
  434. quadlogits = logits[mask] # limit to logits using this sampler
  435. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  436. # Construct the delta from each logit to its new value
  437. diff = quadlogits - max_logits
  438. diff -= diff**2 * (s[mask] * diff - k[mask])
  439. diff[diff != diff] = 0 # Eliminate NaNs due to infs
  440. logits[mask] -= diff
  441. return logits
  442. def _greedy_sample(
  443. selected_seq_groups: List[SequenceGroupToSample],
  444. samples: torch.Tensor,
  445. ) -> List[Tuple[List[int], List[int]]]:
  446. """Run greedy sampling on a given samples.
  447. Args:
  448. selected_seq_groups: A list of sequence groups batched.
  449. samples: (num_selected_samples,) A tensor of samples. The length of
  450. samples could be smaller than selected_seq_groups if
  451. seq_group.do_sample is False.
  452. Returns:
  453. Tuple of (next_token_ids, parent_ids). The length of returned list is
  454. same as the length of selected_seq_groups. If the corresponding
  455. seq_group has do_sample=False, tuple contains ([], [])
  456. """
  457. samples = samples.tolist()
  458. sample_idx = 0
  459. results = []
  460. for seq_group in selected_seq_groups:
  461. if not seq_group.do_sample:
  462. results.append(([], []))
  463. continue
  464. seq_ids = seq_group.seq_ids
  465. num_parent_seqs = len(seq_ids)
  466. assert num_parent_seqs == 1, (
  467. "Greedy sampling should have only one seq.")
  468. parent_ids = list(range(num_parent_seqs))
  469. next_token_ids = [samples[sample_idx]]
  470. results.append((next_token_ids, parent_ids))
  471. sample_idx += num_parent_seqs
  472. return results
  473. def _random_sample(
  474. selected_seq_groups: List[SequenceGroupToSample],
  475. random_samples: torch.Tensor,
  476. ) -> List[Tuple[List[int], List[int]]]:
  477. """Run random sampling on a given samples.
  478. Args:
  479. selected_seq_groups: A list of sequence groups batched.
  480. random_samples: (num_selected_samples,) A tensor of samples. The
  481. length of samples could be smaller than selected_seq_groups if
  482. seq_group.do_sample is False.
  483. Returns:
  484. Tuple of (next_token_ids, parent_ids). The length of returned list is
  485. same as the length of selected_seq_groups. If the corresponding
  486. seq_group has do_sample=False, tuple contains ([], [])
  487. """
  488. # Find the maximum best_of value of the prompt phase requests.
  489. random_samples = random_samples.cpu()
  490. sample_idx = 0
  491. results = []
  492. for seq_group in selected_seq_groups:
  493. if not seq_group.do_sample:
  494. results.append(([], []))
  495. continue
  496. seq_ids = seq_group.seq_ids
  497. sampling_params = seq_group.sampling_params
  498. is_prompt = seq_group.is_prompt
  499. num_parent_seqs = len(seq_ids)
  500. if is_prompt:
  501. # Prompt phase.
  502. parent_ids = [0] * sampling_params.best_of
  503. next_token_ids = random_samples[
  504. sample_idx, :sampling_params.best_of].tolist()
  505. else:
  506. # Generation phase.
  507. parent_ids = list(range(num_parent_seqs))
  508. next_token_ids = random_samples[sample_idx:sample_idx +
  509. num_parent_seqs, 0].tolist()
  510. results.append((next_token_ids, parent_ids))
  511. sample_idx += num_parent_seqs
  512. return results
  513. def _beam_search_sample(
  514. selected_seq_groups: List[SequenceGroupToSample],
  515. logprobs: torch.Tensor,
  516. ) -> List[Tuple[List[int], List[int]]]:
  517. """Run beam sampling on a given samples.
  518. Args:
  519. selected_seq_groups: A list of sequence groups batched.
  520. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  521. on selected sample indices.
  522. Returns:
  523. Tuple of (next_token_ids, parent_ids). The length of returned list is
  524. same as the length of selected_seq_groups. If the corresponding
  525. seq_group has do_sample=False, tuple contains ([], [])
  526. """
  527. # We sample 2 * beam_width candidates to make sure that with high
  528. # probability we can get `beam_width` candidates in addition to
  529. # the finished sequences for the next iteration. See
  530. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  531. # for details. See also HF reference:
  532. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  533. #
  534. # NOTE: Beam search is not vectorized, so its speed can be slower than
  535. # other sampling methods.
  536. sample_idx = 0
  537. results = []
  538. for seq_group in selected_seq_groups:
  539. if not seq_group.do_sample:
  540. results.append(([], []))
  541. continue
  542. is_prompt = seq_group.is_prompt
  543. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  544. num_parent_seqs = len(seq_ids)
  545. beam_width = sampling_params.best_of
  546. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  547. if is_prompt:
  548. # Prompt phase.
  549. assert num_parent_seqs == 1, (
  550. "Prompt input should have only one seq.")
  551. parent_ids = [0] * (2 * beam_width)
  552. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  553. 2 * beam_width)
  554. next_token_ids = next_token_ids.tolist()
  555. else:
  556. # Generation phase.
  557. cumulative_logprobs = [
  558. seq_group.seq_data[seq_id].cumulative_logprob
  559. for seq_id in seq_ids
  560. ]
  561. cumulative_logprobs = torch.tensor(
  562. cumulative_logprobs,
  563. dtype=torch.float,
  564. device=seq_group_logprobs.device)
  565. seq_group_logprobs = (seq_group_logprobs +
  566. cumulative_logprobs.unsqueeze(dim=1))
  567. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  568. 2 * beam_width)
  569. topk_ids = topk_ids.tolist()
  570. vocab_size = seq_group_logprobs.size(-1)
  571. parent_ids = [i // vocab_size for i in topk_ids]
  572. next_token_ids = [i % vocab_size for i in topk_ids]
  573. results.append((next_token_ids, parent_ids))
  574. sample_idx += num_parent_seqs
  575. assert sample_idx == logprobs.size(0)
  576. return results
  577. # torch.multinomial forces a GPU<->CPU sync.
  578. # Therefore, we use an optimized implementation instead.
  579. # Note that we always sample with replacement.
  580. # probs will be modified in place, but this is fine, as we pass
  581. # in a copy already.
  582. def _multinomial(
  583. probs: torch.Tensor,
  584. num_samples: int,
  585. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  586. ) -> torch.Tensor:
  587. if num_samples > 1:
  588. # This is equivalent to torch.repeat_interleaved (which also
  589. # forces a GPU<->CPU sync).
  590. # This allows us to do sampling with replacement by creating
  591. # num_samples copies of each row in the tensor, and then
  592. # batch sampling the resulting tensor.
  593. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  594. probs.shape[1]).contiguous().view(
  595. -1, probs.shape[1])
  596. q = torch.empty_like(probs)
  597. if seq_groups is None:
  598. q.exponential_()
  599. else:
  600. sample_idx = 0
  601. for seq_group in seq_groups:
  602. seq_ids = seq_group.seq_ids
  603. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  604. q[sample_idx:next_sample_idx].exponential_(
  605. generator=seq_group.generator)
  606. sample_idx = next_sample_idx
  607. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  608. def _sample_with_torch(
  609. probs: torch.Tensor,
  610. logprobs: torch.Tensor,
  611. sampling_metadata: SamplingMetadata,
  612. include_gpu_probs_tensor: bool,
  613. modify_greedy_probs: bool,
  614. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  615. categorized_seq_group_ids = {t: [] for t in SamplingType}
  616. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  617. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  618. sampling_params = seq_group.sampling_params
  619. sampling_type = sampling_params.sampling_type
  620. categorized_seq_group_ids[sampling_type].append(i)
  621. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  622. sample_metadata = {}
  623. multinomial_samples = {}
  624. # Create output tensor for sampled token ids.
  625. if include_gpu_probs_tensor:
  626. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  627. 1,
  628. dtype=torch.long,
  629. device=logprobs.device)
  630. else:
  631. sampled_token_ids_tensor = None
  632. # Counterintuitively, having two loops here is actually faster.
  633. # The first loop can run without waiting on GPU<->CPU sync.
  634. for sampling_type in SamplingType:
  635. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  636. num_tokens = len(sample_indices)
  637. if num_tokens == 0:
  638. continue
  639. seq_group_id = categorized_seq_group_ids[sampling_type]
  640. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  641. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  642. long_sample_indices = sample_indices.long()
  643. if sampling_type == SamplingType.GREEDY:
  644. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  645. dim=-1)
  646. if include_gpu_probs_tensor:
  647. # Store sampled tokens in output tensor.
  648. sampled_token_ids_tensor[
  649. long_sample_indices] = greedy_samples.unsqueeze(-1)
  650. if modify_greedy_probs:
  651. # If required, modify the probabilities such that sampling from
  652. # the modified distribution would always sample the argmax
  653. # token id.
  654. _modify_greedy_probs_inplace(logprobs, probs,
  655. long_sample_indices,
  656. greedy_samples)
  657. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  658. max_best_of_in_batch = 1
  659. for seq_group in seq_groups:
  660. if seq_group.is_prompt:
  661. sampling_params = seq_group.sampling_params
  662. max_best_of_in_batch = max(max_best_of_in_batch,
  663. sampling_params.best_of)
  664. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  665. "seq_groups": seq_groups,
  666. }
  667. multinomial_samples[sampling_type] = _multinomial(
  668. probs[long_sample_indices], max_best_of_in_batch,
  669. **seeded_args)
  670. if include_gpu_probs_tensor:
  671. # Store sampled tokens in output tensor.
  672. sampled_token_ids_tensor[
  673. long_sample_indices] = multinomial_samples[sampling_type]
  674. elif sampling_type == SamplingType.BEAM:
  675. beam_search_logprobs = logprobs[sample_indices]
  676. else:
  677. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  678. # GPU<->CPU sync happens in the loop below.
  679. # This also converts the sample output to Python objects.
  680. if not sampling_metadata.skip_sampler_cpu_output:
  681. for sampling_type in SamplingType:
  682. if sampling_type not in sample_metadata:
  683. continue
  684. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  685. if sampling_type == SamplingType.GREEDY:
  686. sample_results = _greedy_sample(seq_groups, greedy_samples)
  687. elif sampling_type in (SamplingType.RANDOM,
  688. SamplingType.RANDOM_SEED):
  689. sample_results = _random_sample(
  690. seq_groups, multinomial_samples[sampling_type])
  691. elif sampling_type == SamplingType.BEAM:
  692. sample_results = _beam_search_sample(seq_groups,
  693. beam_search_logprobs)
  694. sample_results_dict.update(zip(seq_group_id, sample_results))
  695. sample_results = [
  696. sample_results_dict.get(i, ([], []))
  697. for i in range(len(sampling_metadata.seq_groups))
  698. ]
  699. else:
  700. sample_results = []
  701. return sample_results, sampled_token_ids_tensor
  702. def _sample_with_triton_kernel(
  703. probs: torch.Tensor,
  704. logprobs: torch.Tensor,
  705. sampling_metadata: SamplingMetadata,
  706. sampling_tensors: SamplingTensors,
  707. ) -> List[Tuple[List[int], List[int]]]:
  708. categorized_seq_group_ids = {t: [] for t in SamplingType}
  709. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  710. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  711. sampling_params = seq_group.sampling_params
  712. sampling_type = sampling_params.sampling_type
  713. categorized_seq_group_ids[sampling_type].append(i)
  714. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  715. sample_metadata = {}
  716. max_best_of_in_batch = 1
  717. # Counterintuitively, having two loops here is actually faster.
  718. # The first loop can run without waiting on GPU<->CPU sync.
  719. for sampling_type in SamplingType:
  720. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  721. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  722. num_tokens = len(sample_indices)
  723. if num_tokens == 0:
  724. continue
  725. seq_group_id = categorized_seq_group_ids[sampling_type]
  726. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  727. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  728. sample_indices,
  729. sampled_token_indices)
  730. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  731. SamplingType.RANDOM_SEED):
  732. for seq_group in seq_groups:
  733. if seq_group.is_prompt:
  734. sampling_params = seq_group.sampling_params
  735. max_best_of_in_batch = max(max_best_of_in_batch,
  736. sampling_params.best_of)
  737. elif sampling_type == SamplingType.BEAM:
  738. beam_search_logprobs = logprobs[sample_indices]
  739. else:
  740. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  741. sampled_tokens, _, _ = sample_triton(
  742. probs=probs,
  743. seeds=sampling_tensors.sampling_seeds,
  744. max_best_of=max_best_of_in_batch,
  745. sample_indices=sampling_tensors.sample_indices,
  746. logprobs=logprobs,
  747. # don't save logprobs because we have logic for that below
  748. # TODO: use this instead of the CPU-based logic below
  749. save_logprobs=False,
  750. )
  751. # GPU<->CPU sync happens in the loop below.
  752. for sampling_type in SamplingType:
  753. if sampling_type not in sample_metadata:
  754. continue
  755. (seq_group_id, seq_groups, sample_indices,
  756. sampled_token_indices) = sample_metadata[sampling_type]
  757. if sampling_type == SamplingType.GREEDY:
  758. sample_results = _greedy_sample(
  759. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  760. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  761. sample_results = _random_sample(
  762. seq_groups, sampled_tokens[sampled_token_indices])
  763. elif sampling_type == SamplingType.BEAM:
  764. sample_results = _beam_search_sample(seq_groups,
  765. beam_search_logprobs)
  766. sample_results_dict.update(zip(seq_group_id, sample_results))
  767. sample_results = [
  768. sample_results_dict.get(i, ([], []))
  769. for i in range(len(sampling_metadata.seq_groups))
  770. ]
  771. return sample_results
  772. def _sample(
  773. probs: torch.Tensor, logprobs: torch.Tensor,
  774. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  775. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  776. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  777. """
  778. Args:
  779. probs: (num_query_tokens_in_batch, num_vocab)
  780. logprobs: (num_query_tokens_in_batch, num_vocab)
  781. sampling_metadata: The metadata for a batch for sampling.
  782. sampling_tensors: Tensors that include sampling related metadata.
  783. Returns:
  784. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  785. If sampling is skipped, it returns ([], [])
  786. sampled_token_ids_tensor: A tensor of sampled token ids.
  787. """
  788. return _sample_with_torch(
  789. probs,
  790. logprobs,
  791. sampling_metadata,
  792. include_gpu_probs_tensor=include_gpu_probs_tensor,
  793. modify_greedy_probs=modify_greedy_probs,
  794. )
  795. # TODO: Enable once Triton kernel & associated code is faster.
  796. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  797. # sampling_tensors)
  798. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  799. """
  800. This function calculates the ranks of the chosen tokens in a logprob tensor.
  801. Args:
  802. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  803. where N is the no. of tokens and M is the vocab dim.
  804. indices (torch.Tensor): List of chosen token indices.
  805. Returns:
  806. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  807. Each element in the returned tensor represents the rank
  808. of the chosen token in the input logprob tensor.
  809. """
  810. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  811. indices]
  812. return (x > vals[:, None]).long().sum(1).add_(1)
  813. def _get_logprobs(
  814. logprobs: torch.Tensor,
  815. sampling_metadata: SamplingMetadata,
  816. sample_results: List[Tuple[List[int], List[int]]],
  817. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  818. """Return sample lobprobs and prompt logprobs.
  819. The logic consists of 3 parts.
  820. - Select indices to compute logprob from, ranks of token ids, and
  821. the top k token ids from logprobs.
  822. - Compute prompt logprobs if required.
  823. - Compute sample logprobs if required.
  824. Args:
  825. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  826. logprob per vocab. Sequence groups' query tokens are batched in a
  827. single flattened tensor. For example, assuming there are N
  828. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  829. prompt logprob is enabled), decode tokens for seq_group_1 (if
  830. sampling is required), prefill tokens for seq_group_2, ...
  831. sampling_metadata: The sampling metadata.
  832. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  833. parent_ids) for each sequence group. When beam search is enabled,
  834. sample_results can contain different number of seq_ids from
  835. sampling_metadata.seq_groups. It is because beam search creates
  836. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  837. BEAM_WIDTH number of seq_ids).
  838. Returns:
  839. A tuple of prompt and sample logprobs per sequence group in a batch.
  840. """
  841. # The index of query token to calculate logprobs. It includes both
  842. # prompt and sample logprob indices.
  843. query_indices: List[int] = []
  844. # The next token ids to get the logprob value from.
  845. next_token_ids: List[int] = []
  846. # The largest requested number of logprobs. We find logprobs as many as the
  847. # largest num logprobs in this API. If every logprobs is None, it will be
  848. # set to -1.
  849. largest_num_logprobs = -1
  850. # If beam search is enabled.
  851. use_beam_search = False
  852. # Select indices to compute logprob from, ranks of token ids, and the top
  853. # k token ids from logprobs.
  854. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  855. sample_results):
  856. sampling_params = seq_group.sampling_params
  857. # Update indices and tokens for prompt logprobs.
  858. if (seq_group.is_prompt
  859. and sampling_params.prompt_logprobs is not None):
  860. largest_num_logprobs = max(largest_num_logprobs,
  861. sampling_params.prompt_logprobs)
  862. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  863. query_indices.extend(seq_group.prompt_logprob_indices)
  864. next_token_ids.extend(next_prompt_tokens)
  865. # Update indices and next tokenes for sample logprob.
  866. if seq_group.do_sample:
  867. token_ids, parent_seq_ids = sample_result
  868. # NOTE: We cannot directly use sample_indices because
  869. # sample_indices only contain parent seq_ids of a previous step.
  870. # The current step may have different number of seq_ids, and
  871. # we can obtain it from `sample_result[1]`.
  872. query_idx = seq_group.sample_indices[0]
  873. query_indices.extend(
  874. [query_idx + parent_id for parent_id in parent_seq_ids])
  875. next_token_ids.extend(token_ids)
  876. if sampling_params.logprobs is not None:
  877. largest_num_logprobs = max(largest_num_logprobs,
  878. sampling_params.logprobs)
  879. use_beam_search = use_beam_search or sampling_params.use_beam_search
  880. assert len(next_token_ids) == len(query_indices)
  881. if len(query_indices) == 0:
  882. empty_sampled_logprob = []
  883. empty_prompt_logprob = None
  884. return [empty_prompt_logprob], [empty_sampled_logprob]
  885. selected_logprobs, ranks = None, None
  886. top_logprobs, top_token_ids = None, None
  887. # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
  888. # skip the whole logprob calculation.
  889. if largest_num_logprobs >= 0 or use_beam_search:
  890. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  891. next_token_ids_gpu = torch.tensor(next_token_ids,
  892. device=logprobs.device)
  893. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  894. # contain duplicates if beam search is enabled.
  895. selected_logprobs = logprobs[[
  896. query_indices_gpu,
  897. next_token_ids_gpu,
  898. ]]
  899. ranks = _get_ranks(
  900. logprobs[query_indices_gpu],
  901. next_token_ids_gpu,
  902. )
  903. assert selected_logprobs.shape[0] == ranks.shape[0]
  904. # We need to compute top k only if there exists logprobs > 0.
  905. if largest_num_logprobs > 0:
  906. # Logprobs of topk tokens for a batch of sequence groups.
  907. # (num_query_tokens_across_batch).
  908. top_logprobs, top_token_ids = torch.topk(logprobs,
  909. largest_num_logprobs,
  910. dim=-1)
  911. top_logprobs = top_logprobs.to('cpu')
  912. top_token_ids = top_token_ids.to('cpu')
  913. selected_logprobs = selected_logprobs.to('cpu')
  914. ranks = ranks.to('cpu')
  915. # Find prompt/sample logprobs.
  916. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  917. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  918. top_logprob_idx = 0
  919. selected_logprobs_idx = 0
  920. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  921. sample_results):
  922. (prompt_logprobs, top_logprob_idx,
  923. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  924. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  925. selected_logprobs_idx, top_logprob_idx)
  926. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  927. (sampled_logprobs, top_logprob_idx,
  928. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  929. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  930. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  931. sample_logprobs_per_seq_group.append(sampled_logprobs)
  932. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  933. def _get_prompt_logprob_if_needed(
  934. seq_group: SequenceGroupToSample,
  935. selected_logprobs: torch.Tensor,
  936. ranks: torch.Tensor,
  937. top_token_ids: torch.Tensor,
  938. top_logprobs: torch.Tensor,
  939. selected_logprobs_idx: int,
  940. top_logprob_idx: int,
  941. ):
  942. """Compute the prompt logprob from a sequence group if needed."""
  943. sampling_params = seq_group.sampling_params
  944. is_prompt = seq_group.is_prompt
  945. # Find prompt logprobs
  946. prompt_logprobs: Optional[PromptLogprobs] = None
  947. if is_prompt and sampling_params.prompt_logprobs is not None:
  948. prompt_logprobs = []
  949. num_logprobs = sampling_params.prompt_logprobs
  950. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  951. # Pre-select indexes and create a list. It is faster than calling .item
  952. # repetitively.
  953. selected_logprob_items = selected_logprobs[
  954. selected_logprobs_idx:selected_logprobs_idx +
  955. len(next_prompt_tokens)].tolist()
  956. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  957. len(next_prompt_tokens)].tolist()
  958. for idx, token_id in enumerate(next_prompt_tokens):
  959. # Calculate the prompt logprob of the real prompt tokens.
  960. # {token_id: (logprob, rank_from_vocab)}
  961. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  962. token_id: (selected_logprob_items[idx], rank_items[idx])
  963. }
  964. # Add top K prompt logprobs along with its rank.
  965. if num_logprobs > 0:
  966. top_ids = top_token_ids[
  967. top_logprob_idx, :num_logprobs].tolist()
  968. top_probs = top_logprobs[
  969. top_logprob_idx, :num_logprobs].tolist()
  970. # Top K is already sorted by rank, so we can use 1 ~
  971. # num_logprobs + 1 for rank.
  972. top_ranks = range(1, num_logprobs + 1)
  973. prompt_logprobs_dict.update({
  974. top_id: (top_prob, rank)
  975. for top_id, top_prob, rank in zip(top_ids, top_probs,
  976. top_ranks)
  977. })
  978. prompt_logprobs.append({
  979. token_id: Logprob(*logprob_and_rank)
  980. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  981. })
  982. # + 1 to go to the next prompt token.
  983. top_logprob_idx += 1
  984. # + len(next_prompt_tokens) to go to the next prompt.
  985. selected_logprobs_idx += len(next_prompt_tokens)
  986. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  987. def _get_sampled_logprob_if_needed(
  988. seq_group: SequenceGroupToSample,
  989. sample_result: Tuple[List[int], List[int]],
  990. selected_logprobs: torch.Tensor,
  991. ranks: torch.Tensor,
  992. top_token_ids: torch.Tensor,
  993. top_logprobs: torch.Tensor,
  994. selected_logprobs_idx: int,
  995. top_logprob_idx: int,
  996. ):
  997. """Compute the sample logprob if needed."""
  998. seq_ids = seq_group.seq_ids
  999. num_logprobs = seq_group.sampling_params.logprobs
  1000. use_beam_search = seq_group.sampling_params.use_beam_search
  1001. sampled_logprobs: SampleLogprobs = []
  1002. next_token_ids, parent_seq_ids = sample_result
  1003. if seq_group.do_sample:
  1004. assert len(next_token_ids) > 0
  1005. if num_logprobs is None and not use_beam_search:
  1006. for next_token_id in next_token_ids:
  1007. # Use a dummy logprob
  1008. sampled_logprobs.append({next_token_id: Logprob(inf)})
  1009. else:
  1010. # Pre-select items from tensor. tolist() is faster than repetitive
  1011. # `.item()` calls.
  1012. selected_logprob_items = selected_logprobs[
  1013. selected_logprobs_idx:selected_logprobs_idx +
  1014. len(next_token_ids)].tolist()
  1015. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  1016. len(next_token_ids)].tolist()
  1017. for idx, (next_token_id, parent_id) in enumerate(
  1018. zip(next_token_ids, parent_seq_ids)):
  1019. # Get the logprob of a sampled token.
  1020. sampled_logprobs_dict = {
  1021. next_token_id:
  1022. (selected_logprob_items[idx], rank_items[idx])
  1023. }
  1024. if num_logprobs is not None and num_logprobs > 0:
  1025. # Get top K logprobs.
  1026. top_ids = top_token_ids[top_logprob_idx +
  1027. parent_id, :num_logprobs].tolist()
  1028. top_probs = top_logprobs[
  1029. top_logprob_idx + parent_id, :num_logprobs].tolist()
  1030. # Top K is already sorted by rank, so we can use 1 ~
  1031. # num_logprobs + 1 for rank.
  1032. top_ranks = range(1, num_logprobs + 1)
  1033. sampled_logprobs_dict.update({
  1034. top_id: (top_prob, rank)
  1035. for top_id, top_prob, rank in zip(
  1036. top_ids, top_probs, top_ranks)
  1037. })
  1038. sampled_logprobs.append({
  1039. token_id: Logprob(*logprob_and_rank)
  1040. for token_id, logprob_and_rank in
  1041. sampled_logprobs_dict.items()
  1042. })
  1043. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  1044. # logprobs for the current step, which has len(next_token_ids) tokens
  1045. # per sequence group. `logprobs` includes logprobs from the previous
  1046. # steps, which has len(seq_ids) tokens per sequence group.
  1047. # Iterate to the next sequence group in a batch.
  1048. selected_logprobs_idx += len(next_token_ids)
  1049. # Iterate to the next sequence group in a batch.
  1050. top_logprob_idx += len(seq_ids)
  1051. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  1052. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  1053. sample_indices: torch.Tensor,
  1054. greedy_samples: torch.Tensor) -> None:
  1055. """Modify the probability distributions of the greedily-sampled tokens such
  1056. that each sampled token has a "probability" of 1.0. This is required by
  1057. speculative decoding, which depends on the sampling method being encoded
  1058. within the probability distribution for correctness.
  1059. # Why do we only need to do this for greedy sampling?
  1060. Aphrodite's sampler performs the following steps for greedy or multinomial
  1061. (random) sampling:
  1062. 1. Get logits from model.
  1063. 2. Modify logits according to per-sequence sampling parameters.
  1064. - Multiply by temperature, top-k and top-p masking, penalize tokens
  1065. according to their frequency, etc.
  1066. 3. Sample a token.
  1067. - Random sampling simply samples from the modified probability
  1068. distribution.
  1069. - Greedy sampling performs `argmax` to obtain the token with the
  1070. highest likelihood.
  1071. Ignoring greedy sampling for a moment, we find that the computed probability
  1072. distribution has the following property: we can sample from it independently
  1073. and find that the token sampled by the Sampler has a frequency corresponding
  1074. to how often we see it in our sampling. In other words, for tokens sampled
  1075. with Aphrodite's random SamplingType, the computed probability distribution
  1076. encodes the sampling methodology completely.
  1077. Greedy sampling does not normally have this property. Aphrodite modifies
  1078. logits according to sampling params, then performs `argmax`, then returns
  1079. the sampled token and the computed probability distribution. If we sample
  1080. from the distribution, we'll find the likelihood of the greedily-sampled
  1081. token is not always 1.0.
  1082. Since lossless speculative decoding requires that the sampling methodology
  1083. be encoded within the probability distribution, we are motivated to modify
  1084. the probability distribution such that the sampled token has probability 1
  1085. when speculative decoding is used.
  1086. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1087. greedy sampling using multinomial computation and unite the codepaths. This
  1088. has implications on the overall design of the sampler, e.g. how to record
  1089. accurate logprobs for the user, so this improvement is deferred to later.
  1090. """
  1091. # NOTE: logprobs are not modified so they can be returned to the user.
  1092. probs[sample_indices, :] = 0
  1093. probs[sample_indices, greedy_samples] = 1.0
  1094. def _build_sampler_output(
  1095. sample_results: SampleResultType,
  1096. sampling_metadata: SamplingMetadata,
  1097. prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
  1098. sample_logprobs: Optional[List[SampleLogprobs]],
  1099. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1100. torch.Tensor]],
  1101. skip_sampler_cpu_output: bool = False,
  1102. ) -> SamplerOutput:
  1103. """Construct Python objects with the output of sampling.
  1104. Args:
  1105. on_device_tensors: Tuple containing on-device tensors with the
  1106. probabilities used in sampling and the sampled token ids. This
  1107. allows post-processing without copies to CPU/serialization, e.g. in
  1108. speculative decoding rejection sampling.
  1109. """
  1110. sampler_output: List[CompletionSequenceGroupOutput] = []
  1111. if not skip_sampler_cpu_output:
  1112. assert prompt_logprobs is not None
  1113. assert sample_logprobs is not None
  1114. for (seq_group, sample_result, group_prompt_logprobs,
  1115. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1116. sample_results, prompt_logprobs,
  1117. sample_logprobs):
  1118. seq_ids = seq_group.seq_ids
  1119. next_token_ids, parent_ids = sample_result
  1120. seq_outputs: List[SequenceOutput] = []
  1121. for parent_id, next_token_id, logprobs in zip(
  1122. parent_ids, next_token_ids, group_sample_logprobs):
  1123. seq_outputs.append(
  1124. SequenceOutput(seq_ids[parent_id], next_token_id,
  1125. logprobs))
  1126. sampler_output.append(
  1127. CompletionSequenceGroupOutput(seq_outputs,
  1128. group_prompt_logprobs))
  1129. # If not specified, store None values in SamplerOutput.
  1130. if on_device_tensors is not None:
  1131. (sampled_token_probs, logprobs_tensor,
  1132. sampled_token_ids) = on_device_tensors
  1133. else:
  1134. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1135. None)
  1136. return SamplerOutput(
  1137. outputs=sampler_output,
  1138. sampled_token_probs=sampled_token_probs,
  1139. sampled_token_ids=sampled_token_ids,
  1140. logprobs=logprobs_tensor,
  1141. )
  1142. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1143. """Get a list of next prompt tokens to compute logprob from a
  1144. given sequence group.
  1145. It is used to compute prompt logprob. Imagine you have logprob for each
  1146. query token. Query token needs to know the next prompt token id to compute
  1147. prompt logprob. This is a helper to obtain next prompt token ids.
  1148. This API has to be used only when the caller knows seq_group is in prefill
  1149. stage.
  1150. Returns:
  1151. A list of next prompt tokens to compute logprob.
  1152. """
  1153. assert seq_group.is_prompt, (
  1154. "Caller should ensure the sequence group is in a prefill stage.")
  1155. seq_ids = seq_group.seq_ids
  1156. query_len = seq_group.query_len
  1157. assert query_len is not None
  1158. # prompt has only 1 seq id.
  1159. assert len(seq_ids) == 1
  1160. seq_data = seq_group.seq_data[seq_ids[0]]
  1161. computed_len = seq_data.get_num_computed_tokens()
  1162. prompt_tokens = seq_data.prompt_token_ids
  1163. # +1 because we are looking for a next prompt token.
  1164. next_token_index_start = computed_len + 1
  1165. next_token_index_end = min(computed_len + query_len + 1,
  1166. len(prompt_tokens))
  1167. next_prompt_tokens = prompt_tokens[
  1168. next_token_index_start:next_token_index_end]
  1169. return next_prompt_tokens
  1170. # def _apply_mirostat_v2(logits: torch.Tensor,
  1171. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1172. # # Reduce our view to just the affected logits
  1173. # logit_view = logits[sampling_tensors.miro_indices]
  1174. # # Calculate surprise value per token
  1175. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1176. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1177. # # Mask out "too-surprising" tokens (surprisal > mu)
  1178. # mus = sampling_tensors.miro_mus
  1179. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1180. # # Unmask most-likely logit to guarantee a selection.
  1181. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1182. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1183. # # Apply logit mask (effectively a top-k filter).
  1184. # logit_view[miro_mask] = -float("inf")
  1185. # # Project logit changes made to the view onto the original.
  1186. # # I think this step might be redundant.
  1187. # logits[sampling_tensors.miro_indices] = logit_view
  1188. # return logits
  1189. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1190. # sample_results: List[Tuple[List[int], List[int]]],
  1191. # sampling_metadata: SamplingMetadata,
  1192. # output_metadata: OutputMetadata) -> None:
  1193. # """Based on whichever token was finally sampled, we calculate the
  1194. # final surprisal values to update the mus.
  1195. # Because a single sequence can have multiple samples, we must fork
  1196. # the mu accordingly."""
  1197. # assert sampling_metadata.seq_groups is not None
  1198. # seqid_to_tokens = {}
  1199. # seqid_to_indices = {}
  1200. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1201. # sample_results):
  1202. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1203. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1204. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1205. # seqids = args.miro_seqids
  1206. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1207. # device=logits.device,
  1208. # dtype=torch.long)
  1209. # # Clumsily, we recalculate token surprisals.
  1210. # logits_view = logits[args.miro_indices]
  1211. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1212. # dim=-1,
  1213. # index=picked_tokens) / -math.log(2)
  1214. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1215. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1216. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1217. # nu_mus = mus - (picked_surprise - taus) * etas
  1218. # # Record updated mu values for use in the next iteration
  1219. # # Note how each mu is split into multiple based on the number of samples.
  1220. # for seqid, seq_mus in zip(seqids, nu_mus):
  1221. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1222. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)