tpu_model_runner.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. import time
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  4. from unittest.mock import patch
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch_xla.core.xla_model as xm
  9. import torch_xla.runtime as xr
  10. from loguru import logger
  11. from aphrodite.attention import AttentionMetadata, get_attn_backend
  12. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  13. ModelConfig, MultiModalConfig,
  14. ParallelConfig, SchedulerConfig)
  15. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  16. IntermediateTensors, Logprob,
  17. SamplerOutput, SequenceGroupMetadata,
  18. SequenceOutput)
  19. from aphrodite.modeling.model_loader import get_model
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.task_handler.model_runner_base import (
  22. ModelRunnerBase, ModelRunnerInputBase,
  23. _add_attn_metadata_broadcastable_dict,
  24. _init_attn_metadata_from_tensor_dict)
  25. if TYPE_CHECKING:
  26. from aphrodite.attention.backends.abstract import AttentionBackend
  27. # Here we utilize the behavior that out-of-bound index is ignored.
  28. # FIXME: Find a more reliable way to prevent possible bugs.
  29. _PAD_SLOT_ID = 1_000_000_000
  30. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  31. _ENABLE_TOP_P = False
  32. # FIXME: A temporary hack to support `n > 1`.
  33. # This can significantly affect the performance if too large.
  34. _MAX_NUM_SAMPLES = 128
  35. @dataclass(frozen=True)
  36. class ModelInputForTPU(ModelRunnerInputBase):
  37. token_ids: torch.Tensor
  38. position_ids: torch.Tensor
  39. attn_metadata: AttentionMetadata
  40. input_lens: torch.Tensor
  41. t: torch.Tensor
  42. p: torch.Tensor
  43. num_samples: int
  44. best_of: List[int]
  45. seq_groups: List[List[int]]
  46. def as_broadcastable_tensor_dict(
  47. self) -> Dict[str, Union[int, torch.Tensor]]:
  48. tensor_dict = {
  49. "token_ids": self.token_ids,
  50. "position_ids": self.position_ids,
  51. "input_lens": self.input_lens,
  52. "t": self.t,
  53. "p": self.p,
  54. "num_samples": self.num_samples,
  55. "best_of": self.best_of,
  56. "seq_groups": self.seq_groups,
  57. "virtual_engine": self.virtual_engine,
  58. }
  59. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  60. return tensor_dict
  61. @classmethod
  62. def from_broadcasted_tensor_dict(
  63. cls: Type["ModelInputForTPU"],
  64. tensor_dict: Dict[str, Any],
  65. attn_backend: Optional["AttentionBackend"] = None,
  66. ) -> "ModelInputForTPU":
  67. if attn_backend is not None:
  68. tensor_dict = _init_attn_metadata_from_tensor_dict(
  69. attn_backend, tensor_dict)
  70. return cls(**tensor_dict)
  71. class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
  72. def __init__(
  73. self,
  74. model_config: ModelConfig,
  75. parallel_config: ParallelConfig,
  76. scheduler_config: SchedulerConfig,
  77. device_config: DeviceConfig,
  78. cache_config: CacheConfig,
  79. load_config: LoadConfig,
  80. multimodal_config: Optional[MultiModalConfig] = None,
  81. is_driver_worker: bool = False,
  82. ):
  83. self.model_config = model_config
  84. self.parallel_config = parallel_config
  85. self.scheduler_config = scheduler_config
  86. self.device_config = device_config
  87. self.cache_config = cache_config
  88. self.load_config = load_config
  89. self.multimodal_config = multimodal_config
  90. self.is_driver_worker = is_driver_worker
  91. self.block_size = self.cache_config.block_size
  92. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  93. self.block_size)
  94. self.block_tables = np.zeros(
  95. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  96. dtype=np.int32)
  97. self.attn_backend = get_attn_backend(
  98. self.model_config.get_head_size(),
  99. self.model_config.get_sliding_window(),
  100. self.model_config.dtype,
  101. self.cache_config.cache_dtype,
  102. self.block_size,
  103. self.model_config.is_attention_free(),
  104. False,
  105. )
  106. def load_model(self) -> None:
  107. self.device = self.device_config.device
  108. # NOTE: While the executor assigns the TP ranks to the worker
  109. # process, the ranks can be different from the ranks internally assigned
  110. # by the xm runtime. Therefore, there is a mismatch in the rank
  111. # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
  112. # This is not a problem in linear layers because all-reduce is
  113. # rank-agnostic. However, it matters for all-gather as the ranks
  114. # determine the order of concatenating the output tensors.
  115. # As a workaround, we use the xm's rank assignment only when loading
  116. # the embedding weights.
  117. xm_tp_rank = xr.global_ordinal()
  118. with patch(
  119. "aphrodite.modeling.layers.vocab_parallel_embedding."
  120. "get_tensor_model_parallel_rank",
  121. return_value=xm_tp_rank):
  122. model = get_model(
  123. model_config=self.model_config,
  124. load_config=self.load_config,
  125. device_config=self.device_config,
  126. parallel_config=self.parallel_config,
  127. cache_config=self.cache_config,
  128. scheduler_config=self.scheduler_config,
  129. multimodal_config=self.multimodal_config,
  130. lora_config=None,
  131. )
  132. model = model.eval()
  133. xm.wait_device_ops()
  134. self.model = CompiledModelWrapper(model)
  135. def _dummy_run(
  136. self,
  137. batch_size: int,
  138. seq_len: int,
  139. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  140. is_prompt: bool,
  141. ) -> None:
  142. if is_prompt:
  143. seq_len = (seq_len + 15) // 16 * 16
  144. token_ids = torch.zeros((batch_size, seq_len),
  145. dtype=torch.int32,
  146. device=self.device)
  147. position_ids = torch.zeros((batch_size, seq_len),
  148. dtype=torch.int32,
  149. device=self.device)
  150. slot_mapping = torch.zeros((batch_size, seq_len),
  151. dtype=torch.int64,
  152. device=self.device)
  153. attn_metadata = self.attn_backend.make_metadata(
  154. num_prefills=batch_size,
  155. num_prefill_tokens=batch_size * seq_len,
  156. num_decode_tokens=0,
  157. slot_mapping=slot_mapping,
  158. block_tables=None,
  159. context_lens=None,
  160. )
  161. input_lens = torch.ones((batch_size, ),
  162. dtype=torch.int32,
  163. device=self.device)
  164. else:
  165. assert seq_len == 1
  166. token_ids = torch.zeros((batch_size, seq_len),
  167. dtype=torch.int32,
  168. device=self.device)
  169. position_ids = torch.zeros((batch_size, seq_len),
  170. dtype=torch.int32,
  171. device=self.device)
  172. slot_mapping = torch.zeros((batch_size, seq_len),
  173. dtype=torch.int64,
  174. device=self.device)
  175. block_tables = torch.zeros(
  176. (batch_size, self.max_num_blocks_per_seq),
  177. dtype=torch.int32,
  178. device=self.device)
  179. context_lens = torch.ones((batch_size, ),
  180. dtype=torch.int32,
  181. device=self.device)
  182. input_lens = torch.ones((batch_size, ),
  183. dtype=torch.int32,
  184. device=self.device)
  185. attn_metadata = self.attn_backend.make_metadata(
  186. num_prefills=0,
  187. num_prefill_tokens=0,
  188. num_decode_tokens=batch_size * seq_len,
  189. slot_mapping=slot_mapping,
  190. block_tables=block_tables,
  191. context_lens=context_lens,
  192. )
  193. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  194. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  195. # Dummy run.
  196. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  197. self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
  198. num_samples, kv_caches)
  199. def warmup_model(
  200. self,
  201. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  202. ) -> None:
  203. # Prefill
  204. logger.info("Compiling the model with different input shapes...")
  205. start = time.time()
  206. for batch_size in [1]:
  207. seq_len = 16
  208. while True:
  209. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  210. xm.wait_device_ops()
  211. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  212. if seq_len >= self.model_config.max_model_len:
  213. break
  214. num_tokens = batch_size * seq_len
  215. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  216. break
  217. seq_len = seq_len * 2
  218. end = time.time()
  219. logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
  220. # Decode
  221. start = time.time()
  222. seq_len = 1
  223. batch_size = 8 # Must be in sync with _get_padded_batch_size()
  224. while True:
  225. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  226. xm.wait_device_ops()
  227. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  228. if batch_size >= self.scheduler_config.max_num_seqs:
  229. break
  230. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  231. end = time.time()
  232. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  233. def _prepare_prompt(
  234. self,
  235. seq_group_metadata_list: List[SequenceGroupMetadata],
  236. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  237. assert len(seq_group_metadata_list) > 0
  238. input_tokens: List[int] = []
  239. input_positions: List[int] = []
  240. prompt_lens: List[int] = []
  241. slot_mapping: List[int] = []
  242. for seq_group_metadata in seq_group_metadata_list:
  243. assert seq_group_metadata.is_prompt
  244. seq_ids = list(seq_group_metadata.seq_data.keys())
  245. assert len(seq_ids) == 1
  246. seq_id = seq_ids[0]
  247. seq_data = seq_group_metadata.seq_data[seq_id]
  248. # Could include output tokens when a request is preempted.
  249. prompt_tokens = seq_data.get_token_ids()
  250. prompt_len = len(prompt_tokens)
  251. prompt_lens.append(prompt_len)
  252. input_tokens.extend(prompt_tokens)
  253. input_positions.extend(list(range(prompt_len)))
  254. assert seq_group_metadata.block_tables is not None
  255. block_table = seq_group_metadata.block_tables[seq_id]
  256. for i in range(prompt_len):
  257. block_number = block_table[i // self.block_size]
  258. block_offset = i % self.block_size
  259. slot = block_number * self.block_size + block_offset
  260. slot_mapping.append(slot)
  261. # Add paddings to EACH prompt to the smallest power of 2 that is
  262. # greater than or equal to the prompt length.
  263. # We pad the seq_len to reduce the compilation overhead.
  264. # We execute each prompt individually (i.e., with batch_size 1)
  265. # because the FlashAttention kernel does not support ragged inputs.
  266. # TODO(woosuk): Use SplashAttention to support ragged inputs.
  267. padded_prompt_len = _get_padded_prefill_len(prompt_len)
  268. num_paddings = padded_prompt_len - prompt_len
  269. input_tokens += [0] * num_paddings
  270. input_positions += [0] * num_paddings
  271. slot_mapping += [_PAD_SLOT_ID] * num_paddings
  272. assert len(prompt_lens) > 0
  273. num_prefills = len(prompt_lens)
  274. input_tokens = torch.tensor(input_tokens,
  275. dtype=torch.int32,
  276. device="cpu")
  277. input_positions = torch.tensor(input_positions,
  278. dtype=torch.int32,
  279. device="cpu")
  280. slot_mapping = torch.tensor(slot_mapping,
  281. dtype=torch.int64,
  282. device="cpu")
  283. prompt_lens = torch.tensor(prompt_lens,
  284. dtype=torch.int32,
  285. device="cpu")
  286. attn_metadata = self.attn_backend.make_metadata(
  287. num_prefills=num_prefills,
  288. num_prefill_tokens=0, # NOTE: This is not used.
  289. num_decode_tokens=0,
  290. slot_mapping=slot_mapping,
  291. block_tables=None,
  292. context_lens=None,
  293. )
  294. return input_tokens, input_positions, attn_metadata, prompt_lens
  295. def _prepare_decode(
  296. self,
  297. seq_group_metadata_list: List[SequenceGroupMetadata],
  298. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  299. assert len(seq_group_metadata_list) > 0
  300. input_tokens: List[List[int]] = []
  301. input_positions: List[List[int]] = []
  302. slot_mapping: List[List[int]] = []
  303. context_lens: List[int] = []
  304. batch_idx = 0
  305. for seq_group_metadata in seq_group_metadata_list:
  306. assert not seq_group_metadata.is_prompt
  307. seq_ids = list(seq_group_metadata.seq_data.keys())
  308. for seq_id in seq_ids:
  309. seq_data = seq_group_metadata.seq_data[seq_id]
  310. generation_token = seq_data.get_last_token_id()
  311. input_tokens.append([generation_token])
  312. seq_len = seq_data.get_len()
  313. position = seq_len - 1
  314. input_positions.append([position])
  315. context_lens.append(seq_len)
  316. assert seq_group_metadata.block_tables is not None
  317. block_table = seq_group_metadata.block_tables[seq_id]
  318. self.block_tables[batch_idx, :len(block_table)] = block_table
  319. batch_idx += 1
  320. block_number = block_table[position // self.block_size]
  321. block_offset = position % self.block_size
  322. slot = block_number * self.block_size + block_offset
  323. slot_mapping.append([slot])
  324. batch_size = _get_padded_batch_size(batch_idx)
  325. num_paddings = batch_size - batch_idx
  326. input_tokens = input_tokens + [[0]] * num_paddings
  327. input_positions = input_positions + [[0]] * num_paddings
  328. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  329. context_lens = context_lens + [0] * num_paddings
  330. input_tokens = torch.tensor(input_tokens,
  331. dtype=torch.int32,
  332. device="cpu")
  333. input_positions = torch.tensor(input_positions,
  334. dtype=torch.int32,
  335. device="cpu")
  336. slot_mapping = torch.tensor(slot_mapping,
  337. dtype=torch.int64,
  338. device="cpu")
  339. context_lens = torch.tensor(context_lens,
  340. dtype=torch.int32,
  341. device="cpu")
  342. block_tables = torch.tensor(self.block_tables[:batch_size],
  343. dtype=torch.int32,
  344. device="cpu")
  345. input_lens = torch.tensor([1] * batch_size,
  346. dtype=torch.int32,
  347. device="cpu")
  348. attn_metadata = self.attn_backend.make_metadata(
  349. num_prefills=0,
  350. num_prefill_tokens=0,
  351. num_decode_tokens=batch_size,
  352. slot_mapping=slot_mapping,
  353. block_tables=block_tables,
  354. context_lens=context_lens,
  355. )
  356. return input_tokens, input_positions, attn_metadata, input_lens
  357. def _prepare_sample(
  358. self,
  359. seq_group_metadata_list: List[SequenceGroupMetadata],
  360. padded_batch_size: int,
  361. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  362. assert len(seq_group_metadata_list) > 0
  363. t = []
  364. p = []
  365. best_of = []
  366. for seq_group_metadata in seq_group_metadata_list:
  367. sampling_params = seq_group_metadata.sampling_params
  368. t.append(sampling_params.temperature)
  369. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  370. raise NotImplementedError(
  371. "Top-p sampling is currently disabled for the TPU backend "
  372. "due to performance issues.")
  373. p.append(sampling_params.top_p)
  374. if sampling_params.top_k != -1:
  375. raise NotImplementedError(
  376. "Top-k sampling is currently disabled for the TPU backend "
  377. "due to performance issues.")
  378. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  379. raise NotImplementedError(
  380. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  381. "backend.")
  382. best_of.append(sampling_params.best_of)
  383. if sampling_params.use_beam_search:
  384. raise NotImplementedError(
  385. "Beam search is not supported by the TPU backend.")
  386. if sampling_params.logprobs is not None:
  387. raise NotImplementedError(
  388. "logprobs is not currently supported by the TPU backend.")
  389. if sampling_params.prompt_logprobs is not None:
  390. raise NotImplementedError(
  391. "prompt_logprobs is not currently supported by the TPU "
  392. "backend.")
  393. # Repeat the sampling params if the seq group has multiple seqs.
  394. num_seqs = len(seq_group_metadata.seq_data)
  395. t += [t[-1]] * (num_seqs - 1)
  396. p += [p[-1]] * (num_seqs - 1)
  397. best_of += [best_of[-1]] * (num_seqs - 1)
  398. num_paddings = padded_batch_size - len(t)
  399. t += [1.0] * num_paddings
  400. p += [1.0] * num_paddings
  401. t = torch.tensor(t, dtype=torch.float32, device="cpu")
  402. p = torch.tensor(p, dtype=torch.float32, device="cpu")
  403. return t, p, best_of
  404. def prepare_model_input(
  405. self,
  406. seq_group_metadata_list: List[SequenceGroupMetadata],
  407. virtual_engine: int = 0,
  408. finished_requests_ids: Optional[List[str]] = None,
  409. ) -> ModelInputForTPU:
  410. del finished_requests_ids # Unused.
  411. assert virtual_engine == 0
  412. assert len(seq_group_metadata_list) > 0
  413. # NOTE: We assume that all sequences in the group are all prompts or
  414. # all decodes.
  415. is_prompt = seq_group_metadata_list[0].is_prompt
  416. if is_prompt:
  417. inputs = self._prepare_prompt(seq_group_metadata_list)
  418. else:
  419. inputs = self._prepare_decode(seq_group_metadata_list)
  420. input_tokens, input_positions, attn_metadata, input_lens = inputs
  421. padded_batch_size = input_tokens.shape[0]
  422. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  423. padded_batch_size)
  424. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  425. seq_groups = [
  426. list(metadata.seq_data.keys())
  427. for metadata in seq_group_metadata_list
  428. ]
  429. return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
  430. input_lens, t, p, num_samples, best_of,
  431. seq_groups)
  432. def make_model_input_from_broadcasted_tensor_dict(
  433. self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
  434. model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
  435. tensor_dict, attn_backend=self.attn_backend)
  436. return model_input
  437. @torch.no_grad()
  438. def execute_model(
  439. self,
  440. model_input: ModelInputForTPU,
  441. kv_caches: Optional[List[Any]],
  442. intermediate_tensors: Optional[IntermediateTensors] = None,
  443. num_steps: int = 1,
  444. ) -> List[SamplerOutput]:
  445. assert intermediate_tensors is None
  446. if num_steps > 1:
  447. raise ValueError(
  448. "TPUModelRunner does not support multi-step execution.")
  449. def _execute_model(*args, clone: bool = False) -> torch.Tensor:
  450. """Move input args from CPU to device and execute the model."""
  451. def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
  452. if clone:
  453. # When x is a slice of a CPU tensor, XLA may copy the whole
  454. # original tensor to TPU instead of only copying x.
  455. # To avoid this, we copy x after cloning.
  456. x = x.clone()
  457. return x.to(self.device)
  458. new_args = []
  459. for arg in args:
  460. if isinstance(arg, torch.Tensor):
  461. arg = _copy_to_device(arg)
  462. elif isinstance(arg, AttentionMetadata):
  463. arg.slot_mapping = _copy_to_device(arg.slot_mapping)
  464. if getattr(arg, "block_tables", None) is not None:
  465. arg.block_tables = _copy_to_device(arg.block_tables)
  466. if getattr(arg, "context_lens", None) is not None:
  467. arg.context_lens = _copy_to_device(arg.context_lens)
  468. new_args.append(arg)
  469. return self.model(*new_args)
  470. num_prefills = model_input.attn_metadata.num_prefills
  471. is_prompt = num_prefills > 0
  472. if is_prompt:
  473. # NOTE: Since the FlashAttention kernel does not support
  474. # ragged inputs, we split the prompts into different batches and
  475. # process them separately. This is a temporary hack that should be
  476. # optimized by using SplashAttention.
  477. next_token_ids = []
  478. orig_slot_mapping = model_input.attn_metadata.slot_mapping
  479. batch_size = model_input.input_lens.shape[0]
  480. start_idx = 0
  481. for i in range(batch_size):
  482. # Get the actual prefill_len.
  483. prefill_len = model_input.input_lens[i:i + 1].item()
  484. prefill_len = _get_padded_prefill_len(prefill_len)
  485. end_idx = start_idx + prefill_len
  486. model_input.attn_metadata.slot_mapping = orig_slot_mapping[
  487. None, start_idx:end_idx]
  488. model_input.attn_metadata.num_prefills = 1
  489. output_token_ids = _execute_model(
  490. model_input.token_ids[None, start_idx:end_idx],
  491. model_input.position_ids[None, start_idx:end_idx],
  492. model_input.attn_metadata,
  493. model_input.input_lens[i:i + 1],
  494. model_input.t[i:i + 1],
  495. model_input.p[i:i + 1],
  496. model_input.num_samples,
  497. kv_caches,
  498. clone=True)
  499. # Retrieve the outputs to CPU.
  500. next_token_ids += output_token_ids.cpu().tolist()
  501. start_idx = end_idx
  502. else:
  503. # Execute the model.
  504. output_token_ids = _execute_model(
  505. model_input.token_ids, model_input.position_ids,
  506. model_input.attn_metadata, model_input.input_lens,
  507. model_input.t, model_input.p, model_input.num_samples,
  508. kv_caches)
  509. # Retrieve the outputs to CPU.
  510. next_token_ids = output_token_ids.cpu().tolist()
  511. # NOTE: Minimal code to construct the sampler outputs.
  512. # The TPU backend does not reuse the sampler, since the TPU backend
  513. # does not support the advanced sampling parameters such as logprobs.
  514. zero_logprob = Logprob(0.0)
  515. batch_idx = 0
  516. sampler_outputs = []
  517. for seq_group in model_input.seq_groups:
  518. seq_ids = seq_group
  519. seq_outputs = []
  520. if is_prompt:
  521. assert len(seq_ids) == 1
  522. seq_id = seq_ids[0]
  523. for i in range(model_input.best_of[batch_idx]):
  524. next_token_id = next_token_ids[batch_idx][i]
  525. seq_outputs.append(
  526. SequenceOutput(seq_id, next_token_id,
  527. {next_token_id: zero_logprob}))
  528. batch_idx += 1
  529. else:
  530. for seq_id in seq_ids:
  531. next_token_id = next_token_ids[batch_idx][0]
  532. seq_outputs.append(
  533. SequenceOutput(seq_id, next_token_id,
  534. {next_token_id: zero_logprob}))
  535. batch_idx += 1
  536. sampler_outputs.append(
  537. CompletionSequenceGroupOutput(seq_outputs, None))
  538. return [SamplerOutput(sampler_outputs)]
  539. class ModelWrapper(nn.Module):
  540. def __init__(self, model: nn.Module):
  541. super().__init__()
  542. self.model = model
  543. def forward(
  544. self,
  545. token_ids: torch.Tensor,
  546. position_ids: torch.Tensor,
  547. attn_metadata: AttentionMetadata,
  548. input_lens: torch.Tensor,
  549. t: torch.Tensor,
  550. p: torch.Tensor,
  551. num_samples: int,
  552. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  553. ) -> torch.Tensor:
  554. """Executes the forward pass of the model and samples the next token.
  555. Args:
  556. token_ids: The input token IDs of shape [batch_size, seq_len].
  557. position_ids: The input position IDs of shape [batch_size, seq_len].
  558. attn_metadata: The Pallas attention metadata.
  559. input_lens: The actual input lengths of shape [batch_size].
  560. t: The sampling temperature of shape [batch_size].
  561. p: The top-p probability of shape [batch_size].
  562. num_samples: Number of samples to draw from each logits vector.
  563. kv_caches: The key and value caches. They can be None during the
  564. memory profiling at initialization.
  565. """
  566. batch_size, seq_len = token_ids.shape
  567. # Calculate the positions to sample from.
  568. start_indicies = torch.arange(
  569. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  570. logits_indices = start_indicies + input_lens - 1
  571. # FIXME: This is a temporary hack to avoid using the existing
  572. # sampler and sampling metadata.
  573. sampling_metadata = SamplingMetadata(
  574. seq_groups=[],
  575. selected_token_indices=logits_indices,
  576. categorized_sample_indices={},
  577. num_prompts=attn_metadata.num_prefills,
  578. )
  579. # Skip this in memory profiling at initialization.
  580. if kv_caches[0][0] is not None:
  581. # index_copy_(slot_mapping) only works when the inserted dimension
  582. # is 0. However, the KV cache in the Pallas backend has the shape
  583. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  584. # work, we need to flatten the first three dimensions and modify
  585. # the slot_mapping accordingly.
  586. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  587. slot_mapping = attn_metadata.slot_mapping
  588. slot_mapping = slot_mapping.flatten()
  589. head_indicies = torch.arange(0,
  590. num_kv_heads,
  591. device=slot_mapping.device,
  592. dtype=slot_mapping.dtype)
  593. head_indicies *= block_size * num_blocks
  594. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  595. -1, num_kv_heads)
  596. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  597. slot_mapping = slot_mapping.flatten()
  598. attn_metadata.slot_mapping = slot_mapping
  599. hidden_states = self.model(
  600. token_ids,
  601. position_ids,
  602. kv_caches,
  603. attn_metadata,
  604. )
  605. hidden_states = hidden_states.flatten(0, 1)
  606. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  607. # Argmax sampling.
  608. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
  609. argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
  610. # Zero temperature means greedy decoding. Avoid division by zero.
  611. nonzero_t = torch.where(t != 0, t, 1.0)
  612. logits = logits / nonzero_t.unsqueeze(dim=1)
  613. if _ENABLE_TOP_P:
  614. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  615. # Random sampling.
  616. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  617. sampled_token_ids = torch.multinomial(probs,
  618. num_samples,
  619. replacement=True)
  620. next_token_ids = torch.where(t != 0, sampled_token_ids,
  621. argmax_token_ids)
  622. return next_token_ids
  623. class CompiledModelWrapper:
  624. def __init__(self, model: nn.Module):
  625. model = ModelWrapper(model)
  626. self.model = torch.compile(model,
  627. backend="openxla",
  628. fullgraph=True,
  629. dynamic=False)
  630. def __call__(
  631. self,
  632. token_ids: torch.Tensor,
  633. position_ids: torch.Tensor,
  634. attn_metadata: AttentionMetadata,
  635. input_lens: torch.Tensor,
  636. t: torch.Tensor,
  637. p: torch.Tensor,
  638. num_samples: int,
  639. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  640. ) -> torch.Tensor:
  641. # NOTE: There are two stages of compilation: torch.compile and
  642. # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
  643. # overhead by reusing the FX graph for different shapes.
  644. # However, the XLA graph will still require static shapes and needs to
  645. # be re-compiled for every different shapes. This overhead is inevitable
  646. # in the first run, but can be skipped afterwards as we cache the XLA
  647. # graphs in the disk (APHRODITE_XLA_CACHE_PATH).
  648. if attn_metadata.num_prefills > 0:
  649. # Prefll
  650. torch._dynamo.mark_dynamic(token_ids, 1)
  651. torch._dynamo.mark_dynamic(position_ids, 1)
  652. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
  653. else:
  654. # Decode
  655. torch._dynamo.mark_dynamic(token_ids, 0)
  656. torch._dynamo.mark_dynamic(position_ids, 0)
  657. torch._dynamo.mark_dynamic(input_lens, 0)
  658. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
  659. torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
  660. torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
  661. torch._dynamo.mark_dynamic(t, 0)
  662. torch._dynamo.mark_dynamic(p, 0)
  663. return self.model(token_ids, position_ids, attn_metadata, input_lens,
  664. t, p, num_samples, kv_caches)
  665. def _get_padded_prefill_len(x: int) -> int:
  666. # NOTE: The pallas FlashAttention kernel requires the sequence
  667. # length to be a multiple of 16. We pad the prompt length to the nearest
  668. # multiple of 16. This is also good for performance.
  669. if x <= 16:
  670. return 16
  671. return 1 << (x - 1).bit_length()
  672. def _get_padded_batch_size(batch_size: int) -> int:
  673. # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
  674. # To meet this requirement in the simplest way, we set the minimal batch
  675. # size to 8.
  676. if batch_size <= 8:
  677. return 8
  678. else:
  679. return ((batch_size + 15) // 16) * 16
  680. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  681. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  682. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  683. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  684. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  685. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  686. return logits