sampling_metadata.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. import random
  2. from array import array
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple
  5. import torch
  6. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  7. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  8. from aphrodite.common.utils import (PyObjectCache, async_tensor_h2d,
  9. is_pin_memory_available,
  10. make_tensor_with_pad, maybe_expand_dim)
  11. from aphrodite.triton_utils.sample import get_num_triton_sampler_splits
  12. _SAMPLING_EPS = 1e-5
  13. _SEED_0_REPLACEMENT = 3403598558
  14. # Some triton sampler related code is guarded before it is ready.
  15. _USE_TRITON_SAMPLER = False
  16. @dataclass
  17. class SequenceGroupToSample:
  18. # |---------- N-1 iteration --------|
  19. # |---------------- N iteration ---------------------|
  20. # |- tokenA -|......................|-- newTokens ---|
  21. # |---------- context_len ----------|
  22. # |-------------------- seq_len ----------------------|
  23. # |-- query_len ---|
  24. # Sequence ids for the sequence group in a previous step.
  25. seq_ids: List[int]
  26. sampling_params: SamplingParams
  27. # seq_id -> sequence data.
  28. seq_data: Dict[int, SequenceData]
  29. # The length of the sequence (all tokens seen in the past + new token to
  30. # compute attention) of the sequence group. None if it is in a decode
  31. # stage.
  32. seq_len: Optional[int]
  33. # The length of new query tokens to compute in the current step. None if it
  34. # is in a decode stage. The length of query_len <= seq_len if chunked
  35. # prefill is enabled.
  36. query_len: Optional[int]
  37. # A random number generator for sampling.
  38. generator: Optional[torch.Generator]
  39. # True if the sequence group is in prefill stage. False if it is in a
  40. # decode stage.
  41. is_prompt: bool
  42. # Query token indices from logits. to compute prompt logprob. Empty if
  43. # prompt logprob is not required.
  44. prompt_logprob_indices: List[int]
  45. # Sample token indices from logits. Empty if sampling is not required.
  46. sample_indices: List[int]
  47. @property
  48. def do_sample(self):
  49. return len(self.sample_indices) > 0
  50. def __post_init__(self):
  51. if len(self.prompt_logprob_indices) > 0:
  52. assert self.sampling_params.prompt_logprobs is not None
  53. if self.is_prompt:
  54. assert self.seq_len is not None
  55. assert self.query_len is not None
  56. def gen_seq_group_to_sample_builder(num_seqs: int):
  57. return lambda: SequenceGroupToSample(
  58. seq_ids=[0] * num_seqs,
  59. sampling_params=None,
  60. seq_data=None, # type: ignore
  61. seq_len=0,
  62. query_len=0,
  63. generator=None,
  64. is_prompt=True,
  65. prompt_logprob_indices=[],
  66. sample_indices=[])
  67. class SamplingMetadataCache:
  68. """Used to cache SamplingMetadata objects between scheduler iterations
  69. """
  70. def __init__(self):
  71. self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
  72. def get_cached_seq_group_to_sample(self, num_seqs):
  73. if num_seqs not in self._seq_group_to_sample_cache:
  74. self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
  75. gen_seq_group_to_sample_builder(num_seqs))
  76. obj = self._seq_group_to_sample_cache[num_seqs].get_object()
  77. return obj
  78. def reset(self):
  79. for cache in self._seq_group_to_sample_cache.values():
  80. cache.reset()
  81. class SamplingMetadata:
  82. """Metadata for input sequences. Used in sampler.
  83. The usage is as follow;
  84. ```
  85. hidden_states = execute_model(...)
  86. logits = hidden_states[sampling_metadata.selected_token_indices]
  87. sample(logits)
  88. def sample(logits):
  89. # Use categorized_sample_indices for sampling....
  90. ```
  91. Args:
  92. seq_groups: List of batched sequence groups.
  93. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  94. logits from the initial model output hidden states.
  95. categorized_sample_indices: SamplingType -> token indices to sample.
  96. Each token indices is 2D tensor of (num_indices, num_indices) where
  97. the first item means the sample index within the returned logit
  98. (before pruning padding), and the second item means the sample
  99. index after pruning using selected_token_indices.
  100. For example, if the returned logit is [1, 2, 3], and we select
  101. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  102. The first tuple is [1, 2] (sampled index within original logit),
  103. and the second tuple is [0, 1] (sampled index within pruned logit).
  104. num_prompts: Number of prompt sequence groups in seq_groups.
  105. skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
  106. serialization of token outputs.
  107. reuse_sampling_tensors: Indicates if we want to reuse sampling
  108. tensors that are part of the sampler forward pass. Currently,
  109. it is mainly used for multi-step decode.
  110. """
  111. def __init__(
  112. self,
  113. seq_groups: List[SequenceGroupToSample],
  114. selected_token_indices: torch.Tensor,
  115. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  116. num_prompts: int,
  117. skip_sampler_cpu_output: bool = False,
  118. reuse_sampling_tensors: bool = False,
  119. ) -> None:
  120. self.seq_groups = seq_groups
  121. self.selected_token_indices = selected_token_indices
  122. self.categorized_sample_indices = categorized_sample_indices
  123. self.num_prompts = num_prompts
  124. self.skip_sampler_cpu_output = skip_sampler_cpu_output
  125. self.reuse_sampling_tensors = reuse_sampling_tensors
  126. @staticmethod
  127. def prepare(
  128. seq_group_metadata_list: List[SequenceGroupMetadata],
  129. seq_lens: List[int],
  130. query_lens: Optional[List[int]],
  131. device: str,
  132. pin_memory: bool,
  133. generators: Optional[Dict[str, torch.Generator]] = None,
  134. cache: Optional[SamplingMetadataCache] = None
  135. ) -> "SamplingMetadata":
  136. (
  137. seq_groups,
  138. selected_token_indices,
  139. categorized_sample_indices,
  140. num_prompts,
  141. ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
  142. device, generators, cache)
  143. selected_token_indices = async_tensor_h2d(selected_token_indices,
  144. dtype=torch.long,
  145. target_device=device,
  146. pin_memory=pin_memory)
  147. categorized_sample_indices = {
  148. t: maybe_expand_dim(
  149. async_tensor_h2d(seq_ids,
  150. dtype=torch.int,
  151. target_device=device,
  152. pin_memory=pin_memory), 2, 2)
  153. for t, seq_ids in categorized_sample_indices.items()
  154. }
  155. sampling_metadata = SamplingMetadata(
  156. seq_groups=seq_groups,
  157. selected_token_indices=selected_token_indices,
  158. categorized_sample_indices=categorized_sample_indices,
  159. num_prompts=num_prompts,
  160. )
  161. return sampling_metadata
  162. def __repr__(self) -> str:
  163. return (
  164. "SamplingMetadata("
  165. f"seq_groups={self.seq_groups}, "
  166. f"selected_token_indices={self.selected_token_indices}, "
  167. f"categorized_sample_indices={self.categorized_sample_indices}), ")
  168. def _prepare_seq_groups(
  169. seq_group_metadata_list: List[SequenceGroupMetadata],
  170. seq_lens: List[int],
  171. query_lens: Optional[List[int]],
  172. device: str,
  173. generators: Optional[Dict[str, torch.Generator]] = None,
  174. cache: Optional[SamplingMetadataCache] = None,
  175. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  176. SamplingType, List[Tuple[int, int]]], int]:
  177. """Prepare sequence groups and indices for sampling.
  178. Args:
  179. seq_group_metadata_list: A list of sequence group to batch.
  180. seq_lens: A list of sequence lens per sequence group.
  181. Index of prompt len should match with seq_group_metadata_list.
  182. query_lens: A list of query lengths. Prompt lens include the length
  183. of entire prompt tokens, and it could be shorter.
  184. device: A device to use for random number generators,
  185. `SequenceGroupToSample.generator`.
  186. generators: A store of per-request random number generators used
  187. for seeded requests.
  188. Returns:
  189. seq_groups: A list of sequence group to sample.
  190. selected_token_indices: See the definition from `SamplingMetadata`.
  191. categorized_sample_indices: See the definition from `SamplingMetadata`.
  192. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  193. """
  194. # Batched sequence groups for the current model forward stsep.
  195. seq_groups: List[SequenceGroupToSample] = []
  196. # A list of token indices to sample/compute logprob. It is used to
  197. # prune the outcome logits from the model for the performance.
  198. selected_token_indices: List[int] = []
  199. # Used for selected_token_indices.
  200. model_output_idx = 0
  201. # Sampling type -> (
  202. # indices to sample/prompt logprob within pruned output logits,
  203. # indices to sample within pruned logits)
  204. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  205. t: []
  206. for t in SamplingType
  207. }
  208. # Index of logits to compute logprob. Logits include both prompt logprob
  209. # and sample logprob indices.
  210. logit_idx = 0
  211. # Index to sample from a sample tensor. It is used by triton sample kernel.
  212. # See `_sample_with_triton_kernel` for more details.
  213. sample_idx = 0
  214. # Total number of prompts from given sequence groups.
  215. num_prompts = 0
  216. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  217. seq_ids = seq_group_metadata.seq_data.keys()
  218. if cache is not None:
  219. sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
  220. for j, seq_id in enumerate(seq_ids):
  221. sample_obj.seq_ids[j] = seq_id
  222. sample_obj.prompt_logprob_indices.clear()
  223. sample_obj.sample_indices.clear()
  224. sampling_params = seq_group_metadata.sampling_params
  225. is_prompt = seq_group_metadata.is_prompt
  226. generator: Optional[torch.Generator] = None
  227. # If the current seq group is in decode stage, it is None.
  228. seq_len: Optional[int] = None
  229. query_len: Optional[int] = None
  230. prompt_logprob_indices: List[int] = \
  231. sample_obj.prompt_logprob_indices if cache is not None else []
  232. sample_indices: List[int] = \
  233. sample_obj.sample_indices if cache is not None else []
  234. do_sample = seq_group_metadata.do_sample
  235. if seq_group_metadata.is_prompt:
  236. if sampling_params.seed is not None:
  237. generator = torch.Generator(device=device).manual_seed(
  238. sampling_params.seed)
  239. if generators is not None:
  240. generators[seq_group_metadata.request_id] = generator
  241. num_prompts += 1
  242. num_prefill_sample = len(seq_ids)
  243. assert num_prefill_sample == 1
  244. assert query_lens is not None and seq_lens is not None
  245. query_len, seq_len = query_lens[i], seq_lens[i]
  246. # If we need sampling, exclude num_prefill_sample tokens from
  247. # prompt logprob.
  248. prompt_logprob_len = (query_len - num_prefill_sample
  249. if do_sample else query_len)
  250. sample_len = num_prefill_sample if do_sample else 0
  251. else:
  252. # Decode
  253. prompt_logprob_len = 0
  254. sample_len = len(seq_ids) if do_sample else 0
  255. if sampling_params.seed is not None and generators is not None:
  256. generator = generators.get(seq_group_metadata.request_id)
  257. # Update indices to select from the model output.
  258. """
  259. This blocks computes selected_token_indices which is used in the
  260. following way.
  261. hidden_states = model(...)
  262. logits = hidden_states[selected_token_indices]
  263. """
  264. if sampling_params.prompt_logprobs is not None:
  265. selected_token_indices.extend(
  266. range(model_output_idx, model_output_idx + prompt_logprob_len))
  267. model_output_idx += prompt_logprob_len
  268. if do_sample:
  269. selected_token_indices.extend(
  270. range(model_output_idx, model_output_idx + sample_len))
  271. model_output_idx += sample_len
  272. # We now find indices for logprob computation and sampling.
  273. """
  274. This block computes categorized_sample_indices which is used in the
  275. following way.
  276. hidden_states = model(...)
  277. logits = hidden_states[selected_token_indices]
  278. def sample(logits):
  279. # Use categorized_sample_indices for sampling.
  280. # prompt_logprob_indices to find prompt logprob indices.
  281. # sample_indices to find sample indices.
  282. """
  283. if sampling_params.prompt_logprobs is not None:
  284. prompt_logprob_indices.extend(
  285. range(logit_idx, logit_idx + prompt_logprob_len))
  286. logit_idx += prompt_logprob_len
  287. if do_sample:
  288. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  289. categorized_sample_indices[sampling_params.sampling_type].extend(
  290. list(
  291. zip(range(logit_idx, logit_idx + sample_len),
  292. range(sample_idx, sample_idx + sample_len))))
  293. logit_idx += sample_len
  294. sample_idx += sample_len
  295. if cache is not None:
  296. sample_obj.sampling_params = sampling_params
  297. sample_obj.seq_data = seq_group_metadata.seq_data
  298. sample_obj.seq_len = seq_len
  299. sample_obj.query_len = query_len
  300. sample_obj.generator = generator
  301. sample_obj.is_prompt = is_prompt
  302. else:
  303. sample_obj = SequenceGroupToSample(
  304. seq_ids=list(seq_ids),
  305. sampling_params=sampling_params,
  306. seq_data=seq_group_metadata.seq_data,
  307. seq_len=seq_len,
  308. query_len=query_len,
  309. generator=generator,
  310. is_prompt=is_prompt,
  311. prompt_logprob_indices=list(prompt_logprob_indices),
  312. sample_indices=list(sample_indices))
  313. seq_groups.append(sample_obj)
  314. if cache is not None:
  315. cache.reset()
  316. return (seq_groups, selected_token_indices, categorized_sample_indices,
  317. num_prompts)
  318. @dataclass
  319. class SamplingTensors:
  320. """Tensors for sampling."""
  321. temperatures: torch.Tensor
  322. temperature_lasts: torch.Tensor
  323. top_ps: torch.Tensor
  324. top_ks: torch.Tensor
  325. top_as: torch.Tensor
  326. min_ps: torch.Tensor
  327. presence_penalties: torch.Tensor
  328. frequency_penalties: torch.Tensor
  329. repetition_penalties: torch.Tensor
  330. rep_pen_ranges: torch.Tensor
  331. tfss: torch.Tensor
  332. eta_cutoffs: torch.Tensor
  333. epsilon_cutoffs: torch.Tensor
  334. typical_ps: torch.Tensor
  335. smoothing_factors: torch.Tensor
  336. smoothing_curves: torch.Tensor
  337. xtc_thresholds: torch.Tensor
  338. xtc_probabilities: torch.Tensor
  339. sampling_seeds: torch.Tensor
  340. sample_indices: torch.Tensor
  341. extra_seeds: Optional[torch.Tensor]
  342. prompt_tokens: torch.Tensor
  343. output_tokens: torch.Tensor
  344. @classmethod
  345. def from_sampling_metadata(
  346. cls,
  347. sampling_metadata: "SamplingMetadata",
  348. vocab_size: int,
  349. device: torch.device,
  350. dtype: torch.dtype,
  351. *,
  352. extra_seeds_to_generate: int = 0,
  353. extra_entropy: Optional[Tuple[int, ...]] = None
  354. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  355. bool, bool, bool, bool]:
  356. """
  357. extra_seeds_to_generate: extra seeds to generate using the
  358. user-defined seed for each sequence.
  359. extra_entropy: extra entropy to use when generating seeds.
  360. """
  361. prompt_tokens: List[array] = []
  362. output_tokens: List[array] = []
  363. top_ks: List[int] = []
  364. temperatures: List[float] = []
  365. temperature_lasts: List[bool] = []
  366. top_ps: List[float] = []
  367. top_as: List[float] = []
  368. min_ps: List[float] = []
  369. presence_penalties: List[float] = []
  370. frequency_penalties: List[float] = []
  371. repetition_penalties: List[float] = []
  372. rep_pen_ranges: List[int] = []
  373. tfss: List[float] = []
  374. eta_cutoffs: List[float] = []
  375. epsilon_cutoffs: List[float] = []
  376. typical_ps: List[float] = []
  377. smoothing_factors: List[float] = []
  378. smoothing_curves: List[float] = []
  379. xtc_thresholds: List[float] = []
  380. xtc_probabilities: List[float] = []
  381. sampling_seeds: List[int] = []
  382. sample_indices: List[int] = []
  383. do_penalties = False
  384. do_top_p_top_k = False
  385. do_top_as = False
  386. do_min_p = False
  387. do_tfss = False
  388. do_eta_cutoffs = False
  389. do_epsilon_cutoffs = False
  390. do_typical_ps = False
  391. do_quadratic = False
  392. do_xtc = False
  393. do_temp_last = False
  394. if _USE_TRITON_SAMPLER:
  395. prompt_best_of: List[int] = []
  396. # We need one base seed per Triton slice.
  397. seeds_to_generate = (extra_seeds_to_generate +
  398. get_num_triton_sampler_splits(vocab_size))
  399. assert sampling_metadata.seq_groups is not None
  400. for seq_group in sampling_metadata.seq_groups:
  401. seq_ids = seq_group.seq_ids
  402. sampling_params = seq_group.sampling_params
  403. temperature = sampling_params.temperature
  404. temperature_last = sampling_params.temperature_last
  405. p = sampling_params.presence_penalty
  406. f = sampling_params.frequency_penalty
  407. r = sampling_params.repetition_penalty
  408. rep_pen_range = sampling_params.rep_pen_range
  409. top_p = sampling_params.top_p
  410. top_a = sampling_params.top_a
  411. min_p = sampling_params.min_p
  412. tfs = sampling_params.tfs
  413. eta_cutoff = sampling_params.eta_cutoff
  414. epsilon_cutoff = sampling_params.epsilon_cutoff
  415. typical_p = sampling_params.typical_p
  416. smoothing_factor = sampling_params.smoothing_factor
  417. smoothing_curve = sampling_params.smoothing_curve
  418. xtc_threshold = sampling_params.xtc_threshold
  419. xtc_probability = sampling_params.xtc_probability
  420. # k should not be greater than the vocab size.
  421. top_k = min(sampling_params.top_k, vocab_size)
  422. top_k = vocab_size if top_k == -1 else top_k
  423. if temperature < _SAMPLING_EPS:
  424. # NOTE: Zero temperature means deterministic sampling
  425. # (i.e., greedy sampling or beam search).
  426. # Set the temperature to 1 to avoid division by zero.
  427. temperature = 1.0
  428. if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS
  429. or top_k != vocab_size):
  430. do_top_p_top_k = True
  431. if do_top_as is False and top_a > 0.0:
  432. do_top_as = True
  433. if not do_min_p and min_p > _SAMPLING_EPS:
  434. do_min_p = True
  435. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  436. or abs(f) >= _SAMPLING_EPS
  437. or abs(r - 1.0) >= _SAMPLING_EPS):
  438. do_penalties = True
  439. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  440. do_tfss = True
  441. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  442. do_eta_cutoffs = True
  443. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  444. do_epsilon_cutoffs = True
  445. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  446. do_typical_ps = True
  447. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  448. or smoothing_curve > 1.0):
  449. do_quadratic = True
  450. if do_xtc is False and xtc_probability > _SAMPLING_EPS:
  451. do_xtc = True
  452. if do_temp_last is False and temperature_last:
  453. do_temp_last = True
  454. is_prompt = seq_group.is_prompt
  455. if (is_prompt and sampling_params.prompt_logprobs is not None):
  456. # For tokens in the prompt that we only need to get
  457. # their logprobs
  458. query_len = seq_group.query_len
  459. assert query_len is not None
  460. prefill_len = len(seq_group.prompt_logprob_indices)
  461. temperatures += [temperature] * prefill_len
  462. temperature_lasts += [temperature_last] * prefill_len
  463. top_ps += [top_p] * prefill_len
  464. top_ks += [top_k] * prefill_len
  465. top_as += [top_a] * prefill_len
  466. min_ps += [min_p] * prefill_len
  467. presence_penalties += [0] * prefill_len
  468. frequency_penalties += [0] * prefill_len
  469. repetition_penalties += [1] * prefill_len
  470. rep_pen_ranges += [rep_pen_range] * prefill_len
  471. tfss += [1] * prefill_len
  472. eta_cutoffs += [0] * prefill_len
  473. epsilon_cutoffs += [0] * prefill_len
  474. typical_ps += [1] * prefill_len
  475. smoothing_factors += [smoothing_factor] * prefill_len
  476. smoothing_curves += [smoothing_curve] * prefill_len
  477. xtc_thresholds += [xtc_threshold] * prefill_len
  478. xtc_probabilities += [xtc_probability] * prefill_len
  479. if seq_group.do_sample:
  480. sample_lens = len(seq_group.sample_indices)
  481. assert sample_lens == len(seq_ids)
  482. temperatures += [temperature] * len(seq_ids)
  483. temperature_lasts += [temperature_last] * len(seq_ids)
  484. top_ps += [top_p] * len(seq_ids)
  485. top_ks += [top_k] * len(seq_ids)
  486. top_as += [top_a] * len(seq_ids)
  487. min_ps += [min_p] * len(seq_ids)
  488. presence_penalties += [p] * len(seq_ids)
  489. frequency_penalties += [f] * len(seq_ids)
  490. repetition_penalties += [r] * len(seq_ids)
  491. rep_pen_ranges += [rep_pen_range] * len(seq_ids)
  492. tfss += [tfs] * len(seq_ids)
  493. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  494. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  495. typical_ps += [typical_p] * len(seq_ids)
  496. smoothing_factors += [smoothing_factor] * len(seq_ids)
  497. smoothing_curves += [smoothing_curve] * len(seq_ids)
  498. xtc_thresholds += [xtc_threshold] * len(seq_ids)
  499. xtc_probabilities += [xtc_probability] * len(seq_ids)
  500. if _USE_TRITON_SAMPLER:
  501. if is_prompt:
  502. prompt_best_of.append(sampling_params.best_of)
  503. query_len = seq_group.query_len
  504. assert query_len is not None
  505. seed = sampling_params.seed
  506. is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
  507. for seq_id in seq_ids:
  508. seq_data = seq_group.seq_data[seq_id]
  509. extra_entropy = extra_entropy or ()
  510. seq_seeds = cls._get_sequence_seeds(
  511. seed,
  512. seq_data.get_len(),
  513. *extra_entropy,
  514. seq_id,
  515. seeds_to_generate=seeds_to_generate,
  516. is_greedy=is_greedy)
  517. sampling_seeds.append(seq_seeds)
  518. sample_indices.extend(seq_group.sample_indices)
  519. if do_penalties:
  520. for seq_group in sampling_metadata.seq_groups:
  521. seq_ids = seq_group.seq_ids
  522. if (seq_group.is_prompt
  523. and sampling_params.prompt_logprobs is not None):
  524. prefill_len = len(seq_group.prompt_logprob_indices)
  525. prompt_tokens.extend(
  526. array('l') for _ in range(prefill_len))
  527. output_tokens.extend(
  528. array('l') for _ in range(prefill_len))
  529. if seq_group.do_sample:
  530. for seq_id in seq_ids:
  531. seq_data = seq_group.seq_data[seq_id]
  532. prompt_tokens.append(seq_data.prompt_token_ids_array)
  533. output_tokens.append(seq_data.output_token_ids_array)
  534. sampling_tensors = SamplingTensors.from_lists(
  535. temperatures, temperature_lasts, top_ps, top_ks, top_as, min_ps,
  536. presence_penalties, frequency_penalties, repetition_penalties,
  537. rep_pen_ranges, tfss, eta_cutoffs, epsilon_cutoffs, typical_ps,
  538. smoothing_factors, smoothing_curves, xtc_thresholds,
  539. xtc_probabilities, sampling_seeds, sample_indices, prompt_tokens,
  540. output_tokens, vocab_size, extra_seeds_to_generate, device, dtype)
  541. return (sampling_tensors, do_penalties, do_top_p_top_k, do_top_as,
  542. do_min_p, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
  543. do_typical_ps, do_quadratic, do_xtc, do_temp_last)
  544. @classmethod
  545. def from_lists(cls, temperatures: List[float],
  546. temperature_lasts: List[bool], top_ps: List[float],
  547. top_ks: List[int], top_as: List[float],
  548. min_ps: List[float], presence_penalties: List[float],
  549. frequency_penalties: List[float],
  550. repetition_penalties: List[float],
  551. rep_pen_ranges: List[int], tfss: List[float],
  552. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  553. typical_ps: List[float], smoothing_factors: List[float],
  554. smoothing_curves: List[float], xtc_thresholds: List[float],
  555. xtc_probabilities: List[float], sampling_seeds: List[int],
  556. sample_indices: List[int], prompt_tokens: List[array],
  557. output_tokens: List[array], vocab_size: int,
  558. extra_seeds_to_generate: int, device: torch.device,
  559. dtype: torch.dtype) -> "SamplingTensors":
  560. # Note that the performance will be very bad without
  561. # pinned memory.
  562. pin_memory = is_pin_memory_available()
  563. do_penalties = prompt_tokens or output_tokens
  564. if do_penalties:
  565. prompt_t = make_tensor_with_pad(
  566. prompt_tokens,
  567. vocab_size,
  568. device="cpu",
  569. dtype=torch.int64,
  570. pin_memory=pin_memory,
  571. )
  572. output_t = make_tensor_with_pad(
  573. output_tokens,
  574. vocab_size,
  575. device="cpu",
  576. dtype=torch.int64,
  577. pin_memory=pin_memory,
  578. )
  579. else:
  580. empty_tensor = torch.empty(0, device=device, dtype=torch.long)
  581. prompt_t = empty_tensor
  582. output_t = empty_tensor
  583. temperatures_t = torch.tensor(
  584. temperatures,
  585. device="cpu",
  586. dtype=dtype,
  587. pin_memory=pin_memory,
  588. )
  589. temp_lasts_t = torch.tensor(
  590. temperature_lasts,
  591. device="cpu",
  592. dtype=torch.bool,
  593. pin_memory=pin_memory,
  594. )
  595. top_ps_t = torch.tensor(
  596. top_ps,
  597. device="cpu",
  598. dtype=dtype,
  599. pin_memory=pin_memory,
  600. )
  601. top_as_t = torch.tensor(top_as,
  602. device="cpu",
  603. dtype=dtype,
  604. pin_memory=pin_memory)
  605. min_ps_t = torch.tensor(
  606. min_ps,
  607. device="cpu",
  608. dtype=dtype,
  609. pin_memory=pin_memory,
  610. )
  611. presence_penalties_t = torch.tensor(
  612. presence_penalties,
  613. device="cpu",
  614. dtype=dtype,
  615. pin_memory=pin_memory,
  616. )
  617. frequency_penalties_t = torch.tensor(
  618. frequency_penalties,
  619. device="cpu",
  620. dtype=dtype,
  621. pin_memory=pin_memory,
  622. )
  623. repetition_penalties_t = torch.tensor(
  624. repetition_penalties,
  625. device="cpu",
  626. dtype=dtype,
  627. pin_memory=pin_memory,
  628. )
  629. rep_pen_ranges_t = torch.tensor(
  630. rep_pen_ranges,
  631. device="cpu",
  632. dtype=torch.int,
  633. pin_memory=pin_memory,
  634. )
  635. top_ks_t = torch.tensor(
  636. top_ks,
  637. device="cpu",
  638. dtype=torch.int,
  639. pin_memory=pin_memory,
  640. )
  641. tfss_t = torch.tensor(tfss,
  642. device="cpu",
  643. dtype=dtype,
  644. pin_memory=pin_memory)
  645. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  646. device="cpu",
  647. dtype=dtype,
  648. pin_memory=pin_memory)
  649. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  650. device="cpu",
  651. dtype=dtype,
  652. pin_memory=pin_memory)
  653. typical_ps_t = torch.tensor(typical_ps,
  654. device="cpu",
  655. dtype=dtype,
  656. pin_memory=pin_memory)
  657. smoothing_factors_t = torch.tensor(smoothing_factors,
  658. device="cpu",
  659. dtype=dtype,
  660. pin_memory=pin_memory)
  661. smoothing_curves_t = torch.tensor(smoothing_curves,
  662. device="cpu",
  663. dtype=dtype,
  664. pin_memory=pin_memory)
  665. xtc_thresholds_t = torch.tensor(xtc_thresholds,
  666. device="cpu",
  667. dtype=dtype,
  668. pin_memory=pin_memory)
  669. xtc_probabilities_t = torch.tensor(xtc_probabilities,
  670. device="cpu",
  671. dtype=dtype,
  672. pin_memory=pin_memory)
  673. sample_indices_t = torch.tensor(
  674. sample_indices,
  675. device="cpu",
  676. dtype=torch.long,
  677. pin_memory=pin_memory,
  678. )
  679. # need to transpose and make contiguous to
  680. # copy the tensor correctly.
  681. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  682. sampling_seeds_t = torch.tensor(
  683. sampling_seeds,
  684. device="cpu",
  685. dtype=torch.long,
  686. pin_memory=pin_memory,
  687. ).t().contiguous()
  688. # Because the memory is pinned, we can do non-blocking
  689. # transfer to device.
  690. # How many seeds the sample operation itself will need.
  691. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  692. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  693. non_blocking=True)
  694. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  695. if not extra_seeds_gpu.numel():
  696. extra_seeds_gpu = None
  697. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  698. return cls(
  699. temperatures=temperatures_t.to(device=device, non_blocking=True),
  700. temperature_lasts=temp_lasts_t.to(device=device, non_blocking=True),
  701. top_ps=top_ps_t.to(device=device, non_blocking=True),
  702. top_ks=top_ks_t.to(device=device, non_blocking=True),
  703. top_as=top_as_t.to(device=device, non_blocking=True),
  704. min_ps=min_ps_t.to(device=device, non_blocking=True),
  705. presence_penalties=presence_penalties_t.to(device=device,
  706. non_blocking=True),
  707. frequency_penalties=frequency_penalties_t.to(device=device,
  708. non_blocking=True),
  709. repetition_penalties=repetition_penalties_t.to(device=device,
  710. non_blocking=True),
  711. rep_pen_ranges=rep_pen_ranges_t.to(device=device,
  712. non_blocking=True),
  713. tfss=tfss_t.to(device=device, non_blocking=True),
  714. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  715. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  716. non_blocking=True),
  717. smoothing_factors=smoothing_factors_t.to(device=device,
  718. non_blocking=True),
  719. smoothing_curves=smoothing_curves_t.to(device=device,
  720. non_blocking=True),
  721. xtc_thresholds=xtc_thresholds_t.to(device=device,
  722. non_blocking=True),
  723. xtc_probabilities=xtc_probabilities_t.to(device=device,
  724. non_blocking=True),
  725. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  726. prompt_tokens=prompt_t.to(device=device, non_blocking=True),
  727. output_tokens=output_t.to(device=device, non_blocking=True),
  728. sampling_seeds=sampling_seeds_gpu,
  729. sample_indices=sample_indices_t.to(device=device,
  730. non_blocking=True),
  731. extra_seeds=extra_seeds_gpu,
  732. )
  733. @staticmethod
  734. def _get_sequence_seeds(
  735. seed: int,
  736. *extra_entropy: int,
  737. seeds_to_generate: int,
  738. is_greedy: bool,
  739. ):
  740. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  741. if not is_greedy:
  742. if seed is None:
  743. randint_fn = random.randint
  744. else:
  745. generator = random.Random(str((seed, ) + extra_entropy))
  746. randint_fn = generator.randint
  747. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  748. # If the user/random sets seed = 0 but request should
  749. # have sampling, we need to change it to something
  750. # else. We use a constant in that case.
  751. # This way we don't need to create and load a bool
  752. # matrix in the sampling kernel, which reduces CPU
  753. # overhead and latency.
  754. seq_seeds = [
  755. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  756. for _ in range(seeds_to_generate)
  757. ]
  758. else:
  759. # For the kernel, seed == 0 means greedy decoding.
  760. seq_seeds = [0] * seeds_to_generate
  761. return seq_seeds