1
0

sampling_metadata.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import random
  2. from dataclasses import dataclass
  3. from typing import Callable, Dict, List, Optional, Tuple, TypeVar
  4. import torch
  5. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  6. from aphrodite.common.sequence import SequenceData
  7. from aphrodite.common.utils import is_pin_memory_available
  8. from aphrodite.modeling.layers.ops.sample import get_num_triton_sampler_splits
  9. _SEED_0_REPLACEMENT = 3403598558 # chosen by fair roll of a die
  10. class PersistentMetadata:
  11. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  12. self._metadata: Dict[int, dict] = metadata or {}
  13. def get(self, seq_id: int, key, default=None):
  14. return self._metadata.get(seq_id, {}).get(key, default)
  15. class OutputMetadata():
  16. """Not symmetrical with PersistentMetadata because the process of
  17. sampling can produce unique metadata per sample, per sequence.
  18. The appropriate conversion would be `output[seq][sample](dict)` to
  19. `persist[new_seq_for_sample](dict)`"""
  20. def __init__(self):
  21. self._metadata: Dict[int, Dict[int, dict]] = {}
  22. def add(self, seq_id: int, sample_id: int, key, val) -> None:
  23. (self._metadata.setdefault(seq_id, {}).setdefault(sample_id,
  24. {})[key]) = val
  25. def get(self, seq_id: int, sample_id: int) -> dict:
  26. return self._metadata.get(seq_id, {}).get(sample_id, {})
  27. class SamplingMetadata:
  28. """Metadata for input sequences. Used in sampler.
  29. Args:
  30. seq_groups: List of (seq_ids, sampling_params).
  31. seq_data: Seq_id -> SequenceData.
  32. prompt_lens: Lengths of prompts.
  33. selected_token_indices: Token indices selected for sampling.
  34. categorized_sample_indices: SamplingType -> token indices to sample.
  35. generators: List of torch.Generators to use for seeded sampling
  36. perform_sampling: Whether to perform sampling. This option is used to
  37. make the sampling only happens in the driver worker, and disable
  38. sampling in other worker processes.
  39. persistent_metadata: Metadata that persists across iterations.
  40. output_metadata: the output metadata.
  41. """
  42. def __init__(
  43. self,
  44. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
  45. seq_data: Optional[Dict[int, SequenceData]],
  46. prompt_lens: Optional[List[int]],
  47. selected_token_indices: torch.Tensor,
  48. categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
  49. generators: Optional[List[torch.Generator]] = None,
  50. perform_sampling: bool = True,
  51. persistent_metadata: Optional[PersistentMetadata] = None,
  52. output_metadata: Optional[OutputMetadata] = None,
  53. ) -> None:
  54. self.seq_groups = seq_groups
  55. self.seq_data = seq_data
  56. self.prompt_lens = prompt_lens
  57. self.selected_token_indices = selected_token_indices
  58. self.categorized_sample_indices = categorized_sample_indices
  59. self.generators = generators
  60. self.perform_sampling = perform_sampling
  61. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  62. self.output_metadata = output_metadata or OutputMetadata()
  63. self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
  64. def __repr__(self) -> str:
  65. return (
  66. "SamplingMetadata("
  67. f"seq_groups={self.seq_groups}, "
  68. f"seq_data={self.seq_data}, "
  69. f"prompt_lens={self.prompt_lens}, "
  70. f"selected_token_indices={self.selected_token_indices}, "
  71. f"categorized_sample_indices={self.categorized_sample_indices}, "
  72. f"perform_sampling={self.perform_sampling}, "
  73. f"persistent_metadata={self.persistent_metadata}, "
  74. f"output_metadata={self.output_metadata}) ")
  75. @dataclass
  76. class SamplingTensors:
  77. """Tensors for sampling."""
  78. temperatures: torch.Tensor
  79. top_ps: torch.Tensor
  80. top_ks: torch.Tensor
  81. top_as: torch.Tensor
  82. min_ps: torch.Tensor
  83. pres_penalties: torch.Tensor
  84. freq_penalties: torch.Tensor
  85. rep_penalties: torch.Tensor
  86. tfss: torch.Tensor
  87. eta_cutoffs: torch.Tensor
  88. epsilon_cutoffs: torch.Tensor
  89. typical_ps: torch.Tensor
  90. miro_taus: torch.Tensor
  91. miro_etas: torch.Tensor
  92. miro_mus: torch.Tensor
  93. miro_indices: torch.Tensor
  94. miro_seqids: List[int] # state writeback done CPU side
  95. dynatemp_mins: torch.Tensor
  96. dynatemp_maxs: torch.Tensor
  97. dynatemp_exps: torch.Tensor
  98. smoothing_indices: torch.Tensor
  99. smoothing_factors: torch.Tensor
  100. smoothing_curves: torch.Tensor
  101. seed_indices: torch.Tensor
  102. seed_transpose: torch.Tensor
  103. extra_seed_transpose: Optional[torch.Tensor]
  104. prompt_tokens: torch.Tensor
  105. output_tokens: torch.Tensor
  106. do_temperatures: bool
  107. do_dynatemps: bool
  108. do_penalties: bool
  109. do_top_ks: bool
  110. do_top_ps: bool
  111. do_top_as: bool
  112. do_min_ps: bool
  113. do_tfss: bool
  114. do_eta_cutoffs: bool
  115. do_epsilon_cutoffs: bool
  116. do_typical_ps: bool
  117. do_quadratic: bool
  118. do_mirostat: bool
  119. @classmethod
  120. def from_sampling_metadata(
  121. cls,
  122. sampling_metadata: "SamplingMetadata",
  123. vocab_size: int,
  124. tgt_device: torch.device,
  125. float_dtype: torch.dtype,
  126. *,
  127. extra_seeds_to_generate: int = 0,
  128. extra_entropy: Optional[Tuple[int,
  129. ...]] = None) -> "SamplingTensors":
  130. prompt_lens = sampling_metadata.prompt_lens or []
  131. groups = sampling_metadata.seq_groups or []
  132. seq_data = sampling_metadata.seq_data or {}
  133. persistent = sampling_metadata.persistent_metadata
  134. extra_entropy = extra_entropy or ()
  135. # Flattened list of (params, sid) matching the logits tensor.
  136. # `sid < 0` implies a prompt seq.
  137. unrolled_seqs: List[Tuple[SamplingParams, int]] = []
  138. group_plens = prompt_lens + [0] * (len(groups) - len(prompt_lens))
  139. for (ids, params), prompt_len in zip(groups, group_plens):
  140. if prompt_len and params.prompt_logprobs is not None:
  141. unrolled_seqs.extend([(params, -1)] * (prompt_len - 1))
  142. unrolled_seqs.extend([(params, sid) for sid in ids])
  143. T = TypeVar('T')
  144. def _unroll(fn_val: Callable[[SamplingParams], T],
  145. prompt: Optional[T] = None) -> List[T]:
  146. """`fn_val` for every seq, with an override for prompt seqs."""
  147. return [
  148. prompt if sid < 0 and prompt is not None else fn_val(p)
  149. for p, sid in unrolled_seqs
  150. ]
  151. def _index(fn_mask: Callable[[SamplingParams], bool],
  152. prompt: Optional[bool] = None) -> List[int]:
  153. """Index for every seq where `fn_mask` is true, with an override
  154. for prompt seqs."""
  155. return [
  156. i for i, (p, sid) in enumerate(unrolled_seqs)
  157. if (fn_mask(p) if prompt is None else (
  158. prompt if sid < 0 else fn_mask(p)))
  159. ]
  160. def _filter(arr: List[T], indices: List[int]) -> List[T]:
  161. """Return only the elements of `arr` accessed by `indices`."""
  162. return [arr[i] for i in indices]
  163. miro_inds = _index(lambda p: p.mirostat_mode == 2, prompt=False)
  164. _miro_seqs = _filter(unrolled_seqs, miro_inds)
  165. quad_inds = _index(lambda p: p.smoothing_factor != 0)
  166. _quad_seqs = _filter(unrolled_seqs, quad_inds)
  167. # We need one base seed per Triton slice.
  168. triton_sampler_splits = get_num_triton_sampler_splits(vocab_size)
  169. n_seeds = triton_sampler_splits + extra_seeds_to_generate
  170. # Sequences get seeds. Prompt "sequences" do not.
  171. seed_indices = _index(lambda p: True, prompt=False)
  172. sampling_seeds = [
  173. cls._get_sequence_seeds(p.seed, n_seeds,
  174. p.sampling_type == SamplingType.GREEDY,
  175. seq_data[sid].get_len(), *extra_entropy,
  176. sid)
  177. for p, sid in _filter(unrolled_seqs, seed_indices)
  178. ]
  179. fvars = { # noqa
  180. "temperatures": _unroll(lambda p: p.temperature),
  181. "top_ps": _unroll(lambda p: p.top_p),
  182. "top_as": _unroll(lambda p: p.top_a),
  183. "min_ps": _unroll(lambda p: p.min_p),
  184. "tfss": _unroll(lambda p: p.tfs, prompt=1),
  185. "eta_cutoffs": _unroll(lambda p: p.eta_cutoff * 1e-4, prompt=0),
  186. "epsilon_cutoffs": _unroll(lambda p: p.epsilon_cutoff * 1e-4, 0),
  187. "typical_ps": _unroll(lambda p: p.typical_p, prompt=1),
  188. "pres_penalties": _unroll(lambda p: p.presence_penalty, prompt=0),
  189. "freq_penalties": _unroll(lambda p: p.frequency_penalty, prompt=0),
  190. "rep_penalties": _unroll(lambda p: p.repetition_penalty, prompt=1),
  191. "dynatemp_mins": _unroll(lambda p: p.dynatemp_min),
  192. "dynatemp_maxs": _unroll(lambda p: p.dynatemp_max),
  193. "dynatemp_exps": _unroll(lambda p: p.dynatemp_exponent),
  194. "miro_taus": [p.mirostat_tau for p, _ in _miro_seqs],
  195. "miro_etas": [p.mirostat_eta for p, _ in _miro_seqs],
  196. "miro_mus": [persistent.get(sid, "miro_mu", p.mirostat_tau * 2)
  197. for p, sid in _miro_seqs],
  198. "smoothing_factors": [p.smoothing_factor for p, _ in _quad_seqs],
  199. "smoothing_curves": [p.smoothing_curve for p, _ in _quad_seqs],
  200. }
  201. ivars = { # noqa
  202. "top_ks": _unroll(lambda p: vocab_size
  203. if p.top_k == -1 else min(p.top_k, vocab_size)),
  204. "miro_indices": miro_inds,
  205. "smoothing_indices": quad_inds,
  206. "seed_indices": seed_indices,
  207. }
  208. prompt_tokens = [[] if sid < 0 else seq_data[sid].prompt_token_ids
  209. for _, sid in unrolled_seqs]
  210. output_tokens = [[] if sid < 0 else seq_data[sid].output_token_ids
  211. for _, sid in unrolled_seqs]
  212. # need to transpose and make contiguous to copy the tensor correctly.
  213. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  214. seeds_transpose = list(map(list, zip(*sampling_seeds)))
  215. seeds_gpu = seeds_transpose[:triton_sampler_splits]
  216. extra_seeds_gpu = seeds_transpose[triton_sampler_splits:] or None
  217. # Note that the performance will be very bad without pinned memory.
  218. # Pinned memory allows non-blocking transfers to device.
  219. pin_memory = is_pin_memory_available()
  220. def _tensor(contents: list, dtype) -> torch.Tensor:
  221. loc_t = torch.tensor(contents,
  222. dtype=dtype,
  223. device="cpu",
  224. pin_memory=pin_memory)
  225. return loc_t.to(device=tgt_device, non_blocking=True)
  226. def _unjagged(arrs: List[List[T]], padval: T) -> List[List[T]]:
  227. max_len = max(len(arr) for arr in arrs)
  228. return [arr + [padval] * (max_len - len(arr)) for arr in arrs]
  229. return cls(
  230. # Flags and non-tensor fields
  231. do_temperatures=any(x != 1 for x in fvars["temperatures"]),
  232. do_dynatemps=(any(fvars["dynatemp_mins"])
  233. or any(fvars["dynatemp_maxs"])),
  234. do_top_ks=any(x != vocab_size for x in ivars["top_ks"]),
  235. do_top_ps=any(x != 1 for x in fvars["top_ps"]),
  236. do_top_as=any(fvars["top_as"]),
  237. do_min_ps=any(fvars["min_ps"]),
  238. do_tfss=any(x != 1 for x in fvars["tfss"]),
  239. do_eta_cutoffs=any(fvars["eta_cutoffs"]),
  240. do_epsilon_cutoffs=any(fvars["epsilon_cutoffs"]),
  241. do_typical_ps=any(x != 1 for x in fvars["typical_ps"]),
  242. do_penalties=(any(fvars["pres_penalties"])
  243. or any(fvars["freq_penalties"])
  244. or any(x != 1 for x in fvars["rep_penalties"])),
  245. do_quadratic=len(quad_inds) > 0,
  246. do_mirostat=len(miro_inds) > 0,
  247. miro_seqids=_filter([s for _, s in unrolled_seqs], miro_inds),
  248. # Float tensors
  249. **{n: _tensor(vals, float_dtype)
  250. for n, vals in fvars.items()},
  251. # Integer tensors
  252. **{n: _tensor(vals, torch.int)
  253. for n, vals in ivars.items()},
  254. # Token ID tensors
  255. prompt_tokens=_tensor(_unjagged(prompt_tokens, vocab_size),
  256. torch.long),
  257. output_tokens=_tensor(_unjagged(output_tokens, vocab_size),
  258. torch.long),
  259. # Seeds (only for triton, though?)
  260. seed_transpose=_tensor(seeds_gpu, torch.long),
  261. extra_seed_transpose=(_tensor(extra_seeds_gpu, torch.long)
  262. if extra_seeds_gpu else None),
  263. )
  264. @staticmethod
  265. def _get_sequence_seeds(
  266. seed: Optional[int],
  267. seeds_to_generate: int,
  268. is_greedy: bool,
  269. *extra_entropy: int,
  270. ):
  271. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  272. if is_greedy: # For the kernel, seed == 0 means greedy decoding.
  273. return [0] * seeds_to_generate
  274. if seed is None:
  275. randint_fn = random.randint
  276. else:
  277. randint_fn = random.Random(str((seed, ) + extra_entropy)).randint
  278. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  279. # If the user/random sets seed = 0 but request should
  280. # have sampling, we need to change it to something
  281. # else. We use a constant in that case.
  282. # This way we don't need to create and load a bool
  283. # matrix in the sampling kernel, which reduces CPU
  284. # overhead and latency.
  285. return [
  286. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  287. for _ in range(seeds_to_generate)
  288. ]