sampling_metadata.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Tuple, Optional
  3. import torch
  4. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  5. from aphrodite.common.sequence import SequenceData
  6. from aphrodite.common.utils import in_wsl
  7. _SAMPLING_EPS = 1e-5
  8. class PersistentMetadata:
  9. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  10. self._metadata: Dict[int, dict] = metadata or {}
  11. def get(self, seq_id: int) -> dict:
  12. return self._metadata.get(seq_id, {})
  13. class OutputMetadata(PersistentMetadata):
  14. def add(self, seq_id: int, key, val) -> None:
  15. if seq_id not in self._metadata:
  16. self._metadata[seq_id] = {}
  17. self._metadata[seq_id][key] = val
  18. class SamplingMetadata:
  19. """Metadata for input sequences. Used in sampler.
  20. Args:
  21. seq_groups: List of (seq_ids, sampling_params).
  22. seq_data: Seq_id -> SequenceData.
  23. prompt_lens: Lengths of prompts.
  24. selected_token_indices: Token indices selected for sampling.
  25. categorized_sample_indices: SamplingType -> token indices to sample.
  26. perform_sampling: Whether to perform sampling. This option is used to
  27. make the sampling only happens in the driver worker, and disable
  28. sampling in other worker processes.
  29. persistent_metadata: Metadata that persists across iterations.
  30. output_metadata: the output metadata.
  31. """
  32. def __init__(
  33. self,
  34. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
  35. seq_data: Optional[Dict[int, SequenceData]],
  36. prompt_lens: Optional[List[int]],
  37. selected_token_indices: torch.Tensor,
  38. categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
  39. perform_sampling: bool = True,
  40. persistent_metadata: Optional[PersistentMetadata] = None,
  41. output_metadata: Optional[OutputMetadata] = None,
  42. ) -> None:
  43. self.seq_groups = seq_groups
  44. self.seq_data = seq_data
  45. self.prompt_lens = prompt_lens
  46. self.selected_token_indices = selected_token_indices
  47. self.categorized_sample_indices = categorized_sample_indices
  48. self.perform_sampling = perform_sampling
  49. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  50. self.output_metadata = output_metadata or OutputMetadata()
  51. self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
  52. def __repr__(self) -> str:
  53. return (
  54. "SamplingMetadata("
  55. f"seq_groups={self.seq_groups}, "
  56. f"seq_data={self.seq_data}, "
  57. f"prompt_lens={self.prompt_lens}, "
  58. f"selected_token_indices={self.selected_token_indices}, "
  59. f"categorized_sample_indices={self.categorized_sample_indices}, "
  60. f"perform_sampling={self.perform_sampling}, "
  61. f"persistent_metadata={self.persistent_metadata}, "
  62. f"output_metadata={self.output_metadata}) ")
  63. @dataclass
  64. class SamplingTensors:
  65. """Tensors for sampling."""
  66. temperatures: torch.Tensor
  67. top_ps: torch.Tensor
  68. top_ks: torch.Tensor
  69. top_as: torch.Tensor
  70. min_ps: torch.Tensor
  71. presence_penalties: torch.Tensor
  72. frequency_penalties: torch.Tensor
  73. repetition_penalties: torch.Tensor
  74. tfss: torch.Tensor
  75. eta_cutoffs: torch.Tensor
  76. epsilon_cutoffs: torch.Tensor
  77. typical_ps: torch.Tensor
  78. miro_taus: torch.Tensor
  79. miro_etas: torch.Tensor
  80. miro_mus: torch.Tensor
  81. miro_indices: torch.Tensor
  82. miro_seqids: List[int] # state writeback done CPU side
  83. dynatemp_ranges: torch.Tensor
  84. dynatemp_exps: torch.Tensor
  85. smoothing_factors: torch.Tensor
  86. prompt_tokens: torch.Tensor
  87. output_tokens: torch.Tensor
  88. @classmethod
  89. def from_sampling_metadata(
  90. cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
  91. device: torch.device, dtype: torch.dtype
  92. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  93. bool, bool, bool, bool, bool]:
  94. prompt_tokens: List[List[int]] = []
  95. output_tokens: List[List[int]] = []
  96. top_ks: List[int] = []
  97. temperatures: List[float] = []
  98. top_ps: List[float] = []
  99. top_as: List[float] = []
  100. min_ps: List[float] = []
  101. presence_penalties: List[float] = []
  102. frequency_penalties: List[float] = []
  103. repetition_penalties: List[float] = []
  104. tfss: List[float] = []
  105. eta_cutoffs: List[float] = []
  106. epsilon_cutoffs: List[float] = []
  107. typical_ps: List[float] = []
  108. miro_taus: List[float] = []
  109. miro_etas: List[float] = []
  110. miro_mus: List[float] = []
  111. miro_indices: List[int] = []
  112. miro_seqids: List[int] = []
  113. dynatemp_ranges: List[float] = []
  114. dynatemp_exps: List[float] = []
  115. smoothing_factors: List[float] = []
  116. index = 0 # temporary, needed for building miro_indices
  117. do_temperatures = False
  118. do_penalties = False
  119. do_topks = False
  120. do_topps = False
  121. do_topas = False
  122. do_minps = False
  123. do_tfss = False
  124. do_eta_cutoffs = False
  125. do_epsilon_cutoffs = False
  126. do_typical_ps = False
  127. do_quadratic = False
  128. do_mirostat = False
  129. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  130. seq_ids, sampling_params = seq_group
  131. temperature = sampling_params.temperature
  132. p = sampling_params.presence_penalty
  133. f = sampling_params.frequency_penalty
  134. r = sampling_params.repetition_penalty
  135. top_p = sampling_params.top_p
  136. # k should not be greater than the vocab size
  137. top_k = min(sampling_params.top_k, vocab_size)
  138. top_k = vocab_size if top_k == -1 else top_k
  139. top_a = sampling_params.top_a
  140. min_p = sampling_params.min_p
  141. tfs = sampling_params.tfs
  142. eta_cutoff = sampling_params.eta_cutoff
  143. epsilon_cutoff = sampling_params.epsilon_cutoff
  144. typical_p = sampling_params.typical_p
  145. miro_tau = sampling_params.mirostat_tau
  146. miro_eta = sampling_params.mirostat_eta
  147. dynatemp_range = sampling_params.dynatemp_range
  148. dynatemp_exp = sampling_params.dynatemp_exponent
  149. smoothing_factor = sampling_params.smoothing_factor
  150. if do_temperatures is False and temperature > _SAMPLING_EPS:
  151. do_temperatures = True
  152. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  153. or abs(f) >= _SAMPLING_EPS
  154. or abs(r - 1.0) >= _SAMPLING_EPS):
  155. do_penalties = True
  156. if do_topks is False and top_k != vocab_size:
  157. do_topks = True
  158. if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
  159. do_topps = True
  160. if do_topas is False and top_a > 0.0:
  161. do_topas = True
  162. if do_minps is False and min_p > _SAMPLING_EPS:
  163. do_minps = True
  164. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  165. do_tfss = True
  166. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  167. do_eta_cutoffs = True
  168. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  169. do_epsilon_cutoffs = True
  170. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  171. do_typical_ps = True
  172. if do_quadratic is False and smoothing_factor > _SAMPLING_EPS:
  173. do_quadratic = True
  174. if do_mirostat is False and sampling_params.mirostat_mode == 2:
  175. do_mirostat = True
  176. if (i < sampling_metadata.num_prompts
  177. and sampling_params.prompt_logprobs is not None):
  178. # For tokens in the prompt that we only need to get their
  179. # logprobs
  180. prompt_len = sampling_metadata.prompt_lens[i]
  181. index += sampling_metadata.prompt_lens[i] - 1
  182. temperatures += [temperature] * (prompt_len - 1)
  183. top_ps += [top_p] * (prompt_len - 1)
  184. top_ks += [top_k] * (prompt_len - 1)
  185. top_as += [top_a] * (prompt_len - 1)
  186. min_ps += [min_p] * (prompt_len - 1)
  187. presence_penalties += [0] * (prompt_len - 1)
  188. frequency_penalties += [0] * (prompt_len - 1)
  189. repetition_penalties += [1] * (prompt_len - 1)
  190. tfss += [1] * (prompt_len - 1)
  191. eta_cutoffs += [0] * (prompt_len - 1)
  192. epsilon_cutoffs += [0] * (prompt_len - 1)
  193. typical_ps += [1] * (prompt_len - 1)
  194. dynatemp_ranges += [dynatemp_range] * (prompt_len - 1)
  195. dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
  196. smoothing_factors += [smoothing_factor] * (prompt_len - 1)
  197. prompt_tokens.extend([] for _ in range(prompt_len - 1))
  198. output_tokens.extend([] for _ in range(prompt_len - 1))
  199. for seq_id in seq_ids:
  200. seq_data = sampling_metadata.seq_data[seq_id]
  201. prompt_tokens.append(seq_data.prompt_token_ids)
  202. output_tokens.append(seq_data.output_token_ids)
  203. temperatures += [temperature] * len(seq_ids)
  204. top_ps += [top_p] * len(seq_ids)
  205. top_ks += [top_k] * len(seq_ids)
  206. top_as += [top_a] * len(seq_ids)
  207. min_ps += [min_p] * len(seq_ids)
  208. presence_penalties += [p] * len(seq_ids)
  209. frequency_penalties += [f] * len(seq_ids)
  210. repetition_penalties += [r] * len(seq_ids)
  211. tfss += [tfs] * len(seq_ids)
  212. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  213. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  214. typical_ps += [typical_p] * len(seq_ids)
  215. dynatemp_ranges += [dynatemp_range] * len(seq_ids)
  216. dynatemp_exps += [dynatemp_exp] * len(seq_ids)
  217. smoothing_factors += [smoothing_factor] * len(seq_ids)
  218. if sampling_params.mirostat_mode == 2:
  219. miro_indices += [(index + i) for i in range(len(seq_ids))]
  220. miro_seqids += seq_ids
  221. miro_taus += [miro_tau] * len(seq_ids)
  222. miro_etas += [miro_eta] * len(seq_ids)
  223. miro_mus += [
  224. sampling_metadata.persistent_metadata.get(sid).get(
  225. "miro_mu", sampling_params.mirostat_tau * 2)
  226. for sid in seq_ids
  227. ]
  228. index += len(seq_ids)
  229. sampling_tensors = SamplingTensors.from_lists(
  230. temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
  231. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  232. epsilon_cutoffs, typical_ps, dynatemp_ranges, dynatemp_exps,
  233. miro_taus, miro_etas, miro_mus, miro_indices, miro_seqids,
  234. smoothing_factors, prompt_tokens, output_tokens, vocab_size,
  235. device, dtype)
  236. return (sampling_tensors, do_temperatures, do_penalties, do_topks,
  237. do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
  238. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
  239. @classmethod
  240. def from_lists(cls, temperatures: List[float], top_ps: List[float],
  241. top_ks: List[int], top_as: List[float], min_ps: List[float],
  242. presence_penalties: List[float],
  243. frequency_penalties: List[float],
  244. repetition_penalties: List[float], tfss: List[float],
  245. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  246. typical_ps: List[float], dynatemp_ranges: List[float],
  247. dynatemp_exps: List[float], miro_taus: List[float],
  248. miro_etas: List[float], miro_mus: List[float],
  249. miro_indices: List[int], miro_seqids: List[int],
  250. smoothing_factors: List[float],
  251. prompt_tokens: List[List[int]],
  252. output_tokens: List[List[int]], vocab_size: int,
  253. device: torch.device,
  254. dtype: torch.dtype) -> "SamplingTensors":
  255. # Note that the performance will be very bad without
  256. # pinned memory.
  257. pin_memory = not in_wsl()
  258. prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
  259. prompt_padded_tokens = [
  260. tokens + [vocab_size] * (prompt_max_len - len(tokens))
  261. for tokens in prompt_tokens
  262. ]
  263. output_max_len = max(len(tokens) for tokens in output_tokens)
  264. output_padded_tokens = [
  265. tokens + [vocab_size] * (output_max_len - len(tokens))
  266. for tokens in output_tokens
  267. ]
  268. temperatures_t = torch.tensor(temperatures,
  269. device="cpu",
  270. dtype=dtype,
  271. pin_memory=pin_memory)
  272. top_ps_t = torch.tensor(top_ps,
  273. device="cpu",
  274. dtype=dtype,
  275. pin_memory=pin_memory)
  276. top_ks_t = torch.tensor(top_ks,
  277. device="cpu",
  278. dtype=torch.int,
  279. pin_memory=pin_memory)
  280. top_as_t = torch.tensor(top_as,
  281. device="cpu",
  282. dtype=dtype,
  283. pin_memory=pin_memory)
  284. min_ps_t = torch.tensor(min_ps,
  285. device="cpu",
  286. dtype=dtype,
  287. pin_memory=pin_memory)
  288. presence_penalties_t = torch.tensor(presence_penalties,
  289. device="cpu",
  290. dtype=dtype,
  291. pin_memory=pin_memory)
  292. frequency_penalties_t = torch.tensor(frequency_penalties,
  293. device="cpu",
  294. dtype=dtype,
  295. pin_memory=pin_memory)
  296. repetition_penalties_t = torch.tensor(repetition_penalties,
  297. device="cpu",
  298. dtype=dtype,
  299. pin_memory=pin_memory)
  300. tfss_t = torch.tensor(tfss,
  301. device="cpu",
  302. dtype=dtype,
  303. pin_memory=pin_memory)
  304. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  305. device="cpu",
  306. dtype=dtype,
  307. pin_memory=pin_memory)
  308. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  309. device="cpu",
  310. dtype=dtype,
  311. pin_memory=pin_memory)
  312. typical_ps_t = torch.tensor(typical_ps,
  313. device="cpu",
  314. dtype=dtype,
  315. pin_memory=pin_memory)
  316. dynatemp_ranges_t = torch.tensor(dynatemp_ranges,
  317. device="cpu",
  318. dtype=dtype,
  319. pin_memory=pin_memory)
  320. dynatemp_exps_t = torch.tensor(dynatemp_exps,
  321. device="cpu",
  322. dtype=dtype,
  323. pin_memory=pin_memory)
  324. smoothing_factors_t = torch.tensor(smoothing_factors,
  325. device="cpu",
  326. dtype=dtype,
  327. pin_memory=pin_memory)
  328. miro_taus_t = torch.tensor(miro_taus,
  329. device="cpu",
  330. dtype=dtype,
  331. pin_memory=pin_memory)
  332. miro_etas_t = torch.tensor(miro_etas,
  333. device="cpu",
  334. dtype=dtype,
  335. pin_memory=pin_memory)
  336. miro_mus_t = torch.tensor(miro_mus,
  337. device="cpu",
  338. dtype=dtype,
  339. pin_memory=pin_memory)
  340. miro_indices_t = torch.tensor(miro_indices,
  341. device="cpu",
  342. dtype=torch.int,
  343. pin_memory=pin_memory)
  344. prompt_tensor = torch.tensor(prompt_padded_tokens,
  345. device=device,
  346. dtype=torch.long,
  347. pin_memory=pin_memory)
  348. output_tensor = torch.tensor(output_padded_tokens,
  349. device=device,
  350. dtype=torch.long,
  351. pin_memory=pin_memory)
  352. # Because the memory is pinned, we can do non-blocking
  353. # transfer to device.
  354. return cls(
  355. temperatures=temperatures_t.to(device=device, non_blocking=True),
  356. top_ps=top_ps_t.to(device=device, non_blocking=True),
  357. top_ks=top_ks_t.to(device=device, non_blocking=True),
  358. top_as=top_as_t.to(device=device, non_blocking=True),
  359. min_ps=min_ps_t.to(device=device, non_blocking=True),
  360. presence_penalties=presence_penalties_t.to(device=device,
  361. non_blocking=True),
  362. frequency_penalties=frequency_penalties_t.to(device=device,
  363. non_blocking=True),
  364. repetition_penalties=repetition_penalties_t.to(device=device,
  365. non_blocking=True),
  366. tfss=tfss_t.to(device=device, non_blocking=True),
  367. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  368. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  369. non_blocking=True),
  370. dynatemp_ranges=dynatemp_ranges_t.to(device=device,
  371. non_blocking=True),
  372. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  373. smoothing_factors=smoothing_factors_t.to(device=device,
  374. non_blocking=True),
  375. miro_taus=miro_taus_t.to(device=device, non_blocking=True),
  376. miro_etas=miro_etas_t.to(device=device, non_blocking=True),
  377. miro_mus=miro_mus_t.to(device=device, non_blocking=True),
  378. miro_indices=miro_indices_t.to(device=device, non_blocking=True),
  379. miro_seqids=miro_seqids,
  380. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  381. prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
  382. output_tokens=output_tensor.to(device=device, non_blocking=True),
  383. )