1
0

sampling_metadata.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. import random
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  6. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  7. from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available,
  8. maybe_expand_dim)
  9. from aphrodite.modeling.layers.ops.sample import get_num_triton_sampler_splits
  10. _SAMPLING_EPS = 1e-5
  11. _SEED_0_REPLACEMENT = 3403598558
  12. @dataclass
  13. class SequenceGroupToSample:
  14. # Sequence ids for the sequence group in a previous step.
  15. seq_ids: List[int]
  16. sampling_params: SamplingParams
  17. # seq_id -> sequence data.
  18. seq_data: Dict[int, SequenceData]
  19. # The length of the prompt of the sequence group. None if it is in a decode
  20. # stage.
  21. prompt_len: Optional[int]
  22. # The length of the query tokens to compute in the current step. None if it
  23. # is in a decode stage. The length of subquery_len <= prompt_len.
  24. subquery_len: Optional[int]
  25. # A random number generator for sampling.
  26. generator: Optional[torch.Generator]
  27. # True if the sequence group is in prefill stage. False if it is in a
  28. # decode stage.
  29. is_prompt: bool
  30. # Query token indices from logits. to compute prompt logprob. Empty if
  31. # prompt logprob is not required.
  32. prompt_logprob_indices: List[int]
  33. # Sample token indices from logits. Empty if sampling is not required.
  34. sample_indices: List[int]
  35. @property
  36. def do_sample(self):
  37. return len(self.sample_indices) > 0
  38. def __post_init__(self):
  39. if len(self.prompt_logprob_indices) > 0:
  40. assert self.sampling_params.prompt_logprobs is not None
  41. if self.is_prompt:
  42. assert self.prompt_len is not None
  43. assert self.subquery_len is not None
  44. class SamplingMetadata:
  45. """Metadata for input sequences. Used in sampler.
  46. The usage is as follow;
  47. ```
  48. hidden_states = execute_model(...)
  49. logits = hidden_states[sampling_metadata.selected_token_indices]
  50. sample(logits)
  51. def sample(logits):
  52. # Use categorized_sample_indices for sampling....
  53. ```
  54. Args:
  55. seq_groups: List of batched sequence groups.
  56. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  57. logits from the initial model output hidden states.
  58. categorized_sample_indices: SamplingType -> token indices to sample.
  59. Each token indices is 2D tensor of (num_indices, num_indices) where
  60. the first item means the sample index within the returned logit
  61. (before pruning padding), and the second item means the sample
  62. index after pruning using selected_token_indices.
  63. For example, if the returned logit is [1, 2, 3], and we select
  64. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  65. The first tuple is [1, 2] (sampled index within original logit),
  66. and the second tuple is [0, 1] (sampled index within pruned logit).
  67. num_prompts: Number of prompt sequence groups in seq_groups.
  68. """
  69. def __init__(
  70. self,
  71. seq_groups: List[SequenceGroupToSample],
  72. selected_token_indices: torch.Tensor,
  73. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  74. num_prompts: int,
  75. ) -> None:
  76. self.seq_groups = seq_groups
  77. self.selected_token_indices = selected_token_indices
  78. self.categorized_sample_indices = categorized_sample_indices
  79. self.num_prompts = num_prompts
  80. @staticmethod
  81. def prepare(
  82. seq_group_metadata_list: List[SequenceGroupMetadata],
  83. prompt_lens: List[int],
  84. subquery_lens: Optional[List[int]],
  85. device: str,
  86. pin_memory: bool,
  87. ) -> "SamplingMetadata":
  88. (
  89. seq_groups,
  90. selected_token_indices,
  91. categorized_sample_indices,
  92. num_prompts,
  93. ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
  94. subquery_lens, device)
  95. selected_token_indices = async_tensor_h2d(selected_token_indices,
  96. dtype=torch.long,
  97. target_device=device,
  98. pin_memory=pin_memory)
  99. categorized_sample_indices = {
  100. t: maybe_expand_dim(
  101. async_tensor_h2d(seq_ids,
  102. dtype=torch.int,
  103. target_device=device,
  104. pin_memory=pin_memory), 2, 2)
  105. for t, seq_ids in categorized_sample_indices.items()
  106. }
  107. sampling_metadata = SamplingMetadata(
  108. seq_groups=seq_groups,
  109. selected_token_indices=selected_token_indices,
  110. categorized_sample_indices=categorized_sample_indices,
  111. num_prompts=num_prompts,
  112. )
  113. return sampling_metadata
  114. def __repr__(self) -> str:
  115. return (
  116. "SamplingMetadata("
  117. f"seq_groups={self.seq_groups}, "
  118. f"selected_token_indices={self.selected_token_indices}, "
  119. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  120. def _prepare_seq_groups(
  121. seq_group_metadata_list: List[SequenceGroupMetadata],
  122. prompt_lens: List[int],
  123. subquery_lens: Optional[List[int]],
  124. device: str,
  125. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  126. SamplingType, List[Tuple[int, int]]], int]:
  127. """Prepare sequence groups and indices for sampling.
  128. Args:
  129. seq_group_metadata_list: A list of sequence group to batch.
  130. prompt_lens: A list of prompt lens per sequence group.
  131. Index of prompt len should match with seq_group_metadata_list.
  132. subquery_lens: A list of query lengths. Prompt lens include the length
  133. of entire prompt tokens, and it could be shorter.
  134. device: A device to use for random number generator,
  135. `SequenceGroupToSample.generator`.
  136. Returns:
  137. seq_groups: A list of sequence group to sample.
  138. selected_token_indices: See the definition from `SamplingMetadata`.
  139. categorized_sample_indices: See the definition from `SamplingMetadata`.
  140. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  141. """
  142. # Batched sequence groups for the current model forward stsep.
  143. seq_groups: List[SequenceGroupToSample] = []
  144. # A list of token indices to sample/compute logprob. It is used to
  145. # prune the outcome logits from the model for the performance.
  146. selected_token_indices: List[int] = []
  147. # Used for selected_token_indices.
  148. model_output_idx = 0
  149. # Sampling type -> (
  150. # indices to sample/prompt logprob within pruned output logits,
  151. # indices to sample within pruned logits)
  152. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  153. t: []
  154. for t in SamplingType
  155. }
  156. # Index of logits to compute logprob. Logits include both prompt logprob
  157. # and sample logprob indices.
  158. logit_idx = 0
  159. # Index to sample from a sample tensor. It is used by triton sample kernel.
  160. # See `_sample_with_triton_kernel` for more details.
  161. sample_idx = 0
  162. # Total number of prompts from given sequence groups.
  163. num_prompts = 0
  164. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  165. seq_ids = list(seq_group_metadata.seq_data.keys())
  166. sampling_params = seq_group_metadata.sampling_params
  167. is_prompt = seq_group_metadata.is_prompt
  168. generator: Optional[torch.Generator] = None
  169. # If the current seq group is in decode stage, it is None.
  170. prompt_len: Optional[int] = None
  171. subquery_len: Optional[int] = None
  172. prompt_logprob_indices: List[int] = []
  173. sample_indices: List[int] = []
  174. do_sample = seq_group_metadata.do_sample
  175. if seq_group_metadata.is_prompt:
  176. if sampling_params.seed is not None:
  177. seq_group_metadata.state.generator = torch.Generator(
  178. device=device).manual_seed(sampling_params.seed)
  179. num_prompts += 1
  180. num_prefill_sample = len(seq_ids)
  181. assert num_prefill_sample == 1
  182. assert subquery_lens is not None and prompt_lens is not None
  183. subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
  184. # If we need sampling, exclude num_prefill_sample tokens from
  185. # prompt logprob.
  186. prompt_logprob_len = (subquery_len - num_prefill_sample
  187. if do_sample else subquery_len)
  188. sample_len = num_prefill_sample if do_sample else 0
  189. else:
  190. # Decode
  191. prompt_logprob_len = 0
  192. sample_len = len(seq_ids) if do_sample else 0
  193. # Update indices to select from the model output.
  194. """
  195. This blocks computes selected_token_indices which is used in the
  196. following way.
  197. hidden_states = model(...)
  198. logits = hidden_states[selected_token_indices]
  199. """
  200. if sampling_params.prompt_logprobs:
  201. selected_token_indices.extend(
  202. range(model_output_idx, model_output_idx + prompt_logprob_len))
  203. model_output_idx += prompt_logprob_len
  204. if do_sample:
  205. selected_token_indices.extend(
  206. range(model_output_idx, model_output_idx + sample_len))
  207. model_output_idx += sample_len
  208. # We now find indices for logprob computation and sampling.
  209. """
  210. This block computes categorized_sample_indices which is used in the
  211. following way.
  212. hidden_states = model(...)
  213. logits = hidden_states[selected_token_indices]
  214. def sample(logits):
  215. # Use categorized_sample_indices for sampling.
  216. # prompt_logprob_indices to find prompt logprob indices.
  217. # sample_indices to find sample indices.
  218. """
  219. if sampling_params.prompt_logprobs is not None:
  220. prompt_logprob_indices.extend(
  221. range(logit_idx, logit_idx + prompt_logprob_len))
  222. logit_idx += prompt_logprob_len
  223. if do_sample:
  224. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  225. categorized_sample_indices[sampling_params.sampling_type].extend(
  226. list(
  227. zip(range(logit_idx, logit_idx + sample_len),
  228. range(sample_idx, sample_idx + sample_len))))
  229. logit_idx += sample_len
  230. sample_idx += sample_len
  231. if sampling_params.seed is not None:
  232. generator = seq_group_metadata.state.generator
  233. seq_groups.append(
  234. SequenceGroupToSample(
  235. seq_ids=seq_ids,
  236. sampling_params=sampling_params,
  237. seq_data=seq_group_metadata.seq_data,
  238. prompt_len=prompt_len,
  239. subquery_len=subquery_len,
  240. generator=generator,
  241. is_prompt=is_prompt,
  242. prompt_logprob_indices=list(prompt_logprob_indices),
  243. sample_indices=list(sample_indices)))
  244. return (seq_groups, selected_token_indices, categorized_sample_indices,
  245. num_prompts)
  246. @dataclass
  247. class SamplingTensors:
  248. """Tensors for sampling."""
  249. temperatures: torch.Tensor
  250. top_ps: torch.Tensor
  251. top_ks: torch.Tensor
  252. top_as: torch.Tensor
  253. min_ps: torch.Tensor
  254. presence_penalties: torch.Tensor
  255. frequency_penalties: torch.Tensor
  256. repetition_penalties: torch.Tensor
  257. tfss: torch.Tensor
  258. eta_cutoffs: torch.Tensor
  259. epsilon_cutoffs: torch.Tensor
  260. typical_ps: torch.Tensor
  261. dynatemp_mins: torch.Tensor
  262. dynatemp_maxs: torch.Tensor
  263. dynatemp_exps: torch.Tensor
  264. smoothing_factors: torch.Tensor
  265. smoothing_curves: torch.Tensor
  266. sampling_seeds: torch.Tensor
  267. sample_indices: torch.Tensor
  268. extra_seeds: Optional[torch.Tensor]
  269. prompt_tokens: torch.Tensor
  270. output_tokens: torch.Tensor
  271. @classmethod
  272. def from_sampling_metadata(
  273. cls,
  274. sampling_metadata: "SamplingMetadata",
  275. vocab_size: int,
  276. device: torch.device,
  277. dtype: torch.dtype,
  278. *,
  279. extra_seeds_to_generate: int = 0,
  280. extra_entropy: Optional[Tuple[int, ...]] = None
  281. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  282. bool, bool, bool, bool]:
  283. prompt_tokens: List[List[int]] = []
  284. output_tokens: List[List[int]] = []
  285. top_ks: List[int] = []
  286. temperatures: List[float] = []
  287. top_ps: List[float] = []
  288. top_as: List[float] = []
  289. min_ps: List[float] = []
  290. presence_penalties: List[float] = []
  291. frequency_penalties: List[float] = []
  292. repetition_penalties: List[float] = []
  293. tfss: List[float] = []
  294. eta_cutoffs: List[float] = []
  295. epsilon_cutoffs: List[float] = []
  296. typical_ps: List[float] = []
  297. dynatemp_mins: List[float] = []
  298. dynatemp_maxs: List[float] = []
  299. dynatemp_exps: List[float] = []
  300. smoothing_factors: List[float] = []
  301. smoothing_curves: List[float] = []
  302. sampling_seeds: List[int] = []
  303. sample_indices: List[int] = []
  304. prompt_best_of: List[int] = []
  305. do_temperatures = False
  306. do_penalties = False
  307. do_topks = False
  308. do_topps = False
  309. do_topas = False
  310. do_minps = False
  311. do_tfss = False
  312. do_eta_cutoffs = False
  313. do_epsilon_cutoffs = False
  314. do_typical_ps = False
  315. do_quadratic = False
  316. # We need one base seed per Triton slice.
  317. seeds_to_generate = (extra_seeds_to_generate +
  318. get_num_triton_sampler_splits(vocab_size))
  319. assert sampling_metadata.seq_groups is not None
  320. for seq_group in sampling_metadata.seq_groups:
  321. seq_ids = seq_group.seq_ids
  322. sampling_params = seq_group.sampling_params
  323. temperature = sampling_params.temperature
  324. p = sampling_params.presence_penalty
  325. f = sampling_params.frequency_penalty
  326. r = sampling_params.repetition_penalty
  327. top_p = sampling_params.top_p
  328. # k should not be greater than the vocab size
  329. top_k = min(sampling_params.top_k, vocab_size)
  330. top_k = vocab_size if top_k == -1 else top_k
  331. top_a = sampling_params.top_a
  332. min_p = sampling_params.min_p
  333. tfs = sampling_params.tfs
  334. eta_cutoff = sampling_params.eta_cutoff
  335. epsilon_cutoff = sampling_params.epsilon_cutoff
  336. typical_p = sampling_params.typical_p
  337. dynatemp_min = sampling_params.dynatemp_min
  338. dynatemp_max = sampling_params.dynatemp_max
  339. dynatemp_exp = sampling_params.dynatemp_exponent
  340. smoothing_factor = sampling_params.smoothing_factor
  341. smoothing_curve = sampling_params.smoothing_curve
  342. seed = sampling_params.seed
  343. is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
  344. if do_temperatures is False and temperature > _SAMPLING_EPS:
  345. do_temperatures = True
  346. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  347. or abs(f) >= _SAMPLING_EPS
  348. or abs(r - 1.0) >= _SAMPLING_EPS):
  349. do_penalties = True
  350. if do_topks is False and top_k != vocab_size:
  351. do_topks = True
  352. if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
  353. do_topps = True
  354. if do_topas is False and top_a > 0.0:
  355. do_topas = True
  356. if do_minps is False and min_p > _SAMPLING_EPS:
  357. do_minps = True
  358. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  359. do_tfss = True
  360. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  361. do_eta_cutoffs = True
  362. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  363. do_epsilon_cutoffs = True
  364. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  365. do_typical_ps = True
  366. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  367. or smoothing_curve > 1.0):
  368. do_quadratic = True
  369. is_prompt = seq_group.is_prompt
  370. if (seq_group.is_prompt
  371. and sampling_params.prompt_logprobs is not None):
  372. # For tokens in the prompt that we only need to get their
  373. # logprobs
  374. subquery_len = seq_group.subquery_len
  375. assert subquery_len is not None
  376. prefill_len = len(seq_group.prompt_logprob_indices)
  377. temperatures += [temperature] * prefill_len
  378. top_ps += [top_p] * prefill_len
  379. top_ks += [top_k] * prefill_len
  380. top_as += [top_a] * prefill_len
  381. min_ps += [min_p] * prefill_len
  382. presence_penalties += [0] * prefill_len
  383. frequency_penalties += [0] * prefill_len
  384. repetition_penalties += [1] * prefill_len
  385. tfss += [1] * prefill_len
  386. eta_cutoffs += [0] * prefill_len
  387. epsilon_cutoffs += [0] * prefill_len
  388. typical_ps += [1] * prefill_len
  389. dynatemp_mins += [dynatemp_min] * prefill_len
  390. dynatemp_maxs += [dynatemp_max] * prefill_len
  391. dynatemp_exps += [dynatemp_exp] * prefill_len
  392. smoothing_factors += [smoothing_factor] * prefill_len
  393. smoothing_curves += [smoothing_curve] * prefill_len
  394. prompt_tokens.extend([] for _ in range(prefill_len))
  395. output_tokens.extend([] for _ in range(prefill_len))
  396. if seq_group.do_sample:
  397. sample_lens = len(seq_group.sample_indices)
  398. assert sample_lens == len(seq_ids)
  399. for seq_id in seq_ids:
  400. seq_data = seq_group.seq_data[seq_id]
  401. prompt_tokens.append(seq_data.prompt_token_ids)
  402. output_tokens.append(seq_data.output_token_ids)
  403. temperatures += [temperature] * len(seq_ids)
  404. top_ps += [top_p] * len(seq_ids)
  405. top_ks += [top_k] * len(seq_ids)
  406. top_as += [top_a] * len(seq_ids)
  407. min_ps += [min_p] * len(seq_ids)
  408. presence_penalties += [p] * len(seq_ids)
  409. frequency_penalties += [f] * len(seq_ids)
  410. repetition_penalties += [r] * len(seq_ids)
  411. tfss += [tfs] * len(seq_ids)
  412. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  413. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  414. typical_ps += [typical_p] * len(seq_ids)
  415. dynatemp_mins += [dynatemp_min] * len(seq_ids)
  416. dynatemp_maxs += [dynatemp_max] * len(seq_ids)
  417. dynatemp_exps += [dynatemp_exp] * len(seq_ids)
  418. smoothing_factors += [smoothing_factor] * len(seq_ids)
  419. smoothing_curves += [smoothing_curve] * len(seq_ids)
  420. if is_prompt:
  421. prompt_best_of.append(sampling_params.best_of)
  422. subquery_len = seq_group.subquery_len
  423. assert subquery_len is not None
  424. for seq_id in seq_ids:
  425. seq_data = seq_group.seq_data[seq_id]
  426. extra_entropy = extra_entropy or ()
  427. seq_seeds = cls._get_sequence_seeds(
  428. seed,
  429. seq_data.get_len(),
  430. *extra_entropy,
  431. seq_id,
  432. seeds_to_generate=seeds_to_generate,
  433. is_greedy=is_greedy)
  434. sampling_seeds.append(seq_seeds)
  435. sample_indices.extend(seq_group.sample_indices)
  436. sampling_tensors = SamplingTensors.from_lists(
  437. temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
  438. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  439. epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
  440. dynatemp_exps, smoothing_factors, smoothing_curves, sampling_seeds,
  441. sample_indices, prompt_tokens, output_tokens, vocab_size,
  442. extra_seeds_to_generate, device, dtype)
  443. return (sampling_tensors, do_temperatures, do_penalties, do_topks,
  444. do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
  445. do_epsilon_cutoffs, do_typical_ps, do_quadratic)
  446. @classmethod
  447. def from_lists(cls, temperatures: List[float], top_ps: List[float],
  448. top_ks: List[int], top_as: List[float], min_ps: List[float],
  449. presence_penalties: List[float],
  450. frequency_penalties: List[float],
  451. repetition_penalties: List[float], tfss: List[float],
  452. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  453. typical_ps: List[float], dynatemp_mins: List[float],
  454. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  455. smoothing_factors: List[float],
  456. smoothing_curves: List[float], sampling_seeds: List[int],
  457. sample_indices: List[int], prompt_tokens: List[List[int]],
  458. output_tokens: List[List[int]], vocab_size: int,
  459. extra_seeds_to_generate: int, device: torch.device,
  460. dtype: torch.dtype) -> "SamplingTensors":
  461. # Note that the performance will be very bad without
  462. # pinned memory.
  463. pin_memory = is_pin_memory_available()
  464. prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
  465. default=0)
  466. prompt_padded_tokens = [
  467. tokens + [vocab_size] * (prompt_max_len - len(tokens))
  468. for tokens in prompt_tokens
  469. ]
  470. output_max_len = max([len(tokens) for tokens in output_tokens],
  471. default=0)
  472. output_padded_tokens = [
  473. tokens + [vocab_size] * (output_max_len - len(tokens))
  474. for tokens in output_tokens
  475. ]
  476. temperatures_t = torch.tensor(temperatures,
  477. device="cpu",
  478. dtype=dtype,
  479. pin_memory=pin_memory)
  480. top_ps_t = torch.tensor(top_ps,
  481. device="cpu",
  482. dtype=dtype,
  483. pin_memory=pin_memory)
  484. top_ks_t = torch.tensor(top_ks,
  485. device="cpu",
  486. dtype=torch.int,
  487. pin_memory=pin_memory)
  488. top_as_t = torch.tensor(top_as,
  489. device="cpu",
  490. dtype=dtype,
  491. pin_memory=pin_memory)
  492. min_ps_t = torch.tensor(min_ps,
  493. device="cpu",
  494. dtype=dtype,
  495. pin_memory=pin_memory)
  496. presence_penalties_t = torch.tensor(presence_penalties,
  497. device="cpu",
  498. dtype=dtype,
  499. pin_memory=pin_memory)
  500. frequency_penalties_t = torch.tensor(frequency_penalties,
  501. device="cpu",
  502. dtype=dtype,
  503. pin_memory=pin_memory)
  504. repetition_penalties_t = torch.tensor(repetition_penalties,
  505. device="cpu",
  506. dtype=dtype,
  507. pin_memory=pin_memory)
  508. tfss_t = torch.tensor(tfss,
  509. device="cpu",
  510. dtype=dtype,
  511. pin_memory=pin_memory)
  512. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  513. device="cpu",
  514. dtype=dtype,
  515. pin_memory=pin_memory)
  516. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  517. device="cpu",
  518. dtype=dtype,
  519. pin_memory=pin_memory)
  520. typical_ps_t = torch.tensor(typical_ps,
  521. device="cpu",
  522. dtype=dtype,
  523. pin_memory=pin_memory)
  524. dynatemp_mins_t = torch.tensor(dynatemp_mins,
  525. device="cpu",
  526. dtype=dtype,
  527. pin_memory=pin_memory)
  528. dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
  529. device="cpu",
  530. dtype=dtype,
  531. pin_memory=pin_memory)
  532. dynatemp_exps_t = torch.tensor(dynatemp_exps,
  533. device="cpu",
  534. dtype=dtype,
  535. pin_memory=pin_memory)
  536. smoothing_factors_t = torch.tensor(smoothing_factors,
  537. device="cpu",
  538. dtype=dtype,
  539. pin_memory=pin_memory)
  540. smoothing_curves_t = torch.tensor(smoothing_curves,
  541. device="cpu",
  542. dtype=dtype,
  543. pin_memory=pin_memory)
  544. sample_indices_t = torch.tensor(sample_indices,
  545. device="cpu",
  546. dtype=torch.int,
  547. pin_memory=pin_memory)
  548. prompt_tensor = torch.tensor(prompt_padded_tokens,
  549. device=device,
  550. dtype=torch.long,
  551. pin_memory=pin_memory)
  552. output_tensor = torch.tensor(output_padded_tokens,
  553. device=device,
  554. dtype=torch.long,
  555. pin_memory=pin_memory)
  556. # need to transpose and make contiguous to
  557. # copy the tensor correctly.
  558. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  559. sampling_seeds_t = torch.tensor(
  560. sampling_seeds,
  561. device="cpu",
  562. dtype=torch.long,
  563. pin_memory=pin_memory,
  564. ).T.contiguous()
  565. # Because the memory is pinned, we can do non-blocking
  566. # transfer to device.
  567. # How many seeds the sample operation itself will need.
  568. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  569. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  570. non_blocking=True)
  571. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  572. if not extra_seeds_gpu.numel():
  573. extra_seeds_gpu = None
  574. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  575. return cls(
  576. temperatures=temperatures_t.to(device=device, non_blocking=True),
  577. top_ps=top_ps_t.to(device=device, non_blocking=True),
  578. top_ks=top_ks_t.to(device=device, non_blocking=True),
  579. top_as=top_as_t.to(device=device, non_blocking=True),
  580. min_ps=min_ps_t.to(device=device, non_blocking=True),
  581. presence_penalties=presence_penalties_t.to(device=device,
  582. non_blocking=True),
  583. frequency_penalties=frequency_penalties_t.to(device=device,
  584. non_blocking=True),
  585. repetition_penalties=repetition_penalties_t.to(device=device,
  586. non_blocking=True),
  587. tfss=tfss_t.to(device=device, non_blocking=True),
  588. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  589. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  590. non_blocking=True),
  591. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  592. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  593. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  594. smoothing_factors=smoothing_factors_t.to(device=device,
  595. non_blocking=True),
  596. smoothing_curves=smoothing_curves_t.to(device=device,
  597. non_blocking=True),
  598. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  599. prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
  600. output_tokens=output_tensor.to(device=device, non_blocking=True),
  601. sampling_seeds=sampling_seeds_gpu,
  602. sample_indices=sample_indices_t.to(device=device,
  603. non_blocking=True),
  604. extra_seeds=extra_seeds_gpu,
  605. )
  606. @staticmethod
  607. def _get_sequence_seeds(
  608. seed: int,
  609. *extra_entropy: int,
  610. seeds_to_generate: int,
  611. is_greedy: bool,
  612. ):
  613. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  614. if not is_greedy:
  615. if seed is None:
  616. randint_fn = random.randint
  617. else:
  618. generator = random.Random(str((seed, ) + extra_entropy))
  619. randint_fn = generator.randint
  620. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  621. # If the user/random sets seed = 0 but request should
  622. # have sampling, we need to change it to something
  623. # else. We use a constant in that case.
  624. # This way we don't need to create and load a bool
  625. # matrix in the sampling kernel, which reduces CPU
  626. # overhead and latency.
  627. seq_seeds = [
  628. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  629. for _ in range(seeds_to_generate)
  630. ]
  631. else:
  632. # For the kernel, seed == 0 means greedy decoding.
  633. seq_seeds = [0] * seeds_to_generate
  634. return seq_seeds