sampling_metadata.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  5. from aphrodite.common.sequence import SequenceData
  6. from aphrodite.common.utils import in_wsl
  7. _SAMPLING_EPS = 1e-5
  8. class PersistentMetadata:
  9. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  10. self._metadata: Dict[int, dict] = metadata or {}
  11. def get(self, seq_id: int) -> dict:
  12. return self._metadata.get(seq_id, {})
  13. class OutputMetadata(PersistentMetadata):
  14. def add(self, seq_id: int, key, val) -> None:
  15. if seq_id not in self._metadata:
  16. self._metadata[seq_id] = {}
  17. self._metadata[seq_id][key] = val
  18. class SamplingMetadata:
  19. """Metadata for input sequences. Used in sampler.
  20. Args:
  21. seq_groups: List of (seq_ids, sampling_params).
  22. seq_data: Seq_id -> SequenceData.
  23. prompt_lens: Lengths of prompts.
  24. selected_token_indices: Token indices selected for sampling.
  25. categorized_sample_indices: SamplingType -> token indices to sample.
  26. generators: List of torch.Generators to use for seeded sampling
  27. perform_sampling: Whether to perform sampling. This option is used to
  28. make the sampling only happens in the driver worker, and disable
  29. sampling in other worker processes.
  30. persistent_metadata: Metadata that persists across iterations.
  31. output_metadata: the output metadata.
  32. """
  33. def __init__(
  34. self,
  35. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
  36. seq_data: Optional[Dict[int, SequenceData]],
  37. prompt_lens: Optional[List[int]],
  38. selected_token_indices: torch.Tensor,
  39. categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
  40. generators: Optional[List[torch.Generator]] = None,
  41. perform_sampling: bool = True,
  42. persistent_metadata: Optional[PersistentMetadata] = None,
  43. output_metadata: Optional[OutputMetadata] = None,
  44. ) -> None:
  45. self.seq_groups = seq_groups
  46. self.seq_data = seq_data
  47. self.prompt_lens = prompt_lens
  48. self.selected_token_indices = selected_token_indices
  49. self.categorized_sample_indices = categorized_sample_indices
  50. self.generators = generators
  51. self.perform_sampling = perform_sampling
  52. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  53. self.output_metadata = output_metadata or OutputMetadata()
  54. self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
  55. def __repr__(self) -> str:
  56. return (
  57. "SamplingMetadata("
  58. f"seq_groups={self.seq_groups}, "
  59. f"seq_data={self.seq_data}, "
  60. f"prompt_lens={self.prompt_lens}, "
  61. f"selected_token_indices={self.selected_token_indices}, "
  62. f"categorized_sample_indices={self.categorized_sample_indices}, "
  63. f"perform_sampling={self.perform_sampling}, "
  64. f"persistent_metadata={self.persistent_metadata}, "
  65. f"output_metadata={self.output_metadata}) ")
  66. @dataclass
  67. class SamplingTensors:
  68. """Tensors for sampling."""
  69. temperatures: torch.Tensor
  70. top_ps: torch.Tensor
  71. top_ks: torch.Tensor
  72. top_as: torch.Tensor
  73. min_ps: torch.Tensor
  74. presence_penalties: torch.Tensor
  75. frequency_penalties: torch.Tensor
  76. repetition_penalties: torch.Tensor
  77. tfss: torch.Tensor
  78. eta_cutoffs: torch.Tensor
  79. epsilon_cutoffs: torch.Tensor
  80. typical_ps: torch.Tensor
  81. typical_p_sigmas: torch.Tensor
  82. miro_taus: torch.Tensor
  83. miro_etas: torch.Tensor
  84. miro_mus: torch.Tensor
  85. miro_indices: torch.Tensor
  86. miro_seqids: List[int] # state writeback done CPU side
  87. dynatemp_mins: torch.Tensor
  88. dynatemp_maxs: torch.Tensor
  89. dynatemp_exps: torch.Tensor
  90. smoothing_factors: torch.Tensor
  91. smoothing_curves: torch.Tensor
  92. prompt_tokens: torch.Tensor
  93. output_tokens: torch.Tensor
  94. @classmethod
  95. def from_sampling_metadata(
  96. cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
  97. device: torch.device, dtype: torch.dtype
  98. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  99. bool, bool, bool, bool, bool]:
  100. prompt_tokens: List[List[int]] = []
  101. output_tokens: List[List[int]] = []
  102. top_ks: List[int] = []
  103. temperatures: List[float] = []
  104. top_ps: List[float] = []
  105. top_as: List[float] = []
  106. min_ps: List[float] = []
  107. presence_penalties: List[float] = []
  108. frequency_penalties: List[float] = []
  109. repetition_penalties: List[float] = []
  110. tfss: List[float] = []
  111. eta_cutoffs: List[float] = []
  112. epsilon_cutoffs: List[float] = []
  113. typical_ps: List[float] = []
  114. typical_p_sigmas: List[float] = []
  115. miro_taus: List[float] = []
  116. miro_etas: List[float] = []
  117. miro_mus: List[float] = []
  118. miro_indices: List[int] = []
  119. miro_seqids: List[int] = []
  120. dynatemp_mins: List[float] = []
  121. dynatemp_maxs: List[float] = []
  122. dynatemp_exps: List[float] = []
  123. smoothing_factors: List[float] = []
  124. smoothing_curves: List[float] = []
  125. index = 0 # temporary, needed for building miro_indices
  126. do_temperatures = False
  127. do_penalties = False
  128. do_topks = False
  129. do_topps = False
  130. do_topas = False
  131. do_minps = False
  132. do_tfss = False
  133. do_eta_cutoffs = False
  134. do_epsilon_cutoffs = False
  135. do_typical_ps = False
  136. do_quadratic = False
  137. do_mirostat = False
  138. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  139. seq_ids, sampling_params = seq_group
  140. temperature = sampling_params.temperature
  141. p = sampling_params.presence_penalty
  142. f = sampling_params.frequency_penalty
  143. r = sampling_params.repetition_penalty
  144. top_p = sampling_params.top_p
  145. # k should not be greater than the vocab size
  146. top_k = min(sampling_params.top_k, vocab_size)
  147. top_k = vocab_size if top_k == -1 else top_k
  148. top_a = sampling_params.top_a
  149. min_p = sampling_params.min_p
  150. tfs = sampling_params.tfs
  151. eta_cutoff = sampling_params.eta_cutoff
  152. epsilon_cutoff = sampling_params.epsilon_cutoff
  153. typical_p = sampling_params.typical_p
  154. typical_p_sigma = sampling_params.typical_p_sigma
  155. miro_tau = sampling_params.mirostat_tau
  156. miro_eta = sampling_params.mirostat_eta
  157. dynatemp_min = sampling_params.dynatemp_min
  158. dynatemp_max = sampling_params.dynatemp_max
  159. dynatemp_exp = sampling_params.dynatemp_exponent
  160. smoothing_factor = sampling_params.smoothing_factor
  161. smoothing_curve = sampling_params.smoothing_curve
  162. if do_temperatures is False and temperature > _SAMPLING_EPS:
  163. do_temperatures = True
  164. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  165. or abs(f) >= _SAMPLING_EPS
  166. or abs(r - 1.0) >= _SAMPLING_EPS):
  167. do_penalties = True
  168. if do_topks is False and top_k != vocab_size:
  169. do_topks = True
  170. if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
  171. do_topps = True
  172. if do_topas is False and top_a > 0.0:
  173. do_topas = True
  174. if do_minps is False and min_p > _SAMPLING_EPS:
  175. do_minps = True
  176. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  177. do_tfss = True
  178. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  179. do_eta_cutoffs = True
  180. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  181. do_epsilon_cutoffs = True
  182. if do_typical_ps is False and (typical_p < 1.0 - _SAMPLING_EPS
  183. or typical_p_sigma > 0.0):
  184. do_typical_ps = True
  185. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  186. or smoothing_curve > 1.0):
  187. do_quadratic = True
  188. if do_mirostat is False and sampling_params.mirostat_mode == 2:
  189. do_mirostat = True
  190. if (i < sampling_metadata.num_prompts
  191. and sampling_params.prompt_logprobs is not None):
  192. # For tokens in the prompt that we only need to get their
  193. # logprobs
  194. prompt_len = sampling_metadata.prompt_lens[i]
  195. index += sampling_metadata.prompt_lens[i] - 1
  196. temperatures += [temperature] * (prompt_len - 1)
  197. top_ps += [top_p] * (prompt_len - 1)
  198. top_ks += [top_k] * (prompt_len - 1)
  199. top_as += [top_a] * (prompt_len - 1)
  200. min_ps += [min_p] * (prompt_len - 1)
  201. presence_penalties += [0] * (prompt_len - 1)
  202. frequency_penalties += [0] * (prompt_len - 1)
  203. repetition_penalties += [1] * (prompt_len - 1)
  204. tfss += [1] * (prompt_len - 1)
  205. eta_cutoffs += [0] * (prompt_len - 1)
  206. epsilon_cutoffs += [0] * (prompt_len - 1)
  207. typical_ps += [1] * (prompt_len - 1)
  208. typical_p_sigmas += [0] * (prompt_len - 1)
  209. dynatemp_mins += [dynatemp_min] * (prompt_len - 1)
  210. dynatemp_maxs += [dynatemp_max] * (prompt_len - 1)
  211. dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
  212. smoothing_factors += [smoothing_factor] * (prompt_len - 1)
  213. smoothing_curves += [smoothing_curve] * (prompt_len - 1)
  214. prompt_tokens.extend([] for _ in range(prompt_len - 1))
  215. output_tokens.extend([] for _ in range(prompt_len - 1))
  216. for seq_id in seq_ids:
  217. seq_data = sampling_metadata.seq_data[seq_id]
  218. prompt_tokens.append(seq_data.prompt_token_ids)
  219. output_tokens.append(seq_data.output_token_ids)
  220. temperatures += [temperature] * len(seq_ids)
  221. top_ps += [top_p] * len(seq_ids)
  222. top_ks += [top_k] * len(seq_ids)
  223. top_as += [top_a] * len(seq_ids)
  224. min_ps += [min_p] * len(seq_ids)
  225. presence_penalties += [p] * len(seq_ids)
  226. frequency_penalties += [f] * len(seq_ids)
  227. repetition_penalties += [r] * len(seq_ids)
  228. tfss += [tfs] * len(seq_ids)
  229. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  230. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  231. typical_ps += [typical_p] * len(seq_ids)
  232. typical_p_sigmas += [typical_p_sigma] * len(seq_ids)
  233. dynatemp_mins += [dynatemp_min] * len(seq_ids)
  234. dynatemp_maxs += [dynatemp_max] * len(seq_ids)
  235. dynatemp_exps += [dynatemp_exp] * len(seq_ids)
  236. smoothing_factors += [smoothing_factor] * len(seq_ids)
  237. smoothing_curves += [smoothing_curve] * len(seq_ids)
  238. if sampling_params.mirostat_mode == 2:
  239. miro_indices += [(index + i) for i in range(len(seq_ids))]
  240. miro_seqids += seq_ids
  241. miro_taus += [miro_tau] * len(seq_ids)
  242. miro_etas += [miro_eta] * len(seq_ids)
  243. miro_mus += [
  244. sampling_metadata.persistent_metadata.get(sid).get(
  245. "miro_mu", sampling_params.mirostat_tau * 2)
  246. for sid in seq_ids
  247. ]
  248. index += len(seq_ids)
  249. sampling_tensors = SamplingTensors.from_lists(
  250. temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
  251. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  252. epsilon_cutoffs, typical_ps, typical_p_sigmas, dynatemp_mins,
  253. dynatemp_maxs, dynatemp_exps, smoothing_factors, smoothing_curves,
  254. miro_taus, miro_etas, miro_mus, miro_indices, miro_seqids,
  255. prompt_tokens, output_tokens, vocab_size, device, dtype)
  256. return (sampling_tensors, do_temperatures, do_penalties, do_topks,
  257. do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
  258. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
  259. @classmethod
  260. def from_lists(cls, temperatures: List[float], top_ps: List[float],
  261. top_ks: List[int], top_as: List[float], min_ps: List[float],
  262. presence_penalties: List[float],
  263. frequency_penalties: List[float],
  264. repetition_penalties: List[float], tfss: List[float],
  265. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  266. typical_ps: List[float], typical_p_sigmas: List[float],
  267. dynatemp_mins: List[float], dynatemp_maxs: List[float],
  268. dynatemp_exps: List[float], smoothing_factors: List[float],
  269. smoothing_curves: List[float], miro_taus: List[float],
  270. miro_etas: List[float], miro_mus: List[float],
  271. miro_indices: List[int], miro_seqids: List[int],
  272. prompt_tokens: List[List[int]],
  273. output_tokens: List[List[int]], vocab_size: int,
  274. device: torch.device,
  275. dtype: torch.dtype) -> "SamplingTensors":
  276. # Note that the performance will be very bad without
  277. # pinned memory.
  278. pin_memory = not in_wsl()
  279. prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
  280. prompt_padded_tokens = [
  281. tokens + [vocab_size] * (prompt_max_len - len(tokens))
  282. for tokens in prompt_tokens
  283. ]
  284. output_max_len = max(len(tokens) for tokens in output_tokens)
  285. output_padded_tokens = [
  286. tokens + [vocab_size] * (output_max_len - len(tokens))
  287. for tokens in output_tokens
  288. ]
  289. temperatures_t = torch.tensor(temperatures,
  290. device="cpu",
  291. dtype=dtype,
  292. pin_memory=pin_memory)
  293. top_ps_t = torch.tensor(top_ps,
  294. device="cpu",
  295. dtype=dtype,
  296. pin_memory=pin_memory)
  297. top_ks_t = torch.tensor(top_ks,
  298. device="cpu",
  299. dtype=torch.int,
  300. pin_memory=pin_memory)
  301. top_as_t = torch.tensor(top_as,
  302. device="cpu",
  303. dtype=dtype,
  304. pin_memory=pin_memory)
  305. min_ps_t = torch.tensor(min_ps,
  306. device="cpu",
  307. dtype=dtype,
  308. pin_memory=pin_memory)
  309. presence_penalties_t = torch.tensor(presence_penalties,
  310. device="cpu",
  311. dtype=dtype,
  312. pin_memory=pin_memory)
  313. frequency_penalties_t = torch.tensor(frequency_penalties,
  314. device="cpu",
  315. dtype=dtype,
  316. pin_memory=pin_memory)
  317. repetition_penalties_t = torch.tensor(repetition_penalties,
  318. device="cpu",
  319. dtype=dtype,
  320. pin_memory=pin_memory)
  321. tfss_t = torch.tensor(tfss,
  322. device="cpu",
  323. dtype=dtype,
  324. pin_memory=pin_memory)
  325. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  326. device="cpu",
  327. dtype=dtype,
  328. pin_memory=pin_memory)
  329. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  330. device="cpu",
  331. dtype=dtype,
  332. pin_memory=pin_memory)
  333. typical_ps_t = torch.tensor(typical_ps,
  334. device="cpu",
  335. dtype=dtype,
  336. pin_memory=pin_memory)
  337. typical_p_sigmas_t = torch.tensor(typical_p_sigmas,
  338. device="cpu",
  339. dtype=dtype,
  340. pin_memory=pin_memory)
  341. dynatemp_mins_t = torch.tensor(dynatemp_mins,
  342. device="cpu",
  343. dtype=dtype,
  344. pin_memory=pin_memory)
  345. dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
  346. device="cpu",
  347. dtype=dtype,
  348. pin_memory=pin_memory)
  349. dynatemp_exps_t = torch.tensor(dynatemp_exps,
  350. device="cpu",
  351. dtype=dtype,
  352. pin_memory=pin_memory)
  353. smoothing_factors_t = torch.tensor(smoothing_factors,
  354. device="cpu",
  355. dtype=dtype,
  356. pin_memory=pin_memory)
  357. smoothing_curves_t = torch.tensor(smoothing_curves,
  358. device="cpu",
  359. dtype=dtype,
  360. pin_memory=pin_memory)
  361. miro_taus_t = torch.tensor(miro_taus,
  362. device="cpu",
  363. dtype=dtype,
  364. pin_memory=pin_memory)
  365. miro_etas_t = torch.tensor(miro_etas,
  366. device="cpu",
  367. dtype=dtype,
  368. pin_memory=pin_memory)
  369. miro_mus_t = torch.tensor(miro_mus,
  370. device="cpu",
  371. dtype=dtype,
  372. pin_memory=pin_memory)
  373. miro_indices_t = torch.tensor(miro_indices,
  374. device="cpu",
  375. dtype=torch.int,
  376. pin_memory=pin_memory)
  377. prompt_tensor = torch.tensor(prompt_padded_tokens,
  378. device=device,
  379. dtype=torch.long,
  380. pin_memory=pin_memory)
  381. output_tensor = torch.tensor(output_padded_tokens,
  382. device=device,
  383. dtype=torch.long,
  384. pin_memory=pin_memory)
  385. # Because the memory is pinned, we can do non-blocking
  386. # transfer to device.
  387. return cls(
  388. temperatures=temperatures_t.to(device=device, non_blocking=True),
  389. top_ps=top_ps_t.to(device=device, non_blocking=True),
  390. top_ks=top_ks_t.to(device=device, non_blocking=True),
  391. top_as=top_as_t.to(device=device, non_blocking=True),
  392. min_ps=min_ps_t.to(device=device, non_blocking=True),
  393. presence_penalties=presence_penalties_t.to(device=device,
  394. non_blocking=True),
  395. frequency_penalties=frequency_penalties_t.to(device=device,
  396. non_blocking=True),
  397. repetition_penalties=repetition_penalties_t.to(device=device,
  398. non_blocking=True),
  399. tfss=tfss_t.to(device=device, non_blocking=True),
  400. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  401. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  402. non_blocking=True),
  403. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  404. typical_p_sigmas=typical_p_sigmas_t.to(device=device,
  405. non_blocking=True),
  406. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  407. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  408. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  409. smoothing_factors=smoothing_factors_t.to(device=device,
  410. non_blocking=True),
  411. smoothing_curves=smoothing_curves_t.to(device=device,
  412. non_blocking=True),
  413. miro_taus=miro_taus_t.to(device=device, non_blocking=True),
  414. miro_etas=miro_etas_t.to(device=device, non_blocking=True),
  415. miro_mus=miro_mus_t.to(device=device, non_blocking=True),
  416. miro_indices=miro_indices_t.to(device=device, non_blocking=True),
  417. miro_seqids=miro_seqids,
  418. prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
  419. output_tokens=output_tensor.to(device=device, non_blocking=True),
  420. )