sampling_metadata.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825
  1. import random
  2. from array import array
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  7. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  8. from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
  9. is_pin_memory_available,
  10. make_tensor_with_pad, maybe_expand_dim)
  11. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  12. from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
  13. _SAMPLING_EPS = 1e-5
  14. _SEED_0_REPLACEMENT = 3403598558
  15. # Some triton sampler related code is guarded before it is ready.
  16. _USE_TRITON_SAMPLER = False
  17. @dataclass
  18. class SequenceGroupToSample:
  19. # |---------- N-1 iteration --------|
  20. # |---------------- N iteration ---------------------|
  21. # |- tokenA -|......................|-- newTokens ---|
  22. # |---------- context_len ----------|
  23. # |-------------------- seq_len ----------------------|
  24. # |-- query_len ---|
  25. # Sequence ids for the sequence group in a previous step.
  26. seq_ids: List[int]
  27. sampling_params: SamplingParams
  28. # seq_id -> sequence data.
  29. seq_data: Dict[int, SequenceData]
  30. # The length of the sequence (all tokens seen in the past + new token to
  31. # compute attention) of the sequence group. None if it is in a decode
  32. # stage.
  33. seq_len: Optional[int]
  34. # The length of new query tokens to compute in the current step. None if it
  35. # is in a decode stage. The length of query_len <= seq_len if chunked
  36. # prefill is enabled.
  37. query_len: Optional[int]
  38. # A random number generator for sampling.
  39. generator: Optional[torch.Generator]
  40. # True if the sequence group is in prefill stage. False if it is in a
  41. # decode stage.
  42. is_prompt: bool
  43. # Query token indices from logits. to compute prompt logprob. Empty if
  44. # prompt logprob is not required.
  45. prompt_logprob_indices: List[int]
  46. # Sample token indices from logits. Empty if sampling is not required.
  47. sample_indices: List[int]
  48. @property
  49. def do_sample(self):
  50. return len(self.sample_indices) > 0
  51. def __post_init__(self):
  52. if len(self.prompt_logprob_indices) > 0:
  53. assert self.sampling_params.prompt_logprobs is not None
  54. if self.is_prompt:
  55. assert self.seq_len is not None
  56. assert self.query_len is not None
  57. def gen_seq_group_to_sample_builder(num_seqs: int):
  58. return lambda: SequenceGroupToSample(
  59. seq_ids=[0] * num_seqs,
  60. sampling_params=None,
  61. seq_data=None, # type: ignore
  62. seq_len=0,
  63. query_len=0,
  64. generator=None,
  65. is_prompt=True,
  66. prompt_logprob_indices=[],
  67. sample_indices=[])
  68. class SamplingMetadataCache:
  69. """Used to cache SamplingMetadata objects between scheduler iterations
  70. """
  71. def __init__(self):
  72. self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
  73. def get_cached_seq_group_to_sample(self, num_seqs):
  74. if num_seqs not in self._seq_group_to_sample_cache:
  75. self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
  76. gen_seq_group_to_sample_builder(num_seqs))
  77. obj = self._seq_group_to_sample_cache[num_seqs].get_object()
  78. return obj
  79. def reset(self):
  80. for cache in self._seq_group_to_sample_cache.values():
  81. cache.reset()
  82. class SamplingMetadata:
  83. """Metadata for input sequences. Used in sampler.
  84. The usage is as follow;
  85. ```
  86. hidden_states = execute_model(...)
  87. logits = hidden_states[sampling_metadata.selected_token_indices]
  88. sample(logits)
  89. def sample(logits):
  90. # Use categorized_sample_indices for sampling....
  91. ```
  92. Args:
  93. seq_groups: List of batched sequence groups.
  94. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  95. logits from the initial model output hidden states.
  96. categorized_sample_indices: SamplingType -> token indices to sample.
  97. Each token indices is 2D tensor of (num_indices, num_indices) where
  98. the first item means the sample index within the returned logit
  99. (before pruning padding), and the second item means the sample
  100. index after pruning using selected_token_indices.
  101. For example, if the returned logit is [1, 2, 3], and we select
  102. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  103. The first tuple is [1, 2] (sampled index within original logit),
  104. and the second tuple is [0, 1] (sampled index within pruned logit).
  105. num_prompts: Number of prompt sequence groups in seq_groups.
  106. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  107. serialization of token outputs.
  108. reuse_sampling_tensors: Indicates if we want to reuse sampling
  109. tensors that are part of the sampler forward pass. Currently,
  110. it is mainly used for multi-step decode.
  111. """
  112. def __init__(
  113. self,
  114. seq_groups: List[SequenceGroupToSample],
  115. selected_token_indices: torch.Tensor,
  116. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  117. num_prompts: int,
  118. skip_sampler_cpu_output: bool = False,
  119. reuse_sampling_tensors: bool = False,
  120. ) -> None:
  121. self.seq_groups = seq_groups
  122. self.selected_token_indices = selected_token_indices
  123. self.categorized_sample_indices = categorized_sample_indices
  124. self.num_prompts = num_prompts
  125. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  126. self.reuse_sampling_tensors = reuse_sampling_tensors
  127. @staticmethod
  128. def prepare(
  129. seq_group_metadata_list: List[SequenceGroupMetadata],
  130. seq_lens: List[int],
  131. query_lens: Optional[List[int]],
  132. device: str,
  133. pin_memory: bool,
  134. generators: Optional[Dict[str, torch.Generator]] = None,
  135. cache: Optional[SamplingMetadataCache] = None
  136. ) -> "SamplingMetadata":
  137. (
  138. seq_groups,
  139. selected_token_indices,
  140. categorized_sample_indices,
  141. num_prompts,
  142. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  143. device, generators, cache)
  144. selected_token_indices = async_tensor_h2d(selected_token_indices,
  145. dtype=torch.long,
  146. target_device=device,
  147. pin_memory=pin_memory)
  148. categorized_sample_indices = {
  149. t: maybe_expand_dim(
  150. async_tensor_h2d(seq_ids,
  151. dtype=torch.int,
  152. target_device=device,
  153. pin_memory=pin_memory), 2, 2)
  154. for t, seq_ids in categorized_sample_indices.items()
  155. }
  156. sampling_metadata = SamplingMetadata(
  157. seq_groups=seq_groups,
  158. selected_token_indices=selected_token_indices,
  159. categorized_sample_indices=categorized_sample_indices,
  160. num_prompts=num_prompts,
  161. )
  162. return sampling_metadata
  163. def __repr__(self) -> str:
  164. return (
  165. "SamplingMetadata("
  166. f"seq_groups={self.seq_groups}, "
  167. f"selected_token_indices={self.selected_token_indices}, "
  168. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  169. def _prepare_seq_groups(
  170. seq_group_metadata_list: List[SequenceGroupMetadata],
  171. seq_lens: List[int],
  172. query_lens: Optional[List[int]],
  173. device: str,
  174. generators: Optional[Dict[str, torch.Generator]] = None,
  175. cache: Optional[SamplingMetadataCache] = None,
  176. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  177. SamplingType, List[Tuple[int, int]]], int]:
  178. """Prepare sequence groups and indices for sampling.
  179. Args:
  180. seq_group_metadata_list: A list of sequence group to batch.
  181. seq_lens: A list of sequence lens per sequence group.
  182. Index of prompt len should match with seq_group_metadata_list.
  183. query_lens: A list of query lengths. Prompt lens include the length
  184. of entire prompt tokens, and it could be shorter.
  185. device: A device to use for random number generators,
  186. `SequenceGroupToSample.generator`.
  187. generators: A store of per-request random number generators used
  188. for seeded requests.
  189. Returns:
  190. seq_groups: A list of sequence group to sample.
  191. selected_token_indices: See the definition from `SamplingMetadata`.
  192. categorized_sample_indices: See the definition from `SamplingMetadata`.
  193. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  194. """
  195. # Batched sequence groups for the current model forward stsep.
  196. seq_groups: List[SequenceGroupToSample] = []
  197. # A list of token indices to sample/compute logprob. It is used to
  198. # prune the outcome logits from the model for the performance.
  199. selected_token_indices: List[int] = []
  200. # Used for selected_token_indices.
  201. model_output_idx = 0
  202. # Sampling type -> (
  203. # indices to sample/prompt logprob within pruned output logits,
  204. # indices to sample within pruned logits)
  205. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  206. t: []
  207. for t in SamplingType
  208. }
  209. # Index of logits to compute logprob. Logits include both prompt logprob
  210. # and sample logprob indices.
  211. logit_idx = 0
  212. # Index to sample from a sample tensor. It is used by triton sample kernel.
  213. # See `_sample_with_triton_kernel` for more details.
  214. sample_idx = 0
  215. # Total number of prompts from given sequence groups.
  216. num_prompts = 0
  217. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  218. seq_ids = seq_group_metadata.seq_data.keys()
  219. if cache is not None:
  220. sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
  221. for j, seq_id in enumerate(seq_ids):
  222. sample_obj.seq_ids[j] = seq_id
  223. sample_obj.prompt_logprob_indices.clear()
  224. sample_obj.sample_indices.clear()
  225. sampling_params = seq_group_metadata.sampling_params
  226. is_prompt = seq_group_metadata.is_prompt
  227. generator: Optional[torch.Generator] = None
  228. # If the current seq group is in decode stage, it is None.
  229. seq_len: Optional[int] = None
  230. query_len: Optional[int] = None
  231. prompt_logprob_indices: List[int] = \
  232. sample_obj.prompt_logprob_indices if cache is not None else []
  233. sample_indices: List[int] = \
  234. sample_obj.sample_indices if cache is not None else []
  235. do_sample = seq_group_metadata.do_sample
  236. if seq_group_metadata.is_prompt:
  237. if sampling_params.seed is not None:
  238. generator = torch.Generator(device=device).manual_seed(
  239. sampling_params.seed)
  240. if generators is not None:
  241. generators[seq_group_metadata.request_id] = generator
  242. num_prompts += 1
  243. num_prefill_sample = len(seq_ids)
  244. assert num_prefill_sample == 1
  245. assert query_lens is not None and seq_lens is not None
  246. query_len, seq_len = query_lens[i], seq_lens[i]
  247. # If we need sampling, exclude num_prefill_sample tokens from
  248. # prompt logprob.
  249. prompt_logprob_len = (query_len - num_prefill_sample
  250. if do_sample else query_len)
  251. sample_len = num_prefill_sample if do_sample else 0
  252. else:
  253. # Decode
  254. prompt_logprob_len = 0
  255. sample_len = len(seq_ids) if do_sample else 0
  256. if sampling_params.seed is not None and generators is not None:
  257. generator = generators.get(seq_group_metadata.request_id)
  258. # Update indices to select from the model output.
  259. """
  260. This blocks computes selected_token_indices which is used in the
  261. following way.
  262. hidden_states = model(...)
  263. logits = hidden_states[selected_token_indices]
  264. """
  265. if sampling_params.prompt_logprobs is not None:
  266. selected_token_indices.extend(
  267. range(model_output_idx, model_output_idx + prompt_logprob_len))
  268. model_output_idx += prompt_logprob_len
  269. if do_sample:
  270. selected_token_indices.extend(
  271. range(model_output_idx, model_output_idx + sample_len))
  272. model_output_idx += sample_len
  273. # We now find indices for logprob computation and sampling.
  274. """
  275. This block computes categorized_sample_indices which is used in the
  276. following way.
  277. hidden_states = model(...)
  278. logits = hidden_states[selected_token_indices]
  279. def sample(logits):
  280. # Use categorized_sample_indices for sampling.
  281. # prompt_logprob_indices to find prompt logprob indices.
  282. # sample_indices to find sample indices.
  283. """
  284. if sampling_params.prompt_logprobs is not None:
  285. prompt_logprob_indices.extend(
  286. range(logit_idx, logit_idx + prompt_logprob_len))
  287. logit_idx += prompt_logprob_len
  288. if do_sample:
  289. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  290. categorized_sample_indices[sampling_params.sampling_type].extend(
  291. list(
  292. zip(range(logit_idx, logit_idx + sample_len),
  293. range(sample_idx, sample_idx + sample_len))))
  294. logit_idx += sample_len
  295. sample_idx += sample_len
  296. if cache is not None:
  297. sample_obj.sampling_params = sampling_params
  298. sample_obj.seq_data = seq_group_metadata.seq_data
  299. sample_obj.seq_len = seq_len
  300. sample_obj.query_len = query_len
  301. sample_obj.generator = generator
  302. sample_obj.is_prompt = is_prompt
  303. else:
  304. sample_obj = SequenceGroupToSample(
  305. seq_ids=list(seq_ids),
  306. sampling_params=sampling_params,
  307. seq_data=seq_group_metadata.seq_data,
  308. seq_len=seq_len,
  309. query_len=query_len,
  310. generator=generator,
  311. is_prompt=is_prompt,
  312. prompt_logprob_indices=list(prompt_logprob_indices),
  313. sample_indices=list(sample_indices))
  314. seq_groups.append(sample_obj)
  315. if cache is not None:
  316. cache.reset()
  317. return (seq_groups, selected_token_indices, categorized_sample_indices,
  318. num_prompts)
  319. @dataclass
  320. class SamplingTensors:
  321. """Tensors for sampling."""
  322. temperatures: torch.Tensor
  323. dynatemp_mins: torch.Tensor
  324. dynatemp_maxs: torch.Tensor
  325. dynatemp_exps: torch.Tensor
  326. temperature_lasts: torch.Tensor
  327. top_ps: torch.Tensor
  328. top_ks: torch.Tensor
  329. top_as: torch.Tensor
  330. min_ps: torch.Tensor
  331. presence_penalties: torch.Tensor
  332. frequency_penalties: torch.Tensor
  333. repetition_penalties: torch.Tensor
  334. tfss: torch.Tensor
  335. eta_cutoffs: torch.Tensor
  336. epsilon_cutoffs: torch.Tensor
  337. typical_ps: torch.Tensor
  338. smoothing_factors: torch.Tensor
  339. smoothing_curves: torch.Tensor
  340. xtc_thresholds: torch.Tensor
  341. xtc_probabilities: torch.Tensor
  342. nsigmas: torch.Tensor
  343. sampling_seeds: torch.Tensor
  344. sample_indices: torch.Tensor
  345. extra_seeds: Optional[torch.Tensor]
  346. prompt_tokens: torch.Tensor
  347. output_tokens: torch.Tensor
  348. @classmethod
  349. def from_sampling_metadata(
  350. cls,
  351. sampling_metadata: "SamplingMetadata",
  352. vocab_size: int,
  353. device: torch.device,
  354. dtype: torch.dtype,
  355. *,
  356. extra_seeds_to_generate: int = 0,
  357. extra_entropy: Optional[Tuple[int, ...]] = None
  358. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  359. bool, bool, bool, bool, bool, bool]:
  360. """
  361. extra_seeds_to_generate: extra seeds to generate using the
  362. user-defined seed for each sequence.
  363. extra_entropy: extra entropy to use when generating seeds.
  364. """
  365. prompt_tokens: List[array] = []
  366. output_tokens: List[array] = []
  367. top_ks: List[int] = []
  368. temperatures: List[float] = []
  369. dynatemp_mins: List[float] = []
  370. dynatemp_maxs: List[float] = []
  371. dynatemp_exps: List[float] = []
  372. temperature_lasts: List[bool] = []
  373. top_ps: List[float] = []
  374. top_as: List[float] = []
  375. min_ps: List[float] = []
  376. presence_penalties: List[float] = []
  377. frequency_penalties: List[float] = []
  378. repetition_penalties: List[float] = []
  379. tfss: List[float] = []
  380. eta_cutoffs: List[float] = []
  381. epsilon_cutoffs: List[float] = []
  382. typical_ps: List[float] = []
  383. smoothing_factors: List[float] = []
  384. smoothing_curves: List[float] = []
  385. xtc_thresholds: List[float] = []
  386. xtc_probabilities: List[float] = []
  387. nsigmas: List[float] = []
  388. sampling_seeds: List[List[int]] = []
  389. sample_indices: List[int] = []
  390. do_penalties = False
  391. do_temperatures = False
  392. do_top_p_top_k = False
  393. do_top_as = False
  394. do_min_p = False
  395. do_tfss = False
  396. do_eta_cutoffs = False
  397. do_epsilon_cutoffs = False
  398. do_typical_ps = False
  399. do_quadratic = False
  400. do_xtc = False
  401. do_nsigmas = False
  402. do_temp_last = False
  403. if _USE_TRITON_SAMPLER:
  404. prompt_best_of: List[int] = []
  405. # We need one base seed per Triton slice.
  406. seeds_to_generate = (extra_seeds_to_generate +
  407. get_num_triton_sampler_splits(vocab_size))
  408. assert sampling_metadata.seq_groups is not None
  409. for seq_group in sampling_metadata.seq_groups:
  410. seq_ids = seq_group.seq_ids
  411. params = seq_group.sampling_params
  412. # k should not be greater than the vocab size.
  413. top_k = min(params.top_k, vocab_size)
  414. top_k = vocab_size if top_k == -1 else top_k
  415. temperature = params.temperature
  416. if temperature < _SAMPLING_EPS:
  417. # NOTE: Zero temperature means deterministic sampling
  418. # (i.e., greedy sampling or beam search).
  419. # Set the temperature to 1 to avoid division by zero.
  420. temperature = 1.0
  421. do_temperatures |= (temperature != 1.0 or
  422. params.dynatemp_min > _SAMPLING_EPS or
  423. params.dynatemp_max > _SAMPLING_EPS)
  424. do_top_p_top_k |= (params.top_p < 1.0 - _SAMPLING_EPS or
  425. top_k != vocab_size)
  426. do_top_as |= params.top_a > 0.0
  427. do_min_p |= params.min_p > _SAMPLING_EPS
  428. do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
  429. abs(params.frequency_penalty) >= _SAMPLING_EPS or
  430. params.repetition_penalty > 1.0)
  431. do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
  432. do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
  433. do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
  434. do_typical_ps |= params.typical_p < 1.0 - _SAMPLING_EPS
  435. do_quadratic |= (params.smoothing_factor > _SAMPLING_EPS or
  436. params.smoothing_curve > 1.0)
  437. do_xtc |= params.xtc_probability > _SAMPLING_EPS
  438. do_nsigmas |= params.nsigma > _SAMPLING_EPS
  439. do_temp_last |= params.temperature_last
  440. is_prompt = seq_group.is_prompt
  441. wants_prompt_logprobs = params.prompt_logprobs is not None
  442. n_seqs = 0
  443. if seq_group.is_prompt and wants_prompt_logprobs:
  444. assert seq_group.query_len is not None
  445. n_seqs += len(seq_group.prompt_logprob_indices)
  446. if seq_group.do_sample:
  447. assert len(seq_group.sample_indices) == len(seq_ids)
  448. n_seqs += len(seq_ids)
  449. temperatures += [temperature] * n_seqs
  450. dynatemp_mins += [params.dynatemp_min] * n_seqs
  451. dynatemp_maxs += [params.dynatemp_max] * n_seqs
  452. dynatemp_exps += [params.dynatemp_exponent] * n_seqs
  453. temperature_lasts += [params.temperature_last] * n_seqs
  454. top_ps += [params.top_p] * n_seqs
  455. top_ks += [top_k] * n_seqs
  456. top_as += [params.top_a] * n_seqs
  457. min_ps += [params.min_p] * n_seqs
  458. presence_penalties += [params.presence_penalty] * n_seqs
  459. frequency_penalties += [params.frequency_penalty] * n_seqs
  460. repetition_penalties += [params.repetition_penalty] * n_seqs
  461. tfss += [params.tfs] * n_seqs
  462. eta_cutoffs += [params.eta_cutoff] * n_seqs
  463. epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
  464. typical_ps += [params.typical_p] * n_seqs
  465. smoothing_factors += [params.smoothing_factor] * n_seqs
  466. smoothing_curves += [params.smoothing_curve] * n_seqs
  467. xtc_thresholds += [params.xtc_threshold] * n_seqs
  468. xtc_probabilities += [params.xtc_probability] * n_seqs
  469. nsigmas += [params.nsigma] * n_seqs
  470. if _USE_TRITON_SAMPLER:
  471. if is_prompt:
  472. prompt_best_of.append(params.best_of)
  473. query_len = seq_group.query_len
  474. assert query_len is not None
  475. seed = params.seed
  476. is_greedy = params.sampling_type == SamplingType.GREEDY
  477. for seq_id in seq_ids:
  478. seq_data = seq_group.seq_data[seq_id]
  479. extra_entropy = extra_entropy or ()
  480. seq_seeds = cls._get_sequence_seeds(
  481. seed,
  482. seq_data.get_len(),
  483. *extra_entropy,
  484. seq_id,
  485. seeds_to_generate=seeds_to_generate,
  486. is_greedy=is_greedy)
  487. sampling_seeds.append(seq_seeds)
  488. sample_indices.extend(seq_group.sample_indices)
  489. if do_penalties:
  490. for seq_group in sampling_metadata.seq_groups:
  491. seq_ids = seq_group.seq_ids
  492. if (seq_group.is_prompt
  493. and params.prompt_logprobs is not None):
  494. prefill_len = len(seq_group.prompt_logprob_indices)
  495. prompt_tokens.extend(
  496. array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  497. for _ in range(prefill_len))
  498. output_tokens.extend(
  499. array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  500. for _ in range(prefill_len))
  501. if seq_group.do_sample:
  502. for seq_id in seq_ids:
  503. seq_data = seq_group.seq_data[seq_id]
  504. prompt_tokens.append(seq_data.prompt_token_ids_array)
  505. output_tokens.append(seq_data.output_token_ids_array)
  506. sampling_tensors = SamplingTensors.from_lists(
  507. temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
  508. temperature_lasts, top_ps, top_ks, top_as, min_ps,
  509. presence_penalties, frequency_penalties, repetition_penalties,
  510. tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
  511. smoothing_curves, xtc_thresholds, xtc_probabilities, nsigmas,
  512. sampling_seeds, sample_indices, prompt_tokens, output_tokens,
  513. vocab_size, extra_seeds_to_generate, device, dtype)
  514. return (sampling_tensors, do_penalties, do_temperatures,
  515. do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
  516. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
  517. do_nsigmas, do_temp_last)
  518. @classmethod
  519. def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
  520. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  521. temperature_lasts: List[bool], top_ps: List[float],
  522. top_ks: List[int], top_as: List[float],
  523. min_ps: List[float], presence_penalties: List[float],
  524. frequency_penalties: List[float],
  525. repetition_penalties: List[float], tfss: List[float],
  526. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  527. typical_ps: List[float], smoothing_factors: List[float],
  528. smoothing_curves: List[float], xtc_thresholds: List[float],
  529. xtc_probabilities: List[float], nsigmas: List[float],
  530. sampling_seeds: List[List[int]],
  531. sample_indices: List[int], prompt_tokens: List[array],
  532. output_tokens: List[array], vocab_size: int,
  533. extra_seeds_to_generate: int, device: torch.device,
  534. dtype: torch.dtype) -> "SamplingTensors":
  535. # Note that the performance will be very bad without
  536. # pinned memory.
  537. pin_memory = is_pin_memory_available()
  538. do_penalties = prompt_tokens or output_tokens
  539. if do_penalties:
  540. prompt_t = make_tensor_with_pad(
  541. prompt_tokens,
  542. vocab_size,
  543. device="cpu",
  544. dtype=torch.int64,
  545. pin_memory=pin_memory,
  546. )
  547. output_t = make_tensor_with_pad(
  548. output_tokens,
  549. vocab_size,
  550. device="cpu",
  551. dtype=torch.int64,
  552. pin_memory=pin_memory,
  553. )
  554. else:
  555. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  556. prompt_t = empty_tensor
  557. output_t = empty_tensor
  558. temperatures_t = torch.tensor(
  559. temperatures,
  560. device="cpu",
  561. dtype=dtype,
  562. pin_memory=pin_memory,
  563. )
  564. dynatemp_mins_t = torch.tensor(
  565. dynatemp_mins,
  566. device="cpu",
  567. dtype=dtype,
  568. pin_memory=pin_memory,
  569. )
  570. dynatemp_maxs_t = torch.tensor(
  571. dynatemp_maxs,
  572. device="cpu",
  573. dtype=dtype,
  574. pin_memory=pin_memory,
  575. )
  576. dynatemp_exps_t = torch.tensor(
  577. dynatemp_exps,
  578. device="cpu",
  579. dtype=dtype,
  580. pin_memory=pin_memory,
  581. )
  582. temp_lasts_t = torch.tensor(
  583. temperature_lasts,
  584. device="cpu",
  585. dtype=torch.bool,
  586. pin_memory=pin_memory,
  587. )
  588. top_ps_t = torch.tensor(
  589. top_ps,
  590. device="cpu",
  591. dtype=dtype,
  592. pin_memory=pin_memory,
  593. )
  594. top_as_t = torch.tensor(top_as,
  595. device="cpu",
  596. dtype=dtype,
  597. pin_memory=pin_memory)
  598. min_ps_t = torch.tensor(
  599. min_ps,
  600. device="cpu",
  601. dtype=dtype,
  602. pin_memory=pin_memory,
  603. )
  604. presence_penalties_t = torch.tensor(
  605. presence_penalties,
  606. device="cpu",
  607. dtype=dtype,
  608. pin_memory=pin_memory,
  609. )
  610. frequency_penalties_t = torch.tensor(
  611. frequency_penalties,
  612. device="cpu",
  613. dtype=dtype,
  614. pin_memory=pin_memory,
  615. )
  616. repetition_penalties_t = torch.tensor(
  617. repetition_penalties,
  618. device="cpu",
  619. dtype=dtype,
  620. pin_memory=pin_memory,
  621. )
  622. top_ks_t = torch.tensor(
  623. top_ks,
  624. device="cpu",
  625. dtype=torch.int,
  626. pin_memory=pin_memory,
  627. )
  628. tfss_t = torch.tensor(tfss,
  629. device="cpu",
  630. dtype=dtype,
  631. pin_memory=pin_memory)
  632. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  633. device="cpu",
  634. dtype=dtype,
  635. pin_memory=pin_memory)
  636. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  637. device="cpu",
  638. dtype=dtype,
  639. pin_memory=pin_memory)
  640. typical_ps_t = torch.tensor(typical_ps,
  641. device="cpu",
  642. dtype=dtype,
  643. pin_memory=pin_memory)
  644. smoothing_factors_t = torch.tensor(smoothing_factors,
  645. device="cpu",
  646. dtype=dtype,
  647. pin_memory=pin_memory)
  648. smoothing_curves_t = torch.tensor(smoothing_curves,
  649. device="cpu",
  650. dtype=dtype,
  651. pin_memory=pin_memory)
  652. xtc_thresholds_t = torch.tensor(xtc_thresholds,
  653. device="cpu",
  654. dtype=dtype,
  655. pin_memory=pin_memory)
  656. xtc_probabilities_t = torch.tensor(xtc_probabilities,
  657. device="cpu",
  658. dtype=dtype,
  659. pin_memory=pin_memory)
  660. nsigmas_t = torch.tensor(nsigmas,
  661. device="cpu",
  662. dtype=dtype,
  663. pin_memory=pin_memory)
  664. sample_indices_t = torch.tensor(
  665. sample_indices,
  666. device="cpu",
  667. dtype=torch.long,
  668. pin_memory=pin_memory,
  669. )
  670. # need to transpose and make contiguous to
  671. # copy the tensor correctly.
  672. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  673. sampling_seeds_t = torch.tensor(
  674. sampling_seeds,
  675. device="cpu",
  676. dtype=torch.long,
  677. pin_memory=pin_memory,
  678. ).t().contiguous()
  679. # Because the memory is pinned, we can do non-blocking
  680. # transfer to device.
  681. # How many seeds the sample operation itself will need.
  682. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  683. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  684. non_blocking=True)
  685. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  686. if not extra_seeds_gpu.numel():
  687. extra_seeds_gpu = None
  688. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  689. return cls(
  690. temperatures=temperatures_t.to(device=device, non_blocking=True),
  691. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  692. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  693. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  694. temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
  695. top_ps=top_ps_t.to(device=device, non_blocking=True),
  696. top_ks=top_ks_t.to(device=device, non_blocking=True),
  697. top_as=top_as_t.to(device=device, non_blocking=True),
  698. min_ps=min_ps_t.to(device=device, non_blocking=True),
  699. presence_penalties=presence_penalties_t.to(device=device,
  700. non_blocking=True),
  701. frequency_penalties=frequency_penalties_t.to(device=device,
  702. non_blocking=True),
  703. repetition_penalties=repetition_penalties_t.to(device=device,
  704. non_blocking=True),
  705. tfss=tfss_t.to(device=device, non_blocking=True),
  706. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  707. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  708. non_blocking=True),
  709. smoothing_factors=smoothing_factors_t.to(device=device,
  710. non_blocking=True),
  711. smoothing_curves=smoothing_curves_t.to(device=device,
  712. non_blocking=True),
  713. xtc_thresholds=xtc_thresholds_t.to(device=device,
  714. non_blocking=True),
  715. xtc_probabilities=xtc_probabilities_t.to(device=device,
  716. non_blocking=True),
  717. nsigmas=nsigmas_t.to(device=device, non_blocking=True),
  718. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  719. prompt_tokens=prompt_t.to(device=device, non_blocking=True),
  720. output_tokens=output_t.to(device=device, non_blocking=True),
  721. sampling_seeds=sampling_seeds_gpu,
  722. sample_indices=sample_indices_t.to(device=device,
  723. non_blocking=True),
  724. extra_seeds=extra_seeds_gpu,
  725. )
  726. @staticmethod
  727. def _get_sequence_seeds(
  728. seed: int|None,
  729. *extra_entropy: int,
  730. seeds_to_generate: int,
  731. is_greedy: bool,
  732. ):
  733. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  734. if not is_greedy:
  735. if seed is None:
  736. randint_fn = random.randint
  737. else:
  738. generator = random.Random(str((seed, ) + extra_entropy))
  739. randint_fn = generator.randint
  740. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  741. # If the user/random sets seed = 0 but request should
  742. # have sampling, we need to change it to something
  743. # else. We use a constant in that case.
  744. # This way we don't need to create and load a bool
  745. # matrix in the sampling kernel, which reduces CPU
  746. # overhead and latency.
  747. seq_seeds = [
  748. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  749. for _ in range(seeds_to_generate)
  750. ]
  751. else:
  752. # For the kernel, seed == 0 means greedy decoding.
  753. seq_seeds = [0] * seeds_to_generate
  754. return seq_seeds