sampler.py 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from aphrodite.common.sampling_params import SamplingType
  7. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  8. PromptLogprobs, SampleLogprobs,
  9. SamplerOutput, SequenceOutput)
  10. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  11. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  12. SamplingTensors,
  13. SequenceGroupToSample)
  14. class Sampler(nn.Module):
  15. """Samples the next tokens from the model's outputs.
  16. This layer does the following:
  17. 1. Discard the hidden states that are not used for sampling (i.e., all
  18. tokens except the final one in each prompt).
  19. 2. Compute the logits for the next tokens.
  20. 3. Apply presence, frequency and repetition penalties.
  21. 4. Apply temperature scaling.
  22. 5. Apply top-p and top-k truncation.
  23. 6. Sample the next tokens.
  24. Here, each sequence group within the batch can have different sampling
  25. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  26. The structure of the logits tensor is coupled with the seq_groups in
  27. sampling_metadata. Typically, each sequence in each seq_group has one row in
  28. logits for the next token to be sampled; however, for a seq_group with a
  29. prompt request with the prompt_logprobs sampling parameter, there are rows
  30. in logits for each token in the input prompt.
  31. """
  32. def __init__(self):
  33. super().__init__()
  34. # Whether or not the SamplerOutput should have on-device tensors
  35. # containing the sampled token ids and probabilities. This is used by
  36. # speculative decoding.
  37. self.include_gpu_probs_tensor = False
  38. def forward(
  39. self,
  40. logits: torch.Tensor,
  41. sampling_metadata: SamplingMetadata,
  42. ) -> Optional[SamplerOutput]:
  43. """
  44. Args:
  45. logits: (num_tokens, vocab_size).
  46. sampling_metadata: Metadata for sampling.
  47. """
  48. assert logits is not None
  49. _, vocab_size = logits.shape
  50. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  51. # Prepare sampling tensors with pinned memory to avoid blocking.
  52. (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as, do_min_p,
  53. do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
  54. do_quadratic) = SamplingTensors.from_sampling_metadata(
  55. sampling_metadata, vocab_size, logits.device, logits.dtype)
  56. # Apply presence and frequency penalties.
  57. if do_penalties:
  58. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  59. sampling_tensors.output_tokens,
  60. sampling_tensors.presence_penalties,
  61. sampling_tensors.frequency_penalties,
  62. sampling_tensors.repetition_penalties)
  63. # Apply temperature scaling.
  64. # Use in-place division to avoid creating a new tensor.
  65. logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
  66. if do_top_p_top_k:
  67. logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
  68. sampling_tensors.top_ks)
  69. if do_top_as:
  70. logits = _apply_top_a(logits, sampling_tensors.top_as)
  71. if do_min_p:
  72. logits = _apply_min_p(logits, sampling_tensors.min_ps)
  73. if do_tfss:
  74. logits = _apply_tfs(logits, sampling_tensors.tfss)
  75. if do_eta_cutoffs:
  76. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  77. if do_epsilon_cutoffs:
  78. logits = _apply_epsilon_cutoff(logits,
  79. sampling_tensors.epsilon_cutoffs)
  80. if do_typical_ps:
  81. logits = _apply_typical_sampling(logits,
  82. sampling_tensors.typical_ps)
  83. if do_quadratic:
  84. logits = _apply_quadratic_sampling(
  85. logits, sampling_tensors.smoothing_factors,
  86. sampling_tensors.smoothing_curves)
  87. # banned_tokens = _get_custom_token_bans(sampling_metadata)
  88. # logits = _apply_token_bans(logits, banned_tokens)
  89. # We use float32 for probabilities and log probabilities.
  90. # Compute the probabilities.
  91. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  92. # Compute the log probabilities.
  93. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  94. # Sample the next tokens.
  95. sample_results, maybe_sampled_tokens_tensor = _sample(
  96. probs,
  97. logprobs,
  98. sampling_metadata,
  99. sampling_tensors,
  100. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  101. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  102. )
  103. if self.include_gpu_probs_tensor:
  104. assert maybe_sampled_tokens_tensor is not None
  105. on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
  106. else:
  107. on_device_tensors = None
  108. # Get the logprobs query results.
  109. prompt_logprobs, sample_logprobs = _get_logprobs(
  110. logprobs, sampling_metadata, sample_results)
  111. return _build_sampler_output(sample_results,
  112. sampling_metadata,
  113. prompt_logprobs,
  114. sample_logprobs,
  115. on_device_tensors=on_device_tensors)
  116. @property
  117. def _should_modify_greedy_probs_inplace(self) -> bool:
  118. """Whether or not the sampler should modify the probability distribution
  119. of greedily-sampled tokens such that multinomial sampling would sample
  120. the greedily-sampled token.
  121. In other words, if True then we set the probability of the greedily-
  122. sampled token to 1.
  123. This is used by speculative decoding, which requires that the sampling
  124. method be encoded into the probability distribution.
  125. """
  126. # Modify greedy probs if include_gpu_probs_tensor is set.
  127. return self.include_gpu_probs_tensor
  128. def _get_bin_counts_and_mask(
  129. tokens: torch.Tensor,
  130. vocab_size: int,
  131. num_seqs: int,
  132. ) -> Tuple[torch.Tensor, torch.Tensor]:
  133. # Compute the bin counts for the tokens.
  134. # vocab_size + 1 for padding.
  135. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  136. dtype=torch.long,
  137. device=tokens.device)
  138. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  139. bin_counts = bin_counts[:, :vocab_size]
  140. mask = bin_counts > 0
  141. return bin_counts, mask
  142. def _get_custom_token_bans(
  143. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  144. assert sampling_metadata.seq_groups is not None
  145. banned_tokens: List[List[int]] = []
  146. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  147. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  148. seq_ids = seq_group.seq_ids
  149. custom_token_bans = sampling_params.custom_token_bans
  150. if (i < sampling_metadata.num_prompts
  151. and sampling_params.prompt_logprobs is not None):
  152. prompt_len = len(seq_group.prompt_logprob_indices)
  153. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  154. banned_tokens += [custom_token_bans] * len(seq_ids)
  155. return banned_tokens
  156. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  157. output_tokens_tensor: torch.Tensor,
  158. presence_penalties: torch.Tensor,
  159. frequency_penalties: torch.Tensor,
  160. repetition_penalties: torch.Tensor) -> torch.Tensor:
  161. num_seqs, vocab_size = logits.shape
  162. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  163. num_seqs)
  164. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  165. output_tokens_tensor, vocab_size, num_seqs)
  166. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  167. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  168. logits = torch.where(logits > 0, logits / repetition_penalties,
  169. logits * repetition_penalties)
  170. # We follow the definition in OpenAI API.
  171. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  172. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  173. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  174. return logits
  175. def _apply_token_bans(logits: torch.Tensor,
  176. banned_tokens: List[List[int]]) -> torch.Tensor:
  177. for i, banned_token_ids in enumerate(banned_tokens):
  178. if not banned_token_ids:
  179. continue
  180. logits[i, banned_token_ids] = -float("inf")
  181. return logits
  182. def _apply_min_tokens_penalty(
  183. logits: torch.Tensor,
  184. sampling_metadata: SamplingMetadata,
  185. ) -> torch.Tensor:
  186. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  187. have not been generated yet
  188. """
  189. # list of indices in logits that will be set to -inf
  190. logits_to_penalize = []
  191. logits_applied = 0
  192. for seq_group in sampling_metadata.seq_groups:
  193. seq_ids = seq_group.seq_ids
  194. sampling_params = seq_group.sampling_params
  195. sample_indices = seq_group.sample_indices
  196. logits_applied += len(sample_indices) + len(
  197. seq_group.prompt_logprob_indices)
  198. if not seq_group.do_sample:
  199. continue
  200. start_idx = sample_indices[0]
  201. min_tokens = sampling_params.min_tokens
  202. token_ids_to_penalize = sampling_params.all_stop_token_ids
  203. if min_tokens > 0 and token_ids_to_penalize:
  204. seqs_to_penalize = []
  205. for j, seq_id in enumerate(seq_ids):
  206. seq_data = seq_group.seq_data[seq_id]
  207. if len(seq_data.output_token_ids) < min_tokens:
  208. seqs_to_penalize.append(j)
  209. if seqs_to_penalize:
  210. # convert to the index into logits
  211. seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
  212. # itertools.product pairs each seq index with every token id
  213. logits_to_penalize.extend(
  214. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  215. if logits_to_penalize:
  216. # use zip and * to group indices along each dimension
  217. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  218. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  219. # verifies that no rows in logits were missed unexpectedly
  220. assert logits_applied == logits.shape[0]
  221. return logits
  222. def _apply_top_k_top_p(
  223. logits: torch.Tensor,
  224. p: torch.Tensor,
  225. k: torch.Tensor,
  226. ) -> torch.Tensor:
  227. logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
  228. # Apply top-k.
  229. top_k_mask = logits_sort.size(1) - k.to(torch.long)
  230. # Get all the top_k values.
  231. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
  232. top_k_mask = logits_sort < top_k_mask
  233. logits_sort.masked_fill_(top_k_mask, -float("inf"))
  234. # Apply top-p.
  235. probs_sort = logits_sort.softmax(dim=-1)
  236. probs_sum = probs_sort.cumsum(dim=-1)
  237. top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
  238. # at least one
  239. top_p_mask[:, -1] = False
  240. logits_sort.masked_fill_(top_p_mask, -float("inf"))
  241. # Re-sort the probabilities.
  242. src = torch.arange(logits_idx.shape[-1],
  243. device=logits_idx.device).expand_as(logits_idx)
  244. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  245. index=logits_idx,
  246. src=src)
  247. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  248. return logits
  249. def _apply_min_p(
  250. logits: torch.Tensor,
  251. min_p: torch.Tensor,
  252. ) -> torch.Tensor:
  253. """
  254. Adapted from
  255. https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
  256. """
  257. probs = torch.softmax(logits, dim=-1)
  258. top_probs, _ = probs.max(dim=-1, keepdim=True)
  259. scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
  260. tokens_to_remove = probs < scaled_min_p
  261. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  262. return logits
  263. def _apply_top_a(
  264. logits: torch.Tensor,
  265. top_a: torch.Tensor,
  266. ) -> torch.Tensor:
  267. probs = torch.softmax(logits, dim=-1)
  268. top_probs, _ = probs.max(dim=-1, keepdim=True)
  269. threshold = torch.pow(top_probs, 2) * top_a.unsqueeze_(dim=1)
  270. tokens_to_remove = probs < threshold
  271. logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
  272. return logits
  273. def _apply_tfs(
  274. logits: torch.Tensor,
  275. tfs: torch.Tensor,
  276. ) -> torch.Tensor:
  277. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  278. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  279. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  280. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  281. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  282. tfs_mask = torch.cat(
  283. (
  284. torch.zeros(
  285. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  286. tfs_mask,
  287. torch.ones(
  288. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  289. ),
  290. dim=-1,
  291. )
  292. logits_sort[tfs_mask] = -float("inf")
  293. logits = torch.gather(logits_sort,
  294. dim=-1,
  295. index=torch.argsort(logits_idx, dim=-1))
  296. return logits
  297. def _apply_eta_cutoff(
  298. logits: torch.Tensor,
  299. eta_cutoff: torch.Tensor,
  300. ) -> torch.Tensor:
  301. shifted_logits = torch.log_softmax(logits, dim=-1)
  302. probs = shifted_logits.exp()
  303. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  304. eps = torch.min(eta_cutoff,
  305. torch.sqrt(eta_cutoff) *
  306. torch.exp(neg_entropy)).unsqueeze(dim=1)
  307. eta_mask = probs < eps
  308. # guard against nulling out all the logits
  309. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  310. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  311. logits[eta_mask] = -float("inf")
  312. return logits
  313. def _apply_epsilon_cutoff(
  314. logits: torch.Tensor,
  315. epsilon_cutoff: torch.Tensor,
  316. ) -> torch.Tensor:
  317. probs = logits.softmax(dim=-1)
  318. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  319. # guard against nulling out all the logits
  320. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  321. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  322. logits[eps_mask] = -float("inf")
  323. return logits
  324. def _apply_typical_sampling(
  325. logits: torch.Tensor,
  326. typical_p: torch.Tensor,
  327. ) -> torch.Tensor:
  328. shifted_logits = torch.log_softmax(logits, dim=-1)
  329. probs = shifted_logits.exp()
  330. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  331. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  332. _, indices = torch.sort(surprisal_deviations)
  333. reordered_probs = probs.gather(-1, indices)
  334. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  335. dim=1)
  336. min_tokens_to_keep = 1
  337. # Keep at least min_tokens_to_keep
  338. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  339. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  340. logits[typ_mask] = -float("inf")
  341. return logits
  342. def _apply_quadratic_sampling(
  343. logits: torch.Tensor,
  344. smoothing_factor: torch.Tensor,
  345. smoothing_curve: torch.Tensor,
  346. ) -> torch.Tensor:
  347. """
  348. Applies a quadratic transformation to the logits based on the
  349. provided smoothing factors and curves. The transformation is
  350. centered around the maximum logit value in the batch.
  351. The transformation involves a quadratic and cubic term, with the
  352. cubic term controlled by the smoothing curve. The quadratic term is
  353. scaled by the smoothing factor, and the cubic term is scaled by the
  354. product of the smoothing factor and the smoothing curve.
  355. params:
  356. logits (torch.Tensor): The logits to be transformed.
  357. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  358. term in the transformation.
  359. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  360. in the transformation.
  361. returns:
  362. torch.Tensor: The transformed logits.
  363. Credits: @kalomaze
  364. """
  365. max_logits = logits.max(dim=-1, keepdim=True).values
  366. diff = logits - max_logits
  367. smoothing_factor.unsqueeze_(dim=1)
  368. smoothing_curve.unsqueeze_(dim=1)
  369. k = (3 - smoothing_curve) / 2
  370. s = (smoothing_curve - 1) / 2
  371. mask = smoothing_factor > 0
  372. mask = mask.flatten()
  373. transformed_logits = torch.where(
  374. logits != float('-inf'), -(k * smoothing_factor * diff**2) +
  375. (s * smoothing_factor * diff**3) + max_logits, logits)
  376. logits[mask, :] = transformed_logits[mask, :]
  377. return logits
  378. def _greedy_sample(
  379. selected_seq_groups: List[SequenceGroupToSample],
  380. samples: torch.Tensor,
  381. ) -> List[Tuple[List[int], List[int]]]:
  382. """Run greedy sampling on a given samples.
  383. Args:
  384. selected_seq_groups: A list of sequence groups batched.
  385. samples: (num_selected_samples,) A tensor of samples. The length of
  386. samples could be smaller than selected_seq_groups if
  387. seq_group.do_sample is False.
  388. Returns:
  389. Tuple of (next_token_ids, parent_ids). The length of returned list is
  390. same as the length of selected_seq_groups. If the corresponding
  391. seq_group has do_sample=False, tuple contains ([], [])
  392. """
  393. samples = samples.tolist()
  394. sample_idx = 0
  395. results = []
  396. for seq_group in selected_seq_groups:
  397. if not seq_group.do_sample:
  398. results.append(([], []))
  399. continue
  400. seq_ids = seq_group.seq_ids
  401. num_parent_seqs = len(seq_ids)
  402. assert num_parent_seqs == 1, (
  403. "Greedy sampling should have only one seq.")
  404. parent_ids = list(range(num_parent_seqs))
  405. next_token_ids = [samples[sample_idx]]
  406. results.append((next_token_ids, parent_ids))
  407. sample_idx += num_parent_seqs
  408. return results
  409. def _random_sample(
  410. selected_seq_groups: List[SequenceGroupToSample],
  411. random_samples: torch.Tensor,
  412. ) -> List[Tuple[List[int], List[int]]]:
  413. """Run random sampling on a given samples.
  414. Args:
  415. selected_seq_groups: A list of sequence groups batched.
  416. random_samples: (num_selected_samples,) A tensor of samples. The
  417. length of samples could be smaller than selected_seq_groups if
  418. seq_group.do_sample is False.
  419. Returns:
  420. Tuple of (next_token_ids, parent_ids). The length of returned list is
  421. same as the length of selected_seq_groups. If the corresponding
  422. seq_group has do_sample=False, tuple contains ([], [])
  423. """
  424. # Find the maximum best_of value of the prompt phase requests.
  425. random_samples = random_samples.cpu()
  426. sample_idx = 0
  427. results = []
  428. for seq_group in selected_seq_groups:
  429. if not seq_group.do_sample:
  430. results.append(([], []))
  431. continue
  432. seq_ids = seq_group.seq_ids
  433. sampling_params = seq_group.sampling_params
  434. is_prompt = seq_group.is_prompt
  435. num_parent_seqs = len(seq_ids)
  436. if is_prompt:
  437. # Prompt phase.
  438. parent_ids = [0] * sampling_params.best_of
  439. next_token_ids = random_samples[
  440. sample_idx, :sampling_params.best_of].tolist()
  441. else:
  442. # Generation phase.
  443. parent_ids = list(range(num_parent_seqs))
  444. next_token_ids = random_samples[sample_idx:sample_idx +
  445. num_parent_seqs, 0].tolist()
  446. results.append((next_token_ids, parent_ids))
  447. sample_idx += num_parent_seqs
  448. return results
  449. def _beam_search_sample(
  450. selected_seq_groups: List[SequenceGroupToSample],
  451. logprobs: torch.Tensor,
  452. ) -> List[Tuple[List[int], List[int]]]:
  453. """Run beam sampling on a given samples.
  454. Args:
  455. selected_seq_groups: A list of sequence groups batched.
  456. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  457. on selected sample indices.
  458. Returns:
  459. Tuple of (next_token_ids, parent_ids). The length of returned list is
  460. same as the length of selected_seq_groups. If the corresponding
  461. seq_group has do_sample=False, tuple contains ([], [])
  462. """
  463. # We sample 2 * beam_width candidates to make sure that with high
  464. # probability we can get `beam_width` candidates in addition to
  465. # the finished sequences for the next iteration. See
  466. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  467. # for details. See also HF reference:
  468. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  469. #
  470. # NOTE: Beam search is not vectorized, so its speed can be slower than
  471. # other sampling methods.
  472. sample_idx = 0
  473. results = []
  474. for seq_group in selected_seq_groups:
  475. if not seq_group.do_sample:
  476. results.append(([], []))
  477. continue
  478. is_prompt = seq_group.is_prompt
  479. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  480. num_parent_seqs = len(seq_ids)
  481. beam_width = sampling_params.best_of
  482. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  483. if is_prompt:
  484. # Prompt phase.
  485. assert num_parent_seqs == 1, (
  486. "Prompt input should have only one seq.")
  487. parent_ids = [0] * (2 * beam_width)
  488. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  489. 2 * beam_width)
  490. next_token_ids = next_token_ids.tolist()
  491. else:
  492. # Generation phase.
  493. cumulative_logprobs = [
  494. seq_group.seq_data[seq_id].cumulative_logprob
  495. for seq_id in seq_ids
  496. ]
  497. cumulative_logprobs = torch.tensor(
  498. cumulative_logprobs,
  499. dtype=torch.float,
  500. device=seq_group_logprobs.device)
  501. seq_group_logprobs = (seq_group_logprobs +
  502. cumulative_logprobs.unsqueeze(dim=1))
  503. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  504. 2 * beam_width)
  505. topk_ids = topk_ids.tolist()
  506. vocab_size = seq_group_logprobs.size(-1)
  507. parent_ids = [i // vocab_size for i in topk_ids]
  508. next_token_ids = [i % vocab_size for i in topk_ids]
  509. results.append((next_token_ids, parent_ids))
  510. sample_idx += num_parent_seqs
  511. assert sample_idx == logprobs.size(0)
  512. return results
  513. # torch.multinomial forces a GPU<->CPU sync.
  514. # Therefore, we use an optimized implementation instead.
  515. # Note that we always sample with replacement.
  516. # probs will be modified in place, but this is fine, as we pass
  517. # in a copy already.
  518. def _multinomial(
  519. probs: torch.Tensor,
  520. num_samples: int,
  521. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  522. ) -> torch.Tensor:
  523. if num_samples > 1:
  524. # This is equivalent to torch.repeat_interleaved (which also
  525. # forces a GPU<->CPU sync).
  526. # This allows us to do sampling with replacement by creating
  527. # num_samples copies of each row in the tensor, and then
  528. # batch sampling the resulting tensor.
  529. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  530. probs.shape[1]).contiguous().view(
  531. -1, probs.shape[1])
  532. q = torch.empty_like(probs)
  533. if seq_groups is None:
  534. q.exponential_()
  535. else:
  536. sample_idx = 0
  537. for seq_group in seq_groups:
  538. seq_ids = seq_group.seq_ids
  539. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  540. q[sample_idx:next_sample_idx].exponential_(
  541. generator=seq_group.generator)
  542. sample_idx = next_sample_idx
  543. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  544. def _sample_with_torch(
  545. probs: torch.Tensor,
  546. logprobs: torch.Tensor,
  547. sampling_metadata: SamplingMetadata,
  548. include_gpu_probs_tensor: bool,
  549. modify_greedy_probs: bool,
  550. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  551. categorized_seq_group_ids = {t: [] for t in SamplingType}
  552. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  553. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  554. sampling_params = seq_group.sampling_params
  555. sampling_type = sampling_params.sampling_type
  556. categorized_seq_group_ids[sampling_type].append(i)
  557. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  558. sample_metadata = {}
  559. multinomial_samples = {}
  560. # Create output tensor for sampled token ids.
  561. if include_gpu_probs_tensor:
  562. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  563. 1,
  564. dtype=torch.long,
  565. device=logprobs.device)
  566. else:
  567. sampled_token_ids_tensor = None
  568. # Counterintuitively, having two loops here is actually faster.
  569. # The first loop can run without waiting on GPU<->CPU sync.
  570. for sampling_type in SamplingType:
  571. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  572. num_tokens = len(sample_indices)
  573. if num_tokens == 0:
  574. continue
  575. seq_group_id = categorized_seq_group_ids[sampling_type]
  576. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  577. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  578. long_sample_indices = sample_indices.long()
  579. if sampling_type == SamplingType.GREEDY:
  580. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  581. dim=-1)
  582. if include_gpu_probs_tensor:
  583. # Store sampled tokens in output tensor.
  584. sampled_token_ids_tensor[
  585. long_sample_indices] = greedy_samples.unsqueeze(-1)
  586. if modify_greedy_probs:
  587. # If required, modify the probabilities such that sampling from
  588. # the modified distribution would always sample the argmax
  589. # token id.
  590. _modify_greedy_probs_inplace(logprobs, probs,
  591. long_sample_indices,
  592. greedy_samples)
  593. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  594. max_best_of_in_batch = 1
  595. for seq_group in seq_groups:
  596. if seq_group.is_prompt:
  597. sampling_params = seq_group.sampling_params
  598. max_best_of_in_batch = max(max_best_of_in_batch,
  599. sampling_params.best_of)
  600. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  601. "seq_groups": seq_groups,
  602. }
  603. multinomial_samples[sampling_type] = _multinomial(
  604. probs[long_sample_indices], max_best_of_in_batch,
  605. **seeded_args)
  606. if include_gpu_probs_tensor:
  607. # Store sampled tokens in output tensor.
  608. sampled_token_ids_tensor[
  609. long_sample_indices] = multinomial_samples[sampling_type]
  610. elif sampling_type == SamplingType.BEAM:
  611. beam_search_logprobs = logprobs[sample_indices]
  612. else:
  613. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  614. # GPU<->CPU sync happens in the loop below.
  615. # This also converts the sample output to Python objects.
  616. for sampling_type in SamplingType:
  617. if sampling_type not in sample_metadata:
  618. continue
  619. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  620. if sampling_type == SamplingType.GREEDY:
  621. sample_results = _greedy_sample(seq_groups, greedy_samples)
  622. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  623. sample_results = _random_sample(seq_groups,
  624. multinomial_samples[sampling_type])
  625. elif sampling_type == SamplingType.BEAM:
  626. sample_results = _beam_search_sample(seq_groups,
  627. beam_search_logprobs)
  628. sample_results_dict.update(zip(seq_group_id, sample_results))
  629. sample_results = [
  630. sample_results_dict.get(i, ([], []))
  631. for i in range(len(sampling_metadata.seq_groups))
  632. ]
  633. return sample_results, sampled_token_ids_tensor
  634. def _sample_with_triton_kernel(
  635. probs: torch.Tensor,
  636. logprobs: torch.Tensor,
  637. sampling_metadata: SamplingMetadata,
  638. sampling_tensors: SamplingTensors,
  639. ) -> List[Tuple[List[int], List[int]]]:
  640. categorized_seq_group_ids = {t: [] for t in SamplingType}
  641. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  642. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  643. sampling_params = seq_group.sampling_params
  644. sampling_type = sampling_params.sampling_type
  645. categorized_seq_group_ids[sampling_type].append(i)
  646. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  647. sample_metadata = {}
  648. max_best_of_in_batch = 1
  649. # Counterintuitively, having two loops here is actually faster.
  650. # The first loop can run without waiting on GPU<->CPU sync.
  651. for sampling_type in SamplingType:
  652. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  653. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  654. num_tokens = len(sample_indices)
  655. if num_tokens == 0:
  656. continue
  657. seq_group_id = categorized_seq_group_ids[sampling_type]
  658. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  659. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  660. sample_indices,
  661. sampled_token_indices)
  662. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  663. SamplingType.RANDOM_SEED):
  664. for seq_group in seq_groups:
  665. if seq_group.is_prompt:
  666. sampling_params = seq_group.sampling_params
  667. max_best_of_in_batch = max(max_best_of_in_batch,
  668. sampling_params.best_of)
  669. elif sampling_type == SamplingType.BEAM:
  670. beam_search_logprobs = logprobs[sample_indices]
  671. else:
  672. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  673. sampled_tokens, _, _ = sample_triton(
  674. probs=probs,
  675. seeds=sampling_tensors.sampling_seeds,
  676. max_best_of=max_best_of_in_batch,
  677. sample_indices=sampling_tensors.sample_indices,
  678. logprobs=logprobs,
  679. # don't save logprobs because we have logic for that below
  680. # TODO: use this instead of the CPU-based logic below
  681. save_logprobs=False,
  682. )
  683. # GPU<->CPU sync happens in the loop below.
  684. for sampling_type in SamplingType:
  685. if sampling_type not in sample_metadata:
  686. continue
  687. (seq_group_id, seq_groups, sample_indices,
  688. sampled_token_indices) = sample_metadata[sampling_type]
  689. if sampling_type == SamplingType.GREEDY:
  690. sample_results = _greedy_sample(
  691. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  692. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  693. sample_results = _random_sample(
  694. seq_groups, sampled_tokens[sampled_token_indices])
  695. elif sampling_type == SamplingType.BEAM:
  696. sample_results = _beam_search_sample(seq_groups,
  697. beam_search_logprobs)
  698. sample_results_dict.update(zip(seq_group_id, sample_results))
  699. sample_results = [
  700. sample_results_dict.get(i, ([], []))
  701. for i in range(len(sampling_metadata.seq_groups))
  702. ]
  703. return sample_results
  704. def _sample(
  705. probs: torch.Tensor, logprobs: torch.Tensor,
  706. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  707. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  708. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  709. """
  710. Args:
  711. probs: (num_query_tokens_in_batch, num_vocab)
  712. logprobs: (num_query_tokens_in_batch, num_vocab)
  713. sampling_metadata: The metadata for a batch for sampling.
  714. sampling_tensors: Tensors that include sampling related metadata.
  715. Returns:
  716. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  717. If sampling is skipped, it returns ([], [])
  718. sampled_token_ids_tensor: A tensor of sampled token ids.
  719. """
  720. return _sample_with_torch(
  721. probs,
  722. logprobs,
  723. sampling_metadata,
  724. include_gpu_probs_tensor=include_gpu_probs_tensor,
  725. modify_greedy_probs=modify_greedy_probs,
  726. )
  727. # TODO: Enable once Triton kernel & associated code is faster.
  728. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  729. # sampling_tensors)
  730. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  731. """
  732. This function calculates the ranks of the chosen tokens in a logprob tensor.
  733. Args:
  734. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  735. where N is the no. of tokens and M is the vocab dim.
  736. indices (torch.Tensor): List of chosen token indices.
  737. Returns:
  738. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  739. Each element in the returned tensor represents the rank
  740. of the chosen token in the input logprob tensor.
  741. """
  742. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  743. indices]
  744. return (x > vals[:, None]).long().sum(1).add_(1)
  745. def _get_logprobs(
  746. logprobs: torch.Tensor,
  747. sampling_metadata: SamplingMetadata,
  748. sample_results: List[Tuple[List[int], List[int]]],
  749. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  750. """Return sample lobprobs and prompt logprobs.
  751. The logic consists of 3 parts.
  752. - Select indices to compute logprob from, ranks of token ids, and
  753. the top k token ids from logprobs.
  754. - Compute prompt logprobs if required.
  755. - Compute sample logprobs if required.
  756. Args:
  757. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  758. logprob per vocab. Sequence groups' query tokens are batched in a
  759. single flattened tensor. For example, assuming there are N
  760. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  761. prompt logprob is enabled), decode tokens for seq_group_1 (if
  762. sampling is required), prefill tokens for seq_group_2, ...
  763. sampling_metadata: The sampling metadata.
  764. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  765. parent_ids) for each sequence group. When beam search is enabled,
  766. sample_results can contain different number of seq_ids from
  767. sampling_metadata.seq_groups. It is because beam search creates
  768. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  769. BEAM_WIDTH number of seq_ids).
  770. Returns:
  771. A tuple of prompt and sample logprobs per sequence group in a batch.
  772. """
  773. # The index of query token to calculate logprobs. It includes both
  774. # prompt and sample logprob indices.
  775. query_indices: List[int] = []
  776. # The next token ids to get the logprob value from.
  777. next_token_ids: List[int] = []
  778. # The largest requested number of logprobs. We find logprobs as many as the
  779. # largest num logprobs in this API.
  780. largest_num_logprobs = 1
  781. # Select indices to compute logprob from, ranks of token ids, and the top
  782. # k token ids from logprobs.
  783. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  784. sample_results):
  785. sampling_params = seq_group.sampling_params
  786. # Update indices and tokens for prompt logprobs.
  787. if (seq_group.is_prompt
  788. and sampling_params.prompt_logprobs is not None):
  789. largest_num_logprobs = max(largest_num_logprobs,
  790. sampling_params.prompt_logprobs)
  791. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  792. query_indices.extend(seq_group.prompt_logprob_indices)
  793. next_token_ids.extend(next_prompt_tokens)
  794. # Update indices and next tokenes for sample logprob.
  795. if seq_group.do_sample:
  796. token_ids, parent_seq_ids = sample_result
  797. # NOTE: We cannot directly use sample_indices because
  798. # sample_indices only contain parent seq_ids of a previous step.
  799. # The current step may have different number of seq_ids, and
  800. # we can obtain it from `sample_result[1]`.
  801. query_idx = seq_group.sample_indices[0]
  802. query_indices.extend(
  803. [query_idx + parent_id for parent_id in parent_seq_ids])
  804. next_token_ids.extend(token_ids)
  805. if sampling_params.logprobs is not None:
  806. largest_num_logprobs = max(largest_num_logprobs,
  807. sampling_params.logprobs)
  808. assert len(next_token_ids) == len(query_indices)
  809. if len(query_indices) == 0:
  810. empty_sampled_logprob = []
  811. empty_prompt_logprob = None
  812. return [empty_prompt_logprob], [empty_sampled_logprob]
  813. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  814. next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
  815. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  816. # contain duplicates if beam search is enabled.
  817. selected_logprobs = logprobs[[
  818. query_indices_gpu,
  819. next_token_ids_gpu,
  820. ]]
  821. ranks = _get_ranks(
  822. logprobs[query_indices_gpu],
  823. next_token_ids_gpu,
  824. )
  825. assert selected_logprobs.shape[0] == ranks.shape[0]
  826. # Logprobs of topk tokens for a batch of sequence groups.
  827. # (num_query_tokens_across_batch).
  828. if largest_num_logprobs > 0:
  829. top_logprobs, top_token_ids = torch.topk(logprobs,
  830. largest_num_logprobs,
  831. dim=-1)
  832. else:
  833. top_logprobs, top_token_ids = None, None
  834. selected_logprobs = selected_logprobs.to('cpu')
  835. ranks = ranks.to('cpu')
  836. if top_logprobs is not None and top_token_ids is not None:
  837. top_logprobs = top_logprobs.to('cpu')
  838. top_token_ids = top_token_ids.to('cpu')
  839. # Find prompt/sample logprobs.
  840. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  841. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  842. top_logprob_idx = 0
  843. selected_logprobs_idx = 0
  844. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  845. sample_results):
  846. (prompt_logprobs, top_logprob_idx,
  847. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  848. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  849. selected_logprobs_idx, top_logprob_idx)
  850. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  851. (sampled_logprobs, top_logprob_idx,
  852. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  853. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  854. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  855. sample_logprobs_per_seq_group.append(sampled_logprobs)
  856. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  857. def _get_prompt_logprob_if_needed(
  858. seq_group: SequenceGroupToSample,
  859. selected_logprobs: torch.Tensor,
  860. ranks: torch.Tensor,
  861. top_token_ids: torch.Tensor,
  862. top_logprobs: torch.Tensor,
  863. selected_logprobs_idx: int,
  864. top_logprob_idx: int,
  865. ):
  866. """Compute the prompt logprob from a sequence group if needed."""
  867. sampling_params = seq_group.sampling_params
  868. is_prompt = seq_group.is_prompt
  869. # Find prompt logprobs
  870. prompt_logprobs: Optional[PromptLogprobs] = None
  871. if is_prompt and sampling_params.prompt_logprobs is not None:
  872. prompt_logprobs = []
  873. num_logprobs = sampling_params.prompt_logprobs
  874. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  875. # Pre-select indexes and create a list. It is faster than calling .item
  876. # repetitively.
  877. selected_logprob_items = selected_logprobs[
  878. selected_logprobs_idx:selected_logprobs_idx +
  879. len(next_prompt_tokens)].tolist()
  880. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  881. len(next_prompt_tokens)].tolist()
  882. for idx, token_id in enumerate(next_prompt_tokens):
  883. # Calculate the prompt logprob of the real prompt tokens.
  884. # {token_id: (logprob, rank_from_vocab)}
  885. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  886. token_id: (selected_logprob_items[idx], rank_items[idx])
  887. }
  888. # Add top K prompt logprobs along with its rank.
  889. if num_logprobs > 0:
  890. top_ids = top_token_ids[
  891. top_logprob_idx, :num_logprobs].tolist()
  892. top_probs = top_logprobs[
  893. top_logprob_idx, :num_logprobs].tolist()
  894. # Top K is already sorted by rank, so we can use 1 ~
  895. # num_logprobs + 1 for rank.
  896. top_ranks = range(1, num_logprobs + 1)
  897. prompt_logprobs_dict.update({
  898. top_id: (top_prob, rank)
  899. for top_id, top_prob, rank in zip(top_ids, top_probs,
  900. top_ranks)
  901. })
  902. prompt_logprobs.append({
  903. token_id: Logprob(*logprob_and_rank)
  904. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  905. })
  906. # + 1 to go to the next prompt token.
  907. top_logprob_idx += 1
  908. # + len(next_prompt_tokens) to go to the next prompt.
  909. selected_logprobs_idx += len(next_prompt_tokens)
  910. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  911. def _get_sampled_logprob_if_needed(
  912. seq_group: SequenceGroupToSample,
  913. sample_result: Tuple[List[int], List[int]],
  914. selected_logprobs: torch.Tensor,
  915. ranks: torch.Tensor,
  916. top_token_ids: torch.Tensor,
  917. top_logprobs: torch.Tensor,
  918. selected_logprobs_idx: int,
  919. top_logprob_idx: int,
  920. ):
  921. """Compute the sample logprob if needed."""
  922. seq_ids = seq_group.seq_ids
  923. num_logprobs = seq_group.sampling_params.logprobs or 0
  924. sampled_logprobs: SampleLogprobs = []
  925. next_token_ids, parent_seq_ids = sample_result
  926. if seq_group.do_sample:
  927. assert len(next_token_ids) > 0
  928. # Pre-select items from tensor. tolist() is faster than repetitive
  929. # `.item()` calls.
  930. selected_logprob_items = selected_logprobs[
  931. selected_logprobs_idx:selected_logprobs_idx +
  932. len(next_token_ids)].tolist()
  933. rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
  934. len(next_token_ids)].tolist()
  935. for idx, (next_token_id,
  936. parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
  937. # Get the logprob of a sampled token.
  938. sampled_logprobs_dict = {
  939. next_token_id: (selected_logprob_items[idx], rank_items[idx])
  940. }
  941. # Get top K logprobs.
  942. if num_logprobs > 0:
  943. top_ids = top_token_ids[top_logprob_idx +
  944. parent_id, :num_logprobs].tolist()
  945. top_probs = top_logprobs[top_logprob_idx +
  946. parent_id, :num_logprobs].tolist()
  947. # Top K is already sorted by rank, so we can use 1 ~
  948. # num_logprobs + 1 for rank.
  949. top_ranks = range(1, num_logprobs + 1)
  950. sampled_logprobs_dict.update({
  951. top_id: (top_prob, rank)
  952. for top_id, top_prob, rank in zip(top_ids, top_probs,
  953. top_ranks)
  954. })
  955. sampled_logprobs.append({
  956. token_id: Logprob(*logprob_and_rank)
  957. for token_id, logprob_and_rank in
  958. sampled_logprobs_dict.items()
  959. })
  960. # NOTE: This part of code is not intuitive. `selected_logprobs` include
  961. # logprobs for the current step, which has len(next_token_ids) tokens
  962. # per sequence group. `logprobs` includes logprobs from the previous
  963. # steps, which has len(seq_ids) tokens per sequence group.
  964. # Iterate to the next sequence group in a batch.
  965. selected_logprobs_idx += len(next_token_ids)
  966. # Iterate to the next sequence group in a batch.
  967. top_logprob_idx += len(seq_ids)
  968. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  969. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  970. sample_indices: torch.Tensor,
  971. greedy_samples: torch.Tensor) -> None:
  972. """Modify the probability distributions of the greedily-sampled tokens such
  973. that each sampled token has a "probability" of 1.0. This is required by
  974. speculative decoding, which depends on the sampling method being encoded
  975. within the probability distribution for correctness.
  976. # Why do we only need to do this for greedy sampling?
  977. Aphrodite's sampler performs the following steps for greedy or multinomial
  978. (random) sampling:
  979. 1. Get logits from model.
  980. 2. Modify logits according to per-sequence sampling parameters.
  981. - Multiply by temperature, top-k and top-p masking, penalize tokens
  982. according to their frequency, etc.
  983. 3. Sample a token.
  984. - Random sampling simply samples from the modified probability
  985. distribution.
  986. - Greedy sampling performs `argmax` to obtain the token with the
  987. highest likelihood.
  988. Ignoring greedy sampling for a moment, we find that the computed probability
  989. distribution has the following property: we can sample from it independently
  990. and find that the token sampled by the Sampler has a frequency corresponding
  991. to how often we see it in our sampling. In other words, for tokens sampled
  992. with Aphrodite's random SamplingType, the computed probability distribution
  993. encodes the sampling methodology completely.
  994. Greedy sampling does not normally have this property. Aphrodite modifies
  995. logits according to sampling params, then performs `argmax`, then returns
  996. the sampled token and the computed probability distribution. If we sample
  997. from the distribution, we'll find the likelihood of the greedily-sampled
  998. token is not always 1.0.
  999. Since lossless speculative decoding requires that the sampling methodology
  1000. be encoded within the probability distribution, we are motivated to modify
  1001. the probability distribution such that the sampled token has probability 1
  1002. when speculative decoding is used.
  1003. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1004. greedy sampling using multinomial computation and unite the codepaths. This
  1005. has implications on the overall design of the sampler, e.g. how to record
  1006. accurate logprobs for the user, so this improvement is deferred to later.
  1007. """
  1008. # NOTE: logprobs are not modified so they can be returned to the user.
  1009. probs[sample_indices, :] = 0
  1010. probs[sample_indices, greedy_samples] = 1.0
  1011. def _build_sampler_output(
  1012. sample_results: List[Tuple[List[int], List[int]]],
  1013. sampling_metadata: SamplingMetadata,
  1014. prompt_logprobs: List[Optional[PromptLogprobs]],
  1015. sample_logprobs: List[SampleLogprobs],
  1016. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
  1017. torch.Tensor]],
  1018. ) -> SamplerOutput:
  1019. """Construct Python objects with the output of sampling.
  1020. Args:
  1021. on_device_tensors: Tuple containing on-device tensors with the
  1022. probabilities used in sampling and the sampled token ids. This
  1023. allows post-processing without copies to CPU/serialization, e.g. in
  1024. speculative decoding rejection sampling.
  1025. """
  1026. sampler_output = []
  1027. for (seq_group, sample_result, group_prompt_logprobs,
  1028. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1029. sample_results, prompt_logprobs,
  1030. sample_logprobs):
  1031. seq_ids = seq_group.seq_ids
  1032. next_token_ids, parent_ids = sample_result
  1033. seq_outputs = []
  1034. for parent_id, next_token_id, logprobs in zip(parent_ids,
  1035. next_token_ids,
  1036. group_sample_logprobs):
  1037. seq_outputs.append(
  1038. SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
  1039. sampler_output.append(
  1040. CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  1041. # If not specified, store None values in SamplerOutput.
  1042. if on_device_tensors is not None:
  1043. (sampled_token_probs, logprobs_tensor,
  1044. sampled_token_ids) = on_device_tensors
  1045. else:
  1046. sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
  1047. None)
  1048. return SamplerOutput(
  1049. outputs=sampler_output,
  1050. sampled_token_probs=sampled_token_probs,
  1051. sampled_token_ids=sampled_token_ids,
  1052. logprobs=logprobs_tensor,
  1053. )
  1054. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1055. """Get a list of next prompt tokens to compute logprob from a
  1056. given sequence group.
  1057. It is used to compute prompt logprob. Imagine you have logprob for each
  1058. query token. Query token needs to know the next prompt token id to compute
  1059. prompt logprob. This is a helper to obtain next prompt token ids.
  1060. This API has to be used only when the caller knows seq_group is in prefill
  1061. stage.
  1062. Returns:
  1063. A list of next prompt tokens to compute logprob.
  1064. """
  1065. assert seq_group.is_prompt, (
  1066. "Caller should ensure the sequence group is in a prefill stage.")
  1067. seq_ids = seq_group.seq_ids
  1068. query_len = seq_group.query_len
  1069. assert query_len is not None
  1070. # prompt has only 1 seq id.
  1071. assert len(seq_ids) == 1
  1072. seq_data = seq_group.seq_data[seq_ids[0]]
  1073. computed_len = seq_data.get_num_computed_tokens()
  1074. prompt_tokens = seq_data.prompt_token_ids
  1075. # +1 because we are looking for a next prompt token.
  1076. next_token_index_start = computed_len + 1
  1077. next_token_index_end = min(computed_len + query_len + 1,
  1078. len(prompt_tokens))
  1079. next_prompt_tokens = prompt_tokens[
  1080. next_token_index_start:next_token_index_end]
  1081. return next_prompt_tokens
  1082. # def _apply_mirostat_v2(logits: torch.Tensor,
  1083. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1084. # # Reduce our view to just the affected logits
  1085. # logit_view = logits[sampling_tensors.miro_indices]
  1086. # # Calculate surprise value per token
  1087. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1088. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1089. # # Mask out "too-surprising" tokens (surprisal > mu)
  1090. # mus = sampling_tensors.miro_mus
  1091. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1092. # # Unmask most-likely logit to guarantee a selection.
  1093. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1094. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1095. # # Apply logit mask (effectively a top-k filter).
  1096. # logit_view[miro_mask] = -float("inf")
  1097. # # Project logit changes made to the view onto the original.
  1098. # # I think this step might be redundant.
  1099. # logits[sampling_tensors.miro_indices] = logit_view
  1100. # return logits
  1101. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1102. # sample_results: List[Tuple[List[int], List[int]]],
  1103. # sampling_metadata: SamplingMetadata,
  1104. # output_metadata: OutputMetadata) -> None:
  1105. # """Based on whichever token was finally sampled, we calculate the
  1106. # final surprisal values to update the mus.
  1107. # Because a single sequence can have multiple samples, we must fork
  1108. # the mu accordingly."""
  1109. # assert sampling_metadata.seq_groups is not None
  1110. # seqid_to_tokens = {}
  1111. # seqid_to_indices = {}
  1112. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1113. # sample_results):
  1114. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1115. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1116. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1117. # seqids = args.miro_seqids
  1118. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1119. # device=logits.device,
  1120. # dtype=torch.long)
  1121. # # Clumsily, we recalculate token surprisals.
  1122. # logits_view = logits[args.miro_indices]
  1123. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1124. # dim=-1,
  1125. # index=picked_tokens) / -math.log(2)
  1126. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1127. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1128. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1129. # nu_mus = mus - (picked_surprise - taus) * etas
  1130. # # Record updated mu values for use in the next iteration
  1131. # # Note how each mu is split into multiple based on the number of samples.
  1132. # for seqid, seq_mus in zip(seqids, nu_mus):
  1133. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1134. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)