sampling_metadata.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. import random
  2. from dataclasses import dataclass
  3. from typing import Callable, Dict, List, Optional, Tuple, TypeVar
  4. import torch
  5. from aphrodite.common.sampling_params import SamplingParams, SamplingType
  6. from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata
  7. from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available,
  8. maybe_expand_dim)
  9. from aphrodite.modeling.layers.ops.sample import get_num_triton_sampler_splits
  10. _SEED_0_REPLACEMENT = 3403598558 # chosen by fair roll of a die
  11. class PersistentMetadata:
  12. def __init__(self, metadata: Optional[Dict[int, dict]] = None):
  13. self._metadata: Dict[int, dict] = metadata or {}
  14. def get(self, seq_id: int, key, default=None):
  15. return self._metadata.get(seq_id, {}).get(key, default)
  16. class OutputMetadata():
  17. """Not symmetrical with PersistentMetadata because the process of
  18. sampling can produce unique metadata per sample, per sequence.
  19. The appropriate conversion would be `output[seq][sample](dict)` to
  20. `persist[new_seq_for_sample](dict)`"""
  21. def __init__(self):
  22. self._metadata: Dict[int, Dict[int, dict]] = {}
  23. def add(self, seq_id: int, sample_id: int, key, val) -> None:
  24. (self._metadata.setdefault(seq_id, {}).setdefault(sample_id,
  25. {})[key]) = val
  26. def get(self, seq_id: int, sample_id: int) -> dict:
  27. return self._metadata.get(seq_id, {}).get(sample_id, {})
  28. @dataclass
  29. class SequenceGroupToSample:
  30. # Sequence ids for the sequence group in a previous step.
  31. seq_ids: List[int]
  32. sampling_params: SamplingParams
  33. # seq_id -> sequence data.
  34. seq_data: Dict[int, SequenceData]
  35. # The length of the prompt of the sequence group. None if it is in a decode
  36. # stage.
  37. prompt_len: Optional[int]
  38. # The length of the query tokens to compute in the current step. None if it
  39. # is in a decode stage. The length of subquery_len <= prompt_len.
  40. subquery_len: Optional[int]
  41. # A random number generator for sampling.
  42. generator: Optional[torch.Generator]
  43. # True if the sequence group is in prefill stage. False if it is in a
  44. # decode stage.
  45. is_prompt: bool
  46. # Query token indices from logits. to compute prompt logprob. Empty if
  47. # prompt logprob is not required.
  48. prompt_logprob_indices: List[int]
  49. # Sample token indices from logits. Empty if sampling is not required.
  50. sample_indices: List[int]
  51. @property
  52. def do_sample(self):
  53. return len(self.sample_indices) > 0
  54. def __post_init__(self):
  55. if len(self.prompt_logprob_indices) > 0:
  56. assert self.sampling_params.prompt_logprobs is not None
  57. if self.is_prompt:
  58. assert self.prompt_len is not None
  59. assert self.subquery_len is not None
  60. class SamplingMetadata:
  61. """Metadata for input sequences. Used in sampler.
  62. The usage is as follow;
  63. ```
  64. hidden_states = execute_model(...)
  65. logits = hidden_states[sampling_metadata.selected_token_indices]
  66. sample(logits)
  67. def sample(logits):
  68. # Use categorized_sample_indices for sampling....
  69. ```
  70. Args:
  71. seq_groups: List of batched sequence groups.
  72. selected_token_indices: (num_query_tokens_to_logprob). Indices to find
  73. logits from the initial model output hidden states.
  74. categorized_sample_indices: SamplingType -> token indices to sample.
  75. Each token indices is 2D tensor of (num_indices, num_indices) where
  76. the first item means the sample index within the returned logit
  77. (before pruning padding), and the second item means the sample
  78. index after pruning using selected_token_indices.
  79. For example, if the returned logit is [1, 2, 3], and we select
  80. [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
  81. The first tuple is [1, 2] (sampled index within original logit),
  82. and the second tuple is [0, 1] (sampled index within pruned logit).
  83. num_prompts: Number of prompt sequence groups in seq_groups.
  84. persistent_metadata: Metadata that persists across iterations.
  85. output_metadata: the output metadata.
  86. """
  87. def __init__(
  88. self,
  89. seq_groups: List[SequenceGroupToSample],
  90. selected_token_indices: torch.Tensor,
  91. categorized_sample_indices: Dict[SamplingType, torch.Tensor],
  92. num_prompts: int,
  93. persistent_metadata: Optional[PersistentMetadata] = None,
  94. output_metadata: Optional[OutputMetadata] = None,
  95. ) -> None:
  96. self.seq_groups = seq_groups
  97. self.selected_token_indices = selected_token_indices
  98. self.categorized_sample_indices = categorized_sample_indices
  99. self.num_prompts = num_prompts
  100. self.persistent_metadata = persistent_metadata or PersistentMetadata()
  101. self.output_metadata = output_metadata or OutputMetadata()
  102. @staticmethod
  103. def prepare(
  104. seq_group_metadata_list: List[SequenceGroupMetadata],
  105. prompt_lens: List[int],
  106. subquery_lens: Optional[List[int]],
  107. device: str,
  108. pin_memory: bool,
  109. ) -> "SamplingMetadata":
  110. (
  111. seq_groups,
  112. selected_token_indices,
  113. categorized_sample_indices,
  114. num_prompts,
  115. ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens,
  116. subquery_lens, device)
  117. selected_token_indices = async_tensor_h2d(selected_token_indices,
  118. dtype=torch.long,
  119. target_device=device,
  120. pin_memory=pin_memory)
  121. categorized_sample_indices = {
  122. t: maybe_expand_dim(
  123. async_tensor_h2d(seq_ids,
  124. dtype=torch.int,
  125. target_device=device,
  126. pin_memory=pin_memory), 2, 2)
  127. for t, seq_ids in categorized_sample_indices.items()
  128. }
  129. sampling_metadata = SamplingMetadata(
  130. seq_groups=seq_groups,
  131. selected_token_indices=selected_token_indices,
  132. categorized_sample_indices=categorized_sample_indices,
  133. num_prompts=num_prompts,
  134. )
  135. return sampling_metadata
  136. def __repr__(self) -> str:
  137. return (
  138. "SamplingMetadata("
  139. f"seq_groups={self.seq_groups}, "
  140. f"selected_token_indices={self.selected_token_indices}, "
  141. f"categorized_sample_indices={self.categorized_sample_indices}, "
  142. f"persistent_metadata={self.persistent_metadata}, "
  143. f"output_metadata={self.output_metadata}) ")
  144. def _prepare_seq_groups(
  145. seq_group_metadata_list: List[SequenceGroupMetadata],
  146. prompt_lens: List[int],
  147. subquery_lens: Optional[List[int]],
  148. device: str,
  149. ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
  150. SamplingType, List[Tuple[int, int]]], int]:
  151. """Prepare sequence groups and indices for sampling.
  152. Args:
  153. seq_group_metadata_list: A list of sequence group to batch.
  154. prompt_lens: A list of prompt lens per sequence group.
  155. Index of prompt len should match with seq_group_metadata_list.
  156. subquery_lens: A list of query lengths. Prompt lens include the length
  157. of entire prompt tokens, and it could be shorter.
  158. device: A device to use for random number generator,
  159. `SequenceGroupToSample.generator`.
  160. Returns:
  161. seq_groups: A list of sequence group to sample.
  162. selected_token_indices: See the definition from `SamplingMetadata`.
  163. categorized_sample_indices: See the definition from `SamplingMetadata`.
  164. num_prompts: Total number of prompts from `seq_group_metadata_list`.
  165. """
  166. # Batched sequence groups for the current model forward stsep.
  167. seq_groups: List[SequenceGroupToSample] = []
  168. # A list of token indices to sample/compute logprob. It is used to
  169. # prune the outcome logits from the model for the performance.
  170. selected_token_indices: List[int] = []
  171. # Used for selected_token_indices.
  172. model_output_idx = 0
  173. # Sampling type -> (
  174. # indices to sample/prompt logprob within pruned output logits,
  175. # indices to sample within pruned logits)
  176. categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = {
  177. t: []
  178. for t in SamplingType
  179. }
  180. # Index of logits to compute logprob. Logits include both prompt logprob
  181. # and sample logprob indices.
  182. logit_idx = 0
  183. # Index to sample from a sample tensor. It is used by triton sample kernel.
  184. # See `_sample_with_triton_kernel` for more details.
  185. sample_idx = 0
  186. # Total number of prompts from given sequence groups.
  187. num_prompts = 0
  188. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  189. seq_ids = list(seq_group_metadata.seq_data.keys())
  190. sampling_params = seq_group_metadata.sampling_params
  191. is_prompt = seq_group_metadata.is_prompt
  192. generator: Optional[torch.Generator] = None
  193. # If the current seq group is in decode stage, it is None.
  194. prompt_len: Optional[int] = None
  195. subquery_len: Optional[int] = None
  196. prompt_logprob_indices: List[int] = []
  197. sample_indices: List[int] = []
  198. do_sample = seq_group_metadata.do_sample
  199. if seq_group_metadata.is_prompt:
  200. if sampling_params.seed is not None:
  201. seq_group_metadata.state.generator = torch.Generator(
  202. device=device).manual_seed(sampling_params.seed)
  203. num_prompts += 1
  204. num_prefill_sample = len(seq_ids)
  205. assert num_prefill_sample == 1
  206. assert subquery_lens is not None and prompt_lens is not None
  207. subquery_len, prompt_len = subquery_lens[i], prompt_lens[i]
  208. # If we need sampling, exclude num_prefill_sample tokens from
  209. # prompt logprob.
  210. prompt_logprob_len = (subquery_len - num_prefill_sample
  211. if do_sample else subquery_len)
  212. sample_len = num_prefill_sample if do_sample else 0
  213. else:
  214. # Decode
  215. prompt_logprob_len = 0
  216. sample_len = len(seq_ids) if do_sample else 0
  217. # Update indices to select from the model output.
  218. """
  219. This blocks computes selected_token_indices which is used in the
  220. following way.
  221. hidden_states = model(...)
  222. logits = hidden_states[selected_token_indices]
  223. """
  224. if sampling_params.prompt_logprobs:
  225. selected_token_indices.extend(
  226. range(model_output_idx, model_output_idx + prompt_logprob_len))
  227. model_output_idx += prompt_logprob_len
  228. if do_sample:
  229. selected_token_indices.extend(
  230. range(model_output_idx, model_output_idx + sample_len))
  231. model_output_idx += sample_len
  232. # We now find indices for logprob computation and sampling.
  233. """
  234. This block computes categorized_sample_indices which is used in the
  235. following way.
  236. hidden_states = model(...)
  237. logits = hidden_states[selected_token_indices]
  238. def sample(logits):
  239. # Use categorized_sample_indices for sampling.
  240. # prompt_logprob_indices to find prompt logprob indices.
  241. # sample_indices to find sample indices.
  242. """
  243. if sampling_params.prompt_logprobs is not None:
  244. prompt_logprob_indices.extend(
  245. range(logit_idx, logit_idx + prompt_logprob_len))
  246. logit_idx += prompt_logprob_len
  247. if do_sample:
  248. sample_indices.extend(range(logit_idx, logit_idx + sample_len))
  249. categorized_sample_indices[sampling_params.sampling_type].extend(
  250. list(
  251. zip(range(logit_idx, logit_idx + sample_len),
  252. range(sample_idx, sample_idx + sample_len))))
  253. logit_idx += sample_len
  254. sample_idx += sample_len
  255. if sampling_params.seed is not None:
  256. generator = seq_group_metadata.state.generator
  257. seq_groups.append(
  258. SequenceGroupToSample(
  259. seq_ids=seq_ids,
  260. sampling_params=sampling_params,
  261. seq_data=seq_group_metadata.seq_data,
  262. prompt_len=prompt_len,
  263. subquery_len=subquery_len,
  264. generator=generator,
  265. is_prompt=is_prompt,
  266. prompt_logprob_indices=list(prompt_logprob_indices),
  267. sample_indices=list(sample_indices)))
  268. return (seq_groups, selected_token_indices, categorized_sample_indices,
  269. num_prompts)
  270. @dataclass
  271. class SamplingTensors:
  272. """Tensors for sampling."""
  273. temperatures: torch.Tensor
  274. top_ps: torch.Tensor
  275. top_ks: torch.Tensor
  276. top_as: torch.Tensor
  277. min_ps: torch.Tensor
  278. pres_penalties: torch.Tensor
  279. freq_penalties: torch.Tensor
  280. rep_penalties: torch.Tensor
  281. tfss: torch.Tensor
  282. eta_cutoffs: torch.Tensor
  283. epsilon_cutoffs: torch.Tensor
  284. typical_ps: torch.Tensor
  285. miro_taus: torch.Tensor
  286. miro_etas: torch.Tensor
  287. miro_mus: torch.Tensor
  288. miro_indices: torch.Tensor
  289. miro_seqids: List[int] # state writeback done CPU side
  290. dynatemp_mins: torch.Tensor
  291. dynatemp_maxs: torch.Tensor
  292. dynatemp_exps: torch.Tensor
  293. smoothing_indices: torch.Tensor
  294. smoothing_factors: torch.Tensor
  295. smoothing_curves: torch.Tensor
  296. seed_indices: torch.Tensor
  297. seed_transpose: torch.Tensor
  298. extra_seed_transpose: Optional[torch.Tensor]
  299. prompt_tokens: torch.Tensor
  300. output_tokens: torch.Tensor
  301. do_temperatures: bool
  302. do_dynatemps: bool
  303. do_penalties: bool
  304. do_top_ks: bool
  305. do_top_ps: bool
  306. do_top_as: bool
  307. do_min_ps: bool
  308. do_tfss: bool
  309. do_eta_cutoffs: bool
  310. do_epsilon_cutoffs: bool
  311. do_typical_ps: bool
  312. do_quadratic: bool
  313. do_mirostat: bool
  314. @classmethod
  315. def from_sampling_metadata(
  316. cls,
  317. sampling_metadata: "SamplingMetadata",
  318. vocab_size: int,
  319. tgt_device: torch.device,
  320. float_dtype: torch.dtype,
  321. *,
  322. extra_seeds_to_generate: int = 0,
  323. extra_entropy: Optional[Tuple[int,
  324. ...]] = None) -> "SamplingTensors":
  325. prompt_lens = sampling_metadata.prompt_lens or []
  326. groups = sampling_metadata.seq_groups or []
  327. seq_data = sampling_metadata.seq_data or {}
  328. persistent = sampling_metadata.persistent_metadata
  329. extra_entropy = extra_entropy or ()
  330. # Flattened list of (params, sid) matching the logits tensor.
  331. # `sid < 0` implies a prompt seq.
  332. unrolled_seqs: List[Tuple[SamplingParams, int]] = []
  333. group_plens = prompt_lens + [0] * (len(groups) - len(prompt_lens))
  334. for (ids, params), prompt_len in zip(groups, group_plens):
  335. if prompt_len and params.prompt_logprobs is not None:
  336. unrolled_seqs.extend([(params, -1)] * (prompt_len - 1))
  337. unrolled_seqs.extend([(params, sid) for sid in ids])
  338. T = TypeVar('T')
  339. def _unroll(fn_val: Callable[[SamplingParams], T],
  340. prompt: Optional[T] = None) -> List[T]:
  341. """`fn_val` for every seq, with an override for prompt seqs."""
  342. return [
  343. prompt if sid < 0 and prompt is not None else fn_val(p)
  344. for p, sid in unrolled_seqs
  345. ]
  346. def _index(fn_mask: Callable[[SamplingParams], bool],
  347. prompt: Optional[bool] = None) -> List[int]:
  348. """Index for every seq where `fn_mask` is true, with an override
  349. for prompt seqs."""
  350. return [
  351. i for i, (p, sid) in enumerate(unrolled_seqs)
  352. if (fn_mask(p) if prompt is None else (
  353. prompt if sid < 0 else fn_mask(p)))
  354. ]
  355. def _filter(arr: List[T], indices: List[int]) -> List[T]:
  356. """Return only the elements of `arr` accessed by `indices`."""
  357. return [arr[i] for i in indices]
  358. miro_inds = _index(lambda p: p.mirostat_mode == 2, prompt=False)
  359. _miro_seqs = _filter(unrolled_seqs, miro_inds)
  360. quad_inds = _index(lambda p: p.smoothing_factor != 0)
  361. _quad_seqs = _filter(unrolled_seqs, quad_inds)
  362. # We need one base seed per Triton slice.
  363. triton_sampler_splits = get_num_triton_sampler_splits(vocab_size)
  364. n_seeds = triton_sampler_splits + extra_seeds_to_generate
  365. # Sequences get seeds. Prompt "sequences" do not.
  366. seed_indices = _index(lambda p: True, prompt=False)
  367. sampling_seeds = [
  368. cls._get_sequence_seeds(p.seed, n_seeds,
  369. p.sampling_type == SamplingType.GREEDY,
  370. seq_data[sid].get_len(), *extra_entropy,
  371. sid)
  372. for p, sid in _filter(unrolled_seqs, seed_indices)
  373. ]
  374. fvars = { # noqa
  375. "temperatures": _unroll(lambda p: p.temperature),
  376. "top_ps": _unroll(lambda p: p.top_p),
  377. "top_as": _unroll(lambda p: p.top_a),
  378. "min_ps": _unroll(lambda p: p.min_p),
  379. "tfss": _unroll(lambda p: p.tfs, prompt=1),
  380. "eta_cutoffs": _unroll(lambda p: p.eta_cutoff * 1e-4, prompt=0),
  381. "epsilon_cutoffs": _unroll(lambda p: p.epsilon_cutoff * 1e-4, 0),
  382. "typical_ps": _unroll(lambda p: p.typical_p, prompt=1),
  383. "pres_penalties": _unroll(lambda p: p.presence_penalty, prompt=0),
  384. "freq_penalties": _unroll(lambda p: p.frequency_penalty, prompt=0),
  385. "rep_penalties": _unroll(lambda p: p.repetition_penalty, prompt=1),
  386. "dynatemp_mins": _unroll(lambda p: p.dynatemp_min),
  387. "dynatemp_maxs": _unroll(lambda p: p.dynatemp_max),
  388. "dynatemp_exps": _unroll(lambda p: p.dynatemp_exponent),
  389. "miro_taus": [p.mirostat_tau for p, _ in _miro_seqs],
  390. "miro_etas": [p.mirostat_eta for p, _ in _miro_seqs],
  391. "miro_mus": [persistent.get(sid, "miro_mu", p.mirostat_tau * 2)
  392. for p, sid in _miro_seqs],
  393. "smoothing_factors": [p.smoothing_factor for p, _ in _quad_seqs],
  394. "smoothing_curves": [p.smoothing_curve for p, _ in _quad_seqs],
  395. }
  396. ivars = { # noqa
  397. "top_ks": _unroll(lambda p: vocab_size
  398. if p.top_k == -1 else min(p.top_k, vocab_size)),
  399. "miro_indices": miro_inds,
  400. "smoothing_indices": quad_inds,
  401. "seed_indices": seed_indices,
  402. }
  403. prompt_tokens = [[] if sid < 0 else seq_data[sid].prompt_token_ids
  404. for _, sid in unrolled_seqs]
  405. output_tokens = [[] if sid < 0 else seq_data[sid].output_token_ids
  406. for _, sid in unrolled_seqs]
  407. # need to transpose and make contiguous to copy the tensor correctly.
  408. # [batch_size, n_seeds] -> [n_seeds, batch_size]
  409. seeds_transpose = list(map(list, zip(*sampling_seeds)))
  410. seeds_gpu = seeds_transpose[:triton_sampler_splits]
  411. extra_seeds_gpu = seeds_transpose[triton_sampler_splits:] or None
  412. # Note that the performance will be very bad without pinned memory.
  413. # Pinned memory allows non-blocking transfers to device.
  414. pin_memory = is_pin_memory_available()
  415. def _tensor(contents: list, dtype) -> torch.Tensor:
  416. loc_t = torch.tensor(contents,
  417. dtype=dtype,
  418. device="cpu",
  419. pin_memory=pin_memory)
  420. return loc_t.to(device=tgt_device, non_blocking=True)
  421. def _unjagged(arrs: List[List[T]], padval: T) -> List[List[T]]:
  422. max_len = max(len(arr) for arr in arrs)
  423. return [arr + [padval] * (max_len - len(arr)) for arr in arrs]
  424. return cls(
  425. # Flags and non-tensor fields
  426. do_temperatures=any(x != 1 for x in fvars["temperatures"]),
  427. do_dynatemps=(any(fvars["dynatemp_mins"])
  428. or any(fvars["dynatemp_maxs"])),
  429. do_top_ks=any(x != vocab_size for x in ivars["top_ks"]),
  430. do_top_ps=any(x != 1 for x in fvars["top_ps"]),
  431. do_top_as=any(fvars["top_as"]),
  432. do_min_ps=any(fvars["min_ps"]),
  433. do_tfss=any(x != 1 for x in fvars["tfss"]),
  434. do_eta_cutoffs=any(fvars["eta_cutoffs"]),
  435. do_epsilon_cutoffs=any(fvars["epsilon_cutoffs"]),
  436. do_typical_ps=any(x != 1 for x in fvars["typical_ps"]),
  437. do_penalties=(any(fvars["pres_penalties"])
  438. or any(fvars["freq_penalties"])
  439. or any(x != 1 for x in fvars["rep_penalties"])),
  440. do_quadratic=len(quad_inds) > 0,
  441. do_mirostat=len(miro_inds) > 0,
  442. miro_seqids=_filter([s for _, s in unrolled_seqs], miro_inds),
  443. # Float tensors
  444. **{n: _tensor(vals, float_dtype)
  445. for n, vals in fvars.items()},
  446. # Integer tensors
  447. **{n: _tensor(vals, torch.int)
  448. for n, vals in ivars.items()},
  449. # Token ID tensors
  450. prompt_tokens=_tensor(_unjagged(prompt_tokens, vocab_size),
  451. torch.long),
  452. output_tokens=_tensor(_unjagged(output_tokens, vocab_size),
  453. torch.long),
  454. # Seeds (only for triton, though?)
  455. seed_transpose=_tensor(seeds_gpu, torch.long),
  456. extra_seed_transpose=(_tensor(extra_seeds_gpu, torch.long)
  457. if extra_seeds_gpu else None),
  458. )
  459. @staticmethod
  460. def _get_sequence_seeds(
  461. seed: Optional[int],
  462. seeds_to_generate: int,
  463. is_greedy: bool,
  464. *extra_entropy: int,
  465. ):
  466. """Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
  467. if is_greedy: # For the kernel, seed == 0 means greedy decoding.
  468. return [0] * seeds_to_generate
  469. if seed is None:
  470. randint_fn = random.randint
  471. else:
  472. randint_fn = random.Random(str((seed, ) + extra_entropy)).randint
  473. lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
  474. # If the user/random sets seed = 0 but request should
  475. # have sampling, we need to change it to something
  476. # else. We use a constant in that case.
  477. # This way we don't need to create and load a bool
  478. # matrix in the sampling kernel, which reduces CPU
  479. # overhead and latency.
  480. return [
  481. randint_fn(lo, hi) or _SEED_0_REPLACEMENT
  482. for _ in range(seeds_to_generate)
  483. ]