sampler.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. """A layer that samples the next tokens from the model's outputs."""
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  6. OutputMetadata,
  7. SamplingTensors)
  8. from aphrodite.modeling.megatron.communication_op import (
  9. tensor_model_parallel_gather)
  10. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  11. from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
  12. SamplerOutput, SequenceData,
  13. SequenceGroupOutput, SequenceOutput)
  14. class Sampler(nn.Module):
  15. """Samples the next tokens from the model's outputs.
  16. This layer does the following:
  17. 1. Discard the hidden states that are not used for sampling (i.e., all
  18. tokens except the final one in each prompt).
  19. 2. Compute the logits for the next tokens.
  20. 3. Apply presence and frequency penalties.
  21. 4. Apply temperature scaling.
  22. 5. Apply top-p and top-k truncation.
  23. 6. Sample the next tokens.
  24. Here, each sequence group within the batch can have different sampling
  25. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  26. """
  27. def __init__(self,
  28. vocab_size: int,
  29. org_vocab_size: Optional[int] = None) -> None:
  30. super().__init__()
  31. self.vocab_size = vocab_size
  32. # original vocabulary size (without LoRA).
  33. self.org_vocab_size = org_vocab_size or vocab_size
  34. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  35. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  36. # Get the logits for the next tokens.
  37. logits = torch.matmul(hidden_states, embedding.t())
  38. if embedding_bias is not None:
  39. logits += embedding_bias
  40. logits = tensor_model_parallel_gather(logits)
  41. # Remove paddings in vocab (if any).
  42. if logits is not None:
  43. logits = logits[:, :self.org_vocab_size]
  44. return logits
  45. def forward(
  46. self,
  47. embedding: torch.Tensor,
  48. hidden_states: torch.Tensor,
  49. sampling_metadata: SamplingMetadata,
  50. embedding_bias: Optional[torch.Tensor] = None,
  51. ) -> Optional[SamplerOutput]:
  52. # Get the hidden states that we use for sampling.
  53. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  54. # Get the logits for the next tokens.
  55. logits = self._get_logits(hidden_states, embedding, embedding_bias)
  56. return _perform_sampling(logits, sampling_metadata)
  57. # FIXME: This is a hack for the missing GPU blocks. This should be removed
  58. # once a proper fix is implemented.
  59. class QuantSampler(nn.Module):
  60. """Samples the next tokens from the model's outputs.
  61. This layer does the following:
  62. 1. Discard the hidden states that are not used for sampling (i.e., all
  63. tokens except the final one in each prompt).
  64. 2. Compute the logits for the next tokens.
  65. 3. Apply presence and frequency penalties.
  66. 4. Apply temperature scaling.
  67. 5. Apply top-p and top-k truncation.
  68. 6. Sample the next tokens.
  69. Here, each sequence group within the batch can have different sampling
  70. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  71. """
  72. def __init__(self,
  73. vocab_size: int,
  74. org_vocab_size: Optional[int] = None) -> None:
  75. super().__init__()
  76. self.vocab_size = vocab_size
  77. # original vocabulary size (without LoRA).
  78. self.org_vocab_size = org_vocab_size or vocab_size
  79. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  80. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  81. # Get the logits for the next tokens.
  82. logits = torch.matmul(hidden_states, embedding.t())
  83. if embedding_bias is not None:
  84. logits += embedding_bias
  85. logits = tensor_model_parallel_gather(logits)
  86. # Remove paddings in vocab (if any).
  87. if logits is not None:
  88. logits = logits[:, :self.org_vocab_size]
  89. return logits
  90. def forward(
  91. self,
  92. logits: torch.Tensor,
  93. sampling_metadata: SamplingMetadata,
  94. ) -> Optional[SamplerOutput]:
  95. # Get the hidden states that we use for sampling.
  96. logits = _prune_hidden_states(logits, sampling_metadata)
  97. logits = tensor_model_parallel_gather(logits)
  98. # Remove paddings in vocab (if any).
  99. if logits is not None:
  100. logits = logits[:, :self.vocab_size]
  101. return _perform_sampling(logits, sampling_metadata)
  102. def _perform_sampling(
  103. logits: torch.Tensor,
  104. sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
  105. # Only perform sampling in the driver worker.
  106. # Note: `_get_logits` is still distributed across TP workers because
  107. # the `embedding` weight is distributed across TP workers.
  108. # TODO: Change the get_logits part to a separate stage.
  109. if not sampling_metadata.perform_sampling:
  110. return None
  111. assert logits is not None
  112. _, vocab_size = logits.shape
  113. output_metadata = OutputMetadata()
  114. # Apply logits processors (if any)
  115. logits = _apply_logits_processors(logits, sampling_metadata)
  116. # Prepare sampling tensors with pinned memory to avoid blocking.
  117. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  118. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  119. do_typical_ps, do_quadratic,
  120. do_mirostat) = (SamplingTensors.from_sampling_metadata(
  121. sampling_metadata, vocab_size, logits.device, logits.dtype))
  122. if do_penalties:
  123. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  124. sampling_tensors.output_tokens,
  125. sampling_tensors.presence_penalties,
  126. sampling_tensors.frequency_penalties,
  127. sampling_tensors.repetition_penalties)
  128. if do_temperatures:
  129. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  130. sampling_tensors.dynatemp_mins,
  131. sampling_tensors.dynatemp_maxs,
  132. sampling_tensors.dynatemp_exps)
  133. if do_topks or do_topps or do_topas or do_minps:
  134. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  135. sampling_tensors.top_ks,
  136. sampling_tensors.top_as,
  137. sampling_tensors.min_ps)
  138. if do_tfss:
  139. logits = _apply_tfs(logits, sampling_tensors.tfss)
  140. if do_eta_cutoffs:
  141. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  142. if do_epsilon_cutoffs:
  143. logits = _apply_epsilon_cutoff(logits,
  144. sampling_tensors.epsilon_cutoffs)
  145. if do_typical_ps:
  146. logits = _apply_typical_sampling(logits, sampling_tensors.typical_ps)
  147. if do_quadratic:
  148. logits = _apply_quadratic_sampling(logits,
  149. sampling_tensors.smoothing_factors,
  150. sampling_tensors.smoothing_curves)
  151. banned_tokens = _get_custom_token_bans(sampling_metadata)
  152. assert len(banned_tokens) == logits.shape[0]
  153. logits = _apply_token_bans(logits, banned_tokens)
  154. if do_mirostat:
  155. logits = _mirostat(logits, sampling_tensors, output_metadata)
  156. # We use float32 for probabilities and log probabilities.
  157. # Compute the probabilities.
  158. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  159. # Compute the log probabilities.
  160. # Use log_softmax to ensure numerical stability.
  161. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  162. # Sample the next tokens.
  163. sample_results = _sample(probs, logprobs, sampling_metadata)
  164. # Get the logprobs query results.
  165. prompt_logprobs, sample_logprobs = _get_logprobs(logprobs,
  166. sampling_metadata,
  167. sample_results)
  168. return _build_sampler_output(sample_results, sampling_metadata,
  169. prompt_logprobs, sample_logprobs,
  170. output_metadata)
  171. def _prune_hidden_states(
  172. hidden_states: torch.Tensor,
  173. sampling_metadata: SamplingMetadata,
  174. ) -> torch.Tensor:
  175. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  176. return hidden_states.index_select(0,
  177. sampling_metadata.selected_token_indices)
  178. def _get_bin_counts_and_mask(
  179. tokens: torch.Tensor,
  180. vocab_size: int,
  181. num_seqs: int,
  182. ) -> Tuple[torch.Tensor, torch.Tensor]:
  183. # Compute the bin counts for the tokens.
  184. # vocab_size + 1 for padding.
  185. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  186. dtype=torch.long,
  187. device=tokens.device)
  188. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  189. bin_counts = bin_counts[:, :vocab_size]
  190. mask = bin_counts > 0
  191. return bin_counts, mask
  192. def _get_custom_token_bans(
  193. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  194. banned_tokens: List[List[int]] = []
  195. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  196. seq_ids, sampling_params = seq_group
  197. custom_token_bans = sampling_params.custom_token_bans
  198. if (i < sampling_metadata.num_prompts
  199. and sampling_params.prompt_logprobs is not None):
  200. prompt_len = sampling_metadata.prompt_lens[i]
  201. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  202. banned_tokens += [custom_token_bans] * len(seq_ids)
  203. return banned_tokens
  204. # def _apply_logits_processors(
  205. # logits: torch.Tensor,
  206. # metadata: SamplingMetadata,
  207. # ) -> torch.Tensor:
  208. # seq_offset = 0
  209. # for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
  210. # seq_size = len(seq_ids)
  211. # output_tokens = []
  212. # if (i < metadata.num_prompts
  213. # and sampling_params.prompt_logprobs is not None):
  214. # prompt_seqs = metadata.prompt_lens[i] - 1
  215. # seq_size += prompt_seqs
  216. # output_tokens.extend([[]] * prompt_seqs)
  217. # seq_end = seq_offset + seq_size
  218. # if sampling_params.logits_processors:
  219. # output_tokens.extend(metadata.seq_data[sid].output_token_ids
  220. # for sid in seq_ids)
  221. # for proc in sampling_params.logits_processors:
  222. # proc(logits[seq_offset:seq_end], output_tokens)
  223. # seq_offset = seq_end
  224. # return logits
  225. def _apply_logits_processors(
  226. logits: torch.Tensor,
  227. sampling_metadata: SamplingMetadata,
  228. ) -> torch.Tensor:
  229. logits_row_idx = 0
  230. found_logits_processors = False
  231. for seq_ids, sampling_params in sampling_metadata.seq_groups:
  232. logits_processors = sampling_params.logits_processors
  233. if logits_processors:
  234. found_logits_processors = True
  235. for seq_id in seq_ids:
  236. logits_row = logits[logits_row_idx]
  237. token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
  238. for logits_processor in logits_processors:
  239. logits_row = logits_processor(token_ids, logits_row)
  240. logits[logits_row_idx] = logits_row
  241. logits_row_idx += 1
  242. else:
  243. logits_row_idx += len(seq_ids)
  244. if found_logits_processors:
  245. assert logits_row_idx == logits.shape[0]
  246. return logits
  247. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  248. output_tokens_tensor: torch.Tensor,
  249. presence_penalties: torch.Tensor,
  250. frequency_penalties: torch.Tensor,
  251. repetition_penalties: torch.Tensor) -> torch.Tensor:
  252. num_seqs, vocab_size = logits.shape
  253. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  254. num_seqs)
  255. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  256. output_tokens_tensor, vocab_size, num_seqs)
  257. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  258. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  259. logits = torch.where(logits > 0, logits / repetition_penalties,
  260. logits * repetition_penalties)
  261. # We follow the definition in OpenAI API.
  262. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  263. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  264. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  265. return logits
  266. def _apply_token_bans(logits: torch.Tensor,
  267. banned_tokens: List[List[int]]) -> torch.Tensor:
  268. for i, banned_token_ids in enumerate(banned_tokens):
  269. if not banned_token_ids:
  270. continue
  271. logits[i, banned_token_ids] = -float("inf")
  272. return logits
  273. def _apply_alphabet_soup(
  274. logits: torch.Tensor,
  275. p: torch.Tensor,
  276. k: torch.Tensor,
  277. a: torch.Tensor,
  278. m: torch.Tensor,
  279. ) -> torch.Tensor:
  280. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  281. # Apply top-p, min-p and top-a.
  282. probs_sort = logits_sort.softmax(dim=-1)
  283. probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
  284. min_p_thresholds = probs_sort[:, 0] * m
  285. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
  286. threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
  287. mask = (probs_sort < threshold.unsqueeze(1)
  288. ) # Cull logits below the top-a threshold
  289. mask.logical_or_(
  290. probs_sum >
  291. p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
  292. mask[:, 0] = False # Guarantee at least one token is pickable
  293. logits_sort[mask] = -float("inf")
  294. # Apply top-k.
  295. # Create a mask for the top-k elements.
  296. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  297. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  298. top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
  299. # Final mask.
  300. mask = (mask | top_k_mask)
  301. logits_sort.masked_fill_(mask, -float("inf"))
  302. # Re-sort the probabilities.
  303. src = torch.arange(logits_idx.shape[-1],
  304. device=logits_idx.device).expand_as(logits_idx)
  305. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  306. index=logits_idx,
  307. src=src)
  308. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  309. return logits
  310. def _apply_tfs(
  311. logits: torch.Tensor,
  312. tfs: torch.Tensor,
  313. ) -> torch.Tensor:
  314. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  315. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  316. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  317. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  318. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  319. tfs_mask = torch.cat(
  320. (
  321. torch.zeros(
  322. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  323. tfs_mask,
  324. torch.ones(
  325. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  326. ),
  327. dim=-1,
  328. )
  329. logits_sort[tfs_mask] = -float("inf")
  330. logits = torch.gather(logits_sort,
  331. dim=-1,
  332. index=torch.argsort(logits_idx, dim=-1))
  333. return logits
  334. def _apply_eta_cutoff(
  335. logits: torch.Tensor,
  336. eta_cutoff: torch.Tensor,
  337. ) -> torch.Tensor:
  338. eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
  339. device=logits.device) * 1e-4
  340. shifted_logits = torch.log_softmax(logits, dim=-1)
  341. probs = shifted_logits.exp()
  342. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  343. eps = torch.min(eta,
  344. torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
  345. eta_mask = probs < eps
  346. if torch.all(eta_mask): # guard against nulling out all the logits
  347. topk_prob, _ = torch.max(probs, dim=-1)
  348. eta_mask = probs < topk_prob
  349. logits[eta_mask] = -float("inf")
  350. return logits
  351. def _apply_epsilon_cutoff(
  352. logits: torch.Tensor,
  353. epsilon_cutoff: torch.Tensor,
  354. ) -> torch.Tensor:
  355. eps = torch.tensor(epsilon_cutoff,
  356. dtype=logits.dtype,
  357. device=logits.device).unsqueeze(dim=1)
  358. probs = logits.softmax(dim=-1)
  359. eps_mask = probs < (eps * 1e-4)
  360. if torch.all(eps_mask): # guard against nulling out all the logits
  361. topk_prob, _ = torch.max(probs, dim=-1)
  362. eps_mask = probs < topk_prob
  363. logits[eps_mask] = -float("inf")
  364. return logits
  365. def _apply_typical_sampling(
  366. logits: torch.Tensor,
  367. typical_p: torch.Tensor,
  368. ) -> torch.Tensor:
  369. typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
  370. shifted_logits = torch.log_softmax(logits, dim=-1)
  371. probs = shifted_logits.exp()
  372. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  373. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  374. _, indices = torch.sort(surprisal_deviations)
  375. reordered_probs = probs.gather(-1, indices)
  376. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
  377. min_tokens_to_keep = 1
  378. # Keep at least min_tokens_to_keep
  379. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  380. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  381. logits[typ_mask] = -float("inf")
  382. return logits
  383. # pulls double duty for temperature and dynatemp
  384. def _apply_temperature(
  385. logits: torch.Tensor,
  386. temperatures: torch.Tensor,
  387. dynatemp_mins: torch.Tensor,
  388. dynatemp_maxs: torch.Tensor,
  389. dynatemp_exps: torch.Tensor,
  390. ) -> torch.Tensor:
  391. dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
  392. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  393. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  394. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  395. dynatemp_mins = dynatemp_mins.clamp_(min=0)
  396. dynatemp_logits = logits[dynatemp_mask]
  397. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  398. dynatemp_probs = dynatemp_shifted_logits.exp()
  399. dynatemp_entropies = -(dynatemp_probs *
  400. dynatemp_shifted_logits).nansum(dim=-1)
  401. dynatemp_max_entropies = torch.log_(
  402. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  403. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  404. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  405. normalized_entropies.pow_(dynatemp_exps))
  406. temperatures[dynatemp_mask] = dyn_temp
  407. temperatures[temperatures == 0.0] = 1.0
  408. logits.div_(temperatures.unsqueeze_(dim=1))
  409. return logits
  410. def _apply_quadratic_sampling(
  411. logits: torch.Tensor,
  412. smoothing_factors: torch.Tensor,
  413. smoothing_curves: torch.Tensor,
  414. ) -> torch.Tensor:
  415. """
  416. Applies a quadratic transformation to the logits based on the
  417. provided smoothing factors and curves. The transformation is
  418. centered around the maximum logit value in the batch.
  419. The transformation involves a quadratic and cubic term, with the
  420. cubic term controlled by the smoothing curve. The quadratic term is
  421. scaled by the smoothing factor, and the cubic term is scaled by the
  422. product of the smoothing factor and the smoothing curve.
  423. params:
  424. logits (torch.Tensor): The logits to be transformed.
  425. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  426. term in the transformation.
  427. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  428. in the transformation.
  429. returns:
  430. torch.Tensor: The transformed logits.
  431. Credits: @kalomaze
  432. """
  433. max_logits = logits.max(dim=-1, keepdim=True).values
  434. diff = logits - max_logits
  435. smoothing_factors.unsqueeze_(dim=1)
  436. smoothing_curves.unsqueeze_(dim=1)
  437. k = (3 - smoothing_curves) / 2
  438. s = (smoothing_curves - 1) / 2
  439. mask = smoothing_factors > 0
  440. mask = mask.flatten()
  441. transformed_logits = torch.where(
  442. logits != float('-inf'), -(k * smoothing_factors * diff**2) +
  443. (s * smoothing_factors * diff**3) + max_logits, logits)
  444. logits[mask, :] = transformed_logits[mask, :]
  445. return logits
  446. def _greedy_sample(
  447. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  448. samples: torch.Tensor,
  449. ) -> List[Tuple[List[int], List[int]]]:
  450. samples = samples.tolist()
  451. sample_idx = 0
  452. results = []
  453. for seq_group in selected_seq_groups:
  454. seq_ids, _ = seq_group
  455. num_parent_seqs = len(seq_ids)
  456. assert num_parent_seqs == 1, (
  457. "Greedy sampling should have only one seq.")
  458. parent_ids = list(range(num_parent_seqs))
  459. next_token_ids = [samples[sample_idx]]
  460. results.append((next_token_ids, parent_ids))
  461. sample_idx += num_parent_seqs
  462. return results
  463. def _random_sample(
  464. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  465. is_prompts: List[bool],
  466. random_samples: torch.Tensor,
  467. ) -> List[Tuple[List[int], List[int]]]:
  468. # Find the maximum best_of value of the prompt phase requests.
  469. random_samples = random_samples.cpu()
  470. sample_idx = 0
  471. results = []
  472. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  473. seq_ids, sampling_params = seq_group
  474. num_parent_seqs = len(seq_ids)
  475. if is_prompt:
  476. # Prompt phase.
  477. parent_ids = [0] * sampling_params.best_of
  478. next_token_ids = random_samples[
  479. sample_idx, :sampling_params.best_of].tolist()
  480. else:
  481. # Generation phase.
  482. parent_ids = list(range(num_parent_seqs))
  483. next_token_ids = random_samples[sample_idx:sample_idx +
  484. num_parent_seqs, 0].tolist()
  485. results.append((next_token_ids, parent_ids))
  486. sample_idx += num_parent_seqs
  487. return results
  488. def _beam_search_sample(
  489. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  490. is_prompts: List[bool],
  491. seq_data: Dict[int, SequenceData],
  492. logprobs: torch.Tensor,
  493. ) -> List[Tuple[List[int], List[int]]]:
  494. # We sample 2 * beam_width candidates to make sure that with high
  495. # probability we can get `beam_width` candidates in addition to
  496. # the finished sequences for the next iteration. See
  497. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  498. # for details. See also HF reference:
  499. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  500. #
  501. # Note: Beam search is not vectorized, so its speed can be slower than
  502. # other sampling methods.
  503. sample_idx = 0
  504. results = []
  505. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  506. seq_ids, sampling_params = seq_group
  507. num_parent_seqs = len(seq_ids)
  508. beam_width = sampling_params.best_of
  509. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  510. if is_prompt:
  511. # Prompt phase.
  512. assert num_parent_seqs == 1, (
  513. "Prompt input should have only one seq.")
  514. parent_ids = [0] * (2 * beam_width)
  515. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  516. 2 * beam_width)
  517. next_token_ids = next_token_ids.tolist()
  518. else:
  519. # Generation phase.
  520. cumulative_logprobs = [
  521. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
  522. ]
  523. cumulative_logprobs = torch.tensor(
  524. cumulative_logprobs,
  525. dtype=torch.float,
  526. device=seq_group_logprobs.device)
  527. seq_group_logprobs = (seq_group_logprobs +
  528. cumulative_logprobs.unsqueeze(dim=1))
  529. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  530. 2 * beam_width)
  531. topk_ids = topk_ids.tolist()
  532. vocab_size = seq_group_logprobs.size(-1)
  533. parent_ids = [i // vocab_size for i in topk_ids]
  534. next_token_ids = [i % vocab_size for i in topk_ids]
  535. results.append((next_token_ids, parent_ids))
  536. sample_idx += num_parent_seqs
  537. assert sample_idx == logprobs.size(0)
  538. return results
  539. # torch.multinomial forces a GPU<->CPU sync.
  540. # Therefore, we use an optimized implementation instead.
  541. # Note that we always sample with replacement.
  542. # probs will be modified in place, but this is fine, as we pass
  543. # in a copy already.
  544. def _multinomial(
  545. probs: torch.Tensor,
  546. num_samples: int,
  547. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
  548. generators: Optional[List[torch.Generator]] = None,
  549. ) -> torch.Tensor:
  550. if num_samples > 1:
  551. # This is equivalent to torch.repeat_interleaved (which also
  552. # forces a GPU<->CPU sync).
  553. # This allows us to do sampling with replacement by creating
  554. # num_samples copies of each row in the tensor, and then
  555. # batch sampling the resulting tensor.
  556. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  557. probs.shape[1]).contiguous().view(
  558. -1, probs.shape[1])
  559. q = torch.empty_like(probs)
  560. if seq_groups is None:
  561. q.exponential_()
  562. else:
  563. sample_idx = 0
  564. for (seq_ids, _), generator in zip(seq_groups, generators):
  565. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  566. q[sample_idx:next_sample_idx].exponential_(generator=generator)
  567. sample_idx = next_sample_idx
  568. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  569. def _sample(
  570. probs: torch.Tensor,
  571. logprobs: torch.Tensor,
  572. sampling_metadata: SamplingMetadata,
  573. ) -> List[Tuple[List[int], List[int]]]:
  574. categorized_seq_group_ids = {t: [] for t in SamplingType}
  575. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  576. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  577. _, sampling_params = seq_group
  578. sampling_type = sampling_params.sampling_type
  579. categorized_seq_group_ids[sampling_type].append(i)
  580. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  581. sample_metadata = {}
  582. multinomial_samples = {}
  583. # Counterintuitively, having two loops here is actually faster.
  584. # The first loop can run without waiting on GPU<->CPU sync.
  585. for sampling_type, sample_indices in categorized_sample_indices.items():
  586. if len(sample_indices) == 0:
  587. continue
  588. seq_group_ids = categorized_seq_group_ids[sampling_type]
  589. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
  590. is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
  591. sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
  592. is_prompts, sample_indices)
  593. if sampling_type == SamplingType.GREEDY:
  594. greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
  595. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  596. max_best_of = 1
  597. for seq_group, is_prompt in zip(seq_groups, is_prompts):
  598. if is_prompt:
  599. _, sampling_params = seq_group
  600. max_best_of = max(max_best_of, sampling_params.best_of)
  601. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  602. "seq_groups": seq_groups,
  603. "generators": sampling_metadata.generators,
  604. }
  605. multinomial_samples[sampling_type] = _multinomial(
  606. probs[sample_indices], max_best_of, **seeded_args)
  607. elif sampling_type == SamplingType.BEAM:
  608. beam_search_logprobs = logprobs[sample_indices]
  609. else:
  610. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  611. # GPU<->CPU sync happens in the loop below.
  612. for sampling_type, metadata in sample_metadata.items():
  613. seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
  614. if sampling_type == SamplingType.GREEDY:
  615. sample_results = _greedy_sample(seq_groups, greedy_samples)
  616. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  617. sample_results = _random_sample(seq_groups, is_prompts,
  618. multinomial_samples[sampling_type])
  619. elif sampling_type == SamplingType.BEAM:
  620. sample_results = _beam_search_sample(seq_groups, is_prompts,
  621. sampling_metadata.seq_data,
  622. beam_search_logprobs)
  623. sample_results_dict.update(zip(seq_group_ids, sample_results))
  624. sample_results = [
  625. sample_results_dict[i]
  626. for i in range(len(sampling_metadata.seq_groups))
  627. ]
  628. return sample_results
  629. def _get_logprobs(
  630. logprobs: torch.Tensor,
  631. sampling_metadata: SamplingMetadata,
  632. sample_results: List[Tuple[List[int], List[int]]],
  633. ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
  634. int, float]]]]:
  635. # Prepare query indices
  636. batched_logprobs_query_seq_indices: List[int] = []
  637. batched_logprobs_query_token_indices: List[int] = []
  638. largest_num_logprobs = 0
  639. sample_idx = 0
  640. for i, (seq_group, sample_result) in enumerate(
  641. zip(sampling_metadata.seq_groups, sample_results)):
  642. seq_ids, sampling_params = seq_group
  643. next_token_ids, parent_ids = sample_result
  644. num_parent_seqs = len(seq_ids)
  645. if (i < sampling_metadata.num_prompts
  646. and sampling_params.prompt_logprobs is not None):
  647. largest_num_logprobs = max(largest_num_logprobs,
  648. sampling_params.prompt_logprobs)
  649. prompt_len = sampling_metadata.prompt_lens[i]
  650. prompt_tokens = sampling_metadata.seq_data[
  651. seq_ids[0]].prompt_token_ids
  652. batched_logprobs_query_seq_indices.extend(
  653. sample_idx + j for j in range(prompt_len - 1))
  654. batched_logprobs_query_token_indices.extend(
  655. token_id for token_id in prompt_tokens[1:])
  656. sample_idx += prompt_len - 1
  657. batched_logprobs_query_seq_indices.extend(
  658. [sample_idx + parent_id for parent_id in parent_ids])
  659. batched_logprobs_query_token_indices.extend(next_token_ids)
  660. if sampling_params.logprobs is not None:
  661. largest_num_logprobs = max(largest_num_logprobs,
  662. sampling_params.logprobs)
  663. sample_idx += num_parent_seqs
  664. assert sample_idx == logprobs.size(0)
  665. # Batched query for logprobs of selected token
  666. batched_logprobs_query_result = logprobs[[
  667. batched_logprobs_query_seq_indices,
  668. batched_logprobs_query_token_indices
  669. ]]
  670. # Batched query for logprobs of topk tokens
  671. if largest_num_logprobs > 0:
  672. top_logprobs, top_token_ids = torch.topk(logprobs,
  673. largest_num_logprobs,
  674. dim=-1)
  675. top_logprobs = top_logprobs.cpu()
  676. top_token_ids = top_token_ids.cpu()
  677. else:
  678. top_logprobs, top_token_ids = None, None
  679. batched_logprobs_query_result = batched_logprobs_query_result.cpu()
  680. # Gather results
  681. result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
  682. result_sample_logprobs: List[SampleLogprobs] = []
  683. sample_idx = 0
  684. query_result_idx = 0
  685. for i, (seq_group, sample_result) in enumerate(
  686. zip(sampling_metadata.seq_groups, sample_results)):
  687. seq_ids, sampling_params = seq_group
  688. next_token_ids, parent_ids = sample_result
  689. # Prompt logprobs
  690. if (i < sampling_metadata.num_prompts
  691. and sampling_params.prompt_logprobs is not None):
  692. num_logprobs = sampling_params.prompt_logprobs
  693. prompt_tokens = sampling_metadata.seq_data[
  694. seq_ids[0]].prompt_token_ids
  695. group_prompt_logprobs: PromptLogprobs = [None]
  696. for token_id in prompt_tokens[1:]:
  697. prompt_logprobs_dict = {
  698. token_id:
  699. batched_logprobs_query_result[query_result_idx].item()
  700. }
  701. if num_logprobs > 0:
  702. prompt_logprobs_dict.update(
  703. zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
  704. top_logprobs[sample_idx, :num_logprobs].tolist()))
  705. group_prompt_logprobs.append({
  706. token_id: Logprob(logprob)
  707. for token_id, logprob in prompt_logprobs_dict.items()
  708. })
  709. sample_idx += 1
  710. query_result_idx += 1
  711. result_prompt_logprobs.append(group_prompt_logprobs)
  712. else:
  713. result_prompt_logprobs.append(None)
  714. # Sample logprobs
  715. num_logprobs = sampling_params.logprobs
  716. if num_logprobs is None:
  717. num_logprobs = 0
  718. group_sample_logprobs: SampleLogprobs = []
  719. for next_token_id, parent_id in zip(next_token_ids, parent_ids):
  720. sample_logprobs_dict = {
  721. next_token_id:
  722. batched_logprobs_query_result[query_result_idx].item()
  723. }
  724. query_result_idx += 1
  725. if num_logprobs > 0:
  726. sample_logprobs_dict.update(
  727. zip(
  728. top_token_ids[sample_idx +
  729. parent_id, :num_logprobs].tolist(),
  730. top_logprobs[sample_idx +
  731. parent_id, :num_logprobs].tolist()))
  732. group_sample_logprobs.append({
  733. token_id: Logprob(logprob)
  734. for token_id, logprob in sample_logprobs_dict.items()
  735. })
  736. result_sample_logprobs.append(group_sample_logprobs)
  737. sample_idx += len(seq_ids)
  738. return result_prompt_logprobs, result_sample_logprobs
  739. def _build_sampler_output(
  740. sample_results: List[Tuple[List[int], List[int]]],
  741. sampling_metadata: SamplingMetadata,
  742. prompt_logprobs: List[Optional[PromptLogprobs]],
  743. sample_logprobs: List[SampleLogprobs],
  744. output_metadata: OutputMetadata,
  745. ) -> SamplerOutput:
  746. sampler_output = []
  747. for (seq_group, sample_result, group_prompt_logprobs,
  748. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  749. sample_results, prompt_logprobs,
  750. sample_logprobs):
  751. seq_ids, _ = seq_group
  752. next_token_ids, parent_ids = sample_result
  753. seq_outputs = []
  754. for parent_id, next_token_id, logprobs in zip(parent_ids,
  755. next_token_ids,
  756. group_sample_logprobs):
  757. seq_outputs.append(
  758. SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
  759. output_metadata.get(seq_ids[parent_id])))
  760. sampler_output.append(
  761. SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  762. return SamplerOutput(outputs=sampler_output)
  763. def _miro_store_args(seqids: List[int], mus: List[float],
  764. output_metadata: OutputMetadata) -> None:
  765. for sid, mu in zip(seqids,
  766. mus.tolist()): # tolist might be premature optimization
  767. output_metadata.add(sid, "miro_mu", mu)
  768. def _apply_mirostat_v2(
  769. logits: torch.Tensor,
  770. taus: torch.Tensor, # AKA the targeted surprise
  771. etas: torch.Tensor, # AKA the learning rate
  772. mus: torch.
  773. Tensor, # AKA the accumulator that always tries to approach [tau]
  774. ) -> torch.Tensor:
  775. logit_surprise = torch.softmax(
  776. logits, dim=-1).log2_().neg_() # Calculate surprise value per token
  777. # For compatibility with ooba/kobold, done in unit of bits(log base 2)
  778. # not nats(ln).
  779. # Ideally this would be a log_softmax, for numerical stability and
  780. # elegance purposes.
  781. # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
  782. miro_mask = logit_surprise > mus.unsqueeze(
  783. dim=-1) # Mask out "too-surprising" tokens (above mu)
  784. mininds = torch.argmin(logit_surprise, dim=-1)
  785. miro_mask.scatter_(
  786. 1, mininds.unsqueeze(dim=-1), False
  787. ) # Force at least one outcome to be possible, ideally the most likely one
  788. logits[miro_mask] = -float("inf")
  789. probs = torch.softmax(logits, dim=-1,
  790. dtype=logits.dtype) # Get probs, post-mask
  791. # NOTE: Mirostat updates its `mu` values based on the sample chosen.
  792. # The silly approach here is to just sample it and make the logits one-hot.
  793. # This breaks fine grained seeding, but we don't have that yet.
  794. # TODO: FIX when it gets added
  795. next_token_ids = _multinomial(probs, num_samples=1)
  796. # Calculation new `mu` values
  797. # NOTE: If we can know the logit values of the PREVIOUS iteration,
  798. # it should be possible to update `mu` before applying mirostat each
  799. # iteration, thus letting us keep _sample as the last thing that happens.
  800. picked_surprises = torch.gather(logit_surprise,
  801. dim=-1,
  802. index=next_token_ids)
  803. eps = picked_surprises.squeeze() - taus
  804. mus.sub_(etas * eps)
  805. logits.fill_(-float("inf"))
  806. # This value doesn't actually matter, so long as it's not -inf.
  807. # Vectors are now one-hot, after all.
  808. logits.scatter_(1, next_token_ids, 1.0)
  809. return logits
  810. def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
  811. output_metadata: OutputMetadata) -> torch.Tensor:
  812. idx = sampling_tensors.miro_indices
  813. seqids = sampling_tensors.miro_seqids
  814. taus = sampling_tensors.miro_taus
  815. etas = sampling_tensors.miro_etas
  816. mus = sampling_tensors.miro_mus
  817. logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
  818. mus) # mus is an i/o param, :vomit:
  819. _miro_store_args(seqids, mus, output_metadata)
  820. return logits