tpu_model_runner.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. import time
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  4. from unittest.mock import patch
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch_xla.core.xla_model as xm
  9. import torch_xla.runtime as xr
  10. from loguru import logger
  11. from aphrodite.attention import AttentionMetadata, get_attn_backend
  12. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  13. ModelConfig, MultiModalConfig,
  14. ParallelConfig, SchedulerConfig)
  15. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  16. IntermediateTensors, Logprob,
  17. SamplerOutput, SequenceGroupMetadata,
  18. SequenceOutput)
  19. from aphrodite.modeling.model_loader import get_model
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.task_handler.model_runner_base import (
  22. ModelRunnerBase, ModelRunnerInputBase,
  23. _add_attn_metadata_broadcastable_dict,
  24. _init_attn_metadata_from_tensor_dict)
  25. if TYPE_CHECKING:
  26. from aphrodite.attention.backends.abstract import AttentionBackend
  27. # Here we utilize the behavior that out-of-bound index is ignored.
  28. # FIXME: Find a more reliable way to prevent possible bugs.
  29. _PAD_SLOT_ID = 1_000_000_000
  30. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  31. _ENABLE_TOP_P = False
  32. # FIXME: A temporary hack to support `n > 1`.
  33. # This can significantly affect the performance if too large.
  34. _MAX_NUM_SAMPLES = 128
  35. @dataclass(frozen=True)
  36. class ModelInputForTPU(ModelRunnerInputBase):
  37. token_ids: torch.Tensor
  38. position_ids: torch.Tensor
  39. attn_metadata: AttentionMetadata
  40. input_lens: torch.Tensor
  41. t: torch.Tensor
  42. p: torch.Tensor
  43. num_samples: int
  44. best_of: List[int]
  45. seq_groups: List[List[int]]
  46. def as_broadcastable_tensor_dict(
  47. self) -> Dict[str, Union[int, torch.Tensor]]:
  48. tensor_dict = {
  49. "token_ids": self.token_ids,
  50. "position_ids": self.position_ids,
  51. "input_lens": self.input_lens,
  52. "t": self.t,
  53. "p": self.p,
  54. "num_samples": self.num_samples,
  55. "best_of": self.best_of,
  56. "seq_groups": self.seq_groups,
  57. "virtual_engine": self.virtual_engine,
  58. }
  59. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  60. return tensor_dict
  61. @classmethod
  62. def from_broadcasted_tensor_dict(
  63. cls: Type["ModelInputForTPU"],
  64. tensor_dict: Dict[str, Any],
  65. attn_backend: Optional["AttentionBackend"] = None,
  66. ) -> "ModelInputForTPU":
  67. if attn_backend is not None:
  68. tensor_dict = _init_attn_metadata_from_tensor_dict(
  69. attn_backend, tensor_dict)
  70. return cls(**tensor_dict)
  71. class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
  72. def __init__(
  73. self,
  74. model_config: ModelConfig,
  75. parallel_config: ParallelConfig,
  76. scheduler_config: SchedulerConfig,
  77. device_config: DeviceConfig,
  78. cache_config: CacheConfig,
  79. load_config: LoadConfig,
  80. multimodal_config: Optional[MultiModalConfig] = None,
  81. is_driver_worker: bool = False,
  82. **kwargs,
  83. ):
  84. self.model_config = model_config
  85. self.parallel_config = parallel_config
  86. self.scheduler_config = scheduler_config
  87. self.device_config = device_config
  88. self.cache_config = cache_config
  89. self.load_config = load_config
  90. self.multimodal_config = multimodal_config
  91. self.is_driver_worker = is_driver_worker
  92. self.block_size = self.cache_config.block_size
  93. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  94. self.block_size)
  95. self.block_tables = np.zeros(
  96. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  97. dtype=np.int32)
  98. self.attn_backend = get_attn_backend(
  99. self.model_config.get_head_size(),
  100. self.model_config.get_sliding_window(),
  101. self.model_config.dtype,
  102. self.cache_config.cache_dtype,
  103. self.block_size,
  104. self.model_config.is_attention_free(),
  105. False,
  106. )
  107. def load_model(self) -> None:
  108. self.device = self.device_config.device
  109. # NOTE: While the executor assigns the TP ranks to the worker
  110. # process, the ranks can be different from the ranks internally assigned
  111. # by the xm runtime. Therefore, there is a mismatch in the rank
  112. # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
  113. # This is not a problem in linear layers because all-reduce is
  114. # rank-agnostic. However, it matters for all-gather as the ranks
  115. # determine the order of concatenating the output tensors.
  116. # As a workaround, we use the xm's rank assignment only when loading
  117. # the embedding weights.
  118. xm_tp_rank = xr.global_ordinal()
  119. with patch(
  120. "aphrodite.modeling.layers.vocab_parallel_embedding."
  121. "get_tensor_model_parallel_rank",
  122. return_value=xm_tp_rank):
  123. model = get_model(
  124. model_config=self.model_config,
  125. load_config=self.load_config,
  126. device_config=self.device_config,
  127. parallel_config=self.parallel_config,
  128. cache_config=self.cache_config,
  129. scheduler_config=self.scheduler_config,
  130. multimodal_config=self.multimodal_config,
  131. lora_config=None,
  132. )
  133. model = model.eval()
  134. xm.wait_device_ops()
  135. self.model = CompiledModelWrapper(model)
  136. def _dummy_run(
  137. self,
  138. batch_size: int,
  139. seq_len: int,
  140. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  141. is_prompt: bool,
  142. ) -> None:
  143. if is_prompt:
  144. seq_len = (seq_len + 15) // 16 * 16
  145. token_ids = torch.zeros((batch_size, seq_len),
  146. dtype=torch.int32,
  147. device=self.device)
  148. position_ids = torch.zeros((batch_size, seq_len),
  149. dtype=torch.int32,
  150. device=self.device)
  151. slot_mapping = torch.zeros((batch_size, seq_len),
  152. dtype=torch.int64,
  153. device=self.device)
  154. attn_metadata = self.attn_backend.make_metadata(
  155. num_prefills=batch_size,
  156. num_prefill_tokens=batch_size * seq_len,
  157. num_decode_tokens=0,
  158. slot_mapping=slot_mapping,
  159. block_tables=None,
  160. context_lens=None,
  161. )
  162. input_lens = torch.ones((batch_size, ),
  163. dtype=torch.int32,
  164. device=self.device)
  165. else:
  166. assert seq_len == 1
  167. token_ids = torch.zeros((batch_size, seq_len),
  168. dtype=torch.int32,
  169. device=self.device)
  170. position_ids = torch.zeros((batch_size, seq_len),
  171. dtype=torch.int32,
  172. device=self.device)
  173. slot_mapping = torch.zeros((batch_size, seq_len),
  174. dtype=torch.int64,
  175. device=self.device)
  176. block_tables = torch.zeros(
  177. (batch_size, self.max_num_blocks_per_seq),
  178. dtype=torch.int32,
  179. device=self.device)
  180. context_lens = torch.ones((batch_size, ),
  181. dtype=torch.int32,
  182. device=self.device)
  183. input_lens = torch.ones((batch_size, ),
  184. dtype=torch.int32,
  185. device=self.device)
  186. attn_metadata = self.attn_backend.make_metadata(
  187. num_prefills=0,
  188. num_prefill_tokens=0,
  189. num_decode_tokens=batch_size * seq_len,
  190. slot_mapping=slot_mapping,
  191. block_tables=block_tables,
  192. context_lens=context_lens,
  193. )
  194. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  195. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  196. # Dummy run.
  197. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  198. self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
  199. num_samples, kv_caches)
  200. def warmup_model(
  201. self,
  202. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  203. ) -> None:
  204. # Prefill
  205. logger.info("Compiling the model with different input shapes...")
  206. start = time.time()
  207. for batch_size in [1]:
  208. seq_len = 16
  209. while True:
  210. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  211. xm.wait_device_ops()
  212. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  213. if seq_len >= self.model_config.max_model_len:
  214. break
  215. num_tokens = batch_size * seq_len
  216. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  217. break
  218. seq_len = seq_len * 2
  219. end = time.time()
  220. logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
  221. # Decode
  222. start = time.time()
  223. seq_len = 1
  224. batch_size = 8 # Must be in sync with _get_padded_batch_size()
  225. while True:
  226. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  227. xm.wait_device_ops()
  228. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  229. if batch_size >= self.scheduler_config.max_num_seqs:
  230. break
  231. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  232. end = time.time()
  233. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  234. def _prepare_prompt(
  235. self,
  236. seq_group_metadata_list: List[SequenceGroupMetadata],
  237. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  238. assert len(seq_group_metadata_list) > 0
  239. input_tokens: List[int] = []
  240. input_positions: List[int] = []
  241. prompt_lens: List[int] = []
  242. slot_mapping: List[int] = []
  243. for seq_group_metadata in seq_group_metadata_list:
  244. assert seq_group_metadata.is_prompt
  245. seq_ids = list(seq_group_metadata.seq_data.keys())
  246. assert len(seq_ids) == 1
  247. seq_id = seq_ids[0]
  248. seq_data = seq_group_metadata.seq_data[seq_id]
  249. # Could include output tokens when a request is preempted.
  250. prompt_tokens = seq_data.get_token_ids()
  251. prompt_len = len(prompt_tokens)
  252. prompt_lens.append(prompt_len)
  253. input_tokens.extend(prompt_tokens)
  254. input_positions.extend(list(range(prompt_len)))
  255. assert seq_group_metadata.block_tables is not None
  256. block_table = seq_group_metadata.block_tables[seq_id]
  257. for i in range(prompt_len):
  258. block_number = block_table[i // self.block_size]
  259. block_offset = i % self.block_size
  260. slot = block_number * self.block_size + block_offset
  261. slot_mapping.append(slot)
  262. # Add paddings to EACH prompt to the smallest power of 2 that is
  263. # greater than or equal to the prompt length.
  264. # We pad the seq_len to reduce the compilation overhead.
  265. # We execute each prompt individually (i.e., with batch_size 1)
  266. # because the FlashAttention kernel does not support ragged inputs.
  267. # TODO(woosuk): Use SplashAttention to support ragged inputs.
  268. padded_prompt_len = _get_padded_prefill_len(prompt_len)
  269. num_paddings = padded_prompt_len - prompt_len
  270. input_tokens += [0] * num_paddings
  271. input_positions += [0] * num_paddings
  272. slot_mapping += [_PAD_SLOT_ID] * num_paddings
  273. assert len(prompt_lens) > 0
  274. num_prefills = len(prompt_lens)
  275. input_tokens = torch.tensor(input_tokens,
  276. dtype=torch.int32,
  277. device="cpu")
  278. input_positions = torch.tensor(input_positions,
  279. dtype=torch.int32,
  280. device="cpu")
  281. slot_mapping = torch.tensor(slot_mapping,
  282. dtype=torch.int64,
  283. device="cpu")
  284. prompt_lens = torch.tensor(prompt_lens,
  285. dtype=torch.int32,
  286. device="cpu")
  287. attn_metadata = self.attn_backend.make_metadata(
  288. num_prefills=num_prefills,
  289. num_prefill_tokens=0, # NOTE: This is not used.
  290. num_decode_tokens=0,
  291. slot_mapping=slot_mapping,
  292. block_tables=None,
  293. context_lens=None,
  294. )
  295. return input_tokens, input_positions, attn_metadata, prompt_lens
  296. def _prepare_decode(
  297. self,
  298. seq_group_metadata_list: List[SequenceGroupMetadata],
  299. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  300. assert len(seq_group_metadata_list) > 0
  301. input_tokens: List[List[int]] = []
  302. input_positions: List[List[int]] = []
  303. slot_mapping: List[List[int]] = []
  304. context_lens: List[int] = []
  305. batch_idx = 0
  306. for seq_group_metadata in seq_group_metadata_list:
  307. assert not seq_group_metadata.is_prompt
  308. seq_ids = list(seq_group_metadata.seq_data.keys())
  309. for seq_id in seq_ids:
  310. seq_data = seq_group_metadata.seq_data[seq_id]
  311. generation_token = seq_data.get_last_token_id()
  312. input_tokens.append([generation_token])
  313. seq_len = seq_data.get_len()
  314. position = seq_len - 1
  315. input_positions.append([position])
  316. context_lens.append(seq_len)
  317. assert seq_group_metadata.block_tables is not None
  318. block_table = seq_group_metadata.block_tables[seq_id]
  319. self.block_tables[batch_idx, :len(block_table)] = block_table
  320. batch_idx += 1
  321. block_number = block_table[position // self.block_size]
  322. block_offset = position % self.block_size
  323. slot = block_number * self.block_size + block_offset
  324. slot_mapping.append([slot])
  325. batch_size = _get_padded_batch_size(batch_idx)
  326. num_paddings = batch_size - batch_idx
  327. input_tokens = input_tokens + [[0]] * num_paddings
  328. input_positions = input_positions + [[0]] * num_paddings
  329. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  330. context_lens = context_lens + [0] * num_paddings
  331. input_tokens = torch.tensor(input_tokens,
  332. dtype=torch.int32,
  333. device="cpu")
  334. input_positions = torch.tensor(input_positions,
  335. dtype=torch.int32,
  336. device="cpu")
  337. slot_mapping = torch.tensor(slot_mapping,
  338. dtype=torch.int64,
  339. device="cpu")
  340. context_lens = torch.tensor(context_lens,
  341. dtype=torch.int32,
  342. device="cpu")
  343. block_tables = torch.tensor(self.block_tables[:batch_size],
  344. dtype=torch.int32,
  345. device="cpu")
  346. input_lens = torch.tensor([1] * batch_size,
  347. dtype=torch.int32,
  348. device="cpu")
  349. attn_metadata = self.attn_backend.make_metadata(
  350. num_prefills=0,
  351. num_prefill_tokens=0,
  352. num_decode_tokens=batch_size,
  353. slot_mapping=slot_mapping,
  354. block_tables=block_tables,
  355. context_lens=context_lens,
  356. )
  357. return input_tokens, input_positions, attn_metadata, input_lens
  358. def _prepare_sample(
  359. self,
  360. seq_group_metadata_list: List[SequenceGroupMetadata],
  361. padded_batch_size: int,
  362. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  363. assert len(seq_group_metadata_list) > 0
  364. t = []
  365. p = []
  366. best_of = []
  367. for seq_group_metadata in seq_group_metadata_list:
  368. sampling_params = seq_group_metadata.sampling_params
  369. t.append(sampling_params.temperature)
  370. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  371. raise NotImplementedError(
  372. "Top-p sampling is currently disabled for the TPU backend "
  373. "due to performance issues.")
  374. p.append(sampling_params.top_p)
  375. if sampling_params.top_k != -1:
  376. raise NotImplementedError(
  377. "Top-k sampling is currently disabled for the TPU backend "
  378. "due to performance issues.")
  379. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  380. raise NotImplementedError(
  381. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  382. "backend.")
  383. best_of.append(sampling_params.best_of)
  384. if sampling_params.use_beam_search:
  385. raise NotImplementedError(
  386. "Beam search is not supported by the TPU backend.")
  387. if sampling_params.logprobs is not None:
  388. raise NotImplementedError(
  389. "logprobs is not currently supported by the TPU backend.")
  390. if sampling_params.prompt_logprobs is not None:
  391. raise NotImplementedError(
  392. "prompt_logprobs is not currently supported by the TPU "
  393. "backend.")
  394. # Repeat the sampling params if the seq group has multiple seqs.
  395. num_seqs = len(seq_group_metadata.seq_data)
  396. t += [t[-1]] * (num_seqs - 1)
  397. p += [p[-1]] * (num_seqs - 1)
  398. best_of += [best_of[-1]] * (num_seqs - 1)
  399. num_paddings = padded_batch_size - len(t)
  400. t += [1.0] * num_paddings
  401. p += [1.0] * num_paddings
  402. t = torch.tensor(t, dtype=torch.float32, device="cpu")
  403. p = torch.tensor(p, dtype=torch.float32, device="cpu")
  404. return t, p, best_of
  405. def prepare_model_input(
  406. self,
  407. seq_group_metadata_list: List[SequenceGroupMetadata],
  408. virtual_engine: int = 0,
  409. finished_requests_ids: Optional[List[str]] = None,
  410. ) -> ModelInputForTPU:
  411. del finished_requests_ids # Unused.
  412. assert virtual_engine == 0
  413. assert len(seq_group_metadata_list) > 0
  414. # NOTE: We assume that all sequences in the group are all prompts or
  415. # all decodes.
  416. is_prompt = seq_group_metadata_list[0].is_prompt
  417. if is_prompt:
  418. inputs = self._prepare_prompt(seq_group_metadata_list)
  419. else:
  420. inputs = self._prepare_decode(seq_group_metadata_list)
  421. input_tokens, input_positions, attn_metadata, input_lens = inputs
  422. padded_batch_size = input_tokens.shape[0]
  423. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  424. padded_batch_size)
  425. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  426. seq_groups = [
  427. list(metadata.seq_data.keys())
  428. for metadata in seq_group_metadata_list
  429. ]
  430. return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
  431. input_lens, t, p, num_samples, best_of,
  432. seq_groups)
  433. def make_model_input_from_broadcasted_tensor_dict(
  434. self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
  435. model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
  436. tensor_dict, attn_backend=self.attn_backend)
  437. return model_input
  438. @torch.no_grad()
  439. def execute_model(
  440. self,
  441. model_input: ModelInputForTPU,
  442. kv_caches: Optional[List[Any]],
  443. intermediate_tensors: Optional[IntermediateTensors] = None,
  444. num_steps: int = 1,
  445. ) -> List[SamplerOutput]:
  446. assert intermediate_tensors is None
  447. if num_steps > 1:
  448. raise ValueError(
  449. "TPUModelRunner does not support multi-step execution.")
  450. def _execute_model(*args, clone: bool = False) -> torch.Tensor:
  451. """Move input args from CPU to device and execute the model."""
  452. def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
  453. if clone:
  454. # When x is a slice of a CPU tensor, XLA may copy the whole
  455. # original tensor to TPU instead of only copying x.
  456. # To avoid this, we copy x after cloning.
  457. x = x.clone()
  458. return x.to(self.device)
  459. new_args = []
  460. for arg in args:
  461. if isinstance(arg, torch.Tensor):
  462. arg = _copy_to_device(arg)
  463. elif isinstance(arg, AttentionMetadata):
  464. arg.slot_mapping = _copy_to_device(arg.slot_mapping)
  465. if getattr(arg, "block_tables", None) is not None:
  466. arg.block_tables = _copy_to_device(arg.block_tables)
  467. if getattr(arg, "context_lens", None) is not None:
  468. arg.context_lens = _copy_to_device(arg.context_lens)
  469. new_args.append(arg)
  470. return self.model(*new_args)
  471. num_prefills = model_input.attn_metadata.num_prefills
  472. is_prompt = num_prefills > 0
  473. if is_prompt:
  474. # NOTE: Since the FlashAttention kernel does not support
  475. # ragged inputs, we split the prompts into different batches and
  476. # process them separately. This is a temporary hack that should be
  477. # optimized by using SplashAttention.
  478. next_token_ids = []
  479. orig_slot_mapping = model_input.attn_metadata.slot_mapping
  480. batch_size = model_input.input_lens.shape[0]
  481. start_idx = 0
  482. for i in range(batch_size):
  483. # Get the actual prefill_len.
  484. prefill_len = model_input.input_lens[i:i + 1].item()
  485. prefill_len = _get_padded_prefill_len(prefill_len)
  486. end_idx = start_idx + prefill_len
  487. model_input.attn_metadata.slot_mapping = orig_slot_mapping[
  488. None, start_idx:end_idx]
  489. model_input.attn_metadata.num_prefills = 1
  490. output_token_ids = _execute_model(
  491. model_input.token_ids[None, start_idx:end_idx],
  492. model_input.position_ids[None, start_idx:end_idx],
  493. model_input.attn_metadata,
  494. model_input.input_lens[i:i + 1],
  495. model_input.t[i:i + 1],
  496. model_input.p[i:i + 1],
  497. model_input.num_samples,
  498. kv_caches,
  499. clone=True)
  500. # Retrieve the outputs to CPU.
  501. next_token_ids += output_token_ids.cpu().tolist()
  502. start_idx = end_idx
  503. else:
  504. # Execute the model.
  505. output_token_ids = _execute_model(
  506. model_input.token_ids, model_input.position_ids,
  507. model_input.attn_metadata, model_input.input_lens,
  508. model_input.t, model_input.p, model_input.num_samples,
  509. kv_caches)
  510. # Retrieve the outputs to CPU.
  511. next_token_ids = output_token_ids.cpu().tolist()
  512. # NOTE: Minimal code to construct the sampler outputs.
  513. # The TPU backend does not reuse the sampler, since the TPU backend
  514. # does not support the advanced sampling parameters such as logprobs.
  515. zero_logprob = Logprob(0.0)
  516. batch_idx = 0
  517. sampler_outputs = []
  518. for seq_group in model_input.seq_groups:
  519. seq_ids = seq_group
  520. seq_outputs = []
  521. if is_prompt:
  522. assert len(seq_ids) == 1
  523. seq_id = seq_ids[0]
  524. for i in range(model_input.best_of[batch_idx]):
  525. next_token_id = next_token_ids[batch_idx][i]
  526. seq_outputs.append(
  527. SequenceOutput(seq_id, next_token_id,
  528. {next_token_id: zero_logprob}))
  529. batch_idx += 1
  530. else:
  531. for seq_id in seq_ids:
  532. next_token_id = next_token_ids[batch_idx][0]
  533. seq_outputs.append(
  534. SequenceOutput(seq_id, next_token_id,
  535. {next_token_id: zero_logprob}))
  536. batch_idx += 1
  537. sampler_outputs.append(
  538. CompletionSequenceGroupOutput(seq_outputs, None))
  539. return [SamplerOutput(sampler_outputs)]
  540. class ModelWrapper(nn.Module):
  541. def __init__(self, model: nn.Module):
  542. super().__init__()
  543. self.model = model
  544. def forward(
  545. self,
  546. token_ids: torch.Tensor,
  547. position_ids: torch.Tensor,
  548. attn_metadata: AttentionMetadata,
  549. input_lens: torch.Tensor,
  550. t: torch.Tensor,
  551. p: torch.Tensor,
  552. num_samples: int,
  553. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  554. ) -> torch.Tensor:
  555. """Executes the forward pass of the model and samples the next token.
  556. Args:
  557. token_ids: The input token IDs of shape [batch_size, seq_len].
  558. position_ids: The input position IDs of shape [batch_size, seq_len].
  559. attn_metadata: The Pallas attention metadata.
  560. input_lens: The actual input lengths of shape [batch_size].
  561. t: The sampling temperature of shape [batch_size].
  562. p: The top-p probability of shape [batch_size].
  563. num_samples: Number of samples to draw from each logits vector.
  564. kv_caches: The key and value caches. They can be None during the
  565. memory profiling at initialization.
  566. """
  567. batch_size, seq_len = token_ids.shape
  568. # Calculate the positions to sample from.
  569. start_indicies = torch.arange(
  570. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  571. logits_indices = start_indicies + input_lens - 1
  572. # FIXME: This is a temporary hack to avoid using the existing
  573. # sampler and sampling metadata.
  574. sampling_metadata = SamplingMetadata(
  575. seq_groups=[],
  576. selected_token_indices=logits_indices,
  577. categorized_sample_indices={},
  578. num_prompts=attn_metadata.num_prefills,
  579. )
  580. # Skip this in memory profiling at initialization.
  581. if kv_caches[0][0] is not None:
  582. # index_copy_(slot_mapping) only works when the inserted dimension
  583. # is 0. However, the KV cache in the Pallas backend has the shape
  584. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  585. # work, we need to flatten the first three dimensions and modify
  586. # the slot_mapping accordingly.
  587. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  588. slot_mapping = attn_metadata.slot_mapping
  589. slot_mapping = slot_mapping.flatten()
  590. head_indicies = torch.arange(0,
  591. num_kv_heads,
  592. device=slot_mapping.device,
  593. dtype=slot_mapping.dtype)
  594. head_indicies *= block_size * num_blocks
  595. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  596. -1, num_kv_heads)
  597. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  598. slot_mapping = slot_mapping.flatten()
  599. attn_metadata.slot_mapping = slot_mapping
  600. hidden_states = self.model(
  601. token_ids,
  602. position_ids,
  603. kv_caches,
  604. attn_metadata,
  605. )
  606. hidden_states = hidden_states.flatten(0, 1)
  607. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  608. # Argmax sampling.
  609. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
  610. argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
  611. # Zero temperature means greedy decoding. Avoid division by zero.
  612. nonzero_t = torch.where(t != 0, t, 1.0)
  613. logits = logits / nonzero_t.unsqueeze(dim=1)
  614. if _ENABLE_TOP_P:
  615. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  616. # Random sampling.
  617. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  618. sampled_token_ids = torch.multinomial(probs,
  619. num_samples,
  620. replacement=True)
  621. next_token_ids = torch.where(t != 0, sampled_token_ids,
  622. argmax_token_ids)
  623. return next_token_ids
  624. class CompiledModelWrapper:
  625. def __init__(self, model: nn.Module):
  626. model = ModelWrapper(model)
  627. self.model = torch.compile(model,
  628. backend="openxla",
  629. fullgraph=True,
  630. dynamic=False)
  631. def __call__(
  632. self,
  633. token_ids: torch.Tensor,
  634. position_ids: torch.Tensor,
  635. attn_metadata: AttentionMetadata,
  636. input_lens: torch.Tensor,
  637. t: torch.Tensor,
  638. p: torch.Tensor,
  639. num_samples: int,
  640. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  641. ) -> torch.Tensor:
  642. # NOTE: There are two stages of compilation: torch.compile and
  643. # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
  644. # overhead by reusing the FX graph for different shapes.
  645. # However, the XLA graph will still require static shapes and needs to
  646. # be re-compiled for every different shapes. This overhead is inevitable
  647. # in the first run, but can be skipped afterwards as we cache the XLA
  648. # graphs in the disk (APHRODITE_XLA_CACHE_PATH).
  649. if attn_metadata.num_prefills > 0:
  650. # Prefll
  651. torch._dynamo.mark_dynamic(token_ids, 1)
  652. torch._dynamo.mark_dynamic(position_ids, 1)
  653. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
  654. else:
  655. # Decode
  656. torch._dynamo.mark_dynamic(token_ids, 0)
  657. torch._dynamo.mark_dynamic(position_ids, 0)
  658. torch._dynamo.mark_dynamic(input_lens, 0)
  659. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
  660. torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
  661. torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
  662. torch._dynamo.mark_dynamic(t, 0)
  663. torch._dynamo.mark_dynamic(p, 0)
  664. return self.model(token_ids, position_ids, attn_metadata, input_lens,
  665. t, p, num_samples, kv_caches)
  666. def _get_padded_prefill_len(x: int) -> int:
  667. # NOTE: The pallas FlashAttention kernel requires the sequence
  668. # length to be a multiple of 16. We pad the prompt length to the nearest
  669. # multiple of 16. This is also good for performance.
  670. if x <= 16:
  671. return 16
  672. return 1 << (x - 1).bit_length()
  673. def _get_padded_batch_size(batch_size: int) -> int:
  674. # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
  675. # To meet this requirement in the simplest way, we set the minimal batch
  676. # size to 8.
  677. if batch_size <= 8:
  678. return 8
  679. else:
  680. return ((batch_size + 15) // 16) * 16
  681. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  682. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  683. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  684. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  685. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  686. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  687. return logits