sampling_metadata.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. import random
  2. from array import array
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  7. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  8. from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available,
  9. make_tensor_with_pad, maybe_expand_dim)
  10. from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
  11. _SAMPLING_EPS = 1e-5
  12. _SEED_0_REPLACEMENT = 3403598558
  13. # Some triton sampler related code is guarded before it is ready.
  14. _USE_TRITON_SAMPLER = False
  15. @dataclass
  16. class SequenceGroupToSample:
  17. # |---------- N-1 iteration --------|
  18. # |---------------- N iteration ---------------------|
  19. # |- tokenA -|......................|-- newTokens ---|
  20. # |---------- context_len ----------|
  21. # |-------------------- seq_len ----------------------|
  22. # |-- query_len ---|
  23. # Sequence ids for the sequence group in a previous step.
  24. seq_ids: List[int]
  25. sampling_params: SamplingParams
  26. # seq_id -> sequence data.
  27. seq_data: Dict[int, SequenceData]
  28. # The length of the sequence (all tokens seen in the past + new token to
  29. # compute attention) of the sequence group. None if it is in a decode
  30. # stage.
  31. seq_len: Optional[int]
  32. # The length of new query tokens to compute in the current step. None if it
  33. # is in a decode stage. The length of query_len <= seq_len if chunked
  34. # prefill is enabled.
  35. query_len: Optional[int]
  36. # A random number generator for sampling.
  37. generator: Optional[torch.Generator]
  38. # True if the sequence group is in prefill stage. False if it is in a
  39. # decode stage.
  40. is_prompt: bool
  41. # Query token indices from logits. to compute prompt logprob. Empty if
  42. # prompt logprob is not required.
  43. prompt_logprob_indices: List[int]
  44. # Sample token indices from logits. Empty if sampling is not required.
  45. sample_indices: List[int]
  46. @property
  47. def do_sample(self):
  48. return len(self.sample_indices) > 0
  49. def __post_init__(self):
  50. if len(self.prompt_logprob_indices) > 0:
  51. assert self.sampling_params.prompt_logprobs is not None
  52. if self.is_prompt:
  53. assert self.seq_len is not None
  54. assert self.query_len is not None
  55. class SamplingMetadata:
  56. """Metadata for input sequences. Used in sampler.
  57. The usage is as follow;
  58. ```
  59. hidden_states = execute_model(...)
  60. logits = hidden_states[sampling_metadata.selected_token_indices]
  61. sample(logits)
  62. def sample(logits):
  63. # Use categorized_sample_indices for sampling....
  64. ```
  65. Args:
  66. seq_groups: List of batched sequence groups.
  67. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  68. logits from the initial model output hidden states.
  69. categorized_sample_indices: SamplingType -> token indices to sample.
  70. Each token indices is 2D tensor of (num_indices, num_indices) where
  71. the first item means the sample index within the returned logit
  72. (before pruning padding), and the second item means the sample
  73. index after pruning using selected_token_indices.
  74. For example, if the returned logit is [1, 2, 3], and we select
  75. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  76. The first tuple is [1, 2] (sampled index within original logit),
  77. and the second tuple is [0, 1] (sampled index within pruned logit).
  78. num_prompts: Number of prompt sequence groups in seq_groups.
  79. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  80. serialization of token outputs.
  81. reuse_sampling_tensors: Indicates if we want to reuse sampling
  82. tensors that are part of the sampler forward pass. Currently,
  83. it is mainly used for multi-step decode.
  84. """
  85. def __init__(
  86. self,
  87. seq_groups: List[SequenceGroupToSample],
  88. selected_token_indices: torch.Tensor,
  89. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  90. num_prompts: int,
  91. skip_sampler_cpu_output: bool = False,
  92. reuse_sampling_tensors: bool = False,
  93. ) -> None:
  94. self.seq_groups = seq_groups
  95. self.selected_token_indices = selected_token_indices
  96. self.categorized_sample_indices = categorized_sample_indices
  97. self.num_prompts = num_prompts
  98. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  99. self.reuse_sampling_tensors = reuse_sampling_tensors
  100. @staticmethod
  101. def prepare(
  102. seq_group_metadata_list: List[SequenceGroupMetadata],
  103. seq_lens: List[int],
  104. query_lens: Optional[List[int]],
  105. device: str,
  106. pin_memory: bool,
  107. generators: Optional[Dict[str, torch.Generator]] = None,
  108. ) -> "SamplingMetadata":
  109. (
  110. seq_groups,
  111. selected_token_indices,
  112. categorized_sample_indices,
  113. num_prompts,
  114. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  115. device, generators)
  116. selected_token_indices = async_tensor_h2d(selected_token_indices,
  117. dtype=torch.long,
  118. target_device=device,
  119. pin_memory=pin_memory)
  120. categorized_sample_indices = {
  121. t: maybe_expand_dim(
  122. async_tensor_h2d(seq_ids,
  123. dtype=torch.int,
  124. target_device=device,
  125. pin_memory=pin_memory), 2, 2)
  126. for t, seq_ids in categorized_sample_indices.items()
  127. }
  128. sampling_metadata = SamplingMetadata(
  129. seq_groups=seq_groups,
  130. selected_token_indices=selected_token_indices,
  131. categorized_sample_indices=categorized_sample_indices,
  132. num_prompts=num_prompts,
  133. )
  134. return sampling_metadata
  135. def __repr__(self) -> str:
  136. return (
  137. "SamplingMetadata("
  138. f"seq_groups={self.seq_groups}, "
  139. f"selected_token_indices={self.selected_token_indices}, "
  140. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  141. def _prepare_seq_groups(
  142. seq_group_metadata_list: List[SequenceGroupMetadata],
  143. seq_lens: List[int],
  144. query_lens: Optional[List[int]],
  145. device: str,
  146. generators: Optional[Dict[str, torch.Generator]] = None,
  147. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  148. SamplingType, List[Tuple[int, int]]], int]:
  149. """Prepare sequence groups and indices for sampling.
  150. Args:
  151. seq_group_metadata_list: A list of sequence group to batch.
  152. seq_lens: A list of sequence lens per sequence group.
  153. Index of prompt len should match with seq_group_metadata_list.
  154. query_lens: A list of query lengths. Prompt lens include the length
  155. of entire prompt tokens, and it could be shorter.
  156. device: A device to use for random number generators,
  157. `SequenceGroupToSample.generator`.
  158. generators: A store of per-request random number generators used
  159. for seeded requests.
  160. Returns:
  161. seq_groups: A list of sequence group to sample.
  162. selected_token_indices: See the definition from `SamplingMetadata`.
  163. categorized_sample_indices: See the definition from `SamplingMetadata`.
  164. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  165. """
  166. # Batched sequence groups for the current model forward stsep.
  167. seq_groups: List[SequenceGroupToSample] = []
  168. # A list of token indices to sample/compute logprob. It is used to
  169. # prune the outcome logits from the model for the performance.
  170. selected_token_indices: List[int] = []
  171. # Used for selected_token_indices.
  172. model_output_idx = 0
  173. # Sampling type -> (
  174. # indices to sample/prompt logprob within pruned output logits,
  175. # indices to sample within pruned logits)
  176. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  177. t: []
  178. for t in SamplingType
  179. }
  180. # Index of logits to compute logprob. Logits include both prompt logprob
  181. # and sample logprob indices.
  182. logit_idx = 0
  183. # Index to sample from a sample tensor. It is used by triton sample kernel.
  184. # See `_sample_with_triton_kernel` for more details.
  185. sample_idx = 0
  186. # Total number of prompts from given sequence groups.
  187. num_prompts = 0
  188. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  189. seq_ids = list(seq_group_metadata.seq_data.keys())
  190. sampling_params = seq_group_metadata.sampling_params
  191. is_prompt = seq_group_metadata.is_prompt
  192. generator: Optional[torch.Generator] = None
  193. # If the current seq group is in decode stage, it is None.
  194. seq_len: Optional[int] = None
  195. query_len: Optional[int] = None
  196. prompt_logprob_indices: List[int] = []
  197. sample_indices: List[int] = []
  198. do_sample = seq_group_metadata.do_sample
  199. if seq_group_metadata.is_prompt:
  200. if sampling_params.seed is not None:
  201. generator = torch.Generator(device=device).manual_seed(
  202. sampling_params.seed)
  203. if generators is not None:
  204. generators[seq_group_metadata.request_id] = generator
  205. num_prompts += 1
  206. num_prefill_sample = len(seq_ids)
  207. assert num_prefill_sample == 1
  208. assert query_lens is not None and seq_lens is not None
  209. query_len, seq_len = query_lens[i], seq_lens[i]
  210. # If we need sampling, exclude num_prefill_sample tokens from
  211. # prompt logprob.
  212. prompt_logprob_len = (query_len - num_prefill_sample
  213. if do_sample else query_len)
  214. sample_len = num_prefill_sample if do_sample else 0
  215. else:
  216. # Decode
  217. prompt_logprob_len = 0
  218. sample_len = len(seq_ids) if do_sample else 0
  219. if sampling_params.seed is not None and generators is not None:
  220. generator = generators.get(seq_group_metadata.request_id)
  221. # Update indices to select from the model output.
  222. """
  223. This blocks computes selected_token_indices which is used in the
  224. following way.
  225. hidden_states = model(...)
  226. logits = hidden_states[selected_token_indices]
  227. """
  228. if sampling_params.prompt_logprobs is not None:
  229. selected_token_indices.extend(
  230. range(model_output_idx, model_output_idx + prompt_logprob_len))
  231. model_output_idx += prompt_logprob_len
  232. if do_sample:
  233. selected_token_indices.extend(
  234. range(model_output_idx, model_output_idx + sample_len))
  235. model_output_idx += sample_len
  236. # We now find indices for logprob computation and sampling.
  237. """
  238. This block computes categorized_sample_indices which is used in the
  239. following way.
  240. hidden_states = model(...)
  241. logits = hidden_states[selected_token_indices]
  242. def sample(logits):
  243. # Use categorized_sample_indices for sampling.
  244. # prompt_logprob_indices to find prompt logprob indices.
  245. # sample_indices to find sample indices.
  246. """
  247. if sampling_params.prompt_logprobs is not None:
  248. prompt_logprob_indices.extend(
  249. range(logit_idx, logit_idx + prompt_logprob_len))
  250. logit_idx += prompt_logprob_len
  251. if do_sample:
  252. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  253. categorized_sample_indices[sampling_params.sampling_type].extend(
  254. list(
  255. zip(range(logit_idx, logit_idx + sample_len),
  256. range(sample_idx, sample_idx + sample_len))))
  257. logit_idx += sample_len
  258. sample_idx += sample_len
  259. seq_groups.append(
  260. SequenceGroupToSample(
  261. seq_ids=seq_ids,
  262. sampling_params=sampling_params,
  263. seq_data=seq_group_metadata.seq_data,
  264. seq_len=seq_len,
  265. query_len=query_len,
  266. generator=generator,
  267. is_prompt=is_prompt,
  268. prompt_logprob_indices=list(prompt_logprob_indices),
  269. sample_indices=list(sample_indices)))
  270. return (seq_groups, selected_token_indices, categorized_sample_indices,
  271. num_prompts)
  272. @dataclass
  273. class SamplingTensors:
  274. """Tensors for sampling."""
  275. temperatures: torch.Tensor
  276. dynatemp_mins: torch.Tensor
  277. dynatemp_maxs: torch.Tensor
  278. dynatemp_exps: torch.Tensor
  279. top_ps: torch.Tensor
  280. top_ks: torch.Tensor
  281. top_as: torch.Tensor
  282. min_ps: torch.Tensor
  283. presence_penalties: torch.Tensor
  284. frequency_penalties: torch.Tensor
  285. repetition_penalties: torch.Tensor
  286. tfss: torch.Tensor
  287. eta_cutoffs: torch.Tensor
  288. epsilon_cutoffs: torch.Tensor
  289. typical_ps: torch.Tensor
  290. smoothing_factors: torch.Tensor
  291. smoothing_curves: torch.Tensor
  292. sampling_seeds: torch.Tensor
  293. sample_indices: torch.Tensor
  294. extra_seeds: Optional[torch.Tensor]
  295. prompt_tokens: torch.Tensor
  296. output_tokens: torch.Tensor
  297. @classmethod
  298. def from_sampling_metadata(
  299. cls,
  300. sampling_metadata: "SamplingMetadata",
  301. vocab_size: int,
  302. device: torch.device,
  303. dtype: torch.dtype,
  304. *,
  305. extra_seeds_to_generate: int = 0,
  306. extra_entropy: Optional[Tuple[int, ...]] = None
  307. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  308. bool, bool]:
  309. """
  310. extra_seeds_to_generate: extra seeds to generate using the
  311. user-defined seed for each sequence.
  312. extra_entropy: extra entropy to use when generating seeds.
  313. """
  314. prompt_tokens: List[array] = []
  315. output_tokens: List[array] = []
  316. top_ks: List[int] = []
  317. temperatures: List[float] = []
  318. dynatemp_mins: List[float] = []
  319. dynatemp_maxs: List[float] = []
  320. dynatemp_exps: List[float] = []
  321. top_ps: List[float] = []
  322. top_as: List[float] = []
  323. min_ps: List[float] = []
  324. presence_penalties: List[float] = []
  325. frequency_penalties: List[float] = []
  326. repetition_penalties: List[float] = []
  327. tfss: List[float] = []
  328. eta_cutoffs: List[float] = []
  329. epsilon_cutoffs: List[float] = []
  330. typical_ps: List[float] = []
  331. smoothing_factors: List[float] = []
  332. smoothing_curves: List[float] = []
  333. sampling_seeds: List[int] = []
  334. sample_indices: List[int] = []
  335. do_penalties = False
  336. do_top_p_top_k = False
  337. do_top_as = False
  338. do_min_p = False
  339. do_tfss = False
  340. do_eta_cutoffs = False
  341. do_epsilon_cutoffs = False
  342. do_typical_ps = False
  343. do_quadratic = False
  344. if _USE_TRITON_SAMPLER:
  345. prompt_best_of: List[int] = []
  346. # We need one base seed per Triton slice.
  347. seeds_to_generate = (extra_seeds_to_generate +
  348. get_num_triton_sampler_splits(vocab_size))
  349. assert sampling_metadata.seq_groups is not None
  350. for seq_group in sampling_metadata.seq_groups:
  351. seq_ids = seq_group.seq_ids
  352. sampling_params = seq_group.sampling_params
  353. temperature = sampling_params.temperature
  354. p = sampling_params.presence_penalty
  355. f = sampling_params.frequency_penalty
  356. r = sampling_params.repetition_penalty
  357. top_p = sampling_params.top_p
  358. top_a = sampling_params.top_a
  359. min_p = sampling_params.min_p
  360. tfs = sampling_params.tfs
  361. eta_cutoff = sampling_params.eta_cutoff
  362. epsilon_cutoff = sampling_params.epsilon_cutoff
  363. typical_p = sampling_params.typical_p
  364. smoothing_factor = sampling_params.smoothing_factor
  365. smoothing_curve = sampling_params.smoothing_curve
  366. # k should not be greater than the vocab size.
  367. top_k = min(sampling_params.top_k, vocab_size)
  368. top_k = vocab_size if top_k == -1 else top_k
  369. if temperature < _SAMPLING_EPS:
  370. # NOTE: Zero temperature means deterministic sampling
  371. # (i.e., greedy sampling or beam search).
  372. # Set the temperature to 1 to avoid division by zero.
  373. temperature = 1.0
  374. if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
  375. or top_k != vocab_size):
  376. do_top_p_top_k = True
  377. if do_top_as is False and top_a > 0.0:
  378. do_top_as = True
  379. if not do_min_p and min_p > _SAMPLING_EPS:
  380. do_min_p = True
  381. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  382. or abs(f) >= _SAMPLING_EPS
  383. or abs(r - 1.0) >= _SAMPLING_EPS):
  384. do_penalties = True
  385. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  386. do_tfss = True
  387. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  388. do_eta_cutoffs = True
  389. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  390. do_epsilon_cutoffs = True
  391. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  392. do_typical_ps = True
  393. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  394. or smoothing_curve > 1.0):
  395. do_quadratic = True
  396. is_prompt = seq_group.is_prompt
  397. if (is_prompt and sampling_params.prompt_logprobs is not None):
  398. # For tokens in the prompt that we only need to get
  399. # their logprobs
  400. query_len = seq_group.query_len
  401. assert query_len is not None
  402. prefill_len = len(seq_group.prompt_logprob_indices)
  403. temperatures += [temperature] * prefill_len
  404. dynatemp_mins += [sampling_params.dynatemp_min] * prefill_len
  405. dynatemp_maxs += [sampling_params.dynatemp_max] * prefill_len
  406. dynatemp_exps += ([sampling_params.dynatemp_exponent] *
  407. prefill_len)
  408. top_ps += [top_p] * prefill_len
  409. top_ks += [top_k] * prefill_len
  410. top_as += [top_a] * prefill_len
  411. min_ps += [min_p] * prefill_len
  412. presence_penalties += [0] * prefill_len
  413. frequency_penalties += [0] * prefill_len
  414. repetition_penalties += [1] * prefill_len
  415. tfss += [1] * prefill_len
  416. eta_cutoffs += [0] * prefill_len
  417. epsilon_cutoffs += [0] * prefill_len
  418. typical_ps += [1] * prefill_len
  419. smoothing_factors += [smoothing_factor] * prefill_len
  420. smoothing_curves += [smoothing_curve] * prefill_len
  421. if seq_group.do_sample:
  422. sample_lens = len(seq_group.sample_indices)
  423. assert sample_lens == len(seq_ids)
  424. temperatures += [temperature] * len(seq_ids)
  425. dynatemp_mins += [sampling_params.dynatemp_min] * len(seq_ids)
  426. dynatemp_maxs += [sampling_params.dynatemp_max] * len(seq_ids)
  427. dynatemp_exps += ([sampling_params.dynatemp_exponent] *
  428. len(seq_ids))
  429. top_ps += [top_p] * len(seq_ids)
  430. top_ks += [top_k] * len(seq_ids)
  431. top_as += [top_a] * len(seq_ids)
  432. min_ps += [min_p] * len(seq_ids)
  433. presence_penalties += [p] * len(seq_ids)
  434. frequency_penalties += [f] * len(seq_ids)
  435. repetition_penalties += [r] * len(seq_ids)
  436. tfss += [tfs] * len(seq_ids)
  437. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  438. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  439. typical_ps += [typical_p] * len(seq_ids)
  440. smoothing_factors += [smoothing_factor] * len(seq_ids)
  441. smoothing_curves += [smoothing_curve] * len(seq_ids)
  442. if _USE_TRITON_SAMPLER:
  443. if is_prompt:
  444. prompt_best_of.append(sampling_params.best_of)
  445. query_len = seq_group.query_len
  446. assert query_len is not None
  447. seed = sampling_params.seed
  448. is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
  449. for seq_id in seq_ids:
  450. seq_data = seq_group.seq_data[seq_id]
  451. extra_entropy = extra_entropy or ()
  452. seq_seeds = cls._get_sequence_seeds(
  453. seed,
  454. seq_data.get_len(),
  455. *extra_entropy,
  456. seq_id,
  457. seeds_to_generate=seeds_to_generate,
  458. is_greedy=is_greedy)
  459. sampling_seeds.append(seq_seeds)
  460. sample_indices.extend(seq_group.sample_indices)
  461. if do_penalties:
  462. for seq_group in sampling_metadata.seq_groups:
  463. seq_ids = seq_group.seq_ids
  464. if (seq_group.is_prompt
  465. and sampling_params.prompt_logprobs is not None):
  466. prefill_len = len(seq_group.prompt_logprob_indices)
  467. prompt_tokens.extend(
  468. array('l') for _ in range(prefill_len))
  469. output_tokens.extend(
  470. array('l') for _ in range(prefill_len))
  471. if seq_group.do_sample:
  472. for seq_id in seq_ids:
  473. seq_data = seq_group.seq_data[seq_id]
  474. prompt_tokens.append(seq_data.prompt_token_ids_array)
  475. output_tokens.append(seq_data.output_token_ids_array)
  476. sampling_tensors = SamplingTensors.from_lists(
  477. temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
  478. top_ps, top_ks, top_as, min_ps, presence_penalties,
  479. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  480. epsilon_cutoffs, typical_ps, smoothing_factors, smoothing_curves,
  481. sampling_seeds, sample_indices, prompt_tokens, output_tokens,
  482. vocab_size, extra_seeds_to_generate, device, dtype)
  483. return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
  484. do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  485. do_typical_ps, do_quadratic)
  486. @classmethod
  487. def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
  488. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  489. top_ps: List[float], top_ks: List[int], top_as: List[float],
  490. min_ps: List[float], presence_penalties: List[float],
  491. frequency_penalties: List[float],
  492. repetition_penalties: List[float], tfss: List[float],
  493. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  494. typical_ps: List[float], smoothing_factors: List[float],
  495. smoothing_curves: List[float], sampling_seeds: List[int],
  496. sample_indices: List[int], prompt_tokens: List[array],
  497. output_tokens: List[array], vocab_size: int,
  498. extra_seeds_to_generate: int, device: torch.device,
  499. dtype: torch.dtype) -> "SamplingTensors":
  500. # Note that the performance will be very bad without
  501. # pinned memory.
  502. pin_memory = is_pin_memory_available()
  503. do_penalties = prompt_tokens or output_tokens
  504. if do_penalties:
  505. prompt_t = make_tensor_with_pad(
  506. prompt_tokens,
  507. vocab_size,
  508. device="cpu",
  509. dtype=torch.int64,
  510. pin_memory=pin_memory,
  511. )
  512. output_t = make_tensor_with_pad(
  513. output_tokens,
  514. vocab_size,
  515. device="cpu",
  516. dtype=torch.int64,
  517. pin_memory=pin_memory,
  518. )
  519. else:
  520. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  521. prompt_t = empty_tensor
  522. output_t = empty_tensor
  523. temperatures_t = torch.tensor(
  524. temperatures,
  525. device="cpu",
  526. dtype=dtype,
  527. pin_memory=pin_memory,
  528. )
  529. dynatemp_mins_t = torch.tensor(
  530. dynatemp_mins,
  531. device="cpu",
  532. dtype=dtype,
  533. pin_memory=pin_memory,
  534. )
  535. dynatemp_maxs_t = torch.tensor(
  536. dynatemp_maxs,
  537. device="cpu",
  538. dtype=dtype,
  539. pin_memory=pin_memory,
  540. )
  541. dynatemp_exps_t = torch.tensor(
  542. dynatemp_exps,
  543. device="cpu",
  544. dtype=dtype,
  545. pin_memory=pin_memory,
  546. )
  547. top_ps_t = torch.tensor(
  548. top_ps,
  549. device="cpu",
  550. dtype=dtype,
  551. pin_memory=pin_memory,
  552. )
  553. top_as_t = torch.tensor(top_as,
  554. device="cpu",
  555. dtype=dtype,
  556. pin_memory=pin_memory)
  557. min_ps_t = torch.tensor(
  558. min_ps,
  559. device="cpu",
  560. dtype=dtype,
  561. pin_memory=pin_memory,
  562. )
  563. presence_penalties_t = torch.tensor(
  564. presence_penalties,
  565. device="cpu",
  566. dtype=dtype,
  567. pin_memory=pin_memory,
  568. )
  569. frequency_penalties_t = torch.tensor(
  570. frequency_penalties,
  571. device="cpu",
  572. dtype=dtype,
  573. pin_memory=pin_memory,
  574. )
  575. repetition_penalties_t = torch.tensor(
  576. repetition_penalties,
  577. device="cpu",
  578. dtype=dtype,
  579. pin_memory=pin_memory,
  580. )
  581. top_ks_t = torch.tensor(
  582. top_ks,
  583. device="cpu",
  584. dtype=torch.int,
  585. pin_memory=pin_memory,
  586. )
  587. tfss_t = torch.tensor(tfss,
  588. device="cpu",
  589. dtype=dtype,
  590. pin_memory=pin_memory)
  591. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  592. device="cpu",
  593. dtype=dtype,
  594. pin_memory=pin_memory)
  595. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  596. device="cpu",
  597. dtype=dtype,
  598. pin_memory=pin_memory)
  599. typical_ps_t = torch.tensor(typical_ps,
  600. device="cpu",
  601. dtype=dtype,
  602. pin_memory=pin_memory)
  603. smoothing_factors_t = torch.tensor(smoothing_factors,
  604. device="cpu",
  605. dtype=dtype,
  606. pin_memory=pin_memory)
  607. smoothing_curves_t = torch.tensor(smoothing_curves,
  608. device="cpu",
  609. dtype=dtype,
  610. pin_memory=pin_memory)
  611. sample_indices_t = torch.tensor(
  612. sample_indices,
  613. device="cpu",
  614. dtype=torch.long,
  615. pin_memory=pin_memory,
  616. )
  617. # need to transpose and make contiguous to
  618. # copy the tensor correctly.
  619. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  620. sampling_seeds_t = torch.tensor(
  621. sampling_seeds,
  622. device="cpu",
  623. dtype=torch.long,
  624. pin_memory=pin_memory,
  625. ).t().contiguous()
  626. # Because the memory is pinned, we can do non-blocking
  627. # transfer to device.
  628. # How many seeds the sample operation itself will need.
  629. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  630. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  631. non_blocking=True)
  632. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  633. if not extra_seeds_gpu.numel():
  634. extra_seeds_gpu = None
  635. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  636. return cls(
  637. temperatures=temperatures_t.to(device=device, non_blocking=True),
  638. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  639. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  640. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  641. top_ps=top_ps_t.to(device=device, non_blocking=True),
  642. top_ks=top_ks_t.to(device=device, non_blocking=True),
  643. top_as=top_as_t.to(device=device, non_blocking=True),
  644. min_ps=min_ps_t.to(device=device, non_blocking=True),
  645. presence_penalties=presence_penalties_t.to(device=device,
  646. non_blocking=True),
  647. frequency_penalties=frequency_penalties_t.to(device=device,
  648. non_blocking=True),
  649. repetition_penalties=repetition_penalties_t.to(device=device,
  650. non_blocking=True),
  651. tfss=tfss_t.to(device=device, non_blocking=True),
  652. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  653. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  654. non_blocking=True),
  655. smoothing_factors=smoothing_factors_t.to(device=device,
  656. non_blocking=True),
  657. smoothing_curves=smoothing_curves_t.to(device=device,
  658. non_blocking=True),
  659. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  660. prompt_tokens=prompt_t.to(device=device, non_blocking=True),
  661. output_tokens=output_t.to(device=device, non_blocking=True),
  662. sampling_seeds=sampling_seeds_gpu,
  663. sample_indices=sample_indices_t.to(device=device,
  664. non_blocking=True),
  665. extra_seeds=extra_seeds_gpu,
  666. )
  667. @staticmethod
  668. def _get_sequence_seeds(
  669. seed: int|None,
  670. *extra_entropy: int,
  671. seeds_to_generate: int,
  672. is_greedy: bool,
  673. ):
  674. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  675. if not is_greedy:
  676. if seed is None:
  677. randint_fn = random.randint
  678. else:
  679. generator = random.Random(str((seed, ) + extra_entropy))
  680. randint_fn = generator.randint
  681. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  682. # If the user/random sets seed = 0 but request should
  683. # have sampling, we need to change it to something
  684. # else. We use a constant in that case.
  685. # This way we don't need to create and load a bool
  686. # matrix in the sampling kernel, which reduces CPU
  687. # overhead and latency.
  688. seq_seeds = [
  689. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  690. for _ in range(seeds_to_generate)
  691. ]
  692. else:
  693. # For the kernel, seed == 0 means greedy decoding.
  694. seq_seeds = [0] * seeds_to_generate
  695. return seq_seeds