tpu_model_runner.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836
  1. import time
  2. from dataclasses import dataclass
  3. from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
  4. Type, Union)
  5. from unittest.mock import patch
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import torch_xla.core.xla_model as xm
  10. import torch_xla.runtime as xr
  11. from loguru import logger
  12. from aphrodite.attention import AttentionMetadata, get_attn_backend
  13. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  14. ModelConfig, ParallelConfig,
  15. SchedulerConfig)
  16. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  17. IntermediateTensors, Logprob,
  18. SequenceGroupMetadata, SequenceOutput)
  19. from aphrodite.compilation.wrapper import (
  20. TorchCompileWrapperWithCustomDispacther)
  21. from aphrodite.modeling.layers.sampler import SamplerOutput
  22. from aphrodite.modeling.model_loader import get_model
  23. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  24. from aphrodite.worker.model_runner_base import (
  25. ModelRunnerBase, ModelRunnerInputBase,
  26. _add_attn_metadata_broadcastable_dict,
  27. _init_attn_metadata_from_tensor_dict)
  28. if TYPE_CHECKING:
  29. from aphrodite.attention.backends.abstract import AttentionBackend
  30. # Here we utilize the behavior that out-of-bound index is ignored.
  31. # FIXME: Find a more reliable way to prevent possible bugs.
  32. _PAD_SLOT_ID = 1_000_000_000
  33. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  34. _ENABLE_TOP_P = False
  35. # FIXME: A temporary hack to support `n > 1`.
  36. # This can significantly affect the performance if too large.
  37. _MAX_NUM_SAMPLES = 128
  38. @dataclass(frozen=True)
  39. class ModelInputForTPU(ModelRunnerInputBase):
  40. token_ids: torch.Tensor
  41. position_ids: torch.Tensor
  42. attn_metadata: AttentionMetadata
  43. input_lens: torch.Tensor
  44. t: torch.Tensor
  45. p: torch.Tensor
  46. num_samples: int
  47. best_of: List[int]
  48. seq_groups: List[List[int]]
  49. is_first_multi_step: bool = True
  50. is_last_step: bool = True
  51. virtual_engine: int = 0
  52. async_callback: Optional[Callable] = None
  53. def as_broadcastable_tensor_dict(
  54. self) -> Dict[str, Union[int, torch.Tensor]]:
  55. tensor_dict = {
  56. "token_ids": self.token_ids,
  57. "position_ids": self.position_ids,
  58. "input_lens": self.input_lens,
  59. "t": self.t,
  60. "p": self.p,
  61. "num_samples": self.num_samples,
  62. "best_of": self.best_of,
  63. "seq_groups": self.seq_groups,
  64. "is_first_multi_step": self.is_first_multi_step,
  65. "is_last_step": self.is_last_step,
  66. "virtual_engine": self.virtual_engine,
  67. }
  68. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  69. return tensor_dict
  70. @classmethod
  71. def from_broadcasted_tensor_dict(
  72. cls: Type["ModelInputForTPU"],
  73. tensor_dict: Dict[str, Any],
  74. attn_backend: Optional["AttentionBackend"] = None,
  75. ) -> "ModelInputForTPU":
  76. if attn_backend is not None:
  77. tensor_dict = _init_attn_metadata_from_tensor_dict(
  78. attn_backend, tensor_dict)
  79. return cls(**tensor_dict)
  80. class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
  81. def __init__(
  82. self,
  83. model_config: ModelConfig,
  84. parallel_config: ParallelConfig,
  85. scheduler_config: SchedulerConfig,
  86. device_config: DeviceConfig,
  87. cache_config: CacheConfig,
  88. load_config: LoadConfig,
  89. is_driver_worker: bool = False,
  90. **kwargs,
  91. ):
  92. self.model_config = model_config
  93. self.parallel_config = parallel_config
  94. self.scheduler_config = scheduler_config
  95. self.device_config = device_config
  96. self.cache_config = cache_config
  97. self.load_config = load_config
  98. self.is_driver_worker = is_driver_worker
  99. self.block_size = self.cache_config.block_size
  100. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  101. self.block_size)
  102. self.block_tables = np.zeros(
  103. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  104. dtype=np.int32)
  105. self.attn_backend = get_attn_backend(
  106. self.model_config.get_head_size(),
  107. self.model_config.get_sliding_window(),
  108. self.model_config.dtype,
  109. self.cache_config.cache_dtype,
  110. self.block_size,
  111. self.model_config.is_attention_free(),
  112. False,
  113. )
  114. self.cached_step_outputs: List[torch.Tensor] = []
  115. def load_model(self) -> None:
  116. self.device = self.device_config.device
  117. # NOTE: While the executor assigns the TP ranks to the worker
  118. # process, the ranks can be different from the ranks internally assigned
  119. # by the xm runtime. Therefore, there is a mismatch in the rank
  120. # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
  121. # This is not a problem in linear layers because all-reduce is
  122. # rank-agnostic. However, it matters for all-gather as the ranks
  123. # determine the order of concatenating the output tensors.
  124. # As a workaround, we use the xm's rank assignment only when loading
  125. # the embedding weights.
  126. xm_tp_rank = xr.global_ordinal()
  127. with patch(
  128. "aphrodite.modeling.layers.vocab_parallel_embedding."
  129. "get_tensor_model_parallel_rank",
  130. return_value=xm_tp_rank):
  131. model = get_model(
  132. model_config=self.model_config,
  133. load_config=self.load_config,
  134. device_config=self.device_config,
  135. parallel_config=self.parallel_config,
  136. cache_config=self.cache_config,
  137. scheduler_config=self.scheduler_config,
  138. lora_config=None,
  139. )
  140. model = model.eval()
  141. xm.wait_device_ops()
  142. self.model = ModelWrapper(model)
  143. def _dummy_run(
  144. self,
  145. batch_size: int,
  146. seq_len: int,
  147. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  148. is_prompt: bool,
  149. ) -> None:
  150. if is_prompt:
  151. seq_len = (seq_len + 15) // 16 * 16
  152. token_ids = torch.zeros((batch_size, seq_len),
  153. dtype=torch.int32,
  154. device=self.device)
  155. position_ids = torch.zeros((batch_size, seq_len),
  156. dtype=torch.int32,
  157. device=self.device)
  158. slot_mapping = torch.zeros((batch_size, seq_len),
  159. dtype=torch.int64,
  160. device=self.device)
  161. attn_metadata = self.attn_backend.make_metadata(
  162. num_prefills=batch_size,
  163. num_prefill_tokens=batch_size * seq_len,
  164. num_decode_tokens=0,
  165. slot_mapping=slot_mapping,
  166. block_tables=None,
  167. context_lens=None,
  168. )
  169. input_lens = torch.ones((batch_size, ),
  170. dtype=torch.int32,
  171. device=self.device)
  172. else:
  173. assert seq_len == 1
  174. token_ids = torch.zeros((batch_size, seq_len),
  175. dtype=torch.int32,
  176. device=self.device)
  177. position_ids = torch.zeros((batch_size, seq_len),
  178. dtype=torch.int32,
  179. device=self.device)
  180. slot_mapping = torch.zeros((batch_size, seq_len),
  181. dtype=torch.int64,
  182. device=self.device)
  183. block_tables = torch.zeros(
  184. (batch_size, self.max_num_blocks_per_seq),
  185. dtype=torch.int32,
  186. device=self.device)
  187. context_lens = torch.ones((batch_size, ),
  188. dtype=torch.int32,
  189. device=self.device)
  190. input_lens = torch.ones((batch_size, ),
  191. dtype=torch.int32,
  192. device=self.device)
  193. attn_metadata = self.attn_backend.make_metadata(
  194. num_prefills=0,
  195. num_prefill_tokens=0,
  196. num_decode_tokens=batch_size * seq_len,
  197. slot_mapping=slot_mapping,
  198. block_tables=block_tables,
  199. context_lens=context_lens,
  200. )
  201. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  202. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  203. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  204. # NOTE: There are two stages of compilation: torch.compile and
  205. # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
  206. # overhead by reusing the FX graph for different shapes.
  207. # However, the XLA graph will still require static shapes and needs to
  208. # be re-compiled for every different shapes. This overhead is inevitable
  209. # in the first run, but can be skipped afterwards as we cache the XLA
  210. # graphs in the disk (APHRODITE_XLA_CACHE_PATH).
  211. if is_prompt:
  212. # Prefll
  213. torch._dynamo.mark_dynamic(token_ids, 1)
  214. torch._dynamo.mark_dynamic(position_ids, 1)
  215. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
  216. else:
  217. # Decode
  218. torch._dynamo.mark_dynamic(token_ids, 0)
  219. torch._dynamo.mark_dynamic(position_ids, 0)
  220. torch._dynamo.mark_dynamic(input_lens, 0)
  221. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
  222. torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
  223. torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
  224. torch._dynamo.mark_dynamic(t, 0)
  225. torch._dynamo.mark_dynamic(p, 0)
  226. # Dummy run.
  227. self.model(token_ids,
  228. position_ids,
  229. attn_metadata,
  230. input_lens,
  231. t,
  232. p,
  233. num_samples,
  234. kv_caches,
  235. is_prompt=is_prompt)
  236. def warmup_model(
  237. self,
  238. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  239. ) -> None:
  240. # Prefill
  241. logger.info("Compiling the model with different input shapes...")
  242. start = time.time()
  243. for batch_size in [1]:
  244. seq_len = 16
  245. while True:
  246. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  247. xm.wait_device_ops()
  248. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  249. if seq_len >= self.model_config.max_model_len:
  250. break
  251. num_tokens = batch_size * seq_len
  252. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  253. break
  254. seq_len = seq_len * 2
  255. end = time.time()
  256. logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
  257. # Decode
  258. start = time.time()
  259. seq_len = 1
  260. batch_size = 8 # Must be in sync with _get_padded_batch_size()
  261. while True:
  262. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  263. xm.wait_device_ops()
  264. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  265. if batch_size >= self.scheduler_config.max_num_seqs:
  266. break
  267. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  268. end = time.time()
  269. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  270. def _prepare_prompt(
  271. self,
  272. seq_group_metadata_list: List[SequenceGroupMetadata],
  273. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  274. assert len(seq_group_metadata_list) > 0
  275. input_tokens: List[int] = []
  276. input_positions: List[int] = []
  277. prompt_lens: List[int] = []
  278. slot_mapping: List[int] = []
  279. for seq_group_metadata in seq_group_metadata_list:
  280. assert seq_group_metadata.is_prompt
  281. seq_ids = list(seq_group_metadata.seq_data.keys())
  282. assert len(seq_ids) == 1
  283. seq_id = seq_ids[0]
  284. seq_data = seq_group_metadata.seq_data[seq_id]
  285. # Could include output tokens when a request is preempted.
  286. prompt_tokens = seq_data.get_token_ids()
  287. prompt_len = len(prompt_tokens)
  288. prompt_lens.append(prompt_len)
  289. input_tokens.extend(prompt_tokens)
  290. input_positions.extend(list(range(prompt_len)))
  291. assert seq_group_metadata.block_tables is not None
  292. block_table = seq_group_metadata.block_tables[seq_id]
  293. for i in range(prompt_len):
  294. block_number = block_table[i // self.block_size]
  295. block_offset = i % self.block_size
  296. slot = block_number * self.block_size + block_offset
  297. slot_mapping.append(slot)
  298. # Add paddings to EACH prompt to the smallest power of 2 that is
  299. # greater than or equal to the prompt length.
  300. # We pad the seq_len to reduce the compilation overhead.
  301. # We execute each prompt individually (i.e., with batch_size 1)
  302. # because the FlashAttention kernel does not support ragged inputs.
  303. # TODO(woosuk): Use SplashAttention to support ragged inputs.
  304. padded_prompt_len = _get_padded_prefill_len(prompt_len)
  305. num_paddings = padded_prompt_len - prompt_len
  306. input_tokens += [0] * num_paddings
  307. input_positions += [0] * num_paddings
  308. slot_mapping += [_PAD_SLOT_ID] * num_paddings
  309. assert len(prompt_lens) > 0
  310. num_prefills = len(prompt_lens)
  311. input_tokens = torch.tensor(input_tokens,
  312. dtype=torch.int32,
  313. device="cpu")
  314. input_positions = torch.tensor(input_positions,
  315. dtype=torch.int32,
  316. device="cpu")
  317. slot_mapping = torch.tensor(slot_mapping,
  318. dtype=torch.int64,
  319. device="cpu")
  320. prompt_lens = torch.tensor(prompt_lens,
  321. dtype=torch.int32,
  322. device="cpu")
  323. attn_metadata = self.attn_backend.make_metadata(
  324. num_prefills=num_prefills,
  325. num_prefill_tokens=0, # NOTE: This is not used.
  326. num_decode_tokens=0,
  327. slot_mapping=slot_mapping,
  328. block_tables=None,
  329. context_lens=None,
  330. )
  331. return input_tokens, input_positions, attn_metadata, prompt_lens
  332. def _prepare_decode(
  333. self,
  334. seq_group_metadata_list: List[SequenceGroupMetadata],
  335. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  336. assert len(seq_group_metadata_list) > 0
  337. input_tokens: List[List[int]] = []
  338. input_positions: List[List[int]] = []
  339. slot_mapping: List[List[int]] = []
  340. context_lens: List[int] = []
  341. batch_idx = 0
  342. for seq_group_metadata in seq_group_metadata_list:
  343. assert not seq_group_metadata.is_prompt
  344. seq_ids = list(seq_group_metadata.seq_data.keys())
  345. for seq_id in seq_ids:
  346. seq_data = seq_group_metadata.seq_data[seq_id]
  347. generation_token = seq_data.get_last_token_id()
  348. input_tokens.append([generation_token])
  349. seq_len = seq_data.get_len()
  350. position = seq_len - 1
  351. input_positions.append([position])
  352. context_lens.append(seq_len)
  353. assert seq_group_metadata.block_tables is not None
  354. block_table = seq_group_metadata.block_tables[seq_id]
  355. self.block_tables[batch_idx, :len(block_table)] = block_table
  356. batch_idx += 1
  357. block_number = block_table[position // self.block_size]
  358. block_offset = position % self.block_size
  359. slot = block_number * self.block_size + block_offset
  360. slot_mapping.append([slot])
  361. batch_size = _get_padded_batch_size(batch_idx)
  362. num_paddings = batch_size - batch_idx
  363. input_tokens = input_tokens + [[0]] * num_paddings
  364. input_positions = input_positions + [[0]] * num_paddings
  365. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  366. context_lens = context_lens + [0] * num_paddings
  367. input_tokens = torch.tensor(input_tokens,
  368. dtype=torch.int32,
  369. device="cpu")
  370. input_positions = torch.tensor(input_positions,
  371. dtype=torch.int32,
  372. device="cpu")
  373. slot_mapping = torch.tensor(slot_mapping,
  374. dtype=torch.int64,
  375. device="cpu")
  376. context_lens = torch.tensor(context_lens,
  377. dtype=torch.int32,
  378. device="cpu")
  379. block_tables = torch.tensor(self.block_tables[:batch_size],
  380. dtype=torch.int32,
  381. device="cpu")
  382. input_lens = torch.tensor([1] * batch_size,
  383. dtype=torch.int32,
  384. device="cpu")
  385. attn_metadata = self.attn_backend.make_metadata(
  386. num_prefills=0,
  387. num_prefill_tokens=0,
  388. num_decode_tokens=batch_size,
  389. slot_mapping=slot_mapping,
  390. block_tables=block_tables,
  391. context_lens=context_lens,
  392. )
  393. return input_tokens, input_positions, attn_metadata, input_lens
  394. def _prepare_sample(
  395. self,
  396. seq_group_metadata_list: List[SequenceGroupMetadata],
  397. padded_batch_size: int,
  398. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  399. assert len(seq_group_metadata_list) > 0
  400. t = []
  401. p = []
  402. best_of = []
  403. for seq_group_metadata in seq_group_metadata_list:
  404. sampling_params = seq_group_metadata.sampling_params
  405. t.append(sampling_params.temperature)
  406. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  407. raise NotImplementedError(
  408. "Top-p sampling is currently disabled for the TPU backend "
  409. "due to performance issues.")
  410. p.append(sampling_params.top_p)
  411. if sampling_params.top_k != -1:
  412. raise NotImplementedError(
  413. "Top-k sampling is currently disabled for the TPU backend "
  414. "due to performance issues.")
  415. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  416. raise NotImplementedError(
  417. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  418. "backend.")
  419. best_of.append(sampling_params.best_of)
  420. if sampling_params.use_beam_search:
  421. raise NotImplementedError(
  422. "Beam search is not supported by the TPU backend.")
  423. if sampling_params.logprobs is not None:
  424. raise NotImplementedError(
  425. "logprobs is not currently supported by the TPU backend.")
  426. if sampling_params.prompt_logprobs is not None:
  427. raise NotImplementedError(
  428. "prompt_logprobs is not currently supported by the TPU "
  429. "backend.")
  430. # Repeat the sampling params if the seq group has multiple seqs.
  431. num_seqs = len(seq_group_metadata.seq_data)
  432. t += [t[-1]] * (num_seqs - 1)
  433. p += [p[-1]] * (num_seqs - 1)
  434. best_of += [best_of[-1]] * (num_seqs - 1)
  435. num_paddings = padded_batch_size - len(t)
  436. t += [1.0] * num_paddings
  437. p += [1.0] * num_paddings
  438. t = torch.tensor(t, dtype=torch.float32, device="cpu")
  439. p = torch.tensor(p, dtype=torch.float32, device="cpu")
  440. return t, p, best_of
  441. def prepare_model_input(
  442. self,
  443. seq_group_metadata_list: List[SequenceGroupMetadata],
  444. virtual_engine: int = 0,
  445. finished_requests_ids: Optional[List[str]] = None,
  446. ) -> ModelInputForTPU:
  447. del finished_requests_ids # Unused.
  448. assert virtual_engine == 0
  449. assert len(seq_group_metadata_list) > 0
  450. # NOTE: We assume that all sequences in the group are all prompts or
  451. # all decodes.
  452. is_prompt = seq_group_metadata_list[0].is_prompt
  453. if is_prompt:
  454. inputs = self._prepare_prompt(seq_group_metadata_list)
  455. else:
  456. inputs = self._prepare_decode(seq_group_metadata_list)
  457. input_tokens, input_positions, attn_metadata, input_lens = inputs
  458. padded_batch_size = input_tokens.shape[0]
  459. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  460. padded_batch_size)
  461. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  462. seq_groups = [
  463. list(metadata.seq_data.keys())
  464. for metadata in seq_group_metadata_list
  465. ]
  466. return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
  467. input_lens, t, p, num_samples, best_of,
  468. seq_groups)
  469. def make_model_input_from_broadcasted_tensor_dict(
  470. self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
  471. model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
  472. tensor_dict, attn_backend=self.attn_backend)
  473. return model_input
  474. @torch.no_grad()
  475. def execute_model(
  476. self,
  477. model_input: ModelInputForTPU,
  478. kv_caches: Optional[List[Any]],
  479. intermediate_tensors: Optional[IntermediateTensors] = None,
  480. num_steps: int = 1,
  481. ) -> List[SamplerOutput]:
  482. assert intermediate_tensors is None
  483. if not model_input.is_first_multi_step:
  484. if not model_input.is_last_step:
  485. return []
  486. use_async_out_proc = model_input.async_callback is not None
  487. sampler_outputs = []
  488. num_outputs = len(self.cached_step_outputs)
  489. for i in range(num_outputs):
  490. next_token_ids = self.cached_step_outputs.pop(0)
  491. next_token_ids = next_token_ids.cpu().tolist()
  492. sampler_output = _make_decode_output(next_token_ids,
  493. model_input.seq_groups)
  494. sampler_outputs.append(sampler_output)
  495. if i < num_outputs - 1 and use_async_out_proc:
  496. assert model_input.async_callback is not None
  497. ctx = model_input.async_callback.keywords[ # type: ignore
  498. "ctx"]
  499. ctx.append_output(
  500. outputs=[sampler_output],
  501. seq_group_metadata_list=ctx.seq_group_metadata_list,
  502. scheduler_outputs=ctx.scheduler_outputs,
  503. is_async=False,
  504. is_last_step=False)
  505. model_input.async_callback()
  506. if use_async_out_proc:
  507. return [sampler_outputs[-1]]
  508. else:
  509. return sampler_outputs
  510. is_prompt = model_input.attn_metadata.num_prefills > 0
  511. if is_prompt:
  512. assert num_steps == 1
  513. # NOTE: Since the FlashAttention kernel does not support
  514. # ragged inputs, we split the prompts into different batches and
  515. # process them separately. This is a temporary hack that should be
  516. # optimized by using SplashAttention.
  517. orig_slot_mapping = model_input.attn_metadata.slot_mapping
  518. batch_size = model_input.input_lens.shape[0]
  519. start_idx = 0
  520. next_token_ids = []
  521. for i in range(batch_size):
  522. # Get the actual prefill_len.
  523. prefill_len = model_input.input_lens[i:i + 1].item()
  524. prefill_len = _get_padded_prefill_len(prefill_len)
  525. end_idx = start_idx + prefill_len
  526. token_ids = model_input.token_ids[None, start_idx:end_idx].to(
  527. self.device)
  528. position_ids = model_input.position_ids[None,
  529. start_idx:end_idx].to(
  530. self.device)
  531. attn_metadata = model_input.attn_metadata
  532. attn_metadata.num_prefills = 1
  533. attn_metadata.slot_mapping = orig_slot_mapping[
  534. None, start_idx:end_idx].to(self.device)
  535. input_lens = model_input.input_lens[i:i + 1].to(self.device)
  536. t = model_input.t[i:i + 1].to(self.device)
  537. p = model_input.p[i:i + 1].to(self.device)
  538. output_token_ids = self.model(token_ids,
  539. position_ids,
  540. attn_metadata,
  541. input_lens,
  542. t,
  543. p,
  544. model_input.num_samples,
  545. kv_caches,
  546. is_prompt=True)
  547. next_token_ids.append(output_token_ids[0])
  548. start_idx = end_idx
  549. if model_input.async_callback is not None:
  550. model_input.async_callback()
  551. # Retrieve the outputs to CPU.
  552. next_token_ids = [
  553. output_token_ids.cpu().tolist()
  554. for output_token_ids in next_token_ids
  555. ]
  556. # NOTE: Minimal code to construct the sampler outputs.
  557. # The TPU backend does not reuse the sampler, since the TPU backend
  558. # does not support advanced sampling parameters such as logprobs.
  559. zero_logprob = Logprob(0.0)
  560. sampler_outputs = []
  561. for i, seq_group in enumerate(model_input.seq_groups):
  562. seq_ids = seq_group
  563. assert len(seq_ids) == 1
  564. seq_id = seq_ids[0]
  565. seq_outputs = []
  566. for j in range(model_input.best_of[i]):
  567. next_token_id = next_token_ids[i][j]
  568. seq_outputs.append(
  569. SequenceOutput(seq_id, next_token_id,
  570. {next_token_id: zero_logprob}))
  571. sampler_outputs.append(
  572. CompletionSequenceGroupOutput(seq_outputs, None))
  573. return [SamplerOutput(sampler_outputs)]
  574. else:
  575. token_ids = model_input.token_ids.to(self.device)
  576. position_ids = model_input.position_ids.to(self.device)
  577. attn_metadata = model_input.attn_metadata
  578. attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
  579. self.device)
  580. attn_metadata.block_tables = attn_metadata.block_tables.to(
  581. self.device)
  582. attn_metadata.context_lens = attn_metadata.context_lens.to(
  583. self.device)
  584. t = model_input.t.to(self.device)
  585. p = model_input.p.to(self.device)
  586. input_lens = model_input.input_lens.to(self.device)
  587. for i in range(num_steps):
  588. slot_mapping = attn_metadata.slot_mapping
  589. output_token_ids = self.model(token_ids,
  590. position_ids,
  591. attn_metadata,
  592. input_lens,
  593. t,
  594. p,
  595. model_input.num_samples,
  596. kv_caches,
  597. is_prompt=False)
  598. self.cached_step_outputs.append(output_token_ids)
  599. if i < num_steps - 1:
  600. # Prepare the inputs for the next step.
  601. token_ids = output_token_ids.unsqueeze(dim=1).int()
  602. position_ids = position_ids + 1
  603. attn_metadata.context_lens = attn_metadata.context_lens + 1
  604. block_tables = attn_metadata.block_tables
  605. block_number = block_tables.gather(
  606. 1,
  607. position_ids.long() // self.block_size)
  608. block_offset = position_ids % self.block_size
  609. is_padding = slot_mapping == _PAD_SLOT_ID
  610. slot_mapping = block_number * self.block_size + block_offset
  611. slot_mapping = slot_mapping.long()
  612. slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
  613. slot_mapping)
  614. attn_metadata.slot_mapping = slot_mapping
  615. if model_input.async_callback is not None:
  616. model_input.async_callback()
  617. if num_steps > 1:
  618. return []
  619. # Retrieve the outputs to CPU.
  620. next_token_ids = self.cached_step_outputs.pop(0)
  621. next_token_ids = next_token_ids.cpu().tolist()
  622. sampler_output = _make_decode_output(next_token_ids,
  623. model_input.seq_groups)
  624. return [sampler_output]
  625. class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
  626. def __init__(self, model: nn.Module):
  627. self.model = model
  628. compiled_callable = torch.compile(self.forward,
  629. backend="openxla",
  630. fullgraph=True,
  631. dynamic=False)
  632. super().__init__(compiled_callable)
  633. def __call__(self, *args, is_prompt: bool, **kwargs):
  634. if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
  635. # not fully compiled yet, or not using the custom dispatcher,
  636. # let PyTorch handle it
  637. return self.compiled_callable(*args, **kwargs)
  638. # the 3 compiled codes are:
  639. # 0: for profiling
  640. # 1: for prompt
  641. # 2: for decode
  642. # dispatch to the compiled code directly, skip PyTorch
  643. if is_prompt:
  644. with self.dispatch_to_code(1):
  645. return self.forward(*args, **kwargs)
  646. else:
  647. with self.dispatch_to_code(2):
  648. return self.forward(*args, **kwargs)
  649. def forward(
  650. self,
  651. token_ids: torch.Tensor,
  652. position_ids: torch.Tensor,
  653. attn_metadata: AttentionMetadata,
  654. input_lens: torch.Tensor,
  655. t: torch.Tensor,
  656. p: torch.Tensor,
  657. num_samples: int,
  658. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  659. ) -> torch.Tensor:
  660. """Executes the forward pass of the model and samples the next token.
  661. Args:
  662. token_ids: The input token IDs of shape [batch_size, seq_len].
  663. position_ids: The input position IDs of shape [batch_size, seq_len].
  664. attn_metadata: The Pallas attention metadata.
  665. input_lens: The actual input lengths of shape [batch_size].
  666. t: The sampling temperature of shape [batch_size].
  667. p: The top-p probability of shape [batch_size].
  668. num_samples: Number of samples to draw from each logits vector.
  669. kv_caches: The key and value caches. They can be None during the
  670. memory profiling at initialization.
  671. """
  672. batch_size, seq_len = token_ids.shape
  673. # Calculate the positions to sample from.
  674. start_indicies = torch.arange(
  675. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  676. logits_indices = start_indicies + input_lens - 1
  677. # FIXME: This is a temporary hack to avoid using the existing
  678. # sampler and sampling metadata.
  679. sampling_metadata = SamplingMetadata(
  680. seq_groups=[],
  681. selected_token_indices=logits_indices,
  682. categorized_sample_indices={},
  683. num_prompts=attn_metadata.num_prefills,
  684. )
  685. # Skip this in memory profiling at initialization.
  686. if kv_caches[0][0] is not None:
  687. # index_copy_(slot_mapping) only works when the inserted dimension
  688. # is 0. However, the KV cache in the Pallas backend has the shape
  689. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  690. # work, we need to flatten the first three dimensions and modify
  691. # the slot_mapping accordingly.
  692. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  693. slot_mapping = attn_metadata.slot_mapping
  694. slot_mapping = slot_mapping.flatten()
  695. head_indicies = torch.arange(0,
  696. num_kv_heads,
  697. device=slot_mapping.device,
  698. dtype=slot_mapping.dtype)
  699. head_indicies *= block_size * num_blocks
  700. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  701. -1, num_kv_heads)
  702. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  703. slot_mapping = slot_mapping.flatten()
  704. attn_metadata.slot_mapping = slot_mapping
  705. hidden_states = self.model(
  706. token_ids,
  707. position_ids,
  708. kv_caches,
  709. attn_metadata,
  710. )
  711. hidden_states = hidden_states.flatten(0, 1)
  712. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  713. # Argmax sampling.
  714. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
  715. argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
  716. # Zero temperature means greedy decoding. Avoid division by zero.
  717. nonzero_t = torch.where(t != 0, t, 1.0)
  718. logits = logits / nonzero_t.unsqueeze(dim=1)
  719. if _ENABLE_TOP_P:
  720. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  721. # Random sampling.
  722. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  723. sampled_token_ids = torch.multinomial(probs,
  724. num_samples,
  725. replacement=True)
  726. if num_samples == 1:
  727. argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
  728. sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
  729. next_token_ids = torch.where(t != 0, sampled_token_ids,
  730. argmax_token_ids)
  731. return next_token_ids
  732. def _get_padded_prefill_len(x: int) -> int:
  733. # NOTE: The pallas FlashAttention kernel requires the sequence
  734. # length to be a multiple of 16. We pad the prompt length to the nearest
  735. # multiple of 16. This is also good for performance.
  736. if x <= 16:
  737. return 16
  738. return 1 << (x - 1).bit_length()
  739. def _get_padded_batch_size(batch_size: int) -> int:
  740. # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
  741. # To meet this requirement in the simplest way, we set the minimal batch
  742. # size to 8.
  743. if batch_size <= 8:
  744. return 8
  745. else:
  746. return ((batch_size + 15) // 16) * 16
  747. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  748. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  749. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  750. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  751. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  752. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  753. return logits
  754. def _make_decode_output(
  755. next_token_ids: List[int],
  756. seq_groups: List[List[int]],
  757. ) -> SamplerOutput:
  758. zero_logprob = Logprob(0.0)
  759. sampler_outputs = []
  760. batch_idx = 0
  761. for seq_group in seq_groups:
  762. seq_ids = seq_group
  763. seq_outputs = []
  764. for seq_id in seq_ids:
  765. next_token_id = next_token_ids[batch_idx]
  766. seq_outputs.append(
  767. SequenceOutput(seq_id, next_token_id,
  768. {next_token_id: zero_logprob}))
  769. batch_idx += 1
  770. sampler_outputs.append(CompletionSequenceGroupOutput(
  771. seq_outputs, None))
  772. return SamplerOutput(sampler_outputs)