sampling_metadata.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717
  1. import random
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Tuple
  4. import numpy as np
  5. import torch
  6. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  7. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  8. from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available,
  9. maybe_expand_dim)
  10. from aphrodite.modeling.layers.ops.sample import get_num_triton_sampler_splits
  11. _SAMPLING_EPS = 1e-5
  12. _SEED_0_REPLACEMENT = 3403598558
  13. @dataclass
  14. class SequenceGroupToSample:
  15. # |---------- N-1 iteration --------|
  16. # |---------------- N iteration ---------------------|
  17. # |- tokenA -|......................|-- newTokens ---|
  18. # |---------- context_len ----------|
  19. # |-------------------- seq_len ----------------------|
  20. # |-- query_len ---|
  21. # Sequence ids for the sequence group in a previous step.
  22. seq_ids: List[int]
  23. sampling_params: SamplingParams
  24. # seq_id -> sequence data.
  25. seq_data: Dict[int, SequenceData]
  26. # The length of the sequence (all tokens seen in the past + new token to
  27. # compute attention) of the sequence group. None if it is in a decode
  28. # stage.
  29. seq_len: Optional[int]
  30. # The length of new query tokens to compute in the current step. None if it
  31. # is in a decode stage. The length of query_len <= seq_len if chunked
  32. # prefill is enabled.
  33. query_len: Optional[int]
  34. # A random number generator for sampling.
  35. generator: Optional[torch.Generator]
  36. # True if the sequence group is in prefill stage. False if it is in a
  37. # decode stage.
  38. is_prompt: bool
  39. # Query token indices from logits. to compute prompt logprob. Empty if
  40. # prompt logprob is not required.
  41. prompt_logprob_indices: List[int]
  42. # Sample token indices from logits. Empty if sampling is not required.
  43. sample_indices: List[int]
  44. @property
  45. def do_sample(self):
  46. return len(self.sample_indices) > 0
  47. def __post_init__(self):
  48. if len(self.prompt_logprob_indices) > 0:
  49. assert self.sampling_params.prompt_logprobs is not None
  50. if self.is_prompt:
  51. assert self.seq_len is not None
  52. assert self.query_len is not None
  53. class SamplingMetadata:
  54. """Metadata for input sequences. Used in sampler.
  55. The usage is as follow;
  56. ```
  57. hidden_states = execute_model(...)
  58. logits = hidden_states[sampling_metadata.selected_token_indices]
  59. sample(logits)
  60. def sample(logits):
  61. # Use categorized_sample_indices for sampling....
  62. ```
  63. Args:
  64. seq_groups: List of batched sequence groups.
  65. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  66. logits from the initial model output hidden states.
  67. categorized_sample_indices: SamplingType -> token indices to sample.
  68. Each token indices is 2D tensor of (num_indices, num_indices) where
  69. the first item means the sample index within the returned logit
  70. (before pruning padding), and the second item means the sample
  71. index after pruning using selected_token_indices.
  72. For example, if the returned logit is [1, 2, 3], and we select
  73. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  74. The first tuple is [1, 2] (sampled index within original logit),
  75. and the second tuple is [0, 1] (sampled index within pruned logit).
  76. num_prompts: Number of prompt sequence groups in seq_groups.
  77. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  78. serialization of token outputs.
  79. reuse_sampling_tensors: Indicates if we want to reuse sampling
  80. tensors that are part of the sampler forward pass. Currently,
  81. it is mainly used for multi-step decode.
  82. """
  83. def __init__(
  84. self,
  85. seq_groups: List[SequenceGroupToSample],
  86. selected_token_indices: torch.Tensor,
  87. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  88. num_prompts: int,
  89. skip_sampler_cpu_output: bool = False,
  90. reuse_sampling_tensors: bool = False,
  91. ) -> None:
  92. self.seq_groups = seq_groups
  93. self.selected_token_indices = selected_token_indices
  94. self.categorized_sample_indices = categorized_sample_indices
  95. self.num_prompts = num_prompts
  96. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  97. self.reuse_sampling_tensors = reuse_sampling_tensors
  98. @staticmethod
  99. def prepare(
  100. seq_group_metadata_list: List[SequenceGroupMetadata],
  101. seq_lens: List[int],
  102. query_lens: Optional[List[int]],
  103. device: str,
  104. pin_memory: bool,
  105. ) -> "SamplingMetadata":
  106. (
  107. seq_groups,
  108. selected_token_indices,
  109. categorized_sample_indices,
  110. num_prompts,
  111. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  112. device)
  113. selected_token_indices = async_tensor_h2d(selected_token_indices,
  114. dtype=torch.long,
  115. target_device=device,
  116. pin_memory=pin_memory)
  117. categorized_sample_indices = {
  118. t: maybe_expand_dim(
  119. async_tensor_h2d(seq_ids,
  120. dtype=torch.int,
  121. target_device=device,
  122. pin_memory=pin_memory), 2, 2)
  123. for t, seq_ids in categorized_sample_indices.items()
  124. }
  125. sampling_metadata = SamplingMetadata(
  126. seq_groups=seq_groups,
  127. selected_token_indices=selected_token_indices,
  128. categorized_sample_indices=categorized_sample_indices,
  129. num_prompts=num_prompts,
  130. )
  131. return sampling_metadata
  132. def __repr__(self) -> str:
  133. return (
  134. "SamplingMetadata("
  135. f"seq_groups={self.seq_groups}, "
  136. f"selected_token_indices={self.selected_token_indices}, "
  137. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  138. def _prepare_seq_groups(
  139. seq_group_metadata_list: List[SequenceGroupMetadata],
  140. seq_lens: List[int],
  141. query_lens: Optional[List[int]],
  142. device: str,
  143. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  144. SamplingType, List[Tuple[int, int]]], int]:
  145. """Prepare sequence groups and indices for sampling.
  146. Args:
  147. seq_group_metadata_list: A list of sequence group to batch.
  148. seq_lens: A list of sequence lens per sequence group.
  149. Index of prompt len should match with seq_group_metadata_list.
  150. query_lens: A list of query lengths. Prompt lens include the length
  151. of entire prompt tokens, and it could be shorter.
  152. device: A device to use for random number generator,
  153. `SequenceGroupToSample.generator`.
  154. Returns:
  155. seq_groups: A list of sequence group to sample.
  156. selected_token_indices: See the definition from `SamplingMetadata`.
  157. categorized_sample_indices: See the definition from `SamplingMetadata`.
  158. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  159. """
  160. # Batched sequence groups for the current model forward stsep.
  161. seq_groups: List[SequenceGroupToSample] = []
  162. # A list of token indices to sample/compute logprob. It is used to
  163. # prune the outcome logits from the model for the performance.
  164. selected_token_indices: List[int] = []
  165. # Used for selected_token_indices.
  166. model_output_idx = 0
  167. # Sampling type -> (
  168. # indices to sample/prompt logprob within pruned output logits,
  169. # indices to sample within pruned logits)
  170. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  171. t: []
  172. for t in SamplingType
  173. }
  174. # Index of logits to compute logprob. Logits include both prompt logprob
  175. # and sample logprob indices.
  176. logit_idx = 0
  177. # Index to sample from a sample tensor. It is used by triton sample kernel.
  178. # See `_sample_with_triton_kernel` for more details.
  179. sample_idx = 0
  180. # Total number of prompts from given sequence groups.
  181. num_prompts = 0
  182. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  183. seq_ids = list(seq_group_metadata.seq_data.keys())
  184. sampling_params = seq_group_metadata.sampling_params
  185. is_prompt = seq_group_metadata.is_prompt
  186. generator: Optional[torch.Generator] = None
  187. # If the current seq group is in decode stage, it is None.
  188. seq_len: Optional[int] = None
  189. query_len: Optional[int] = None
  190. prompt_logprob_indices: List[int] = []
  191. sample_indices: List[int] = []
  192. do_sample = seq_group_metadata.do_sample
  193. if seq_group_metadata.is_prompt:
  194. if sampling_params.seed is not None:
  195. seq_group_metadata.state.generator = torch.Generator(
  196. device=device).manual_seed(sampling_params.seed)
  197. num_prompts += 1
  198. num_prefill_sample = len(seq_ids)
  199. assert num_prefill_sample == 1
  200. assert query_lens is not None and seq_lens is not None
  201. query_len, seq_len = query_lens[i], seq_lens[i]
  202. # If we need sampling, exclude num_prefill_sample tokens from
  203. # prompt logprob.
  204. prompt_logprob_len = (query_len - num_prefill_sample
  205. if do_sample else query_len)
  206. sample_len = num_prefill_sample if do_sample else 0
  207. else:
  208. # Decode
  209. prompt_logprob_len = 0
  210. sample_len = len(seq_ids) if do_sample else 0
  211. # Update indices to select from the model output.
  212. """
  213. This blocks computes selected_token_indices which is used in the
  214. following way.
  215. hidden_states = model(...)
  216. logits = hidden_states[selected_token_indices]
  217. """
  218. if sampling_params.prompt_logprobs is not None:
  219. selected_token_indices.extend(
  220. range(model_output_idx, model_output_idx + prompt_logprob_len))
  221. model_output_idx += prompt_logprob_len
  222. if do_sample:
  223. selected_token_indices.extend(
  224. range(model_output_idx, model_output_idx + sample_len))
  225. model_output_idx += sample_len
  226. # We now find indices for logprob computation and sampling.
  227. """
  228. This block computes categorized_sample_indices which is used in the
  229. following way.
  230. hidden_states = model(...)
  231. logits = hidden_states[selected_token_indices]
  232. def sample(logits):
  233. # Use categorized_sample_indices for sampling.
  234. # prompt_logprob_indices to find prompt logprob indices.
  235. # sample_indices to find sample indices.
  236. """
  237. if sampling_params.prompt_logprobs is not None:
  238. prompt_logprob_indices.extend(
  239. range(logit_idx, logit_idx + prompt_logprob_len))
  240. logit_idx += prompt_logprob_len
  241. if do_sample:
  242. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  243. categorized_sample_indices[sampling_params.sampling_type].extend(
  244. list(
  245. zip(range(logit_idx, logit_idx + sample_len),
  246. range(sample_idx, sample_idx + sample_len))))
  247. logit_idx += sample_len
  248. sample_idx += sample_len
  249. if sampling_params.seed is not None:
  250. generator = seq_group_metadata.state.generator
  251. seq_groups.append(
  252. SequenceGroupToSample(
  253. seq_ids=seq_ids,
  254. sampling_params=sampling_params,
  255. seq_data=seq_group_metadata.seq_data,
  256. seq_len=seq_len,
  257. query_len=query_len,
  258. generator=generator,
  259. is_prompt=is_prompt,
  260. prompt_logprob_indices=list(prompt_logprob_indices),
  261. sample_indices=list(sample_indices)))
  262. return (seq_groups, selected_token_indices, categorized_sample_indices,
  263. num_prompts)
  264. @dataclass
  265. class SamplingTensors:
  266. """Tensors for sampling."""
  267. temperatures: torch.Tensor
  268. top_ps: torch.Tensor
  269. top_ks: torch.Tensor
  270. top_as: torch.Tensor
  271. min_ps: torch.Tensor
  272. presence_penalties: torch.Tensor
  273. frequency_penalties: torch.Tensor
  274. repetition_penalties: torch.Tensor
  275. tfss: torch.Tensor
  276. eta_cutoffs: torch.Tensor
  277. epsilon_cutoffs: torch.Tensor
  278. typical_ps: torch.Tensor
  279. smoothing_factors: torch.Tensor
  280. smoothing_curves: torch.Tensor
  281. sampling_seeds: torch.Tensor
  282. sample_indices: torch.Tensor
  283. extra_seeds: Optional[torch.Tensor]
  284. prompt_tokens: torch.Tensor
  285. output_tokens: torch.Tensor
  286. @classmethod
  287. def from_sampling_metadata(
  288. cls,
  289. sampling_metadata: "SamplingMetadata",
  290. vocab_size: int,
  291. device: torch.device,
  292. dtype: torch.dtype,
  293. *,
  294. extra_seeds_to_generate: int = 0,
  295. extra_entropy: Optional[Tuple[int, ...]] = None
  296. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  297. bool, bool]:
  298. """
  299. extra_seeds_to_generate: extra seeds to generate using the
  300. user-defined seed for each sequence.
  301. extra_entropy: extra entropy to use when generating seeds.
  302. """
  303. prompt_tokens: List[List[int]] = []
  304. output_tokens: List[List[int]] = []
  305. top_ks: List[int] = []
  306. temperatures: List[float] = []
  307. top_ps: List[float] = []
  308. top_as: List[float] = []
  309. min_ps: List[float] = []
  310. presence_penalties: List[float] = []
  311. frequency_penalties: List[float] = []
  312. repetition_penalties: List[float] = []
  313. tfss: List[float] = []
  314. eta_cutoffs: List[float] = []
  315. epsilon_cutoffs: List[float] = []
  316. typical_ps: List[float] = []
  317. smoothing_factors: List[float] = []
  318. smoothing_curves: List[float] = []
  319. sampling_seeds: List[int] = []
  320. sample_indices: List[int] = []
  321. prompt_best_of: List[int] = []
  322. do_penalties = False
  323. do_top_p_top_k = False
  324. do_top_as = False
  325. do_min_p = False
  326. do_tfss = False
  327. do_eta_cutoffs = False
  328. do_epsilon_cutoffs = False
  329. do_typical_ps = False
  330. do_quadratic = False
  331. # We need one base seed per Triton slice.
  332. seeds_to_generate = (extra_seeds_to_generate +
  333. get_num_triton_sampler_splits(vocab_size))
  334. assert sampling_metadata.seq_groups is not None
  335. for seq_group in sampling_metadata.seq_groups:
  336. seq_ids = seq_group.seq_ids
  337. sampling_params = seq_group.sampling_params
  338. temperature = sampling_params.temperature
  339. p = sampling_params.presence_penalty
  340. f = sampling_params.frequency_penalty
  341. r = sampling_params.repetition_penalty
  342. top_p = sampling_params.top_p
  343. top_a = sampling_params.top_a
  344. min_p = sampling_params.min_p
  345. tfs = sampling_params.tfs
  346. eta_cutoff = sampling_params.eta_cutoff
  347. epsilon_cutoff = sampling_params.epsilon_cutoff
  348. typical_p = sampling_params.typical_p
  349. smoothing_factor = sampling_params.smoothing_factor
  350. smoothing_curve = sampling_params.smoothing_curve
  351. seed = sampling_params.seed
  352. is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
  353. # k should not be greater than the vocab size.
  354. top_k = min(sampling_params.top_k, vocab_size)
  355. top_k = vocab_size if top_k == -1 else top_k
  356. if temperature < _SAMPLING_EPS:
  357. # NOTE: Zero temperature means deterministic sampling
  358. # (i.e., greedy sampling or beam search).
  359. # Set the temperature to 1 to avoid division by zero.
  360. temperature = 1.0
  361. if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
  362. or top_k != vocab_size):
  363. do_top_p_top_k = True
  364. if do_top_as is False and top_a > 0.0:
  365. do_top_as = True
  366. if not do_min_p and min_p > _SAMPLING_EPS:
  367. do_min_p = True
  368. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  369. or abs(f) >= _SAMPLING_EPS
  370. or abs(r - 1.0) >= _SAMPLING_EPS):
  371. do_penalties = True
  372. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  373. do_tfss = True
  374. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  375. do_eta_cutoffs = True
  376. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  377. do_epsilon_cutoffs = True
  378. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  379. do_typical_ps = True
  380. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  381. or smoothing_curve > 1.0):
  382. do_quadratic = True
  383. is_prompt = seq_group.is_prompt
  384. if (seq_group.is_prompt
  385. and sampling_params.prompt_logprobs is not None):
  386. # For tokens in the prompt that we only need to get
  387. # their logprobs
  388. query_len = seq_group.query_len
  389. assert query_len is not None
  390. prefill_len = len(seq_group.prompt_logprob_indices)
  391. temperatures += [temperature] * prefill_len
  392. top_ps += [top_p] * prefill_len
  393. top_ks += [top_k] * prefill_len
  394. top_as += [top_a] * prefill_len
  395. min_ps += [min_p] * prefill_len
  396. presence_penalties += [0] * prefill_len
  397. frequency_penalties += [0] * prefill_len
  398. repetition_penalties += [1] * prefill_len
  399. tfss += [1] * prefill_len
  400. eta_cutoffs += [0] * prefill_len
  401. epsilon_cutoffs += [0] * prefill_len
  402. typical_ps += [1] * prefill_len
  403. smoothing_factors += [smoothing_factor] * prefill_len
  404. smoothing_curves += [smoothing_curve] * prefill_len
  405. if seq_group.do_sample:
  406. sample_lens = len(seq_group.sample_indices)
  407. assert sample_lens == len(seq_ids)
  408. temperatures += [temperature] * len(seq_ids)
  409. top_ps += [top_p] * len(seq_ids)
  410. top_ks += [top_k] * len(seq_ids)
  411. top_as += [top_a] * len(seq_ids)
  412. min_ps += [min_p] * len(seq_ids)
  413. presence_penalties += [p] * len(seq_ids)
  414. frequency_penalties += [f] * len(seq_ids)
  415. repetition_penalties += [r] * len(seq_ids)
  416. tfss += [tfs] * len(seq_ids)
  417. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  418. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  419. typical_ps += [typical_p] * len(seq_ids)
  420. smoothing_factors += [smoothing_factor] * len(seq_ids)
  421. smoothing_curves += [smoothing_curve] * len(seq_ids)
  422. if is_prompt:
  423. prompt_best_of.append(sampling_params.best_of)
  424. query_len = seq_group.query_len
  425. assert query_len is not None
  426. for seq_id in seq_ids:
  427. seq_data = seq_group.seq_data[seq_id]
  428. extra_entropy = extra_entropy or ()
  429. seq_seeds = cls._get_sequence_seeds(
  430. seed,
  431. seq_data.get_len(),
  432. *extra_entropy,
  433. seq_id,
  434. seeds_to_generate=seeds_to_generate,
  435. is_greedy=is_greedy)
  436. sampling_seeds.append(seq_seeds)
  437. sample_indices.extend(seq_group.sample_indices)
  438. if do_penalties:
  439. for seq_group in sampling_metadata.seq_groups:
  440. seq_ids = seq_group.seq_ids
  441. if (seq_group.is_prompt
  442. and sampling_params.prompt_logprobs is not None):
  443. prefill_len = len(seq_group.prompt_logprob_indices)
  444. prompt_tokens.extend([] for _ in range(prefill_len))
  445. output_tokens.extend([] for _ in range(prefill_len))
  446. if seq_group.do_sample:
  447. for seq_id in seq_ids:
  448. seq_data = seq_group.seq_data[seq_id]
  449. prompt_tokens.append(list(seq_data.prompt_token_ids))
  450. output_tokens.append(list(seq_data.output_token_ids))
  451. sampling_tensors = SamplingTensors.from_lists(
  452. temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
  453. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  454. epsilon_cutoffs, typical_ps, smoothing_factors, smoothing_curves,
  455. sampling_seeds, sample_indices, prompt_tokens, output_tokens,
  456. vocab_size, extra_seeds_to_generate, device, dtype)
  457. return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
  458. do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  459. do_typical_ps, do_quadratic)
  460. @classmethod
  461. def from_lists(cls, temperatures: List[float], top_ps: List[float],
  462. top_ks: List[int], top_as: List[float], min_ps: List[float],
  463. presence_penalties: List[float],
  464. frequency_penalties: List[float],
  465. repetition_penalties: List[float], tfss: List[float],
  466. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  467. typical_ps: List[float], smoothing_factors: List[float],
  468. smoothing_curves: List[float], sampling_seeds: List[int],
  469. sample_indices: List[int], prompt_tokens: List[List[int]],
  470. output_tokens: List[List[int]], vocab_size: int,
  471. extra_seeds_to_generate: int, device: torch.device,
  472. dtype: torch.dtype) -> "SamplingTensors":
  473. # Note that the performance will be very bad without
  474. # pinned memory.
  475. pin_memory = is_pin_memory_available()
  476. do_penalties = prompt_tokens or output_tokens
  477. if do_penalties:
  478. prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
  479. default=0)
  480. prompt_padded_tokens = np.full(
  481. (len(prompt_tokens), prompt_max_len),
  482. vocab_size,
  483. dtype=np.int64)
  484. for i, tokens in enumerate(prompt_tokens):
  485. prompt_padded_tokens[i, :len(tokens)] = tokens
  486. output_max_len = max([len(tokens) for tokens in output_tokens],
  487. default=0)
  488. output_padded_tokens = np.full(
  489. (len(output_tokens), output_max_len),
  490. vocab_size,
  491. dtype=np.int64)
  492. for i, tokens in enumerate(output_tokens):
  493. output_padded_tokens[i, :len(tokens)] = tokens
  494. temperatures_t = torch.tensor(
  495. temperatures,
  496. device="cpu",
  497. dtype=dtype,
  498. pin_memory=pin_memory,
  499. )
  500. top_ps_t = torch.tensor(
  501. top_ps,
  502. device="cpu",
  503. dtype=dtype,
  504. pin_memory=pin_memory,
  505. )
  506. top_as_t = torch.tensor(top_as,
  507. device="cpu",
  508. dtype=dtype,
  509. pin_memory=pin_memory)
  510. min_ps_t = torch.tensor(
  511. min_ps,
  512. device="cpu",
  513. dtype=dtype,
  514. pin_memory=pin_memory,
  515. )
  516. presence_penalties_t = torch.tensor(
  517. presence_penalties,
  518. device="cpu",
  519. dtype=dtype,
  520. pin_memory=pin_memory,
  521. )
  522. frequency_penalties_t = torch.tensor(
  523. frequency_penalties,
  524. device="cpu",
  525. dtype=dtype,
  526. pin_memory=pin_memory,
  527. )
  528. repetition_penalties_t = torch.tensor(
  529. repetition_penalties,
  530. device="cpu",
  531. dtype=dtype,
  532. pin_memory=pin_memory,
  533. )
  534. top_ks_t = torch.tensor(
  535. top_ks,
  536. device="cpu",
  537. dtype=torch.int,
  538. pin_memory=pin_memory,
  539. )
  540. tfss_t = torch.tensor(tfss,
  541. device="cpu",
  542. dtype=dtype,
  543. pin_memory=pin_memory)
  544. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  545. device="cpu",
  546. dtype=dtype,
  547. pin_memory=pin_memory)
  548. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  549. device="cpu",
  550. dtype=dtype,
  551. pin_memory=pin_memory)
  552. typical_ps_t = torch.tensor(typical_ps,
  553. device="cpu",
  554. dtype=dtype,
  555. pin_memory=pin_memory)
  556. smoothing_factors_t = torch.tensor(smoothing_factors,
  557. device="cpu",
  558. dtype=dtype,
  559. pin_memory=pin_memory)
  560. smoothing_curves_t = torch.tensor(smoothing_curves,
  561. device="cpu",
  562. dtype=dtype,
  563. pin_memory=pin_memory)
  564. sample_indices_t = torch.tensor(
  565. sample_indices,
  566. device="cpu",
  567. dtype=torch.long,
  568. pin_memory=pin_memory,
  569. )
  570. if do_penalties:
  571. prompt_tensor = torch.from_numpy(prompt_padded_tokens)
  572. output_tensor = torch.from_numpy(output_padded_tokens)
  573. if pin_memory:
  574. prompt_tensor = prompt_tensor.pin_memory()
  575. output_tensor = output_tensor.pin_memory()
  576. else:
  577. prompt_tensor = None
  578. output_tensor = None
  579. # need to transpose and make contiguous to
  580. # copy the tensor correctly.
  581. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  582. sampling_seeds_t = torch.tensor(
  583. sampling_seeds,
  584. device="cpu",
  585. dtype=torch.long,
  586. pin_memory=pin_memory,
  587. ).T.contiguous()
  588. # Because the memory is pinned, we can do non-blocking
  589. # transfer to device.
  590. # How many seeds the sample operation itself will need.
  591. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  592. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  593. non_blocking=True)
  594. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  595. if not extra_seeds_gpu.numel():
  596. extra_seeds_gpu = None
  597. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  598. if do_penalties:
  599. prompt_tokens_gpu = prompt_tensor.to(device=device,
  600. non_blocking=True)
  601. output_tokens_gpu = output_tensor.to(device=device,
  602. non_blocking=True)
  603. else:
  604. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  605. prompt_tokens_gpu = empty_tensor
  606. output_tokens_gpu = empty_tensor
  607. return cls(
  608. temperatures=temperatures_t.to(device=device, non_blocking=True),
  609. top_ps=top_ps_t.to(device=device, non_blocking=True),
  610. top_ks=top_ks_t.to(device=device, non_blocking=True),
  611. top_as=top_as_t.to(device=device, non_blocking=True),
  612. min_ps=min_ps_t.to(device=device, non_blocking=True),
  613. presence_penalties=presence_penalties_t.to(device=device,
  614. non_blocking=True),
  615. frequency_penalties=frequency_penalties_t.to(device=device,
  616. non_blocking=True),
  617. repetition_penalties=repetition_penalties_t.to(device=device,
  618. non_blocking=True),
  619. tfss=tfss_t.to(device=device, non_blocking=True),
  620. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  621. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  622. non_blocking=True),
  623. smoothing_factors=smoothing_factors_t.to(device=device,
  624. non_blocking=True),
  625. smoothing_curves=smoothing_curves_t.to(device=device,
  626. non_blocking=True),
  627. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  628. prompt_tokens=prompt_tokens_gpu,
  629. output_tokens=output_tokens_gpu,
  630. sampling_seeds=sampling_seeds_gpu,
  631. sample_indices=sample_indices_t.to(device=device,
  632. non_blocking=True),
  633. extra_seeds=extra_seeds_gpu,
  634. )
  635. @staticmethod
  636. def _get_sequence_seeds(
  637. seed: int,
  638. *extra_entropy: int,
  639. seeds_to_generate: int,
  640. is_greedy: bool,
  641. ):
  642. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  643. if not is_greedy:
  644. if seed is None:
  645. randint_fn = random.randint
  646. else:
  647. generator = random.Random(str((seed, ) + extra_entropy))
  648. randint_fn = generator.randint
  649. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  650. # If the user/random sets seed = 0 but request should
  651. # have sampling, we need to change it to something
  652. # else. We use a constant in that case.
  653. # This way we don't need to create and load a bool
  654. # matrix in the sampling kernel, which reduces CPU
  655. # overhead and latency.
  656. seq_seeds = [
  657. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  658. for _ in range(seeds_to_generate)
  659. ]
  660. else:
  661. # For the kernel, seed == 0 means greedy decoding.
  662. seq_seeds = [0] * seeds_to_generate
  663. return seq_seeds