sampler.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020
  1. """A layer that samples the next tokens from the model's outputs."""
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  6. OutputMetadata,
  7. SamplingTensors)
  8. from aphrodite.modeling.megatron.communication_op import (
  9. tensor_model_parallel_gather)
  10. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  11. from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
  12. SamplerOutput, SequenceData,
  13. SequenceGroupOutput, SequenceOutput)
  14. class Sampler(nn.Module):
  15. """Samples the next tokens from the model's outputs.
  16. This layer does the following:
  17. 1. Discard the hidden states that are not used for sampling (i.e., all
  18. tokens except the final one in each prompt).
  19. 2. Compute the logits for the next tokens.
  20. 3. Apply presence and frequency penalties.
  21. 4. Apply temperature scaling.
  22. 5. Apply top-p and top-k truncation.
  23. 6. Sample the next tokens.
  24. Here, each sequence group within the batch can have different sampling
  25. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  26. """
  27. def __init__(self,
  28. vocab_size: int,
  29. org_vocab_size: Optional[int] = None) -> None:
  30. super().__init__()
  31. self.vocab_size = vocab_size
  32. # original vocabulary size (without LoRA).
  33. self.org_vocab_size = org_vocab_size or vocab_size
  34. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  35. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  36. # Get the logits for the next tokens.
  37. logits = torch.matmul(hidden_states, embedding.t())
  38. if embedding_bias is not None:
  39. logits += embedding_bias
  40. logits = tensor_model_parallel_gather(logits)
  41. # Remove paddings in vocab (if any).
  42. if logits is not None:
  43. logits = logits[:, :self.org_vocab_size]
  44. return logits
  45. def forward(
  46. self,
  47. embedding: torch.Tensor,
  48. hidden_states: torch.Tensor,
  49. sampling_metadata: SamplingMetadata,
  50. embedding_bias: Optional[torch.Tensor] = None,
  51. ) -> Optional[SamplerOutput]:
  52. # Get the hidden states that we use for sampling.
  53. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  54. # Get the logits for the next tokens.
  55. logits = self._get_logits(hidden_states, embedding, embedding_bias)
  56. # Only perform sampling in the driver worker.
  57. # Note: `_get_logits` is still distributed across TP workers because
  58. # the `embedding` weight is distributed across TP workers.
  59. # TODO: Change the get_logits part to a separate stage.
  60. if not sampling_metadata.perform_sampling:
  61. return None
  62. assert logits is not None
  63. _, vocab_size = logits.shape
  64. output_metadata = OutputMetadata()
  65. # Apply logits processors (if any)
  66. logits = _apply_logits_processors(logits, sampling_metadata)
  67. # Prepare sampling tensors with pinned memory to avoid blocking.
  68. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  69. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  70. do_typical_ps, do_quadratic,
  71. do_mirostat) = (SamplingTensors.from_sampling_metadata(
  72. sampling_metadata, vocab_size, logits.device, logits.dtype))
  73. if do_penalties:
  74. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  75. sampling_tensors.output_tokens,
  76. sampling_tensors.presence_penalties,
  77. sampling_tensors.frequency_penalties,
  78. sampling_tensors.repetition_penalties)
  79. if do_temperatures:
  80. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  81. sampling_tensors.dynatemp_mins,
  82. sampling_tensors.dynatemp_maxs,
  83. sampling_tensors.dynatemp_exps)
  84. if do_topks or do_topps or do_topas or do_minps:
  85. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  86. sampling_tensors.top_ks,
  87. sampling_tensors.top_as,
  88. sampling_tensors.min_ps)
  89. if do_tfss:
  90. logits = _apply_tfs(logits, sampling_tensors.tfss)
  91. if do_eta_cutoffs:
  92. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  93. if do_epsilon_cutoffs:
  94. logits = _apply_epsilon_cutoff(logits,
  95. sampling_tensors.epsilon_cutoffs)
  96. if do_typical_ps:
  97. logits = _apply_typical_sampling(logits,
  98. sampling_tensors.typical_ps)
  99. if do_quadratic:
  100. logits = _apply_quadratic_sampling(
  101. logits, sampling_tensors.smoothing_factors,
  102. sampling_tensors.smoothing_curves)
  103. banned_tokens = _get_custom_token_bans(sampling_metadata)
  104. assert len(banned_tokens) == logits.shape[0]
  105. logits = _apply_token_bans(logits, banned_tokens)
  106. if do_mirostat:
  107. logits = _mirostat(logits, sampling_tensors, output_metadata)
  108. # We use float32 for probabilities and log probabilities.
  109. # Compute the probabilities.
  110. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  111. # Compute the log probabilities.
  112. # Use log_softmax to ensure numerical stability.
  113. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  114. # Sample the next tokens.
  115. sample_results = _sample(probs, logprobs, sampling_metadata)
  116. # Get the logprobs query results.
  117. prompt_logprobs, sample_logprobs = _get_logprobs(
  118. logprobs, sampling_metadata, sample_results)
  119. return _build_sampler_output(sample_results, sampling_metadata,
  120. prompt_logprobs, sample_logprobs,
  121. output_metadata)
  122. # FIXME: This is a hack for the missing GPU blocks. This should be removed
  123. # once a proper fix is implemented.
  124. class QuantSampler(nn.Module):
  125. """Samples the next tokens from the model's outputs.
  126. This layer does the following:
  127. 1. Discard the hidden states that are not used for sampling (i.e., all
  128. tokens except the final one in each prompt).
  129. 2. Compute the logits for the next tokens.
  130. 3. Apply presence and frequency penalties.
  131. 4. Apply temperature scaling.
  132. 5. Apply top-p and top-k truncation.
  133. 6. Sample the next tokens.
  134. Here, each sequence group within the batch can have different sampling
  135. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  136. """
  137. def __init__(self,
  138. vocab_size: int,
  139. org_vocab_size: Optional[int] = None) -> None:
  140. super().__init__()
  141. self.vocab_size = vocab_size
  142. # original vocabulary size (without LoRA).
  143. self.org_vocab_size = org_vocab_size or vocab_size
  144. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  145. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  146. # Get the logits for the next tokens.
  147. logits = torch.matmul(hidden_states, embedding.t())
  148. if embedding_bias is not None:
  149. logits += embedding_bias
  150. logits = tensor_model_parallel_gather(logits)
  151. # Remove paddings in vocab (if any).
  152. if logits is not None:
  153. logits = logits[:, :self.org_vocab_size]
  154. return logits
  155. def forward(
  156. self,
  157. logits: torch.Tensor,
  158. sampling_metadata: SamplingMetadata,
  159. ) -> Optional[SamplerOutput]:
  160. # Get the hidden states that we use for sampling.
  161. logits = _prune_hidden_states(logits, sampling_metadata)
  162. logits = tensor_model_parallel_gather(logits)
  163. # Remove paddings in vocab (if any).
  164. if logits is not None:
  165. logits = logits[:, :self.vocab_size]
  166. # Only perform sampling in the driver worker.
  167. # Note: `_get_logits` is still distributed across TP workers because
  168. # the `embedding` weight is distributed across TP workers.
  169. # TODO: Change the get_logits part to a separate stage.
  170. if not sampling_metadata.perform_sampling:
  171. return None
  172. assert logits is not None
  173. _, vocab_size = logits.shape
  174. output_metadata = OutputMetadata()
  175. # Apply logits processors (if any)
  176. logits = _apply_logits_processors(logits, sampling_metadata)
  177. # Prepare sampling tensors with pinned memory to avoid blocking.
  178. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  179. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  180. do_typical_ps, do_quadratic,
  181. do_mirostat) = (SamplingTensors.from_sampling_metadata(
  182. sampling_metadata, vocab_size, logits.device, logits.dtype))
  183. if do_penalties:
  184. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  185. sampling_tensors.output_tokens,
  186. sampling_tensors.presence_penalties,
  187. sampling_tensors.frequency_penalties,
  188. sampling_tensors.repetition_penalties)
  189. if do_temperatures:
  190. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  191. sampling_tensors.dynatemp_mins,
  192. sampling_tensors.dynatemp_maxs,
  193. sampling_tensors.dynatemp_exps)
  194. if do_topks or do_topps or do_topas or do_minps:
  195. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  196. sampling_tensors.top_ks,
  197. sampling_tensors.top_as,
  198. sampling_tensors.min_ps)
  199. if do_tfss:
  200. logits = _apply_tfs(logits, sampling_tensors.tfss)
  201. if do_eta_cutoffs:
  202. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  203. if do_epsilon_cutoffs:
  204. logits = _apply_epsilon_cutoff(logits,
  205. sampling_tensors.epsilon_cutoffs)
  206. if do_typical_ps:
  207. logits = _apply_typical_sampling(logits,
  208. sampling_tensors.typical_ps)
  209. if do_quadratic:
  210. logits = _apply_quadratic_sampling(
  211. logits, sampling_tensors.smoothing_factors,
  212. sampling_tensors.smoothing_curves)
  213. banned_tokens = _get_custom_token_bans(sampling_metadata)
  214. assert len(banned_tokens) == logits.shape[0]
  215. logits = _apply_token_bans(logits, banned_tokens)
  216. if do_mirostat:
  217. logits = _mirostat(logits, sampling_tensors, output_metadata)
  218. # We use float32 for probabilities and log probabilities.
  219. # Compute the probabilities.
  220. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  221. # Compute the log probabilities.
  222. # Use log_softmax to ensure numerical stability.
  223. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  224. # Sample the next tokens.
  225. sample_results = _sample(probs, logprobs, sampling_metadata)
  226. # Get the logprobs query results.
  227. prompt_logprobs, sample_logprobs = _get_logprobs(
  228. logprobs, sampling_metadata, sample_results)
  229. return _build_sampler_output(sample_results, sampling_metadata,
  230. prompt_logprobs, sample_logprobs,
  231. output_metadata)
  232. def _prune_hidden_states(
  233. hidden_states: torch.Tensor,
  234. sampling_metadata: SamplingMetadata,
  235. ) -> torch.Tensor:
  236. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  237. return hidden_states.index_select(0,
  238. sampling_metadata.selected_token_indices)
  239. def _get_bin_counts_and_mask(
  240. tokens: torch.Tensor,
  241. vocab_size: int,
  242. num_seqs: int,
  243. ) -> Tuple[torch.Tensor, torch.Tensor]:
  244. # Compute the bin counts for the tokens.
  245. # vocab_size + 1 for padding.
  246. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  247. dtype=torch.long,
  248. device=tokens.device)
  249. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  250. bin_counts = bin_counts[:, :vocab_size]
  251. mask = bin_counts > 0
  252. return bin_counts, mask
  253. def _get_custom_token_bans(
  254. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  255. banned_tokens: List[List[int]] = []
  256. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  257. seq_ids, sampling_params = seq_group
  258. custom_token_bans = sampling_params.custom_token_bans
  259. if (i < sampling_metadata.num_prompts
  260. and sampling_params.prompt_logprobs is not None):
  261. prompt_len = sampling_metadata.prompt_lens[i]
  262. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  263. banned_tokens += [custom_token_bans] * len(seq_ids)
  264. return banned_tokens
  265. # def _apply_logits_processors(
  266. # logits: torch.Tensor,
  267. # metadata: SamplingMetadata,
  268. # ) -> torch.Tensor:
  269. # seq_offset = 0
  270. # for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
  271. # seq_size = len(seq_ids)
  272. # output_tokens = []
  273. # if (i < metadata.num_prompts
  274. # and sampling_params.prompt_logprobs is not None):
  275. # prompt_seqs = metadata.prompt_lens[i] - 1
  276. # seq_size += prompt_seqs
  277. # output_tokens.extend([[]] * prompt_seqs)
  278. # seq_end = seq_offset + seq_size
  279. # if sampling_params.logits_processors:
  280. # output_tokens.extend(metadata.seq_data[sid].output_token_ids
  281. # for sid in seq_ids)
  282. # for proc in sampling_params.logits_processors:
  283. # proc(logits[seq_offset:seq_end], output_tokens)
  284. # seq_offset = seq_end
  285. # return logits
  286. def _apply_logits_processors(
  287. logits: torch.Tensor,
  288. sampling_metadata: SamplingMetadata,
  289. ) -> torch.Tensor:
  290. logits_row_idx = 0
  291. found_logits_processors = False
  292. for seq_ids, sampling_params in sampling_metadata.seq_groups:
  293. logits_processors = sampling_params.logits_processors
  294. if logits_processors:
  295. found_logits_processors = True
  296. for seq_id in seq_ids:
  297. logits_row = logits[logits_row_idx]
  298. token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
  299. for logits_processor in logits_processors:
  300. logits_row = logits_processor(token_ids, logits_row)
  301. logits[logits_row_idx] = logits_row
  302. logits_row_idx += 1
  303. else:
  304. logits_row_idx += len(seq_ids)
  305. if found_logits_processors:
  306. assert logits_row_idx == logits.shape[0]
  307. return logits
  308. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  309. output_tokens_tensor: torch.Tensor,
  310. presence_penalties: torch.Tensor,
  311. frequency_penalties: torch.Tensor,
  312. repetition_penalties: torch.Tensor) -> torch.Tensor:
  313. num_seqs, vocab_size = logits.shape
  314. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  315. num_seqs)
  316. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  317. output_tokens_tensor, vocab_size, num_seqs)
  318. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  319. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  320. logits = torch.where(logits > 0, logits / repetition_penalties,
  321. logits * repetition_penalties)
  322. # We follow the definition in OpenAI API.
  323. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  324. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  325. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  326. return logits
  327. def _apply_token_bans(logits: torch.Tensor,
  328. banned_tokens: List[List[int]]) -> torch.Tensor:
  329. for i, banned_token_ids in enumerate(banned_tokens):
  330. if not banned_token_ids:
  331. continue
  332. logits[i, banned_token_ids] = -float("inf")
  333. return logits
  334. def _apply_alphabet_soup(
  335. logits: torch.Tensor,
  336. p: torch.Tensor,
  337. k: torch.Tensor,
  338. a: torch.Tensor,
  339. m: torch.Tensor,
  340. ) -> torch.Tensor:
  341. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  342. # Apply top-p, min-p and top-a.
  343. probs_sort = logits_sort.softmax(dim=-1)
  344. probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
  345. min_p_thresholds = probs_sort[:, 0] * m
  346. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
  347. threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
  348. mask = (probs_sort < threshold.unsqueeze(1)
  349. ) # Cull logits below the top-a threshold
  350. mask.logical_or_(
  351. probs_sum >
  352. p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
  353. mask[:, 0] = False # Guarantee at least one token is pickable
  354. logits_sort[mask] = -float("inf")
  355. # Apply top-k.
  356. # Create a mask for the top-k elements.
  357. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  358. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  359. top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
  360. # Final mask.
  361. mask = (mask | top_k_mask)
  362. logits_sort.masked_fill_(mask, -float("inf"))
  363. # Re-sort the probabilities.
  364. src = torch.arange(logits_idx.shape[-1],
  365. device=logits_idx.device).expand_as(logits_idx)
  366. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  367. index=logits_idx,
  368. src=src)
  369. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  370. return logits
  371. def _apply_tfs(
  372. logits: torch.Tensor,
  373. tfs: torch.Tensor,
  374. ) -> torch.Tensor:
  375. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  376. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  377. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  378. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  379. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  380. tfs_mask = torch.cat(
  381. (
  382. torch.zeros(
  383. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  384. tfs_mask,
  385. torch.ones(
  386. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  387. ),
  388. dim=-1,
  389. )
  390. logits_sort[tfs_mask] = -float("inf")
  391. logits = torch.gather(logits_sort,
  392. dim=-1,
  393. index=torch.argsort(logits_idx, dim=-1))
  394. return logits
  395. def _apply_eta_cutoff(
  396. logits: torch.Tensor,
  397. eta_cutoff: torch.Tensor,
  398. ) -> torch.Tensor:
  399. eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
  400. device=logits.device) * 1e-4
  401. shifted_logits = torch.log_softmax(logits, dim=-1)
  402. probs = shifted_logits.exp()
  403. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  404. eps = torch.min(eta,
  405. torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
  406. eta_mask = probs < eps
  407. if torch.all(eta_mask): # guard against nulling out all the logits
  408. topk_prob, _ = torch.max(probs, dim=-1)
  409. eta_mask = probs < topk_prob
  410. logits[eta_mask] = -float("inf")
  411. return logits
  412. def _apply_epsilon_cutoff(
  413. logits: torch.Tensor,
  414. epsilon_cutoff: torch.Tensor,
  415. ) -> torch.Tensor:
  416. eps = torch.tensor(epsilon_cutoff,
  417. dtype=logits.dtype,
  418. device=logits.device).unsqueeze(dim=1)
  419. probs = logits.softmax(dim=-1)
  420. eps_mask = probs < (eps * 1e-4)
  421. if torch.all(eps_mask): # guard against nulling out all the logits
  422. topk_prob, _ = torch.max(probs, dim=-1)
  423. eps_mask = probs < topk_prob
  424. logits[eps_mask] = -float("inf")
  425. return logits
  426. def _apply_typical_sampling(
  427. logits: torch.Tensor,
  428. typical_p: torch.Tensor,
  429. ) -> torch.Tensor:
  430. typ_p = torch.tensor(typical_p, dtype=logits.dtype, device=logits.device)
  431. shifted_logits = torch.log_softmax(logits, dim=-1)
  432. probs = shifted_logits.exp()
  433. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  434. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  435. _, indices = torch.sort(surprisal_deviations)
  436. reordered_probs = probs.gather(-1, indices)
  437. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
  438. min_tokens_to_keep = 1
  439. # Keep at least min_tokens_to_keep
  440. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  441. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  442. logits[typ_mask] = -float("inf")
  443. return logits
  444. # pulls double duty for temperature and dynatemp
  445. def _apply_temperature(
  446. logits: torch.Tensor,
  447. temperatures: torch.Tensor,
  448. dynatemp_mins: torch.Tensor,
  449. dynatemp_maxs: torch.Tensor,
  450. dynatemp_exps: torch.Tensor,
  451. ) -> torch.Tensor:
  452. dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
  453. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  454. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  455. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  456. dynatemp_mins = dynatemp_mins.clamp_(min=0)
  457. dynatemp_logits = logits[dynatemp_mask]
  458. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  459. dynatemp_probs = dynatemp_shifted_logits.exp()
  460. dynatemp_entropies = -(dynatemp_probs *
  461. dynatemp_shifted_logits).nansum(dim=-1)
  462. dynatemp_max_entropies = torch.log_(
  463. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  464. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  465. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  466. normalized_entropies.pow_(dynatemp_exps))
  467. temperatures[dynatemp_mask] = dyn_temp
  468. temperatures[temperatures == 0.0] = 1.0
  469. logits.div_(temperatures.unsqueeze_(dim=1))
  470. return logits
  471. def _apply_quadratic_sampling(
  472. logits: torch.Tensor,
  473. smoothing_factors: torch.Tensor,
  474. smoothing_curves: torch.Tensor,
  475. ) -> torch.Tensor:
  476. """
  477. Applies a quadratic transformation to the logits based on the
  478. provided smoothing factors and curves. The transformation is
  479. centered around the maximum logit value in the batch.
  480. The transformation involves a quadratic and cubic term, with the
  481. cubic term controlled by the smoothing curve. The quadratic term is
  482. scaled by the smoothing factor, and the cubic term is scaled by the
  483. product of the smoothing factor and the smoothing curve.
  484. params:
  485. logits (torch.Tensor): The logits to be transformed.
  486. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  487. term in the transformation.
  488. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  489. in the transformation.
  490. returns:
  491. torch.Tensor: The transformed logits.
  492. Credits: @kalomaze
  493. """
  494. max_logits = logits.max(dim=-1, keepdim=True).values
  495. diff = logits - max_logits
  496. smoothing_factors.unsqueeze_(dim=1)
  497. smoothing_curves.unsqueeze_(dim=1)
  498. k = (3 - smoothing_curves) / 2
  499. s = (smoothing_curves - 1) / 2
  500. mask = smoothing_factors > 0
  501. mask = mask.flatten()
  502. transformed_logits = torch.where(
  503. logits != float('-inf'), -(k * smoothing_factors * diff**2) +
  504. (s * smoothing_factors * diff**3) + max_logits, logits)
  505. logits[mask, :] = transformed_logits[mask, :]
  506. return logits
  507. def _greedy_sample(
  508. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  509. samples: torch.Tensor,
  510. ) -> List[Tuple[List[int], List[int]]]:
  511. samples = samples.tolist()
  512. sample_idx = 0
  513. results = []
  514. for seq_group in selected_seq_groups:
  515. seq_ids, _ = seq_group
  516. num_parent_seqs = len(seq_ids)
  517. assert num_parent_seqs == 1, (
  518. "Greedy sampling should have only one seq.")
  519. parent_ids = list(range(num_parent_seqs))
  520. next_token_ids = [samples[sample_idx]]
  521. results.append((next_token_ids, parent_ids))
  522. sample_idx += num_parent_seqs
  523. return results
  524. def _random_sample(
  525. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  526. is_prompts: List[bool],
  527. random_samples: torch.Tensor,
  528. ) -> List[Tuple[List[int], List[int]]]:
  529. # Find the maximum best_of value of the prompt phase requests.
  530. random_samples = random_samples.cpu()
  531. sample_idx = 0
  532. results = []
  533. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  534. seq_ids, sampling_params = seq_group
  535. num_parent_seqs = len(seq_ids)
  536. if is_prompt:
  537. # Prompt phase.
  538. parent_ids = [0] * sampling_params.best_of
  539. next_token_ids = random_samples[
  540. sample_idx, :sampling_params.best_of].tolist()
  541. else:
  542. # Generation phase.
  543. parent_ids = list(range(num_parent_seqs))
  544. next_token_ids = random_samples[sample_idx:sample_idx +
  545. num_parent_seqs, 0].tolist()
  546. results.append((next_token_ids, parent_ids))
  547. sample_idx += num_parent_seqs
  548. return results
  549. def _beam_search_sample(
  550. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  551. is_prompts: List[bool],
  552. seq_data: Dict[int, SequenceData],
  553. logprobs: torch.Tensor,
  554. ) -> List[Tuple[List[int], List[int]]]:
  555. # We sample 2 * beam_width candidates to make sure that with high
  556. # probability we can get `beam_width` candidates in addition to
  557. # the finished sequences for the next iteration. See
  558. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  559. # for details. See also HF reference:
  560. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  561. #
  562. # Note: Beam search is not vectorized, so its speed can be slower than
  563. # other sampling methods.
  564. sample_idx = 0
  565. results = []
  566. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  567. seq_ids, sampling_params = seq_group
  568. num_parent_seqs = len(seq_ids)
  569. beam_width = sampling_params.best_of
  570. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  571. if is_prompt:
  572. # Prompt phase.
  573. assert num_parent_seqs == 1, (
  574. "Prompt input should have only one seq.")
  575. parent_ids = [0] * (2 * beam_width)
  576. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  577. 2 * beam_width)
  578. next_token_ids = next_token_ids.tolist()
  579. else:
  580. # Generation phase.
  581. cumulative_logprobs = [
  582. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
  583. ]
  584. cumulative_logprobs = torch.tensor(
  585. cumulative_logprobs,
  586. dtype=torch.float,
  587. device=seq_group_logprobs.device)
  588. seq_group_logprobs = (seq_group_logprobs +
  589. cumulative_logprobs.unsqueeze(dim=1))
  590. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  591. 2 * beam_width)
  592. topk_ids = topk_ids.tolist()
  593. vocab_size = seq_group_logprobs.size(-1)
  594. parent_ids = [i // vocab_size for i in topk_ids]
  595. next_token_ids = [i % vocab_size for i in topk_ids]
  596. results.append((next_token_ids, parent_ids))
  597. sample_idx += num_parent_seqs
  598. assert sample_idx == logprobs.size(0)
  599. return results
  600. # torch.multinomial forces a GPU<->CPU sync.
  601. # Therefore, we use an optimized implementation instead.
  602. # Note that we always sample with replacement.
  603. # probs will be modified in place, but this is fine, as we pass
  604. # in a copy already.
  605. def _multinomial(
  606. probs: torch.Tensor,
  607. num_samples: int,
  608. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
  609. generators: Optional[List[torch.Generator]] = None,
  610. ) -> torch.Tensor:
  611. if num_samples > 1:
  612. # This is equivalent to torch.repeat_interleaved (which also
  613. # forces a GPU<->CPU sync).
  614. # This allows us to do sampling with replacement by creating
  615. # num_samples copies of each row in the tensor, and then
  616. # batch sampling the resulting tensor.
  617. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  618. probs.shape[1]).contiguous().view(
  619. -1, probs.shape[1])
  620. q = torch.empty_like(probs)
  621. if seq_groups is None:
  622. q.exponential_()
  623. else:
  624. sample_idx = 0
  625. for (seq_ids, _), generator in zip(seq_groups, generators):
  626. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  627. q[sample_idx:next_sample_idx].exponential_(generator=generator)
  628. sample_idx = next_sample_idx
  629. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  630. def _sample(
  631. probs: torch.Tensor,
  632. logprobs: torch.Tensor,
  633. sampling_metadata: SamplingMetadata,
  634. ) -> List[Tuple[List[int], List[int]]]:
  635. categorized_seq_group_ids = {t: [] for t in SamplingType}
  636. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  637. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  638. _, sampling_params = seq_group
  639. sampling_type = sampling_params.sampling_type
  640. categorized_seq_group_ids[sampling_type].append(i)
  641. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  642. sample_metadata = {}
  643. multinomial_samples = {}
  644. # Counterintuitively, having two loops here is actually faster.
  645. # The first loop can run without waiting on GPU<->CPU sync.
  646. for sampling_type, sample_indices in categorized_sample_indices.items():
  647. if len(sample_indices) == 0:
  648. continue
  649. seq_group_ids = categorized_seq_group_ids[sampling_type]
  650. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
  651. is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
  652. sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
  653. is_prompts, sample_indices)
  654. if sampling_type == SamplingType.GREEDY:
  655. greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
  656. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  657. max_best_of = 1
  658. for seq_group, is_prompt in zip(seq_groups, is_prompts):
  659. if is_prompt:
  660. _, sampling_params = seq_group
  661. max_best_of = max(max_best_of, sampling_params.best_of)
  662. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  663. "seq_groups": seq_groups,
  664. "generators": sampling_metadata.generators,
  665. }
  666. multinomial_samples[sampling_type] = _multinomial(
  667. probs[sample_indices], max_best_of, **seeded_args)
  668. elif sampling_type == SamplingType.BEAM:
  669. beam_search_logprobs = logprobs[sample_indices]
  670. else:
  671. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  672. # GPU<->CPU sync happens in the loop below.
  673. for sampling_type, metadata in sample_metadata.items():
  674. seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
  675. if sampling_type == SamplingType.GREEDY:
  676. sample_results = _greedy_sample(seq_groups, greedy_samples)
  677. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  678. sample_results = _random_sample(seq_groups, is_prompts,
  679. multinomial_samples[sampling_type])
  680. elif sampling_type == SamplingType.BEAM:
  681. sample_results = _beam_search_sample(seq_groups, is_prompts,
  682. sampling_metadata.seq_data,
  683. beam_search_logprobs)
  684. sample_results_dict.update(zip(seq_group_ids, sample_results))
  685. sample_results = [
  686. sample_results_dict[i]
  687. for i in range(len(sampling_metadata.seq_groups))
  688. ]
  689. return sample_results
  690. def _get_logprobs(
  691. logprobs: torch.Tensor,
  692. sampling_metadata: SamplingMetadata,
  693. sample_results: List[Tuple[List[int], List[int]]],
  694. ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
  695. int, float]]]]:
  696. # Prepare query indices
  697. batched_logprobs_query_seq_indices: List[int] = []
  698. batched_logprobs_query_token_indices: List[int] = []
  699. largest_num_logprobs = 0
  700. sample_idx = 0
  701. for i, (seq_group, sample_result) in enumerate(
  702. zip(sampling_metadata.seq_groups, sample_results)):
  703. seq_ids, sampling_params = seq_group
  704. next_token_ids, parent_ids = sample_result
  705. num_parent_seqs = len(seq_ids)
  706. if (i < sampling_metadata.num_prompts
  707. and sampling_params.prompt_logprobs is not None):
  708. largest_num_logprobs = max(largest_num_logprobs,
  709. sampling_params.prompt_logprobs)
  710. prompt_len = sampling_metadata.prompt_lens[i]
  711. prompt_tokens = sampling_metadata.seq_data[
  712. seq_ids[0]].prompt_token_ids
  713. batched_logprobs_query_seq_indices.extend(
  714. sample_idx + j for j in range(prompt_len - 1))
  715. batched_logprobs_query_token_indices.extend(
  716. token_id for token_id in prompt_tokens[1:])
  717. sample_idx += prompt_len - 1
  718. batched_logprobs_query_seq_indices.extend(
  719. [sample_idx + parent_id for parent_id in parent_ids])
  720. batched_logprobs_query_token_indices.extend(next_token_ids)
  721. if sampling_params.logprobs is not None:
  722. largest_num_logprobs = max(largest_num_logprobs,
  723. sampling_params.logprobs)
  724. sample_idx += num_parent_seqs
  725. assert sample_idx == logprobs.size(0)
  726. # Batched query for logprobs of selected token
  727. batched_logprobs_query_result = logprobs[[
  728. batched_logprobs_query_seq_indices,
  729. batched_logprobs_query_token_indices
  730. ]]
  731. # Batched query for logprobs of topk tokens
  732. if largest_num_logprobs > 0:
  733. top_logprobs, top_token_ids = torch.topk(logprobs,
  734. largest_num_logprobs,
  735. dim=-1)
  736. top_logprobs = top_logprobs.cpu()
  737. top_token_ids = top_token_ids.cpu()
  738. else:
  739. top_logprobs, top_token_ids = None, None
  740. batched_logprobs_query_result = batched_logprobs_query_result.cpu()
  741. # Gather results
  742. result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
  743. result_sample_logprobs: List[SampleLogprobs] = []
  744. sample_idx = 0
  745. query_result_idx = 0
  746. for i, (seq_group, sample_result) in enumerate(
  747. zip(sampling_metadata.seq_groups, sample_results)):
  748. seq_ids, sampling_params = seq_group
  749. next_token_ids, parent_ids = sample_result
  750. # Prompt logprobs
  751. if (i < sampling_metadata.num_prompts
  752. and sampling_params.prompt_logprobs is not None):
  753. num_logprobs = sampling_params.prompt_logprobs
  754. prompt_len = sampling_metadata.prompt_lens[i]
  755. prompt_tokens = sampling_metadata.seq_data[
  756. seq_ids[0]].prompt_token_ids
  757. group_prompt_logprobs: PromptLogprobs = [None]
  758. for token_id in prompt_tokens[1:]:
  759. prompt_logprobs_dict = {
  760. token_id:
  761. batched_logprobs_query_result[query_result_idx].item()
  762. }
  763. if num_logprobs > 0:
  764. prompt_logprobs_dict.update(
  765. zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
  766. top_logprobs[sample_idx, :num_logprobs].tolist()))
  767. group_prompt_logprobs.append({
  768. token_id: Logprob(logprob)
  769. for token_id, logprob in prompt_logprobs_dict.items()
  770. })
  771. sample_idx += 1
  772. query_result_idx += 1
  773. result_prompt_logprobs.append(group_prompt_logprobs)
  774. else:
  775. result_prompt_logprobs.append(None)
  776. # Sample logprobs
  777. num_logprobs = sampling_params.logprobs
  778. if num_logprobs is None:
  779. num_logprobs = 0
  780. group_sample_logprobs: SampleLogprobs = []
  781. for next_token_id, parent_id in zip(next_token_ids, parent_ids):
  782. sample_logprobs_dict = {
  783. next_token_id:
  784. batched_logprobs_query_result[query_result_idx].item()
  785. }
  786. query_result_idx += 1
  787. if num_logprobs > 0:
  788. sample_logprobs_dict.update(
  789. zip(
  790. top_token_ids[sample_idx +
  791. parent_id, :num_logprobs].tolist(),
  792. top_logprobs[sample_idx +
  793. parent_id, :num_logprobs].tolist()))
  794. group_sample_logprobs.append({
  795. token_id: Logprob(logprob)
  796. for token_id, logprob in sample_logprobs_dict.items()
  797. })
  798. result_sample_logprobs.append(group_sample_logprobs)
  799. sample_idx += len(seq_ids)
  800. return result_prompt_logprobs, result_sample_logprobs
  801. def _build_sampler_output(
  802. sample_results: List[Tuple[List[int], List[int]]],
  803. sampling_metadata: SamplingMetadata,
  804. prompt_logprobs: List[Optional[PromptLogprobs]],
  805. sample_logprobs: List[SampleLogprobs],
  806. output_metadata: OutputMetadata,
  807. ) -> SamplerOutput:
  808. sampler_output = []
  809. for (seq_group, sample_result, group_prompt_logprobs,
  810. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  811. sample_results, prompt_logprobs,
  812. sample_logprobs):
  813. seq_ids, _ = seq_group
  814. next_token_ids, parent_ids = sample_result
  815. seq_outputs = []
  816. for parent_id, next_token_id, logprobs in zip(parent_ids,
  817. next_token_ids,
  818. group_sample_logprobs):
  819. seq_outputs.append(
  820. SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
  821. output_metadata.get(seq_ids[parent_id])))
  822. sampler_output.append(
  823. SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  824. return sampler_output
  825. def _miro_store_args(seqids: List[int], mus: List[float],
  826. output_metadata: OutputMetadata) -> None:
  827. for sid, mu in zip(seqids,
  828. mus.tolist()): # tolist might be premature optimization
  829. output_metadata.add(sid, "miro_mu", mu)
  830. def _apply_mirostat_v2(
  831. logits: torch.Tensor,
  832. taus: torch.Tensor, # AKA the targeted surprise
  833. etas: torch.Tensor, # AKA the learning rate
  834. mus: torch.
  835. Tensor, # AKA the accumulator that always tries to approach [tau]
  836. ) -> torch.Tensor:
  837. logit_surprise = torch.softmax(
  838. logits, dim=-1).log2_().neg_() # Calculate surprise value per token
  839. # For compatibility with ooba/kobold, done in unit of bits(log base 2)
  840. # not nats(ln).
  841. # Ideally this would be a log_softmax, for numerical stability and
  842. # elegance purposes.
  843. # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
  844. miro_mask = logit_surprise > mus.unsqueeze(
  845. dim=-1) # Mask out "too-surprising" tokens (above mu)
  846. mininds = torch.argmin(logit_surprise, dim=-1)
  847. miro_mask.scatter_(
  848. 1, mininds.unsqueeze(dim=-1), False
  849. ) # Force at least one outcome to be possible, ideally the most likely one
  850. logits[miro_mask] = -float("inf")
  851. probs = torch.softmax(logits, dim=-1,
  852. dtype=logits.dtype) # Get probs, post-mask
  853. # NOTE: Mirostat updates its `mu` values based on the sample chosen.
  854. # The silly approach here is to just sample it and make the logits one-hot.
  855. # This breaks fine grained seeding, but we don't have that yet.
  856. # TODO: FIX when it gets added
  857. next_token_ids = _multinomial(probs, num_samples=1)
  858. # Calculation new `mu` values
  859. # NOTE: If we can know the logit values of the PREVIOUS iteration,
  860. # it should be possible to update `mu` before applying mirostat each
  861. # iteration, thus letting us keep _sample as the last thing that happens.
  862. picked_surprises = torch.gather(logit_surprise,
  863. dim=-1,
  864. index=next_token_ids)
  865. eps = picked_surprises.squeeze() - taus
  866. mus.sub_(etas * eps)
  867. logits.fill_(-float("inf"))
  868. # This value doesn't actually matter, so long as it's not -inf.
  869. # Vectors are now one-hot, after all.
  870. logits.scatter_(1, next_token_ids, 1.0)
  871. return logits
  872. def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
  873. output_metadata: OutputMetadata) -> torch.Tensor:
  874. idx = sampling_tensors.miro_indices
  875. seqids = sampling_tensors.miro_seqids
  876. taus = sampling_tensors.miro_taus
  877. etas = sampling_tensors.miro_etas
  878. mus = sampling_tensors.miro_mus
  879. logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
  880. mus) # mus is an i/o param, :vomit:
  881. _miro_store_args(seqids, mus, output_metadata)
  882. return logits