sampler.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042
  1. """A layer that samples the next tokens from the model's outputs."""
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
  6. OutputMetadata,
  7. SamplingTensors)
  8. from aphrodite.modeling.megatron.communication_op import (
  9. tensor_model_parallel_gather)
  10. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  11. from aphrodite.common.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
  12. SamplerOutput, SequenceData,
  13. SequenceGroupOutput, SequenceOutput)
  14. class Sampler(nn.Module):
  15. """Samples the next tokens from the model's outputs.
  16. This layer does the following:
  17. 1. Discard the hidden states that are not used for sampling (i.e., all
  18. tokens except the final one in each prompt).
  19. 2. Compute the logits for the next tokens.
  20. 3. Apply presence and frequency penalties.
  21. 4. Apply temperature scaling.
  22. 5. Apply top-p and top-k truncation.
  23. 6. Sample the next tokens.
  24. Here, each sequence group within the batch can have different sampling
  25. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  26. """
  27. def __init__(self,
  28. vocab_size: int,
  29. org_vocab_size: Optional[int] = None) -> None:
  30. super().__init__()
  31. self.vocab_size = vocab_size
  32. # original vocabulary size (without LoRA).
  33. self.org_vocab_size = org_vocab_size or vocab_size
  34. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  35. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  36. # Get the logits for the next tokens.
  37. logits = torch.matmul(hidden_states, embedding.t())
  38. if embedding_bias is not None:
  39. logits += embedding_bias
  40. logits = tensor_model_parallel_gather(logits)
  41. # Remove paddings in vocab (if any).
  42. if logits is not None:
  43. logits = logits[:, :self.org_vocab_size]
  44. return logits
  45. def forward(
  46. self,
  47. embedding: torch.Tensor,
  48. hidden_states: torch.Tensor,
  49. sampling_metadata: SamplingMetadata,
  50. embedding_bias: Optional[torch.Tensor] = None,
  51. ) -> Optional[SamplerOutput]:
  52. # Get the hidden states that we use for sampling.
  53. hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
  54. # Get the logits for the next tokens.
  55. logits = self._get_logits(hidden_states, embedding, embedding_bias)
  56. # Only perform sampling in the driver worker.
  57. # Note: `_get_logits` is still distributed across TP workers because
  58. # the `embedding` weight is distributed across TP workers.
  59. # TODO: Change the get_logits part to a separate stage.
  60. if not sampling_metadata.perform_sampling:
  61. return None
  62. assert logits is not None
  63. _, vocab_size = logits.shape
  64. output_metadata = OutputMetadata()
  65. # Apply logits processors (if any)
  66. logits = _apply_logits_processors(logits, sampling_metadata)
  67. # Prepare sampling tensors with pinned memory to avoid blocking.
  68. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  69. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  70. do_typical_ps, do_quadratic,
  71. do_mirostat) = (SamplingTensors.from_sampling_metadata(
  72. sampling_metadata, vocab_size, logits.device, logits.dtype))
  73. if do_penalties:
  74. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  75. sampling_tensors.output_tokens,
  76. sampling_tensors.presence_penalties,
  77. sampling_tensors.frequency_penalties,
  78. sampling_tensors.repetition_penalties)
  79. if do_temperatures:
  80. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  81. sampling_tensors.dynatemp_mins,
  82. sampling_tensors.dynatemp_maxs,
  83. sampling_tensors.dynatemp_exps)
  84. if do_topks or do_topps or do_topas or do_minps:
  85. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  86. sampling_tensors.top_ks,
  87. sampling_tensors.top_as,
  88. sampling_tensors.min_ps)
  89. if do_tfss:
  90. logits = _apply_tfs(logits, sampling_tensors.tfss)
  91. if do_eta_cutoffs:
  92. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  93. if do_epsilon_cutoffs:
  94. logits = _apply_epsilon_cutoff(logits,
  95. sampling_tensors.epsilon_cutoffs)
  96. if do_typical_ps:
  97. logits = _apply_typical_sampling(logits,
  98. sampling_tensors.typical_ps,
  99. sampling_tensors.typical_p_sigmas)
  100. if do_quadratic:
  101. logits = _apply_quadratic_sampling(
  102. logits, sampling_tensors.smoothing_factors,
  103. sampling_tensors.smoothing_curves)
  104. banned_tokens = _get_custom_token_bans(sampling_metadata)
  105. assert len(banned_tokens) == logits.shape[0]
  106. logits = _apply_token_bans(logits, banned_tokens)
  107. if do_mirostat:
  108. logits = _mirostat(logits, sampling_tensors, output_metadata)
  109. # We use float32 for probabilities and log probabilities.
  110. # Compute the probabilities.
  111. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  112. # Compute the log probabilities.
  113. # Use log_softmax to ensure numerical stability.
  114. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  115. # Sample the next tokens.
  116. sample_results = _sample(probs, logprobs, sampling_metadata)
  117. # Get the logprobs query results.
  118. prompt_logprobs, sample_logprobs = _get_logprobs(
  119. logprobs, sampling_metadata, sample_results)
  120. return _build_sampler_output(sample_results, sampling_metadata,
  121. prompt_logprobs, sample_logprobs,
  122. output_metadata)
  123. # FIXME: This is a hack for the missing GPU blocks. This should be removed
  124. # once a proper fix is implemented.
  125. class QuantSampler(nn.Module):
  126. """Samples the next tokens from the model's outputs.
  127. This layer does the following:
  128. 1. Discard the hidden states that are not used for sampling (i.e., all
  129. tokens except the final one in each prompt).
  130. 2. Compute the logits for the next tokens.
  131. 3. Apply presence and frequency penalties.
  132. 4. Apply temperature scaling.
  133. 5. Apply top-p and top-k truncation.
  134. 6. Sample the next tokens.
  135. Here, each sequence group within the batch can have different sampling
  136. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  137. """
  138. def __init__(self,
  139. vocab_size: int,
  140. org_vocab_size: Optional[int] = None) -> None:
  141. super().__init__()
  142. self.vocab_size = vocab_size
  143. # original vocabulary size (without LoRA).
  144. self.org_vocab_size = org_vocab_size or vocab_size
  145. def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
  146. embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
  147. # Get the logits for the next tokens.
  148. logits = torch.matmul(hidden_states, embedding.t())
  149. if embedding_bias is not None:
  150. logits += embedding_bias
  151. logits = tensor_model_parallel_gather(logits)
  152. # Remove paddings in vocab (if any).
  153. if logits is not None:
  154. logits = logits[:, :self.org_vocab_size]
  155. return logits
  156. def forward(
  157. self,
  158. logits: torch.Tensor,
  159. sampling_metadata: SamplingMetadata,
  160. ) -> Optional[SamplerOutput]:
  161. # Get the hidden states that we use for sampling.
  162. logits = _prune_hidden_states(logits, sampling_metadata)
  163. logits = tensor_model_parallel_gather(logits)
  164. # Remove paddings in vocab (if any).
  165. if logits is not None:
  166. logits = logits[:, :self.vocab_size]
  167. # Only perform sampling in the driver worker.
  168. # Note: `_get_logits` is still distributed across TP workers because
  169. # the `embedding` weight is distributed across TP workers.
  170. # TODO: Change the get_logits part to a separate stage.
  171. if not sampling_metadata.perform_sampling:
  172. return None
  173. assert logits is not None
  174. _, vocab_size = logits.shape
  175. output_metadata = OutputMetadata()
  176. # Apply logits processors (if any)
  177. logits = _apply_logits_processors(logits, sampling_metadata)
  178. # Prepare sampling tensors with pinned memory to avoid blocking.
  179. (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
  180. do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  181. do_typical_ps, do_quadratic,
  182. do_mirostat) = (SamplingTensors.from_sampling_metadata(
  183. sampling_metadata, vocab_size, logits.device, logits.dtype))
  184. if do_penalties:
  185. logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
  186. sampling_tensors.output_tokens,
  187. sampling_tensors.presence_penalties,
  188. sampling_tensors.frequency_penalties,
  189. sampling_tensors.repetition_penalties)
  190. if do_temperatures:
  191. logits = _apply_temperature(logits, sampling_tensors.temperatures,
  192. sampling_tensors.dynatemp_mins,
  193. sampling_tensors.dynatemp_maxs,
  194. sampling_tensors.dynatemp_exps)
  195. if do_topks or do_topps or do_topas or do_minps:
  196. logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
  197. sampling_tensors.top_ks,
  198. sampling_tensors.top_as,
  199. sampling_tensors.min_ps)
  200. if do_tfss:
  201. logits = _apply_tfs(logits, sampling_tensors.tfss)
  202. if do_eta_cutoffs:
  203. logits = _apply_eta_cutoff(logits, sampling_tensors.eta_cutoffs)
  204. if do_epsilon_cutoffs:
  205. logits = _apply_epsilon_cutoff(logits,
  206. sampling_tensors.epsilon_cutoffs)
  207. if do_typical_ps:
  208. logits = _apply_typical_sampling(logits,
  209. sampling_tensors.typical_ps)
  210. if do_quadratic:
  211. logits = _apply_quadratic_sampling(
  212. logits, sampling_tensors.smoothing_factors,
  213. sampling_tensors.smoothing_curves)
  214. banned_tokens = _get_custom_token_bans(sampling_metadata)
  215. assert len(banned_tokens) == logits.shape[0]
  216. logits = _apply_token_bans(logits, banned_tokens)
  217. if do_mirostat:
  218. logits = _mirostat(logits, sampling_tensors, output_metadata)
  219. # We use float32 for probabilities and log probabilities.
  220. # Compute the probabilities.
  221. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  222. # Compute the log probabilities.
  223. # Use log_softmax to ensure numerical stability.
  224. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  225. # Sample the next tokens.
  226. sample_results = _sample(probs, logprobs, sampling_metadata)
  227. # Get the logprobs query results.
  228. prompt_logprobs, sample_logprobs = _get_logprobs(
  229. logprobs, sampling_metadata, sample_results)
  230. return _build_sampler_output(sample_results, sampling_metadata,
  231. prompt_logprobs, sample_logprobs,
  232. output_metadata)
  233. def _prune_hidden_states(
  234. hidden_states: torch.Tensor,
  235. sampling_metadata: SamplingMetadata,
  236. ) -> torch.Tensor:
  237. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  238. return hidden_states.index_select(0,
  239. sampling_metadata.selected_token_indices)
  240. def _get_bin_counts_and_mask(
  241. tokens: torch.Tensor,
  242. vocab_size: int,
  243. num_seqs: int,
  244. ) -> Tuple[torch.Tensor, torch.Tensor]:
  245. # Compute the bin counts for the tokens.
  246. # vocab_size + 1 for padding.
  247. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  248. dtype=torch.long,
  249. device=tokens.device)
  250. bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
  251. bin_counts = bin_counts[:, :vocab_size]
  252. mask = bin_counts > 0
  253. return bin_counts, mask
  254. def _get_custom_token_bans(
  255. sampling_metadata: SamplingMetadata) -> List[List[int]]:
  256. banned_tokens: List[List[int]] = []
  257. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  258. seq_ids, sampling_params = seq_group
  259. custom_token_bans = sampling_params.custom_token_bans
  260. if (i < sampling_metadata.num_prompts
  261. and sampling_params.prompt_logprobs is not None):
  262. prompt_len = sampling_metadata.prompt_lens[i]
  263. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  264. banned_tokens += [custom_token_bans] * len(seq_ids)
  265. return banned_tokens
  266. # def _apply_logits_processors(
  267. # logits: torch.Tensor,
  268. # metadata: SamplingMetadata,
  269. # ) -> torch.Tensor:
  270. # seq_offset = 0
  271. # for i, (seq_ids, sampling_params) in enumerate(metadata.seq_groups):
  272. # seq_size = len(seq_ids)
  273. # output_tokens = []
  274. # if (i < metadata.num_prompts
  275. # and sampling_params.prompt_logprobs is not None):
  276. # prompt_seqs = metadata.prompt_lens[i] - 1
  277. # seq_size += prompt_seqs
  278. # output_tokens.extend([[]] * prompt_seqs)
  279. # seq_end = seq_offset + seq_size
  280. # if sampling_params.logits_processors:
  281. # output_tokens.extend(metadata.seq_data[sid].output_token_ids
  282. # for sid in seq_ids)
  283. # for proc in sampling_params.logits_processors:
  284. # proc(logits[seq_offset:seq_end], output_tokens)
  285. # seq_offset = seq_end
  286. # return logits
  287. def _apply_logits_processors(
  288. logits: torch.Tensor,
  289. sampling_metadata: SamplingMetadata,
  290. ) -> torch.Tensor:
  291. logits_row_idx = 0
  292. found_logits_processors = False
  293. for seq_ids, sampling_params in sampling_metadata.seq_groups:
  294. logits_processors = sampling_params.logits_processors
  295. if logits_processors:
  296. found_logits_processors = True
  297. for seq_id in seq_ids:
  298. logits_row = logits[logits_row_idx]
  299. token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
  300. for logits_processor in logits_processors:
  301. logits_row = logits_processor(token_ids, logits_row)
  302. logits[logits_row_idx] = logits_row
  303. logits_row_idx += 1
  304. else:
  305. logits_row_idx += len(seq_ids)
  306. if found_logits_processors:
  307. assert logits_row_idx == logits.shape[0]
  308. return logits
  309. def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
  310. output_tokens_tensor: torch.Tensor,
  311. presence_penalties: torch.Tensor,
  312. frequency_penalties: torch.Tensor,
  313. repetition_penalties: torch.Tensor) -> torch.Tensor:
  314. num_seqs, vocab_size = logits.shape
  315. _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
  316. num_seqs)
  317. output_bin_counts, output_mask = _get_bin_counts_and_mask(
  318. output_tokens_tensor, vocab_size, num_seqs)
  319. repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
  320. repetition_penalties[~(prompt_mask | output_mask)] = 1.0
  321. logits = torch.where(logits > 0, logits / repetition_penalties,
  322. logits * repetition_penalties)
  323. # We follow the definition in OpenAI API.
  324. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  325. logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
  326. logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
  327. return logits
  328. def _apply_token_bans(logits: torch.Tensor,
  329. banned_tokens: List[List[int]]) -> torch.Tensor:
  330. for i, banned_token_ids in enumerate(banned_tokens):
  331. if not banned_token_ids:
  332. continue
  333. logits[i, banned_token_ids] = -float("inf")
  334. return logits
  335. def _apply_alphabet_soup(
  336. logits: torch.Tensor,
  337. p: torch.Tensor,
  338. k: torch.Tensor,
  339. a: torch.Tensor,
  340. m: torch.Tensor,
  341. ) -> torch.Tensor:
  342. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  343. # Apply top-p, min-p and top-a.
  344. probs_sort = logits_sort.softmax(dim=-1)
  345. probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
  346. min_p_thresholds = probs_sort[:, 0] * m
  347. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * a
  348. threshold = torch.maximum(min_p_thresholds, top_a_thresholds)
  349. mask = (probs_sort < threshold.unsqueeze(1)
  350. ) # Cull logits below the top-a threshold
  351. mask.logical_or_(
  352. probs_sum >
  353. p.unsqueeze(dim=1)) # Cull logits above the top-p summation threshold
  354. mask[:, 0] = False # Guarantee at least one token is pickable
  355. logits_sort[mask] = -float("inf")
  356. # Apply top-k.
  357. # Create a mask for the top-k elements.
  358. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  359. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  360. top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
  361. # Final mask.
  362. mask = (mask | top_k_mask)
  363. logits_sort.masked_fill_(mask, -float("inf"))
  364. # Re-sort the probabilities.
  365. src = torch.arange(logits_idx.shape[-1],
  366. device=logits_idx.device).expand_as(logits_idx)
  367. logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
  368. index=logits_idx,
  369. src=src)
  370. logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
  371. return logits
  372. def _apply_tfs(
  373. logits: torch.Tensor,
  374. tfs: torch.Tensor,
  375. ) -> torch.Tensor:
  376. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  377. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  378. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  379. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  380. tfs_mask = curvature_cdf > tfs.unsqueeze(dim=-1)
  381. tfs_mask = torch.cat(
  382. (
  383. torch.zeros(
  384. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  385. tfs_mask,
  386. torch.ones(
  387. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  388. ),
  389. dim=-1,
  390. )
  391. logits_sort[tfs_mask] = -float("inf")
  392. logits = torch.gather(logits_sort,
  393. dim=-1,
  394. index=torch.argsort(logits_idx, dim=-1))
  395. return logits
  396. def _apply_eta_cutoff(
  397. logits: torch.Tensor,
  398. eta_cutoff: torch.Tensor,
  399. ) -> torch.Tensor:
  400. eta = torch.tensor(eta_cutoff, dtype=logits.dtype,
  401. device=logits.device) * 1e-4
  402. shifted_logits = torch.log_softmax(logits, dim=-1)
  403. probs = shifted_logits.exp()
  404. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  405. eps = torch.min(eta,
  406. torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
  407. eta_mask = probs < eps
  408. if torch.all(eta_mask): # guard against nulling out all the logits
  409. topk_prob, _ = torch.max(probs, dim=-1)
  410. eta_mask = probs < topk_prob
  411. logits[eta_mask] = -float("inf")
  412. return logits
  413. def _apply_epsilon_cutoff(
  414. logits: torch.Tensor,
  415. epsilon_cutoff: torch.Tensor,
  416. ) -> torch.Tensor:
  417. eps = torch.tensor(epsilon_cutoff,
  418. dtype=logits.dtype,
  419. device=logits.device).unsqueeze(dim=1)
  420. probs = logits.softmax(dim=-1)
  421. eps_mask = probs < (eps * 1e-4)
  422. if torch.all(eps_mask): # guard against nulling out all the logits
  423. topk_prob, _ = torch.max(probs, dim=-1)
  424. eps_mask = probs < topk_prob
  425. logits[eps_mask] = -float("inf")
  426. return logits
  427. def _apply_typical_sampling(
  428. logits: torch.Tensor,
  429. typical_p: torch.Tensor,
  430. typical_p_sigma: torch.Tensor,
  431. ) -> torch.Tensor:
  432. typ_p = typical_p.clone().detach().to(logits.device).to(logits.dtype)
  433. typ_threshold = typical_p_sigma.clone().detach().to(logits.device).to(
  434. logits.dtype)
  435. THRESHOLD = 1000
  436. shifted_logits = torch.log_softmax(logits, dim=-1)
  437. probs = torch.exp(shifted_logits)
  438. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  439. # NOTE: We don't take the absolute value of the surprisal deviations
  440. # This deviates from the original implementation
  441. surprisal_deviations = neg_entropy - shifted_logits
  442. _, indices = torch.sort(surprisal_deviations)
  443. reordered_probs = probs.gather(-1, indices)
  444. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
  445. # Invert the minimum deviation from the expected information content of the
  446. # probability distribution for the next token and scale it based on the
  447. # provided typical_p_sigma parameter.
  448. max_threshold = surprisal_deviations.min().negative() * typ_threshold
  449. # Mask negative deviations and positive deviations above the max threshold
  450. max_threshold = max_threshold.unsqueeze(1)
  451. surprisal_deviations[surprisal_deviations <= 0] = THRESHOLD
  452. surprisal_deviations[surprisal_deviations > max_threshold] = THRESHOLD
  453. positive_mask = surprisal_deviations == THRESHOLD
  454. typ_mask_sorted[..., :1] = 0
  455. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  456. # Merging the mask created above with the one from the standard Typical-p.
  457. # Masked out tokens in the distribution are True
  458. typ_mask = typ_mask.bitwise_and(positive_mask)
  459. logits[typ_mask] = -float("inf")
  460. return logits
  461. # pulls double duty for temperature and dynatemp
  462. def _apply_temperature(
  463. logits: torch.Tensor,
  464. temperatures: torch.Tensor,
  465. dynatemp_mins: torch.Tensor,
  466. dynatemp_maxs: torch.Tensor,
  467. dynatemp_exps: torch.Tensor,
  468. ) -> torch.Tensor:
  469. dynatemp_mask = torch.logical_or(dynatemp_mins > 0, dynatemp_maxs > 0)
  470. dynatemp_mins = dynatemp_mins[dynatemp_mask]
  471. dynatemp_maxs = dynatemp_maxs[dynatemp_mask]
  472. dynatemp_exps = dynatemp_exps[dynatemp_mask]
  473. dynatemp_mins = dynatemp_mins.clamp_(min=0)
  474. dynatemp_logits = logits[dynatemp_mask]
  475. dynatemp_shifted_logits = torch.log_softmax(dynatemp_logits, dim=-1)
  476. dynatemp_probs = dynatemp_shifted_logits.exp()
  477. dynatemp_entropies = -(dynatemp_probs *
  478. dynatemp_shifted_logits).nansum(dim=-1)
  479. dynatemp_max_entropies = torch.log_(
  480. (dynatemp_logits > float("-inf")).sum(dim=-1).float())
  481. normalized_entropies = dynatemp_entropies.div_(dynatemp_max_entropies)
  482. dyn_temp = (dynatemp_mins + (dynatemp_maxs - dynatemp_mins) *
  483. normalized_entropies.pow_(dynatemp_exps))
  484. temperatures[dynatemp_mask] = dyn_temp
  485. temperatures[temperatures == 0.0] = 1.0
  486. logits.div_(temperatures.unsqueeze_(dim=1))
  487. return logits
  488. def _apply_quadratic_sampling(
  489. logits: torch.Tensor,
  490. smoothing_factors: torch.Tensor,
  491. smoothing_curves: torch.Tensor,
  492. ) -> torch.Tensor:
  493. """
  494. Applies a quadratic transformation to the logits based on the
  495. provided smoothing factors and curves. The transformation is
  496. centered around the maximum logit value in the batch.
  497. The transformation involves a quadratic and cubic term, with the
  498. cubic term controlled by the smoothing curve. The quadratic term is
  499. scaled by the smoothing factor, and the cubic term is scaled by the
  500. product of the smoothing factor and the smoothing curve.
  501. params:
  502. logits (torch.Tensor): The logits to be transformed.
  503. smoothing_factors (torch.Tensor): The factors to scale the quadratic
  504. term in the transformation.
  505. smoothing_curves (torch.Tensor): The factors to scale the cubic term
  506. in the transformation.
  507. returns:
  508. torch.Tensor: The transformed logits.
  509. Credits: @kalomaze
  510. """
  511. max_logits = logits.max(dim=-1, keepdim=True).values
  512. diff = logits - max_logits
  513. smoothing_factors.unsqueeze_(dim=1)
  514. smoothing_curves.unsqueeze_(dim=1)
  515. k = (3 - smoothing_curves) / 2
  516. s = (smoothing_curves - 1) / 2
  517. mask = smoothing_factors > 0
  518. mask = mask.flatten()
  519. transformed_logits = torch.where(
  520. logits != float('-inf'), -(k * smoothing_factors * diff**2) +
  521. (s * smoothing_factors * diff**3) + max_logits, logits)
  522. logits[mask, :] = transformed_logits[mask, :]
  523. return logits
  524. def _greedy_sample(
  525. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  526. samples: torch.Tensor,
  527. ) -> List[Tuple[List[int], List[int]]]:
  528. samples = samples.tolist()
  529. sample_idx = 0
  530. results = []
  531. for seq_group in selected_seq_groups:
  532. seq_ids, _ = seq_group
  533. num_parent_seqs = len(seq_ids)
  534. assert num_parent_seqs == 1, (
  535. "Greedy sampling should have only one seq.")
  536. parent_ids = list(range(num_parent_seqs))
  537. next_token_ids = [samples[sample_idx]]
  538. results.append((next_token_ids, parent_ids))
  539. sample_idx += num_parent_seqs
  540. return results
  541. def _random_sample(
  542. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  543. is_prompts: List[bool],
  544. random_samples: torch.Tensor,
  545. ) -> List[Tuple[List[int], List[int]]]:
  546. # Find the maximum best_of value of the prompt phase requests.
  547. random_samples = random_samples.cpu()
  548. sample_idx = 0
  549. results = []
  550. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  551. seq_ids, sampling_params = seq_group
  552. num_parent_seqs = len(seq_ids)
  553. if is_prompt:
  554. # Prompt phase.
  555. parent_ids = [0] * sampling_params.best_of
  556. next_token_ids = random_samples[
  557. sample_idx, :sampling_params.best_of].tolist()
  558. else:
  559. # Generation phase.
  560. parent_ids = list(range(num_parent_seqs))
  561. next_token_ids = random_samples[sample_idx:sample_idx +
  562. num_parent_seqs, 0].tolist()
  563. results.append((next_token_ids, parent_ids))
  564. sample_idx += num_parent_seqs
  565. return results
  566. def _beam_search_sample(
  567. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  568. is_prompts: List[bool],
  569. seq_data: Dict[int, SequenceData],
  570. logprobs: torch.Tensor,
  571. ) -> List[Tuple[List[int], List[int]]]:
  572. # We sample 2 * beam_width candidates to make sure that with high
  573. # probability we can get `beam_width` candidates in addition to
  574. # the finished sequences for the next iteration. See
  575. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  576. # for details. See also HF reference:
  577. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  578. #
  579. # Note: Beam search is not vectorized, so its speed can be slower than
  580. # other sampling methods.
  581. sample_idx = 0
  582. results = []
  583. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  584. seq_ids, sampling_params = seq_group
  585. num_parent_seqs = len(seq_ids)
  586. beam_width = sampling_params.best_of
  587. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  588. if is_prompt:
  589. # Prompt phase.
  590. assert num_parent_seqs == 1, (
  591. "Prompt input should have only one seq.")
  592. parent_ids = [0] * (2 * beam_width)
  593. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  594. 2 * beam_width)
  595. next_token_ids = next_token_ids.tolist()
  596. else:
  597. # Generation phase.
  598. cumulative_logprobs = [
  599. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
  600. ]
  601. cumulative_logprobs = torch.tensor(
  602. cumulative_logprobs,
  603. dtype=torch.float,
  604. device=seq_group_logprobs.device)
  605. seq_group_logprobs = (seq_group_logprobs +
  606. cumulative_logprobs.unsqueeze(dim=1))
  607. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  608. 2 * beam_width)
  609. topk_ids = topk_ids.tolist()
  610. vocab_size = seq_group_logprobs.size(-1)
  611. parent_ids = [i // vocab_size for i in topk_ids]
  612. next_token_ids = [i % vocab_size for i in topk_ids]
  613. results.append((next_token_ids, parent_ids))
  614. sample_idx += num_parent_seqs
  615. assert sample_idx == logprobs.size(0)
  616. return results
  617. # torch.multinomial forces a GPU<->CPU sync.
  618. # Therefore, we use an optimized implementation instead.
  619. # Note that we always sample with replacement.
  620. # probs will be modified in place, but this is fine, as we pass
  621. # in a copy already.
  622. def _multinomial(
  623. probs: torch.Tensor,
  624. num_samples: int,
  625. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
  626. generators: Optional[List[torch.Generator]] = None,
  627. ) -> torch.Tensor:
  628. if num_samples > 1:
  629. # This is equivalent to torch.repeat_interleaved (which also
  630. # forces a GPU<->CPU sync).
  631. # This allows us to do sampling with replacement by creating
  632. # num_samples copies of each row in the tensor, and then
  633. # batch sampling the resulting tensor.
  634. probs = probs[:, None, :].expand(probs.shape[0], num_samples,
  635. probs.shape[1]).contiguous().view(
  636. -1, probs.shape[1])
  637. q = torch.empty_like(probs)
  638. if seq_groups is None:
  639. q.exponential_()
  640. else:
  641. sample_idx = 0
  642. for (seq_ids, _), generator in zip(seq_groups, generators):
  643. next_sample_idx = sample_idx + len(seq_ids) * num_samples
  644. q[sample_idx:next_sample_idx].exponential_(generator=generator)
  645. sample_idx = next_sample_idx
  646. return probs.div_(q).argmax(dim=1).view(-1, num_samples)
  647. def _sample(
  648. probs: torch.Tensor,
  649. logprobs: torch.Tensor,
  650. sampling_metadata: SamplingMetadata,
  651. ) -> List[Tuple[List[int], List[int]]]:
  652. categorized_seq_group_ids = {t: [] for t in SamplingType}
  653. categorized_sample_indices = sampling_metadata.categorized_sample_indices
  654. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  655. _, sampling_params = seq_group
  656. sampling_type = sampling_params.sampling_type
  657. categorized_seq_group_ids[sampling_type].append(i)
  658. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  659. sample_metadata = {}
  660. multinomial_samples = {}
  661. # Counterintuitively, having two loops here is actually faster.
  662. # The first loop can run without waiting on GPU<->CPU sync.
  663. for sampling_type, sample_indices in categorized_sample_indices.items():
  664. if len(sample_indices) == 0:
  665. continue
  666. seq_group_ids = categorized_seq_group_ids[sampling_type]
  667. seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
  668. is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
  669. sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
  670. is_prompts, sample_indices)
  671. if sampling_type == SamplingType.GREEDY:
  672. greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
  673. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  674. max_best_of = 1
  675. for seq_group, is_prompt in zip(seq_groups, is_prompts):
  676. if is_prompt:
  677. _, sampling_params = seq_group
  678. max_best_of = max(max_best_of, sampling_params.best_of)
  679. seeded_args = {} if sampling_type == SamplingType.RANDOM else {
  680. "seq_groups": seq_groups,
  681. "generators": sampling_metadata.generators,
  682. }
  683. multinomial_samples[sampling_type] = _multinomial(
  684. probs[sample_indices], max_best_of, **seeded_args)
  685. elif sampling_type == SamplingType.BEAM:
  686. beam_search_logprobs = logprobs[sample_indices]
  687. else:
  688. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  689. # GPU<->CPU sync happens in the loop below.
  690. for sampling_type, metadata in sample_metadata.items():
  691. seq_group_ids, seq_groups, is_prompts, sample_indices = metadata
  692. if sampling_type == SamplingType.GREEDY:
  693. sample_results = _greedy_sample(seq_groups, greedy_samples)
  694. elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
  695. sample_results = _random_sample(seq_groups, is_prompts,
  696. multinomial_samples[sampling_type])
  697. elif sampling_type == SamplingType.BEAM:
  698. sample_results = _beam_search_sample(seq_groups, is_prompts,
  699. sampling_metadata.seq_data,
  700. beam_search_logprobs)
  701. sample_results_dict.update(zip(seq_group_ids, sample_results))
  702. sample_results = [
  703. sample_results_dict[i]
  704. for i in range(len(sampling_metadata.seq_groups))
  705. ]
  706. return sample_results
  707. def _get_logprobs(
  708. logprobs: torch.Tensor,
  709. sampling_metadata: SamplingMetadata,
  710. sample_results: List[Tuple[List[int], List[int]]],
  711. ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
  712. int, float]]]]:
  713. # Prepare query indices
  714. batched_logprobs_query_seq_indices: List[int] = []
  715. batched_logprobs_query_token_indices: List[int] = []
  716. largest_num_logprobs = 0
  717. sample_idx = 0
  718. for i, (seq_group, sample_result) in enumerate(
  719. zip(sampling_metadata.seq_groups, sample_results)):
  720. seq_ids, sampling_params = seq_group
  721. next_token_ids, parent_ids = sample_result
  722. num_parent_seqs = len(seq_ids)
  723. if (i < sampling_metadata.num_prompts
  724. and sampling_params.prompt_logprobs is not None):
  725. largest_num_logprobs = max(largest_num_logprobs,
  726. sampling_params.prompt_logprobs)
  727. prompt_len = sampling_metadata.prompt_lens[i]
  728. prompt_tokens = sampling_metadata.seq_data[
  729. seq_ids[0]].prompt_token_ids
  730. batched_logprobs_query_seq_indices.extend(
  731. sample_idx + j for j in range(prompt_len - 1))
  732. batched_logprobs_query_token_indices.extend(
  733. token_id for token_id in prompt_tokens[1:])
  734. sample_idx += prompt_len - 1
  735. batched_logprobs_query_seq_indices.extend(
  736. [sample_idx + parent_id for parent_id in parent_ids])
  737. batched_logprobs_query_token_indices.extend(next_token_ids)
  738. if sampling_params.logprobs is not None:
  739. largest_num_logprobs = max(largest_num_logprobs,
  740. sampling_params.logprobs)
  741. sample_idx += num_parent_seqs
  742. assert sample_idx == logprobs.size(0)
  743. # Batched query for logprobs of selected token
  744. batched_logprobs_query_result = logprobs[[
  745. batched_logprobs_query_seq_indices,
  746. batched_logprobs_query_token_indices
  747. ]]
  748. # Batched query for logprobs of topk tokens
  749. if largest_num_logprobs > 0:
  750. top_logprobs, top_token_ids = torch.topk(logprobs,
  751. largest_num_logprobs,
  752. dim=-1)
  753. top_logprobs = top_logprobs.cpu()
  754. top_token_ids = top_token_ids.cpu()
  755. else:
  756. top_logprobs, top_token_ids = None, None
  757. batched_logprobs_query_result = batched_logprobs_query_result.cpu()
  758. # Gather results
  759. result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
  760. result_sample_logprobs: List[SampleLogprobs] = []
  761. sample_idx = 0
  762. query_result_idx = 0
  763. for i, (seq_group, sample_result) in enumerate(
  764. zip(sampling_metadata.seq_groups, sample_results)):
  765. seq_ids, sampling_params = seq_group
  766. next_token_ids, parent_ids = sample_result
  767. # Prompt logprobs
  768. if (i < sampling_metadata.num_prompts
  769. and sampling_params.prompt_logprobs is not None):
  770. num_logprobs = sampling_params.prompt_logprobs
  771. prompt_len = sampling_metadata.prompt_lens[i]
  772. prompt_tokens = sampling_metadata.seq_data[
  773. seq_ids[0]].prompt_token_ids
  774. group_prompt_logprobs: PromptLogprobs = [None]
  775. for token_id in prompt_tokens[1:]:
  776. prompt_logprobs_dict = {
  777. token_id:
  778. batched_logprobs_query_result[query_result_idx].item()
  779. }
  780. if num_logprobs > 0:
  781. prompt_logprobs_dict.update(
  782. zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
  783. top_logprobs[sample_idx, :num_logprobs].tolist()))
  784. group_prompt_logprobs.append({
  785. token_id: Logprob(logprob)
  786. for token_id, logprob in prompt_logprobs_dict.items()
  787. })
  788. sample_idx += 1
  789. query_result_idx += 1
  790. result_prompt_logprobs.append(group_prompt_logprobs)
  791. else:
  792. result_prompt_logprobs.append(None)
  793. # Sample logprobs
  794. num_logprobs = sampling_params.logprobs
  795. if num_logprobs is None:
  796. num_logprobs = 0
  797. group_sample_logprobs: SampleLogprobs = []
  798. for next_token_id, parent_id in zip(next_token_ids, parent_ids):
  799. sample_logprobs_dict = {
  800. next_token_id:
  801. batched_logprobs_query_result[query_result_idx].item()
  802. }
  803. query_result_idx += 1
  804. if num_logprobs > 0:
  805. sample_logprobs_dict.update(
  806. zip(
  807. top_token_ids[sample_idx +
  808. parent_id, :num_logprobs].tolist(),
  809. top_logprobs[sample_idx +
  810. parent_id, :num_logprobs].tolist()))
  811. group_sample_logprobs.append({
  812. token_id: Logprob(logprob)
  813. for token_id, logprob in sample_logprobs_dict.items()
  814. })
  815. result_sample_logprobs.append(group_sample_logprobs)
  816. sample_idx += len(seq_ids)
  817. return result_prompt_logprobs, result_sample_logprobs
  818. def _build_sampler_output(
  819. sample_results: List[Tuple[List[int], List[int]]],
  820. sampling_metadata: SamplingMetadata,
  821. prompt_logprobs: List[Optional[PromptLogprobs]],
  822. sample_logprobs: List[SampleLogprobs],
  823. output_metadata: OutputMetadata,
  824. ) -> SamplerOutput:
  825. sampler_output = []
  826. for (seq_group, sample_result, group_prompt_logprobs,
  827. group_sample_logprobs) in zip(sampling_metadata.seq_groups,
  828. sample_results, prompt_logprobs,
  829. sample_logprobs):
  830. seq_ids, _ = seq_group
  831. next_token_ids, parent_ids = sample_result
  832. seq_outputs = []
  833. for parent_id, next_token_id, logprobs in zip(parent_ids,
  834. next_token_ids,
  835. group_sample_logprobs):
  836. seq_outputs.append(
  837. SequenceOutput(seq_ids[parent_id], next_token_id, logprobs,
  838. output_metadata.get(seq_ids[parent_id])))
  839. sampler_output.append(
  840. SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
  841. return sampler_output
  842. def _miro_store_args(seqids: List[int], mus: List[float],
  843. output_metadata: OutputMetadata) -> None:
  844. for sid, mu in zip(seqids,
  845. mus.tolist()): # tolist might be premature optimization
  846. output_metadata.add(sid, "miro_mu", mu)
  847. def _apply_mirostat_v2(
  848. logits: torch.Tensor,
  849. taus: torch.Tensor, # AKA the targeted surprise
  850. etas: torch.Tensor, # AKA the learning rate
  851. mus: torch.
  852. Tensor, # AKA the accumulator that always tries to approach [tau]
  853. ) -> torch.Tensor:
  854. logit_surprise = torch.softmax(
  855. logits, dim=-1).log2_().neg_() # Calculate surprise value per token
  856. # For compatibility with ooba/kobold, done in unit of bits(log base 2)
  857. # not nats(ln).
  858. # Ideally this would be a log_softmax, for numerical stability and
  859. # elegance purposes.
  860. # logit_surprise = torch.log_softmax(logits, dim=-1).neg_()
  861. miro_mask = logit_surprise > mus.unsqueeze(
  862. dim=-1) # Mask out "too-surprising" tokens (above mu)
  863. mininds = torch.argmin(logit_surprise, dim=-1)
  864. miro_mask.scatter_(
  865. 1, mininds.unsqueeze(dim=-1), False
  866. ) # Force at least one outcome to be possible, ideally the most likely one
  867. logits[miro_mask] = -float("inf")
  868. probs = torch.softmax(logits, dim=-1,
  869. dtype=logits.dtype) # Get probs, post-mask
  870. # NOTE: Mirostat updates its `mu` values based on the sample chosen.
  871. # The silly approach here is to just sample it and make the logits one-hot.
  872. # This breaks fine grained seeding, but we don't have that yet.
  873. # TODO: FIX when it gets added
  874. next_token_ids = _multinomial(probs, num_samples=1)
  875. # Calculation new `mu` values
  876. # NOTE: If we can know the logit values of the PREVIOUS iteration,
  877. # it should be possible to update `mu` before applying mirostat each
  878. # iteration, thus letting us keep _sample as the last thing that happens.
  879. picked_surprises = torch.gather(logit_surprise,
  880. dim=-1,
  881. index=next_token_ids)
  882. eps = picked_surprises.squeeze() - taus
  883. mus.sub_(etas * eps)
  884. logits.fill_(-float("inf"))
  885. # This value doesn't actually matter, so long as it's not -inf.
  886. # Vectors are now one-hot, after all.
  887. logits.scatter_(1, next_token_ids, 1.0)
  888. return logits
  889. def _mirostat(logits: torch.Tensor, sampling_tensors: SamplingTensors,
  890. output_metadata: OutputMetadata) -> torch.Tensor:
  891. idx = sampling_tensors.miro_indices
  892. seqids = sampling_tensors.miro_seqids
  893. taus = sampling_tensors.miro_taus
  894. etas = sampling_tensors.miro_etas
  895. mus = sampling_tensors.miro_mus
  896. logits[idx] = _apply_mirostat_v2(logits[idx], taus, etas,
  897. mus) # mus is an i/o param, :vomit:
  898. _miro_store_args(seqids, mus, output_metadata)
  899. return logits