sampling_metadata.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910
  1. import random
  2. from array import array
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  7. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  8. from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
  9. is_pin_memory_available,
  10. make_tensor_with_pad, maybe_expand_dim)
  11. from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
  12. _SAMPLING_EPS = 1e-5
  13. _SEED_0_REPLACEMENT = 3403598558
  14. # Some triton sampler related code is guarded before it is ready.
  15. _USE_TRITON_SAMPLER = False
  16. @dataclass
  17. class SequenceGroupToSample:
  18. # |---------- N-1 iteration --------|
  19. # |---------------- N iteration ---------------------|
  20. # |- tokenA -|......................|-- newTokens ---|
  21. # |---------- context_len ----------|
  22. # |-------------------- seq_len ----------------------|
  23. # |-- query_len ---|
  24. # Sequence ids for the sequence group in a previous step.
  25. seq_ids: List[int]
  26. sampling_params: SamplingParams
  27. # seq_id -> sequence data.
  28. seq_data: Dict[int, SequenceData]
  29. # The length of the sequence (all tokens seen in the past + new token to
  30. # compute attention) of the sequence group. None if it is in a decode
  31. # stage.
  32. seq_len: Optional[int]
  33. # The length of new query tokens to compute in the current step. None if it
  34. # is in a decode stage. The length of query_len <= seq_len if chunked
  35. # prefill is enabled.
  36. query_len: Optional[int]
  37. # A random number generator for sampling.
  38. generator: Optional[torch.Generator]
  39. # True if the sequence group is in prefill stage. False if it is in a
  40. # decode stage.
  41. is_prompt: bool
  42. # Query token indices from logits. to compute prompt logprob. Empty if
  43. # prompt logprob is not required.
  44. prompt_logprob_indices: List[int]
  45. # Sample token indices from logits. Empty if sampling is not required.
  46. sample_indices: List[int]
  47. @property
  48. def do_sample(self):
  49. return len(self.sample_indices) > 0
  50. def __post_init__(self):
  51. if len(self.prompt_logprob_indices) > 0:
  52. assert self.sampling_params.prompt_logprobs is not None
  53. if self.is_prompt:
  54. assert self.seq_len is not None
  55. assert self.query_len is not None
  56. def gen_seq_group_to_sample_builder(num_seqs: int):
  57. return lambda: SequenceGroupToSample(
  58. seq_ids=[0] * num_seqs,
  59. sampling_params=None,
  60. seq_data=None, # type: ignore
  61. seq_len=0,
  62. query_len=0,
  63. generator=None,
  64. is_prompt=True,
  65. prompt_logprob_indices=[],
  66. sample_indices=[])
  67. class SamplingMetadataCache:
  68. """Used to cache SamplingMetadata objects between scheduler iterations
  69. """
  70. def __init__(self):
  71. self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
  72. def get_cached_seq_group_to_sample(self, num_seqs):
  73. if num_seqs not in self._seq_group_to_sample_cache:
  74. self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
  75. gen_seq_group_to_sample_builder(num_seqs))
  76. obj = self._seq_group_to_sample_cache[num_seqs].get_object()
  77. return obj
  78. def reset(self):
  79. for cache in self._seq_group_to_sample_cache.values():
  80. cache.reset()
  81. class SamplingMetadata:
  82. """Metadata for input sequences. Used in sampler.
  83. The usage is as follow;
  84. ```
  85. hidden_states = execute_model(...)
  86. logits = hidden_states[sampling_metadata.selected_token_indices]
  87. sample(logits)
  88. def sample(logits):
  89. # Use categorized_sample_indices for sampling....
  90. ```
  91. Args:
  92. seq_groups: List of batched sequence groups.
  93. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  94. logits from the initial model output hidden states.
  95. categorized_sample_indices: SamplingType -> token indices to sample.
  96. Each token indices is 2D tensor of (num_indices, num_indices) where
  97. the first item means the sample index within the returned logit
  98. (before pruning padding), and the second item means the sample
  99. index after pruning using selected_token_indices.
  100. For example, if the returned logit is [1, 2, 3], and we select
  101. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  102. The first tuple is [1, 2] (sampled index within original logit),
  103. and the second tuple is [0, 1] (sampled index within pruned logit).
  104. num_prompts: Number of prompt sequence groups in seq_groups.
  105. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  106. serialization of token outputs.
  107. reuse_sampling_tensors: Indicates if we want to reuse sampling
  108. tensors that are part of the sampler forward pass. Currently,
  109. it is mainly used for multi-step decode.
  110. """
  111. def __init__(
  112. self,
  113. seq_groups: List[SequenceGroupToSample],
  114. selected_token_indices: torch.Tensor,
  115. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  116. num_prompts: int,
  117. skip_sampler_cpu_output: bool = False,
  118. reuse_sampling_tensors: bool = False,
  119. ) -> None:
  120. self.seq_groups = seq_groups
  121. self.selected_token_indices = selected_token_indices
  122. self.categorized_sample_indices = categorized_sample_indices
  123. self.num_prompts = num_prompts
  124. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  125. self.reuse_sampling_tensors = reuse_sampling_tensors
  126. @staticmethod
  127. def prepare(
  128. seq_group_metadata_list: List[SequenceGroupMetadata],
  129. seq_lens: List[int],
  130. query_lens: Optional[List[int]],
  131. device: str,
  132. pin_memory: bool,
  133. generators: Optional[Dict[str, torch.Generator]] = None,
  134. cache: Optional[SamplingMetadataCache] = None
  135. ) -> "SamplingMetadata":
  136. (
  137. seq_groups,
  138. selected_token_indices,
  139. categorized_sample_indices,
  140. num_prompts,
  141. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  142. device, generators, cache)
  143. selected_token_indices = async_tensor_h2d(selected_token_indices,
  144. dtype=torch.long,
  145. target_device=device,
  146. pin_memory=pin_memory)
  147. categorized_sample_indices = {
  148. t: maybe_expand_dim(
  149. async_tensor_h2d(seq_ids,
  150. dtype=torch.int,
  151. target_device=device,
  152. pin_memory=pin_memory), 2, 2)
  153. for t, seq_ids in categorized_sample_indices.items()
  154. }
  155. sampling_metadata = SamplingMetadata(
  156. seq_groups=seq_groups,
  157. selected_token_indices=selected_token_indices,
  158. categorized_sample_indices=categorized_sample_indices,
  159. num_prompts=num_prompts,
  160. )
  161. return sampling_metadata
  162. def __repr__(self) -> str:
  163. return (
  164. "SamplingMetadata("
  165. f"seq_groups={self.seq_groups}, "
  166. f"selected_token_indices={self.selected_token_indices}, "
  167. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  168. def _prepare_seq_groups(
  169. seq_group_metadata_list: List[SequenceGroupMetadata],
  170. seq_lens: List[int],
  171. query_lens: Optional[List[int]],
  172. device: str,
  173. generators: Optional[Dict[str, torch.Generator]] = None,
  174. cache: Optional[SamplingMetadataCache] = None,
  175. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  176. SamplingType, List[Tuple[int, int]]], int]:
  177. """Prepare sequence groups and indices for sampling.
  178. Args:
  179. seq_group_metadata_list: A list of sequence group to batch.
  180. seq_lens: A list of sequence lens per sequence group.
  181. Index of prompt len should match with seq_group_metadata_list.
  182. query_lens: A list of query lengths. Prompt lens include the length
  183. of entire prompt tokens, and it could be shorter.
  184. device: A device to use for random number generators,
  185. `SequenceGroupToSample.generator`.
  186. generators: A store of per-request random number generators used
  187. for seeded requests.
  188. Returns:
  189. seq_groups: A list of sequence group to sample.
  190. selected_token_indices: See the definition from `SamplingMetadata`.
  191. categorized_sample_indices: See the definition from `SamplingMetadata`.
  192. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  193. """
  194. # Batched sequence groups for the current model forward stsep.
  195. seq_groups: List[SequenceGroupToSample] = []
  196. # A list of token indices to sample/compute logprob. It is used to
  197. # prune the outcome logits from the model for the performance.
  198. selected_token_indices: List[int] = []
  199. # Used for selected_token_indices.
  200. model_output_idx = 0
  201. # Sampling type -> (
  202. # indices to sample/prompt logprob within pruned output logits,
  203. # indices to sample within pruned logits)
  204. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  205. t: []
  206. for t in SamplingType
  207. }
  208. # Index of logits to compute logprob. Logits include both prompt logprob
  209. # and sample logprob indices.
  210. logit_idx = 0
  211. # Index to sample from a sample tensor. It is used by triton sample kernel.
  212. # See `_sample_with_triton_kernel` for more details.
  213. sample_idx = 0
  214. # Total number of prompts from given sequence groups.
  215. num_prompts = 0
  216. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  217. seq_ids = seq_group_metadata.seq_data.keys()
  218. if cache is not None:
  219. sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
  220. for j, seq_id in enumerate(seq_ids):
  221. sample_obj.seq_ids[j] = seq_id
  222. sample_obj.prompt_logprob_indices.clear()
  223. sample_obj.sample_indices.clear()
  224. sampling_params = seq_group_metadata.sampling_params
  225. is_prompt = seq_group_metadata.is_prompt
  226. generator: Optional[torch.Generator] = None
  227. # If the current seq group is in decode stage, it is None.
  228. seq_len: Optional[int] = None
  229. query_len: Optional[int] = None
  230. prompt_logprob_indices: List[int] = \
  231. sample_obj.prompt_logprob_indices if cache is not None else []
  232. sample_indices: List[int] = \
  233. sample_obj.sample_indices if cache is not None else []
  234. do_sample = seq_group_metadata.do_sample
  235. if seq_group_metadata.is_prompt:
  236. if sampling_params.seed is not None:
  237. generator = torch.Generator(device=device).manual_seed(
  238. sampling_params.seed)
  239. if generators is not None:
  240. generators[seq_group_metadata.request_id] = generator
  241. num_prompts += 1
  242. num_prefill_sample = len(seq_ids)
  243. assert num_prefill_sample == 1
  244. assert query_lens is not None and seq_lens is not None
  245. query_len, seq_len = query_lens[i], seq_lens[i]
  246. # If we need sampling, exclude num_prefill_sample tokens from
  247. # prompt logprob.
  248. prompt_logprob_len = (query_len - num_prefill_sample
  249. if do_sample else query_len)
  250. sample_len = num_prefill_sample if do_sample else 0
  251. else:
  252. # Decode
  253. prompt_logprob_len = 0
  254. sample_len = len(seq_ids) if do_sample else 0
  255. if sampling_params.seed is not None and generators is not None:
  256. generator = generators.get(seq_group_metadata.request_id)
  257. # Update indices to select from the model output.
  258. """
  259. This blocks computes selected_token_indices which is used in the
  260. following way.
  261. hidden_states = model(...)
  262. logits = hidden_states[selected_token_indices]
  263. """
  264. if sampling_params.prompt_logprobs is not None:
  265. selected_token_indices.extend(
  266. range(model_output_idx, model_output_idx + prompt_logprob_len))
  267. model_output_idx += prompt_logprob_len
  268. if do_sample:
  269. selected_token_indices.extend(
  270. range(model_output_idx, model_output_idx + sample_len))
  271. model_output_idx += sample_len
  272. # We now find indices for logprob computation and sampling.
  273. """
  274. This block computes categorized_sample_indices which is used in the
  275. following way.
  276. hidden_states = model(...)
  277. logits = hidden_states[selected_token_indices]
  278. def sample(logits):
  279. # Use categorized_sample_indices for sampling.
  280. # prompt_logprob_indices to find prompt logprob indices.
  281. # sample_indices to find sample indices.
  282. """
  283. if sampling_params.prompt_logprobs is not None:
  284. prompt_logprob_indices.extend(
  285. range(logit_idx, logit_idx + prompt_logprob_len))
  286. logit_idx += prompt_logprob_len
  287. if do_sample:
  288. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  289. categorized_sample_indices[sampling_params.sampling_type].extend(
  290. list(
  291. zip(range(logit_idx, logit_idx + sample_len),
  292. range(sample_idx, sample_idx + sample_len))))
  293. logit_idx += sample_len
  294. sample_idx += sample_len
  295. if cache is not None:
  296. sample_obj.sampling_params = sampling_params
  297. sample_obj.seq_data = seq_group_metadata.seq_data
  298. sample_obj.seq_len = seq_len
  299. sample_obj.query_len = query_len
  300. sample_obj.generator = generator
  301. sample_obj.is_prompt = is_prompt
  302. else:
  303. sample_obj = SequenceGroupToSample(
  304. seq_ids=list(seq_ids),
  305. sampling_params=sampling_params,
  306. seq_data=seq_group_metadata.seq_data,
  307. seq_len=seq_len,
  308. query_len=query_len,
  309. generator=generator,
  310. is_prompt=is_prompt,
  311. prompt_logprob_indices=list(prompt_logprob_indices),
  312. sample_indices=list(sample_indices))
  313. seq_groups.append(sample_obj)
  314. if cache is not None:
  315. cache.reset()
  316. return (seq_groups, selected_token_indices, categorized_sample_indices,
  317. num_prompts)
  318. @dataclass
  319. class SamplingTensors:
  320. """Tensors for sampling."""
  321. temperatures: torch.Tensor
  322. dynatemp_mins: torch.Tensor
  323. dynatemp_maxs: torch.Tensor
  324. dynatemp_exps: torch.Tensor
  325. temperature_lasts: torch.Tensor
  326. top_ps: torch.Tensor
  327. top_ks: torch.Tensor
  328. top_as: torch.Tensor
  329. min_ps: torch.Tensor
  330. presence_penalties: torch.Tensor
  331. frequency_penalties: torch.Tensor
  332. repetition_penalties: torch.Tensor
  333. tfss: torch.Tensor
  334. eta_cutoffs: torch.Tensor
  335. epsilon_cutoffs: torch.Tensor
  336. typical_ps: torch.Tensor
  337. smoothing_factors: torch.Tensor
  338. smoothing_curves: torch.Tensor
  339. xtc_thresholds: torch.Tensor
  340. xtc_probabilities: torch.Tensor
  341. kl_thresholds: torch.Tensor
  342. jsd_thresholds: torch.Tensor
  343. min_typical_ps: torch.Tensor
  344. max_typical_ps: torch.Tensor
  345. sampling_seeds: torch.Tensor
  346. sample_indices: torch.Tensor
  347. extra_seeds: Optional[torch.Tensor]
  348. prompt_tokens: torch.Tensor
  349. output_tokens: torch.Tensor
  350. @classmethod
  351. def from_sampling_metadata(
  352. cls,
  353. sampling_metadata: "SamplingMetadata",
  354. vocab_size: int,
  355. device: torch.device,
  356. dtype: torch.dtype,
  357. *,
  358. extra_seeds_to_generate: int = 0,
  359. extra_entropy: Optional[Tuple[int, ...]] = None
  360. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  361. bool, bool, bool, bool, bool, bool, bool, bool]:
  362. """
  363. extra_seeds_to_generate: extra seeds to generate using the
  364. user-defined seed for each sequence.
  365. extra_entropy: extra entropy to use when generating seeds.
  366. """
  367. prompt_tokens: List[array] = []
  368. output_tokens: List[array] = []
  369. top_ks: List[int] = []
  370. temperatures: List[float] = []
  371. dynatemp_mins: List[float] = []
  372. dynatemp_maxs: List[float] = []
  373. dynatemp_exps: List[float] = []
  374. temperature_lasts: List[bool] = []
  375. top_ps: List[float] = []
  376. top_as: List[float] = []
  377. min_ps: List[float] = []
  378. presence_penalties: List[float] = []
  379. frequency_penalties: List[float] = []
  380. repetition_penalties: List[float] = []
  381. tfss: List[float] = []
  382. eta_cutoffs: List[float] = []
  383. epsilon_cutoffs: List[float] = []
  384. typical_ps: List[float] = []
  385. smoothing_factors: List[float] = []
  386. smoothing_curves: List[float] = []
  387. xtc_thresholds: List[float] = []
  388. xtc_probabilities: List[float] = []
  389. kl_thresholds: List[float] = []
  390. jsd_thresholds: List[float] = []
  391. min_typical_ps: List[float] = []
  392. max_typical_ps: List[float] = []
  393. sampling_seeds: List[int] = []
  394. sample_indices: List[int] = []
  395. do_penalties = False
  396. do_temperatures = False
  397. do_top_p_top_k = False
  398. do_top_as = False
  399. do_min_p = False
  400. do_tfss = False
  401. do_eta_cutoffs = False
  402. do_epsilon_cutoffs = False
  403. do_typical_ps = False
  404. do_quadratic = False
  405. do_xtc = False
  406. do_kl_threshold = False
  407. do_jsd_threshold = False
  408. do_dynatypical_p = False
  409. do_temp_last = False
  410. if _USE_TRITON_SAMPLER:
  411. prompt_best_of: List[int] = []
  412. # We need one base seed per Triton slice.
  413. seeds_to_generate = (extra_seeds_to_generate +
  414. get_num_triton_sampler_splits(vocab_size))
  415. assert sampling_metadata.seq_groups is not None
  416. for seq_group in sampling_metadata.seq_groups:
  417. seq_ids = seq_group.seq_ids
  418. sampling_params = seq_group.sampling_params
  419. temperature = sampling_params.temperature
  420. dynatemp_min = sampling_params.dynatemp_min
  421. dynatemp_max = sampling_params.dynatemp_max
  422. dynatemp_exp = sampling_params.dynatemp_exponent
  423. temperature_last = sampling_params.temperature_last
  424. p = sampling_params.presence_penalty
  425. f = sampling_params.frequency_penalty
  426. r = sampling_params.repetition_penalty
  427. top_p = sampling_params.top_p
  428. top_a = sampling_params.top_a
  429. min_p = sampling_params.min_p
  430. tfs = sampling_params.tfs
  431. eta_cutoff = sampling_params.eta_cutoff
  432. epsilon_cutoff = sampling_params.epsilon_cutoff
  433. typical_p = sampling_params.typical_p
  434. smoothing_factor = sampling_params.smoothing_factor
  435. smoothing_curve = sampling_params.smoothing_curve
  436. xtc_threshold = sampling_params.xtc_threshold
  437. xtc_probability = sampling_params.xtc_probability
  438. kl_threshold = sampling_params.kl_threshold
  439. jsd_threshold = sampling_params.jsd_threshold
  440. min_typical_p = sampling_params.min_typical_p
  441. max_typical_p = sampling_params.max_typical_p
  442. # k should not be greater than the vocab size.
  443. top_k = min(sampling_params.top_k, vocab_size)
  444. top_k = vocab_size if top_k == -1 else top_k
  445. if temperature < _SAMPLING_EPS:
  446. # NOTE: Zero temperature means deterministic sampling
  447. # (i.e., greedy sampling or beam search).
  448. # Set the temperature to 1 to avoid division by zero.
  449. temperature = 1.0
  450. if not do_temperatures and temperature != 1.0:
  451. do_temperatures = True
  452. if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
  453. or top_k != vocab_size):
  454. do_top_p_top_k = True
  455. if do_top_as is False and top_a > 0.0:
  456. do_top_as = True
  457. if not do_min_p and min_p > _SAMPLING_EPS:
  458. do_min_p = True
  459. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  460. or abs(f) >= _SAMPLING_EPS
  461. or abs(r - 1.0) >= _SAMPLING_EPS):
  462. do_penalties = True
  463. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  464. do_tfss = True
  465. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  466. do_eta_cutoffs = True
  467. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  468. do_epsilon_cutoffs = True
  469. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  470. do_typical_ps = True
  471. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  472. or smoothing_curve > 1.0):
  473. do_quadratic = True
  474. if do_xtc is False and xtc_probability > _SAMPLING_EPS:
  475. do_xtc = True
  476. if do_kl_threshold is False and kl_threshold > _SAMPLING_EPS:
  477. do_kl_threshold = True
  478. if do_jsd_threshold is False and jsd_threshold > _SAMPLING_EPS:
  479. do_jsd_threshold = True
  480. if do_dynatypical_p is False and (min_typical_p < 1.0 - _SAMPLING_EPS
  481. or max_typical_p < 1.0 - _SAMPLING_EPS):
  482. do_dynatypical_p = True
  483. if do_temp_last is False and temperature_last:
  484. do_temp_last = True
  485. is_prompt = seq_group.is_prompt
  486. if (is_prompt and sampling_params.prompt_logprobs is not None):
  487. # For tokens in the prompt that we only need to get
  488. # their logprobs
  489. query_len = seq_group.query_len
  490. assert query_len is not None
  491. prefill_len = len(seq_group.prompt_logprob_indices)
  492. temperatures += [temperature] * prefill_len
  493. dynatemp_mins += [dynatemp_min] * prefill_len
  494. dynatemp_maxs += [dynatemp_max] * prefill_len
  495. dynatemp_exps += [dynatemp_exp] * prefill_len
  496. temperature_lasts += [temperature_last] * prefill_len
  497. top_ps += [top_p] * prefill_len
  498. top_ks += [top_k] * prefill_len
  499. top_as += [top_a] * prefill_len
  500. min_ps += [min_p] * prefill_len
  501. presence_penalties += [0] * prefill_len
  502. frequency_penalties += [0] * prefill_len
  503. repetition_penalties += [1] * prefill_len
  504. tfss += [1] * prefill_len
  505. eta_cutoffs += [0] * prefill_len
  506. epsilon_cutoffs += [0] * prefill_len
  507. typical_ps += [1] * prefill_len
  508. smoothing_factors += [smoothing_factor] * prefill_len
  509. smoothing_curves += [smoothing_curve] * prefill_len
  510. xtc_thresholds += [xtc_threshold] * prefill_len
  511. xtc_probabilities += [xtc_probability] * prefill_len
  512. kl_thresholds += [kl_threshold] * prefill_len
  513. jsd_thresholds += [jsd_threshold] * prefill_len
  514. min_typical_ps += [min_typical_p] * prefill_len
  515. max_typical_ps += [max_typical_p] * prefill_len
  516. if seq_group.do_sample:
  517. sample_lens = len(seq_group.sample_indices)
  518. assert sample_lens == len(seq_ids)
  519. temperatures += [temperature] * len(seq_ids)
  520. dynatemp_mins += [dynatemp_min] * len(seq_ids)
  521. dynatemp_maxs += [dynatemp_max] * len(seq_ids)
  522. dynatemp_exps += [dynatemp_exp] * len(seq_ids)
  523. temperature_lasts += [temperature_last] * len(seq_ids)
  524. top_ps += [top_p] * len(seq_ids)
  525. top_ks += [top_k] * len(seq_ids)
  526. top_as += [top_a] * len(seq_ids)
  527. min_ps += [min_p] * len(seq_ids)
  528. presence_penalties += [p] * len(seq_ids)
  529. frequency_penalties += [f] * len(seq_ids)
  530. repetition_penalties += [r] * len(seq_ids)
  531. tfss += [tfs] * len(seq_ids)
  532. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  533. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  534. typical_ps += [typical_p] * len(seq_ids)
  535. smoothing_factors += [smoothing_factor] * len(seq_ids)
  536. smoothing_curves += [smoothing_curve] * len(seq_ids)
  537. xtc_thresholds += [xtc_threshold] * len(seq_ids)
  538. xtc_probabilities += [xtc_probability] * len(seq_ids)
  539. kl_thresholds += [kl_threshold] * len(seq_ids)
  540. jsd_thresholds += [jsd_threshold] * len(seq_ids)
  541. min_typical_ps += [min_typical_p] * len(seq_ids)
  542. max_typical_ps += [max_typical_p] * len(seq_ids)
  543. if _USE_TRITON_SAMPLER:
  544. if is_prompt:
  545. prompt_best_of.append(sampling_params.best_of)
  546. query_len = seq_group.query_len
  547. assert query_len is not None
  548. seed = sampling_params.seed
  549. is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
  550. for seq_id in seq_ids:
  551. seq_data = seq_group.seq_data[seq_id]
  552. extra_entropy = extra_entropy or ()
  553. seq_seeds = cls._get_sequence_seeds(
  554. seed,
  555. seq_data.get_len(),
  556. *extra_entropy,
  557. seq_id,
  558. seeds_to_generate=seeds_to_generate,
  559. is_greedy=is_greedy)
  560. sampling_seeds.append(seq_seeds)
  561. sample_indices.extend(seq_group.sample_indices)
  562. if do_penalties:
  563. for seq_group in sampling_metadata.seq_groups:
  564. seq_ids = seq_group.seq_ids
  565. if (seq_group.is_prompt
  566. and sampling_params.prompt_logprobs is not None):
  567. prefill_len = len(seq_group.prompt_logprob_indices)
  568. prompt_tokens.extend(
  569. array('l') for _ in range(prefill_len))
  570. output_tokens.extend(
  571. array('l') for _ in range(prefill_len))
  572. if seq_group.do_sample:
  573. for seq_id in seq_ids:
  574. seq_data = seq_group.seq_data[seq_id]
  575. prompt_tokens.append(seq_data.prompt_token_ids_array)
  576. output_tokens.append(seq_data.output_token_ids_array)
  577. sampling_tensors = SamplingTensors.from_lists(
  578. temperatures, dynatemp_mins, dynatemp_maxs, dynatemp_exps,
  579. temperature_lasts, top_ps, top_ks, top_as, min_ps,
  580. presence_penalties, frequency_penalties, repetition_penalties,
  581. tfss, eta_cutoffs, epsilon_cutoffs, typical_ps, smoothing_factors,
  582. smoothing_curves, xtc_thresholds, xtc_probabilities, kl_thresholds,
  583. jsd_thresholds, min_typical_ps, max_typical_ps,
  584. sampling_seeds, sample_indices, prompt_tokens, output_tokens,
  585. vocab_size, extra_seeds_to_generate, device, dtype)
  586. return (sampling_tensors, do_penalties, do_temperatures,
  587. do_top_p_top_k, do_top_as, do_min_p, do_tfss, do_eta_cutoffs,
  588. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_xtc,
  589. do_kl_threshold, do_jsd_threshold, do_dynatypical_p,
  590. do_temp_last)
  591. @classmethod
  592. def from_lists(cls, temperatures: List[float], dynatemp_mins: List[float],
  593. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  594. temperature_lasts: List[bool], top_ps: List[float],
  595. top_ks: List[int], top_as: List[float],
  596. min_ps: List[float], presence_penalties: List[float],
  597. frequency_penalties: List[float],
  598. repetition_penalties: List[float], tfss: List[float],
  599. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  600. typical_ps: List[float], smoothing_factors: List[float],
  601. smoothing_curves: List[float], xtc_thresholds: List[float],
  602. xtc_probabilities: List[float], kl_thresholds: List[float],
  603. jsd_thresholds: List[float], min_typical_ps: List[float],
  604. max_typical_ps: List[float], sampling_seeds: List[int],
  605. sample_indices: List[int], prompt_tokens: List[array],
  606. output_tokens: List[array], vocab_size: int,
  607. extra_seeds_to_generate: int, device: torch.device,
  608. dtype: torch.dtype) -> "SamplingTensors":
  609. # Note that the performance will be very bad without
  610. # pinned memory.
  611. pin_memory = is_pin_memory_available()
  612. do_penalties = prompt_tokens or output_tokens
  613. if do_penalties:
  614. prompt_t = make_tensor_with_pad(
  615. prompt_tokens,
  616. vocab_size,
  617. device="cpu",
  618. dtype=torch.int64,
  619. pin_memory=pin_memory,
  620. )
  621. output_t = make_tensor_with_pad(
  622. output_tokens,
  623. vocab_size,
  624. device="cpu",
  625. dtype=torch.int64,
  626. pin_memory=pin_memory,
  627. )
  628. else:
  629. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  630. prompt_t = empty_tensor
  631. output_t = empty_tensor
  632. temperatures_t = torch.tensor(
  633. temperatures,
  634. device="cpu",
  635. dtype=dtype,
  636. pin_memory=pin_memory,
  637. )
  638. dynatemp_mins_t = torch.tensor(
  639. dynatemp_mins,
  640. device="cpu",
  641. dtype=dtype,
  642. pin_memory=pin_memory,
  643. )
  644. dynatemp_maxs_t = torch.tensor(
  645. dynatemp_maxs,
  646. device="cpu",
  647. dtype=dtype,
  648. pin_memory=pin_memory,
  649. )
  650. dynatemp_exps_t = torch.tensor(
  651. dynatemp_exps,
  652. device="cpu",
  653. dtype=dtype,
  654. pin_memory=pin_memory,
  655. )
  656. temp_lasts_t = torch.tensor(
  657. temperature_lasts,
  658. device="cpu",
  659. dtype=torch.bool,
  660. pin_memory=pin_memory,
  661. )
  662. top_ps_t = torch.tensor(
  663. top_ps,
  664. device="cpu",
  665. dtype=dtype,
  666. pin_memory=pin_memory,
  667. )
  668. top_as_t = torch.tensor(top_as,
  669. device="cpu",
  670. dtype=dtype,
  671. pin_memory=pin_memory)
  672. min_ps_t = torch.tensor(
  673. min_ps,
  674. device="cpu",
  675. dtype=dtype,
  676. pin_memory=pin_memory,
  677. )
  678. presence_penalties_t = torch.tensor(
  679. presence_penalties,
  680. device="cpu",
  681. dtype=dtype,
  682. pin_memory=pin_memory,
  683. )
  684. frequency_penalties_t = torch.tensor(
  685. frequency_penalties,
  686. device="cpu",
  687. dtype=dtype,
  688. pin_memory=pin_memory,
  689. )
  690. repetition_penalties_t = torch.tensor(
  691. repetition_penalties,
  692. device="cpu",
  693. dtype=dtype,
  694. pin_memory=pin_memory,
  695. )
  696. top_ks_t = torch.tensor(
  697. top_ks,
  698. device="cpu",
  699. dtype=torch.int,
  700. pin_memory=pin_memory,
  701. )
  702. tfss_t = torch.tensor(tfss,
  703. device="cpu",
  704. dtype=dtype,
  705. pin_memory=pin_memory)
  706. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  707. device="cpu",
  708. dtype=dtype,
  709. pin_memory=pin_memory)
  710. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  711. device="cpu",
  712. dtype=dtype,
  713. pin_memory=pin_memory)
  714. typical_ps_t = torch.tensor(typical_ps,
  715. device="cpu",
  716. dtype=dtype,
  717. pin_memory=pin_memory)
  718. smoothing_factors_t = torch.tensor(smoothing_factors,
  719. device="cpu",
  720. dtype=dtype,
  721. pin_memory=pin_memory)
  722. smoothing_curves_t = torch.tensor(smoothing_curves,
  723. device="cpu",
  724. dtype=dtype,
  725. pin_memory=pin_memory)
  726. xtc_thresholds_t = torch.tensor(xtc_thresholds,
  727. device="cpu",
  728. dtype=dtype,
  729. pin_memory=pin_memory)
  730. xtc_probabilities_t = torch.tensor(xtc_probabilities,
  731. device="cpu",
  732. dtype=dtype,
  733. pin_memory=pin_memory)
  734. kl_thresholds_t = torch.tensor(kl_thresholds,
  735. device="cpu",
  736. dtype=dtype,
  737. pin_memory=pin_memory)
  738. jsd_thresholds_t = torch.tensor(jsd_thresholds,
  739. device="cpu",
  740. dtype=dtype,
  741. pin_memory=pin_memory)
  742. min_typical_ps_t = torch.tensor(min_typical_ps,
  743. device="cpu",
  744. dtype=dtype,
  745. pin_memory=pin_memory)
  746. max_typical_ps_t = torch.tensor(max_typical_ps,
  747. device="cpu",
  748. dtype=dtype,
  749. pin_memory=pin_memory)
  750. sample_indices_t = torch.tensor(
  751. sample_indices,
  752. device="cpu",
  753. dtype=torch.long,
  754. pin_memory=pin_memory,
  755. )
  756. # need to transpose and make contiguous to
  757. # copy the tensor correctly.
  758. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  759. sampling_seeds_t = torch.tensor(
  760. sampling_seeds,
  761. device="cpu",
  762. dtype=torch.long,
  763. pin_memory=pin_memory,
  764. ).t().contiguous()
  765. # Because the memory is pinned, we can do non-blocking
  766. # transfer to device.
  767. # How many seeds the sample operation itself will need.
  768. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  769. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  770. non_blocking=True)
  771. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  772. if not extra_seeds_gpu.numel():
  773. extra_seeds_gpu = None
  774. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  775. return cls(
  776. temperatures=temperatures_t.to(device=device, non_blocking=True),
  777. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  778. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  779. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  780. temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
  781. top_ps=top_ps_t.to(device=device, non_blocking=True),
  782. top_ks=top_ks_t.to(device=device, non_blocking=True),
  783. top_as=top_as_t.to(device=device, non_blocking=True),
  784. min_ps=min_ps_t.to(device=device, non_blocking=True),
  785. presence_penalties=presence_penalties_t.to(device=device,
  786. non_blocking=True),
  787. frequency_penalties=frequency_penalties_t.to(device=device,
  788. non_blocking=True),
  789. repetition_penalties=repetition_penalties_t.to(device=device,
  790. non_blocking=True),
  791. tfss=tfss_t.to(device=device, non_blocking=True),
  792. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  793. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  794. non_blocking=True),
  795. smoothing_factors=smoothing_factors_t.to(device=device,
  796. non_blocking=True),
  797. smoothing_curves=smoothing_curves_t.to(device=device,
  798. non_blocking=True),
  799. xtc_thresholds=xtc_thresholds_t.to(device=device,
  800. non_blocking=True),
  801. xtc_probabilities=xtc_probabilities_t.to(device=device,
  802. non_blocking=True),
  803. kl_thresholds=kl_thresholds_t.to(device=device, non_blocking=True),
  804. jsd_thresholds=jsd_thresholds_t.to(device=device, non_blocking=True),
  805. min_typical_ps=min_typical_ps_t.to(device=device, non_blocking=True),
  806. max_typical_ps=max_typical_ps_t.to(device=device, non_blocking=True),
  807. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  808. prompt_tokens=prompt_t.to(device=device, non_blocking=True),
  809. output_tokens=output_t.to(device=device, non_blocking=True),
  810. sampling_seeds=sampling_seeds_gpu,
  811. sample_indices=sample_indices_t.to(device=device,
  812. non_blocking=True),
  813. extra_seeds=extra_seeds_gpu,
  814. )
  815. @staticmethod
  816. def _get_sequence_seeds(
  817. seed: int,
  818. *extra_entropy: int,
  819. seeds_to_generate: int,
  820. is_greedy: bool,
  821. ):
  822. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  823. if not is_greedy:
  824. if seed is None:
  825. randint_fn = random.randint
  826. else:
  827. generator = random.Random(str((seed, ) + extra_entropy))
  828. randint_fn = generator.randint
  829. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  830. # If the user/random sets seed = 0 but request should
  831. # have sampling, we need to change it to something
  832. # else. We use a constant in that case.
  833. # This way we don't need to create and load a bool
  834. # matrix in the sampling kernel, which reduces CPU
  835. # overhead and latency.
  836. seq_seeds = [
  837. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  838. for _ in range(seeds_to_generate)
  839. ]
  840. else:
  841. # For the kernel, seed == 0 means greedy decoding.
  842. seq_seeds = [0] * seeds_to_generate
  843. return seq_seeds