sampling_metadata.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. import random
  2. from array import array
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  7. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  8. from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
  9. is_pin_memory_available,
  10. make_tensor_with_pad, maybe_expand_dim)
  11. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  12. from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
  13. _SAMPLING_EPS = 1e-5
  14. _SEED_0_REPLACEMENT = 3403598558
  15. # Some triton sampler related code is guarded before it is ready.
  16. _USE_TRITON_SAMPLER = False
  17. @dataclass
  18. class SequenceGroupToSample:
  19. # |---------- N-1 iteration --------|
  20. # |---------------- N iteration ---------------------|
  21. # |- tokenA -|......................|-- newTokens ---|
  22. # |---------- context_len ----------|
  23. # |-------------------- seq_len ----------------------|
  24. # |-- query_len ---|
  25. # Sequence ids for the sequence group in a previous step.
  26. seq_ids: List[int]
  27. sampling_params: SamplingParams
  28. # seq_id -> sequence data.
  29. seq_data: Dict[int, SequenceData]
  30. # The length of the sequence (all tokens seen in the past + new token to
  31. # compute attention) of the sequence group. None if it is in a decode
  32. # stage.
  33. seq_len: Optional[int]
  34. # The length of new query tokens to compute in the current step. None if it
  35. # is in a decode stage. The length of query_len <= seq_len if chunked
  36. # prefill is enabled.
  37. query_len: Optional[int]
  38. # A random number generator for sampling.
  39. generator: Optional[torch.Generator]
  40. # True if the sequence group is in prefill stage. False if it is in a
  41. # decode stage.
  42. is_prompt: bool
  43. # Query token indices from logits. to compute prompt logprob. Empty if
  44. # prompt logprob is not required.
  45. prompt_logprob_indices: List[int]
  46. # Sample token indices from logits. Empty if sampling is not required.
  47. sample_indices: List[int]
  48. negative_seq_data: Optional[Dict[int, SequenceData]] = None
  49. negative_seq_len: Optional[int] = None
  50. negative_query_len: Optional[int] = None
  51. @property
  52. def do_sample(self):
  53. return len(self.sample_indices) > 0
  54. def __post_init__(self):
  55. if len(self.prompt_logprob_indices) > 0:
  56. assert self.sampling_params.prompt_logprobs is not None
  57. if self.is_prompt:
  58. assert self.seq_len is not None
  59. assert self.query_len is not None
  60. def gen_seq_group_to_sample_builder(num_seqs: int):
  61. return lambda: SequenceGroupToSample(
  62. seq_ids=[0] * num_seqs,
  63. sampling_params=None,
  64. seq_data=None, # type: ignore
  65. seq_len=0,
  66. query_len=0,
  67. generator=None,
  68. is_prompt=True,
  69. prompt_logprob_indices=[],
  70. sample_indices=[])
  71. class SamplingMetadataCache:
  72. """Used to cache SamplingMetadata objects between scheduler iterations
  73. """
  74. def __init__(self):
  75. self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
  76. def get_cached_seq_group_to_sample(self, num_seqs):
  77. if num_seqs not in self._seq_group_to_sample_cache:
  78. self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
  79. gen_seq_group_to_sample_builder(num_seqs))
  80. obj = self._seq_group_to_sample_cache[num_seqs].get_object()
  81. return obj
  82. def reset(self):
  83. for cache in self._seq_group_to_sample_cache.values():
  84. cache.reset()
  85. class SamplingMetadata:
  86. """Metadata for input sequences. Used in sampler.
  87. The usage is as follow;
  88. ```
  89. hidden_states = execute_model(...)
  90. logits = hidden_states[sampling_metadata.selected_token_indices]
  91. sample(logits)
  92. def sample(logits):
  93. # Use categorized_sample_indices for sampling....
  94. ```
  95. Args:
  96. seq_groups: List of batched sequence groups.
  97. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  98. logits from the initial model output hidden states.
  99. categorized_sample_indices: SamplingType -> token indices to sample.
  100. Each token indices is 2D tensor of (num_indices, num_indices) where
  101. the first item means the sample index within the returned logit
  102. (before pruning padding), and the second item means the sample
  103. index after pruning using selected_token_indices.
  104. For example, if the returned logit is [1, 2, 3], and we select
  105. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  106. The first tuple is [1, 2] (sampled index within original logit),
  107. and the second tuple is [0, 1] (sampled index within pruned logit).
  108. num_prompts: Number of prompt sequence groups in seq_groups.
  109. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  110. serialization of token outputs.
  111. reuse_sampling_tensors: Indicates if we want to reuse sampling
  112. tensors that are part of the sampler forward pass. Currently,
  113. it is mainly used for multi-step decode.
  114. """
  115. def __init__(
  116. self,
  117. seq_groups: List[SequenceGroupToSample],
  118. selected_token_indices: torch.Tensor,
  119. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  120. num_prompts: int,
  121. skip_sampler_cpu_output: bool = False,
  122. reuse_sampling_tensors: bool = False,
  123. ) -> None:
  124. self.seq_groups = seq_groups
  125. self.selected_token_indices = selected_token_indices
  126. self.categorized_sample_indices = categorized_sample_indices
  127. self.num_prompts = num_prompts
  128. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  129. self.reuse_sampling_tensors = reuse_sampling_tensors
  130. @staticmethod
  131. def prepare(
  132. seq_group_metadata_list: List[SequenceGroupMetadata],
  133. seq_lens: List[int],
  134. query_lens: Optional[List[int]],
  135. device: str,
  136. pin_memory: bool,
  137. generators: Optional[Dict[str, torch.Generator]] = None,
  138. cache: Optional[SamplingMetadataCache] = None
  139. ) -> "SamplingMetadata":
  140. (
  141. seq_groups,
  142. selected_token_indices,
  143. categorized_sample_indices,
  144. num_prompts,
  145. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  146. device, generators, cache)
  147. selected_token_indices = async_tensor_h2d(selected_token_indices,
  148. dtype=torch.long,
  149. target_device=device,
  150. pin_memory=pin_memory)
  151. categorized_sample_indices = {
  152. t: maybe_expand_dim(
  153. async_tensor_h2d(seq_ids,
  154. dtype=torch.int,
  155. target_device=device,
  156. pin_memory=pin_memory), 2, 2)
  157. for t, seq_ids in categorized_sample_indices.items()
  158. }
  159. sampling_metadata = SamplingMetadata(
  160. seq_groups=seq_groups,
  161. selected_token_indices=selected_token_indices,
  162. categorized_sample_indices=categorized_sample_indices,
  163. num_prompts=num_prompts,
  164. )
  165. return sampling_metadata
  166. def __repr__(self) -> str:
  167. return (
  168. "SamplingMetadata("
  169. f"seq_groups={self.seq_groups}, "
  170. f"selected_token_indices={self.selected_token_indices}, "
  171. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  172. def _prepare_seq_groups(
  173. seq_group_metadata_list: List[SequenceGroupMetadata],
  174. seq_lens: List[int],
  175. query_lens: Optional[List[int]],
  176. device: str,
  177. generators: Optional[Dict[str, torch.Generator]] = None,
  178. cache: Optional[SamplingMetadataCache] = None,
  179. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  180. SamplingType, List[Tuple[int, int]]], int]:
  181. """Prepare sequence groups and indices for sampling.
  182. Args:
  183. seq_group_metadata_list: A list of sequence group to batch.
  184. seq_lens: A list of sequence lens per sequence group.
  185. Index of prompt len should match with seq_group_metadata_list.
  186. query_lens: A list of query lengths. Prompt lens include the length
  187. of entire prompt tokens, and it could be shorter.
  188. device: A device to use for random number generators,
  189. `SequenceGroupToSample.generator`.
  190. generators: A store of per-request random number generators used
  191. for seeded requests.
  192. Returns:
  193. seq_groups: A list of sequence group to sample.
  194. selected_token_indices: See the definition from `SamplingMetadata`.
  195. categorized_sample_indices: See the definition from `SamplingMetadata`.
  196. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  197. """
  198. # Batched sequence groups for the current model forward stsep.
  199. seq_groups: List[SequenceGroupToSample] = []
  200. # A list of token indices to sample/compute logprob. It is used to
  201. # prune the outcome logits from the model for the performance.
  202. selected_token_indices: List[int] = []
  203. # Used for selected_token_indices.
  204. model_output_idx = 0
  205. # Sampling type -> (
  206. # indices to sample/prompt logprob within pruned output logits,
  207. # indices to sample within pruned logits)
  208. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  209. t: []
  210. for t in SamplingType
  211. }
  212. # Index of logits to compute logprob. Logits include both prompt logprob
  213. # and sample logprob indices.
  214. logit_idx = 0
  215. # Index to sample from a sample tensor. It is used by triton sample kernel.
  216. # See `_sample_with_triton_kernel` for more details.
  217. sample_idx = 0
  218. # Total number of prompts from given sequence groups.
  219. num_prompts = 0
  220. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  221. seq_ids = seq_group_metadata.seq_data.keys()
  222. has_negative = seq_group_metadata.negative_seq_data is not None
  223. if cache is not None:
  224. sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
  225. for j, seq_id in enumerate(seq_ids):
  226. sample_obj.seq_ids[j] = seq_id
  227. sample_obj.prompt_logprob_indices.clear()
  228. sample_obj.sample_indices.clear()
  229. sampling_params = seq_group_metadata.sampling_params
  230. is_prompt = seq_group_metadata.is_prompt
  231. generator: Optional[torch.Generator] = None
  232. # If the current seq group is in decode stage, it is None.
  233. seq_len: Optional[int] = None
  234. query_len: Optional[int] = None
  235. negative_seq_len: Optional[int] = None
  236. negative_query_len: Optional[int] = None
  237. prompt_logprob_indices: List[int] = \
  238. sample_obj.prompt_logprob_indices if cache is not None else []
  239. sample_indices: List[int] = \
  240. sample_obj.sample_indices if cache is not None else []
  241. do_sample = seq_group_metadata.do_sample
  242. if seq_group_metadata.is_prompt:
  243. if sampling_params.seed is not None:
  244. generator = torch.Generator(device=device).manual_seed(
  245. sampling_params.seed)
  246. if generators is not None:
  247. generators[seq_group_metadata.request_id] = generator
  248. num_prompts += 1
  249. num_prefill_sample = len(seq_ids)
  250. assert num_prefill_sample == 1
  251. assert query_lens is not None and seq_lens is not None
  252. if has_negative:
  253. positive_query_lens, positive_seq_lens = (query_lens[::2],
  254. seq_lens[::2])
  255. negative_query_lens, negative_seq_lens = (query_lens[1::2],
  256. seq_lens[1::2])
  257. query_len, seq_len = (positive_query_lens[i],
  258. positive_seq_lens[i])
  259. prompt_logprob_len = (query_len - num_prefill_sample
  260. if do_sample else query_len)
  261. sample_len = num_prefill_sample if do_sample else 0
  262. negative_query_len, negative_seq_len = (negative_query_lens[i],
  263. negative_seq_lens[i])
  264. negative_prompt_logprob_len = (negative_query_len -
  265. num_prefill_sample
  266. if do_sample else
  267. negative_query_len)
  268. negative_sample_len = num_prefill_sample if do_sample else 0
  269. else:
  270. query_len, seq_len = query_lens[i], seq_lens[i]
  271. # If we need sampling, exclude num_prefill_sample tokens from
  272. # prompt logprob.
  273. prompt_logprob_len = (query_len - num_prefill_sample
  274. if do_sample else query_len)
  275. sample_len = num_prefill_sample if do_sample else 0
  276. else:
  277. # Decode
  278. if has_negative:
  279. prompt_logprob_len = 0
  280. sample_len = len(seq_ids) if do_sample else 0
  281. negative_prompt_logprob_len = 0
  282. negative_sample_len = len(seq_ids) if do_sample else 0
  283. else:
  284. prompt_logprob_len = 0
  285. sample_len = len(seq_ids) if do_sample else 0
  286. if sampling_params.seed is not None and generators is not None:
  287. generator = generators.get(seq_group_metadata.request_id)
  288. # Update indices to select from the model output.
  289. """
  290. This blocks computes selected_token_indices which is used in the
  291. following way.
  292. hidden_states = model(...)
  293. logits = hidden_states[selected_token_indices]
  294. """
  295. if sampling_params.prompt_logprobs is not None:
  296. selected_token_indices.extend(
  297. range(model_output_idx, model_output_idx + prompt_logprob_len))
  298. model_output_idx += prompt_logprob_len
  299. if do_sample:
  300. selected_token_indices.extend(
  301. range(model_output_idx, model_output_idx + sample_len))
  302. model_output_idx += sample_len
  303. if sampling_params.prompt_logprobs is not None and has_negative:
  304. selected_token_indices.extend(
  305. range(model_output_idx, model_output_idx +
  306. negative_prompt_logprob_len))
  307. if has_negative:
  308. model_output_idx += negative_prompt_logprob_len
  309. if do_sample and has_negative:
  310. selected_token_indices.extend(
  311. range(model_output_idx, model_output_idx + negative_sample_len))
  312. if has_negative:
  313. model_output_idx += negative_sample_len
  314. # We now find indices for logprob computation and sampling.
  315. """
  316. This block computes categorized_sample_indices which is used in the
  317. following way.
  318. hidden_states = model(...)
  319. logits = hidden_states[selected_token_indices]
  320. def sample(logits):
  321. # Use categorized_sample_indices for sampling.
  322. # prompt_logprob_indices to find prompt logprob indices.
  323. # sample_indices to find sample indices.
  324. """
  325. if sampling_params.prompt_logprobs is not None:
  326. prompt_logprob_indices.extend(
  327. range(logit_idx, logit_idx + prompt_logprob_len))
  328. logit_idx += prompt_logprob_len
  329. if do_sample:
  330. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  331. categorized_sample_indices[sampling_params.sampling_type].extend(
  332. list(
  333. zip(range(logit_idx, logit_idx + sample_len),
  334. range(sample_idx, sample_idx + sample_len))))
  335. logit_idx += sample_len
  336. sample_idx += sample_len
  337. if cache is not None:
  338. sample_obj.sampling_params = sampling_params
  339. sample_obj.seq_data = seq_group_metadata.seq_data
  340. sample_obj.seq_len = seq_len
  341. sample_obj.query_len = query_len
  342. sample_obj.generator = generator
  343. sample_obj.is_prompt = is_prompt
  344. if has_negative:
  345. sample_obj.negative_seq_data = \
  346. seq_group_metadata.negative_seq_data
  347. sample_obj.negative_seq_len = negative_seq_len
  348. sample_obj.negative_query_len = negative_query_len
  349. else:
  350. sample_obj = SequenceGroupToSample(
  351. seq_ids=list(seq_ids),
  352. sampling_params=sampling_params,
  353. seq_data=seq_group_metadata.seq_data,
  354. seq_len=seq_len,
  355. query_len=query_len,
  356. generator=generator,
  357. is_prompt=is_prompt,
  358. prompt_logprob_indices=list(prompt_logprob_indices),
  359. sample_indices=list(sample_indices),
  360. negative_seq_data=seq_group_metadata.negative_seq_data if
  361. has_negative else None,
  362. negative_seq_len=negative_seq_len if has_negative else None,
  363. negative_query_len=negative_query_len if has_negative else None)
  364. seq_groups.append(sample_obj)
  365. if cache is not None:
  366. cache.reset()
  367. return (seq_groups, selected_token_indices, categorized_sample_indices,
  368. num_prompts)
  369. @dataclass
  370. class SamplingTensors:
  371. """Tensors for sampling."""
  372. temperatures: torch.Tensor
  373. dynatemp_mins: torch.Tensor
  374. dynatemp_maxs: torch.Tensor
  375. dynatemp_exps: torch.Tensor
  376. temperature_lasts: torch.Tensor
  377. top_ps: torch.Tensor
  378. top_ks: torch.Tensor
  379. top_as: torch.Tensor
  380. min_ps: torch.Tensor
  381. presence_penalties: torch.Tensor
  382. frequency_penalties: torch.Tensor
  383. repetition_penalties: torch.Tensor
  384. no_repeat_ngram_sizes: torch.Tensor
  385. tfss: torch.Tensor
  386. eta_cutoffs: torch.Tensor
  387. epsilon_cutoffs: torch.Tensor
  388. typical_ps: torch.Tensor
  389. smoothing_factors: torch.Tensor
  390. smoothing_curves: torch.Tensor
  391. xtc_thresholds: torch.Tensor
  392. xtc_probabilities: torch.Tensor
  393. nsigmas: torch.Tensor
  394. dry_multipliers: torch.Tensor
  395. dry_bases: torch.Tensor
  396. dry_allowed_lengths: torch.Tensor
  397. dry_sequence_breaker_ids: torch.Tensor
  398. skews: torch.Tensor
  399. sampling_seeds: torch.Tensor
  400. sample_indices: torch.Tensor
  401. extra_seeds: Optional[torch.Tensor]
  402. prompt_tokens: torch.Tensor
  403. output_tokens: torch.Tensor
  404. @classmethod
  405. def from_sampling_metadata(
  406. cls,
  407. sampling_metadata: "SamplingMetadata",
  408. vocab_size: int,
  409. device: torch.device,
  410. dtype: torch.dtype,
  411. *,
  412. extra_seeds_to_generate: int = 0,
  413. extra_entropy: Optional[Tuple[int, ...]] = None
  414. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  415. bool, bool, bool, bool, bool, bool, bool, bool, bool]:
  416. """
  417. extra_seeds_to_generate: extra seeds to generate using the
  418. user-defined seed for each sequence.
  419. extra_entropy: extra entropy to use when generating seeds.
  420. """
  421. prompt_tokens: List[array] = []
  422. output_tokens: List[array] = []
  423. top_ks: List[int] = []
  424. temperatures: List[float] = []
  425. dynatemp_mins: List[float] = []
  426. dynatemp_maxs: List[float] = []
  427. dynatemp_exps: List[float] = []
  428. temperature_lasts: List[bool] = []
  429. top_ps: List[float] = []
  430. top_as: List[float] = []
  431. min_ps: List[float] = []
  432. presence_penalties: List[float] = []
  433. frequency_penalties: List[float] = []
  434. repetition_penalties: List[float] = []
  435. no_repeat_ngram_sizes: List[int] = []
  436. tfss: List[float] = []
  437. eta_cutoffs: List[float] = []
  438. epsilon_cutoffs: List[float] = []
  439. typical_ps: List[float] = []
  440. smoothing_factors: List[float] = []
  441. smoothing_curves: List[float] = []
  442. xtc_thresholds: List[float] = []
  443. xtc_probabilities: List[float] = []
  444. nsigmas: List[float] = []
  445. sampling_seeds: List[List[int]] = []
  446. sample_indices: List[int] = []
  447. dry_multipliers: List[float] = []
  448. dry_bases: List[float] = []
  449. dry_allowed_lengths: List[int] = []
  450. dry_sequence_breaker_ids: List[List[int]] = []
  451. skews: List[float] = []
  452. do_penalties = False
  453. do_no_repeat_ngrams = False
  454. do_temperatures = False
  455. do_top_p_top_k = False
  456. do_top_as = False
  457. do_min_p = False
  458. do_tfss = False
  459. do_eta_cutoffs = False
  460. do_epsilon_cutoffs = False
  461. do_typical_ps = False
  462. do_quadratic = False
  463. do_xtc = False
  464. do_nsigmas = False
  465. do_dry = False
  466. do_skews = False
  467. do_temp_last = False
  468. if _USE_TRITON_SAMPLER:
  469. prompt_best_of: List[int] = []
  470. # We need one base seed per Triton slice.
  471. seeds_to_generate = (extra_seeds_to_generate +
  472. get_num_triton_sampler_splits(vocab_size))
  473. assert sampling_metadata.seq_groups is not None
  474. for seq_group in sampling_metadata.seq_groups:
  475. seq_ids = seq_group.seq_ids
  476. params = seq_group.sampling_params
  477. # k should not be greater than the vocab size.
  478. top_k = min(params.top_k, vocab_size)
  479. top_k = vocab_size if top_k == -1 else top_k
  480. temperature = params.temperature
  481. if temperature < _SAMPLING_EPS:
  482. # NOTE: Zero temperature means deterministic sampling
  483. # (i.e., greedy sampling or beam search).
  484. # Set the temperature to 1 to avoid division by zero.
  485. temperature = 1.0
  486. do_temperatures |= (temperature != 1.0 or
  487. params.dynatemp_min > _SAMPLING_EPS or
  488. params.dynatemp_max > _SAMPLING_EPS)
  489. do_top_p_top_k |= (params.top_p < 1.0 - _SAMPLING_EPS or
  490. top_k != vocab_size)
  491. do_top_as |= params.top_a > 0.0
  492. do_min_p |= params.min_p > _SAMPLING_EPS
  493. do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
  494. abs(params.frequency_penalty) >= _SAMPLING_EPS or
  495. params.repetition_penalty > 1.0)
  496. do_no_repeat_ngrams |= params.no_repeat_ngram_size > 0
  497. do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
  498. do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
  499. do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
  500. do_typical_ps |= params.typical_p < 1.0 - _SAMPLING_EPS
  501. do_quadratic |= (params.smoothing_factor > _SAMPLING_EPS or
  502. params.smoothing_curve > 1.0)
  503. do_xtc |= params.xtc_probability > _SAMPLING_EPS
  504. do_nsigmas |= params.nsigma > _SAMPLING_EPS
  505. do_dry |= params.dry_multiplier > _SAMPLING_EPS
  506. do_skews |= abs(params.skew) > _SAMPLING_EPS
  507. do_temp_last |= params.temperature_last
  508. is_prompt = seq_group.is_prompt
  509. wants_prompt_logprobs = params.prompt_logprobs is not None
  510. n_seqs = 0
  511. if seq_group.is_prompt and wants_prompt_logprobs:
  512. assert seq_group.query_len is not None
  513. n_seqs += len(seq_group.prompt_logprob_indices)
  514. if seq_group.do_sample:
  515. assert len(seq_group.sample_indices) == len(seq_ids)
  516. n_seqs += len(seq_ids)
  517. temperatures += [temperature] * n_seqs
  518. dynatemp_mins += [params.dynatemp_min] * n_seqs
  519. dynatemp_maxs += [params.dynatemp_max] * n_seqs
  520. dynatemp_exps += [params.dynatemp_exponent] * n_seqs
  521. temperature_lasts += [params.temperature_last] * n_seqs
  522. top_ps += [params.top_p] * n_seqs
  523. top_ks += [top_k] * n_seqs
  524. top_as += [params.top_a] * n_seqs
  525. min_ps += [params.min_p] * n_seqs
  526. presence_penalties += [params.presence_penalty] * n_seqs
  527. frequency_penalties += [params.frequency_penalty] * n_seqs
  528. repetition_penalties += [params.repetition_penalty] * n_seqs
  529. no_repeat_ngram_sizes += [params.no_repeat_ngram_size] * n_seqs
  530. tfss += [params.tfs] * n_seqs
  531. eta_cutoffs += [params.eta_cutoff] * n_seqs
  532. epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
  533. typical_ps += [params.typical_p] * n_seqs
  534. smoothing_factors += [params.smoothing_factor] * n_seqs
  535. smoothing_curves += [params.smoothing_curve] * n_seqs
  536. xtc_thresholds += [params.xtc_threshold] * n_seqs
  537. xtc_probabilities += [params.xtc_probability] * n_seqs
  538. nsigmas += [params.nsigma] * n_seqs
  539. dry_multipliers += [params.dry_multiplier] * n_seqs
  540. dry_bases += [params.dry_base] * n_seqs
  541. dry_allowed_lengths += [params.dry_allowed_length] * n_seqs
  542. dry_sequence_breaker_ids += (
  543. [params.dry_sequence_breaker_ids] * n_seqs)
  544. skews += [params.skew] * n_seqs
  545. if _USE_TRITON_SAMPLER:
  546. if is_prompt:
  547. prompt_best_of.append(params.best_of)
  548. query_len = seq_group.query_len
  549. assert query_len is not None
  550. seed = params.seed
  551. is_greedy = params.sampling_type == SamplingType.GREEDY
  552. for seq_id in seq_ids:
  553. seq_data = seq_group.seq_data[seq_id]
  554. extra_entropy = extra_entropy or ()
  555. seq_seeds = cls._get_sequence_seeds(
  556. seed,
  557. seq_data.get_len(),
  558. *extra_entropy,
  559. seq_id,
  560. seeds_to_generate=seeds_to_generate,
  561. is_greedy=is_greedy)
  562. sampling_seeds.append(seq_seeds)
  563. sample_indices.extend(seq_group.sample_indices)
  564. if do_penalties or do_dry or do_no_repeat_ngrams:
  565. for seq_group in sampling_metadata.seq_groups:
  566. seq_ids = seq_group.seq_ids
  567. if (seq_group.is_prompt
  568. and params.prompt_logprobs is not None):
  569. prefill_len = len(seq_group.prompt_logprob_indices)
  570. prompt_tokens.extend(
  571. array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  572. for _ in range(prefill_len))
  573. output_tokens.extend(
  574. array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  575. for _ in range(prefill_len))
  576. if seq_group.do_sample:
  577. for seq_id in seq_ids:
  578. seq_data = seq_group.seq_data[seq_id]
  579. prompt_tokens.append(seq_data.prompt_token_ids_array)
  580. output_tokens.append(seq_data.output_token_ids_array)
  581. sampling_tensors = SamplingTensors.from_lists(
  582. temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
  583. temperature_lasts, top_ps, top_ks, top_as, min_ps,
  584. presence_penalties, frequency_penalties, repetition_penalties,
  585. no_repeat_ngram_sizes, tfss, eta_cutoffs, epsilon_cutoffs,
  586. typical_ps, smoothing_factors, smoothing_curves, xtc_thresholds,
  587. xtc_probabilities, nsigmas, dry_multipliers, dry_bases,
  588. dry_allowed_lengths, dry_sequence_breaker_ids, skews,
  589. sampling_seeds, sample_indices, prompt_tokens, output_tokens,
  590. vocab_size, extra_seeds_to_generate, device, dtype)
  591. return (sampling_tensors, do_penalties, do_no_repeat_ngrams,
  592. do_temperatures, do_top_p_top_k, do_top_as, do_min_p,
  593. do_tfss, do_eta_cutoffs, do_epsilon_cutoffs, do_typical_ps,
  594. do_quadratic, do_xtc, do_nsigmas, do_dry, do_skews,
  595. do_temp_last)
  596. @classmethod
  597. def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
  598. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  599. temperature_lasts: List[bool], top_ps: List[float],
  600. top_ks: List[int], top_as: List[float],
  601. min_ps: List[float], presence_penalties: List[float],
  602. frequency_penalties: List[float],
  603. repetition_penalties: List[float],
  604. no_repeat_ngram_sizes: List[int], tfss: List[float],
  605. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  606. typical_ps: List[float], smoothing_factors: List[float],
  607. smoothing_curves: List[float], xtc_thresholds: List[float],
  608. xtc_probabilities: List[float], nsigmas: List[float],
  609. dry_multipliers: List[float], dry_bases: List[float],
  610. dry_allowed_lengths: List[int],
  611. dry_sequence_breaker_ids: List[List[int]],
  612. skews: List[float], sampling_seeds: List[List[int]],
  613. sample_indices: List[int], prompt_tokens: List[array],
  614. output_tokens: List[array], vocab_size: int,
  615. extra_seeds_to_generate: int, device: torch.device,
  616. dtype: torch.dtype) -> "SamplingTensors":
  617. # Note that the performance will be very bad without
  618. # pinned memory.
  619. pin_memory = is_pin_memory_available()
  620. do_penalties = prompt_tokens or output_tokens
  621. if do_penalties:
  622. prompt_t = make_tensor_with_pad(
  623. prompt_tokens,
  624. vocab_size,
  625. device="cpu",
  626. dtype=torch.int64,
  627. pin_memory=pin_memory,
  628. )
  629. output_t = make_tensor_with_pad(
  630. output_tokens,
  631. vocab_size,
  632. device="cpu",
  633. dtype=torch.int64,
  634. pin_memory=pin_memory,
  635. )
  636. else:
  637. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  638. prompt_t = empty_tensor
  639. output_t = empty_tensor
  640. temperatures_t = torch.tensor(
  641. temperatures,
  642. device="cpu",
  643. dtype=dtype,
  644. pin_memory=pin_memory,
  645. )
  646. dynatemp_mins_t = torch.tensor(
  647. dynatemp_mins,
  648. device="cpu",
  649. dtype=dtype,
  650. pin_memory=pin_memory,
  651. )
  652. dynatemp_maxs_t = torch.tensor(
  653. dynatemp_maxs,
  654. device="cpu",
  655. dtype=dtype,
  656. pin_memory=pin_memory,
  657. )
  658. dynatemp_exps_t = torch.tensor(
  659. dynatemp_exps,
  660. device="cpu",
  661. dtype=dtype,
  662. pin_memory=pin_memory,
  663. )
  664. temp_lasts_t = torch.tensor(
  665. temperature_lasts,
  666. device="cpu",
  667. dtype=torch.bool,
  668. pin_memory=pin_memory,
  669. )
  670. top_ps_t = torch.tensor(
  671. top_ps,
  672. device="cpu",
  673. dtype=dtype,
  674. pin_memory=pin_memory,
  675. )
  676. top_as_t = torch.tensor(top_as,
  677. device="cpu",
  678. dtype=dtype,
  679. pin_memory=pin_memory)
  680. min_ps_t = torch.tensor(
  681. min_ps,
  682. device="cpu",
  683. dtype=dtype,
  684. pin_memory=pin_memory,
  685. )
  686. presence_penalties_t = torch.tensor(
  687. presence_penalties,
  688. device="cpu",
  689. dtype=dtype,
  690. pin_memory=pin_memory,
  691. )
  692. frequency_penalties_t = torch.tensor(
  693. frequency_penalties,
  694. device="cpu",
  695. dtype=dtype,
  696. pin_memory=pin_memory,
  697. )
  698. repetition_penalties_t = torch.tensor(
  699. repetition_penalties,
  700. device="cpu",
  701. dtype=dtype,
  702. pin_memory=pin_memory,
  703. )
  704. no_repeat_ngram_sizes_t = torch.tensor(
  705. no_repeat_ngram_sizes,
  706. device="cpu",
  707. dtype=torch.int,
  708. pin_memory=pin_memory,
  709. )
  710. top_ks_t = torch.tensor(
  711. top_ks,
  712. device="cpu",
  713. dtype=torch.int,
  714. pin_memory=pin_memory,
  715. )
  716. tfss_t = torch.tensor(tfss,
  717. device="cpu",
  718. dtype=dtype,
  719. pin_memory=pin_memory)
  720. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  721. device="cpu",
  722. dtype=dtype,
  723. pin_memory=pin_memory)
  724. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  725. device="cpu",
  726. dtype=dtype,
  727. pin_memory=pin_memory)
  728. typical_ps_t = torch.tensor(typical_ps,
  729. device="cpu",
  730. dtype=dtype,
  731. pin_memory=pin_memory)
  732. smoothing_factors_t = torch.tensor(smoothing_factors,
  733. device="cpu",
  734. dtype=dtype,
  735. pin_memory=pin_memory)
  736. smoothing_curves_t = torch.tensor(smoothing_curves,
  737. device="cpu",
  738. dtype=dtype,
  739. pin_memory=pin_memory)
  740. xtc_thresholds_t = torch.tensor(xtc_thresholds,
  741. device="cpu",
  742. dtype=dtype,
  743. pin_memory=pin_memory)
  744. xtc_probabilities_t = torch.tensor(xtc_probabilities,
  745. device="cpu",
  746. dtype=dtype,
  747. pin_memory=pin_memory)
  748. nsigmas_t = torch.tensor(nsigmas,
  749. device="cpu",
  750. dtype=dtype,
  751. pin_memory=pin_memory)
  752. dry_multipliers_t = torch.tensor(
  753. dry_multipliers,
  754. device="cpu",
  755. dtype=dtype,
  756. pin_memory=pin_memory,
  757. )
  758. dry_bases_t = torch.tensor(
  759. dry_bases,
  760. device="cpu",
  761. dtype=dtype,
  762. pin_memory=pin_memory,
  763. )
  764. dry_allowed_lengths_t = torch.tensor(
  765. dry_allowed_lengths,
  766. device="cpu",
  767. dtype=torch.int,
  768. pin_memory=pin_memory,
  769. )
  770. dry_sequence_breakers_t = torch.tensor(
  771. dry_sequence_breaker_ids,
  772. device="cpu",
  773. dtype=torch.long,
  774. pin_memory=pin_memory,
  775. )
  776. skews_t = torch.tensor(
  777. skews,
  778. device="cpu",
  779. dtype=dtype,
  780. pin_memory=pin_memory,
  781. )
  782. sample_indices_t = torch.tensor(
  783. sample_indices,
  784. device="cpu",
  785. dtype=torch.long,
  786. pin_memory=pin_memory,
  787. )
  788. # need to transpose and make contiguous to
  789. # copy the tensor correctly.
  790. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  791. sampling_seeds_t = torch.tensor(
  792. sampling_seeds,
  793. device="cpu",
  794. dtype=torch.long,
  795. pin_memory=pin_memory,
  796. ).t().contiguous()
  797. # Because the memory is pinned, we can do non-blocking
  798. # transfer to device.
  799. # How many seeds the sample operation itself will need.
  800. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  801. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  802. non_blocking=True)
  803. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  804. if not extra_seeds_gpu.numel():
  805. extra_seeds_gpu = None
  806. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  807. return cls(
  808. temperatures=temperatures_t.to(device=device, non_blocking=True),
  809. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  810. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  811. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  812. temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
  813. top_ps=top_ps_t.to(device=device, non_blocking=True),
  814. top_ks=top_ks_t.to(device=device, non_blocking=True),
  815. top_as=top_as_t.to(device=device, non_blocking=True),
  816. min_ps=min_ps_t.to(device=device, non_blocking=True),
  817. presence_penalties=presence_penalties_t.to(device=device,
  818. non_blocking=True),
  819. frequency_penalties=frequency_penalties_t.to(device=device,
  820. non_blocking=True),
  821. repetition_penalties=repetition_penalties_t.to(device=device,
  822. non_blocking=True),
  823. no_repeat_ngram_sizes=no_repeat_ngram_sizes_t.to(device=device,
  824. non_blocking=True),
  825. tfss=tfss_t.to(device=device, non_blocking=True),
  826. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  827. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  828. non_blocking=True),
  829. smoothing_factors=smoothing_factors_t.to(device=device,
  830. non_blocking=True),
  831. smoothing_curves=smoothing_curves_t.to(device=device,
  832. non_blocking=True),
  833. xtc_thresholds=xtc_thresholds_t.to(device=device,
  834. non_blocking=True),
  835. xtc_probabilities=xtc_probabilities_t.to(device=device,
  836. non_blocking=True),
  837. nsigmas=nsigmas_t.to(device=device, non_blocking=True),
  838. dry_multipliers=dry_multipliers_t.to(device=device,
  839. non_blocking=True),
  840. dry_bases=dry_bases_t.to(device=device, non_blocking=True),
  841. dry_allowed_lengths=dry_allowed_lengths_t.to(device=device,
  842. non_blocking=True),
  843. dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device,
  844. non_blocking=True),
  845. skews=skews_t.to(device=device, non_blocking=True),
  846. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  847. prompt_tokens=prompt_t.to(device=device, non_blocking=True),
  848. output_tokens=output_t.to(device=device, non_blocking=True),
  849. sampling_seeds=sampling_seeds_gpu,
  850. sample_indices=sample_indices_t.to(device=device,
  851. non_blocking=True),
  852. extra_seeds=extra_seeds_gpu,
  853. )
  854. @staticmethod
  855. def _get_sequence_seeds(
  856. seed: int|None,
  857. *extra_entropy: int,
  858. seeds_to_generate: int,
  859. is_greedy: bool,
  860. ):
  861. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  862. if not is_greedy:
  863. if seed is None:
  864. randint_fn = random.randint
  865. else:
  866. generator = random.Random(str((seed, ) + extra_entropy))
  867. randint_fn = generator.randint
  868. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  869. # If the user/random sets seed = 0 but request should
  870. # have sampling, we need to change it to something
  871. # else. We use a constant in that case.
  872. # This way we don't need to create and load a bool
  873. # matrix in the sampling kernel, which reduces CPU
  874. # overhead and latency.
  875. seq_seeds = [
  876. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  877. for _ in range(seeds_to_generate)
  878. ]
  879. else:
  880. # For the kernel, seed == 0 means greedy decoding.
  881. seq_seeds = [0] * seeds_to_generate
  882. return seq_seeds