tpu_model_runner.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. import time
  2. from dataclasses import dataclass
  3. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
  4. Type, Union)
  5. from unittest.mock import patch
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import torch_xla.core.xla_model as xm
  10. import torch_xla.runtime as xr
  11. from loguru import logger
  12. from aphrodite.attention import AttentionMetadata, get_attn_backend
  13. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  14. ModelConfig, ParallelConfig,
  15. SchedulerConfig)
  16. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  17. IntermediateTensors, Logprob,
  18. SequenceGroupMetadata, SequenceOutput)
  19. from aphrodite.compilation.wrapper import (
  20. TorchCompileWrapperWithCustomDispacther)
  21. from aphrodite.modeling.layers.sampler import SamplerOutput
  22. from aphrodite.modeling.model_loader import get_model
  23. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  24. from aphrodite.task_handler.model_runner_base import (
  25. ModelRunnerBase, ModelRunnerInputBase,
  26. _add_attn_metadata_broadcastable_dict,
  27. _init_attn_metadata_from_tensor_dict)
  28. if TYPE_CHECKING:
  29. from aphrodite.attention.backends.abstract import AttentionBackend
  30. # Here we utilize the behavior that out-of-bound index is ignored.
  31. # FIXME: Find a more reliable way to prevent possible bugs.
  32. _PAD_SLOT_ID = 1_000_000_000
  33. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  34. _ENABLE_TOP_P = False
  35. # FIXME: A temporary hack to support `n > 1`.
  36. # This can significantly affect the performance if too large.
  37. _MAX_NUM_SAMPLES = 128
  38. @dataclass(frozen=True)
  39. class ModelInputForTPU(ModelRunnerInputBase):
  40. token_ids: torch.Tensor
  41. position_ids: torch.Tensor
  42. attn_metadata: AttentionMetadata
  43. input_lens: torch.Tensor
  44. t: torch.Tensor
  45. p: torch.Tensor
  46. num_samples: int
  47. best_of: List[int]
  48. seq_groups: List[List[int]]
  49. virtual_engine: int = 0
  50. async_callback: Optional[Callable] = None
  51. def as_broadcastable_tensor_dict(
  52. self) -> Dict[str, Union[int, torch.Tensor]]:
  53. tensor_dict = {
  54. "token_ids": self.token_ids,
  55. "position_ids": self.position_ids,
  56. "input_lens": self.input_lens,
  57. "t": self.t,
  58. "p": self.p,
  59. "num_samples": self.num_samples,
  60. "best_of": self.best_of,
  61. "seq_groups": self.seq_groups,
  62. "virtual_engine": self.virtual_engine,
  63. }
  64. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  65. return tensor_dict
  66. @classmethod
  67. def from_broadcasted_tensor_dict(
  68. cls: Type["ModelInputForTPU"],
  69. tensor_dict: Dict[str, Any],
  70. attn_backend: Optional["AttentionBackend"] = None,
  71. ) -> "ModelInputForTPU":
  72. if attn_backend is not None:
  73. tensor_dict = _init_attn_metadata_from_tensor_dict(
  74. attn_backend, tensor_dict)
  75. return cls(**tensor_dict)
  76. class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
  77. def __init__(
  78. self,
  79. model_config: ModelConfig,
  80. parallel_config: ParallelConfig,
  81. scheduler_config: SchedulerConfig,
  82. device_config: DeviceConfig,
  83. cache_config: CacheConfig,
  84. load_config: LoadConfig,
  85. is_driver_worker: bool = False,
  86. **kwargs,
  87. ):
  88. self.model_config = model_config
  89. self.parallel_config = parallel_config
  90. self.scheduler_config = scheduler_config
  91. self.device_config = device_config
  92. self.cache_config = cache_config
  93. self.load_config = load_config
  94. self.is_driver_worker = is_driver_worker
  95. self.block_size = self.cache_config.block_size
  96. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  97. self.block_size)
  98. self.block_tables = np.zeros(
  99. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  100. dtype=np.int32)
  101. self.attn_backend = get_attn_backend(
  102. self.model_config.get_head_size(),
  103. self.model_config.get_sliding_window(),
  104. self.model_config.dtype,
  105. self.cache_config.cache_dtype,
  106. self.block_size,
  107. self.model_config.is_attention_free(),
  108. False,
  109. )
  110. def load_model(self) -> None:
  111. self.device = self.device_config.device
  112. # NOTE: While the executor assigns the TP ranks to the worker
  113. # process, the ranks can be different from the ranks internally assigned
  114. # by the xm runtime. Therefore, there is a mismatch in the rank
  115. # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
  116. # This is not a problem in linear layers because all-reduce is
  117. # rank-agnostic. However, it matters for all-gather as the ranks
  118. # determine the order of concatenating the output tensors.
  119. # As a workaround, we use the xm's rank assignment only when loading
  120. # the embedding weights.
  121. xm_tp_rank = xr.global_ordinal()
  122. with patch(
  123. "aphrodite.modeling.layers.vocab_parallel_embedding."
  124. "get_tensor_model_parallel_rank",
  125. return_value=xm_tp_rank):
  126. model = get_model(
  127. model_config=self.model_config,
  128. load_config=self.load_config,
  129. device_config=self.device_config,
  130. parallel_config=self.parallel_config,
  131. cache_config=self.cache_config,
  132. scheduler_config=self.scheduler_config,
  133. lora_config=None,
  134. )
  135. model = model.eval()
  136. xm.wait_device_ops()
  137. self.model = ModelWrapper(model)
  138. def _dummy_run(
  139. self,
  140. batch_size: int,
  141. seq_len: int,
  142. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  143. is_prompt: bool,
  144. ) -> None:
  145. if is_prompt:
  146. seq_len = (seq_len + 15) // 16 * 16
  147. token_ids = torch.zeros((batch_size, seq_len),
  148. dtype=torch.int32,
  149. device=self.device)
  150. position_ids = torch.zeros((batch_size, seq_len),
  151. dtype=torch.int32,
  152. device=self.device)
  153. slot_mapping = torch.zeros((batch_size, seq_len),
  154. dtype=torch.int64,
  155. device=self.device)
  156. attn_metadata = self.attn_backend.make_metadata(
  157. num_prefills=batch_size,
  158. num_prefill_tokens=batch_size * seq_len,
  159. num_decode_tokens=0,
  160. slot_mapping=slot_mapping,
  161. block_tables=None,
  162. context_lens=None,
  163. )
  164. input_lens = torch.ones((batch_size, ),
  165. dtype=torch.int32,
  166. device=self.device)
  167. else:
  168. assert seq_len == 1
  169. token_ids = torch.zeros((batch_size, seq_len),
  170. dtype=torch.int32,
  171. device=self.device)
  172. position_ids = torch.zeros((batch_size, seq_len),
  173. dtype=torch.int32,
  174. device=self.device)
  175. slot_mapping = torch.zeros((batch_size, seq_len),
  176. dtype=torch.int64,
  177. device=self.device)
  178. block_tables = torch.zeros(
  179. (batch_size, self.max_num_blocks_per_seq),
  180. dtype=torch.int32,
  181. device=self.device)
  182. context_lens = torch.ones((batch_size, ),
  183. dtype=torch.int32,
  184. device=self.device)
  185. input_lens = torch.ones((batch_size, ),
  186. dtype=torch.int32,
  187. device=self.device)
  188. attn_metadata = self.attn_backend.make_metadata(
  189. num_prefills=0,
  190. num_prefill_tokens=0,
  191. num_decode_tokens=batch_size * seq_len,
  192. slot_mapping=slot_mapping,
  193. block_tables=block_tables,
  194. context_lens=context_lens,
  195. )
  196. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  197. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  198. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  199. # NOTE: There are two stages of compilation: torch.compile and
  200. # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
  201. # overhead by reusing the FX graph for different shapes.
  202. # However, the XLA graph will still require static shapes and needs to
  203. # be re-compiled for every different shapes. This overhead is inevitable
  204. # in the first run, but can be skipped afterwards as we cache the XLA
  205. # graphs in the disk (APHRODITE_XLA_CACHE_PATH).
  206. if is_prompt:
  207. # Prefll
  208. torch._dynamo.mark_dynamic(token_ids, 1)
  209. torch._dynamo.mark_dynamic(position_ids, 1)
  210. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
  211. else:
  212. # Decode
  213. torch._dynamo.mark_dynamic(token_ids, 0)
  214. torch._dynamo.mark_dynamic(position_ids, 0)
  215. torch._dynamo.mark_dynamic(input_lens, 0)
  216. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
  217. torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
  218. torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
  219. torch._dynamo.mark_dynamic(t, 0)
  220. torch._dynamo.mark_dynamic(p, 0)
  221. # Dummy run.
  222. self.model(token_ids,
  223. position_ids,
  224. attn_metadata,
  225. input_lens,
  226. t,
  227. p,
  228. num_samples,
  229. kv_caches,
  230. is_prompt=is_prompt)
  231. def warmup_model(
  232. self,
  233. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  234. ) -> None:
  235. # Prefill
  236. logger.info("Compiling the model with different input shapes...")
  237. start = time.time()
  238. for batch_size in [1]:
  239. seq_len = 16
  240. while True:
  241. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  242. xm.wait_device_ops()
  243. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  244. if seq_len >= self.model_config.max_model_len:
  245. break
  246. num_tokens = batch_size * seq_len
  247. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  248. break
  249. seq_len = seq_len * 2
  250. end = time.time()
  251. logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
  252. # Decode
  253. start = time.time()
  254. seq_len = 1
  255. batch_size = 8 # Must be in sync with _get_padded_batch_size()
  256. while True:
  257. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  258. xm.wait_device_ops()
  259. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  260. if batch_size >= self.scheduler_config.max_num_seqs:
  261. break
  262. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  263. end = time.time()
  264. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  265. def _prepare_prompt(
  266. self,
  267. seq_group_metadata_list: List[SequenceGroupMetadata],
  268. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  269. assert len(seq_group_metadata_list) > 0
  270. input_tokens: List[int] = []
  271. input_positions: List[int] = []
  272. prompt_lens: List[int] = []
  273. slot_mapping: List[int] = []
  274. for seq_group_metadata in seq_group_metadata_list:
  275. assert seq_group_metadata.is_prompt
  276. seq_ids = list(seq_group_metadata.seq_data.keys())
  277. assert len(seq_ids) == 1
  278. seq_id = seq_ids[0]
  279. seq_data = seq_group_metadata.seq_data[seq_id]
  280. # Could include output tokens when a request is preempted.
  281. prompt_tokens = seq_data.get_token_ids()
  282. prompt_len = len(prompt_tokens)
  283. prompt_lens.append(prompt_len)
  284. input_tokens.extend(prompt_tokens)
  285. input_positions.extend(list(range(prompt_len)))
  286. assert seq_group_metadata.block_tables is not None
  287. block_table = seq_group_metadata.block_tables[seq_id]
  288. for i in range(prompt_len):
  289. block_number = block_table[i // self.block_size]
  290. block_offset = i % self.block_size
  291. slot = block_number * self.block_size + block_offset
  292. slot_mapping.append(slot)
  293. # Add paddings to EACH prompt to the smallest power of 2 that is
  294. # greater than or equal to the prompt length.
  295. # We pad the seq_len to reduce the compilation overhead.
  296. # We execute each prompt individually (i.e., with batch_size 1)
  297. # because the FlashAttention kernel does not support ragged inputs.
  298. # TODO(woosuk): Use SplashAttention to support ragged inputs.
  299. padded_prompt_len = _get_padded_prefill_len(prompt_len)
  300. num_paddings = padded_prompt_len - prompt_len
  301. input_tokens += [0] * num_paddings
  302. input_positions += [0] * num_paddings
  303. slot_mapping += [_PAD_SLOT_ID] * num_paddings
  304. assert len(prompt_lens) > 0
  305. num_prefills = len(prompt_lens)
  306. input_tokens = torch.tensor(input_tokens,
  307. dtype=torch.int32,
  308. device="cpu")
  309. input_positions = torch.tensor(input_positions,
  310. dtype=torch.int32,
  311. device="cpu")
  312. slot_mapping = torch.tensor(slot_mapping,
  313. dtype=torch.int64,
  314. device="cpu")
  315. prompt_lens = torch.tensor(prompt_lens,
  316. dtype=torch.int32,
  317. device="cpu")
  318. attn_metadata = self.attn_backend.make_metadata(
  319. num_prefills=num_prefills,
  320. num_prefill_tokens=0, # NOTE: This is not used.
  321. num_decode_tokens=0,
  322. slot_mapping=slot_mapping,
  323. block_tables=None,
  324. context_lens=None,
  325. )
  326. return input_tokens, input_positions, attn_metadata, prompt_lens
  327. def _prepare_decode(
  328. self,
  329. seq_group_metadata_list: List[SequenceGroupMetadata],
  330. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  331. assert len(seq_group_metadata_list) > 0
  332. input_tokens: List[List[int]] = []
  333. input_positions: List[List[int]] = []
  334. slot_mapping: List[List[int]] = []
  335. context_lens: List[int] = []
  336. batch_idx = 0
  337. for seq_group_metadata in seq_group_metadata_list:
  338. assert not seq_group_metadata.is_prompt
  339. seq_ids = list(seq_group_metadata.seq_data.keys())
  340. for seq_id in seq_ids:
  341. seq_data = seq_group_metadata.seq_data[seq_id]
  342. generation_token = seq_data.get_last_token_id()
  343. input_tokens.append([generation_token])
  344. seq_len = seq_data.get_len()
  345. position = seq_len - 1
  346. input_positions.append([position])
  347. context_lens.append(seq_len)
  348. assert seq_group_metadata.block_tables is not None
  349. block_table = seq_group_metadata.block_tables[seq_id]
  350. self.block_tables[batch_idx, :len(block_table)] = block_table
  351. batch_idx += 1
  352. block_number = block_table[position // self.block_size]
  353. block_offset = position % self.block_size
  354. slot = block_number * self.block_size + block_offset
  355. slot_mapping.append([slot])
  356. batch_size = _get_padded_batch_size(batch_idx)
  357. num_paddings = batch_size - batch_idx
  358. input_tokens = input_tokens + [[0]] * num_paddings
  359. input_positions = input_positions + [[0]] * num_paddings
  360. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  361. context_lens = context_lens + [0] * num_paddings
  362. input_tokens = torch.tensor(input_tokens,
  363. dtype=torch.int32,
  364. device="cpu")
  365. input_positions = torch.tensor(input_positions,
  366. dtype=torch.int32,
  367. device="cpu")
  368. slot_mapping = torch.tensor(slot_mapping,
  369. dtype=torch.int64,
  370. device="cpu")
  371. context_lens = torch.tensor(context_lens,
  372. dtype=torch.int32,
  373. device="cpu")
  374. block_tables = torch.tensor(self.block_tables[:batch_size],
  375. dtype=torch.int32,
  376. device="cpu")
  377. input_lens = torch.tensor([1] * batch_size,
  378. dtype=torch.int32,
  379. device="cpu")
  380. attn_metadata = self.attn_backend.make_metadata(
  381. num_prefills=0,
  382. num_prefill_tokens=0,
  383. num_decode_tokens=batch_size,
  384. slot_mapping=slot_mapping,
  385. block_tables=block_tables,
  386. context_lens=context_lens,
  387. )
  388. return input_tokens, input_positions, attn_metadata, input_lens
  389. def _prepare_sample(
  390. self,
  391. seq_group_metadata_list: List[SequenceGroupMetadata],
  392. padded_batch_size: int,
  393. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  394. assert len(seq_group_metadata_list) > 0
  395. t = []
  396. p = []
  397. best_of = []
  398. for seq_group_metadata in seq_group_metadata_list:
  399. sampling_params = seq_group_metadata.sampling_params
  400. t.append(sampling_params.temperature)
  401. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  402. raise NotImplementedError(
  403. "Top-p sampling is currently disabled for the TPU backend "
  404. "due to performance issues.")
  405. p.append(sampling_params.top_p)
  406. if sampling_params.top_k != -1:
  407. raise NotImplementedError(
  408. "Top-k sampling is currently disabled for the TPU backend "
  409. "due to performance issues.")
  410. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  411. raise NotImplementedError(
  412. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  413. "backend.")
  414. best_of.append(sampling_params.best_of)
  415. if sampling_params.use_beam_search:
  416. raise NotImplementedError(
  417. "Beam search is not supported by the TPU backend.")
  418. if sampling_params.logprobs is not None:
  419. raise NotImplementedError(
  420. "logprobs is not currently supported by the TPU backend.")
  421. if sampling_params.prompt_logprobs is not None:
  422. raise NotImplementedError(
  423. "prompt_logprobs is not currently supported by the TPU "
  424. "backend.")
  425. # Repeat the sampling params if the seq group has multiple seqs.
  426. num_seqs = len(seq_group_metadata.seq_data)
  427. t += [t[-1]] * (num_seqs - 1)
  428. p += [p[-1]] * (num_seqs - 1)
  429. best_of += [best_of[-1]] * (num_seqs - 1)
  430. num_paddings = padded_batch_size - len(t)
  431. t += [1.0] * num_paddings
  432. p += [1.0] * num_paddings
  433. t = torch.tensor(t, dtype=torch.float32, device="cpu")
  434. p = torch.tensor(p, dtype=torch.float32, device="cpu")
  435. return t, p, best_of
  436. def prepare_model_input(
  437. self,
  438. seq_group_metadata_list: List[SequenceGroupMetadata],
  439. virtual_engine: int = 0,
  440. finished_requests_ids: Optional[List[str]] = None,
  441. ) -> ModelInputForTPU:
  442. del finished_requests_ids # Unused.
  443. assert virtual_engine == 0
  444. assert len(seq_group_metadata_list) > 0
  445. # NOTE: We assume that all sequences in the group are all prompts or
  446. # all decodes.
  447. is_prompt = seq_group_metadata_list[0].is_prompt
  448. if is_prompt:
  449. inputs = self._prepare_prompt(seq_group_metadata_list)
  450. else:
  451. inputs = self._prepare_decode(seq_group_metadata_list)
  452. input_tokens, input_positions, attn_metadata, input_lens = inputs
  453. padded_batch_size = input_tokens.shape[0]
  454. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  455. padded_batch_size)
  456. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  457. seq_groups = [
  458. list(metadata.seq_data.keys())
  459. for metadata in seq_group_metadata_list
  460. ]
  461. return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
  462. input_lens, t, p, num_samples, best_of,
  463. seq_groups)
  464. def make_model_input_from_broadcasted_tensor_dict(
  465. self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
  466. model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
  467. tensor_dict, attn_backend=self.attn_backend)
  468. return model_input
  469. @torch.no_grad()
  470. def execute_model(
  471. self,
  472. model_input: ModelInputForTPU,
  473. kv_caches: Optional[List[Any]],
  474. intermediate_tensors: Optional[IntermediateTensors] = None,
  475. num_steps: int = 1,
  476. ) -> List[SamplerOutput]:
  477. assert intermediate_tensors is None
  478. if num_steps > 1:
  479. raise ValueError(
  480. "TPUModelRunner does not support multi-step execution.")
  481. def _execute_model(*args):
  482. """Move input args from CPU to device and execute the model."""
  483. new_args = []
  484. for arg in args:
  485. if isinstance(arg, torch.Tensor):
  486. arg = arg.to(self.device)
  487. elif isinstance(arg, AttentionMetadata):
  488. arg.slot_mapping = arg.slot_mapping.to(self.device)
  489. if getattr(arg, "block_tables", None) is not None:
  490. arg.block_tables = arg.block_tables.to(self.device)
  491. if getattr(arg, "context_lens", None) is not None:
  492. arg.context_lens = arg.context_lens.to(self.device)
  493. new_args.append(arg)
  494. return self.model(*new_args, is_prompt=is_prompt)
  495. num_prefills = model_input.attn_metadata.num_prefills
  496. is_prompt = num_prefills > 0
  497. if is_prompt:
  498. # NOTE: Since the FlashAttention kernel does not support
  499. # ragged inputs, we split the prompts into different batches and
  500. # process them separately. This is a temporary hack that should be
  501. # optimized by using SplashAttention.
  502. next_token_ids = []
  503. orig_slot_mapping = model_input.attn_metadata.slot_mapping
  504. batch_size = model_input.input_lens.shape[0]
  505. start_idx = 0
  506. for i in range(batch_size):
  507. # Get the actual prefill_len.
  508. prefill_len = model_input.input_lens[i:i + 1].item()
  509. prefill_len = _get_padded_prefill_len(prefill_len)
  510. end_idx = start_idx + prefill_len
  511. model_input.attn_metadata.slot_mapping = orig_slot_mapping[
  512. None, start_idx:end_idx]
  513. model_input.attn_metadata.num_prefills = 1
  514. output_token_ids = _execute_model(
  515. model_input.token_ids[None, start_idx:end_idx],
  516. model_input.position_ids[None, start_idx:end_idx],
  517. model_input.attn_metadata, model_input.input_lens[i:i + 1],
  518. model_input.t[i:i + 1], model_input.p[i:i + 1],
  519. model_input.num_samples, kv_caches)
  520. if i == 0 and model_input.async_callback is not None:
  521. model_input.async_callback()
  522. # Retrieve the outputs to CPU.
  523. next_token_ids += output_token_ids.cpu().tolist()
  524. start_idx = end_idx
  525. else:
  526. # Execute the model.
  527. output_token_ids = _execute_model(
  528. model_input.token_ids, model_input.position_ids,
  529. model_input.attn_metadata, model_input.input_lens,
  530. model_input.t, model_input.p, model_input.num_samples,
  531. kv_caches)
  532. if model_input.async_callback is not None:
  533. model_input.async_callback()
  534. # Retrieve the outputs to CPU.
  535. next_token_ids = output_token_ids.cpu().tolist()
  536. # NOTE: Minimal code to construct the sampler outputs.
  537. # The TPU backend does not reuse the sampler, since the TPU backend
  538. # does not support the advanced sampling parameters such as logprobs.
  539. zero_logprob = Logprob(0.0)
  540. batch_idx = 0
  541. sampler_outputs = []
  542. for seq_group in model_input.seq_groups:
  543. seq_ids = seq_group
  544. seq_outputs = []
  545. if is_prompt:
  546. assert len(seq_ids) == 1
  547. seq_id = seq_ids[0]
  548. for i in range(model_input.best_of[batch_idx]):
  549. next_token_id = next_token_ids[batch_idx][i]
  550. seq_outputs.append(
  551. SequenceOutput(seq_id, next_token_id,
  552. {next_token_id: zero_logprob}))
  553. batch_idx += 1
  554. else:
  555. for seq_id in seq_ids:
  556. next_token_id = next_token_ids[batch_idx][0]
  557. seq_outputs.append(
  558. SequenceOutput(seq_id, next_token_id,
  559. {next_token_id: zero_logprob}))
  560. batch_idx += 1
  561. sampler_outputs.append(
  562. CompletionSequenceGroupOutput(seq_outputs, None))
  563. return [SamplerOutput(sampler_outputs)]
  564. class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
  565. def __init__(self, model: nn.Module):
  566. self.model = model
  567. compiled_callable = torch.compile(self.forward,
  568. backend="openxla",
  569. fullgraph=True,
  570. dynamic=False)
  571. super().__init__(compiled_callable)
  572. def __call__(self, *args, is_prompt: bool, **kwargs):
  573. if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
  574. # not fully compiled yet, or not using the custom dispatcher,
  575. # let PyTorch handle it
  576. return self.compiled_callable(*args, **kwargs)
  577. # the 3 compiled codes are:
  578. # 0: for profiling
  579. # 1: for prompt
  580. # 2: for decode
  581. # dispatch to the compiled code directly, skip PyTorch
  582. if is_prompt:
  583. with self.dispatch_to_code(1):
  584. return self.forward(*args, **kwargs)
  585. else:
  586. with self.dispatch_to_code(2):
  587. return self.forward(*args, **kwargs)
  588. def forward(
  589. self,
  590. token_ids: torch.Tensor,
  591. position_ids: torch.Tensor,
  592. attn_metadata: AttentionMetadata,
  593. input_lens: torch.Tensor,
  594. t: torch.Tensor,
  595. p: torch.Tensor,
  596. num_samples: int,
  597. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  598. ) -> torch.Tensor:
  599. """Executes the forward pass of the model and samples the next token.
  600. Args:
  601. token_ids: The input token IDs of shape [batch_size, seq_len].
  602. position_ids: The input position IDs of shape [batch_size, seq_len].
  603. attn_metadata: The Pallas attention metadata.
  604. input_lens: The actual input lengths of shape [batch_size].
  605. t: The sampling temperature of shape [batch_size].
  606. p: The top-p probability of shape [batch_size].
  607. num_samples: Number of samples to draw from each logits vector.
  608. kv_caches: The key and value caches. They can be None during the
  609. memory profiling at initialization.
  610. """
  611. batch_size, seq_len = token_ids.shape
  612. # Calculate the positions to sample from.
  613. start_indicies = torch.arange(
  614. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  615. logits_indices = start_indicies + input_lens - 1
  616. # FIXME: This is a temporary hack to avoid using the existing
  617. # sampler and sampling metadata.
  618. sampling_metadata = SamplingMetadata(
  619. seq_groups=[],
  620. selected_token_indices=logits_indices,
  621. categorized_sample_indices={},
  622. num_prompts=attn_metadata.num_prefills,
  623. )
  624. # Skip this in memory profiling at initialization.
  625. if kv_caches[0][0] is not None:
  626. # index_copy_(slot_mapping) only works when the inserted dimension
  627. # is 0. However, the KV cache in the Pallas backend has the shape
  628. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  629. # work, we need to flatten the first three dimensions and modify
  630. # the slot_mapping accordingly.
  631. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  632. slot_mapping = attn_metadata.slot_mapping
  633. slot_mapping = slot_mapping.flatten()
  634. head_indicies = torch.arange(0,
  635. num_kv_heads,
  636. device=slot_mapping.device,
  637. dtype=slot_mapping.dtype)
  638. head_indicies *= block_size * num_blocks
  639. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  640. -1, num_kv_heads)
  641. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  642. slot_mapping = slot_mapping.flatten()
  643. attn_metadata.slot_mapping = slot_mapping
  644. hidden_states = self.model(
  645. token_ids,
  646. position_ids,
  647. kv_caches,
  648. attn_metadata,
  649. )
  650. hidden_states = hidden_states.flatten(0, 1)
  651. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  652. # Argmax sampling.
  653. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
  654. argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
  655. # Zero temperature means greedy decoding. Avoid division by zero.
  656. nonzero_t = torch.where(t != 0, t, 1.0)
  657. logits = logits / nonzero_t.unsqueeze(dim=1)
  658. if _ENABLE_TOP_P:
  659. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  660. # Random sampling.
  661. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  662. sampled_token_ids = torch.multinomial(probs,
  663. num_samples,
  664. replacement=True)
  665. next_token_ids = torch.where(t != 0, sampled_token_ids,
  666. argmax_token_ids)
  667. return next_token_ids
  668. def _get_padded_prefill_len(x: int) -> int:
  669. # NOTE: The pallas FlashAttention kernel requires the sequence
  670. # length to be a multiple of 16. We pad the prompt length to the nearest
  671. # multiple of 16. This is also good for performance.
  672. if x <= 16:
  673. return 16
  674. return 1 << (x - 1).bit_length()
  675. def _get_padded_batch_size(batch_size: int) -> int:
  676. # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
  677. # To meet this requirement in the simplest way, we set the minimal batch
  678. # size to 8.
  679. if batch_size <= 8:
  680. return 8
  681. else:
  682. return ((batch_size + 15) // 16) * 16
  683. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  684. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  685. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  686. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  687. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  688. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  689. return logits