1
0

sampling_metadata.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. from array import array
  2. from dataclasses import dataclass
  3. from typing import Dict, List, Optional, Tuple
  4. import torch
  5. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  6. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  7. from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
  8. is_pin_memory_available,
  9. make_tensor_with_pad)
  10. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  11. _SAMPLING_EPS = 1e-5
  12. @dataclass
  13. class SequenceGroupToSample:
  14. # |---------- N-1 iteration --------|
  15. # |---------------- N iteration ---------------------|
  16. # |- tokenA -|......................|-- newTokens ---|
  17. # |---------- context_len ----------|
  18. # |-------------------- seq_len ----------------------|
  19. # |-- query_len ---|
  20. # Sequence ids for the sequence group in a previous step.
  21. seq_ids: List[int]
  22. sampling_params: SamplingParams
  23. # seq_id -> sequence data.
  24. seq_data: Dict[int, SequenceData]
  25. # The length of the sequence (all tokens seen in the past + new token to
  26. # compute attention) of the sequence group. None if it is in a decode
  27. # stage.
  28. seq_len: Optional[int]
  29. # The length of new query tokens to compute in the current step. None if it
  30. # is in a decode stage. The length of query_len <= seq_len if chunked
  31. # prefill is enabled.
  32. query_len: Optional[int]
  33. # A random number generator for sampling.
  34. generator: Optional[torch.Generator]
  35. # True if the sequence group is in prefill stage. False if it is in a
  36. # decode stage.
  37. is_prompt: bool
  38. # Query token indices from logits. to compute prompt logprob. Empty if
  39. # prompt logprob is not required.
  40. prompt_logprob_indices: List[int]
  41. # Sample token indices from logits. Empty if sampling is not required.
  42. sample_indices: List[int]
  43. @property
  44. def do_sample(self):
  45. return len(self.sample_indices) > 0
  46. def __post_init__(self):
  47. if len(self.prompt_logprob_indices) > 0:
  48. assert self.sampling_params.prompt_logprobs is not None
  49. if self.is_prompt:
  50. assert self.seq_len is not None
  51. assert self.query_len is not None
  52. def gen_seq_group_to_sample_builder(num_seqs: int):
  53. return lambda: SequenceGroupToSample(
  54. seq_ids=[0] * num_seqs,
  55. sampling_params=None,
  56. seq_data=None, # type: ignore
  57. seq_len=0,
  58. query_len=0,
  59. generator=None,
  60. is_prompt=True,
  61. prompt_logprob_indices=[],
  62. sample_indices=[])
  63. class SamplingMetadataCache:
  64. """Used to cache SamplingMetadata objects between scheduler iterations
  65. """
  66. def __init__(self):
  67. self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
  68. def get_cached_seq_group_to_sample(self, num_seqs):
  69. if num_seqs not in self._seq_group_to_sample_cache:
  70. self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
  71. gen_seq_group_to_sample_builder(num_seqs))
  72. obj = self._seq_group_to_sample_cache[num_seqs].get_object()
  73. return obj
  74. def reset(self):
  75. for cache in self._seq_group_to_sample_cache.values():
  76. cache.reset()
  77. class SamplingMetadata:
  78. """Metadata for input sequences. Used in sampler.
  79. The usage is as follow;
  80. ```
  81. hidden_states = execute_model(...)
  82. logits = hidden_states[sampling_metadata.selected_token_indices]
  83. sample(logits)
  84. def sample(logits):
  85. # Use categorized_sample_indices for sampling....
  86. ```
  87. Args:
  88. seq_groups: List of batched sequence groups.
  89. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  90. logits from the initial model output hidden states.
  91. categorized_sample_indices: SamplingType -> token indices to sample.
  92. Each token indices is 2D tensor of (num_indices, num_indices) where
  93. the first item means the sample index within the returned logit
  94. (before pruning padding), and the second item means the sample
  95. index after pruning using selected_token_indices.
  96. For example, if the returned logit is [1, 2, 3], and we select
  97. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  98. The first tuple is [1, 2] (sampled index within original logit),
  99. and the second tuple is [0, 1] (sampled index within pruned logit).
  100. num_prompts: Number of prompt sequence groups in seq_groups.
  101. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  102. serialization of token outputs.
  103. reuse_sampling_tensors: Indicates if we want to reuse sampling
  104. tensors that are part of the sampler forward pass. Currently,
  105. it is mainly used for multi-step decode.
  106. """
  107. def __init__(
  108. self,
  109. seq_groups: List[SequenceGroupToSample],
  110. selected_token_indices: torch.Tensor,
  111. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  112. num_prompts: int,
  113. skip_sampler_cpu_output: bool = False,
  114. reuse_sampling_tensors: bool = False,
  115. ) -> None:
  116. self.seq_groups = seq_groups
  117. self.selected_token_indices = selected_token_indices
  118. self.categorized_sample_indices = categorized_sample_indices
  119. self.num_prompts = num_prompts
  120. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  121. self.reuse_sampling_tensors = reuse_sampling_tensors
  122. @staticmethod
  123. def prepare(
  124. seq_group_metadata_list: List[SequenceGroupMetadata],
  125. seq_lens: List[int],
  126. query_lens: Optional[List[int]],
  127. device: str,
  128. pin_memory: bool,
  129. generators: Optional[Dict[str, torch.Generator]] = None,
  130. cache: Optional[SamplingMetadataCache] = None
  131. ) -> "SamplingMetadata":
  132. (
  133. seq_groups,
  134. selected_token_indices,
  135. categorized_sample_indices,
  136. num_prompts,
  137. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  138. device, generators, cache)
  139. selected_token_indices = async_tensor_h2d(selected_token_indices,
  140. dtype=torch.long,
  141. target_device=device,
  142. pin_memory=pin_memory)
  143. categorized_sample_indices = {
  144. t: async_tensor_h2d(
  145. seq_ids,
  146. dtype=torch.int,
  147. target_device=device,
  148. pin_memory=pin_memory,
  149. )
  150. for t, seq_ids in categorized_sample_indices.items()
  151. }
  152. sampling_metadata = SamplingMetadata(
  153. seq_groups=seq_groups,
  154. selected_token_indices=selected_token_indices,
  155. categorized_sample_indices=categorized_sample_indices,
  156. num_prompts=num_prompts,
  157. )
  158. return sampling_metadata
  159. def __repr__(self) -> str:
  160. return (
  161. "SamplingMetadata("
  162. f"seq_groups={self.seq_groups}, "
  163. f"selected_token_indices={self.selected_token_indices}, "
  164. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  165. def _prepare_seq_groups(
  166. seq_group_metadata_list: List[SequenceGroupMetadata],
  167. seq_lens: List[int],
  168. query_lens: Optional[List[int]],
  169. device: str,
  170. generators: Optional[Dict[str, torch.Generator]] = None,
  171. cache: Optional[SamplingMetadataCache] = None,
  172. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType,
  173. List[int]], int, ]:
  174. """Prepare sequence groups and indices for sampling.
  175. Args:
  176. seq_group_metadata_list: A list of sequence group to batch.
  177. seq_lens: A list of sequence lens per sequence group.
  178. Index of prompt len should match with seq_group_metadata_list.
  179. query_lens: A list of query lengths. Prompt lens include the length
  180. of entire prompt tokens, and it could be shorter.
  181. device: A device to use for random number generators,
  182. `SequenceGroupToSample.generator`.
  183. generators: A store of per-request random number generators used
  184. for seeded requests.
  185. Returns:
  186. seq_groups: A list of sequence group to sample.
  187. selected_token_indices: See the definition from `SamplingMetadata`.
  188. categorized_sample_indices: See the definition from `SamplingMetadata`.
  189. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  190. """
  191. # Batched sequence groups for the current model forward stsep.
  192. seq_groups: List[SequenceGroupToSample] = []
  193. # A list of token indices to sample/compute logprob. It is used to
  194. # prune the outcome logits from the model for the performance.
  195. selected_token_indices: List[int] = []
  196. # Used for selected_token_indices.
  197. model_output_idx = 0
  198. # Sampling type -> (
  199. # indices to sample/prompt logprob within pruned output logits,
  200. # indices to sample within pruned logits)
  201. categorized_sample_indices: Dict[SamplingType, List[int]] = {
  202. t: []
  203. for t in SamplingType
  204. }
  205. # Index of logits to compute logprob. Logits include both prompt logprob
  206. # and sample logprob indices.
  207. logit_idx = 0
  208. # Total number of prompts from given sequence groups.
  209. num_prompts = 0
  210. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  211. seq_ids = seq_group_metadata.seq_data.keys()
  212. if cache is not None:
  213. sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
  214. for j, seq_id in enumerate(seq_ids):
  215. sample_obj.seq_ids[j] = seq_id
  216. sample_obj.prompt_logprob_indices.clear()
  217. sample_obj.sample_indices.clear()
  218. sampling_params = seq_group_metadata.sampling_params
  219. is_prompt = seq_group_metadata.is_prompt
  220. generator: Optional[torch.Generator] = None
  221. # If the current seq group is in decode stage, it is None.
  222. seq_len: Optional[int] = None
  223. query_len: Optional[int] = None
  224. prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
  225. if cache is not None else [])
  226. sample_indices: List[int] = (sample_obj.sample_indices
  227. if cache is not None else [])
  228. do_sample = seq_group_metadata.do_sample
  229. if seq_group_metadata.is_prompt:
  230. if sampling_params.seed is not None:
  231. generator = torch.Generator(device=device).manual_seed(
  232. sampling_params.seed)
  233. if generators is not None:
  234. generators[seq_group_metadata.request_id] = generator
  235. num_prompts += 1
  236. num_prefill_sample = len(seq_ids)
  237. assert num_prefill_sample == 1
  238. assert query_lens is not None and seq_lens is not None
  239. query_len, seq_len = query_lens[i], seq_lens[i]
  240. # If we need sampling, exclude num_prefill_sample tokens from
  241. # prompt logprob.
  242. prompt_logprob_len = (query_len - num_prefill_sample
  243. if do_sample else query_len)
  244. sample_len = num_prefill_sample if do_sample else 0
  245. else:
  246. # Decode
  247. prompt_logprob_len = 0
  248. sample_len = len(seq_ids) if do_sample else 0
  249. if sampling_params.seed is not None and generators is not None:
  250. generator = generators.get(seq_group_metadata.request_id)
  251. # Update indices to select from the model output.
  252. """
  253. This blocks computes selected_token_indices which is used in the
  254. following way.
  255. hidden_states = model(...)
  256. logits = hidden_states[selected_token_indices]
  257. """
  258. if sampling_params.prompt_logprobs is not None:
  259. selected_token_indices.extend(
  260. range(model_output_idx, model_output_idx + prompt_logprob_len))
  261. model_output_idx += prompt_logprob_len
  262. if do_sample:
  263. selected_token_indices.extend(
  264. range(model_output_idx, model_output_idx + sample_len))
  265. model_output_idx += sample_len
  266. # We now find indices for logprob computation and sampling.
  267. """
  268. This block computes categorized_sample_indices which is used in the
  269. following way.
  270. hidden_states = model(...)
  271. logits = hidden_states[selected_token_indices]
  272. def sample(logits):
  273. # Use categorized_sample_indices for sampling.
  274. # prompt_logprob_indices to find prompt logprob indices.
  275. # sample_indices to find sample indices.
  276. """
  277. if sampling_params.prompt_logprobs is not None:
  278. prompt_logprob_indices.extend(
  279. range(logit_idx, logit_idx + prompt_logprob_len))
  280. logit_idx += prompt_logprob_len
  281. if do_sample:
  282. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  283. categorized_sample_indices[sampling_params.sampling_type].extend(
  284. list(range(logit_idx, logit_idx + sample_len)))
  285. logit_idx += sample_len
  286. if cache is not None:
  287. sample_obj.sampling_params = sampling_params
  288. sample_obj.seq_data = seq_group_metadata.seq_data
  289. sample_obj.seq_len = seq_len
  290. sample_obj.query_len = query_len
  291. sample_obj.generator = generator
  292. sample_obj.is_prompt = is_prompt
  293. else:
  294. sample_obj = SequenceGroupToSample(
  295. seq_ids=list(seq_ids),
  296. sampling_params=sampling_params,
  297. seq_data=seq_group_metadata.seq_data,
  298. seq_len=seq_len,
  299. query_len=query_len,
  300. generator=generator,
  301. is_prompt=is_prompt,
  302. prompt_logprob_indices=list(prompt_logprob_indices),
  303. sample_indices=list(sample_indices),
  304. )
  305. seq_groups.append(sample_obj)
  306. if cache is not None:
  307. cache.reset()
  308. return (seq_groups, selected_token_indices, categorized_sample_indices,
  309. num_prompts)
  310. @dataclass
  311. class SamplingTensors:
  312. """Tensors for sampling."""
  313. temperatures: torch.Tensor
  314. dynatemp_mins: torch.Tensor
  315. dynatemp_maxs: torch.Tensor
  316. dynatemp_exps: torch.Tensor
  317. temperature_lasts: torch.Tensor
  318. top_ps: torch.Tensor
  319. top_ks: torch.Tensor
  320. top_as: torch.Tensor
  321. min_ps: torch.Tensor
  322. presence_penalties: torch.Tensor
  323. frequency_penalties: torch.Tensor
  324. repetition_penalties: torch.Tensor
  325. no_repeat_ngram_sizes: torch.Tensor
  326. tfss: torch.Tensor
  327. eta_cutoffs: torch.Tensor
  328. epsilon_cutoffs: torch.Tensor
  329. typical_ps: torch.Tensor
  330. smoothing_factors: torch.Tensor
  331. smoothing_curves: torch.Tensor
  332. xtc_thresholds: torch.Tensor
  333. xtc_probabilities: torch.Tensor
  334. nsigmas: torch.Tensor
  335. dry_multipliers: torch.Tensor
  336. dry_bases: torch.Tensor
  337. dry_allowed_lengths: torch.Tensor
  338. dry_sequence_breaker_ids: torch.Tensor
  339. dry_ranges: torch.Tensor
  340. dry_max_ngram: torch.Tensor
  341. dry_max_occurrences: torch.Tensor
  342. dry_early_exit_match_len: torch.Tensor
  343. skews: torch.Tensor
  344. prompt_tokens: torch.Tensor
  345. output_tokens: torch.Tensor
  346. @classmethod
  347. def from_sampling_metadata(
  348. cls,
  349. sampling_metadata: "SamplingMetadata",
  350. vocab_size: int,
  351. device: torch.device,
  352. dtype: torch.dtype,
  353. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  354. bool, bool, bool, bool, bool, bool, bool, bool, bool]:
  355. prompt_tokens: List[array] = []
  356. output_tokens: List[array] = []
  357. top_ks: List[int] = []
  358. temperatures: List[float] = []
  359. dynatemp_mins: List[float] = []
  360. dynatemp_maxs: List[float] = []
  361. dynatemp_exps: List[float] = []
  362. temperature_lasts: List[bool] = []
  363. top_ps: List[float] = []
  364. top_as: List[float] = []
  365. min_ps: List[float] = []
  366. presence_penalties: List[float] = []
  367. frequency_penalties: List[float] = []
  368. repetition_penalties: List[float] = []
  369. no_repeat_ngram_sizes: List[int] = []
  370. tfss: List[float] = []
  371. eta_cutoffs: List[float] = []
  372. epsilon_cutoffs: List[float] = []
  373. typical_ps: List[float] = []
  374. smoothing_factors: List[float] = []
  375. smoothing_curves: List[float] = []
  376. xtc_thresholds: List[float] = []
  377. xtc_probabilities: List[float] = []
  378. nsigmas: List[float] = []
  379. dry_multipliers: List[float] = []
  380. dry_bases: List[float] = []
  381. dry_allowed_lengths: List[int] = []
  382. dry_sequence_breaker_ids: List[List[int]] = []
  383. dry_ranges: List[int] = []
  384. dry_max_ngram: List[int] = []
  385. dry_max_occurrences: List[int] = []
  386. dry_early_exit_match_len: List[int] = []
  387. skews: List[float] = []
  388. do_penalties = False
  389. do_no_repeat_ngrams = False
  390. do_temperatures = False
  391. do_top_p_top_k = False
  392. do_top_as = False
  393. do_min_p = False
  394. do_tfss = False
  395. do_eta_cutoffs = False
  396. do_epsilon_cutoffs = False
  397. do_typical_ps = False
  398. do_quadratic = False
  399. do_xtc = False
  400. do_nsigmas = False
  401. do_dry = False
  402. do_skews = False
  403. do_temp_last = False
  404. assert sampling_metadata.seq_groups is not None
  405. for seq_group in sampling_metadata.seq_groups:
  406. seq_ids = seq_group.seq_ids
  407. params = seq_group.sampling_params
  408. # k should not be greater than the vocab size.
  409. top_k = min(params.top_k, vocab_size)
  410. top_k = vocab_size if top_k == -1 else top_k
  411. temperature = params.temperature
  412. if temperature < _SAMPLING_EPS:
  413. # NOTE: Zero temperature means deterministic sampling
  414. # (i.e., greedy sampling or beam search).
  415. # Set the temperature to 1 to avoid division by zero.
  416. temperature = 1.0
  417. do_temperatures |= (temperature != 1.0 or
  418. params.dynatemp_min > _SAMPLING_EPS or
  419. params.dynatemp_max > _SAMPLING_EPS)
  420. do_top_p_top_k |= (params.top_p < 1.0 - _SAMPLING_EPS or
  421. top_k != vocab_size)
  422. do_top_as |= params.top_a > 0.0
  423. do_min_p |= params.min_p > _SAMPLING_EPS
  424. do_penalties |= (abs(params.presence_penalty) >= _SAMPLING_EPS or
  425. abs(params.frequency_penalty) >= _SAMPLING_EPS or
  426. params.repetition_penalty > 1.0)
  427. do_no_repeat_ngrams |= params.no_repeat_ngram_size > 0
  428. do_tfss |= params.tfs < 1.0 - _SAMPLING_EPS
  429. do_eta_cutoffs |= params.eta_cutoff > _SAMPLING_EPS
  430. do_epsilon_cutoffs |= params.epsilon_cutoff > _SAMPLING_EPS
  431. do_typical_ps |= params.typical_p < 1.0 - _SAMPLING_EPS
  432. do_quadratic |= (params.smoothing_factor > _SAMPLING_EPS or
  433. params.smoothing_curve > 1.0)
  434. do_xtc |= params.xtc_probability > _SAMPLING_EPS
  435. do_nsigmas |= params.nsigma > _SAMPLING_EPS
  436. do_dry |= params.dry_multiplier > _SAMPLING_EPS
  437. do_skews |= abs(params.skew) > _SAMPLING_EPS
  438. do_temp_last |= params.temperature_last
  439. wants_prompt_logprobs = params.prompt_logprobs is not None
  440. n_seqs = 0
  441. if seq_group.is_prompt and wants_prompt_logprobs:
  442. assert seq_group.query_len is not None
  443. n_seqs += len(seq_group.prompt_logprob_indices)
  444. if seq_group.do_sample:
  445. assert len(seq_group.sample_indices) == len(seq_ids)
  446. n_seqs += len(seq_ids)
  447. temperatures += [temperature] * n_seqs
  448. dynatemp_mins += [params.dynatemp_min] * n_seqs
  449. dynatemp_maxs += [params.dynatemp_max] * n_seqs
  450. dynatemp_exps += [params.dynatemp_exponent] * n_seqs
  451. temperature_lasts += [params.temperature_last] * n_seqs
  452. top_ps += [params.top_p] * n_seqs
  453. top_ks += [top_k] * n_seqs
  454. top_as += [params.top_a] * n_seqs
  455. min_ps += [params.min_p] * n_seqs
  456. presence_penalties += [params.presence_penalty] * n_seqs
  457. frequency_penalties += [params.frequency_penalty] * n_seqs
  458. repetition_penalties += [params.repetition_penalty] * n_seqs
  459. no_repeat_ngram_sizes += [params.no_repeat_ngram_size] * n_seqs
  460. tfss += [params.tfs] * n_seqs
  461. eta_cutoffs += [params.eta_cutoff] * n_seqs
  462. epsilon_cutoffs += [params.epsilon_cutoff] * n_seqs
  463. typical_ps += [params.typical_p] * n_seqs
  464. smoothing_factors += [params.smoothing_factor] * n_seqs
  465. smoothing_curves += [params.smoothing_curve] * n_seqs
  466. xtc_thresholds += [params.xtc_threshold] * n_seqs
  467. xtc_probabilities += [params.xtc_probability] * n_seqs
  468. nsigmas += [params.nsigma] * n_seqs
  469. dry_multipliers += [params.dry_multiplier] * n_seqs
  470. dry_bases += [params.dry_base] * n_seqs
  471. dry_allowed_lengths += [params.dry_allowed_length] * n_seqs
  472. dry_sequence_breaker_ids += (
  473. [params.dry_sequence_breaker_ids] * n_seqs)
  474. dry_ranges += [params.dry_range] * n_seqs
  475. dry_max_ngram += [params.dry_max_ngram] * n_seqs
  476. dry_max_occurrences += [params.dry_max_occurrences] * n_seqs
  477. dry_early_exit_match_len += [
  478. params.dry_early_exit_match_len] * n_seqs
  479. skews += [params.skew] * n_seqs
  480. if do_penalties or do_dry or do_no_repeat_ngrams:
  481. for seq_group in sampling_metadata.seq_groups:
  482. seq_ids = seq_group.seq_ids
  483. if (seq_group.is_prompt
  484. and params.prompt_logprobs is not None):
  485. prefill_len = len(seq_group.prompt_logprob_indices)
  486. prompt_tokens.extend(
  487. array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  488. for _ in range(prefill_len))
  489. output_tokens.extend(
  490. array(APHRODITE_TOKEN_ID_ARRAY_TYPE)
  491. for _ in range(prefill_len))
  492. if seq_group.do_sample:
  493. for seq_id in seq_ids:
  494. seq_data = seq_group.seq_data[seq_id]
  495. prompt_tokens.append(seq_data.prompt_token_ids_array)
  496. output_tokens.append(seq_data.output_token_ids_array)
  497. sampling_tensors = SamplingTensors.from_lists(
  498. temperatures,
  499. dynatemp_mins,
  500. dynatemp_maxs,
  501. dynatemp_exps,
  502. temperature_lasts,
  503. top_ps,
  504. top_ks,
  505. top_as,
  506. min_ps,
  507. presence_penalties,
  508. frequency_penalties,
  509. repetition_penalties,
  510. no_repeat_ngram_sizes,
  511. tfss,
  512. eta_cutoffs,
  513. epsilon_cutoffs,
  514. typical_ps,
  515. smoothing_factors,
  516. smoothing_curves,
  517. xtc_thresholds,
  518. xtc_probabilities,
  519. nsigmas,
  520. dry_multipliers,
  521. dry_bases,
  522. dry_allowed_lengths,
  523. dry_sequence_breaker_ids,
  524. dry_ranges,
  525. dry_max_ngram,
  526. dry_max_occurrences,
  527. dry_early_exit_match_len,
  528. skews,
  529. prompt_tokens,
  530. output_tokens,
  531. vocab_size,
  532. device,
  533. dtype)
  534. return (
  535. sampling_tensors,
  536. do_penalties,
  537. do_no_repeat_ngrams,
  538. do_temperatures,
  539. do_top_p_top_k,
  540. do_top_as,
  541. do_min_p,
  542. do_tfss,
  543. do_eta_cutoffs,
  544. do_epsilon_cutoffs,
  545. do_typical_ps,
  546. do_quadratic,
  547. do_xtc,
  548. do_nsigmas,
  549. do_dry,
  550. do_skews,
  551. do_temp_last)
  552. @classmethod
  553. def from_lists(
  554. cls,
  555. temperatures: List[float],
  556. dynatemp_mins: List[float],
  557. dynatemp_maxs: List[float],
  558. dynatemp_exps: List[float],
  559. temperature_lasts: List[bool],
  560. top_ps: List[float],
  561. top_ks: List[int],
  562. top_as: List[float],
  563. min_ps: List[float],
  564. presence_penalties: List[float],
  565. frequency_penalties: List[float],
  566. repetition_penalties: List[float],
  567. no_repeat_ngram_sizes: List[int],
  568. tfss: List[float],
  569. eta_cutoffs: List[float],
  570. epsilon_cutoffs: List[float],
  571. typical_ps: List[float],
  572. smoothing_factors: List[float],
  573. smoothing_curves: List[float],
  574. xtc_thresholds: List[float],
  575. xtc_probabilities: List[float],
  576. nsigmas: List[float],
  577. dry_multipliers: List[float],
  578. dry_bases: List[float],
  579. dry_allowed_lengths: List[int],
  580. dry_sequence_breaker_ids: List[List[int]],
  581. dry_ranges: List[int],
  582. dry_max_ngram: List[int],
  583. dry_max_occurrences: List[int],
  584. dry_early_exit_match_len: List[int],
  585. skews: List[float],
  586. prompt_tokens: List[array],
  587. output_tokens: List[array],
  588. vocab_size: int,
  589. device: torch.device,
  590. dtype: torch.dtype) -> "SamplingTensors":
  591. # Note that the performance will be very bad without
  592. # pinned memory.
  593. pin_memory = is_pin_memory_available()
  594. do_penalties = prompt_tokens or output_tokens
  595. if do_penalties:
  596. prompt_t = make_tensor_with_pad(
  597. prompt_tokens,
  598. vocab_size,
  599. device="cpu",
  600. dtype=torch.int64,
  601. pin_memory=pin_memory,
  602. )
  603. output_t = make_tensor_with_pad(
  604. output_tokens,
  605. vocab_size,
  606. device="cpu",
  607. dtype=torch.int64,
  608. pin_memory=pin_memory,
  609. )
  610. else:
  611. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  612. prompt_t = empty_tensor
  613. output_t = empty_tensor
  614. temperatures_t = torch.tensor(
  615. temperatures,
  616. device="cpu",
  617. dtype=dtype,
  618. pin_memory=pin_memory,
  619. )
  620. dynatemp_mins_t = torch.tensor(
  621. dynatemp_mins,
  622. device="cpu",
  623. dtype=dtype,
  624. pin_memory=pin_memory,
  625. )
  626. dynatemp_maxs_t = torch.tensor(
  627. dynatemp_maxs,
  628. device="cpu",
  629. dtype=dtype,
  630. pin_memory=pin_memory,
  631. )
  632. dynatemp_exps_t = torch.tensor(
  633. dynatemp_exps,
  634. device="cpu",
  635. dtype=dtype,
  636. pin_memory=pin_memory,
  637. )
  638. temp_lasts_t = torch.tensor(
  639. temperature_lasts,
  640. device="cpu",
  641. dtype=torch.bool,
  642. pin_memory=pin_memory,
  643. )
  644. top_ps_t = torch.tensor(
  645. top_ps,
  646. device="cpu",
  647. dtype=dtype,
  648. pin_memory=pin_memory,
  649. )
  650. top_as_t = torch.tensor(top_as,
  651. device="cpu",
  652. dtype=dtype,
  653. pin_memory=pin_memory)
  654. min_ps_t = torch.tensor(
  655. min_ps,
  656. device="cpu",
  657. dtype=dtype,
  658. pin_memory=pin_memory,
  659. )
  660. presence_penalties_t = torch.tensor(
  661. presence_penalties,
  662. device="cpu",
  663. dtype=dtype,
  664. pin_memory=pin_memory,
  665. )
  666. frequency_penalties_t = torch.tensor(
  667. frequency_penalties,
  668. device="cpu",
  669. dtype=dtype,
  670. pin_memory=pin_memory,
  671. )
  672. repetition_penalties_t = torch.tensor(
  673. repetition_penalties,
  674. device="cpu",
  675. dtype=dtype,
  676. pin_memory=pin_memory,
  677. )
  678. no_repeat_ngram_sizes_t = torch.tensor(
  679. no_repeat_ngram_sizes,
  680. device="cpu",
  681. dtype=torch.int,
  682. pin_memory=pin_memory,
  683. )
  684. top_ks_t = torch.tensor(
  685. top_ks,
  686. device="cpu",
  687. dtype=torch.int,
  688. pin_memory=pin_memory,
  689. )
  690. tfss_t = torch.tensor(tfss,
  691. device="cpu",
  692. dtype=dtype,
  693. pin_memory=pin_memory)
  694. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  695. device="cpu",
  696. dtype=dtype,
  697. pin_memory=pin_memory)
  698. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  699. device="cpu",
  700. dtype=dtype,
  701. pin_memory=pin_memory)
  702. typical_ps_t = torch.tensor(typical_ps,
  703. device="cpu",
  704. dtype=dtype,
  705. pin_memory=pin_memory)
  706. smoothing_factors_t = torch.tensor(smoothing_factors,
  707. device="cpu",
  708. dtype=dtype,
  709. pin_memory=pin_memory)
  710. smoothing_curves_t = torch.tensor(smoothing_curves,
  711. device="cpu",
  712. dtype=dtype,
  713. pin_memory=pin_memory)
  714. xtc_thresholds_t = torch.tensor(xtc_thresholds,
  715. device="cpu",
  716. dtype=dtype,
  717. pin_memory=pin_memory)
  718. xtc_probabilities_t = torch.tensor(xtc_probabilities,
  719. device="cpu",
  720. dtype=dtype,
  721. pin_memory=pin_memory)
  722. nsigmas_t = torch.tensor(nsigmas,
  723. device="cpu",
  724. dtype=dtype,
  725. pin_memory=pin_memory)
  726. dry_multipliers_t = torch.tensor(
  727. dry_multipliers,
  728. device="cpu",
  729. dtype=dtype,
  730. pin_memory=pin_memory,
  731. )
  732. dry_bases_t = torch.tensor(
  733. dry_bases,
  734. device="cpu",
  735. dtype=dtype,
  736. pin_memory=pin_memory,
  737. )
  738. dry_allowed_lengths_t = torch.tensor(
  739. dry_allowed_lengths,
  740. device="cpu",
  741. dtype=torch.int,
  742. pin_memory=pin_memory,
  743. )
  744. dry_sequence_breakers_t = torch.tensor(
  745. [seq + [0] * (max(len(s) for s in
  746. dry_sequence_breaker_ids) - len(seq))
  747. for seq in dry_sequence_breaker_ids],
  748. device="cpu",
  749. dtype=torch.long,
  750. pin_memory=pin_memory,
  751. )
  752. dry_ranges_t = torch.tensor(
  753. dry_ranges,
  754. device="cpu",
  755. dtype=torch.int,
  756. pin_memory=pin_memory,
  757. )
  758. dry_max_ngram_t = torch.tensor(
  759. dry_max_ngram,
  760. device="cpu",
  761. dtype=torch.int,
  762. pin_memory=pin_memory,
  763. )
  764. dry_max_occurrences_t = torch.tensor(
  765. dry_max_occurrences,
  766. device="cpu",
  767. dtype=torch.int,
  768. pin_memory=pin_memory,
  769. )
  770. dry_early_exit_match_len_t = torch.tensor(
  771. dry_early_exit_match_len,
  772. device="cpu",
  773. dtype=torch.int,
  774. pin_memory=pin_memory,
  775. )
  776. skews_t = torch.tensor(
  777. skews,
  778. device="cpu",
  779. dtype=dtype,
  780. pin_memory=pin_memory,
  781. )
  782. # Because the memory is pinned, we can do non-blocking
  783. # transfer to device.
  784. return cls(
  785. temperatures=temperatures_t.to(device=device, non_blocking=True),
  786. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  787. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  788. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  789. temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
  790. top_ps=top_ps_t.to(device=device, non_blocking=True),
  791. top_ks=top_ks_t.to(device=device, non_blocking=True),
  792. top_as=top_as_t.to(device=device, non_blocking=True),
  793. min_ps=min_ps_t.to(device=device, non_blocking=True),
  794. presence_penalties=presence_penalties_t.to(device=device,
  795. non_blocking=True),
  796. frequency_penalties=frequency_penalties_t.to(device=device,
  797. non_blocking=True),
  798. repetition_penalties=repetition_penalties_t.to(device=device,
  799. non_blocking=True),
  800. no_repeat_ngram_sizes=no_repeat_ngram_sizes_t.to(device=device,
  801. non_blocking=True),
  802. tfss=tfss_t.to(device=device, non_blocking=True),
  803. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  804. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  805. non_blocking=True),
  806. smoothing_factors=smoothing_factors_t.to(device=device,
  807. non_blocking=True),
  808. smoothing_curves=smoothing_curves_t.to(device=device,
  809. non_blocking=True),
  810. xtc_thresholds=xtc_thresholds_t.to(device=device,
  811. non_blocking=True),
  812. xtc_probabilities=xtc_probabilities_t.to(device=device,
  813. non_blocking=True),
  814. nsigmas=nsigmas_t.to(device=device, non_blocking=True),
  815. dry_multipliers=dry_multipliers_t.to(device=device,
  816. non_blocking=True),
  817. dry_bases=dry_bases_t.to(device=device, non_blocking=True),
  818. dry_allowed_lengths=dry_allowed_lengths_t.to(device=device,
  819. non_blocking=True),
  820. dry_sequence_breaker_ids=dry_sequence_breakers_t.to(device=device,
  821. non_blocking=True),
  822. dry_ranges=dry_ranges_t.to(device=device, non_blocking=True),
  823. dry_max_ngram=dry_max_ngram_t.to(device=device, non_blocking=True),
  824. dry_max_occurrences=dry_max_occurrences_t.to(device=device,
  825. non_blocking=True),
  826. dry_early_exit_match_len=dry_early_exit_match_len_t.to(
  827. device=device, non_blocking=True),
  828. skews=skews_t.to(device=device, non_blocking=True),
  829. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  830. prompt_tokens=prompt_t.to(device=device, non_blocking=True),
  831. output_tokens=output_t.to(device=device, non_blocking=True),
  832. )