sampling_metadata.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Tuple, Optional, TypeVar, Callable
  3. import random
  4. import torch
  5. from aphrodite.modeling.layers.ops.sample import (get_num_triton_sampler_splits
  6. )
  7. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  8. from aphrodite.common.sequence import SequenceData
  9. from aphrodite.common.utils import is_pin_memory_available
  10. _SEED_0_REPLACEMENT = 3403598558 # chosen by fair roll of a die
  11. class PersistentMetadata:
  12. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  13. self._metadata: Dict[int, dict] = metadata or {}
  14. def get(self, seq_id: int, key, default=None):
  15. return self._metadata.get(seq_id, {}).get(key, default)
  16. class OutputMetadata():
  17. """Not symmetrical with PersistentMetadata because the process of
  18. sampling can produce unique metadata per sample, per sequence.
  19. The appropriate conversion would be `output[seq][sample](dict)` to
  20. `persist[new_seq_for_sample](dict)`"""
  21. def __init__(self):
  22. self._metadata: Dict[int, Dict[int, dict]] = {}
  23. def add(self, seq_id: int, sample_id: int, key, val) -> None:
  24. (self._metadata.setdefault(seq_id, {}).setdefault(sample_id,
  25. {})[key]) = val
  26. def get(self, seq_id: int, sample_id: int) -> dict:
  27. return self._metadata.get(seq_id, {}).get(sample_id, {})
  28. class SamplingMetadata:
  29. """Metadata for input sequences. Used in sampler.
  30. Args:
  31. seq_groups: List of (seq_ids, sampling_params).
  32. seq_data: Seq_id -> SequenceData.
  33. prompt_lens: Lengths of prompts.
  34. selected_token_indices: Token indices selected for sampling.
  35. categorized_sample_indices: SamplingType -> token indices to sample.
  36. generators: List of torch.Generators to use for seeded sampling
  37. perform_sampling: Whether to perform sampling. This option is used to
  38. make the sampling only happens in the driver worker, and disable
  39. sampling in other worker processes.
  40. persistent_metadata: Metadata that persists across iterations.
  41. output_metadata: the output metadata.
  42. """
  43. def __init__(
  44. self,
  45. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
  46. seq_data: Optional[Dict[int, SequenceData]],
  47. prompt_lens: Optional[List[int]],
  48. selected_token_indices: torch.Tensor,
  49. categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
  50. generators: Optional[List[torch.Generator]] = None,
  51. perform_sampling: bool = True,
  52. persistent_metadata: Optional[PersistentMetadata] = None,
  53. output_metadata: Optional[OutputMetadata] = None,
  54. ) -> None:
  55. self.seq_groups = seq_groups
  56. self.seq_data = seq_data
  57. self.prompt_lens = prompt_lens
  58. self.selected_token_indices = selected_token_indices
  59. self.categorized_sample_indices = categorized_sample_indices
  60. self.generators = generators
  61. self.perform_sampling = perform_sampling
  62. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  63. self.output_metadata = output_metadata or OutputMetadata()
  64. self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
  65. def __repr__(self) -> str:
  66. return (
  67. "SamplingMetadata("
  68. f"seq_groups={self.seq_groups}, "
  69. f"seq_data={self.seq_data}, "
  70. f"prompt_lens={self.prompt_lens}, "
  71. f"selected_token_indices={self.selected_token_indices}, "
  72. f"categorized_sample_indices={self.categorized_sample_indices}, "
  73. f"perform_sampling={self.perform_sampling}, "
  74. f"persistent_metadata={self.persistent_metadata}, "
  75. f"output_metadata={self.output_metadata}) ")
  76. @dataclass
  77. class SamplingTensors:
  78. """Tensors for sampling."""
  79. temperatures: torch.Tensor
  80. top_ps: torch.Tensor
  81. top_ks: torch.Tensor
  82. top_as: torch.Tensor
  83. min_ps: torch.Tensor
  84. pres_penalties: torch.Tensor
  85. freq_penalties: torch.Tensor
  86. rep_penalties: torch.Tensor
  87. tfss: torch.Tensor
  88. eta_cutoffs: torch.Tensor
  89. epsilon_cutoffs: torch.Tensor
  90. typical_ps: torch.Tensor
  91. miro_taus: torch.Tensor
  92. miro_etas: torch.Tensor
  93. miro_mus: torch.Tensor
  94. miro_indices: torch.Tensor
  95. miro_seqids: List[int] # state writeback done CPU side
  96. dynatemp_mins: torch.Tensor
  97. dynatemp_maxs: torch.Tensor
  98. dynatemp_exps: torch.Tensor
  99. smoothing_indices: torch.Tensor
  100. smoothing_factors: torch.Tensor
  101. smoothing_curves: torch.Tensor
  102. seed_indices: torch.Tensor
  103. seed_transpose: torch.Tensor
  104. extra_seed_transpose: Optional[torch.Tensor]
  105. prompt_tokens: torch.Tensor
  106. output_tokens: torch.Tensor
  107. do_temperatures: bool
  108. do_dynatemps: bool
  109. do_penalties: bool
  110. do_top_ks: bool
  111. do_top_ps: bool
  112. do_top_as: bool
  113. do_min_ps: bool
  114. do_tfss: bool
  115. do_eta_cutoffs: bool
  116. do_epsilon_cutoffs: bool
  117. do_typical_ps: bool
  118. do_quadratic: bool
  119. do_mirostat: bool
  120. @classmethod
  121. def from_sampling_metadata(
  122. cls,
  123. sampling_metadata: "SamplingMetadata",
  124. vocab_size: int,
  125. tgt_device: torch.device,
  126. float_dtype: torch.dtype,
  127. *,
  128. extra_seeds_to_generate: int = 0,
  129. extra_entropy: Optional[Tuple[int,
  130. ...]] = None) -> "SamplingTensors":
  131. prompt_lens = sampling_metadata.prompt_lens or []
  132. groups = sampling_metadata.seq_groups or []
  133. seq_data = sampling_metadata.seq_data or {}
  134. persistent = sampling_metadata.persistent_metadata
  135. extra_entropy = extra_entropy or ()
  136. # Flattened list of (params, sid) matching the logits tensor.
  137. # `sid < 0` implies a prompt seq.
  138. unrolled_seqs: List[Tuple[SamplingParams, int]] = []
  139. group_plens = prompt_lens + [0] * (len(groups) - len(prompt_lens))
  140. for (ids, params), prompt_len in zip(groups, group_plens):
  141. if prompt_len and params.prompt_logprobs is not None:
  142. unrolled_seqs.extend([(params, -1)] * (prompt_len - 1))
  143. unrolled_seqs.extend([(params, sid) for sid in ids])
  144. T = TypeVar('T')
  145. def _unroll(fn_val: Callable[[SamplingParams], T],
  146. prompt: Optional[T] = None) -> List[T]:
  147. """`fn_val` for every seq, with an override for prompt seqs."""
  148. return [
  149. prompt if sid < 0 and prompt is not None else fn_val(p)
  150. for p, sid in unrolled_seqs
  151. ]
  152. def _index(fn_mask: Callable[[SamplingParams], bool],
  153. prompt: Optional[bool] = None) -> List[int]:
  154. """Index for every seq where `fn_mask` is true, with an override
  155. for prompt seqs."""
  156. return [
  157. i for i, (p, sid) in enumerate(unrolled_seqs)
  158. if (fn_mask(p) if prompt is None else (
  159. prompt if sid < 0 else fn_mask(p)))
  160. ]
  161. def _filter(arr: List[T], indices: List[int]) -> List[T]:
  162. """Return only the elements of `arr` accessed by `indices`."""
  163. return [arr[i] for i in indices]
  164. miro_inds = _index(lambda p: p.mirostat_mode == 2, prompt=False)
  165. _miro_seqs = _filter(unrolled_seqs, miro_inds)
  166. quad_inds = _index(lambda p: p.smoothing_factor != 0)
  167. _quad_seqs = _filter(unrolled_seqs, quad_inds)
  168. # We need one base seed per Triton slice.
  169. triton_sampler_splits = get_num_triton_sampler_splits(vocab_size)
  170. n_seeds = triton_sampler_splits + extra_seeds_to_generate
  171. # Sequences get seeds. Prompt "sequences" do not.
  172. seed_indices = _index(lambda p: True, prompt=False)
  173. sampling_seeds = [
  174. cls._get_sequence_seeds(p.seed, n_seeds,
  175. p.sampling_type == SamplingType.GREEDY,
  176. seq_data[sid].get_len(), *extra_entropy,
  177. sid)
  178. for p, sid in _filter(unrolled_seqs, seed_indices)
  179. ]
  180. fvars = { # noqa
  181. "temperatures": _unroll(lambda p: p.temperature),
  182. "top_ps": _unroll(lambda p: p.top_p),
  183. "top_as": _unroll(lambda p: p.top_a),
  184. "min_ps": _unroll(lambda p: p.min_p),
  185. "tfss": _unroll(lambda p: p.tfs, prompt=1),
  186. "eta_cutoffs": _unroll(lambda p: p.eta_cutoff * 1e-4, prompt=0),
  187. "epsilon_cutoffs": _unroll(lambda p: p.epsilon_cutoff * 1e-4, 0),
  188. "typical_ps": _unroll(lambda p: p.typical_p, prompt=1),
  189. "pres_penalties": _unroll(lambda p: p.presence_penalty, prompt=0),
  190. "freq_penalties": _unroll(lambda p: p.frequency_penalty, prompt=0),
  191. "rep_penalties": _unroll(lambda p: p.repetition_penalty, prompt=1),
  192. "dynatemp_mins": _unroll(lambda p: p.dynatemp_min),
  193. "dynatemp_maxs": _unroll(lambda p: p.dynatemp_max),
  194. "dynatemp_exps": _unroll(lambda p: p.dynatemp_exponent),
  195. "miro_taus": [p.mirostat_tau for p, _ in _miro_seqs],
  196. "miro_etas": [p.mirostat_eta for p, _ in _miro_seqs],
  197. "miro_mus": [persistent.get(sid, "miro_mu", p.mirostat_tau * 2)
  198. for p, sid in _miro_seqs],
  199. "smoothing_factors": [p.smoothing_factor for p, _ in _quad_seqs],
  200. "smoothing_curves": [p.smoothing_curve for p, _ in _quad_seqs],
  201. }
  202. ivars = { # noqa
  203. "top_ks": _unroll(lambda p: vocab_size
  204. if p.top_k == -1 else min(p.top_k, vocab_size)),
  205. "miro_indices": miro_inds,
  206. "smoothing_indices": quad_inds,
  207. "seed_indices": seed_indices,
  208. }
  209. prompt_tokens = [[] if sid < 0 else seq_data[sid].prompt_token_ids
  210. for _, sid in unrolled_seqs]
  211. output_tokens = [[] if sid < 0 else seq_data[sid].output_token_ids
  212. for _, sid in unrolled_seqs]
  213. # need to transpose and make contiguous to copy the tensor correctly.
  214. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  215. seeds_transpose = list(map(list, zip(*sampling_seeds)))
  216. seeds_gpu = seeds_transpose[:triton_sampler_splits]
  217. extra_seeds_gpu = seeds_transpose[triton_sampler_splits:] or None
  218. # Note that the performance will be very bad without pinned memory.
  219. # Pinned memory allows non-blocking transfers to device.
  220. pin_memory = is_pin_memory_available()
  221. def _tensor(contents: list, dtype) -> torch.Tensor:
  222. loc_t = torch.tensor(contents,
  223. dtype=dtype,
  224. device="cpu",
  225. pin_memory=pin_memory)
  226. return loc_t.to(device=tgt_device, non_blocking=True)
  227. def _unjagged(arrs: List[List[T]], padval: T) -> List[List[T]]:
  228. max_len = max(len(arr) for arr in arrs)
  229. return [arr + [padval] * (max_len - len(arr)) for arr in arrs]
  230. return cls(
  231. # Flags and non-tensor fields
  232. do_temperatures=any(x != 1 for x in fvars["temperatures"]),
  233. do_dynatemps=(any(fvars["dynatemp_mins"])
  234. or any(fvars["dynatemp_maxs"])),
  235. do_top_ks=any(x != vocab_size for x in ivars["top_ks"]),
  236. do_top_ps=any(x != 1 for x in fvars["top_ps"]),
  237. do_top_as=any(fvars["top_as"]),
  238. do_min_ps=any(fvars["min_ps"]),
  239. do_tfss=any(x != 1 for x in fvars["tfss"]),
  240. do_eta_cutoffs=any(fvars["eta_cutoffs"]),
  241. do_epsilon_cutoffs=any(fvars["epsilon_cutoffs"]),
  242. do_typical_ps=any(x != 1 for x in fvars["typical_ps"]),
  243. do_penalties=(any(fvars["pres_penalties"])
  244. or any(fvars["freq_penalties"])
  245. or any(x != 1 for x in fvars["rep_penalties"])),
  246. do_quadratic=len(quad_inds) > 0,
  247. do_mirostat=len(miro_inds) > 0,
  248. miro_seqids=_filter([s for _, s in unrolled_seqs], miro_inds),
  249. # Float tensors
  250. **{n: _tensor(vals, float_dtype)
  251. for n, vals in fvars.items()},
  252. # Integer tensors
  253. **{n: _tensor(vals, torch.int)
  254. for n, vals in ivars.items()},
  255. # Token ID tensors
  256. prompt_tokens=_tensor(_unjagged(prompt_tokens, vocab_size),
  257. torch.long),
  258. output_tokens=_tensor(_unjagged(output_tokens, vocab_size),
  259. torch.long),
  260. # Seeds (only for triton, though?)
  261. seed_transpose=_tensor(seeds_gpu, torch.long),
  262. extra_seed_transpose=(_tensor(extra_seeds_gpu, torch.long)
  263. if extra_seeds_gpu else None),
  264. )
  265. @staticmethod
  266. def _get_sequence_seeds(
  267. seed: Optional[int],
  268. seeds_to_generate: int,
  269. is_greedy: bool,
  270. *extra_entropy: int,
  271. ):
  272. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  273. if is_greedy: # For the kernel, seed == 0 means greedy decoding.
  274. return [0] * seeds_to_generate
  275. if seed is None:
  276. randint_fn = random.randint
  277. else:
  278. randint_fn = random.Random(str((seed, ) + extra_entropy)).randint
  279. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  280. # If the user/random sets seed = 0 but request should
  281. # have sampling, we need to change it to something
  282. # else. We use a constant in that case.
  283. # This way we don't need to create and load a bool
  284. # matrix in the sampling kernel, which reduces CPU
  285. # overhead and latency.
  286. return [
  287. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  288. for _ in range(seeds_to_generate)
  289. ]