1
0

sampling_metadata.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Tuple, Optional
  3. import random
  4. import torch
  5. from aphrodite.modeling.layers.ops.sample import (get_num_triton_sampler_splits
  6. )
  7. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  8. from aphrodite.common.sequence import SequenceData
  9. from aphrodite.common.utils import is_pin_memory_available
  10. _SAMPLING_EPS = 1e-5
  11. _SEED_0_REPLACEMENT = 3403598558
  12. class PersistentMetadata:
  13. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  14. self._metadata: Dict[int, dict] = metadata or {}
  15. def get(self, seq_id: int) -> dict:
  16. return self._metadata.get(seq_id, {})
  17. class OutputMetadata(PersistentMetadata):
  18. def add(self, seq_id: int, key, val) -> None:
  19. if seq_id not in self._metadata:
  20. self._metadata[seq_id] = {}
  21. self._metadata[seq_id][key] = val
  22. class SamplingMetadata:
  23. """Metadata for input sequences. Used in sampler.
  24. Args:
  25. seq_groups: List of (seq_ids, sampling_params).
  26. seq_data: Seq_id -> SequenceData.
  27. prompt_lens: Lengths of prompts.
  28. selected_token_indices: Token indices selected for sampling.
  29. categorized_sample_indices: SamplingType -> token indices to sample.
  30. generators: List of torch.Generators to use for seeded sampling
  31. perform_sampling: Whether to perform sampling. This option is used to
  32. make the sampling only happens in the driver worker, and disable
  33. sampling in other worker processes.
  34. persistent_metadata: Metadata that persists across iterations.
  35. output_metadata: the output metadata.
  36. """
  37. def __init__(
  38. self,
  39. seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
  40. seq_data: Optional[Dict[int, SequenceData]],
  41. prompt_lens: Optional[List[int]],
  42. selected_token_indices: torch.Tensor,
  43. categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
  44. generators: Optional[List[torch.Generator]] = None,
  45. perform_sampling: bool = True,
  46. persistent_metadata: Optional[PersistentMetadata] = None,
  47. output_metadata: Optional[OutputMetadata] = None,
  48. ) -> None:
  49. self.seq_groups = seq_groups
  50. self.seq_data = seq_data
  51. self.prompt_lens = prompt_lens
  52. self.selected_token_indices = selected_token_indices
  53. self.categorized_sample_indices = categorized_sample_indices
  54. self.generators = generators
  55. self.perform_sampling = perform_sampling
  56. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  57. self.output_metadata = output_metadata or OutputMetadata()
  58. self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
  59. def __repr__(self) -> str:
  60. return (
  61. "SamplingMetadata("
  62. f"seq_groups={self.seq_groups}, "
  63. f"seq_data={self.seq_data}, "
  64. f"prompt_lens={self.prompt_lens}, "
  65. f"selected_token_indices={self.selected_token_indices}, "
  66. f"categorized_sample_indices={self.categorized_sample_indices}, "
  67. f"perform_sampling={self.perform_sampling}, "
  68. f"persistent_metadata={self.persistent_metadata}, "
  69. f"output_metadata={self.output_metadata}) ")
  70. @dataclass
  71. class SamplingTensors:
  72. """Tensors for sampling."""
  73. temperatures: torch.Tensor
  74. top_ps: torch.Tensor
  75. top_ks: torch.Tensor
  76. top_as: torch.Tensor
  77. min_ps: torch.Tensor
  78. presence_penalties: torch.Tensor
  79. frequency_penalties: torch.Tensor
  80. repetition_penalties: torch.Tensor
  81. tfss: torch.Tensor
  82. eta_cutoffs: torch.Tensor
  83. epsilon_cutoffs: torch.Tensor
  84. typical_ps: torch.Tensor
  85. miro_taus: torch.Tensor
  86. miro_etas: torch.Tensor
  87. miro_mus: torch.Tensor
  88. miro_indices: torch.Tensor
  89. miro_seqids: List[int] # state writeback done CPU side
  90. dynatemp_mins: torch.Tensor
  91. dynatemp_maxs: torch.Tensor
  92. dynatemp_exps: torch.Tensor
  93. smoothing_factors: torch.Tensor
  94. smoothing_curves: torch.Tensor
  95. sampling_seeds: torch.Tensor
  96. sample_indices: torch.Tensor
  97. extra_seeds: Optional[torch.Tensor]
  98. prompt_tokens: torch.Tensor
  99. output_tokens: torch.Tensor
  100. @classmethod
  101. def from_sampling_metadata(
  102. cls,
  103. sampling_metadata: "SamplingMetadata",
  104. vocab_size: int,
  105. device: torch.device,
  106. dtype: torch.dtype,
  107. *,
  108. extra_seeds_to_generate: int = 0,
  109. extra_entropy: Optional[Tuple[int, ...]] = None
  110. ) -> Tuple["SamplingTensors", bool, bool, bool, bool, bool, bool, bool,
  111. bool, bool, bool, bool, bool]:
  112. prompt_tokens: List[List[int]] = []
  113. output_tokens: List[List[int]] = []
  114. top_ks: List[int] = []
  115. temperatures: List[float] = []
  116. top_ps: List[float] = []
  117. top_as: List[float] = []
  118. min_ps: List[float] = []
  119. presence_penalties: List[float] = []
  120. frequency_penalties: List[float] = []
  121. repetition_penalties: List[float] = []
  122. tfss: List[float] = []
  123. eta_cutoffs: List[float] = []
  124. epsilon_cutoffs: List[float] = []
  125. typical_ps: List[float] = []
  126. miro_taus: List[float] = []
  127. miro_etas: List[float] = []
  128. miro_mus: List[float] = []
  129. miro_indices: List[int] = []
  130. miro_seqids: List[int] = []
  131. dynatemp_mins: List[float] = []
  132. dynatemp_maxs: List[float] = []
  133. dynatemp_exps: List[float] = []
  134. smoothing_factors: List[float] = []
  135. smoothing_curves: List[float] = []
  136. sampling_seeds: List[int] = []
  137. sample_indices: List[int] = []
  138. prompt_best_of: List[int] = []
  139. index = 0 # temporary, needed for building miro_indices
  140. do_temperatures = False
  141. do_penalties = False
  142. do_topks = False
  143. do_topps = False
  144. do_topas = False
  145. do_minps = False
  146. do_tfss = False
  147. do_eta_cutoffs = False
  148. do_epsilon_cutoffs = False
  149. do_typical_ps = False
  150. do_quadratic = False
  151. do_mirostat = False
  152. # We need one base seed per Triton slice.
  153. seeds_to_generate = (extra_seeds_to_generate +
  154. get_num_triton_sampler_splits(vocab_size))
  155. sample_indices_start_idx = 0
  156. for i, seq_group in enumerate(sampling_metadata.seq_groups):
  157. seq_ids, sampling_params = seq_group
  158. temperature = sampling_params.temperature
  159. p = sampling_params.presence_penalty
  160. f = sampling_params.frequency_penalty
  161. r = sampling_params.repetition_penalty
  162. top_p = sampling_params.top_p
  163. # k should not be greater than the vocab size
  164. top_k = min(sampling_params.top_k, vocab_size)
  165. top_k = vocab_size if top_k == -1 else top_k
  166. top_a = sampling_params.top_a
  167. min_p = sampling_params.min_p
  168. tfs = sampling_params.tfs
  169. eta_cutoff = sampling_params.eta_cutoff
  170. epsilon_cutoff = sampling_params.epsilon_cutoff
  171. typical_p = sampling_params.typical_p
  172. miro_tau = sampling_params.mirostat_tau
  173. miro_eta = sampling_params.mirostat_eta
  174. dynatemp_min = sampling_params.dynatemp_min
  175. dynatemp_max = sampling_params.dynatemp_max
  176. dynatemp_exp = sampling_params.dynatemp_exponent
  177. smoothing_factor = sampling_params.smoothing_factor
  178. smoothing_curve = sampling_params.smoothing_curve
  179. seed = sampling_params.seed
  180. is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
  181. if do_temperatures is False and temperature > _SAMPLING_EPS:
  182. do_temperatures = True
  183. if not do_penalties and (abs(p) >= _SAMPLING_EPS
  184. or abs(f) >= _SAMPLING_EPS
  185. or abs(r - 1.0) >= _SAMPLING_EPS):
  186. do_penalties = True
  187. if do_topks is False and top_k != vocab_size:
  188. do_topks = True
  189. if do_topps is False and top_p < 1.0 - _SAMPLING_EPS:
  190. do_topps = True
  191. if do_topas is False and top_a > 0.0:
  192. do_topas = True
  193. if do_minps is False and min_p > _SAMPLING_EPS:
  194. do_minps = True
  195. if do_tfss is False and tfs < 1.0 - _SAMPLING_EPS:
  196. do_tfss = True
  197. if do_eta_cutoffs is False and eta_cutoff > _SAMPLING_EPS:
  198. do_eta_cutoffs = True
  199. if do_epsilon_cutoffs is False and epsilon_cutoff > _SAMPLING_EPS:
  200. do_epsilon_cutoffs = True
  201. if do_typical_ps is False and typical_p < 1.0 - _SAMPLING_EPS:
  202. do_typical_ps = True
  203. if do_quadratic is False and (smoothing_factor > _SAMPLING_EPS
  204. or smoothing_curve > 1.0):
  205. do_quadratic = True
  206. if do_mirostat is False and sampling_params.mirostat_mode == 2:
  207. do_mirostat = True
  208. if (i < sampling_metadata.num_prompts
  209. and sampling_params.prompt_logprobs is not None):
  210. # For tokens in the prompt that we only need to get their
  211. # logprobs
  212. prompt_len = sampling_metadata.prompt_lens[i]
  213. index += sampling_metadata.prompt_lens[i] - 1
  214. temperatures += [temperature] * (prompt_len - 1)
  215. top_ps += [top_p] * (prompt_len - 1)
  216. top_ks += [top_k] * (prompt_len - 1)
  217. top_as += [top_a] * (prompt_len - 1)
  218. min_ps += [min_p] * (prompt_len - 1)
  219. presence_penalties += [0] * (prompt_len - 1)
  220. frequency_penalties += [0] * (prompt_len - 1)
  221. repetition_penalties += [1] * (prompt_len - 1)
  222. tfss += [1] * (prompt_len - 1)
  223. eta_cutoffs += [0] * (prompt_len - 1)
  224. epsilon_cutoffs += [0] * (prompt_len - 1)
  225. typical_ps += [1] * (prompt_len - 1)
  226. dynatemp_mins += [dynatemp_min] * (prompt_len - 1)
  227. dynatemp_maxs += [dynatemp_max] * (prompt_len - 1)
  228. dynatemp_exps += [dynatemp_exp] * (prompt_len - 1)
  229. smoothing_factors += [smoothing_factor] * (prompt_len - 1)
  230. smoothing_curves += [smoothing_curve] * (prompt_len - 1)
  231. prompt_tokens.extend([] for _ in range(prompt_len - 1))
  232. output_tokens.extend([] for _ in range(prompt_len - 1))
  233. for seq_id in seq_ids:
  234. seq_data = sampling_metadata.seq_data[seq_id]
  235. prompt_tokens.append(seq_data.prompt_token_ids)
  236. output_tokens.append(seq_data.output_token_ids)
  237. temperatures += [temperature] * len(seq_ids)
  238. top_ps += [top_p] * len(seq_ids)
  239. top_ks += [top_k] * len(seq_ids)
  240. top_as += [top_a] * len(seq_ids)
  241. min_ps += [min_p] * len(seq_ids)
  242. presence_penalties += [p] * len(seq_ids)
  243. frequency_penalties += [f] * len(seq_ids)
  244. repetition_penalties += [r] * len(seq_ids)
  245. tfss += [tfs] * len(seq_ids)
  246. eta_cutoffs += [eta_cutoff] * len(seq_ids)
  247. epsilon_cutoffs += [epsilon_cutoff] * len(seq_ids)
  248. typical_ps += [typical_p] * len(seq_ids)
  249. dynatemp_mins += [dynatemp_min] * len(seq_ids)
  250. dynatemp_maxs += [dynatemp_max] * len(seq_ids)
  251. dynatemp_exps += [dynatemp_exp] * len(seq_ids)
  252. smoothing_factors += [smoothing_factor] * len(seq_ids)
  253. smoothing_curves += [smoothing_curve] * len(seq_ids)
  254. if sampling_params.mirostat_mode == 2:
  255. miro_indices += [(index + i) for i in range(len(seq_ids))]
  256. miro_seqids += seq_ids
  257. miro_taus += [miro_tau] * len(seq_ids)
  258. miro_etas += [miro_eta] * len(seq_ids)
  259. miro_mus += [
  260. sampling_metadata.persistent_metadata.get(sid).get(
  261. "miro_mu", sampling_params.mirostat_tau * 2)
  262. for sid in seq_ids
  263. ]
  264. index += len(seq_ids)
  265. is_prompt = i < sampling_metadata.num_prompts
  266. if is_prompt:
  267. prompt_best_of.append(sampling_params.best_of)
  268. prompt_len = sampling_metadata.prompt_lens[i]
  269. if sampling_params.prompt_logprobs is not None:
  270. # NOTE: the sampling position is the last token
  271. # in the prompt
  272. sample_indices_start_idx += prompt_len - 1
  273. for seq_id in seq_ids:
  274. seq_data = sampling_metadata.seq_data[seq_id]
  275. extra_entropy = extra_entropy or ()
  276. seq_seeds = cls._get_sequence_seeds(
  277. seed,
  278. seq_data.get_len(),
  279. *extra_entropy,
  280. seq_id,
  281. seeds_to_generate=seeds_to_generate,
  282. is_greedy=is_greedy)
  283. sampling_seeds.append(seq_seeds)
  284. sample_indices.append(sample_indices_start_idx)
  285. sample_indices_start_idx += 1
  286. sampling_tensors = SamplingTensors.from_lists(
  287. temperatures, top_ps, top_ks, top_as, min_ps, presence_penalties,
  288. frequency_penalties, repetition_penalties, tfss, eta_cutoffs,
  289. epsilon_cutoffs, typical_ps, dynatemp_mins, dynatemp_maxs,
  290. dynatemp_exps, miro_taus, miro_etas, miro_mus, miro_indices,
  291. miro_seqids, smoothing_factors, smoothing_curves, sampling_seeds,
  292. sample_indices, prompt_tokens, output_tokens, vocab_size,
  293. extra_seeds_to_generate, device, dtype)
  294. return (sampling_tensors, do_temperatures, do_penalties, do_topks,
  295. do_topps, do_topas, do_minps, do_tfss, do_eta_cutoffs,
  296. do_epsilon_cutoffs, do_typical_ps, do_quadratic, do_mirostat)
  297. @classmethod
  298. def from_lists(cls, temperatures: List[float], top_ps: List[float],
  299. top_ks: List[int], top_as: List[float], min_ps: List[float],
  300. presence_penalties: List[float],
  301. frequency_penalties: List[float],
  302. repetition_penalties: List[float], tfss: List[float],
  303. eta_cutoffs: List[float], epsilon_cutoffs: List[float],
  304. typical_ps: List[float], dynatemp_mins: List[float],
  305. dynatemp_maxs: List[float], dynatemp_exps: List[float],
  306. miro_taus: List[float], miro_etas: List[float],
  307. miro_mus: List[float], miro_indices: List[int],
  308. miro_seqids: List[int], smoothing_factors: List[float],
  309. smoothing_curves: List[float], sampling_seeds: List[int],
  310. sample_indices: List[int], prompt_tokens: List[List[int]],
  311. output_tokens: List[List[int]], vocab_size: int,
  312. extra_seeds_to_generate: int, device: torch.device,
  313. dtype: torch.dtype) -> "SamplingTensors":
  314. # Note that the performance will be very bad without
  315. # pinned memory.
  316. pin_memory = is_pin_memory_available()
  317. prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
  318. prompt_padded_tokens = [
  319. tokens + [vocab_size] * (prompt_max_len - len(tokens))
  320. for tokens in prompt_tokens
  321. ]
  322. output_max_len = max(len(tokens) for tokens in output_tokens)
  323. output_padded_tokens = [
  324. tokens + [vocab_size] * (output_max_len - len(tokens))
  325. for tokens in output_tokens
  326. ]
  327. temperatures_t = torch.tensor(temperatures,
  328. device="cpu",
  329. dtype=dtype,
  330. pin_memory=pin_memory)
  331. top_ps_t = torch.tensor(top_ps,
  332. device="cpu",
  333. dtype=dtype,
  334. pin_memory=pin_memory)
  335. top_ks_t = torch.tensor(top_ks,
  336. device="cpu",
  337. dtype=torch.int,
  338. pin_memory=pin_memory)
  339. top_as_t = torch.tensor(top_as,
  340. device="cpu",
  341. dtype=dtype,
  342. pin_memory=pin_memory)
  343. min_ps_t = torch.tensor(min_ps,
  344. device="cpu",
  345. dtype=dtype,
  346. pin_memory=pin_memory)
  347. presence_penalties_t = torch.tensor(presence_penalties,
  348. device="cpu",
  349. dtype=dtype,
  350. pin_memory=pin_memory)
  351. frequency_penalties_t = torch.tensor(frequency_penalties,
  352. device="cpu",
  353. dtype=dtype,
  354. pin_memory=pin_memory)
  355. repetition_penalties_t = torch.tensor(repetition_penalties,
  356. device="cpu",
  357. dtype=dtype,
  358. pin_memory=pin_memory)
  359. tfss_t = torch.tensor(tfss,
  360. device="cpu",
  361. dtype=dtype,
  362. pin_memory=pin_memory)
  363. eta_cutoffs_t = torch.tensor(eta_cutoffs,
  364. device="cpu",
  365. dtype=dtype,
  366. pin_memory=pin_memory)
  367. epsilon_cutoffs_t = torch.tensor(epsilon_cutoffs,
  368. device="cpu",
  369. dtype=dtype,
  370. pin_memory=pin_memory)
  371. typical_ps_t = torch.tensor(typical_ps,
  372. device="cpu",
  373. dtype=dtype,
  374. pin_memory=pin_memory)
  375. dynatemp_mins_t = torch.tensor(dynatemp_mins,
  376. device="cpu",
  377. dtype=dtype,
  378. pin_memory=pin_memory)
  379. dynatemp_maxs_t = torch.tensor(dynatemp_maxs,
  380. device="cpu",
  381. dtype=dtype,
  382. pin_memory=pin_memory)
  383. dynatemp_exps_t = torch.tensor(dynatemp_exps,
  384. device="cpu",
  385. dtype=dtype,
  386. pin_memory=pin_memory)
  387. smoothing_factors_t = torch.tensor(smoothing_factors,
  388. device="cpu",
  389. dtype=dtype,
  390. pin_memory=pin_memory)
  391. smoothing_curves_t = torch.tensor(smoothing_curves,
  392. device="cpu",
  393. dtype=dtype,
  394. pin_memory=pin_memory)
  395. miro_taus_t = torch.tensor(miro_taus,
  396. device="cpu",
  397. dtype=dtype,
  398. pin_memory=pin_memory)
  399. miro_etas_t = torch.tensor(miro_etas,
  400. device="cpu",
  401. dtype=dtype,
  402. pin_memory=pin_memory)
  403. miro_mus_t = torch.tensor(miro_mus,
  404. device="cpu",
  405. dtype=dtype,
  406. pin_memory=pin_memory)
  407. miro_indices_t = torch.tensor(miro_indices,
  408. device="cpu",
  409. dtype=torch.int,
  410. pin_memory=pin_memory)
  411. sample_indices_t = torch.tensor(sample_indices,
  412. device="cpu",
  413. dtype=torch.int,
  414. pin_memory=pin_memory)
  415. prompt_tensor = torch.tensor(prompt_padded_tokens,
  416. device=device,
  417. dtype=torch.long,
  418. pin_memory=pin_memory)
  419. output_tensor = torch.tensor(output_padded_tokens,
  420. device=device,
  421. dtype=torch.long,
  422. pin_memory=pin_memory)
  423. # need to transpose and make contiguous to
  424. # copy the tensor correctly.
  425. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  426. sampling_seeds_t = torch.tensor(
  427. sampling_seeds,
  428. device="cpu",
  429. dtype=torch.long,
  430. pin_memory=pin_memory,
  431. ).T.contiguous()
  432. # Because the memory is pinned, we can do non-blocking
  433. # transfer to device.
  434. # How many seeds the sample operation itself will need.
  435. num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
  436. sampling_seeds_gpu = sampling_seeds_t.to(device=device,
  437. non_blocking=True)
  438. extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
  439. if not extra_seeds_gpu.numel():
  440. extra_seeds_gpu = None
  441. sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
  442. return cls(
  443. temperatures=temperatures_t.to(device=device, non_blocking=True),
  444. top_ps=top_ps_t.to(device=device, non_blocking=True),
  445. top_ks=top_ks_t.to(device=device, non_blocking=True),
  446. top_as=top_as_t.to(device=device, non_blocking=True),
  447. min_ps=min_ps_t.to(device=device, non_blocking=True),
  448. presence_penalties=presence_penalties_t.to(device=device,
  449. non_blocking=True),
  450. frequency_penalties=frequency_penalties_t.to(device=device,
  451. non_blocking=True),
  452. repetition_penalties=repetition_penalties_t.to(device=device,
  453. non_blocking=True),
  454. tfss=tfss_t.to(device=device, non_blocking=True),
  455. eta_cutoffs=eta_cutoffs_t.to(device=device, non_blocking=True),
  456. epsilon_cutoffs=epsilon_cutoffs_t.to(device=device,
  457. non_blocking=True),
  458. dynatemp_mins=dynatemp_mins_t.to(device=device, non_blocking=True),
  459. dynatemp_maxs=dynatemp_maxs_t.to(device=device, non_blocking=True),
  460. dynatemp_exps=dynatemp_exps_t.to(device=device, non_blocking=True),
  461. smoothing_factors=smoothing_factors_t.to(device=device,
  462. non_blocking=True),
  463. smoothing_curves=smoothing_curves_t.to(device=device,
  464. non_blocking=True),
  465. miro_taus=miro_taus_t.to(device=device, non_blocking=True),
  466. miro_etas=miro_etas_t.to(device=device, non_blocking=True),
  467. miro_mus=miro_mus_t.to(device=device, non_blocking=True),
  468. miro_indices=miro_indices_t.to(device=device, non_blocking=True),
  469. miro_seqids=miro_seqids,
  470. typical_ps=typical_ps_t.to(device=device, non_blocking=True),
  471. prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
  472. output_tokens=output_tensor.to(device=device, non_blocking=True),
  473. sampling_seeds=sampling_seeds_gpu,
  474. sample_indices=sample_indices_t.to(device=device,
  475. non_blocking=True),
  476. extra_seeds=extra_seeds_gpu,
  477. )
  478. @staticmethod
  479. def _get_sequence_seeds(
  480. seed: int,
  481. *extra_entropy: int,
  482. seeds_to_generate: int,
  483. is_greedy: bool,
  484. ):
  485. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  486. if not is_greedy:
  487. if seed is None:
  488. randint_fn = random.randint
  489. else:
  490. generator = random.Random(str((seed, ) + extra_entropy))
  491. randint_fn = generator.randint
  492. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  493. # If the user/random sets seed = 0 but request should
  494. # have sampling, we need to change it to something
  495. # else. We use a constant in that case.
  496. # This way we don't need to create and load a bool
  497. # matrix in the sampling kernel, which reduces CPU
  498. # overhead and latency.
  499. seq_seeds = [
  500. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  501. for _ in range(seeds_to_generate)
  502. ]
  503. else:
  504. # For the kernel, seed == 0 means greedy decoding.
  505. seq_seeds = [0] * seeds_to_generate
  506. return seq_seeds