sampler.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. """A layer that samples the next tokens from the model's outputs."""
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.modeling.metadata import InputMetadata
  6. from aphrodite.modeling.megatron.communication_op import (
  7. tensor_model_parallel_all_gather)
  8. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  9. from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
  10. SamplerOutput, SequenceData,
  11. SequenceGroupOutputs, SequenceOutputs)
  12. _SAMPLING_EPS = 1e-5
  13. class Sampler(nn.Module):
  14. """Samples the next tokens from the model's outputs.
  15. This layer does the following:
  16. 1. Discard the hidden states that are not used for sampling (i.e., all
  17. tokens except the final one in each prompt).
  18. 2. Compute the logits for the next tokens.
  19. 3. Apply presence and frequency penalties.
  20. 4. Apply temperature scaling.
  21. 5. Apply top-p and top-k truncation.
  22. 6. Sample the next tokens.
  23. Here, each sequence group within the batch can have different sampling
  24. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  25. """
  26. def __init__(self, vocab_size: int) -> None:
  27. super().__init__()
  28. self.vocab_size = vocab_size
  29. def forward(
  30. self,
  31. embedding: torch.Tensor,
  32. hidden_states: torch.Tensor,
  33. input_metadata: InputMetadata,
  34. embedding_bias: Optional[torch.Tensor] = None,
  35. ) -> SamplerOutput:
  36. # Get the hidden states that we use for sampling.
  37. hidden_states = _prune_hidden_states(hidden_states, input_metadata)
  38. # Get the logits for the next tokens.
  39. logits = _get_logits(hidden_states, embedding, embedding_bias,
  40. self.vocab_size)
  41. # Apply presence and frequency penalties.
  42. output_tokens = _get_output_tokens(input_metadata)
  43. assert len(output_tokens) == logits.shape[0]
  44. [presence_penalties, frequency_penalties,
  45. repetition_penalties] = _get_penalties(input_metadata)
  46. assert len(presence_penalties) == logits.shape[0]
  47. assert len(frequency_penalties) == logits.shape[0]
  48. logits = _apply_penalties(logits, output_tokens, presence_penalties,
  49. frequency_penalties, repetition_penalties,
  50. self.vocab_size)
  51. banned_tokens = _get_custom_token_bans(input_metadata)
  52. assert len(banned_tokens) == logits.shape[0]
  53. logits = _apply_token_bans(logits, banned_tokens)
  54. logits = _apply_logits_processors(input_metadata, logits,
  55. output_tokens)
  56. # Apply Eta sampling, as described in https://arxiv.org/abs/2210.15191
  57. eta_cutoffs = _get_eta_cutoffs(input_metadata)
  58. assert len(eta_cutoffs) == logits.shape[0]
  59. if any(eta > _SAMPLING_EPS for eta in eta_cutoffs):
  60. logits = _apply_eta_cutoff(logits, eta_cutoffs)
  61. # Apply Locally typical sampling, as described in
  62. # https://arxiv.org/abs/2202.00666
  63. typical_ps = _get_typical_ps(input_metadata)
  64. assert len(typical_ps) == logits.shape[0]
  65. if any(typ_p < 1.0 - _SAMPLING_EPS for typ_p in typical_ps):
  66. logits = _apply_typical_sampling(logits, typical_ps)
  67. # Apply Tail Free Sampling, as described in
  68. # https://www.trentonbricken.com/Tail-Free-Sampling/
  69. tfss = _get_tfs(input_metadata)
  70. assert len(tfss) == logits.shape[0]
  71. if any(z < 1.0 - _SAMPLING_EPS for z in tfss):
  72. logits = _apply_tfs(logits, tfss)
  73. epsilon_cutoffs = _get_epsilon_cutoffs(input_metadata)
  74. assert len(epsilon_cutoffs) == logits.shape[0]
  75. if any(epsilon > _SAMPLING_EPS for epsilon in epsilon_cutoffs):
  76. logits = _apply_epsilon_cutoff(logits, epsilon_cutoffs)
  77. # Apply temperature scaling.
  78. temperatures = _get_temperatures(input_metadata)
  79. assert len(temperatures) == logits.shape[0]
  80. if any(t != 1.0 for t in temperatures):
  81. t = torch.tensor(temperatures,
  82. dtype=logits.dtype,
  83. device=logits.device)
  84. # Use in-place division to avoid creating a new tensor.
  85. logits.div_(t.unsqueeze(dim=1))
  86. # Apply top-p, top-k, and top-a truncation.
  87. top_ps, top_ks, top_as = _get_top_a_top_p_top_k(
  88. input_metadata, self.vocab_size)
  89. assert len(top_ps) == len(top_ks) == logits.shape[0]
  90. do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
  91. do_top_k = any(k != self.vocab_size for k in top_ks)
  92. do_top_a = any(a > _SAMPLING_EPS for a in top_as)
  93. if do_top_p or do_top_k or do_top_a:
  94. logits = _apply_top_a_top_p_top_k(logits, top_ps, top_ks, top_as)
  95. # We use float32 for probabilities and log probabilities.
  96. # Compute the probabilities.
  97. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  98. # Compute the log probabilities.
  99. # Use log_softmax to ensure numerical stability.
  100. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  101. # Sample the next tokens.
  102. sample_results = _sample(probs, logprobs, input_metadata)
  103. # Get the logprobs query results.
  104. prompt_logprobs, sample_logprobs = _get_logprobs(
  105. logprobs, input_metadata, sample_results)
  106. return _build_sampler_output(sample_results, input_metadata,
  107. prompt_logprobs, sample_logprobs)
  108. def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
  109. embedding_bias: Optional[torch.Tensor],
  110. vocab_size: int) -> torch.Tensor:
  111. # Get the logits for the next tokens.
  112. logits = torch.matmul(hidden_states, embedding.t())
  113. if embedding_bias is not None:
  114. logits += embedding_bias
  115. logits = tensor_model_parallel_all_gather(logits)
  116. # Remove paddings in vocab (if any).
  117. logits = logits[:, :vocab_size]
  118. return logits
  119. def _prune_hidden_states(
  120. hidden_states: torch.Tensor,
  121. input_metadata: InputMetadata,
  122. ) -> torch.Tensor:
  123. selected_token_indices: List[int] = []
  124. start_idx = 0
  125. for i, seq_group in enumerate(input_metadata.seq_groups):
  126. seq_ids, sampling_params = seq_group
  127. if i < input_metadata.num_prompts:
  128. assert len(seq_ids) == 1, "Prompt input should have only one seq."
  129. prompt_len = input_metadata.prompt_lens[i]
  130. if sampling_params.prompt_logprobs is not None:
  131. selected_token_indices.extend(
  132. range(start_idx, start_idx + prompt_len - 1))
  133. selected_token_indices.append(start_idx + prompt_len - 1)
  134. start_idx += input_metadata.max_prompt_len
  135. else:
  136. num_seqs = len(seq_ids)
  137. selected_token_indices.extend(
  138. range(start_idx, start_idx + num_seqs))
  139. start_idx += num_seqs
  140. selected_token_indices = torch.tensor(selected_token_indices,
  141. dtype=torch.long,
  142. device=hidden_states.device)
  143. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  144. return hidden_states.index_select(0, selected_token_indices)
  145. def _get_penalties(
  146. input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
  147. # Collect the presence and frequency penalties.
  148. presence_penalties: List[float] = []
  149. frequency_penalties: List[float] = []
  150. repetition_penalties: List[float] = []
  151. for i, seq_group in enumerate(input_metadata.seq_groups):
  152. seq_ids, sampling_params = seq_group
  153. if (i < input_metadata.num_prompts
  154. and sampling_params.prompt_logprobs is not None):
  155. prompt_len = input_metadata.prompt_lens[i]
  156. presence_penalties += [0] * (prompt_len - 1)
  157. frequency_penalties += [0] * (prompt_len - 1)
  158. repetition_penalties += [0] * (prompt_len - 1)
  159. presence_penalties += [sampling_params.presence_penalty] * len(seq_ids)
  160. frequency_penalties += [sampling_params.frequency_penalty
  161. ] * len(seq_ids)
  162. repetition_penalties += [sampling_params.repetition_penalty
  163. ] * len(seq_ids)
  164. return presence_penalties, frequency_penalties, repetition_penalties
  165. def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
  166. output_tokens: List[List[int]] = []
  167. for i, seq_group in enumerate(input_metadata.seq_groups):
  168. seq_ids, sampling_params = seq_group
  169. if (i < input_metadata.num_prompts
  170. and sampling_params.prompt_logprobs is not None):
  171. # NOTE: prompt token positions do not need output tokens to
  172. # compute penalties.
  173. prompt_len = input_metadata.prompt_lens[i]
  174. output_tokens.extend([] for _ in range(prompt_len - 1))
  175. for seq_id in seq_ids:
  176. seq_data = input_metadata.seq_data[seq_id]
  177. output_tokens.append(seq_data.output_token_ids)
  178. return output_tokens
  179. def _get_custom_token_bans(input_metadata: InputMetadata) -> List[List[int]]:
  180. banned_tokens: List[List[int]] = []
  181. for i, seq_group in enumerate(input_metadata.seq_groups):
  182. seq_ids, sampling_params = seq_group
  183. custom_token_bans = sampling_params.custom_token_bans
  184. if (i < input_metadata.num_prompts
  185. and sampling_params.prompt_logprobs is not None):
  186. prompt_len = input_metadata.prompt_lens[i]
  187. banned_tokens += [custom_token_bans] * (prompt_len - 1)
  188. banned_tokens += [custom_token_bans] * len(seq_ids)
  189. return banned_tokens
  190. def _apply_logits_processors(input_metadata: InputMetadata,
  191. logits: torch.Tensor,
  192. output_tokens: List[List[int]]) -> torch.Tensor:
  193. seq_offset = 0
  194. for seq_ids, sampling_params in input_metadata.seq_groups:
  195. seq_end = seq_offset + len(seq_ids)
  196. for proc in sampling_params.logits_processors:
  197. proc(logits[seq_offset:seq_end], output_tokens[seq_offset:seq_end])
  198. seq_offset = seq_end
  199. return logits
  200. def _apply_penalties(
  201. logits: torch.Tensor,
  202. output_tokens: List[List[int]],
  203. presence_penalties: List[float],
  204. frequency_penalties: List[float],
  205. repetition_penalties: List[float],
  206. vocab_size: int,
  207. ) -> torch.Tensor:
  208. num_seqs, vocab_size = logits.shape
  209. for i in range(num_seqs):
  210. if not output_tokens[i]:
  211. continue
  212. if (abs(presence_penalties[i]) < _SAMPLING_EPS
  213. and abs(frequency_penalties[i]) < _SAMPLING_EPS
  214. and repetition_penalties[i] < 1.0 + _SAMPLING_EPS):
  215. continue
  216. break
  217. else:
  218. # Return early if all sequences have zero penalties.
  219. return logits
  220. max_output_len = max(len(tokens) for tokens in output_tokens)
  221. padded_output_tokens = [
  222. tokens + [vocab_size] * (max_output_len - len(tokens))
  223. for tokens in output_tokens
  224. ]
  225. output_tokens_tensor = torch.tensor(padded_output_tokens,
  226. dtype=torch.long,
  227. device=logits.device)
  228. # Compute the bin counts for the output tokens.
  229. # vocab_size + 1 for padding.
  230. bin_counts = torch.zeros((num_seqs, vocab_size + 1),
  231. dtype=torch.long,
  232. device=logits.device)
  233. bin_counts.scatter_add_(1, output_tokens_tensor,
  234. torch.ones_like(output_tokens_tensor))
  235. bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
  236. frequency_penalties = torch.tensor(frequency_penalties,
  237. dtype=logits.dtype,
  238. device=logits.device)
  239. presence_penalties = torch.tensor(presence_penalties,
  240. dtype=logits.dtype,
  241. device=logits.device)
  242. repetition_penalties = torch.tensor(repetition_penalties,
  243. dtype=logits.dtype,
  244. device=logits.device)
  245. # We follow the definition in OpenAI API.
  246. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  247. logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
  248. presence_mask = (bin_counts > 0)
  249. logits -= presence_penalties.unsqueeze(dim=1) * presence_mask
  250. # Effectively:
  251. # If token is present and logit is positive, divide logit by rep_pen.
  252. # If token is present and logit is negative, multiply logit by rep_pen.
  253. logits += logits * (1 / repetition_penalties.unsqueeze(dim=1) -
  254. 1) * presence_mask * (logits > 0)
  255. logits += logits * (repetition_penalties.unsqueeze(dim=1) -
  256. 1) * presence_mask * (logits < 0)
  257. return logits
  258. def _apply_token_bans(logits: torch.Tensor,
  259. banned_tokens: List[List[int]]) -> torch.Tensor:
  260. for i, banned_token_ids in enumerate(banned_tokens):
  261. if not banned_token_ids:
  262. continue
  263. logits[i, banned_token_ids] = -float("inf")
  264. return logits
  265. def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
  266. # Collect the temperatures for the logits.
  267. temperatures: List[float] = []
  268. for i, seq_group in enumerate(input_metadata.seq_groups):
  269. seq_ids, sampling_params = seq_group
  270. temperature = sampling_params.temperature
  271. if temperature < _SAMPLING_EPS:
  272. # NOTE: Zero temperature means deterministic sampling
  273. # (i.e., greedy sampling or beam search).
  274. # Set the temperature to 1 to avoid division by zero.
  275. temperature = 1.0
  276. if (i < input_metadata.num_prompts
  277. and sampling_params.prompt_logprobs is not None):
  278. prompt_len = input_metadata.prompt_lens[i]
  279. temperatures += [temperature] * (prompt_len - 1)
  280. temperatures += [temperature] * len(seq_ids)
  281. return temperatures
  282. def _get_top_a_top_p_top_k(
  283. input_metadata: InputMetadata,
  284. vocab_size: int,
  285. ) -> Tuple[List[float], List[int], List[float]]:
  286. top_ps: List[float] = []
  287. top_ks: List[int] = []
  288. top_as: List[float] = []
  289. for i, seq_group in enumerate(input_metadata.seq_groups):
  290. seq_ids, sampling_params = seq_group
  291. # k should not be greater than the vocab size.
  292. top_k = min(sampling_params.top_k, vocab_size)
  293. # k=-1 means no truncation.
  294. top_k = vocab_size if top_k == -1 else top_k
  295. if (i < input_metadata.num_prompts
  296. and sampling_params.prompt_logprobs is not None):
  297. prompt_len = input_metadata.prompt_lens[i]
  298. top_ps += [sampling_params.top_p] * (prompt_len - 1)
  299. top_ks += [top_k] * (prompt_len - 1)
  300. top_as += [sampling_params.top_a] * (prompt_len - 1)
  301. top_ps += [sampling_params.top_p] * len(seq_ids)
  302. top_ks += [top_k] * len(seq_ids)
  303. top_as += [sampling_params.top_a] * len(seq_ids)
  304. return top_ps, top_ks, top_as
  305. def _get_tfs(input_metadata: InputMetadata) -> List[float]:
  306. tfss: List[float] = []
  307. for i, seq_group in enumerate(input_metadata.seq_groups):
  308. seq_ids, sampling_params = seq_group
  309. z = sampling_params.tfs
  310. if (i < input_metadata.num_prompts
  311. and sampling_params.prompt_logprobs is not None):
  312. prompt_len = input_metadata.prompt_lens[i]
  313. tfss += [z] * (prompt_len - 1)
  314. tfss += [z] * len(seq_ids)
  315. return tfss
  316. def _get_eta_cutoffs(input_metadata: InputMetadata) -> List[float]:
  317. eta_cutoffs: List[float] = []
  318. for i, seq_group in enumerate(input_metadata.seq_groups):
  319. seq_ids, sampling_params = seq_group
  320. eta_cutoff = sampling_params.eta_cutoff
  321. if (i < input_metadata.num_prompts
  322. and sampling_params.prompt_logprobs is not None):
  323. prompt_len = input_metadata.prompt_lens[i]
  324. eta_cutoffs += [eta_cutoff] * (prompt_len - 1)
  325. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  326. return eta_cutoffs
  327. def _get_epsilon_cutoffs(input_metadata: InputMetadata) -> List[float]:
  328. epsilon_cutoffs: List[float] = []
  329. for i, seq_group in enumerate(input_metadata.seq_groups):
  330. seq_ids, sampling_params = seq_group
  331. epsilon_cutoff = sampling_params.epsilon_cutoff
  332. if (i < input_metadata.num_prompts
  333. and sampling_params.prompt_logprobs is not None):
  334. prompt_len = input_metadata.prompt_lens[i]
  335. epsilon_cutoffs += [epsilon_cutoff] * (prompt_len - 1)
  336. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  337. return epsilon_cutoffs
  338. def _get_typical_ps(input_metadata: InputMetadata) -> List[float]:
  339. typical_ps: List[float] = []
  340. for i, seq_group in enumerate(input_metadata.seq_groups):
  341. seq_ids, sampling_params = seq_group
  342. typical_p = sampling_params.typical_p
  343. if (i < input_metadata.num_prompts
  344. and sampling_params.prompt_logprobs is not None):
  345. prompt_len = input_metadata.prompt_lens[i]
  346. typical_ps += [typical_p] * (prompt_len - 1)
  347. typical_ps += [typical_p] * len(seq_ids)
  348. return typical_ps
  349. def _apply_top_a_top_p_top_k(
  350. logits: torch.Tensor,
  351. top_ps: List[float],
  352. top_ks: List[int],
  353. top_as: List[float],
  354. ) -> torch.Tensor:
  355. ts_p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
  356. ts_k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
  357. ts_a = torch.tensor(top_as, dtype=logits.dtype, device=logits.device)
  358. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  359. # Apply top-p and top-a.
  360. probs_sort = logits_sort.softmax(dim=-1)
  361. probs_sum = probs_sort.cumsum(dim=-1)
  362. top_a_thresholds = torch.pow(probs_sort[:, 0], 2) * ts_a
  363. top_ap_mask = (probs_sort < top_a_thresholds.unsqueeze(1)
  364. ) # Cull logits below the top-a threshold
  365. top_ap_mask.logical_or_(probs_sum > ts_p.unsqueeze(
  366. dim=1)) # Cull logits above the top-p summation threshold
  367. top_ap_mask[:, 0] = False # Guarantee at least one token is pickable
  368. logits_sort[top_ap_mask] = -float("inf")
  369. # Apply top-k.
  370. # Create a mask for the top-k elements.
  371. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  372. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  373. top_k_mask = top_k_mask >= ts_k.unsqueeze(dim=1)
  374. logits_sort[top_k_mask] = -float("inf")
  375. # Re-sort the probabilities.
  376. logits = torch.gather(logits_sort,
  377. dim=-1,
  378. index=torch.argsort(logits_idx, dim=-1))
  379. return logits
  380. def _apply_tfs(
  381. logits: torch.Tensor,
  382. tfss: List[float],
  383. ) -> torch.Tensor:
  384. z = torch.tensor(tfss, dtype=logits.dtype, device=logits.device)
  385. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  386. d2 = logits_sort.softmax(dim=-1).diff().diff().abs()
  387. normalized_d2 = d2 / torch.sum(d2, dim=-1, keepdim=True)
  388. curvature_cdf = torch.cumsum(normalized_d2, dim=-1)
  389. tfs_mask = curvature_cdf > z.unsqueeze(dim=-1)
  390. tfs_mask = torch.cat(
  391. (
  392. torch.zeros(
  393. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  394. tfs_mask,
  395. torch.ones(
  396. logits.shape[0], 1, dtype=torch.bool, device=logits.device),
  397. ),
  398. dim=-1,
  399. )
  400. logits_sort[tfs_mask] = -float("inf")
  401. logits = torch.gather(logits_sort,
  402. dim=-1,
  403. index=torch.argsort(logits_idx, dim=-1))
  404. return logits
  405. def _apply_eta_cutoff(
  406. logits: torch.Tensor,
  407. eta_cutoffs: List[float],
  408. ) -> torch.Tensor:
  409. eta = torch.tensor(eta_cutoffs, dtype=logits.dtype,
  410. device=logits.device) * 1e-4
  411. shifted_logits = torch.log_softmax(logits, dim=-1)
  412. probs = shifted_logits.exp()
  413. neg_entropy = (probs * shifted_logits).nansum(dim=-1)
  414. eps = torch.min(eta,
  415. torch.sqrt(eta) * torch.exp(neg_entropy)).unsqueeze(dim=1)
  416. eta_mask = probs < eps
  417. if torch.all(eta_mask): # guard against nulling out all the logits
  418. topk_prob, _ = torch.max(probs, dim=-1)
  419. eta_mask = probs < topk_prob
  420. logits[eta_mask] = -float("inf")
  421. return logits
  422. def _apply_epsilon_cutoff(
  423. logits: torch.Tensor,
  424. epsilon_cutoffs: List[float],
  425. ) -> torch.Tensor:
  426. eps = torch.tensor(epsilon_cutoffs,
  427. dtype=logits.dtype,
  428. device=logits.device).unsqueeze(dim=1)
  429. probs = logits.softmax(dim=-1)
  430. eps_mask = probs < (eps * 1e-4)
  431. if torch.all(eps_mask): # guard against nulling out all the logits
  432. topk_prob, _ = torch.max(probs, dim=-1)
  433. eps_mask = probs < topk_prob
  434. logits[eps_mask] = -float("inf")
  435. return logits
  436. def _apply_typical_sampling(
  437. logits: torch.Tensor,
  438. typical_ps: List[float],
  439. ) -> torch.Tensor:
  440. typ_p = torch.tensor(typical_ps, dtype=logits.dtype, device=logits.device)
  441. shifted_logits = torch.log_softmax(logits, dim=-1)
  442. probs = shifted_logits.exp()
  443. neg_entropy = (probs * shifted_logits).nansum(dim=-1, keepdim=True)
  444. surprisal_deviations = (neg_entropy - shifted_logits).abs()
  445. _, indices = torch.sort(surprisal_deviations)
  446. reordered_probs = probs.gather(-1, indices)
  447. typ_mask_sorted = reordered_probs.cumsum(dim=-1) >= typ_p.unsqueeze(dim=1)
  448. min_tokens_to_keep = 1
  449. # Keep at least min_tokens_to_keep
  450. typ_mask_sorted[..., :min_tokens_to_keep] = 0
  451. typ_mask = typ_mask_sorted.scatter(1, indices, typ_mask_sorted)
  452. logits[typ_mask] = -float("inf")
  453. return logits
  454. def _greedy_sample(
  455. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  456. logprobs: torch.Tensor,
  457. ) -> List[Tuple[List[int], List[int]]]:
  458. samples = torch.argmax(logprobs, dim=-1).cpu()
  459. sample_idx = 0
  460. results = []
  461. for seq_group in selected_seq_groups:
  462. seq_ids, _ = seq_group
  463. num_parent_seqs = len(seq_ids)
  464. assert num_parent_seqs == 1, (
  465. "Greedy sampling should have only one seq.")
  466. parent_ids = list(range(num_parent_seqs))
  467. next_token_ids = [samples[sample_idx].item()]
  468. results.append((next_token_ids, parent_ids))
  469. sample_idx += num_parent_seqs
  470. assert sample_idx == logprobs.size(0)
  471. return results
  472. def _random_sample(
  473. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  474. is_prompts: List[bool],
  475. probs: torch.Tensor,
  476. ) -> List[Tuple[List[int], List[int]]]:
  477. # Find the maximum best_of value of the prompt phase requests.
  478. max_best_of = 1
  479. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  480. if is_prompt:
  481. seq_ids, sampling_params = seq_group
  482. max_best_of = max(max_best_of, sampling_params.best_of)
  483. random_samples = torch.multinomial(probs,
  484. num_samples=max_best_of,
  485. replacement=True).cpu()
  486. sample_idx = 0
  487. results = []
  488. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  489. seq_ids, sampling_params = seq_group
  490. num_parent_seqs = len(seq_ids)
  491. if is_prompt:
  492. # Prompt phase.
  493. assert num_parent_seqs == 1, (
  494. "Prompt input should have only one seq.")
  495. parent_ids = [0] * sampling_params.best_of
  496. next_token_ids = random_samples[
  497. sample_idx, :sampling_params.best_of].tolist()
  498. else:
  499. # Generation phase.
  500. parent_ids = list(range(num_parent_seqs))
  501. next_token_ids = random_samples[sample_idx:sample_idx +
  502. num_parent_seqs, 0].tolist()
  503. results.append((next_token_ids, parent_ids))
  504. sample_idx += num_parent_seqs
  505. assert sample_idx == probs.size(0)
  506. return results
  507. def _beam_search_sample(
  508. selected_seq_groups: List[Tuple[List[int], SamplingParams]],
  509. is_prompts: List[bool],
  510. seq_data: Dict[int, SequenceData],
  511. logprobs: torch.Tensor,
  512. ) -> List[Tuple[List[int], List[int]]]:
  513. # We sample 2 * beam_width candidates to make sure that with high
  514. # probability we can get `beam_width` candidates in addition to
  515. # the finished sequences for the next iteration. See
  516. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  517. # for details. See also HF reference:
  518. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  519. #
  520. # Note: Beam search is not vectorized, so its speed can be slower than
  521. # other sampling methods.
  522. sample_idx = 0
  523. results = []
  524. for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
  525. seq_ids, sampling_params = seq_group
  526. num_parent_seqs = len(seq_ids)
  527. beam_width = sampling_params.best_of
  528. seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
  529. if is_prompt:
  530. # Prompt phase.
  531. assert num_parent_seqs == 1, (
  532. "Prompt input should have only one seq.")
  533. parent_ids = [0] * (2 * beam_width)
  534. _, next_token_ids = torch.topk(seq_group_logprobs[0],
  535. 2 * beam_width)
  536. next_token_ids = next_token_ids.tolist()
  537. else:
  538. # Generation phase.
  539. cumulative_logprobs = [
  540. seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
  541. ]
  542. cumulative_logprobs = torch.tensor(
  543. cumulative_logprobs,
  544. dtype=torch.float,
  545. device=seq_group_logprobs.device)
  546. seq_group_logprobs = (seq_group_logprobs +
  547. cumulative_logprobs.unsqueeze(dim=1))
  548. _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
  549. 2 * beam_width)
  550. topk_ids = topk_ids.tolist()
  551. vocab_size = seq_group_logprobs.size(-1)
  552. parent_ids = [i // vocab_size for i in topk_ids]
  553. next_token_ids = [i % vocab_size for i in topk_ids]
  554. results.append((next_token_ids, parent_ids))
  555. sample_idx += num_parent_seqs
  556. assert sample_idx == logprobs.size(0)
  557. return results
  558. def _sample(
  559. probs: torch.Tensor,
  560. logprobs: torch.Tensor,
  561. input_metadata: InputMetadata,
  562. ) -> List[Tuple[List[int], List[int]]]:
  563. categorized_seq_group_ids = {t: [] for t in SamplingType}
  564. categorized_sample_indices = {t: [] for t in SamplingType}
  565. start_idx = 0
  566. for i, seq_group in enumerate(input_metadata.seq_groups):
  567. seq_ids, sampling_params = seq_group
  568. sampling_type = sampling_params.sampling_type
  569. if (i < input_metadata.num_prompts
  570. and sampling_params.prompt_logprobs is not None):
  571. prompt_len = input_metadata.prompt_lens[i]
  572. start_idx += prompt_len - 1
  573. categorized_seq_group_ids[sampling_type].append(i)
  574. num_seqs = len(seq_ids)
  575. categorized_sample_indices[sampling_type].extend(
  576. range(start_idx, start_idx + num_seqs))
  577. start_idx += num_seqs
  578. sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
  579. for sampling_type in SamplingType:
  580. seq_group_ids = categorized_seq_group_ids[sampling_type]
  581. seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
  582. is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
  583. sample_indices = categorized_sample_indices[sampling_type]
  584. num_tokens = len(sample_indices)
  585. if num_tokens == 0:
  586. continue
  587. if sampling_type == SamplingType.GREEDY:
  588. category_logprobs = logprobs[sample_indices]
  589. sample_results = _greedy_sample(seq_groups, category_logprobs)
  590. elif sampling_type == SamplingType.RANDOM:
  591. category_probs = probs[sample_indices]
  592. sample_results = _random_sample(seq_groups, is_prompts,
  593. category_probs)
  594. elif sampling_type == SamplingType.BEAM:
  595. category_logprobs = logprobs[sample_indices]
  596. sample_results = _beam_search_sample(seq_groups, is_prompts,
  597. input_metadata.seq_data,
  598. category_logprobs)
  599. else:
  600. raise ValueError(f"Unsupported sampling type: {sampling_type}")
  601. sample_results_dict.update(zip(seq_group_ids, sample_results))
  602. sample_results = [
  603. sample_results_dict[i]
  604. for i in range(len(input_metadata.seq_groups))
  605. ]
  606. return sample_results
  607. def _get_logprobs(
  608. logprobs: torch.Tensor,
  609. input_metadata: InputMetadata,
  610. sample_results: List[Tuple[List[int], List[int]]],
  611. ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
  612. int, float]]]]:
  613. # Prepare query indices
  614. batched_logprobs_query_seq_indices: List[int] = []
  615. batched_logprobs_query_token_indices: List[int] = []
  616. largest_num_logprobs = 0
  617. sample_idx = 0
  618. for i, (seq_group, sample_result) in enumerate(
  619. zip(input_metadata.seq_groups, sample_results)):
  620. seq_ids, sampling_params = seq_group
  621. next_token_ids, parent_ids = sample_result
  622. num_parent_seqs = len(seq_ids)
  623. if (i < input_metadata.num_prompts
  624. and sampling_params.prompt_logprobs is not None):
  625. largest_num_logprobs = max(largest_num_logprobs,
  626. sampling_params.prompt_logprobs)
  627. prompt_len = input_metadata.prompt_lens[i]
  628. prompt_tokens = input_metadata.seq_data[
  629. seq_ids[0]].prompt_token_ids
  630. batched_logprobs_query_seq_indices.extend(
  631. sample_idx + j for j in range(prompt_len - 1))
  632. batched_logprobs_query_token_indices.extend(
  633. token_id for token_id in prompt_tokens[1:])
  634. sample_idx += prompt_len - 1
  635. batched_logprobs_query_seq_indices.extend(
  636. [sample_idx + parent_id for parent_id in parent_ids])
  637. batched_logprobs_query_token_indices.extend(next_token_ids)
  638. if sampling_params.logprobs is not None:
  639. largest_num_logprobs = max(largest_num_logprobs,
  640. sampling_params.logprobs)
  641. sample_idx += num_parent_seqs
  642. assert sample_idx == logprobs.size(0)
  643. # Batched query for logprobs of selected token
  644. batched_logprobs_query_result = logprobs[[
  645. batched_logprobs_query_seq_indices,
  646. batched_logprobs_query_token_indices
  647. ]].cpu()
  648. # Batched query for logprobs of topk tokens
  649. if largest_num_logprobs > 0:
  650. top_logprobs, top_token_ids = torch.topk(logprobs,
  651. largest_num_logprobs,
  652. dim=-1)
  653. top_logprobs = top_logprobs.cpu()
  654. top_token_ids = top_token_ids.cpu()
  655. else:
  656. top_logprobs, top_token_ids = None, None
  657. # Gather results
  658. result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
  659. result_sample_logprobs: List[SampleLogprobs] = []
  660. sample_idx = 0
  661. query_result_idx = 0
  662. for i, (seq_group, sample_result) in enumerate(
  663. zip(input_metadata.seq_groups, sample_results)):
  664. seq_ids, sampling_params = seq_group
  665. next_token_ids, parent_ids = sample_result
  666. # Prompt logprobs
  667. if (i < input_metadata.num_prompts
  668. and sampling_params.prompt_logprobs is not None):
  669. num_logprobs = sampling_params.prompt_logprobs
  670. prompt_len = input_metadata.prompt_lens[i]
  671. prompt_tokens = input_metadata.seq_data[
  672. seq_ids[0]].prompt_token_ids
  673. group_prompt_logprobs: PromptLogprobs = [None]
  674. for token_id in prompt_tokens[1:]:
  675. prompt_logprobs_dict = {
  676. token_id:
  677. batched_logprobs_query_result[query_result_idx].item()
  678. }
  679. if num_logprobs > 0:
  680. prompt_logprobs_dict.update(
  681. zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
  682. top_logprobs[sample_idx, :num_logprobs].tolist()))
  683. group_prompt_logprobs.append(prompt_logprobs_dict)
  684. sample_idx += 1
  685. query_result_idx += 1
  686. result_prompt_logprobs.append(group_prompt_logprobs)
  687. else:
  688. result_prompt_logprobs.append(None)
  689. # Sample logprobs
  690. num_logprobs = sampling_params.logprobs
  691. if num_logprobs is None:
  692. num_logprobs = 0
  693. group_sample_logprobs: SampleLogprobs = []
  694. for next_token_id, parent_id in zip(next_token_ids, parent_ids):
  695. sample_logprobs_dict = {
  696. next_token_id:
  697. batched_logprobs_query_result[query_result_idx].item()
  698. }
  699. query_result_idx += 1
  700. if num_logprobs > 0:
  701. sample_logprobs_dict.update(
  702. zip(
  703. top_token_ids[sample_idx +
  704. parent_id, :num_logprobs].tolist(),
  705. top_logprobs[sample_idx +
  706. parent_id, :num_logprobs].tolist()))
  707. group_sample_logprobs.append(sample_logprobs_dict)
  708. result_sample_logprobs.append(group_sample_logprobs)
  709. sample_idx += len(seq_ids)
  710. return result_prompt_logprobs, result_sample_logprobs
  711. def _build_sampler_output(
  712. sample_results: List[Tuple[List[int], List[int]]],
  713. input_metadata: InputMetadata,
  714. prompt_logprobs: List[Optional[PromptLogprobs]],
  715. sample_logprobs: List[SampleLogprobs],
  716. ) -> SamplerOutput:
  717. sampler_output = []
  718. for (seq_group, sample_result, group_prompt_logprobs,
  719. group_sample_logprobs) in zip(input_metadata.seq_groups,
  720. sample_results, prompt_logprobs,
  721. sample_logprobs):
  722. seq_ids, _ = seq_group
  723. next_token_ids, parent_ids = sample_result
  724. seq_outputs = []
  725. for parent_id, next_token_id, logprobs in zip(parent_ids,
  726. next_token_ids,
  727. group_sample_logprobs):
  728. seq_outputs.append(
  729. SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
  730. sampler_output.append(
  731. SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
  732. return sampler_output