sampler.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. import torch.nn as nn
  6. from aphrodite.common.sampling_params import SamplingType
  7. from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
  8. SamplerOutput, SequenceGroupOutput,
  9. SequenceOutput)
  10. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  11. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  12. SamplingTensors,
  13. SequenceGroupToSample)
  14. class Sampler(nn.Module):
  15. """Samples the next tokens from the model's outputs.
  16. This layer does the following:
  17. 1. Discard the hidden states that are not used for sampling (i.e., all
  18. tokens except the final one in each prompt).
  19. 2. Compute the logits for the next tokens.
  20. 3. Apply all the different sampler functions in the specified order.
  21. 4. Sample the next tokens.
  22. Here, each sequence group within the batch can have different sampling
  23. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  24. The structure of the logits tensor is coupled with the seq_groups in
  25. sampling_metadata. Typically, each sequence in each seq_group has one row in
  26. logits for the next token to be sampled; however, for a seq_group with a
  27. prompt request with the prompt_logprobs sampling parameter, there are rows
  28. in logits for each token in the input prompt.
  29. """
  30. def __init__(self):
  31. super().__init__()
  32. # Whether or not the SamplerOutput should have on-device tensors
  33. # containing the sampled token ids and probabilities. This is used by
  34. # speculative decoding.
  35. self.include_gpu_probs_tensor = False
  36. def forward(
  37. self,
  38. logits: torch.Tensor,
  39. sampling_metadata: SamplingMetadata,
  40. ) -> Optional[SamplerOutput]:
  41. """
  42. Args:
  43. logits: (num_tokens, vocab_size).
  44. sampling_metadata: Metadata for sampling.
  45. """
  46. assert logits is not None
  47. _, vocab_size = logits.shape
  48. # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  49. # have not been generated yet
  50. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  51. # Prepare sampling tensors with pinned memory to avoid blocking.
  52. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  53. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  54. do_typical_ps,
  55. do_quadratic) = (SamplingTensors.from_sampling_metadata(
  56. sampling_metadata, vocab_size, logits.device, logits.dtype))
  57. if do_penalties:
  58. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  59. sampling_tensors.output_tokens,
  60. sampling_tensors.presence_penalties,
  61. sampling_tensors.frequency_penalties,
  62. sampling_tensors.repetition_penalties)
  63. if (do_topks or do_topps or do_topas or do_minps):
  64. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  65. sampling_tensors.top_ks,
  66. sampling_tensors.top_as,
  67. sampling_tensors.min_ps)
  68. if do_tfss:
  69. logits = _apply_tfs(logits, sampling_tensors.tfss)
  70. if do_eta_cutoffs:
  71. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  72. if do_epsilon_cutoffs:
  73. logits = _apply_epsilon_cutoff(logits,
  74. sampling_tensors.epsilon_cutoffs)
  75. if do_typical_ps:
  76. logits = _apply_typical_sampling(logits,
  77. sampling_tensors.typical_ps)
  78. if do_quadratic:
  79. logits = _apply_quadratic_sampling(
  80. logits, sampling_tensors.smoothing_factors,
  81. sampling_tensors.smoothing_curves)
  82. if do_temperatures:
  83. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  84. sampling_tensors.dynatemp_mins,
  85. sampling_tensors.dynatemp_maxs,
  86. sampling_tensors.dynatemp_exps)
  87. banned_tokens = _get_custom_token_bans(sampling_metadata)
  88. # assert len(banned_tokens) == logits.shape[0]
  89. logits = _apply_token_bans(logits, banned_tokens)
  90. # We use float32 for probabilities and log probabilities.
  91. # Compute the probabilities.
  92. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  93. # Compute the log probabilities.
  94. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  95. # Sample the next tokens.
  96. sample_results, maybe_sampled_tokens_tensor = _sample(
  97. probs,
  98. logprobs,
  99. sampling_metadata,
  100. sampling_tensors,
  101. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  102. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  103. )
  104. if self.include_gpu_probs_tensor:
  105. assert maybe_sampled_tokens_tensor is not None
  106. sampled_tokens_tensor = maybe_sampled_tokens_tensor
  107. on_device_tensors = (probs, sampled_tokens_tensor)
  108. else:
  109. on_device_tensors = None
  110. # Get the logprobs query results.
  111. prompt_logprobs, sample_logprobs = _get_logprobs(
  112. logprobs, sampling_metadata, sample_results)
  113. return _build_sampler_output(sample_results,
  114. sampling_metadata,
  115. prompt_logprobs,
  116. sample_logprobs,
  117. on_device_tensors=on_device_tensors)
  118. @property
  119. def _should_modify_greedy_probs_inplace(self) -> bool:
  120. """Whether or not the sampler should modify the probability distribution
  121. of greedily-sampled tokens such that multinomial sampling would sample
  122. the greedily-sampled token.
  123. In other words, if True then we set the probability of the greedily-
  124. sampled token to 1.
  125. This is used by speculative decoding, which requires that the sampling
  126. method be encoded into the probability distribution.
  127. """
  128. # Modify greedy probs if include_gpu_probs_tensor is set.
  129. return self.include_gpu_probs_tensor
  130. def _get_bin_counts_and_mask(
  131. tokens: torch.Tensor,
  132. vocab_size: int,
  133. num_seqs: int,
  134. ) -> Tuple[torch.Tensor, torch.Tensor]:
  135. # Compute the bin counts for the tokens.
  136. # vocab_size + 1 for padding.
  137. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  138. dtype=torch.long,
  139. device=tokens.device)
  140. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  141. bin_counts = bin_counts[:, :vocab_size]
  142. mask = bin_counts > 0
  143. return bin_counts, mask
  144. def _get_custom_token_bans(
  145. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  146. assert sampling_metadata.seq_groups is not None
  147. banned_tokens: List[List[int]] = []
  148. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  149. sampling_params = sampling_metadata.seq_groups[i].sampling_params
  150. seq_ids = seq_group.seq_ids
  151. custom_token_bans = sampling_params.custom_token_bans
  152. if (i < sampling_metadata.num_prompts
  153. and sampling_params.prompt_logprobs is not None):
  154. prompt_len = len(seq_group.prompt_logprob_indices)
  155. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  156. banned_tokens += [custom_token_bans] * len(seq_ids)
  157. return banned_tokens
  158. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  159. output_tokens_tensor: torch.Tensor,
  160. presence_penalties: torch.Tensor,
  161. frequency_penalties: torch.Tensor,
  162. repetition_penalties: torch.Tensor) -> torch.Tensor:
  163. num_seqs, vocab_size = logits.shape
  164. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  165. num_seqs)
  166. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  167. output_tokens_tensor, vocab_size, num_seqs)
  168. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  169. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  170. logits = torch.where(logits > 0, logits / repetition_penalties,
  171. logits * repetition_penalties)
  172. # We follow the definition in OpenAI API.
  173. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  174. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  175. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  176. return logits
  177. def _apply_token_bans(logits: torch.Tensor,
  178. banned_tokens: List[List[int]]) -> torch.Tensor:
  179. for i, banned_token_ids in enumerate(banned_tokens):
  180. if not banned_token_ids:
  181. continue
  182. logits[i, banned_token_ids] = -float("inf")
  183. return logits
  184. def _apply_min_tokens_penalty(
  185. logits: torch.Tensor,
  186. sampling_metadata: SamplingMetadata,
  187. ) -> torch.Tensor:
  188. """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  189. have not been generated yet
  190. """
  191. # list of indices in logits that will be set to -inf
  192. logits_to_penalize = []
  193. logits_applied = 0
  194. for seq_group in sampling_metadata.seq_groups:
  195. seq_ids = seq_group.seq_ids
  196. sampling_params = seq_group.sampling_params
  197. sample_indices = seq_group.sample_indices
  198. logits_applied += len(sample_indices) + len(
  199. seq_group.prompt_logprob_indices)
  200. if not seq_group.do_sample:
  201. continue
  202. start_idx = sample_indices[0]
  203. min_tokens = sampling_params.min_tokens
  204. if min_tokens > 0:
  205. seqs_to_penalize = []
  206. for i, seq_id in enumerate(seq_ids):
  207. seq_data = seq_group.seq_data[seq_id]
  208. if len(seq_data.output_token_ids) < min_tokens:
  209. seqs_to_penalize.append(i)
  210. if seqs_to_penalize:
  211. # convert to the index into logits
  212. seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
  213. # use set() to remove any duplicates
  214. token_ids_to_penalize = set(sampling_params.stop_token_ids +
  215. [sampling_params.eos_token_id])
  216. # itertools.product pairs each seq index with every token id
  217. logits_to_penalize.extend(
  218. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  219. if logits_to_penalize:
  220. # use zip and * to group indices along each dimension
  221. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  222. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  223. # verifies that no rows in logits were missed unexpectedly
  224. assert logits_applied == logits.shape[0]
  225. return logits
  226. def _apply_alphabet_soup(
  227. logits: torch.Tensor,
  228. p: torch.Tensor,
  229. k: torch.Tensor,
  230. a: torch.Tensor,
  231. m: torch.Tensor,
  232. ) -> torch.Tensor:
  233. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  234. # Apply top-p, min-p and top-a.
  235. probs_sort = logits_sort.softmax(dim=-1)
  236. probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
  237. min_p_thresholds = probs_sort[:, 0] * m
  238. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
  239. threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
  240. mask = (probs_sort < threshold.unsqueeze(1)
  241. ) # Cull logits below the top-a threshold
  242. mask.logical_or_(
  243. probs_sum >
  244. p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
  245. mask[:, 0] = False # Guarantee at least one token is pickable
  246. logits_sort[mask] = -float("inf")
  247. # Apply top-k.
  248. # Create a mask for the top-k elements.
  249. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  250. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  251. top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
  252. # Final mask.
  253. mask = (mask | top_k_mask)
  254. logits_sort.masked_fill_(mask, -float("inf"))
  255. # Re-sort the probabilities.
  256. src = torch.arange(logits_idx.shape[-1],
  257. device=logits_idx.device).expand_as(logits_idx)
  258. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  259. index=logits_idx,
  260. src=src)
  261. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  262. return logits
  263. def _apply_tfs(
  264. logits: torch.Tensor,
  265. tfs: torch.Tensor,
  266. ) -> torch.Tensor:
  267. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  268. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  269. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  270. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  271. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  272. tfs_mask = torch.cat(
  273. (
  274. torch.zeros(
  275. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  276. tfs_mask,
  277. torch.ones(
  278. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  279. ),
  280. dim=-1,
  281. )
  282. logits_sort[tfs_mask] = -float("inf")
  283. logits = torch.gather(logits_sort,
  284. dim=-1,
  285. index=torch.argsort(logits_idx, dim=-1))
  286. return logits
  287. def _apply_eta_cutoff(
  288. logits: torch.Tensor,
  289. eta_cutoff: torch.Tensor,
  290. ) -> torch.Tensor:
  291. shifted_logits = torch.log_softmax(logits, dim=-1)
  292. probs = shifted_logits.exp()
  293. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  294. eps = torch.min(eta_cutoff,
  295. torch.sqrt(eta_cutoff) *
  296. torch.exp(neg_entropy)).unsqueeze(dim=1)
  297. eta_mask = probs < eps
  298. # guard against nulling out all the logits
  299. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  300. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  301. logits[eta_mask] = -float("inf")
  302. return logits
  303. def _apply_epsilon_cutoff(
  304. logits: torch.Tensor,
  305. epsilon_cutoff: torch.Tensor,
  306. ) -> torch.Tensor:
  307. probs = logits.softmax(dim=-1)
  308. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  309. # guard against nulling out all the logits
  310. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  311. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  312. logits[eps_mask] = -float("inf")
  313. return logits
  314. def _apply_typical_sampling(
  315. logits: torch.Tensor,
  316. typical_p: torch.Tensor,
  317. ) -> torch.Tensor:
  318. shifted_logits = torch.log_softmax(logits, dim=-1)
  319. probs = shifted_logits.exp()
  320. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  321. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  322. _, indices = torch.sort(surprisal_deviations)
  323. reordered_probs = probs.gather(-1, indices)
  324. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  325. dim=1)
  326. min_tokens_to_keep = 1
  327. # Keep at least min_tokens_to_keep
  328. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  329. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  330. logits[typ_mask] = -float("inf")
  331. return logits
  332. # pulls double duty for temperature and dynatemp
  333. def _apply_temperature(
  334. logits: torch.Tensor,
  335. temperatures: torch.Tensor,
  336. dynatemp_mins: torch.Tensor,
  337. dynatemp_maxs: torch.Tensor,
  338. dynatemp_exps: torch.Tensor,
  339. ) -> torch.Tensor:
  340. dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
  341. # Check if dynatemp_mask is not empty
  342. if dynatemp_mask.any():
  343. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  344. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  345. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  346. dynatemp_mins = dynatemp_mins.clamp_(min=0)
  347. dynatemp_logits = logits[dynatemp_mask]
  348. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  349. dynatemp_probs = dynatemp_shifted_logits.exp()
  350. dynatemp_entropies = -(dynatemp_probs *
  351. dynatemp_shifted_logits).nansum(dim=-1)
  352. dynatemp_max_entropies = torch.log_(
  353. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  354. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  355. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  356. normalized_entropies.pow_(dynatemp_exps))
  357. temperatures[dynatemp_mask] = dyn_temp
  358. temperatures[temperatures == 0.0] = 1.0
  359. while temperatures.dim() < logits.dim():
  360. temperatures = temperatures.unsqueeze(-1)
  361. logits = logits.div(temperatures)
  362. return logits
  363. def _apply_quadratic_sampling(
  364. logits: torch.Tensor,
  365. smoothing_factor: torch.Tensor,
  366. smoothing_curve: torch.Tensor,
  367. ) -> torch.Tensor:
  368. """
  369. Applies a quadratic transformation to the logits based on the
  370. provided smoothing factors and curves. The transformation is
  371. centered around the maximum logit value in the batch.
  372. The transformation involves a quadratic and cubic term, with the
  373. cubic term controlled by the smoothing curve. The quadratic term is
  374. scaled by the smoothing factor, and the cubic term is scaled by the
  375. product of the smoothing factor and the smoothing curve.
  376. params:
  377. logits (torch.Tensor): The logits to be transformed.
  378. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  379. term in the transformation.
  380. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  381. in the transformation.
  382. returns:
  383. torch.Tensor: The transformed logits.
  384. Credits: @kalomaze
  385. """
  386. max_logits = logits.max(dim=-1, keepdim=True).values
  387. diff = logits - max_logits
  388. smoothing_factor.unsqueeze_(dim=1)
  389. smoothing_curve.unsqueeze_(dim=1)
  390. k = (3 - smoothing_curve) / 2
  391. s = (smoothing_curve - 1) / 2
  392. mask = smoothing_factor > 0
  393. mask = mask.flatten()
  394. transformed_logits = torch.where(
  395. logits != float('-inf'), -(k * smoothing_factor * diff**2) +
  396. (s * smoothing_factor * diff**3) + max_logits, logits)
  397. logits[mask, :] = transformed_logits[mask, :]
  398. return logits
  399. def _greedy_sample(
  400. selected_seq_groups: List[SequenceGroupToSample],
  401. samples: torch.Tensor,
  402. ) -> List[Tuple[List[int], List[int]]]:
  403. """Run greedy sampling on a given samples.
  404. Args:
  405. selected_seq_groups: A list of sequence groups batched.
  406. samples: (num_selected_samples,) A tensor of samples. The length of
  407. samples could be smaller than selected_seq_groups if
  408. seq_group.do_sample is False.
  409. Returns:
  410. Tuple of (next_token_ids, parent_ids). The length of returned list is
  411. same as the length of selected_seq_groups. If the corresponding
  412. seq_group has do_sample=False, tuple contains ([], [])
  413. """
  414. samples = samples.tolist()
  415. sample_idx = 0
  416. results = []
  417. for seq_group in selected_seq_groups:
  418. if not seq_group.do_sample:
  419. results.append(([], []))
  420. continue
  421. seq_ids = seq_group.seq_ids
  422. num_parent_seqs = len(seq_ids)
  423. assert num_parent_seqs == 1, (
  424. "Greedy sampling should have only one seq.")
  425. parent_ids = list(range(num_parent_seqs))
  426. next_token_ids = [samples[sample_idx]]
  427. results.append((next_token_ids, parent_ids))
  428. sample_idx += num_parent_seqs
  429. return results
  430. def _random_sample(
  431. selected_seq_groups: List[SequenceGroupToSample],
  432. random_samples: torch.Tensor,
  433. ) -> List[Tuple[List[int], List[int]]]:
  434. """Run random sampling on a given samples.
  435. Args:
  436. selected_seq_groups: A list of sequence groups batched.
  437. random_samples: (num_selected_samples,) A tensor of samples. The
  438. length of samples could be smaller than selected_seq_groups if
  439. seq_group.do_sample is False.
  440. Returns:
  441. Tuple of (next_token_ids, parent_ids). The length of returned list is
  442. same as the length of selected_seq_groups. If the corresponding
  443. seq_group has do_sample=False, tuple contains ([], [])
  444. """
  445. # Find the maximum best_of value of the prompt phase requests.
  446. random_samples = random_samples.cpu()
  447. sample_idx = 0
  448. results = []
  449. for seq_group in selected_seq_groups:
  450. if not seq_group.do_sample:
  451. results.append(([], []))
  452. continue
  453. seq_ids = seq_group.seq_ids
  454. sampling_params = seq_group.sampling_params
  455. is_prompt = seq_group.is_prompt
  456. num_parent_seqs = len(seq_ids)
  457. if is_prompt:
  458. # Prompt phase.
  459. parent_ids = [0] * sampling_params.best_of
  460. next_token_ids = random_samples[
  461. sample_idx, :sampling_params.best_of].tolist()
  462. else:
  463. # Generation phase.
  464. parent_ids = list(range(num_parent_seqs))
  465. next_token_ids = random_samples[sample_idx:sample_idx +
  466. num_parent_seqs, 0].tolist()
  467. results.append((next_token_ids, parent_ids))
  468. sample_idx += num_parent_seqs
  469. return results
  470. def _beam_search_sample(
  471. selected_seq_groups: List[SequenceGroupToSample],
  472. logprobs: torch.Tensor,
  473. ) -> List[Tuple[List[int], List[int]]]:
  474. """Run beam sampling on a given samples.
  475. Args:
  476. selected_seq_groups: A list of sequence groups batched.
  477. logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
  478. on selected sample indices.
  479. Returns:
  480. Tuple of (next_token_ids, parent_ids). The length of returned list is
  481. same as the length of selected_seq_groups. If the corresponding
  482. seq_group has do_sample=False, tuple contains ([], [])
  483. """
  484. # We sample 2 * beam_width candidates to make sure that with high
  485. # probability we can get `beam_width` candidates in addition to
  486. # the finished sequences for the next iteration. See
  487. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  488. # for details. See also HF reference:
  489. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  490. #
  491. # NOTE: Beam search is not vectorized, so its speed can be slower than
  492. # other sampling methods.
  493. sample_idx = 0
  494. results = []
  495. for seq_group in selected_seq_groups:
  496. if not seq_group.do_sample:
  497. results.append(([], []))
  498. continue
  499. is_prompt = seq_group.is_prompt
  500. seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
  501. num_parent_seqs = len(seq_ids)
  502. beam_width = sampling_params.best_of
  503. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  504. if is_prompt:
  505. # Prompt phase.
  506. assert num_parent_seqs == 1, (
  507. "Prompt input should have only one seq.")
  508. parent_ids = [0] * (2 * beam_width)
  509. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  510. 2 * beam_width)
  511. next_token_ids = next_token_ids.tolist()
  512. else:
  513. # Generation phase.
  514. cumulative_logprobs = [
  515. seq_group.seq_data[seq_id].cumulative_logprob
  516. for seq_id in seq_ids
  517. ]
  518. cumulative_logprobs = torch.tensor(
  519. cumulative_logprobs,
  520. dtype=torch.float,
  521. device=seq_group_logprobs.device)
  522. seq_group_logprobs = (seq_group_logprobs +
  523. cumulative_logprobs.unsqueeze(dim=1))
  524. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  525. 2 * beam_width)
  526. topk_ids = topk_ids.tolist()
  527. vocab_size = seq_group_logprobs.size(-1)
  528. parent_ids = [i // vocab_size for i in topk_ids]
  529. next_token_ids = [i % vocab_size for i in topk_ids]
  530. results.append((next_token_ids, parent_ids))
  531. sample_idx += num_parent_seqs
  532. assert sample_idx == logprobs.size(0)
  533. return results
  534. # torch.multinomial forces a GPU<->CPU sync.
  535. # Therefore, we use an optimized implementation instead.
  536. # Note that we always sample with replacement.
  537. # probs will be modified in place, but this is fine, as we pass
  538. # in a copy already.
  539. def _multinomial(
  540. probs: torch.Tensor,
  541. num_samples: int,
  542. seq_groups: Optional[List[SequenceGroupToSample]] = None,
  543. ) -> torch.Tensor:
  544. if num_samples > 1:
  545. # This is equivalent to torch.repeat_interleaved (which also
  546. # forces a GPU<->CPU sync).
  547. # This allows us to do sampling with replacement by creating
  548. # num_samples copies of each row in the tensor, and then
  549. # batch sampling the resulting tensor.
  550. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  551. probs.shape[1]).contiguous().view(
  552. -1, probs.shape[1])
  553. q = torch.empty_like(probs)
  554. if seq_groups is None:
  555. q.exponential_()
  556. else:
  557. sample_idx = 0
  558. for seq_group in seq_groups:
  559. seq_ids = seq_group.seq_ids
  560. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  561. q[sample_idx:next_sample_idx].exponential_(
  562. generator=seq_group.generator)
  563. sample_idx = next_sample_idx
  564. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  565. def _sample_with_torch(
  566. probs: torch.Tensor,
  567. logprobs: torch.Tensor,
  568. sampling_metadata: SamplingMetadata,
  569. include_gpu_probs_tensor: bool,
  570. modify_greedy_probs: bool,
  571. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  572. categorized_seq_group_ids = {t: [] for t in SamplingType}
  573. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  574. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  575. sampling_params = seq_group.sampling_params
  576. sampling_type = sampling_params.sampling_type
  577. categorized_seq_group_ids[sampling_type].append(i)
  578. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  579. sample_metadata = {}
  580. multinomial_samples = {}
  581. # Create output tensor for sampled token ids.
  582. if include_gpu_probs_tensor:
  583. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  584. 1,
  585. dtype=torch.long,
  586. device=logprobs.device)
  587. else:
  588. sampled_token_ids_tensor = None
  589. # Counterintuitively, having two loops here is actually faster.
  590. # The first loop can run without waiting on GPU<->CPU sync.
  591. for sampling_type in SamplingType:
  592. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  593. num_tokens = len(sample_indices)
  594. if num_tokens == 0:
  595. continue
  596. seq_group_id = categorized_seq_group_ids[sampling_type]
  597. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  598. sample_metadata[sampling_type] = (seq_group_id, seq_groups)
  599. long_sample_indices = sample_indices.long()
  600. if sampling_type == SamplingType.GREEDY:
  601. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  602. dim=-1)
  603. if include_gpu_probs_tensor:
  604. # Store sampled tokens in output tensor.
  605. sampled_token_ids_tensor[
  606. long_sample_indices] = greedy_samples.unsqueeze(-1)
  607. if modify_greedy_probs:
  608. # If required, modify the probabilities such that sampling from
  609. # the modified distribution would always sample the argmax
  610. # token id.
  611. _modify_greedy_probs_inplace(logprobs, probs,
  612. long_sample_indices,
  613. greedy_samples)
  614. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  615. max_best_of_in_batch = 1
  616. for seq_group in seq_groups:
  617. if seq_group.is_prompt:
  618. sampling_params = seq_group.sampling_params
  619. max_best_of_in_batch = max(max_best_of_in_batch,
  620. sampling_params.best_of)
  621. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  622. "seq_groups": seq_groups,
  623. }
  624. multinomial_samples[sampling_type] = _multinomial(
  625. probs[long_sample_indices], max_best_of_in_batch,
  626. **seeded_args)
  627. if include_gpu_probs_tensor:
  628. # Store sampled tokens in output tensor.
  629. sampled_token_ids_tensor[
  630. long_sample_indices] = multinomial_samples[sampling_type]
  631. elif sampling_type == SamplingType.BEAM:
  632. beam_search_logprobs = logprobs[sample_indices]
  633. else:
  634. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  635. # GPU<->CPU sync happens in the loop below.
  636. # This also converts the sample output to Python objects.
  637. for sampling_type in SamplingType:
  638. if sampling_type not in sample_metadata:
  639. continue
  640. (seq_group_id, seq_groups) = sample_metadata[sampling_type]
  641. if sampling_type == SamplingType.GREEDY:
  642. sample_results = _greedy_sample(seq_groups, greedy_samples)
  643. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  644. sample_results = _random_sample(seq_groups,
  645. multinomial_samples[sampling_type])
  646. elif sampling_type == SamplingType.BEAM:
  647. sample_results = _beam_search_sample(seq_groups,
  648. beam_search_logprobs)
  649. sample_results_dict.update(zip(seq_group_id, sample_results))
  650. sample_results = [
  651. sample_results_dict.get(i, ([], []))
  652. for i in range(len(sampling_metadata.seq_groups))
  653. ]
  654. return sample_results, sampled_token_ids_tensor
  655. def _sample_with_triton_kernel(
  656. probs: torch.Tensor,
  657. logprobs: torch.Tensor,
  658. sampling_metadata: SamplingMetadata,
  659. sampling_tensors: SamplingTensors,
  660. ) -> List[Tuple[List[int], List[int]]]:
  661. categorized_seq_group_ids = {t: [] for t in SamplingType}
  662. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  663. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  664. sampling_params = seq_group.sampling_params
  665. sampling_type = sampling_params.sampling_type
  666. categorized_seq_group_ids[sampling_type].append(i)
  667. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  668. sample_metadata = {}
  669. max_best_of_in_batch = 1
  670. # Counterintuitively, having two loops here is actually faster.
  671. # The first loop can run without waiting on GPU<->CPU sync.
  672. for sampling_type in SamplingType:
  673. sample_indices = categorized_sample_indices[sampling_type][:, 0]
  674. sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
  675. num_tokens = len(sample_indices)
  676. if num_tokens == 0:
  677. continue
  678. seq_group_id = categorized_seq_group_ids[sampling_type]
  679. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
  680. sample_metadata[sampling_type] = (seq_group_id, seq_groups,
  681. sample_indices,
  682. sampled_token_indices)
  683. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  684. SamplingType.RANDOM_SEED):
  685. for seq_group in seq_groups:
  686. if seq_group.is_prompt:
  687. sampling_params = seq_group.sampling_params
  688. max_best_of_in_batch = max(max_best_of_in_batch,
  689. sampling_params.best_of)
  690. elif sampling_type == SamplingType.BEAM:
  691. beam_search_logprobs = logprobs[sample_indices]
  692. else:
  693. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  694. sampled_tokens, _, _ = sample_triton(
  695. probs=probs,
  696. seeds=sampling_tensors.sampling_seeds,
  697. max_best_of=max_best_of_in_batch,
  698. sample_indices=sampling_tensors.sample_indices,
  699. logprobs=logprobs,
  700. # don't save logprobs because we have logic for that below
  701. # TODO: use this instead of the CPU-based logic below
  702. save_logprobs=False,
  703. )
  704. # GPU<->CPU sync happens in the loop below.
  705. for sampling_type in SamplingType:
  706. if sampling_type not in sample_metadata:
  707. continue
  708. (seq_group_id, seq_groups, sample_indices,
  709. sampled_token_indices) = sample_metadata[sampling_type]
  710. if sampling_type == SamplingType.GREEDY:
  711. sample_results = _greedy_sample(
  712. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  713. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  714. sample_results = _random_sample(
  715. seq_groups, sampled_tokens[sampled_token_indices])
  716. elif sampling_type == SamplingType.BEAM:
  717. sample_results = _beam_search_sample(seq_groups,
  718. beam_search_logprobs)
  719. sample_results_dict.update(zip(seq_group_id, sample_results))
  720. sample_results = [
  721. sample_results_dict.get(i, ([], []))
  722. for i in range(len(sampling_metadata.seq_groups))
  723. ]
  724. return sample_results
  725. def _sample(
  726. probs: torch.Tensor, logprobs: torch.Tensor,
  727. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  728. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  729. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  730. """
  731. Args:
  732. probs: (num_query_tokens_in_batch, num_vocab)
  733. logprobs: (num_query_tokens_in_batch, num_vocab)
  734. sampling_metadata: The metadata for a batch for sampling.
  735. sampling_tensors: Tensors that include sampling related metadata.
  736. Returns:
  737. (next_token_ids, parent_seq_ids) for each seq group in a batch.
  738. If sampling is skipped, it returns ([], [])
  739. sampled_token_ids_tensor: A tensor of sampled token ids.
  740. """
  741. return _sample_with_torch(
  742. probs,
  743. logprobs,
  744. sampling_metadata,
  745. include_gpu_probs_tensor=include_gpu_probs_tensor,
  746. modify_greedy_probs=modify_greedy_probs,
  747. )
  748. # TODO: Enable once Triton kernel & associated code is faster.
  749. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  750. # sampling_tensors)
  751. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  752. """
  753. This function calculates the ranks of the chosen tokens in a logprob tensor.
  754. Args:
  755. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  756. where N is the no. of tokens and M is the vocab dim.
  757. indices (torch.Tensor): List of chosen token indices.
  758. Returns:
  759. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  760. Each element in the returned tensor represents the rank
  761. of the chosen token in the input logprob tensor.
  762. """
  763. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  764. indices]
  765. return (x > vals[:, None]).long().sum(1).add_(1)
  766. def _get_logprobs(
  767. logprobs: torch.Tensor,
  768. sampling_metadata: SamplingMetadata,
  769. sample_results: List[Tuple[List[int], List[int]]],
  770. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  771. """Return sample lobprobs and prompt logprobs.
  772. The logic consists of 3 parts.
  773. - Select indices to compute logprob from, ranks of token ids, and
  774. the top k token ids from logprobs.
  775. - Compute prompt logprobs if required.
  776. - Compute sample logprobs if required.
  777. Args:
  778. logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
  779. logprob per vocab. Sequence groups' query tokens are batched in a
  780. single flattened tensor. For example, assuming there are N
  781. seq groups, it is sorted by prefill tokens for seq_group_1 (if
  782. prompt logprob is enabled), decode tokens for seq_group_1 (if
  783. sampling is required), prefill tokens for seq_group_2, ...
  784. sampling_metadata: The sampling metadata.
  785. sample_results: (num_seq_groups) The tuple of (next_token_ids,
  786. parent_ids) for each sequence group. When beam search is enabled,
  787. sample_results can contain different number of seq_ids from
  788. sampling_metadata.seq_groups. It is because beam search creates
  789. 2 * BEAM_WIDTH number of samples (whereas there are only up to
  790. BEAM_WIDTH number of seq_ids).
  791. Returns:
  792. A tuple of prompt and sample logprobs per sequence group in a batch.
  793. """
  794. # The index of query token to calculate logprobs. It includes both
  795. # prompt and sample logprob indices.
  796. query_indices: List[int] = []
  797. # The next token ids to get the logprob value from.
  798. next_token_ids: List[int] = []
  799. # The largest requested number of logprobs. We find logprobs as many as the
  800. # largest num logprobs in this API.
  801. largest_num_logprobs = 1
  802. # Select indices to compute logprob from, ranks of token ids, and the top
  803. # k token ids from logprobs.
  804. for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
  805. sample_results):
  806. sampling_params = seq_group.sampling_params
  807. # Update indices and tokens for prompt logprobs.
  808. if (seq_group.is_prompt
  809. and sampling_params.prompt_logprobs is not None):
  810. largest_num_logprobs = max(largest_num_logprobs,
  811. sampling_params.prompt_logprobs)
  812. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  813. query_indices.extend(seq_group.prompt_logprob_indices)
  814. next_token_ids.extend(next_prompt_tokens)
  815. # Update indices and next tokenes for sample logprob.
  816. if seq_group.do_sample:
  817. token_ids, parent_seq_ids = sample_result
  818. # NOTE: We cannot directly use sample_indices because
  819. # sample_indices only contain parent seq_ids of a previous step.
  820. # The current step may have different number of seq_ids, and
  821. # we can obtain it from `sample_result[1]`.
  822. query_idx = seq_group.sample_indices[0]
  823. query_indices.extend(
  824. [query_idx + parent_id for parent_id in parent_seq_ids])
  825. next_token_ids.extend(token_ids)
  826. if sampling_params.logprobs is not None:
  827. largest_num_logprobs = max(largest_num_logprobs,
  828. sampling_params.logprobs)
  829. assert len(next_token_ids) == len(query_indices)
  830. if len(query_indices) == 0:
  831. empty_sampled_logprob = []
  832. empty_prompt_logprob = None
  833. return [empty_prompt_logprob], [empty_sampled_logprob]
  834. query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
  835. next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
  836. # (num_selected_query_tokens, num_logprobs). Note that query_indices can
  837. # contain duplicates if beam search is enabled.
  838. selected_logprobs = logprobs[[
  839. query_indices_gpu,
  840. next_token_ids_gpu,
  841. ]]
  842. ranks = _get_ranks(
  843. logprobs[query_indices_gpu],
  844. next_token_ids_gpu,
  845. )
  846. assert selected_logprobs.shape[0] == ranks.shape[0]
  847. # Logprobs of topk tokens for a batch of sequence groups.
  848. # (num_query_tokens_across_batch).
  849. if largest_num_logprobs > 0:
  850. top_logprobs, top_token_ids = torch.topk(logprobs,
  851. largest_num_logprobs,
  852. dim=-1)
  853. top_logprobs = top_logprobs.cpu()
  854. top_token_ids = top_token_ids.cpu()
  855. else:
  856. top_logprobs, top_token_ids = None, None
  857. selected_logprobs = selected_logprobs.cpu()
  858. ranks = ranks.cpu()
  859. # Find prompt/sample logprobs.
  860. prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
  861. sample_logprobs_per_seq_group: List[SampleLogprobs] = []
  862. top_logprob_idx = 0
  863. selected_logprobs_idx = 0
  864. for seq_group, sample_result in zip(sampling_metadata.seq_groups,
  865. sample_results):
  866. (prompt_logprobs, top_logprob_idx,
  867. selected_logprobs_idx) = _get_prompt_logprob_if_needed(
  868. seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
  869. selected_logprobs_idx, top_logprob_idx)
  870. prompt_logprobs_per_seq_group.append(prompt_logprobs)
  871. (sampled_logprobs, top_logprob_idx,
  872. selected_logprobs_idx) = _get_sampled_logprob_if_needed(
  873. seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
  874. top_logprobs, selected_logprobs_idx, top_logprob_idx)
  875. sample_logprobs_per_seq_group.append(sampled_logprobs)
  876. return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
  877. def _get_prompt_logprob_if_needed(
  878. seq_group: SequenceGroupToSample,
  879. selected_logprobs: torch.Tensor,
  880. ranks: torch.Tensor,
  881. top_token_ids: torch.Tensor,
  882. top_logprobs: torch.Tensor,
  883. selected_logprobs_idx: int,
  884. top_logprob_idx: int,
  885. ):
  886. """Compute the prompt logprob from a sequence group if needed."""
  887. sampling_params = seq_group.sampling_params
  888. is_prompt = seq_group.is_prompt
  889. # Find prompt logprobs
  890. prompt_logprobs: Optional[PromptLogprobs] = None
  891. if (is_prompt and sampling_params.prompt_logprobs is not None):
  892. prompt_logprobs = []
  893. num_logprobs = sampling_params.prompt_logprobs
  894. next_prompt_tokens = _get_next_prompt_tokens(seq_group)
  895. for token_id in next_prompt_tokens:
  896. # Calculate the prompt logprob of the real prompt tokens.
  897. # Use tuple here for performance (to use to_list()).
  898. # {token_id: (logprob, rank_from_vocab)}
  899. prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
  900. token_id: (selected_logprobs[selected_logprobs_idx].item(),
  901. ranks[selected_logprobs_idx].item())
  902. }
  903. # Add top K prompt logprobs along with its rank.
  904. if num_logprobs > 0:
  905. prompt_logprobs_dict.update(
  906. zip(
  907. top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
  908. zip(
  909. top_logprobs[
  910. top_logprob_idx, :num_logprobs].tolist(),
  911. # This is ranks. Since top_logprob is sorted,
  912. # we can just use a range here.
  913. range(1, num_logprobs + 1))))
  914. prompt_logprobs.append({
  915. token_id: Logprob(*logprob_and_rank)
  916. for token_id, logprob_and_rank in prompt_logprobs_dict.items()
  917. })
  918. # + 1 to go to the next prompt token.
  919. top_logprob_idx += 1
  920. selected_logprobs_idx += 1
  921. return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
  922. def _get_sampled_logprob_if_needed(
  923. seq_group: SequenceGroupToSample,
  924. sample_result: Tuple[List[int], List[int]],
  925. selected_logprobs: torch.Tensor,
  926. ranks: torch.Tensor,
  927. top_token_ids: torch.Tensor,
  928. top_logprobs: torch.Tensor,
  929. selected_logprobs_idx: int,
  930. top_logprob_idx: int,
  931. ):
  932. """Compute the sample logprob if needed."""
  933. seq_ids = seq_group.seq_ids
  934. num_logprobs = seq_group.sampling_params.logprobs
  935. if num_logprobs is None:
  936. num_logprobs = 0
  937. sampled_logprobs: SampleLogprobs = []
  938. next_token_ids, parent_seq_ids = sample_result
  939. if seq_group.do_sample:
  940. assert len(next_token_ids) > 0
  941. for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
  942. # Calculate the sample logprob of the real sampled tokens.
  943. # Use tuple here for performance (to use to_list()).
  944. # token_id: (logprob, rank_from_vocab)
  945. sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
  946. next_token_id:
  947. (selected_logprobs[selected_logprobs_idx].item(),
  948. ranks[selected_logprobs_idx].item())
  949. }
  950. # +1 to go to the next sampled token. Note that
  951. # selected_logprobs can contain duplicates unlike top_logprobs
  952. # when beam search is enabled.
  953. selected_logprobs_idx += 1
  954. # Second, add top K logprobs along with its rank.
  955. if num_logprobs >= 0:
  956. sampled_logprobs_dict.update(
  957. zip(
  958. top_token_ids[top_logprob_idx +
  959. parent_id, :num_logprobs].tolist(),
  960. zip(
  961. top_logprobs[top_logprob_idx +
  962. parent_id, :num_logprobs].tolist(),
  963. # This is rank. Since top_logprob is sorted, we
  964. # can just use a range here.
  965. range(1, num_logprobs + 1))))
  966. sampled_logprobs.append({
  967. token_id: Logprob(*logprob_and_rank)
  968. for token_id, logprob_and_rank in
  969. sampled_logprobs_dict.items()
  970. })
  971. # There are len(seq_ids) number of sampled tokens for the current
  972. # sequence group in top_logprobs. Jump to the next seq_group.
  973. top_logprob_idx += len(seq_ids)
  974. return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
  975. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  976. sample_indices: torch.Tensor,
  977. greedy_samples: torch.Tensor) -> None:
  978. """Modify the probability distributions of the greedily-sampled tokens such
  979. that each sampled token has a "probability" of 1.0. This is required by
  980. speculative decoding, which depends on the sampling method being encoded
  981. within the probability distribution for correctness.
  982. # Why do we only need to do this for greedy sampling?
  983. Aphrodite's sampler performs the following steps for greedy or multinomial
  984. (random) sampling:
  985. 1. Get logits from model.
  986. 2. Modify logits according to per-sequence sampling parameters.
  987. - Multiply by temperature, top-k and top-p masking, penalize tokens
  988. according to their frequency, etc.
  989. 3. Sample a token.
  990. - Random sampling simply samples from the modified probability
  991. distribution.
  992. - Greedy sampling performs `argmax` to obtain the token with the
  993. highest likelihood.
  994. Ignoring greedy sampling for a moment, we find that the computed probability
  995. distribution has the following property: we can sample from it independently
  996. and find that the token sampled by the Sampler has a frequency corresponding
  997. to how often we see it in our sampling. In other words, for tokens sampled
  998. with Aphrodite's random SamplingType, the computed probability distribution
  999. encodes the sampling methodology completely.
  1000. Greedy sampling does not normally have this property. Aphrodite modifies
  1001. logits according to sampling params, then performs `argmax`, then returns
  1002. the sampled token and the computed probability distribution. If we sample
  1003. from the distribution, we'll find the likelihood of the greedily-sampled
  1004. token is not always 1.0.
  1005. Since lossless speculative decoding requires that the sampling methodology
  1006. be encoded within the probability distribution, we are motivated to modify
  1007. the probability distribution such that the sampled token has probability 1
  1008. when speculative decoding is used.
  1009. NOTE: Alternatively, we could use an extremely low temperature to achieve
  1010. greedy sampling using multinomial computation and unite the codepaths. This
  1011. has implications on the overall design of the sampler, e.g. how to record
  1012. accurate logprobs for the user, so this improvement is deferred to later.
  1013. """
  1014. logprobs[sample_indices, :] = -float('inf')
  1015. logprobs[sample_indices, greedy_samples] = 0.0
  1016. probs[sample_indices, :] = 0
  1017. probs[sample_indices, greedy_samples] = 1.0
  1018. def _build_sampler_output(
  1019. sample_results: List[Tuple[List[int], List[int]]],
  1020. sampling_metadata: SamplingMetadata,
  1021. prompt_logprobs: List[Optional[PromptLogprobs]],
  1022. sample_logprobs: List[SampleLogprobs],
  1023. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
  1024. ) -> SamplerOutput:
  1025. """Construct Python objects with the output of sampling.
  1026. Args:
  1027. on_device_tensors: Tuple containing on-device tensors with the
  1028. probabilities used in sampling and the sampled token ids. This
  1029. allows post-processing without copies to CPU/serialization, e.g. in
  1030. speculative decoding rejection sampling.
  1031. """
  1032. sampler_output = []
  1033. for (seq_group, sample_result, group_prompt_logprobs,
  1034. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  1035. sample_results, prompt_logprobs,
  1036. sample_logprobs):
  1037. seq_ids = seq_group.seq_ids
  1038. next_token_ids, parent_ids = sample_result
  1039. seq_outputs = []
  1040. for parent_id, next_token_id, logprobs in zip(parent_ids,
  1041. next_token_ids,
  1042. group_sample_logprobs):
  1043. seq_outputs.append(
  1044. SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
  1045. sampler_output.append(
  1046. SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  1047. # If not specified, store None values in SamplerOutput.
  1048. if on_device_tensors is not None:
  1049. sampled_token_probs, sampled_token_ids = on_device_tensors
  1050. else:
  1051. sampled_token_probs, sampled_token_ids = (None, None)
  1052. return SamplerOutput(
  1053. outputs=sampler_output,
  1054. sampled_token_probs=sampled_token_probs,
  1055. sampled_token_ids=sampled_token_ids,
  1056. )
  1057. def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]:
  1058. """Get a list of next prompt tokens to compute logprob from a
  1059. given sequence group.
  1060. It is used to compute prompt logprob. Imagine you have logprob for each
  1061. query token. Query token needs to know the next prompt token id to compute
  1062. prompt logprob. This is a helper to obtain next prompt token ids.
  1063. This API has to be used only when the caller knows seq_group is in prefill
  1064. stage.
  1065. Returns:
  1066. A list of next prompt tokens to compute logprob.
  1067. """
  1068. assert seq_group.is_prompt, (
  1069. "Caller should ensure the sequence group is in a prefill stage.")
  1070. seq_ids = seq_group.seq_ids
  1071. subquery_len = seq_group.subquery_len
  1072. assert subquery_len is not None
  1073. # prompt has only 1 seq id.
  1074. assert len(seq_ids) == 1
  1075. seq_data = seq_group.seq_data[seq_ids[0]]
  1076. computed_len = seq_data.get_num_computed_tokens()
  1077. prompt_tokens = seq_data.prompt_token_ids
  1078. # +1 because we are looking for a next prompt token.
  1079. next_token_index_start = computed_len + 1
  1080. next_token_index_end = min(computed_len + subquery_len + 1,
  1081. len(prompt_tokens))
  1082. next_prompt_tokens = prompt_tokens[
  1083. next_token_index_start:next_token_index_end]
  1084. return next_prompt_tokens
  1085. # def _apply_mirostat_v2(logits: torch.Tensor,
  1086. # sampling_tensors: SamplingTensors) -> torch.Tensor:
  1087. # # Reduce our view to just the affected logits
  1088. # logit_view = logits[sampling_tensors.miro_indices]
  1089. # # Calculate surprise value per token
  1090. # # Convert nats to bits for compatibility with ooba/kobold parameters.
  1091. # logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  1092. # # Mask out "too-surprising" tokens (surprisal > mu)
  1093. # mus = sampling_tensors.miro_mus
  1094. # miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  1095. # # Unmask most-likely logit to guarantee a selection.
  1096. # maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  1097. # miro_mask.scatter_(dim=1, index=maxinds, value=False)
  1098. # # Apply logit mask (effectively a top-k filter).
  1099. # logit_view[miro_mask] = -float("inf")
  1100. # # Project logit changes made to the view onto the original.
  1101. # # I think this step might be redundant.
  1102. # logits[sampling_tensors.miro_indices] = logit_view
  1103. # return logits
  1104. # def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  1105. # sample_results: List[Tuple[List[int], List[int]]],
  1106. # sampling_metadata: SamplingMetadata,
  1107. # output_metadata: OutputMetadata) -> None:
  1108. # """Based on whichever token was finally sampled, we calculate the
  1109. # final surprisal values to update the mus.
  1110. # Because a single sequence can have multiple samples, we must fork
  1111. # the mu accordingly."""
  1112. # assert sampling_metadata.seq_groups is not None
  1113. # seqid_to_tokens = {}
  1114. # seqid_to_indices = {}
  1115. # for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  1116. # sample_results):
  1117. # for idx, (token, parent) in enumerate(zip(toks, parents)):
  1118. # seqid_to_tokens.setdefault(sids[parent], []).append(token)
  1119. # seqid_to_indices.setdefault(sids[parent], []).append(idx)
  1120. # seqids = args.miro_seqids
  1121. # picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  1122. # device=logits.device,
  1123. # dtype=torch.long)
  1124. # # Clumsily, we recalculate token surprisals.
  1125. # logits_view = logits[args.miro_indices]
  1126. # picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  1127. # dim=-1,
  1128. # index=picked_tokens) / -math.log(2)
  1129. # taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  1130. # etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  1131. # mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  1132. # nu_mus = mus - (picked_surprise - taus) * etas
  1133. # # Record updated mu values for use in the next iteration
  1134. # # Note how each mu is split into multiple based on the number of samples.
  1135. # for seqid, seq_mus in zip(seqids, nu_mus):
  1136. # for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  1137. # output_metadata.add(seqid, sample_idx, "miro_mu", mu)