1
0

sampler.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. """A layer that samples the next tokens from the model's outputs."""
  2. from typing import Dict, List, Tuple, Optional
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from aphrodite.modeling.metadata import InputMetadata
  7. from aphrodite.modeling.megatron.tensor_parallel import (
  8. gather_from_tensor_model_parallel_region)
  9. from aphrodite.common.sampling_params import SamplingParams
  10. from aphrodite.common.sequence import SamplerOutput, SequenceOutputs
  11. _SAMPLING_EPS = 1e-5
  12. class Sampler(nn.Module):
  13. """Samples the next tokens from the model's outputs.
  14. This layer does the following:
  15. 1. Discard the hidden states that are not used for sampling (i.e., all
  16. tokens except the final one in each prompt).
  17. 2. Compute the logits for the next tokens.
  18. 3. Apply presence and frequency penalties.
  19. 4. Apply temperature scaling.
  20. 5. Apply top-p and top-k truncation.
  21. 6. Sample the next tokens.
  22. Here, each sequence group within the batch can have different sampling
  23. parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
  24. """
  25. def __init__(self, vocab_size: int) -> None:
  26. super().__init__()
  27. self.vocab_size = vocab_size
  28. def forward(
  29. self,
  30. embedding: torch.Tensor,
  31. hidden_states: torch.Tensor,
  32. input_metadata: InputMetadata,
  33. embedding_bias: Optional[torch.Tensor] = None,
  34. ) -> SamplerOutput:
  35. # Get the hidden states that we use for sampling.
  36. hidden_states = _prune_hidden_states(hidden_states, input_metadata)
  37. # Get the logits for the next tokens.
  38. logits = torch.matmul(hidden_states, embedding.t())
  39. if embedding_bias is not None:
  40. logits += embedding_bias
  41. logits = gather_from_tensor_model_parallel_region(logits)
  42. # Remove paddings in vocab (if any).
  43. logits = logits[:, :self.vocab_size]
  44. # Apply presence and frequency penalties.
  45. output_tokens = _get_output_tokens(input_metadata)
  46. assert len(output_tokens) == logits.shape[0]
  47. presence_penalties, frequency_penalties = _get_penalties(
  48. input_metadata)
  49. assert len(presence_penalties) == logits.shape[0]
  50. assert len(frequency_penalties) == logits.shape[0]
  51. logits = _apply_penalties(logits, output_tokens, presence_penalties,
  52. frequency_penalties, self.vocab_size)
  53. logits = _apply_logits_processors(input_metadata, logits, output_tokens)
  54. # Apply temperature scaling.
  55. temperatures = _get_temperatures(input_metadata)
  56. assert len(temperatures) == logits.shape[0]
  57. if any(t != 1.0 for t in temperatures):
  58. t = torch.tensor(temperatures,
  59. dtype=logits.dtype,
  60. device=logits.device)
  61. # Use in-place division to avoid creating a new tensor.
  62. logits.div_(t.unsqueeze(dim=1))
  63. # Apply top-p and top-k truncation.
  64. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
  65. assert len(top_ps) == len(top_ks) == logits.shape[0]
  66. do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
  67. do_top_k = any(k != self.vocab_size for k in top_ks)
  68. if do_top_p or do_top_k:
  69. logits = _apply_top_p_top_k(logits, top_ps, top_ks)
  70. # We use float32 for probabilities and log probabilities.
  71. # Compute the probabilities.
  72. probs = torch.softmax(logits, dim=-1, dtype=torch.float)
  73. # Compute the log probabilities (before applying top-p and top-k).
  74. # Use log_softmax to ensure numerical stability
  75. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
  76. # Sample the next tokens.
  77. return _sample(probs, logprobs, input_metadata)
  78. def _prune_hidden_states(
  79. hidden_states: torch.Tensor,
  80. input_metadata: InputMetadata,
  81. ) -> torch.Tensor:
  82. start_idx = 0
  83. last_token_indicies: List[int] = []
  84. for prompt_len in input_metadata.prompt_lens:
  85. last_token_indicies.append(start_idx + prompt_len - 1)
  86. start_idx += prompt_len
  87. last_token_indicies.extend(
  88. range(start_idx, start_idx + input_metadata.num_generation_tokens))
  89. return hidden_states.index_select(
  90. 0, torch.tensor(last_token_indicies, device=hidden_states.device))
  91. def _get_penalties(
  92. input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
  93. # Collect the presence and frequency penalties.
  94. presence_penalties: List[float] = []
  95. frequency_penalties: List[float] = []
  96. for i, seq_group in enumerate(input_metadata.seq_groups):
  97. seq_ids, sampling_params = seq_group
  98. p = sampling_params.presence_penalty
  99. f = sampling_params.frequency_penalty
  100. if i < input_metadata.num_prompts:
  101. # A prompt input.
  102. presence_penalties.append(p)
  103. frequency_penalties.append(f)
  104. else:
  105. # A generation token.
  106. presence_penalties += [p] * len(seq_ids)
  107. frequency_penalties += [f] * len(seq_ids)
  108. return presence_penalties, frequency_penalties
  109. def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
  110. output_tokens: List[List[int]] = []
  111. for i, seq_group in enumerate(input_metadata.seq_groups):
  112. seq_ids, _ = seq_group
  113. if i < input_metadata.num_prompts:
  114. # A prompt input.
  115. # NOTE: While the prompt input usually has no output tokens,
  116. # it may have output tokens in the case of recomputation.
  117. seq_id = seq_ids[0]
  118. seq_data = input_metadata.seq_data[seq_id]
  119. output_tokens.append(seq_data.output_token_ids)
  120. else:
  121. # A generation token.
  122. for seq_id in seq_ids:
  123. seq_data = input_metadata.seq_data[seq_id]
  124. output_tokens.append(seq_data.output_token_ids)
  125. return output_tokens
  126. def _apply_logits_processors(
  127. input_metadata: InputMetadata,
  128. logits: torch.Tensor,
  129. output_tokens: List[List[int]]
  130. ) -> torch.Tensor:
  131. for _, seq_group in enumerate(input_metadata.seq_groups):
  132. _, sampling_params = seq_group
  133. logits_processors = sampling_params.logits_processors
  134. if logits_processors is not None:
  135. for logits_processor in logits_processors:
  136. logits = logits_processor(logits, output_tokens)
  137. return logits
  138. def _apply_penalties(
  139. logits: torch.Tensor,
  140. output_tokens: List[List[int]],
  141. presence_penalties: List[float],
  142. frequency_penalties: List[float],
  143. vocab_size: int,
  144. ) -> torch.Tensor:
  145. num_seqs = logits.shape[0]
  146. # Collect the indices of sequences that have non-zero penalties.
  147. indices = []
  148. for i in range(num_seqs):
  149. if not output_tokens[i]:
  150. continue
  151. p = presence_penalties[i]
  152. f = frequency_penalties[i]
  153. if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
  154. continue
  155. indices.append(i)
  156. # Return early if all sequences have zero penalties.
  157. if not indices:
  158. return logits
  159. bin_counts = []
  160. for i in indices:
  161. bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
  162. bin_counts = np.stack(bin_counts, axis=0)
  163. bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
  164. device=logits.device)
  165. frequency_penalties = [frequency_penalties[i] for i in indices]
  166. frequency_penalties = torch.tensor(frequency_penalties,
  167. dtype=logits.dtype,
  168. device=logits.device)
  169. presence_penalties = [presence_penalties[i] for i in indices]
  170. presence_penalties = torch.tensor(presence_penalties,
  171. dtype=logits.dtype,
  172. device=logits.device)
  173. # We follow the definition in OpenAI API.
  174. # Refer to https://platform.openai.com/docs/api-reference/parameter-details
  175. logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
  176. presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
  177. logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
  178. return logits
  179. def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
  180. # Collect the temperatures for the logits.
  181. temperatures: List[float] = []
  182. for i, seq_group in enumerate(input_metadata.seq_groups):
  183. seq_ids, sampling_params = seq_group
  184. temperature = sampling_params.temperature
  185. if temperature < _SAMPLING_EPS:
  186. # NOTE: Zero temperature means deterministic sampling
  187. # (i.e., greedy sampling or beam search).
  188. # Set the temperature to 1 to avoid division by zero.
  189. temperature = 1.0
  190. if i < input_metadata.num_prompts:
  191. # A prompt input.
  192. temperatures.append(temperature)
  193. else:
  194. # A generation token.
  195. temperatures += [temperature] * len(seq_ids)
  196. return temperatures
  197. def _get_top_p_top_k(
  198. input_metadata: InputMetadata,
  199. vocab_size: int,
  200. ) -> Tuple[List[float], List[int]]:
  201. top_ps: List[float] = []
  202. top_ks: List[int] = []
  203. for i, seq_group in enumerate(input_metadata.seq_groups):
  204. seq_ids, sampling_params = seq_group
  205. top_p = sampling_params.top_p
  206. # k should not be greater than the vocab size.
  207. top_k = min(sampling_params.top_k, vocab_size)
  208. # k=-1 means no truncation.
  209. top_k = vocab_size if top_k == -1 else top_k
  210. if i < input_metadata.num_prompts:
  211. # A prompt input.
  212. top_ps.append(top_p)
  213. top_ks.append(top_k)
  214. else:
  215. # A generation token.
  216. top_ps += [top_p] * len(seq_ids)
  217. top_ks += [top_k] * len(seq_ids)
  218. return top_ps, top_ks
  219. def _apply_top_p_top_k(
  220. logits: torch.Tensor,
  221. top_ps: List[float],
  222. top_ks: List[int],
  223. ) -> torch.Tensor:
  224. p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
  225. k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
  226. logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
  227. # Apply top-p.
  228. probs_sort = logits_sort.softmax(dim=-1)
  229. probs_sum = probs_sort.cumsum(dim=-1)
  230. top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
  231. logits_sort[top_p_mask] = -float("inf")
  232. # Apply top-k.
  233. # Create a mask for the top-k elements.
  234. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
  235. top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
  236. top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
  237. logits_sort[top_k_mask] = -float("inf")
  238. # Re-sort the probabilities.
  239. logits = torch.gather(logits_sort,
  240. dim=-1,
  241. index=torch.argsort(logits_idx, dim=-1))
  242. return logits
  243. def _get_topk_logprobs(
  244. logprobs: torch.Tensor,
  245. num_logprobs: Optional[int],
  246. ) -> Dict[int, float]:
  247. if num_logprobs is None or num_logprobs == 0:
  248. return {}
  249. topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
  250. if num_logprobs == 1:
  251. topk_logprobs = [topk_logprobs.item()]
  252. topk_ids = [topk_ids.item()]
  253. else:
  254. topk_logprobs = topk_logprobs.tolist()
  255. topk_ids = topk_ids.tolist()
  256. token_to_logprob: Dict[int, float] = {}
  257. for token_id, logprob in zip(topk_ids, topk_logprobs):
  258. token_to_logprob[token_id] = logprob
  259. return token_to_logprob
  260. def _sample_from_prompt(
  261. prob: torch.Tensor,
  262. sampling_params: SamplingParams,
  263. ) -> List[int]:
  264. if sampling_params.use_beam_search:
  265. # Beam search.
  266. beam_width = sampling_params.best_of
  267. # Sample 2 * beam_width candidates to make sure that with high
  268. # probability we can get `beam_width` candidates in addition to
  269. # the finished sequences for the next iteration. See
  270. # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
  271. # for details. See also HF reference:
  272. # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
  273. _, next_token_ids = torch.topk(prob, 2 * beam_width)
  274. next_token_ids = next_token_ids.tolist()
  275. elif sampling_params.temperature < _SAMPLING_EPS:
  276. # Greedy sampling.
  277. assert sampling_params.best_of == 1
  278. next_token_id = torch.argmax(prob)
  279. next_token_ids = [next_token_id.item()]
  280. else:
  281. # Random sampling.
  282. # Sample `best_of` tokens for the prompt.
  283. num_seqs = sampling_params.best_of
  284. next_token_ids = torch.multinomial(prob,
  285. num_samples=num_seqs,
  286. replacement=True)
  287. next_token_ids = next_token_ids.tolist()
  288. return next_token_ids
  289. def _sample_from_generation_tokens(
  290. seq_ids: List[int],
  291. probs: torch.Tensor,
  292. logprobs: torch.Tensor,
  293. seq_logprobs: List[float],
  294. sampling_params: SamplingParams,
  295. ) -> Tuple[List[int], List[int]]:
  296. # NOTE: sampling_params.best_of can be greater than
  297. # len(seq_ids) because some sequences in the group might have
  298. # been already terminated.
  299. if sampling_params.use_beam_search:
  300. # Beam search.
  301. # Add cumulative logprobs for the sequences in the group.
  302. seq_logprobs = torch.tensor(seq_logprobs,
  303. dtype=torch.float,
  304. device=logprobs.device)
  305. logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
  306. vocab_size = logprobs.size(-1)
  307. beam_width = len(seq_ids)
  308. _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width)
  309. topk_ids = topk_ids.tolist()
  310. seq_idx = [i // vocab_size for i in topk_ids]
  311. parent_seq_ids = [seq_ids[i] for i in seq_idx]
  312. next_token_ids = [i % vocab_size for i in topk_ids]
  313. elif sampling_params.temperature < _SAMPLING_EPS:
  314. # Greedy sampling.
  315. assert len(seq_ids) == 1
  316. next_token_id = torch.argmax(probs, dim=-1)
  317. next_token_ids = [int(next_token_id.item())]
  318. parent_seq_ids = seq_ids
  319. else:
  320. # Random sampling.
  321. # Sample 1 token for each sequence in the group.
  322. next_token_ids = torch.multinomial(probs,
  323. num_samples=1,
  324. replacement=True)
  325. next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
  326. parent_seq_ids = seq_ids
  327. return parent_seq_ids, next_token_ids
  328. def _sample(
  329. probs: torch.Tensor,
  330. logprobs: torch.Tensor,
  331. input_metadata: InputMetadata,
  332. ) -> SamplerOutput:
  333. seq_outputs: SamplerOutput = []
  334. # TODO: Optimize.
  335. idx = 0
  336. for i, seq_group in enumerate(input_metadata.seq_groups):
  337. seq_group_outputs: List[SequenceOutputs] = []
  338. seq_ids, sampling_params = seq_group
  339. if i < input_metadata.num_prompts:
  340. # Generate the next tokens for a prompt input.
  341. assert len(seq_ids) == 1, "Prompt input should have only one seq."
  342. parent_seq_id = seq_ids[0]
  343. prob = probs[idx]
  344. logprob = logprobs[idx]
  345. idx += 1
  346. # Sample the next tokens.
  347. next_token_ids = _sample_from_prompt(prob, sampling_params)
  348. # Get top-k log probabilities for the next tokens.
  349. next_logprobs = _get_topk_logprobs(logprob,
  350. sampling_params.logprobs)
  351. # Build the output.
  352. for next_token_id in next_token_ids:
  353. output_logprobs = next_logprobs.copy()
  354. output_logprobs[next_token_id] = logprob[next_token_id].item()
  355. seq_group_outputs.append(
  356. SequenceOutputs(parent_seq_id, next_token_id,
  357. output_logprobs))
  358. else:
  359. # Generate the next tokens for generation tokens.
  360. num_parent_seqs = len(seq_ids)
  361. prob = probs[idx:idx + num_parent_seqs]
  362. logprob = logprobs[idx:idx + num_parent_seqs]
  363. idx += num_parent_seqs
  364. # Sample the next tokens.
  365. seq_logprobs = [
  366. input_metadata.seq_data[seq_id].cumulative_logprob
  367. for seq_id in seq_ids
  368. ]
  369. parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
  370. seq_ids, prob, logprob, seq_logprobs, sampling_params)
  371. # Get top-k log probabilities for the next tokens.
  372. next_logprobs: Dict[int, Dict[int, float]] = {}
  373. for j, seq_id in enumerate(seq_ids):
  374. next_logprobs[seq_id] = _get_topk_logprobs(
  375. logprob[j], sampling_params.logprobs)
  376. # Build the output.
  377. for parent_seq_id, next_token_id in zip(parent_seq_ids,
  378. next_token_ids):
  379. j = seq_ids.index(parent_seq_id)
  380. output_logprobs = next_logprobs[parent_seq_id].copy()
  381. output_logprobs[next_token_id] = logprob[j,
  382. next_token_id].item()
  383. seq_group_outputs.append(
  384. SequenceOutputs(parent_seq_id, next_token_id,
  385. output_logprobs))
  386. seq_outputs.append(seq_group_outputs)
  387. return seq_outputs