sampler.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111
  1. """A layer that samples the next tokens from the model's outputs."""
  2. import itertools
  3. import math
  4. from typing import Dict, List, Tuple, Optional
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  8. OutputMetadata,
  9. SamplingTensors)
  10. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  11. from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
  12. SamplerOutput, SequenceData,
  13. SequenceGroupOutput, SequenceOutput)
  14. from aphrodite.modeling.layers.ops.sample import sample as sample_triton
  15. class Sampler(nn.Module):
  16. """Samples the next tokens from the model's outputs.
  17. This layer does the following:
  18. 1. Discard the hidden states that are not used for sampling (i.e., all
  19. tokens except the final one in each prompt).
  20. 2. Compute the logits for the next tokens.
  21. 3. Apply all the different sampler functions in the specified order.
  22. 4. Sample the next tokens.
  23. Here, each sequence group within the batch can have different sampling
  24. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  25. The structure of the logits tensor is coupled with the seq_groups in
  26. sampling_metadata. Typically, each sequence in each seq_group has one row in
  27. logits for the next token to be sampled; however, for a seq_group with a
  28. prompt request with the prompt_logprobs sampling parameter, there are rows
  29. in logits for each token in the input prompt.
  30. """
  31. def __init__(self):
  32. super().__init__()
  33. # Whether or not the SamplerOutput should have on-device tensors
  34. # containing the sampled token ids and probabilities. This is used by
  35. # speculative decoding.
  36. self.include_gpu_probs_tensor = False
  37. def forward(
  38. self,
  39. logits: torch.Tensor,
  40. sampling_metadata: SamplingMetadata,
  41. ) -> Optional[SamplerOutput]:
  42. assert logits is not None
  43. _, vocab_size = logits.shape
  44. output_metadata = OutputMetadata()
  45. # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
  46. # have not been generated yet
  47. logits = _apply_min_tokens_penalty(logits, sampling_metadata)
  48. # Prepare sampling tensors with pinned memory to avoid blocking.
  49. sampling_tensors = SamplingTensors.from_sampling_metadata(
  50. sampling_metadata, vocab_size, logits.device, logits.dtype)
  51. if sampling_tensors.do_penalties:
  52. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  53. sampling_tensors.output_tokens,
  54. sampling_tensors.pres_penalties,
  55. sampling_tensors.freq_penalties,
  56. sampling_tensors.rep_penalties)
  57. if sampling_tensors.do_temperatures or sampling_tensors.do_dynatemps:
  58. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  59. sampling_tensors.dynatemp_mins,
  60. sampling_tensors.dynatemp_maxs,
  61. sampling_tensors.dynatemp_exps)
  62. if (sampling_tensors.do_top_ks or sampling_tensors.do_top_ps
  63. or sampling_tensors.do_top_as or sampling_tensors.do_min_ps):
  64. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  65. sampling_tensors.top_ks,
  66. sampling_tensors.top_as,
  67. sampling_tensors.min_ps)
  68. if sampling_tensors.do_tfss:
  69. logits = _apply_tfs(logits, sampling_tensors.tfss)
  70. if sampling_tensors.do_eta_cutoffs:
  71. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  72. if sampling_tensors.do_epsilon_cutoffs:
  73. logits = _apply_epsilon_cutoff(logits,
  74. sampling_tensors.epsilon_cutoffs)
  75. if sampling_tensors.do_typical_ps:
  76. logits = _apply_typical_sampling(logits,
  77. sampling_tensors.typical_ps)
  78. if sampling_tensors.do_quadratic:
  79. logits = _apply_quadratic_sampling(
  80. logits, sampling_tensors.smoothing_indices,
  81. sampling_tensors.smoothing_factors,
  82. sampling_tensors.smoothing_curves)
  83. banned_tokens = _get_custom_token_bans(sampling_metadata)
  84. assert len(banned_tokens) == logits.shape[0]
  85. logits = _apply_token_bans(logits, banned_tokens)
  86. if sampling_tensors.do_mirostat:
  87. logits = _apply_mirostat_v2(logits, sampling_tensors)
  88. # We use float32 for probabilities and log probabilities.
  89. # Compute the probabilities.
  90. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  91. # Compute the log probabilities.
  92. # Use log_softmax to ensure numerical stability.
  93. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  94. # Sample the next tokens.
  95. # sample_results = _sample(probs, logprobs, sampling_metadata,
  96. # sampling_tensors)
  97. sample_results, maybe_sampled_tokens_tensor = _sample(
  98. probs,
  99. logprobs,
  100. sampling_metadata,
  101. sampling_tensors,
  102. include_gpu_probs_tensor=self.include_gpu_probs_tensor,
  103. modify_greedy_probs=self._should_modify_greedy_probs_inplace,
  104. )
  105. if self.include_gpu_probs_tensor:
  106. assert maybe_sampled_tokens_tensor is not None
  107. sampled_tokens_tensor = maybe_sampled_tokens_tensor
  108. on_device_tensors = (probs, sampled_tokens_tensor)
  109. else:
  110. on_device_tensors = None
  111. if sampling_tensors.do_mirostat:
  112. _mirostat_store_args(logits, sampling_tensors, sample_results,
  113. sampling_metadata, output_metadata)
  114. # Get the logprobs query results.
  115. prompt_logprobs, sample_logprobs = _get_logprobs(
  116. logprobs, sampling_metadata, sample_results)
  117. # return _build_sampler_output(sample_results, sampling_metadata,
  118. # prompt_logprobs, sample_logprobs,
  119. # output_metadata)
  120. return _build_sampler_output(sample_results, sampling_metadata,
  121. prompt_logprobs, sample_logprobs,
  122. output_metadata, on_device_tensors)
  123. @property
  124. def _should_modify_greedy_probs_inplace(self) -> bool:
  125. """Whether or not the sampler should modify the probability distribution
  126. of greedily-sampled tokens such that multinomial sampling would sample
  127. the greedily-sampled token.
  128. In other words, if True then we set the probability of the greedily-
  129. sampled token to 1.
  130. This is used by speculative decoding, which requires that the sampling
  131. method be encoded into the probability distribution.
  132. """
  133. # Modify greedy probs if include_gpu_probs_tensor is set.
  134. return self.include_gpu_probs_tensor
  135. def _get_bin_counts_and_mask(
  136. tokens: torch.Tensor,
  137. vocab_size: int,
  138. num_seqs: int,
  139. ) -> Tuple[torch.Tensor, torch.Tensor]:
  140. # Compute the bin counts for the tokens.
  141. # vocab_size + 1 for padding.
  142. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  143. dtype=torch.long,
  144. device=tokens.device)
  145. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  146. bin_counts = bin_counts[:, :vocab_size]
  147. mask = bin_counts > 0
  148. return bin_counts, mask
  149. def _get_custom_token_bans(
  150. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  151. assert sampling_metadata.seq_groups is not None
  152. assert sampling_metadata.prompt_lens is not None
  153. banned_tokens: List[List[int]] = []
  154. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  155. seq_ids, sampling_params = seq_group
  156. custom_token_bans = sampling_params.custom_token_bans
  157. if (i < sampling_metadata.num_prompts
  158. and sampling_params.prompt_logprobs is not None):
  159. prompt_len = sampling_metadata.prompt_lens[i]
  160. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  161. banned_tokens += [custom_token_bans] * len(seq_ids)
  162. return banned_tokens
  163. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  164. output_tokens_tensor: torch.Tensor,
  165. presence_penalties: torch.Tensor,
  166. frequency_penalties: torch.Tensor,
  167. repetition_penalties: torch.Tensor) -> torch.Tensor:
  168. num_seqs, vocab_size = logits.shape
  169. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  170. num_seqs)
  171. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  172. output_tokens_tensor, vocab_size, num_seqs)
  173. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  174. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  175. logits = torch.where(logits > 0, logits / repetition_penalties,
  176. logits * repetition_penalties)
  177. # We follow the definition in OpenAI API.
  178. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  179. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  180. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  181. return logits
  182. def _apply_token_bans(logits: torch.Tensor,
  183. banned_tokens: List[List[int]]) -> torch.Tensor:
  184. for i, banned_token_ids in enumerate(banned_tokens):
  185. if not banned_token_ids:
  186. continue
  187. logits[i, banned_token_ids] = -float("inf")
  188. return logits
  189. def _apply_min_tokens_penalty(
  190. logits: torch.Tensor,
  191. sampling_metadata: SamplingMetadata,
  192. ) -> torch.Tensor:
  193. assert sampling_metadata.seq_groups is not None
  194. assert sampling_metadata.seq_data is not None
  195. # list of indices in logits that will be set to -inf
  196. logits_to_penalize = []
  197. start_idx = 0
  198. for seq_ids, sampling_params in sampling_metadata.seq_groups:
  199. min_tokens = sampling_params.min_tokens
  200. if min_tokens > 0:
  201. seqs_to_penalize = []
  202. for i, seq_id in enumerate(seq_ids):
  203. seq_data = sampling_metadata.seq_data[seq_id]
  204. if len(seq_data.output_token_ids) < min_tokens:
  205. seqs_to_penalize.append(i)
  206. if seqs_to_penalize:
  207. # convert to the index into logits
  208. seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
  209. # use set() to remove any duplicates
  210. token_ids_to_penalize = set(sampling_params.stop_token_ids +
  211. [sampling_params.eos_token_id])
  212. # itertools.product pairs each seq index with every token id
  213. logits_to_penalize.extend(
  214. itertools.product(seqs_to_penalize, token_ids_to_penalize))
  215. start_idx += len(seq_ids)
  216. if logits_to_penalize:
  217. # use zip and * to group indices along each dimension
  218. # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
  219. logits[tuple(zip(*logits_to_penalize))] = -float("inf")
  220. return logits
  221. def _apply_alphabet_soup(
  222. logits: torch.Tensor,
  223. p: torch.Tensor,
  224. k: torch.Tensor,
  225. a: torch.Tensor,
  226. m: torch.Tensor,
  227. ) -> torch.Tensor:
  228. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  229. # Apply top-p, min-p and top-a.
  230. probs_sort = logits_sort.softmax(dim=-1)
  231. probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
  232. min_p_thresholds = probs_sort[:, 0] * m
  233. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
  234. threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
  235. mask = (probs_sort < threshold.unsqueeze(1)
  236. ) # Cull logits below the top-a threshold
  237. mask.logical_or_(
  238. probs_sum >
  239. p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
  240. mask[:, 0] = False # Guarantee at least one token is pickable
  241. logits_sort[mask] = -float("inf")
  242. # Apply top-k.
  243. for i, topk in enumerate(k):
  244. logits_sort[i, topk:] = -float("inf")
  245. # Re-sort the probabilities.
  246. src = torch.arange(logits_idx.shape[-1],
  247. device=logits_idx.device).expand_as(logits_idx)
  248. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  249. index=logits_idx,
  250. src=src)
  251. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  252. return logits
  253. def _apply_tfs(
  254. logits: torch.Tensor,
  255. tfs: torch.Tensor,
  256. ) -> torch.Tensor:
  257. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  258. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  259. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  260. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  261. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  262. tfs_mask = torch.cat(
  263. (
  264. torch.zeros(
  265. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  266. tfs_mask,
  267. torch.ones(
  268. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  269. ),
  270. dim=-1,
  271. )
  272. logits_sort[tfs_mask] = -float("inf")
  273. logits = torch.gather(logits_sort,
  274. dim=-1,
  275. index=torch.argsort(logits_idx, dim=-1))
  276. return logits
  277. def _apply_eta_cutoff(
  278. logits: torch.Tensor,
  279. eta_cutoff: torch.Tensor,
  280. ) -> torch.Tensor:
  281. shifted_logits = torch.log_softmax(logits, dim=-1)
  282. probs = shifted_logits.exp()
  283. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  284. eps = torch.min(eta_cutoff,
  285. torch.sqrt(eta_cutoff) *
  286. torch.exp(neg_entropy)).unsqueeze(dim=1)
  287. eta_mask = probs < eps
  288. # guard against nulling out all the logits
  289. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  290. eta_mask.scatter_(dim=1, index=top_idx, value=False)
  291. logits[eta_mask] = -float("inf")
  292. return logits
  293. def _apply_epsilon_cutoff(
  294. logits: torch.Tensor,
  295. epsilon_cutoff: torch.Tensor,
  296. ) -> torch.Tensor:
  297. probs = logits.softmax(dim=-1)
  298. eps_mask = probs < epsilon_cutoff.unsqueeze(dim=1)
  299. # guard against nulling out all the logits
  300. top_idx = torch.argmax(probs, dim=1, keepdim=True)
  301. eps_mask.scatter_(dim=1, index=top_idx, value=False)
  302. logits[eps_mask] = -float("inf")
  303. return logits
  304. def _apply_typical_sampling(
  305. logits: torch.Tensor,
  306. typical_p: torch.Tensor,
  307. ) -> torch.Tensor:
  308. shifted_logits = torch.log_softmax(logits, dim=-1)
  309. probs = shifted_logits.exp()
  310. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  311. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  312. _, indices = torch.sort(surprisal_deviations)
  313. reordered_probs = probs.gather(-1, indices)
  314. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typical_p.unsqueeze(
  315. dim=1)
  316. min_tokens_to_keep = 1
  317. # Keep at least min_tokens_to_keep
  318. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  319. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  320. logits[typ_mask] = -float("inf")
  321. return logits
  322. # pulls double duty for temperature and dynatemp
  323. def _apply_temperature(
  324. logits: torch.Tensor,
  325. temperatures: torch.Tensor,
  326. dynatemp_mins: torch.Tensor,
  327. dynatemp_maxs: torch.Tensor,
  328. dynatemp_exps: torch.Tensor,
  329. ) -> torch.Tensor:
  330. dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
  331. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  332. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  333. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  334. dynatemp_mins = dynatemp_mins.clamp_(min=0)
  335. dynatemp_logits = logits[dynatemp_mask]
  336. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  337. dynatemp_probs = dynatemp_shifted_logits.exp()
  338. dynatemp_entropies = -(dynatemp_probs *
  339. dynatemp_shifted_logits).nansum(dim=-1)
  340. dynatemp_max_entropies = torch.log_(
  341. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  342. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  343. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  344. normalized_entropies.pow_(dynatemp_exps))
  345. temperatures[dynatemp_mask] = dyn_temp
  346. temperatures[temperatures == 0.0] = 1.0
  347. logits.div_(temperatures.unsqueeze_(dim=1))
  348. return logits
  349. def _apply_quadratic_sampling(
  350. logits: torch.Tensor,
  351. indices: torch.Tensor,
  352. factors: torch.Tensor,
  353. curves: torch.Tensor,
  354. ) -> torch.Tensor:
  355. """
  356. Applies a quadratic transformation to the logits based on the
  357. provided smoothing factors and curves. The transformation is
  358. centered around the maximum logit value in the batch.
  359. The transformation involves a quadratic and cubic term, with the
  360. cubic term controlled by the smoothing curve. The quadratic term is
  361. scaled by the smoothing factor, and the cubic term is scaled by the
  362. product of the smoothing factor and the smoothing curve.
  363. params:
  364. logits (torch.Tensor): The logits to be transformed.
  365. indices (torch.Tensor): Indices to project `logits` down to
  366. the other tensor's lengths.
  367. factors (torch.Tensor): The factors to scale the quadratic
  368. term in the transformation.
  369. curves (torch.Tensor): The factors to scale the cubic term
  370. in the transformation.
  371. returns:
  372. torch.Tensor: The transformed logits.
  373. Credits: @kalomaze
  374. """
  375. factors.unsqueeze_(dim=1)
  376. curves.unsqueeze_(dim=1)
  377. k = factors * (3 - curves) / 2
  378. s = factors * (curves - 1) / 2
  379. quadlogits = logits[indices] # project to only relevant logits
  380. max_logits = quadlogits.max(dim=-1, keepdim=True).values
  381. # Construct the delta from each logit to its new value
  382. diff = quadlogits - max_logits
  383. diff -= diff**2 * (s * diff - k)
  384. diff[diff != diff] = 0 # Eliminate NaNs from infs
  385. logits[indices] -= diff
  386. return logits
  387. def _greedy_sample(
  388. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  389. samples: torch.Tensor,
  390. ) -> List[Tuple[List[int], List[int]]]:
  391. sample_idx = 0
  392. results = []
  393. for seq_group in selected_seq_groups:
  394. seq_ids, _ = seq_group
  395. num_parent_seqs = len(seq_ids)
  396. assert num_parent_seqs == 1, (
  397. "Greedy sampling should have only one seq.")
  398. parent_ids = list(range(num_parent_seqs))
  399. next_token_ids = [samples[sample_idx].item()]
  400. results.append((next_token_ids, parent_ids))
  401. sample_idx += num_parent_seqs
  402. return results
  403. def _random_sample(
  404. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  405. is_prompts: List[bool],
  406. random_samples: torch.Tensor,
  407. ) -> List[Tuple[List[int], List[int]]]:
  408. # Find the maximum best_of value of the prompt phase requests.
  409. random_samples = random_samples.cpu()
  410. sample_idx = 0
  411. results = []
  412. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  413. seq_ids, sampling_params = seq_group
  414. num_parent_seqs = len(seq_ids)
  415. if is_prompt:
  416. # Prompt phase.
  417. parent_ids = [0] * sampling_params.best_of
  418. next_token_ids = random_samples[
  419. sample_idx, :sampling_params.best_of].tolist()
  420. else:
  421. # Generation phase.
  422. parent_ids = list(range(num_parent_seqs))
  423. next_token_ids = random_samples[sample_idx:sample_idx +
  424. num_parent_seqs, 0].tolist()
  425. results.append((next_token_ids, parent_ids))
  426. sample_idx += num_parent_seqs
  427. return results
  428. def _beam_search_sample(
  429. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  430. is_prompts: List[bool],
  431. seq_data: Dict[int, SequenceData],
  432. logprobs: torch.Tensor,
  433. ) -> List[Tuple[List[int], List[int]]]:
  434. # We sample 2 * beam_width candidates to make sure that with high
  435. # probability we can get `beam_width` candidates in addition to
  436. # the finished sequences for the next iteration. See
  437. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  438. # for details. See also HF reference:
  439. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  440. #
  441. # Note: Beam search is not vectorized, so its speed can be slower than
  442. # other sampling methods.
  443. sample_idx = 0
  444. results = []
  445. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  446. seq_ids, sampling_params = seq_group
  447. num_parent_seqs = len(seq_ids)
  448. beam_width = sampling_params.best_of
  449. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  450. if is_prompt:
  451. # Prompt phase.
  452. assert num_parent_seqs == 1, (
  453. "Prompt input should have only one seq.")
  454. parent_ids = [0] * (2 * beam_width)
  455. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  456. 2 * beam_width)
  457. next_token_ids = next_token_ids.tolist()
  458. else:
  459. # Generation phase.
  460. cumulative_logprobs = [
  461. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
  462. ]
  463. cumulative_logprobs = torch.tensor(
  464. cumulative_logprobs,
  465. dtype=torch.float,
  466. device=seq_group_logprobs.device)
  467. seq_group_logprobs = (seq_group_logprobs +
  468. cumulative_logprobs.unsqueeze(dim=1))
  469. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  470. 2 * beam_width)
  471. topk_ids = topk_ids.tolist()
  472. vocab_size = seq_group_logprobs.size(-1)
  473. parent_ids = [i // vocab_size for i in topk_ids]
  474. next_token_ids = [i % vocab_size for i in topk_ids]
  475. results.append((next_token_ids, parent_ids))
  476. sample_idx += num_parent_seqs
  477. assert sample_idx == logprobs.size(0)
  478. return results
  479. # torch.multinomial forces a GPU<->CPU sync.
  480. # Therefore, we use an optimized implementation instead.
  481. # Note that we always sample with replacement.
  482. # probs will be modified in place, but this is fine, as we pass
  483. # in a copy already.
  484. def _multinomial(
  485. probs: torch.Tensor,
  486. num_samples: int,
  487. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
  488. generators: Optional[List[torch.Generator]] = None,
  489. ) -> torch.Tensor:
  490. if num_samples > 1:
  491. # This is equivalent to torch.repeat_interleaved (which also
  492. # forces a GPU<->CPU sync).
  493. # This allows us to do sampling with replacement by creating
  494. # num_samples copies of each row in the tensor, and then
  495. # batch sampling the resulting tensor.
  496. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  497. probs.shape[1]).contiguous().view(
  498. -1, probs.shape[1])
  499. q = torch.empty_like(probs)
  500. if seq_groups is None:
  501. q.exponential_()
  502. else:
  503. assert generators is not None
  504. sample_idx = 0
  505. for (seq_ids, _), generator in zip(seq_groups, generators):
  506. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  507. q[sample_idx:next_sample_idx].exponential_(generator=generator)
  508. sample_idx = next_sample_idx
  509. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  510. def _sample_with_torch(
  511. probs: torch.Tensor,
  512. logprobs: torch.Tensor,
  513. sampling_metadata: SamplingMetadata,
  514. include_gpu_probs_tensor: bool,
  515. modify_greedy_probs: bool,
  516. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  517. """Returns list of (selected_tokens, parent_seq_ids) tuples
  518. corresponding to sampling_metadata.seq_groups."""
  519. assert sampling_metadata.seq_groups is not None
  520. assert sampling_metadata.categorized_sample_indices is not None
  521. assert sampling_metadata.seq_data is not None
  522. categorized_seq_group_ids = {t: [] for t in SamplingType}
  523. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  524. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  525. _, sampling_params = seq_group
  526. sampling_type = sampling_params.sampling_type
  527. categorized_seq_group_ids[sampling_type].append(i)
  528. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  529. sample_metadata = {}
  530. multinomial_samples = {}
  531. # Create output tensor for sampled token ids.
  532. if include_gpu_probs_tensor:
  533. sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
  534. 1,
  535. dtype=torch.long,
  536. device=logprobs.device)
  537. else:
  538. sampled_token_ids_tensor = None
  539. # Counterintuitively, having two loops here is actually faster.
  540. # The first loop can run without waiting on GPU<->CPU sync.
  541. for sampling_type, sample_indices in categorized_sample_indices.items():
  542. sample_indices = sample_indices[:, 0]
  543. if len(sample_indices) == 0:
  544. continue
  545. seq_group_ids = categorized_seq_group_ids[sampling_type]
  546. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
  547. is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
  548. sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
  549. is_prompts, sample_indices)
  550. long_sample_indices = sample_indices.long()
  551. if sampling_type == SamplingType.GREEDY:
  552. greedy_samples = torch.argmax(logprobs[long_sample_indices],
  553. dim=-1)
  554. if include_gpu_probs_tensor:
  555. # Store sampled tokens in output tensor.
  556. sampled_token_ids_tensor[
  557. long_sample_indices] = greedy_samples.unsqueeze(-1)
  558. if modify_greedy_probs:
  559. # If required, modify the probabilities such that sampling from
  560. # the modified distribution would always sample the argmax
  561. # token id.
  562. _modify_greedy_probs_inplace(logprobs, probs,
  563. long_sample_indices,
  564. greedy_samples)
  565. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  566. max_best_of_in_batch = 1
  567. for seq_group, is_prompt in zip(seq_groups, is_prompts):
  568. if is_prompt:
  569. _, sampling_params = seq_group
  570. max_best_of_in_batch = max(max_best_of_in_batch,
  571. sampling_params.best_of)
  572. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  573. "seq_groups": seq_groups,
  574. "generators": sampling_metadata.generators,
  575. }
  576. multinomial_samples[sampling_type] = _multinomial(
  577. probs[long_sample_indices], max_best_of_in_batch,
  578. **seeded_args)
  579. if include_gpu_probs_tensor:
  580. # Store sampled tokens in output tensor.
  581. sampled_token_ids_tensor[
  582. long_sample_indices] = multinomial_samples[sampling_type]
  583. elif sampling_type == SamplingType.BEAM:
  584. beam_search_logprobs = logprobs[sample_indices]
  585. else:
  586. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  587. # GPU<->CPU sync happens in the loop below.
  588. # This also converts the sample output to Python objects.
  589. for sampling_type, metadata in sample_metadata.items():
  590. seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
  591. if sampling_type == SamplingType.GREEDY:
  592. sample_results = _greedy_sample(seq_groups, greedy_samples)
  593. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  594. sample_results = _random_sample(seq_groups, is_prompts,
  595. multinomial_samples[sampling_type])
  596. elif sampling_type == SamplingType.BEAM:
  597. sample_results = _beam_search_sample(seq_groups, is_prompts,
  598. sampling_metadata.seq_data,
  599. beam_search_logprobs)
  600. sample_results_dict.update(zip(seq_group_ids, sample_results))
  601. sample_results = [
  602. sample_results_dict[i]
  603. for i in range(len(sampling_metadata.seq_groups))
  604. ]
  605. return sample_results, sampled_token_ids_tensor
  606. def _sample_with_triton_kernel(
  607. probs: torch.Tensor,
  608. logprobs: torch.Tensor,
  609. sampling_metadata: SamplingMetadata,
  610. sampling_tensors: SamplingTensors,
  611. ) -> List[Tuple[List[int], List[int]]]:
  612. assert sampling_metadata.seq_groups is not None
  613. assert sampling_metadata.categorized_sample_indices is not None
  614. assert sampling_metadata.seq_data is not None
  615. categorized_seq_group_ids = {t: [] for t in SamplingType}
  616. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  617. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  618. _, sampling_params = seq_group
  619. sampling_type = sampling_params.sampling_type
  620. categorized_seq_group_ids[sampling_type].append(i)
  621. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  622. sample_metadata = {}
  623. max_best_of_in_batch = 1
  624. # Counterintuitively, having two loops here is actually faster.
  625. # The first loop can run without waiting on GPU<->CPU sync.
  626. for sampling_type, sample_indices in categorized_sample_indices.items():
  627. sampled_token_indices = sample_indices[:, 1]
  628. sample_indices = sample_indices[:, 0]
  629. if len(sample_indices) == 0:
  630. continue
  631. seq_group_ids = categorized_seq_group_ids[sampling_type]
  632. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
  633. is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
  634. sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
  635. is_prompts, sample_indices,
  636. sampled_token_indices)
  637. if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
  638. SamplingType.RANDOM_SEED):
  639. for seq_group, is_prompt in zip(seq_groups, is_prompts):
  640. if is_prompt:
  641. _, sampling_params = seq_group
  642. max_best_of_in_batch = max(max_best_of_in_batch,
  643. sampling_params.best_of)
  644. elif sampling_type == SamplingType.BEAM:
  645. beam_search_logprobs = logprobs[sample_indices]
  646. else:
  647. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  648. sampled_tokens, _, _ = sample_triton(
  649. probs=probs,
  650. seeds=sampling_tensors.seed_transpose,
  651. max_best_of=max_best_of_in_batch,
  652. sample_indices=sampling_tensors.seed_indices,
  653. logprobs=logprobs,
  654. # don't save logprobs because we have logic for that below
  655. # TODO: use this instead of the CPU-based logic below
  656. save_logprobs=False,
  657. )
  658. # GPU<->CPU sync happens in the loop below.
  659. for sampling_type in SamplingType:
  660. if sampling_type not in sample_metadata:
  661. continue
  662. (seq_group_ids, seq_groups, is_prompts, sample_indices,
  663. sampled_token_indices) = sample_metadata[sampling_type]
  664. if sampling_type == SamplingType.GREEDY:
  665. sample_results = _greedy_sample(
  666. seq_groups, sampled_tokens[sampled_token_indices][:, 0])
  667. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  668. sample_results = _random_sample(
  669. seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
  670. elif sampling_type == SamplingType.BEAM:
  671. sample_results = _beam_search_sample(seq_groups, is_prompts,
  672. sampling_metadata.seq_data,
  673. beam_search_logprobs)
  674. sample_results_dict.update(zip(seq_group_ids, sample_results))
  675. sample_results = [
  676. sample_results_dict[i]
  677. for i in range(len(sampling_metadata.seq_groups))
  678. ]
  679. return sample_results
  680. def _sample(
  681. probs: torch.Tensor, logprobs: torch.Tensor,
  682. sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
  683. include_gpu_probs_tensor: bool, modify_greedy_probs: bool
  684. ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
  685. return _sample_with_torch(
  686. probs,
  687. logprobs,
  688. sampling_metadata,
  689. include_gpu_probs_tensor=include_gpu_probs_tensor,
  690. modify_greedy_probs=modify_greedy_probs,
  691. )
  692. # TODO: Enable once Triton kernel & associated code is faster.
  693. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
  694. # sampling_tensors)
  695. def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  696. """
  697. This function calculates the ranks of the chosen tokens in a logprob tensor.
  698. Args:
  699. x (torch.Tensor): 2D logprob tensor of shape (N, M)
  700. where N is the no. of tokens and M is the vocab dim.
  701. indices (torch.Tensor): List of chosen token indices.
  702. Returns:
  703. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
  704. Each element in the returned tensor represents the rank
  705. of the chosen token in the input logprob tensor.
  706. """
  707. vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
  708. indices]
  709. return (x > vals[:, None]).long().sum(1).add_(1)
  710. def _get_logprobs(
  711. logprobs: torch.Tensor,
  712. sampling_metadata: SamplingMetadata,
  713. sample_results: List[Tuple[List[int], List[int]]],
  714. ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
  715. assert sampling_metadata.seq_groups is not None
  716. assert sampling_metadata.prompt_lens is not None
  717. assert sampling_metadata.seq_data is not None
  718. # Prepare query indices
  719. batched_logprobs_query_seq_indices: List[int] = []
  720. batched_logprobs_query_token_indices: List[int] = []
  721. # at least get one logprob for each token
  722. largest_num_logprobs = 1
  723. sample_idx = 0
  724. for i, (seq_group, sample_result) in enumerate(
  725. zip(sampling_metadata.seq_groups, sample_results)):
  726. seq_ids, sampling_params = seq_group
  727. next_token_ids, parent_ids = sample_result
  728. num_parent_seqs = len(seq_ids)
  729. if (i < sampling_metadata.num_prompts
  730. and sampling_params.prompt_logprobs is not None):
  731. largest_num_logprobs = max(largest_num_logprobs,
  732. sampling_params.prompt_logprobs)
  733. prompt_len = sampling_metadata.prompt_lens[i]
  734. prompt_tokens = sampling_metadata.seq_data[
  735. seq_ids[0]].prompt_token_ids
  736. batched_logprobs_query_seq_indices.extend(
  737. sample_idx + j for j in range(prompt_len - 1))
  738. batched_logprobs_query_token_indices.extend(
  739. token_id for token_id in prompt_tokens[1:])
  740. sample_idx += prompt_len - 1
  741. batched_logprobs_query_seq_indices.extend(
  742. [sample_idx + parent_id for parent_id in parent_ids])
  743. batched_logprobs_query_token_indices.extend(next_token_ids)
  744. if sampling_params.logprobs is not None:
  745. largest_num_logprobs = max(largest_num_logprobs,
  746. sampling_params.logprobs)
  747. sample_idx += num_parent_seqs
  748. assert sample_idx == logprobs.size(0)
  749. batched_logprobs_query_seq_indices_gpu = torch.tensor(
  750. batched_logprobs_query_seq_indices, device=logprobs.device)
  751. batched_logprobs_query_token_indices_gpu = torch.tensor(
  752. batched_logprobs_query_token_indices, device=logprobs.device)
  753. # Batched query for logprobs of selected token
  754. batched_logprobs_query_result = logprobs[[
  755. batched_logprobs_query_seq_indices_gpu,
  756. batched_logprobs_query_token_indices_gpu
  757. ]]
  758. batched_ranks_query_result = _get_ranks(
  759. logprobs[batched_logprobs_query_seq_indices_gpu],
  760. batched_logprobs_query_token_indices_gpu)
  761. # Batched query for logprobs of topk tokens
  762. if largest_num_logprobs > 0:
  763. top_logprobs, top_token_ids = torch.topk(logprobs,
  764. largest_num_logprobs,
  765. dim=-1)
  766. top_logprobs = top_logprobs.cpu()
  767. top_token_ids = top_token_ids.cpu()
  768. else:
  769. top_logprobs, top_token_ids = None, None
  770. batched_logprobs_query_result = batched_logprobs_query_result.cpu()
  771. batched_ranks_query_result = batched_ranks_query_result.cpu()
  772. # Gather results
  773. result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
  774. result_sample_logprobs: List[SampleLogprobs] = []
  775. sample_idx = 0
  776. query_result_idx = 0
  777. for i, (seq_group, sample_result) in enumerate(
  778. zip(sampling_metadata.seq_groups, sample_results)):
  779. seq_ids, sampling_params = seq_group
  780. next_token_ids, parent_ids = sample_result
  781. # Prompt logprobs
  782. if (i < sampling_metadata.num_prompts
  783. and sampling_params.prompt_logprobs is not None):
  784. num_logprobs = sampling_params.prompt_logprobs
  785. prompt_tokens = sampling_metadata.seq_data[
  786. seq_ids[0]].prompt_token_ids
  787. group_prompt_logprobs: PromptLogprobs = [None]
  788. for token_id in prompt_tokens[1:]:
  789. prompt_logprobs_dict = {
  790. token_id:
  791. (batched_logprobs_query_result[query_result_idx].item(),
  792. batched_ranks_query_result[query_result_idx].item())
  793. }
  794. if num_logprobs > 0:
  795. prompt_logprobs_dict.update(
  796. zip(
  797. top_token_ids[sample_idx, :num_logprobs].tolist(),
  798. zip(
  799. top_logprobs[
  800. sample_idx, :num_logprobs].tolist(),
  801. range(1, num_logprobs + 1))))
  802. group_prompt_logprobs.append({
  803. token_id: Logprob(*logprob_rank)
  804. for token_id, logprob_rank in prompt_logprobs_dict.items()
  805. })
  806. sample_idx += 1
  807. query_result_idx += 1
  808. result_prompt_logprobs.append(group_prompt_logprobs)
  809. else:
  810. result_prompt_logprobs.append(None)
  811. # Sample logprobs
  812. num_logprobs = sampling_params.logprobs
  813. if num_logprobs is None:
  814. num_logprobs = 0
  815. group_sample_logprobs: SampleLogprobs = []
  816. for next_token_id, parent_id in zip(next_token_ids, parent_ids):
  817. sample_logprobs_dict = {
  818. next_token_id:
  819. (batched_logprobs_query_result[query_result_idx].item(),
  820. batched_ranks_query_result[query_result_idx].item())
  821. }
  822. query_result_idx += 1
  823. if num_logprobs >= 0:
  824. sample_logprobs_dict.update(
  825. zip(
  826. top_token_ids[sample_idx +
  827. parent_id, :num_logprobs].tolist(),
  828. zip(
  829. top_logprobs[sample_idx +
  830. parent_id, :num_logprobs].tolist(),
  831. range(1, num_logprobs + 1))))
  832. group_sample_logprobs.append({
  833. token_id: Logprob(*logprob_rank)
  834. for token_id, logprob_rank in sample_logprobs_dict.items()
  835. })
  836. result_sample_logprobs.append(group_sample_logprobs)
  837. sample_idx += len(seq_ids)
  838. return result_prompt_logprobs, result_sample_logprobs
  839. def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
  840. sample_indices: torch.Tensor,
  841. greedy_samples: torch.Tensor) -> None:
  842. """Modify the probability distributions of the greedily-sampled tokens such
  843. that each sampled token has a "probability" of 1.0. This is required by
  844. speculative decoding, which depends on the sampling method being encoded
  845. within the probability distribution for correctness.
  846. # Why do we only need to do this for greedy sampling?
  847. Aphrodite's sampler performs the following steps for greedy or multinomial
  848. (random) sampling:
  849. 1. Get logits from model.
  850. 2. Modify logits according to per-sequence sampling parameters.
  851. - Multiply by temperature, top-k and top-p masking, penalize tokens
  852. according to their frequency, etc.
  853. 3. Sample a token.
  854. - Random sampling simply samples from the modified probability
  855. distribution.
  856. - Greedy sampling performs `argmax` to obtain the token with the
  857. highest likelihood.
  858. Ignoring greedy sampling for a moment, we find that the computed probability
  859. distribution has the following property: we can sample from it independently
  860. and find that the token sampled by the Sampler has a frequency corresponding
  861. to how often we see it in our sampling. In other words, for tokens sampled
  862. with Aphrodite's random SamplingType, the computed probability distribution
  863. encodes the sampling methodology completely.
  864. Greedy sampling does not normally have this property. Aphrodite modifies
  865. logits according to sampling params, then performs `argmax`, then returns
  866. the sampled token and the computed probability distribution. If we sample
  867. from the distribution, we'll find the likelihood of the greedily-sampled
  868. token is not always 1.0.
  869. Since lossless speculative decoding requires that the sampling methodology
  870. be encoded within the probability distribution, we are motivated to modify
  871. the probability distribution such that the sampled token has probability 1
  872. when speculative decoding is used.
  873. NOTE: Alternatively, we could use an extremely low temperature to achieve
  874. greedy sampling using multinomial computation and unite the codepaths. This
  875. has implications on the overall design of the sampler, e.g. how to record
  876. accurate logprobs for the user, so this improvement is deferred to later.
  877. """
  878. logprobs[sample_indices, :] = -float('inf')
  879. logprobs[sample_indices, greedy_samples] = 0.0
  880. probs[sample_indices, :] = 0
  881. probs[sample_indices, greedy_samples] = 1.0
  882. def _build_sampler_output(
  883. sample_results: List[Tuple[List[int], List[int]]],
  884. sampling_metadata: SamplingMetadata,
  885. prompt_logprobs: List[Optional[PromptLogprobs]],
  886. sample_logprobs: List[SampleLogprobs],
  887. output_metadata: OutputMetadata,
  888. on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
  889. ) -> SamplerOutput:
  890. """Construct Python objects with the output of sampling.
  891. Args:
  892. on_device_tensors: Tuple containing on-device tensors with the
  893. probabilities used in sampling and the sampled token ids. This
  894. allows post-processing without copies to CPU/serialization, e.g. in
  895. speculative decoding rejection sampling.
  896. """
  897. assert sampling_metadata.seq_groups is not None
  898. sampler_output = []
  899. for (seq_group, sample_result, group_prompt_logprobs,
  900. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  901. sample_results, prompt_logprobs,
  902. sample_logprobs):
  903. seq_ids, _ = seq_group
  904. seq_outputs = [
  905. SequenceOutput(seq_ids[parent_id], token_id, logprobs,
  906. output_metadata.get(seq_ids[parent_id], idx))
  907. for idx, (token_id, parent_id, logprobs) in enumerate(
  908. zip(*sample_result, group_sample_logprobs))
  909. ]
  910. sampler_output.append(
  911. SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  912. # If not specified, store None values in SamplerOutput.
  913. if on_device_tensors is not None:
  914. sampled_token_probs, sampled_token_ids = on_device_tensors
  915. else:
  916. sampled_token_probs, sampled_token_ids = (None, None)
  917. return SamplerOutput(
  918. outputs=sampler_output,
  919. sampled_token_probs=sampled_token_probs,
  920. sampled_token_ids=sampled_token_ids,
  921. )
  922. def _apply_mirostat_v2(logits: torch.Tensor,
  923. sampling_tensors: SamplingTensors) -> torch.Tensor:
  924. # Reduce our view to just the affected logits
  925. logit_view = logits[sampling_tensors.miro_indices]
  926. # Calculate surprise value per token
  927. # Convert nats to bits for compatibility with ooba/kobold parameters.
  928. logit_surprise = torch.log_softmax(logit_view, dim=-1) / -math.log(2)
  929. # Mask out "too-surprising" tokens (surprisal > mu)
  930. mus = sampling_tensors.miro_mus
  931. miro_mask = logit_surprise > mus.unsqueeze(dim=-1)
  932. # Unmask most-likely logit to guarantee a selection.
  933. maxinds = torch.argmax(logit_view, dim=-1, keepdim=True)
  934. miro_mask.scatter_(dim=1, index=maxinds, value=False)
  935. # Apply logit mask (effectively a top-k filter).
  936. logit_view[miro_mask] = -float("inf")
  937. # Project logit changes made to the view onto the original.
  938. # I think this step might be redundant.
  939. logits[sampling_tensors.miro_indices] = logit_view
  940. return logits
  941. def _mirostat_store_args(logits: torch.Tensor, args: SamplingTensors,
  942. sample_results: List[Tuple[List[int], List[int]]],
  943. sampling_metadata: SamplingMetadata,
  944. output_metadata: OutputMetadata) -> None:
  945. """Based on whichever token was finally sampled, we calculate the
  946. final surprisal values to update the mus.
  947. Because a single sequence can have multiple samples, we must fork
  948. the mu accordingly."""
  949. assert sampling_metadata.seq_groups is not None
  950. seqid_to_tokens = {}
  951. seqid_to_indices = {}
  952. for (sids, _), (toks, parents) in zip(sampling_metadata.seq_groups,
  953. sample_results):
  954. for idx, (token, parent) in enumerate(zip(toks, parents)):
  955. seqid_to_tokens.setdefault(sids[parent], []).append(token)
  956. seqid_to_indices.setdefault(sids[parent], []).append(idx)
  957. seqids = args.miro_seqids
  958. picked_tokens = torch.tensor([seqid_to_tokens[x] for x in seqids],
  959. device=logits.device,
  960. dtype=torch.long)
  961. # Clumsily, we recalculate token surprisals.
  962. logits_view = logits[args.miro_indices]
  963. picked_surprise = torch.gather(torch.log_softmax(logits_view, dim=-1),
  964. dim=-1,
  965. index=picked_tokens) / -math.log(2)
  966. taus = args.miro_taus.unsqueeze(dim=-1) # AKA target surprisals
  967. etas = args.miro_etas.unsqueeze(dim=-1) # AKA accumulation rates
  968. mus = args.miro_mus.unsqueeze(dim=-1) # AKA surprisal accumulators
  969. nu_mus = mus - (picked_surprise - taus) * etas
  970. # Record updated mu values for use in the next iteration
  971. # Note how each mu is split into multiple based on the number of samples.
  972. for seqid, seq_mus in zip(seqids, nu_mus):
  973. for sample_idx, mu in zip(seqid_to_indices[seqid], seq_mus):
  974. output_metadata.add(seqid, sample_idx, "miro_mu", mu)