sampling_metadata.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  5. from aphrodite.common.sequence import SequenceData
  6. from aphrodite.common.utils import in_wsl
  7. _SAMPLING_EPS = 1e-5
  8. class PersistentMetadata:
  9. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  10. self._metadata: Dict[int, dict] = metadata or {}
  11. def get(self, seq_id: int) -> dict:
  12. return self._metadata.get(seq_id, {})
  13. class OutputMetadata(PersistentMetadata):
  14. def add(self, seq_id: int, key, val) -> None:
  15. if seq_id not in self._metadata:
  16. self._metadata[seq_id] = {}
  17. self._metadata[seq_id][key] = val
  18. class SamplingMetadata:
  19. """Metadata for input sequences. Used in sampler.
  20. Args:
  21. seq_groups: List of (seq_ids, sampling_params).
  22. seq_data: Seq_id -> SequenceData.
  23. prompt_lens: Lengths of prompts.
  24. selected_token_indices: Token indices selected for sampling.
  25. categorized_sample_indices: SamplingType -> token indices to sample.
  26. generators: List of torch.Generators to use for seeded sampling
  27. perform_sampling: Whether to perform sampling. This option is used to
  28. make the sampling only happens in the driver worker, and disable
  29. sampling in other worker processes.
  30. persistent_metadata: Metadata that persists across iterations.
  31. output_metadata: the output metadata.
  32. """
  33. def __init__(
  34. self,
  35. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
  36. seq_data: Optional[Dict[int, SequenceData]],
  37. prompt_lens: Optional[List[int]],
  38. selected_token_indices: torch.Tensor,
  39. categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
  40. generators: Optional[List[torch.Generator]] = None,
  41. perform_sampling: bool = True,
  42. persistent_metadata: Optional[PersistentMetadata] = None,
  43. output_metadata: Optional[OutputMetadata] = None,
  44. ) -> None:
  45. self.seq_groups = seq_groups
  46. self.seq_data = seq_data
  47. self.prompt_lens = prompt_lens
  48. self.selected_token_indices = selected_token_indices
  49. self.categorized_sample_indices = categorized_sample_indices
  50. self.generators = generators
  51. self.perform_sampling = perform_sampling
  52. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  53. self.output_metadata = output_metadata or OutputMetadata()
  54. self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
  55. def __repr__(self) -> str:
  56. return (
  57. "SamplingMetadata("
  58. f"seq_groups={self.seq_groups}, "
  59. f"seq_data={self.seq_data}, "
  60. f"prompt_lens={self.prompt_lens}, "
  61. f"selected_token_indices={self.selected_token_indices}, "
  62. f"categorized_sample_indices={self.categorized_sample_indices}, "
  63. f"perform_sampling={self.perform_sampling}, "
  64. f"persistent_metadata={self.persistent_metadata}, "
  65. f"output_metadata={self.output_metadata}) ")
  66. @dataclass
  67. class SamplingTensors:
  68. """Tensors for sampling."""
  69. temperatures: torch.Tensor
  70. top_ps: torch.Tensor
  71. top_ks: torch.Tensor
  72. top_as: torch.Tensor
  73. min_ps: torch.Tensor
  74. presence_penalties: torch.Tensor
  75. frequency_penalties: torch.Tensor
  76. repetition_penalties: torch.Tensor
  77. tfss: torch.Tensor
  78. eta_cutoffs: torch.Tensor
  79. epsilon_cutoffs: torch.Tensor
  80. typical_ps: torch.Tensor
  81. miro_taus: torch.Tensor
  82. miro_etas: torch.Tensor
  83. miro_mus: torch.Tensor
  84. miro_indices: torch.Tensor
  85. miro_seqids: List[int] # state writeback done CPU side
  86. dynatemp_mins: torch.Tensor
  87. dynatemp_maxs: torch.Tensor
  88. dynatemp_exps: torch.Tensor
  89. smoothing_factors: torch.Tensor
  90. smoothing_curves: torch.Tensor
  91. prompt_tokens: torch.Tensor
  92. output_tokens: torch.Tensor
  93. @classmethod
  94. def from_sampling_metadata(
  95. cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
  96. device: torch.device, dtype: torch.dtype
  97. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  98. bool, bool, bool, bool, bool]:
  99. prompt_tokens: List[List[int]] = []
  100. output_tokens: List[List[int]] = []
  101. top_ks: List[int] = []
  102. temperatures: List[float] = []
  103. top_ps: List[float] = []
  104. top_as: List[float] = []
  105. min_ps: List[float] = []
  106. presence_penalties: List[float] = []
  107. frequency_penalties: List[float] = []
  108. repetition_penalties: List[float] = []
  109. tfss: List[float] = []
  110. eta_cutoffs: List[float] = []
  111. epsilon_cutoffs: List[float] = []
  112. typical_ps: List[float] = []
  113. miro_taus: List[float] = []
  114. miro_etas: List[float] = []
  115. miro_mus: List[float] = []
  116. miro_indices: List[int] = []
  117. miro_seqids: List[int] = []
  118. dynatemp_mins: List[float] = []
  119. dynatemp_maxs: List[float] = []
  120. dynatemp_exps: List[float] = []
  121. smoothing_factors: List[float] = []
  122. smoothing_curves: List[float] = []
  123. index = 0 # temporary, needed for building miro_indices
  124. do_temperatures = False
  125. do_penalties = False
  126. do_topks = False
  127. do_topps = False
  128. do_topas = False
  129. do_minps = False
  130. do_tfss = False
  131. do_eta_cutoffs = False
  132. do_epsilon_cutoffs = False
  133. do_typical_ps = False
  134. do_quadratic = False
  135. do_mirostat = False
  136. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  137. seq_ids, sampling_params = seq_group
  138. temperature = sampling_params.temperature
  139. p = sampling_params.presence_penalty
  140. f = sampling_params.frequency_penalty
  141. r = sampling_params.repetition_penalty
  142. top_p = sampling_params.top_p
  143. # k should not be greater than the vocab size
  144. top_k = min(sampling_params.top_k, vocab_size)
  145. top_k = vocab_size if top_k == -1 else top_k
  146. top_a = sampling_params.top_a
  147. min_p = sampling_params.min_p
  148. tfs = sampling_params.tfs
  149. eta_cutoff = sampling_params.eta_cutoff
  150. epsilon_cutoff = sampling_params.epsilon_cutoff
  151. typical_p = sampling_params.typical_p
  152. miro_tau = sampling_params.mirostat_tau
  153. miro_eta = sampling_params.mirostat_eta
  154. dynatemp_min = sampling_params.dynatemp_min
  155. dynatemp_max = sampling_params.dynatemp_max
  156. dynatemp_exp = sampling_params.dynatemp_exponent
  157. smoothing_factor = sampling_params.smoothing_factor
  158. smoothing_curve = sampling_params.smoothing_curve
  159. if do_temperatures is False and temperature > _SAMPLING_EPS:
  160. do_temperatures = True
  161. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  162. or abs(f) >= _SAMPLING_EPS
  163. or abs(r - 1.0) >= _SAMPLING_EPS):
  164. do_penalties = True
  165. if do_topks is False and top_k != vocab_size:
  166. do_topks = True
  167. if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
  168. do_topps = True
  169. if do_topas is False and top_a > 0.0:
  170. do_topas = True
  171. if do_minps is False and min_p > _SAMPLING_EPS:
  172. do_minps = True
  173. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  174. do_tfss = True
  175. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  176. do_eta_cutoffs = True
  177. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  178. do_epsilon_cutoffs = True
  179. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  180. do_typical_ps = True
  181. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  182. or smoothing_curve > 1.0):
  183. do_quadratic = True
  184. if do_mirostat is False and sampling_params.mirostat_mode == 2:
  185. do_mirostat = True
  186. if (i < sampling_metadata.num_prompts
  187. and sampling_params.prompt_logprobs is not None):
  188. # For tokens in the prompt that we only need to get their
  189. # logprobs
  190. prompt_len = sampling_metadata.prompt_lens[i]
  191. index += sampling_metadata.prompt_lens[i] - 1
  192. temperatures += [temperature] * (prompt_len - 1)
  193. top_ps += [top_p] * (prompt_len - 1)
  194. top_ks += [top_k] * (prompt_len - 1)
  195. top_as += [top_a] * (prompt_len - 1)
  196. min_ps += [min_p] * (prompt_len - 1)
  197. presence_penalties += [0] * (prompt_len - 1)
  198. frequency_penalties += [0] * (prompt_len - 1)
  199. repetition_penalties += [1] * (prompt_len - 1)
  200. tfss += [1] * (prompt_len - 1)
  201. eta_cutoffs += [0] * (prompt_len - 1)
  202. epsilon_cutoffs += [0] * (prompt_len - 1)
  203. typical_ps += [1] * (prompt_len - 1)
  204. dynatemp_mins += [dynatemp_min] * (prompt_len - 1)
  205. dynatemp_maxs += [dynatemp_max] * (prompt_len - 1)
  206. dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
  207. smoothing_factors += [smoothing_factor] * (prompt_len - 1)
  208. smoothing_curves += [smoothing_curve] * (prompt_len - 1)
  209. prompt_tokens.extend([] for _ in range(prompt_len - 1))
  210. output_tokens.extend([] for _ in range(prompt_len - 1))
  211. for seq_id in seq_ids:
  212. seq_data = sampling_metadata.seq_data[seq_id]
  213. prompt_tokens.append(seq_data.prompt_token_ids)
  214. output_tokens.append(seq_data.output_token_ids)
  215. temperatures += [temperature] * len(seq_ids)
  216. top_ps += [top_p] * len(seq_ids)
  217. top_ks += [top_k] * len(seq_ids)
  218. top_as += [top_a] * len(seq_ids)
  219. min_ps += [min_p] * len(seq_ids)
  220. presence_penalties += [p] * len(seq_ids)
  221. frequency_penalties += [f] * len(seq_ids)
  222. repetition_penalties += [r] * len(seq_ids)
  223. tfss += [tfs] * len(seq_ids)
  224. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  225. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  226. typical_ps += [typical_p] * len(seq_ids)
  227. dynatemp_mins += [dynatemp_min] * len(seq_ids)
  228. dynatemp_maxs += [dynatemp_max] * len(seq_ids)
  229. dynatemp_exps += [dynatemp_exp] * len(seq_ids)
  230. smoothing_factors += [smoothing_factor] * len(seq_ids)
  231. smoothing_curves += [smoothing_curve] * len(seq_ids)
  232. if sampling_params.mirostat_mode == 2:
  233. miro_indices += [(index + i) for i in range(len(seq_ids))]
  234. miro_seqids += seq_ids
  235. miro_taus += [miro_tau] * len(seq_ids)
  236. miro_etas += [miro_eta] * len(seq_ids)
  237. miro_mus += [
  238. sampling_metadata.persistent_metadata.get(sid).get(
  239. "miro_mu", sampling_params.mirostat_tau * 2)
  240. for sid in seq_ids
  241. ]
  242. index += len(seq_ids)
  243. sampling_tensors = SamplingTensors.from_lists(
  244. temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
  245. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  246. epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
  247. dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices,
  248. miro_seqids, smoothing_factors, smoothing_curves, prompt_tokens,
  249. output_tokens, vocab_size, device, dtype)
  250. return (sampling_tensors, do_temperatures, do_penalties, do_topks,
  251. do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
  252. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
  253. @classmethod
  254. def from_lists(cls, temperatures: List[float], top_ps: List[float],
  255. top_ks: List[int], top_as: List[float], min_ps: List[float],
  256. presence_penalties: List[float],
  257. frequency_penalties: List[float],
  258. repetition_penalties: List[float], tfss: List[float],
  259. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  260. typical_ps: List[float], dynatemp_mins: List[float],
  261. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  262. miro_taus: List[float], miro_etas: List[float],
  263. miro_mus: List[float], miro_indices: List[int],
  264. miro_seqids: List[int], smoothing_factors: List[float],
  265. smoothing_curves: List[float],
  266. prompt_tokens: List[List[int]],
  267. output_tokens: List[List[int]], vocab_size: int,
  268. device: torch.device,
  269. dtype: torch.dtype) -> "SamplingTensors":
  270. # Note that the performance will be very bad without
  271. # pinned memory.
  272. pin_memory = not in_wsl()
  273. prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
  274. prompt_padded_tokens = [
  275. tokens + [vocab_size] * (prompt_max_len - len(tokens))
  276. for tokens in prompt_tokens
  277. ]
  278. output_max_len = max(len(tokens) for tokens in output_tokens)
  279. output_padded_tokens = [
  280. tokens + [vocab_size] * (output_max_len - len(tokens))
  281. for tokens in output_tokens
  282. ]
  283. temperatures_t = torch.tensor(temperatures,
  284. device="cpu",
  285. dtype=dtype,
  286. pin_memory=pin_memory)
  287. top_ps_t = torch.tensor(top_ps,
  288. device="cpu",
  289. dtype=dtype,
  290. pin_memory=pin_memory)
  291. top_ks_t = torch.tensor(top_ks,
  292. device="cpu",
  293. dtype=torch.int,
  294. pin_memory=pin_memory)
  295. top_as_t = torch.tensor(top_as,
  296. device="cpu",
  297. dtype=dtype,
  298. pin_memory=pin_memory)
  299. min_ps_t = torch.tensor(min_ps,
  300. device="cpu",
  301. dtype=dtype,
  302. pin_memory=pin_memory)
  303. presence_penalties_t = torch.tensor(presence_penalties,
  304. device="cpu",
  305. dtype=dtype,
  306. pin_memory=pin_memory)
  307. frequency_penalties_t = torch.tensor(frequency_penalties,
  308. device="cpu",
  309. dtype=dtype,
  310. pin_memory=pin_memory)
  311. repetition_penalties_t = torch.tensor(repetition_penalties,
  312. device="cpu",
  313. dtype=dtype,
  314. pin_memory=pin_memory)
  315. tfss_t = torch.tensor(tfss,
  316. device="cpu",
  317. dtype=dtype,
  318. pin_memory=pin_memory)
  319. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  320. device="cpu",
  321. dtype=dtype,
  322. pin_memory=pin_memory)
  323. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  324. device="cpu",
  325. dtype=dtype,
  326. pin_memory=pin_memory)
  327. typical_ps_t = torch.tensor(typical_ps,
  328. device="cpu",
  329. dtype=dtype,
  330. pin_memory=pin_memory)
  331. dynatemp_mins_t = torch.tensor(dynatemp_mins,
  332. device="cpu",
  333. dtype=dtype,
  334. pin_memory=pin_memory)
  335. dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
  336. device="cpu",
  337. dtype=dtype,
  338. pin_memory=pin_memory)
  339. dynatemp_exps_t = torch.tensor(dynatemp_exps,
  340. device="cpu",
  341. dtype=dtype,
  342. pin_memory=pin_memory)
  343. smoothing_factors_t = torch.tensor(smoothing_factors,
  344. device="cpu",
  345. dtype=dtype,
  346. pin_memory=pin_memory)
  347. smoothing_curves_t = torch.tensor(smoothing_curves,
  348. device="cpu",
  349. dtype=dtype,
  350. pin_memory=pin_memory)
  351. miro_taus_t = torch.tensor(miro_taus,
  352. device="cpu",
  353. dtype=dtype,
  354. pin_memory=pin_memory)
  355. miro_etas_t = torch.tensor(miro_etas,
  356. device="cpu",
  357. dtype=dtype,
  358. pin_memory=pin_memory)
  359. miro_mus_t = torch.tensor(miro_mus,
  360. device="cpu",
  361. dtype=dtype,
  362. pin_memory=pin_memory)
  363. miro_indices_t = torch.tensor(miro_indices,
  364. device="cpu",
  365. dtype=torch.int,
  366. pin_memory=pin_memory)
  367. prompt_tensor = torch.tensor(prompt_padded_tokens,
  368. device=device,
  369. dtype=torch.long,
  370. pin_memory=pin_memory)
  371. output_tensor = torch.tensor(output_padded_tokens,
  372. device=device,
  373. dtype=torch.long,
  374. pin_memory=pin_memory)
  375. # Because the memory is pinned, we can do non-blocking
  376. # transfer to device.
  377. return cls(
  378. temperatures=temperatures_t.to(device=device, non_blocking=True),
  379. top_ps=top_ps_t.to(device=device, non_blocking=True),
  380. top_ks=top_ks_t.to(device=device, non_blocking=True),
  381. top_as=top_as_t.to(device=device, non_blocking=True),
  382. min_ps=min_ps_t.to(device=device, non_blocking=True),
  383. presence_penalties=presence_penalties_t.to(device=device,
  384. non_blocking=True),
  385. frequency_penalties=frequency_penalties_t.to(device=device,
  386. non_blocking=True),
  387. repetition_penalties=repetition_penalties_t.to(device=device,
  388. non_blocking=True),
  389. tfss=tfss_t.to(device=device, non_blocking=True),
  390. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  391. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  392. non_blocking=True),
  393. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  394. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  395. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  396. smoothing_factors=smoothing_factors_t.to(device=device,
  397. non_blocking=True),
  398. smoothing_curves=smoothing_curves_t.to(device=device,
  399. non_blocking=True),
  400. miro_taus=miro_taus_t.to(device=device, non_blocking=True),
  401. miro_etas=miro_etas_t.to(device=device, non_blocking=True),
  402. miro_mus=miro_mus_t.to(device=device, non_blocking=True),
  403. miro_indices=miro_indices_t.to(device=device, non_blocking=True),
  404. miro_seqids=miro_seqids,
  405. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  406. prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
  407. output_tokens=output_tensor.to(device=device, non_blocking=True),
  408. )