sampler.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. """A layer that samples the next tokens from the model's outputs."""
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  6. OutputMetadata,
  7. SamplingTensors)
  8. from aphrodite.modeling.megatron.communication_op import (
  9. tensor_model_parallel_gather)
  10. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  11. from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
  12. SamplerOutput, SequenceData,
  13. SequenceGroupOutput, SequenceOutput)
  14. class Sampler(nn.Module):
  15. """Samples the next tokens from the model's outputs.
  16. This layer does the following:
  17. 1. Discard the hidden states that are not used for sampling (i.e., all
  18. tokens except the final one in each prompt).
  19. 2. Compute the logits for the next tokens.
  20. 3. Apply presence and frequency penalties.
  21. 4. Apply temperature scaling.
  22. 5. Apply top-p and top-k truncation.
  23. 6. Sample the next tokens.
  24. Here, each sequence group within the batch can have different sampling
  25. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  26. """
  27. def __init__(self,
  28. vocab_size: int,
  29. org_vocab_size: Optional[int] = None) -> None:
  30. super().__init__()
  31. self.vocab_size = vocab_size
  32. # original vocabulary size (without LoRA).
  33. self.org_vocab_size = org_vocab_size or vocab_size
  34. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  35. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  36. # Get the logits for the next tokens.
  37. logits = torch.matmul(hidden_states, embedding.t())
  38. if embedding_bias is not None:
  39. logits += embedding_bias
  40. logits = tensor_model_parallel_gather(logits)
  41. # Remove paddings in vocab (if any).
  42. if logits is not None:
  43. logits = logits[:, :self.org_vocab_size]
  44. return logits
  45. def forward(
  46. self,
  47. logits: torch.Tensor,
  48. sampling_metadata: SamplingMetadata,
  49. ) -> Optional[SamplerOutput]:
  50. # Get the hidden states that we use for sampling.
  51. logits = _prune_hidden_states(logits, sampling_metadata)
  52. logits = tensor_model_parallel_gather(logits)
  53. # Remove paddings in vocab (if any).
  54. if logits is not None:
  55. logits = logits[:, :self.vocab_size]
  56. # Only perform sampling in the driver worker.
  57. # Note: `_get_logits` is still distributed across TP workers because
  58. # the `embedding` weight is distributed across TP workers.
  59. # TODO: Change the get_logits part to a separate stage.
  60. if not sampling_metadata.perform_sampling:
  61. return None
  62. assert logits is not None
  63. _, vocab_size = logits.shape
  64. output_metadata = OutputMetadata()
  65. # Apply logits processors (if any)
  66. logits = _apply_logits_processors(logits, sampling_metadata)
  67. # Prepare sampling tensors with pinned memory to avoid blocking.
  68. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  69. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  70. do_typical_ps, do_quadratic,
  71. do_mirostat) = (SamplingTensors.from_sampling_metadata(
  72. sampling_metadata, vocab_size, logits.device, logits.dtype))
  73. if do_penalties:
  74. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  75. sampling_tensors.output_tokens,
  76. sampling_tensors.presence_penalties,
  77. sampling_tensors.frequency_penalties,
  78. sampling_tensors.repetition_penalties)
  79. if do_temperatures:
  80. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  81. sampling_tensors.dynatemp_mins,
  82. sampling_tensors.dynatemp_maxs,
  83. sampling_tensors.dynatemp_exps)
  84. if do_topks or do_topps or do_topas or do_minps:
  85. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  86. sampling_tensors.top_ks,
  87. sampling_tensors.top_as,
  88. sampling_tensors.min_ps)
  89. if do_tfss:
  90. logits = _apply_tfs(logits, sampling_tensors.tfss)
  91. if do_eta_cutoffs:
  92. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  93. if do_epsilon_cutoffs:
  94. logits = _apply_epsilon_cutoff(logits,
  95. sampling_tensors.epsilon_cutoffs)
  96. if do_typical_ps:
  97. logits = _apply_typical_sampling(logits,
  98. sampling_tensors.typical_ps)
  99. if do_quadratic:
  100. logits = _apply_quadratic_sampling(
  101. logits, sampling_tensors.smoothing_factors,
  102. sampling_tensors.smoothing_curves)
  103. banned_tokens = _get_custom_token_bans(sampling_metadata)
  104. assert len(banned_tokens) == logits.shape[0]
  105. logits = _apply_token_bans(logits, banned_tokens)
  106. if do_mirostat:
  107. logits = _mirostat(logits, sampling_tensors, output_metadata)
  108. # We use float32 for probabilities and log probabilities.
  109. # Compute the probabilities.
  110. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  111. # Compute the log probabilities.
  112. # Use log_softmax to ensure numerical stability.
  113. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  114. # Sample the next tokens.
  115. sample_results = _sample(probs, logprobs, sampling_metadata)
  116. # Get the logprobs query results.
  117. prompt_logprobs, sample_logprobs = _get_logprobs(
  118. logprobs, sampling_metadata, sample_results)
  119. return _build_sampler_output(sample_results, sampling_metadata,
  120. prompt_logprobs, sample_logprobs,
  121. output_metadata)
  122. def _prune_hidden_states(
  123. hidden_states: torch.Tensor,
  124. sampling_metadata: SamplingMetadata,
  125. ) -> torch.Tensor:
  126. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  127. return hidden_states.index_select(0,
  128. sampling_metadata.selected_token_indices)
  129. def _get_bin_counts_and_mask(
  130. tokens: torch.Tensor,
  131. vocab_size: int,
  132. num_seqs: int,
  133. ) -> Tuple[torch.Tensor, torch.Tensor]:
  134. # Compute the bin counts for the tokens.
  135. # vocab_size + 1 for padding.
  136. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  137. dtype=torch.long,
  138. device=tokens.device)
  139. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  140. bin_counts = bin_counts[:, :vocab_size]
  141. mask = bin_counts > 0
  142. return bin_counts, mask
  143. def _get_custom_token_bans(
  144. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  145. banned_tokens: List[List[int]] = []
  146. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  147. seq_ids, sampling_params = seq_group
  148. custom_token_bans = sampling_params.custom_token_bans
  149. if (i < sampling_metadata.num_prompts
  150. and sampling_params.prompt_logprobs is not None):
  151. prompt_len = sampling_metadata.prompt_lens[i]
  152. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  153. banned_tokens += [custom_token_bans] * len(seq_ids)
  154. return banned_tokens
  155. # def _apply_logits_processors(
  156. # logits: torch.Tensor,
  157. # metadata: SamplingMetadata,
  158. # ) -> torch.Tensor:
  159. # seq_offset = 0
  160. # for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
  161. # seq_size = len(seq_ids)
  162. # output_tokens = []
  163. # if (i < metadata.num_prompts
  164. # and sampling_params.prompt_logprobs is not None):
  165. # prompt_seqs = metadata.prompt_lens[i] - 1
  166. # seq_size += prompt_seqs
  167. # output_tokens.extend([[]] * prompt_seqs)
  168. # seq_end = seq_offset + seq_size
  169. # if sampling_params.logits_processors:
  170. # output_tokens.extend(metadata.seq_data[sid].output_token_ids
  171. # for sid in seq_ids)
  172. # for proc in sampling_params.logits_processors:
  173. # proc(logits[seq_offset:seq_end], output_tokens)
  174. # seq_offset = seq_end
  175. # return logits
  176. def _apply_logits_processors(
  177. logits: torch.Tensor,
  178. sampling_metadata: SamplingMetadata,
  179. ) -> torch.Tensor:
  180. logits_row_idx = 0
  181. found_logits_processors = False
  182. for seq_ids, sampling_params in sampling_metadata.seq_groups:
  183. logits_processors = sampling_params.logits_processors
  184. if logits_processors:
  185. found_logits_processors = True
  186. for seq_id in seq_ids:
  187. logits_row = logits[logits_row_idx]
  188. token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
  189. for logits_processor in logits_processors:
  190. logits_row = logits_processor(token_ids, logits_row)
  191. logits[logits_row_idx] = logits_row
  192. logits_row_idx += 1
  193. else:
  194. logits_row_idx += len(seq_ids)
  195. if found_logits_processors:
  196. assert logits_row_idx == logits.shape[0]
  197. return logits
  198. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  199. output_tokens_tensor: torch.Tensor,
  200. presence_penalties: torch.Tensor,
  201. frequency_penalties: torch.Tensor,
  202. repetition_penalties: torch.Tensor) -> torch.Tensor:
  203. num_seqs, vocab_size = logits.shape
  204. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  205. num_seqs)
  206. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  207. output_tokens_tensor, vocab_size, num_seqs)
  208. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  209. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  210. logits = torch.where(logits > 0, logits / repetition_penalties,
  211. logits * repetition_penalties)
  212. # We follow the definition in OpenAI API.
  213. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  214. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  215. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  216. return logits
  217. def _apply_token_bans(logits: torch.Tensor,
  218. banned_tokens: List[List[int]]) -> torch.Tensor:
  219. for i, banned_token_ids in enumerate(banned_tokens):
  220. if not banned_token_ids:
  221. continue
  222. logits[i, banned_token_ids] = -float("inf")
  223. return logits
  224. def _apply_alphabet_soup(
  225. logits: torch.Tensor,
  226. p: torch.Tensor,
  227. k: torch.Tensor,
  228. a: torch.Tensor,
  229. m: torch.Tensor,
  230. ) -> torch.Tensor:
  231. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  232. # Apply top-p, min-p and top-a.
  233. probs_sort = logits_sort.softmax(dim=-1)
  234. probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
  235. min_p_thresholds = probs_sort[:, 0] * m
  236. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
  237. treshold = torch.maximum(min_p_thresholds, top_a_thresholds)
  238. mask = (probs_sort < treshold.unsqueeze(1)
  239. ) # Cull logits below the top-a threshold
  240. mask.logical_or_(
  241. probs_sum >
  242. p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
  243. mask[:, 0] = False # Guarantee at least one token is pickable
  244. logits_sort[mask] = -float("inf")
  245. # Apply top-k.
  246. # Create a mask for the top-k elements.
  247. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  248. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  249. top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
  250. # Final mask.
  251. mask = (mask | top_k_mask)
  252. logits_sort.masked_fill_(mask, -float("inf"))
  253. # Re-sort the probabilities.
  254. src = torch.arange(logits_idx.shape[-1],
  255. device=logits_idx.device).expand_as(logits_idx)
  256. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  257. index=logits_idx,
  258. src=src)
  259. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  260. return logits
  261. def _apply_tfs(
  262. logits: torch.Tensor,
  263. tfs: torch.Tensor,
  264. ) -> torch.Tensor:
  265. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  266. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  267. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  268. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  269. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  270. tfs_mask = torch.cat(
  271. (
  272. torch.zeros(
  273. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  274. tfs_mask,
  275. torch.ones(
  276. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  277. ),
  278. dim=-1,
  279. )
  280. logits_sort[tfs_mask] = -float("inf")
  281. logits = torch.gather(logits_sort,
  282. dim=-1,
  283. index=torch.argsort(logits_idx, dim=-1))
  284. return logits
  285. def _apply_eta_cutoff(
  286. logits: torch.Tensor,
  287. eta_cutoff: torch.Tensor,
  288. ) -> torch.Tensor:
  289. eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
  290. device=logits.device) * 1e-4
  291. shifted_logits = torch.log_softmax(logits, dim=-1)
  292. probs = shifted_logits.exp()
  293. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  294. eps = torch.min(eta,
  295. torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
  296. eta_mask = probs < eps
  297. if torch.all(eta_mask): # guard against nulling out all the logits
  298. topk_prob, _ = torch.max(probs, dim=-1)
  299. eta_mask = probs < topk_prob
  300. logits[eta_mask] = -float("inf")
  301. return logits
  302. def _apply_epsilon_cutoff(
  303. logits: torch.Tensor,
  304. epsilon_cutoff: torch.Tensor,
  305. ) -> torch.Tensor:
  306. eps = torch.tensor(epsilon_cutoff,
  307. dtype=logits.dtype,
  308. device=logits.device).unsqueeze(dim=1)
  309. probs = logits.softmax(dim=-1)
  310. eps_mask = probs < (eps * 1e-4)
  311. if torch.all(eps_mask): # guard against nulling out all the logits
  312. topk_prob, _ = torch.max(probs, dim=-1)
  313. eps_mask = probs < topk_prob
  314. logits[eps_mask] = -float("inf")
  315. return logits
  316. def _apply_typical_sampling(
  317. logits: torch.Tensor,
  318. typical_p: torch.Tensor,
  319. ) -> torch.Tensor:
  320. typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
  321. shifted_logits = torch.log_softmax(logits, dim=-1)
  322. probs = shifted_logits.exp()
  323. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  324. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  325. _, indices = torch.sort(surprisal_deviations)
  326. reordered_probs = probs.gather(-1, indices)
  327. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
  328. min_tokens_to_keep = 1
  329. # Keep at least min_tokens_to_keep
  330. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  331. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  332. logits[typ_mask] = -float("inf")
  333. return logits
  334. # pulls double duty for temperature and dynatemp
  335. def _apply_temperature(
  336. logits: torch.Tensor,
  337. temperatures: torch.Tensor,
  338. dynatemp_mins: torch.Tensor,
  339. dynatemp_maxs: torch.Tensor,
  340. dynatemp_exps: torch.Tensor,
  341. ) -> torch.Tensor:
  342. dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
  343. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  344. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  345. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  346. dynatemp_mins = dynatemp_mins.clamp_(min=0)
  347. dynatemp_logits = logits[dynatemp_mask]
  348. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  349. dynatemp_probs = dynatemp_shifted_logits.exp()
  350. dynatemp_entropies = -(dynatemp_probs *
  351. dynatemp_shifted_logits).nansum(dim=-1)
  352. dynatemp_max_entropies = torch.log_(
  353. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  354. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  355. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  356. normalized_entropies.pow_(dynatemp_exps))
  357. temperatures[dynatemp_mask] = dyn_temp
  358. temperatures[temperatures == 0.0] = 1.0
  359. logits.div_(temperatures.unsqueeze_(dim=1))
  360. return logits
  361. def _apply_quadratic_sampling(
  362. logits: torch.Tensor,
  363. smoothing_factors: torch.Tensor,
  364. smoothing_curves: torch.Tensor,
  365. ) -> torch.Tensor:
  366. """
  367. Applies a quadratic transformation to the logits based on the
  368. provided smoothing factors and curves. The transformation is
  369. centered around the maximum logit value in the batch.
  370. The transformation involves a quadratic and cubic term, with the
  371. cubic term controlled by the smoothing curve. The quadratic term is
  372. scaled by the smoothing factor, and the cubic term is scaled by the
  373. product of the smoothing factor and the smoothing curve.
  374. params:
  375. logits (torch.Tensor): The logits to be transformed.
  376. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  377. term in the transformation.
  378. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  379. in the transformation.
  380. returns:
  381. torch.Tensor: The transformed logits.
  382. Credits: @kalomaze
  383. """
  384. max_logits = logits.max(dim=-1, keepdim=True).values
  385. diff = logits - max_logits
  386. smoothing_factors.unsqueeze_(dim=1)
  387. smoothing_curves.unsqueeze_(dim=1)
  388. k = (3 - smoothing_curves) / 2
  389. s = (smoothing_curves - 1) / 2
  390. mask = smoothing_factors > 0
  391. mask = mask.flatten()
  392. transformed_logits = torch.where(
  393. logits != float('-inf'), -(k * smoothing_factors * diff**2) +
  394. (s * smoothing_factors * diff**3) + max_logits, logits)
  395. logits[mask, :] = transformed_logits[mask, :]
  396. return logits
  397. def _greedy_sample(
  398. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  399. samples: torch.Tensor,
  400. ) -> List[Tuple[List[int], List[int]]]:
  401. samples = samples.tolist()
  402. sample_idx = 0
  403. results = []
  404. for seq_group in selected_seq_groups:
  405. seq_ids, _ = seq_group
  406. num_parent_seqs = len(seq_ids)
  407. assert num_parent_seqs == 1, (
  408. "Greedy sampling should have only one seq.")
  409. parent_ids = list(range(num_parent_seqs))
  410. next_token_ids = [samples[sample_idx]]
  411. results.append((next_token_ids, parent_ids))
  412. sample_idx += num_parent_seqs
  413. return results
  414. def _random_sample(
  415. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  416. is_prompts: List[bool],
  417. random_samples: torch.Tensor,
  418. ) -> List[Tuple[List[int], List[int]]]:
  419. # Find the maximum best_of value of the prompt phase requests.
  420. random_samples = random_samples.cpu()
  421. sample_idx = 0
  422. results = []
  423. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  424. seq_ids, sampling_params = seq_group
  425. num_parent_seqs = len(seq_ids)
  426. if is_prompt:
  427. # Prompt phase.
  428. parent_ids = [0] * sampling_params.best_of
  429. next_token_ids = random_samples[
  430. sample_idx, :sampling_params.best_of].tolist()
  431. else:
  432. # Generation phase.
  433. parent_ids = list(range(num_parent_seqs))
  434. next_token_ids = random_samples[sample_idx:sample_idx +
  435. num_parent_seqs, 0].tolist()
  436. results.append((next_token_ids, parent_ids))
  437. sample_idx += num_parent_seqs
  438. return results
  439. def _beam_search_sample(
  440. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  441. is_prompts: List[bool],
  442. seq_data: Dict[int, SequenceData],
  443. logprobs: torch.Tensor,
  444. ) -> List[Tuple[List[int], List[int]]]:
  445. # We sample 2 * beam_width candidates to make sure that with high
  446. # probability we can get `beam_width` candidates in addition to
  447. # the finished sequences for the next iteration. See
  448. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  449. # for details. See also HF reference:
  450. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  451. #
  452. # Note: Beam search is not vectorized, so its speed can be slower than
  453. # other sampling methods.
  454. sample_idx = 0
  455. results = []
  456. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  457. seq_ids, sampling_params = seq_group
  458. num_parent_seqs = len(seq_ids)
  459. beam_width = sampling_params.best_of
  460. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  461. if is_prompt:
  462. # Prompt phase.
  463. assert num_parent_seqs == 1, (
  464. "Prompt input should have only one seq.")
  465. parent_ids = [0] * (2 * beam_width)
  466. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  467. 2 * beam_width)
  468. next_token_ids = next_token_ids.tolist()
  469. else:
  470. # Generation phase.
  471. cumulative_logprobs = [
  472. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
  473. ]
  474. cumulative_logprobs = torch.tensor(
  475. cumulative_logprobs,
  476. dtype=torch.float,
  477. device=seq_group_logprobs.device)
  478. seq_group_logprobs = (seq_group_logprobs +
  479. cumulative_logprobs.unsqueeze(dim=1))
  480. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  481. 2 * beam_width)
  482. topk_ids = topk_ids.tolist()
  483. vocab_size = seq_group_logprobs.size(-1)
  484. parent_ids = [i // vocab_size for i in topk_ids]
  485. next_token_ids = [i % vocab_size for i in topk_ids]
  486. results.append((next_token_ids, parent_ids))
  487. sample_idx += num_parent_seqs
  488. assert sample_idx == logprobs.size(0)
  489. return results
  490. # torch.multinomial forces a GPU<->CPU sync.
  491. # Therefore, we use an optimized implementation instead.
  492. # Note that we always sample with replacement.
  493. # probs will be modified in place, but this is fine, as we pass
  494. # in a copy already.
  495. def _multinomial(
  496. probs: torch.Tensor,
  497. num_samples: int,
  498. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
  499. generators: Optional[List[torch.Generator]] = None,
  500. ) -> torch.Tensor:
  501. if num_samples > 1:
  502. # This is equivalent to torch.repeat_interleaved (which also
  503. # forces a GPU<->CPU sync).
  504. # This allows us to do sampling with replacement by creating
  505. # num_samples copies of each row in the tensor, and then
  506. # batch sampling the resulting tensor.
  507. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  508. probs.shape[1]).contiguous().view(
  509. -1, probs.shape[1])
  510. q = torch.empty_like(probs)
  511. if seq_groups is None:
  512. q.exponential_()
  513. else:
  514. sample_idx = 0
  515. for (seq_ids, _), generator in zip(seq_groups, generators):
  516. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  517. q[sample_idx:next_sample_idx].exponential_(generator=generator)
  518. sample_idx = next_sample_idx
  519. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  520. def _sample(
  521. probs: torch.Tensor,
  522. logprobs: torch.Tensor,
  523. sampling_metadata: SamplingMetadata,
  524. ) -> List[Tuple[List[int], List[int]]]:
  525. categorized_seq_group_ids = {t: [] for t in SamplingType}
  526. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  527. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  528. _, sampling_params = seq_group
  529. sampling_type = sampling_params.sampling_type
  530. categorized_seq_group_ids[sampling_type].append(i)
  531. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  532. sample_metadata = {}
  533. multinomial_samples = {}
  534. # Counterintuitively, having two loops here is actually faster.
  535. # The first loop can run without waiting on GPU<->CPU sync.
  536. for sampling_type, sample_indices in categorized_sample_indices.items():
  537. if len(sample_indices) == 0:
  538. continue
  539. seq_group_ids = categorized_seq_group_ids[sampling_type]
  540. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
  541. is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
  542. sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
  543. is_prompts, sample_indices)
  544. if sampling_type == SamplingType.GREEDY:
  545. greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
  546. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  547. max_best_of = 1
  548. for seq_group, is_prompt in zip(seq_groups, is_prompts):
  549. if is_prompt:
  550. _, sampling_params = seq_group
  551. max_best_of = max(max_best_of, sampling_params.best_of)
  552. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  553. "seq_groups": seq_groups,
  554. "generators": sampling_metadata.generators,
  555. }
  556. multinomial_samples[sampling_type] = _multinomial(
  557. probs[sample_indices], max_best_of, **seeded_args)
  558. elif sampling_type == SamplingType.BEAM:
  559. beam_search_logprobs = logprobs[sample_indices]
  560. else:
  561. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  562. # GPU<->CPU sync happens in the loop below.
  563. for sampling_type, metadata in sample_metadata.items():
  564. seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
  565. if sampling_type == SamplingType.GREEDY:
  566. sample_results = _greedy_sample(seq_groups, greedy_samples)
  567. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  568. sample_results = _random_sample(seq_groups, is_prompts,
  569. multinomial_samples[sampling_type])
  570. elif sampling_type == SamplingType.BEAM:
  571. sample_results = _beam_search_sample(seq_groups, is_prompts,
  572. sampling_metadata.seq_data,
  573. beam_search_logprobs)
  574. sample_results_dict.update(zip(seq_group_ids, sample_results))
  575. sample_results = [
  576. sample_results_dict[i]
  577. for i in range(len(sampling_metadata.seq_groups))
  578. ]
  579. return sample_results
  580. def _get_logprobs(
  581. logprobs: torch.Tensor,
  582. sampling_metadata: SamplingMetadata,
  583. sample_results: List[Tuple[List[int], List[int]]],
  584. ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
  585. int, float]]]]:
  586. # Prepare query indices
  587. batched_logprobs_query_seq_indices: List[int] = []
  588. batched_logprobs_query_token_indices: List[int] = []
  589. largest_num_logprobs = 0
  590. sample_idx = 0
  591. for i, (seq_group, sample_result) in enumerate(
  592. zip(sampling_metadata.seq_groups, sample_results)):
  593. seq_ids, sampling_params = seq_group
  594. next_token_ids, parent_ids = sample_result
  595. num_parent_seqs = len(seq_ids)
  596. if (i < sampling_metadata.num_prompts
  597. and sampling_params.prompt_logprobs is not None):
  598. largest_num_logprobs = max(largest_num_logprobs,
  599. sampling_params.prompt_logprobs)
  600. prompt_len = sampling_metadata.prompt_lens[i]
  601. prompt_tokens = sampling_metadata.seq_data[
  602. seq_ids[0]].prompt_token_ids
  603. batched_logprobs_query_seq_indices.extend(
  604. sample_idx + j for j in range(prompt_len - 1))
  605. batched_logprobs_query_token_indices.extend(
  606. token_id for token_id in prompt_tokens[1:])
  607. sample_idx += prompt_len - 1
  608. batched_logprobs_query_seq_indices.extend(
  609. [sample_idx + parent_id for parent_id in parent_ids])
  610. batched_logprobs_query_token_indices.extend(next_token_ids)
  611. if sampling_params.logprobs is not None:
  612. largest_num_logprobs = max(largest_num_logprobs,
  613. sampling_params.logprobs)
  614. sample_idx += num_parent_seqs
  615. assert sample_idx == logprobs.size(0)
  616. # Batched query for logprobs of selected token
  617. batched_logprobs_query_result = logprobs[[
  618. batched_logprobs_query_seq_indices,
  619. batched_logprobs_query_token_indices
  620. ]]
  621. # Batched query for logprobs of topk tokens
  622. if largest_num_logprobs > 0:
  623. top_logprobs, top_token_ids = torch.topk(logprobs,
  624. largest_num_logprobs,
  625. dim=-1)
  626. top_logprobs = top_logprobs.cpu()
  627. top_token_ids = top_token_ids.cpu()
  628. else:
  629. top_logprobs, top_token_ids = None, None
  630. batched_logprobs_query_result = batched_logprobs_query_result.cpu()
  631. # Gather results
  632. result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
  633. result_sample_logprobs: List[SampleLogprobs] = []
  634. sample_idx = 0
  635. query_result_idx = 0
  636. for i, (seq_group, sample_result) in enumerate(
  637. zip(sampling_metadata.seq_groups, sample_results)):
  638. seq_ids, sampling_params = seq_group
  639. next_token_ids, parent_ids = sample_result
  640. # Prompt logprobs
  641. if (i < sampling_metadata.num_prompts
  642. and sampling_params.prompt_logprobs is not None):
  643. num_logprobs = sampling_params.prompt_logprobs
  644. prompt_len = sampling_metadata.prompt_lens[i]
  645. prompt_tokens = sampling_metadata.seq_data[
  646. seq_ids[0]].prompt_token_ids
  647. group_prompt_logprobs: PromptLogprobs = [None]
  648. for token_id in prompt_tokens[1:]:
  649. prompt_logprobs_dict = {
  650. token_id:
  651. batched_logprobs_query_result[query_result_idx].item()
  652. }
  653. if num_logprobs > 0:
  654. prompt_logprobs_dict.update(
  655. zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
  656. top_logprobs[sample_idx, :num_logprobs].tolist()))
  657. group_prompt_logprobs.append({
  658. token_id: Logprob(logprob)
  659. for token_id, logprob in prompt_logprobs_dict.items()
  660. })
  661. sample_idx += 1
  662. query_result_idx += 1
  663. result_prompt_logprobs.append(group_prompt_logprobs)
  664. else:
  665. result_prompt_logprobs.append(None)
  666. # Sample logprobs
  667. num_logprobs = sampling_params.logprobs
  668. if num_logprobs is None:
  669. num_logprobs = 0
  670. group_sample_logprobs: SampleLogprobs = []
  671. for next_token_id, parent_id in zip(next_token_ids, parent_ids):
  672. sample_logprobs_dict = {
  673. next_token_id:
  674. batched_logprobs_query_result[query_result_idx].item()
  675. }
  676. query_result_idx += 1
  677. if num_logprobs > 0:
  678. sample_logprobs_dict.update(
  679. zip(
  680. top_token_ids[sample_idx +
  681. parent_id, :num_logprobs].tolist(),
  682. top_logprobs[sample_idx +
  683. parent_id, :num_logprobs].tolist()))
  684. group_sample_logprobs.append({
  685. token_id: Logprob(logprob)
  686. for token_id, logprob in sample_logprobs_dict.items()
  687. })
  688. result_sample_logprobs.append(group_sample_logprobs)
  689. sample_idx += len(seq_ids)
  690. return result_prompt_logprobs, result_sample_logprobs
  691. def _build_sampler_output(
  692. sample_results: List[Tuple[List[int], List[int]]],
  693. sampling_metadata: SamplingMetadata,
  694. prompt_logprobs: List[Optional[PromptLogprobs]],
  695. sample_logprobs: List[SampleLogprobs],
  696. output_metadata: OutputMetadata,
  697. ) -> SamplerOutput:
  698. sampler_output = []
  699. for (seq_group, sample_result, group_prompt_logprobs,
  700. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  701. sample_results, prompt_logprobs,
  702. sample_logprobs):
  703. seq_ids, _ = seq_group
  704. next_token_ids, parent_ids = sample_result
  705. seq_outputs = []
  706. for parent_id, next_token_id, logprobs in zip(parent_ids,
  707. next_token_ids,
  708. group_sample_logprobs):
  709. seq_outputs.append(
  710. SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
  711. output_metadata.get(seq_ids[parent_id])))
  712. sampler_output.append(
  713. SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  714. return sampler_output
  715. def _miro_store_args(seqids: List[int], mus: List[float],
  716. output_metadata: OutputMetadata) -> None:
  717. for sid, mu in zip(seqids,
  718. mus.tolist()): # tolist might be premature optimization
  719. output_metadata.add(sid, "miro_mu", mu)
  720. def _apply_mirostat_v2(
  721. logits: torch.Tensor,
  722. taus: torch.Tensor, # AKA the targeted surprise
  723. etas: torch.Tensor, # AKA the learning rate
  724. mus: torch.
  725. Tensor, # AKA the accumulator that always tries to approach [tau]
  726. ) -> torch.Tensor:
  727. logit_surprise = torch.softmax(
  728. logits, dim=-1).log2_().neg_() # Calculate surprise value per token
  729. # For compatibility with ooba/kobold, done in unit of bits(log base 2)
  730. # not nats(ln).
  731. # Ideally this would be a log_softmax, for numerical stability and
  732. # elegance purposes.
  733. # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
  734. miro_mask = logit_surprise > mus.unsqueeze(
  735. dim=-1) # Mask out "too-surprising" tokens (above mu)
  736. mininds = torch.argmin(logit_surprise, dim=-1)
  737. miro_mask.scatter_(
  738. 1, mininds.unsqueeze(dim=-1), False
  739. ) # Force at least one outcome to be possible, ideally the most likely one
  740. logits[miro_mask] = -float("inf")
  741. probs = torch.softmax(logits, dim=-1,
  742. dtype=logits.dtype) # Get probs, post-mask
  743. # NOTE: Mirostat updates its `mu` values based on the sample chosen.
  744. # The silly approach here is to just sample it and make the logits one-hot.
  745. # This breaks fine grained seeding, but we don't have that yet.
  746. # TODO: FIX when it gets added
  747. next_token_ids = _multinomial(probs, num_samples=1)
  748. # Calculation new `mu` values
  749. # NOTE: If we can know the logit values of the PREVIOUS iteration,
  750. # it should be possible to update `mu` before applying mirostat each
  751. # iteration, thus letting us keep _sample as the last thing that happens.
  752. picked_surprises = torch.gather(logit_surprise,
  753. dim=-1,
  754. index=next_token_ids)
  755. eps = picked_surprises.squeeze() - taus
  756. mus.sub_(etas * eps)
  757. logits.fill_(-float("inf"))
  758. # This value doesn't actually matter, so long as it's not -inf.
  759. # Vectors are now one-hot, after all.
  760. logits.scatter_(1, next_token_ids, 1.0)
  761. return logits
  762. def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
  763. output_metadata: OutputMetadata) -> torch.Tensor:
  764. idx = sampling_tensors.miro_indices
  765. seqids = sampling_tensors.miro_seqids
  766. taus = sampling_tensors.miro_taus
  767. etas = sampling_tensors.miro_etas
  768. mus = sampling_tensors.miro_mus
  769. logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
  770. mus) # mus is an inout param, :vomit:
  771. _miro_store_args(seqids, mus, output_metadata)
  772. return logits