tpu_model_runner.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. import time
  2. from typing import List, Mapping, Optional, Tuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch_xla.core.xla_model as xm
  7. from loguru import logger
  8. from aphrodite.attention import AttentionMetadata, get_attn_backend
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. ModelConfig, MultiModalConfig,
  11. ParallelConfig, SchedulerConfig)
  12. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  13. SamplerOutput, SequenceGroupMetadata,
  14. SequenceOutput)
  15. from aphrodite.common.utils import make_tensor_with_pad
  16. from aphrodite.modeling.model_loader import get_model
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
  19. MultiModalInputs)
  20. _PAD_SLOT_ID = -1 # NOTE: In PyTorch XLA, index -1 is ignored
  21. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  22. _ENABLE_TOP_P = False
  23. # FIXME: A temporary hack to support `n > 1`.
  24. # This can significantly affect the performance if too large.
  25. _MAX_NUM_SAMPLES = 128
  26. class TPUModelRunner:
  27. def __init__(
  28. self,
  29. model_config: ModelConfig,
  30. parallel_config: ParallelConfig,
  31. scheduler_config: SchedulerConfig,
  32. device_config: DeviceConfig,
  33. cache_config: CacheConfig,
  34. load_config: LoadConfig,
  35. multimodal_config: Optional[MultiModalConfig] = None,
  36. is_driver_worker: bool = False,
  37. ):
  38. self.model_config = model_config
  39. self.parallel_config = parallel_config
  40. self.scheduler_config = scheduler_config
  41. self.device_config = device_config
  42. self.cache_config = cache_config
  43. self.load_config = load_config
  44. self.multimodal_config = multimodal_config
  45. self.is_driver_worker = is_driver_worker
  46. self.block_size = self.cache_config.block_size
  47. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  48. self.block_size)
  49. self.block_tables = np.zeros(
  50. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  51. dtype=np.int32)
  52. self.attn_backend = get_attn_backend(
  53. self.model_config.get_num_attention_heads(self.parallel_config),
  54. self.model_config.get_head_size(),
  55. self.model_config.get_num_kv_heads(self.parallel_config),
  56. self.model_config.get_sliding_window(),
  57. self.model_config.dtype,
  58. self.cache_config.cache_dtype,
  59. self.block_size,
  60. False,
  61. )
  62. # Multi-modal data support
  63. self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
  64. .create_input_mapper(self.model_config)
  65. def load_model(self) -> None:
  66. self.device = self.device_config.device
  67. model = get_model(
  68. model_config=self.model_config,
  69. load_config=self.load_config,
  70. device_config=self.device_config,
  71. parallel_config=self.parallel_config,
  72. cache_config=self.cache_config,
  73. scheduler_config=self.scheduler_config,
  74. multimodal_config=self.multimodal_config,
  75. lora_config=None,
  76. )
  77. xm.wait_device_ops()
  78. model = ModelWrapper(model)
  79. self.model = torch.compile(model, backend="openxla", fullgraph=True)
  80. def _dummy_run(
  81. self,
  82. batch_size: int,
  83. seq_len: int,
  84. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  85. is_prompt: bool,
  86. ) -> None:
  87. if is_prompt:
  88. seq_len = (seq_len + 15) // 16 * 16
  89. token_ids = torch.zeros((batch_size, seq_len),
  90. dtype=torch.int32,
  91. device=self.device)
  92. position_ids = torch.zeros((batch_size, seq_len),
  93. dtype=torch.int32,
  94. device=self.device)
  95. slot_mapping = torch.zeros((batch_size, seq_len),
  96. dtype=torch.int64,
  97. device=self.device)
  98. attn_metadata = self.attn_backend.make_metadata(
  99. num_prefills=batch_size,
  100. num_prefill_tokens=batch_size * seq_len,
  101. num_decode_tokens=0,
  102. slot_mapping=slot_mapping,
  103. block_tables=None,
  104. context_lens=None,
  105. )
  106. input_lens = torch.ones((batch_size, ),
  107. dtype=torch.int32,
  108. device=self.device)
  109. else:
  110. assert seq_len == 1
  111. token_ids = torch.zeros((batch_size, seq_len),
  112. dtype=torch.int32,
  113. device=self.device)
  114. position_ids = torch.zeros((batch_size, seq_len),
  115. dtype=torch.int32,
  116. device=self.device)
  117. slot_mapping = torch.zeros((batch_size, seq_len),
  118. dtype=torch.int64,
  119. device=self.device)
  120. block_tables = torch.zeros(
  121. (batch_size, self.max_num_blocks_per_seq),
  122. dtype=torch.int32,
  123. device=self.device)
  124. context_lens = torch.ones((batch_size, ),
  125. dtype=torch.int32,
  126. device=self.device)
  127. input_lens = torch.ones((batch_size, ),
  128. dtype=torch.int32,
  129. device=self.device)
  130. attn_metadata = self.attn_backend.make_metadata(
  131. num_prefills=0,
  132. num_prefill_tokens=0,
  133. num_decode_tokens=batch_size * seq_len,
  134. slot_mapping=slot_mapping,
  135. block_tables=block_tables,
  136. context_lens=context_lens,
  137. )
  138. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  139. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  140. # Dummy run.
  141. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  142. self.model(token_ids, position_ids, kv_caches, attn_metadata,
  143. input_lens, None, t, p, num_samples)
  144. def warmup_model(
  145. self,
  146. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  147. ) -> None:
  148. # Prefill
  149. logger.info("Compiling the model with different input shapes...")
  150. start = time.time()
  151. for batch_size in [1]:
  152. seq_len = 16
  153. while True:
  154. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  155. xm.wait_device_ops()
  156. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  157. if seq_len >= self.model_config.max_model_len:
  158. break
  159. num_tokens = batch_size * seq_len
  160. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  161. break
  162. seq_len = seq_len * 2
  163. end = time.time()
  164. logger.info("Compilation for prefill done in %.2f s.", end - start)
  165. # Decode
  166. start = time.time()
  167. seq_len = 1
  168. batch_size = 1
  169. while True:
  170. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  171. xm.wait_device_ops()
  172. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  173. if batch_size >= self.scheduler_config.max_num_seqs:
  174. break
  175. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  176. end = time.time()
  177. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  178. def _prepare_prompt(
  179. self,
  180. seq_group_metadata_list: List[SequenceGroupMetadata],
  181. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
  182. Mapping[str, BatchedTensors]]:
  183. assert len(seq_group_metadata_list) > 0
  184. input_tokens: List[List[int]] = []
  185. input_positions: List[List[int]] = []
  186. prompt_lens: List[int] = []
  187. slot_mapping: List[List[int]] = []
  188. multi_modal_inputs_list: List[MultiModalInputs] = []
  189. for seq_group_metadata in seq_group_metadata_list:
  190. assert seq_group_metadata.is_prompt
  191. seq_ids = list(seq_group_metadata.seq_data.keys())
  192. assert len(seq_ids) == 1
  193. seq_id = seq_ids[0]
  194. seq_data = seq_group_metadata.seq_data[seq_id]
  195. # Could include output tokens when a request is preempted.
  196. prompt_tokens = seq_data.get_token_ids()
  197. prompt_len = len(prompt_tokens)
  198. prompt_lens.append(prompt_len)
  199. input_tokens.append(prompt_tokens)
  200. input_positions.append(list(range(prompt_len)))
  201. assert seq_group_metadata.block_tables is not None
  202. block_table = seq_group_metadata.block_tables[seq_id]
  203. slot_mapping.append([])
  204. for i in range(prompt_len):
  205. block_number = block_table[i // self.block_size]
  206. block_offset = i % self.block_size
  207. slot = block_number * self.block_size + block_offset
  208. slot_mapping[-1].append(slot)
  209. mm_data = seq_group_metadata.multi_modal_data
  210. if mm_data:
  211. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  212. multi_modal_inputs_list.append(mm_kwargs)
  213. assert len(prompt_lens) > 0
  214. num_prefills = len(prompt_lens)
  215. num_prefill_tokens = sum(prompt_lens)
  216. # Add paddings to make the shape [batch_size, max_prompt_len] where
  217. # max_prompt_len is smallest power of 2 that is greater than or equal
  218. # to the maximum prompt length.
  219. # We need the 2D input shape because the Pallas FlashAttention kernel
  220. # does not support packed 1D inputs.
  221. # We pad the seq_len to powers of 2 to reduce the compilation overhead.
  222. max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
  223. input_tokens = make_tensor_with_pad(input_tokens,
  224. max_prompt_len,
  225. pad=0,
  226. dtype=torch.int32,
  227. device=self.device)
  228. input_positions = make_tensor_with_pad(input_positions,
  229. max_prompt_len,
  230. pad=0,
  231. dtype=torch.int32,
  232. device=self.device)
  233. slot_mapping = make_tensor_with_pad(slot_mapping,
  234. max_prompt_len,
  235. pad=_PAD_SLOT_ID,
  236. dtype=torch.int64,
  237. device=self.device)
  238. prompt_lens = torch.tensor(prompt_lens,
  239. dtype=torch.int32,
  240. device=self.device)
  241. attn_metadata = self.attn_backend.make_metadata(
  242. num_prefills=num_prefills,
  243. num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used.
  244. num_decode_tokens=0,
  245. slot_mapping=slot_mapping,
  246. block_tables=None,
  247. context_lens=None,
  248. )
  249. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
  250. device=self.device)
  251. return (input_tokens, input_positions, attn_metadata, prompt_lens,
  252. multi_modal_kwargs)
  253. def _prepare_decode(
  254. self,
  255. seq_group_metadata_list: List[SequenceGroupMetadata],
  256. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor,
  257. Mapping[str, BatchedTensors]]:
  258. assert len(seq_group_metadata_list) > 0
  259. input_tokens: List[List[int]] = []
  260. input_positions: List[List[int]] = []
  261. slot_mapping: List[List[int]] = []
  262. context_lens: List[int] = []
  263. multi_modal_inputs_list: List[MultiModalInputs] = []
  264. batch_idx = 0
  265. for seq_group_metadata in seq_group_metadata_list:
  266. assert not seq_group_metadata.is_prompt
  267. seq_ids = list(seq_group_metadata.seq_data.keys())
  268. for seq_id in seq_ids:
  269. seq_data = seq_group_metadata.seq_data[seq_id]
  270. generation_token = seq_data.get_last_token_id()
  271. input_tokens.append([generation_token])
  272. seq_len = seq_data.get_len()
  273. position = seq_len - 1
  274. input_positions.append([position])
  275. context_lens.append(seq_len)
  276. assert seq_group_metadata.block_tables is not None
  277. block_table = seq_group_metadata.block_tables[seq_id]
  278. self.block_tables[batch_idx, :len(block_table)] = block_table
  279. batch_idx += 1
  280. block_number = block_table[position // self.block_size]
  281. block_offset = position % self.block_size
  282. slot = block_number * self.block_size + block_offset
  283. slot_mapping.append([slot])
  284. mm_data = seq_group_metadata.multi_modal_data
  285. if mm_data:
  286. mm_kwargs = self.multi_modal_input_mapper(mm_data)
  287. multi_modal_inputs_list.append(mm_kwargs)
  288. batch_size = _get_padded_batch_size(batch_idx)
  289. num_paddings = batch_size - batch_idx
  290. input_tokens = input_tokens + [[0]] * num_paddings
  291. input_positions = input_positions + [[0]] * num_paddings
  292. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  293. context_lens = context_lens + [0] * num_paddings
  294. input_tokens = torch.tensor(input_tokens,
  295. dtype=torch.int32,
  296. device=self.device)
  297. input_positions = torch.tensor(input_positions,
  298. dtype=torch.int32,
  299. device=self.device)
  300. slot_mapping = torch.tensor(slot_mapping,
  301. dtype=torch.int64,
  302. device=self.device)
  303. context_lens = torch.tensor(context_lens,
  304. dtype=torch.int32,
  305. device=self.device)
  306. block_tables = torch.tensor(self.block_tables[:batch_size],
  307. dtype=torch.int32,
  308. device=self.device)
  309. input_lens = torch.tensor([1] * batch_size,
  310. dtype=torch.int32,
  311. device=self.device)
  312. attn_metadata = self.attn_backend.make_metadata(
  313. num_prefills=0,
  314. num_prefill_tokens=0,
  315. num_decode_tokens=batch_size,
  316. slot_mapping=slot_mapping,
  317. block_tables=block_tables,
  318. context_lens=context_lens,
  319. )
  320. multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
  321. device=self.device)
  322. return (input_tokens, input_positions, attn_metadata, input_lens,
  323. multi_modal_kwargs)
  324. def _prepare_sample(
  325. self,
  326. seq_group_metadata_list: List[SequenceGroupMetadata],
  327. padded_batch_size: int,
  328. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  329. assert len(seq_group_metadata_list) > 0
  330. t = []
  331. p = []
  332. best_of = []
  333. for seq_group_metadata in seq_group_metadata_list:
  334. sampling_params = seq_group_metadata.sampling_params
  335. # NOTE: Here we mimic argmax sampling by applying a very
  336. # low temperature. This is not accurate.
  337. t.append(sampling_params.temperature
  338. if sampling_params.temperature >= 1e-5 else 1e-5)
  339. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  340. raise NotImplementedError(
  341. "Top-p sampling is currently disabled for the TPU backend "
  342. "due to performance issues.")
  343. p.append(sampling_params.top_p)
  344. if sampling_params.top_k != -1:
  345. raise NotImplementedError(
  346. "Top-k sampling is currently disabled for the TPU backend "
  347. "due to performance issues.")
  348. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  349. raise NotImplementedError(
  350. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  351. "backend.")
  352. best_of.append(sampling_params.best_of)
  353. if sampling_params.use_beam_search:
  354. raise NotImplementedError(
  355. "Beam search is not supported by the TPU backend.")
  356. if sampling_params.logprobs is not None:
  357. raise NotImplementedError(
  358. "logprobs is not currently supported by the TPU backend.")
  359. if sampling_params.prompt_logprobs is not None:
  360. raise NotImplementedError(
  361. "prompt_logprobs is not currently supported by the TPU "
  362. "backend.")
  363. # Repeat the sampling params if the seq group has multiple seqs.
  364. num_seqs = len(seq_group_metadata.seq_data)
  365. t += [t[-1]] * (num_seqs - 1)
  366. p += [p[-1]] * (num_seqs - 1)
  367. best_of += [best_of[-1]] * (num_seqs - 1)
  368. num_paddings = padded_batch_size - len(t)
  369. t += [1.0] * num_paddings
  370. p += [1.0] * num_paddings
  371. t = torch.tensor(t, dtype=torch.float32, device=self.device)
  372. p = torch.tensor(p, dtype=torch.float32, device=self.device)
  373. return t, p, best_of
  374. def _execute_model(
  375. self,
  376. seq_group_metadata_list: List[SequenceGroupMetadata],
  377. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  378. ) -> List[CompletionSequenceGroupOutput]:
  379. # Prepare inputs.
  380. assert len(seq_group_metadata_list) > 0
  381. # NOTE: We assume that all sequences in the group are all prompts or
  382. # all decodes.
  383. is_prompt = seq_group_metadata_list[0].is_prompt
  384. if is_prompt:
  385. inputs = self._prepare_prompt(seq_group_metadata_list)
  386. else:
  387. inputs = self._prepare_decode(seq_group_metadata_list)
  388. padded_batch_size = inputs[0].shape[0]
  389. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  390. padded_batch_size)
  391. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  392. # Execute the model.
  393. next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
  394. *inputs[2:], t, p, num_samples)
  395. # Retrieve the outputs to CPU.
  396. next_token_ids = next_token_ids.cpu().tolist()
  397. # NOTE: Minimal code to construct the sampler outputs.
  398. # The TPU backend does not reuse the sampler, since the TPU backend
  399. # does not support the advanced sampling parameters such as logprobs.
  400. zero_logprob = Logprob(0.0)
  401. batch_idx = 0
  402. sampler_outputs = []
  403. for seq_group_metadata in seq_group_metadata_list:
  404. seq_outputs = []
  405. seq_ids = list(seq_group_metadata.seq_data.keys())
  406. if is_prompt:
  407. assert len(seq_ids) == 1
  408. seq_id = seq_ids[0]
  409. for i in range(best_of[batch_idx]):
  410. next_token_id = next_token_ids[batch_idx][i]
  411. seq_outputs.append(
  412. SequenceOutput(seq_id, next_token_id,
  413. {next_token_id: zero_logprob}))
  414. batch_idx += 1
  415. else:
  416. for seq_id in seq_ids:
  417. next_token_id = next_token_ids[batch_idx][0]
  418. seq_outputs.append(
  419. SequenceOutput(seq_id, next_token_id,
  420. {next_token_id: zero_logprob}))
  421. batch_idx += 1
  422. sampler_outputs.append(
  423. CompletionSequenceGroupOutput(seq_outputs, None))
  424. return sampler_outputs
  425. def execute_model(
  426. self,
  427. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  428. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  429. num_steps: int = 1,
  430. ) -> List[SamplerOutput]:
  431. if num_steps > 1:
  432. raise ValueError(
  433. "TPUModelRunner does not support multi-step execution.")
  434. assert seq_group_metadata_list is not None
  435. assert len(seq_group_metadata_list) > 0
  436. if seq_group_metadata_list[0].is_prompt:
  437. # NOTE: To reduce the compilation time, we only compile the
  438. # prefill inputs with batch size 1. Because the scheduler is not
  439. # aware of this limitation, we need to handle batch size > 1
  440. # internally by calling the model multiple times and concatenating
  441. # the outputs.
  442. # FIXME: This is a temporary hack to not change the existing
  443. # scheduler. We need to fix this in the future.
  444. sampler_outputs = []
  445. for seq_group_metadata in seq_group_metadata_list:
  446. sampler_outputs += self._execute_model([seq_group_metadata],
  447. kv_caches)
  448. else:
  449. sampler_outputs = self._execute_model(seq_group_metadata_list,
  450. kv_caches)
  451. return [SamplerOutput(sampler_outputs)]
  452. class ModelWrapper(nn.Module):
  453. def __init__(self, model: nn.Module):
  454. super().__init__()
  455. self.model = model.eval()
  456. def forward(
  457. self,
  458. token_ids: torch.Tensor,
  459. position_ids: torch.Tensor,
  460. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  461. attn_metadata: AttentionMetadata,
  462. input_lens: torch.Tensor,
  463. multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]],
  464. t: torch.Tensor,
  465. p: torch.Tensor,
  466. num_samples: int,
  467. ) -> torch.Tensor:
  468. """Executes the forward pass of the model and samples the next token.
  469. Args:
  470. token_ids: The input token IDs of shape [batch_size, seq_len].
  471. position_ids: The input position IDs of shape [batch_size, seq_len].
  472. kv_caches: The key and value caches. They can be None during the
  473. memory profiling at initialization.
  474. attn_metadata: The Pallas attention metadata.
  475. input_lens: The actual input lengths of shape [batch_size].
  476. multi_modal_kwargs: Keyword arguments from multi-modal data to
  477. pass to the model.
  478. t: The sampling temperature of shape [batch_size].
  479. p: The top-p probability of shape [batch_size].
  480. """
  481. batch_size, seq_len = token_ids.shape
  482. # Calculate the positions to sample from.
  483. base_indicies = torch.arange(
  484. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  485. logits_indices = base_indicies + input_lens - 1
  486. # FIXME: This is a temporary hack to avoid using the existing
  487. # sampler and sampling metadata.
  488. sampling_metadata = SamplingMetadata(
  489. seq_groups=[],
  490. selected_token_indices=logits_indices,
  491. categorized_sample_indices={},
  492. num_prompts=attn_metadata.num_prefills,
  493. )
  494. # Skip this in memory profiling at initialization.
  495. if kv_caches[0][0] is not None:
  496. # index_copy_(slot_mapping) only works when the inserted dimension
  497. # is 0. However, the KV cache in the Pallas backend has the shape
  498. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  499. # work, we need to flatten the first three dimensions and modify
  500. # the slot_mapping accordingly.
  501. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  502. slot_mapping = attn_metadata.slot_mapping
  503. slot_mapping = slot_mapping.flatten()
  504. head_indicies = torch.arange(0,
  505. num_kv_heads,
  506. device=slot_mapping.device,
  507. dtype=slot_mapping.dtype)
  508. head_indicies *= block_size * num_blocks
  509. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  510. -1, num_kv_heads)
  511. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  512. slot_mapping = slot_mapping.flatten()
  513. attn_metadata.slot_mapping = slot_mapping
  514. hidden_states = self.model(
  515. token_ids,
  516. position_ids,
  517. kv_caches,
  518. attn_metadata,
  519. **(multi_modal_kwargs or {}),
  520. )
  521. hidden_states = hidden_states.flatten(0, 1)
  522. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  523. logits = logits / t.unsqueeze(dim=1)
  524. if _ENABLE_TOP_P:
  525. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  526. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  527. next_token_ids = torch.multinomial(probs,
  528. num_samples,
  529. replacement=True)
  530. return next_token_ids
  531. def _get_padded_prefill_len(x: int) -> int:
  532. # NOTE: The pallas FlashAttention kernel requires the sequence
  533. # length to be a multiple of 16. We pad the prompt length to the nearest
  534. # multiple of 16. This is also good for performance.
  535. if x <= 16:
  536. return 16
  537. return 1 << (x - 1).bit_length()
  538. def _get_padded_batch_size(batch_size: int) -> int:
  539. if batch_size <= 2:
  540. return batch_size
  541. elif batch_size <= 4:
  542. return 4
  543. elif batch_size <= 8:
  544. return 8
  545. else:
  546. return ((batch_size + 15) // 16) * 16
  547. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  548. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  549. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  550. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  551. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  552. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  553. return logits