tpu_model_runner.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. import time
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  4. from unittest.mock import patch
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch_xla.core.xla_model as xm
  9. import torch_xla.runtime as xr
  10. from loguru import logger
  11. from aphrodite.attention import AttentionMetadata, get_attn_backend
  12. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  13. ModelConfig, ParallelConfig,
  14. SchedulerConfig)
  15. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  16. IntermediateTensors, Logprob,
  17. SamplerOutput, SequenceGroupMetadata,
  18. SequenceOutput)
  19. from aphrodite.modeling.model_loader import get_model
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.task_handler.model_runner_base import (
  22. ModelRunnerBase, ModelRunnerInputBase,
  23. _add_attn_metadata_broadcastable_dict,
  24. _init_attn_metadata_from_tensor_dict)
  25. if TYPE_CHECKING:
  26. from aphrodite.attention.backends.abstract import AttentionBackend
  27. # Here we utilize the behavior that out-of-bound index is ignored.
  28. # FIXME: Find a more reliable way to prevent possible bugs.
  29. _PAD_SLOT_ID = 1_000_000_000
  30. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  31. _ENABLE_TOP_P = False
  32. # FIXME: A temporary hack to support `n > 1`.
  33. # This can significantly affect the performance if too large.
  34. _MAX_NUM_SAMPLES = 128
  35. @dataclass(frozen=True)
  36. class ModelInputForTPU(ModelRunnerInputBase):
  37. token_ids: torch.Tensor
  38. position_ids: torch.Tensor
  39. attn_metadata: AttentionMetadata
  40. input_lens: torch.Tensor
  41. t: torch.Tensor
  42. p: torch.Tensor
  43. num_samples: int
  44. best_of: List[int]
  45. seq_groups: List[List[int]]
  46. def as_broadcastable_tensor_dict(
  47. self) -> Dict[str, Union[int, torch.Tensor]]:
  48. tensor_dict = {
  49. "token_ids": self.token_ids,
  50. "position_ids": self.position_ids,
  51. "input_lens": self.input_lens,
  52. "t": self.t,
  53. "p": self.p,
  54. "num_samples": self.num_samples,
  55. "best_of": self.best_of,
  56. "seq_groups": self.seq_groups,
  57. "virtual_engine": self.virtual_engine,
  58. }
  59. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  60. return tensor_dict
  61. @classmethod
  62. def from_broadcasted_tensor_dict(
  63. cls: Type["ModelInputForTPU"],
  64. tensor_dict: Dict[str, Any],
  65. attn_backend: Optional["AttentionBackend"] = None,
  66. ) -> "ModelInputForTPU":
  67. if attn_backend is not None:
  68. tensor_dict = _init_attn_metadata_from_tensor_dict(
  69. attn_backend, tensor_dict)
  70. return cls(**tensor_dict)
  71. class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
  72. def __init__(
  73. self,
  74. model_config: ModelConfig,
  75. parallel_config: ParallelConfig,
  76. scheduler_config: SchedulerConfig,
  77. device_config: DeviceConfig,
  78. cache_config: CacheConfig,
  79. load_config: LoadConfig,
  80. is_driver_worker: bool = False,
  81. **kwargs,
  82. ):
  83. self.model_config = model_config
  84. self.parallel_config = parallel_config
  85. self.scheduler_config = scheduler_config
  86. self.device_config = device_config
  87. self.cache_config = cache_config
  88. self.load_config = load_config
  89. self.is_driver_worker = is_driver_worker
  90. self.block_size = self.cache_config.block_size
  91. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  92. self.block_size)
  93. self.block_tables = np.zeros(
  94. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  95. dtype=np.int32)
  96. self.attn_backend = get_attn_backend(
  97. self.model_config.get_head_size(),
  98. self.model_config.get_sliding_window(),
  99. self.model_config.dtype,
  100. self.cache_config.cache_dtype,
  101. self.block_size,
  102. self.model_config.is_attention_free(),
  103. False,
  104. )
  105. def load_model(self) -> None:
  106. self.device = self.device_config.device
  107. # NOTE: While the executor assigns the TP ranks to the worker
  108. # process, the ranks can be different from the ranks internally assigned
  109. # by the xm runtime. Therefore, there is a mismatch in the rank
  110. # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
  111. # This is not a problem in linear layers because all-reduce is
  112. # rank-agnostic. However, it matters for all-gather as the ranks
  113. # determine the order of concatenating the output tensors.
  114. # As a workaround, we use the xm's rank assignment only when loading
  115. # the embedding weights.
  116. xm_tp_rank = xr.global_ordinal()
  117. with patch(
  118. "aphrodite.modeling.layers.vocab_parallel_embedding."
  119. "get_tensor_model_parallel_rank",
  120. return_value=xm_tp_rank):
  121. model = get_model(
  122. model_config=self.model_config,
  123. load_config=self.load_config,
  124. device_config=self.device_config,
  125. parallel_config=self.parallel_config,
  126. cache_config=self.cache_config,
  127. scheduler_config=self.scheduler_config,
  128. lora_config=None,
  129. )
  130. model = model.eval()
  131. xm.wait_device_ops()
  132. model = ModelWrapper(model)
  133. self.model = torch.compile(model,
  134. backend="openxla",
  135. fullgraph=True,
  136. dynamic=False)
  137. def _dummy_run(
  138. self,
  139. batch_size: int,
  140. seq_len: int,
  141. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  142. is_prompt: bool,
  143. ) -> None:
  144. if is_prompt:
  145. seq_len = (seq_len + 15) // 16 * 16
  146. token_ids = torch.zeros((batch_size, seq_len),
  147. dtype=torch.int32,
  148. device=self.device)
  149. position_ids = torch.zeros((batch_size, seq_len),
  150. dtype=torch.int32,
  151. device=self.device)
  152. slot_mapping = torch.zeros((batch_size, seq_len),
  153. dtype=torch.int64,
  154. device=self.device)
  155. attn_metadata = self.attn_backend.make_metadata(
  156. num_prefills=batch_size,
  157. num_prefill_tokens=batch_size * seq_len,
  158. num_decode_tokens=0,
  159. slot_mapping=slot_mapping,
  160. block_tables=None,
  161. context_lens=None,
  162. )
  163. input_lens = torch.ones((batch_size, ),
  164. dtype=torch.int32,
  165. device=self.device)
  166. else:
  167. assert seq_len == 1
  168. token_ids = torch.zeros((batch_size, seq_len),
  169. dtype=torch.int32,
  170. device=self.device)
  171. position_ids = torch.zeros((batch_size, seq_len),
  172. dtype=torch.int32,
  173. device=self.device)
  174. slot_mapping = torch.zeros((batch_size, seq_len),
  175. dtype=torch.int64,
  176. device=self.device)
  177. block_tables = torch.zeros(
  178. (batch_size, self.max_num_blocks_per_seq),
  179. dtype=torch.int32,
  180. device=self.device)
  181. context_lens = torch.ones((batch_size, ),
  182. dtype=torch.int32,
  183. device=self.device)
  184. input_lens = torch.ones((batch_size, ),
  185. dtype=torch.int32,
  186. device=self.device)
  187. attn_metadata = self.attn_backend.make_metadata(
  188. num_prefills=0,
  189. num_prefill_tokens=0,
  190. num_decode_tokens=batch_size * seq_len,
  191. slot_mapping=slot_mapping,
  192. block_tables=block_tables,
  193. context_lens=context_lens,
  194. )
  195. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  196. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  197. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  198. # NOTE: There are two stages of compilation: torch.compile and
  199. # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
  200. # overhead by reusing the FX graph for different shapes.
  201. # However, the XLA graph will still require static shapes and needs to
  202. # be re-compiled for every different shapes. This overhead is inevitable
  203. # in the first run, but can be skipped afterwards as we cache the XLA
  204. # graphs in the disk (APHRODITE_XLA_CACHE_PATH).
  205. if is_prompt:
  206. # Prefll
  207. torch._dynamo.mark_dynamic(token_ids, 1)
  208. torch._dynamo.mark_dynamic(position_ids, 1)
  209. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
  210. else:
  211. # Decode
  212. torch._dynamo.mark_dynamic(token_ids, 0)
  213. torch._dynamo.mark_dynamic(position_ids, 0)
  214. torch._dynamo.mark_dynamic(input_lens, 0)
  215. torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
  216. torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
  217. torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
  218. torch._dynamo.mark_dynamic(t, 0)
  219. torch._dynamo.mark_dynamic(p, 0)
  220. # Dummy run.
  221. self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
  222. num_samples, kv_caches)
  223. def warmup_model(
  224. self,
  225. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  226. ) -> None:
  227. # Prefill
  228. logger.info("Compiling the model with different input shapes...")
  229. start = time.time()
  230. for batch_size in [1]:
  231. seq_len = 16
  232. while True:
  233. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  234. xm.wait_device_ops()
  235. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  236. if seq_len >= self.model_config.max_model_len:
  237. break
  238. num_tokens = batch_size * seq_len
  239. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  240. break
  241. seq_len = seq_len * 2
  242. end = time.time()
  243. logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
  244. # Decode
  245. start = time.time()
  246. seq_len = 1
  247. batch_size = 8 # Must be in sync with _get_padded_batch_size()
  248. while True:
  249. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  250. xm.wait_device_ops()
  251. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  252. if batch_size >= self.scheduler_config.max_num_seqs:
  253. break
  254. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  255. end = time.time()
  256. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  257. def _prepare_prompt(
  258. self,
  259. seq_group_metadata_list: List[SequenceGroupMetadata],
  260. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  261. assert len(seq_group_metadata_list) > 0
  262. input_tokens: List[int] = []
  263. input_positions: List[int] = []
  264. prompt_lens: List[int] = []
  265. slot_mapping: List[int] = []
  266. for seq_group_metadata in seq_group_metadata_list:
  267. assert seq_group_metadata.is_prompt
  268. seq_ids = list(seq_group_metadata.seq_data.keys())
  269. assert len(seq_ids) == 1
  270. seq_id = seq_ids[0]
  271. seq_data = seq_group_metadata.seq_data[seq_id]
  272. # Could include output tokens when a request is preempted.
  273. prompt_tokens = seq_data.get_token_ids()
  274. prompt_len = len(prompt_tokens)
  275. prompt_lens.append(prompt_len)
  276. input_tokens.extend(prompt_tokens)
  277. input_positions.extend(list(range(prompt_len)))
  278. assert seq_group_metadata.block_tables is not None
  279. block_table = seq_group_metadata.block_tables[seq_id]
  280. for i in range(prompt_len):
  281. block_number = block_table[i // self.block_size]
  282. block_offset = i % self.block_size
  283. slot = block_number * self.block_size + block_offset
  284. slot_mapping.append(slot)
  285. # Add paddings to EACH prompt to the smallest power of 2 that is
  286. # greater than or equal to the prompt length.
  287. # We pad the seq_len to reduce the compilation overhead.
  288. # We execute each prompt individually (i.e., with batch_size 1)
  289. # because the FlashAttention kernel does not support ragged inputs.
  290. # TODO(woosuk): Use SplashAttention to support ragged inputs.
  291. padded_prompt_len = _get_padded_prefill_len(prompt_len)
  292. num_paddings = padded_prompt_len - prompt_len
  293. input_tokens += [0] * num_paddings
  294. input_positions += [0] * num_paddings
  295. slot_mapping += [_PAD_SLOT_ID] * num_paddings
  296. assert len(prompt_lens) > 0
  297. num_prefills = len(prompt_lens)
  298. input_tokens = torch.tensor(input_tokens,
  299. dtype=torch.int32,
  300. device="cpu")
  301. input_positions = torch.tensor(input_positions,
  302. dtype=torch.int32,
  303. device="cpu")
  304. slot_mapping = torch.tensor(slot_mapping,
  305. dtype=torch.int64,
  306. device="cpu")
  307. prompt_lens = torch.tensor(prompt_lens,
  308. dtype=torch.int32,
  309. device="cpu")
  310. attn_metadata = self.attn_backend.make_metadata(
  311. num_prefills=num_prefills,
  312. num_prefill_tokens=0, # NOTE: This is not used.
  313. num_decode_tokens=0,
  314. slot_mapping=slot_mapping,
  315. block_tables=None,
  316. context_lens=None,
  317. )
  318. return input_tokens, input_positions, attn_metadata, prompt_lens
  319. def _prepare_decode(
  320. self,
  321. seq_group_metadata_list: List[SequenceGroupMetadata],
  322. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  323. assert len(seq_group_metadata_list) > 0
  324. input_tokens: List[List[int]] = []
  325. input_positions: List[List[int]] = []
  326. slot_mapping: List[List[int]] = []
  327. context_lens: List[int] = []
  328. batch_idx = 0
  329. for seq_group_metadata in seq_group_metadata_list:
  330. assert not seq_group_metadata.is_prompt
  331. seq_ids = list(seq_group_metadata.seq_data.keys())
  332. for seq_id in seq_ids:
  333. seq_data = seq_group_metadata.seq_data[seq_id]
  334. generation_token = seq_data.get_last_token_id()
  335. input_tokens.append([generation_token])
  336. seq_len = seq_data.get_len()
  337. position = seq_len - 1
  338. input_positions.append([position])
  339. context_lens.append(seq_len)
  340. assert seq_group_metadata.block_tables is not None
  341. block_table = seq_group_metadata.block_tables[seq_id]
  342. self.block_tables[batch_idx, :len(block_table)] = block_table
  343. batch_idx += 1
  344. block_number = block_table[position // self.block_size]
  345. block_offset = position % self.block_size
  346. slot = block_number * self.block_size + block_offset
  347. slot_mapping.append([slot])
  348. batch_size = _get_padded_batch_size(batch_idx)
  349. num_paddings = batch_size - batch_idx
  350. input_tokens = input_tokens + [[0]] * num_paddings
  351. input_positions = input_positions + [[0]] * num_paddings
  352. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  353. context_lens = context_lens + [0] * num_paddings
  354. input_tokens = torch.tensor(input_tokens,
  355. dtype=torch.int32,
  356. device="cpu")
  357. input_positions = torch.tensor(input_positions,
  358. dtype=torch.int32,
  359. device="cpu")
  360. slot_mapping = torch.tensor(slot_mapping,
  361. dtype=torch.int64,
  362. device="cpu")
  363. context_lens = torch.tensor(context_lens,
  364. dtype=torch.int32,
  365. device="cpu")
  366. block_tables = torch.tensor(self.block_tables[:batch_size],
  367. dtype=torch.int32,
  368. device="cpu")
  369. input_lens = torch.tensor([1] * batch_size,
  370. dtype=torch.int32,
  371. device="cpu")
  372. attn_metadata = self.attn_backend.make_metadata(
  373. num_prefills=0,
  374. num_prefill_tokens=0,
  375. num_decode_tokens=batch_size,
  376. slot_mapping=slot_mapping,
  377. block_tables=block_tables,
  378. context_lens=context_lens,
  379. )
  380. return input_tokens, input_positions, attn_metadata, input_lens
  381. def _prepare_sample(
  382. self,
  383. seq_group_metadata_list: List[SequenceGroupMetadata],
  384. padded_batch_size: int,
  385. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  386. assert len(seq_group_metadata_list) > 0
  387. t = []
  388. p = []
  389. best_of = []
  390. for seq_group_metadata in seq_group_metadata_list:
  391. sampling_params = seq_group_metadata.sampling_params
  392. t.append(sampling_params.temperature)
  393. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  394. raise NotImplementedError(
  395. "Top-p sampling is currently disabled for the TPU backend "
  396. "due to performance issues.")
  397. p.append(sampling_params.top_p)
  398. if sampling_params.top_k != -1:
  399. raise NotImplementedError(
  400. "Top-k sampling is currently disabled for the TPU backend "
  401. "due to performance issues.")
  402. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  403. raise NotImplementedError(
  404. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  405. "backend.")
  406. best_of.append(sampling_params.best_of)
  407. if sampling_params.use_beam_search:
  408. raise NotImplementedError(
  409. "Beam search is not supported by the TPU backend.")
  410. if sampling_params.logprobs is not None:
  411. raise NotImplementedError(
  412. "logprobs is not currently supported by the TPU backend.")
  413. if sampling_params.prompt_logprobs is not None:
  414. raise NotImplementedError(
  415. "prompt_logprobs is not currently supported by the TPU "
  416. "backend.")
  417. # Repeat the sampling params if the seq group has multiple seqs.
  418. num_seqs = len(seq_group_metadata.seq_data)
  419. t += [t[-1]] * (num_seqs - 1)
  420. p += [p[-1]] * (num_seqs - 1)
  421. best_of += [best_of[-1]] * (num_seqs - 1)
  422. num_paddings = padded_batch_size - len(t)
  423. t += [1.0] * num_paddings
  424. p += [1.0] * num_paddings
  425. t = torch.tensor(t, dtype=torch.float32, device="cpu")
  426. p = torch.tensor(p, dtype=torch.float32, device="cpu")
  427. return t, p, best_of
  428. def prepare_model_input(
  429. self,
  430. seq_group_metadata_list: List[SequenceGroupMetadata],
  431. virtual_engine: int = 0,
  432. finished_requests_ids: Optional[List[str]] = None,
  433. ) -> ModelInputForTPU:
  434. del finished_requests_ids # Unused.
  435. assert virtual_engine == 0
  436. assert len(seq_group_metadata_list) > 0
  437. # NOTE: We assume that all sequences in the group are all prompts or
  438. # all decodes.
  439. is_prompt = seq_group_metadata_list[0].is_prompt
  440. if is_prompt:
  441. inputs = self._prepare_prompt(seq_group_metadata_list)
  442. else:
  443. inputs = self._prepare_decode(seq_group_metadata_list)
  444. input_tokens, input_positions, attn_metadata, input_lens = inputs
  445. padded_batch_size = input_tokens.shape[0]
  446. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  447. padded_batch_size)
  448. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  449. seq_groups = [
  450. list(metadata.seq_data.keys())
  451. for metadata in seq_group_metadata_list
  452. ]
  453. return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
  454. input_lens, t, p, num_samples, best_of,
  455. seq_groups)
  456. def make_model_input_from_broadcasted_tensor_dict(
  457. self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
  458. model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
  459. tensor_dict, attn_backend=self.attn_backend)
  460. return model_input
  461. @torch.no_grad()
  462. def execute_model(
  463. self,
  464. model_input: ModelInputForTPU,
  465. kv_caches: Optional[List[Any]],
  466. intermediate_tensors: Optional[IntermediateTensors] = None,
  467. num_steps: int = 1,
  468. ) -> List[SamplerOutput]:
  469. assert intermediate_tensors is None
  470. if num_steps > 1:
  471. raise ValueError(
  472. "TPUModelRunner does not support multi-step execution.")
  473. def _execute_model(*args, clone: bool = False) -> torch.Tensor:
  474. """Move input args from CPU to device and execute the model."""
  475. def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
  476. if clone:
  477. # When x is a slice of a CPU tensor, XLA may copy the whole
  478. # original tensor to TPU instead of only copying x.
  479. # To avoid this, we copy x after cloning.
  480. x = x.clone()
  481. return x.to(self.device)
  482. new_args = []
  483. for arg in args:
  484. if isinstance(arg, torch.Tensor):
  485. arg = _copy_to_device(arg)
  486. elif isinstance(arg, AttentionMetadata):
  487. arg.slot_mapping = _copy_to_device(arg.slot_mapping)
  488. if getattr(arg, "block_tables", None) is not None:
  489. arg.block_tables = _copy_to_device(arg.block_tables)
  490. if getattr(arg, "context_lens", None) is not None:
  491. arg.context_lens = _copy_to_device(arg.context_lens)
  492. new_args.append(arg)
  493. return self.model(*new_args)
  494. num_prefills = model_input.attn_metadata.num_prefills
  495. is_prompt = num_prefills > 0
  496. if is_prompt:
  497. # NOTE: Since the FlashAttention kernel does not support
  498. # ragged inputs, we split the prompts into different batches and
  499. # process them separately. This is a temporary hack that should be
  500. # optimized by using SplashAttention.
  501. next_token_ids = []
  502. orig_slot_mapping = model_input.attn_metadata.slot_mapping
  503. batch_size = model_input.input_lens.shape[0]
  504. start_idx = 0
  505. for i in range(batch_size):
  506. # Get the actual prefill_len.
  507. prefill_len = model_input.input_lens[i:i + 1].item()
  508. prefill_len = _get_padded_prefill_len(prefill_len)
  509. end_idx = start_idx + prefill_len
  510. model_input.attn_metadata.slot_mapping = orig_slot_mapping[
  511. None, start_idx:end_idx]
  512. model_input.attn_metadata.num_prefills = 1
  513. output_token_ids = _execute_model(
  514. model_input.token_ids[None, start_idx:end_idx],
  515. model_input.position_ids[None, start_idx:end_idx],
  516. model_input.attn_metadata,
  517. model_input.input_lens[i:i + 1],
  518. model_input.t[i:i + 1],
  519. model_input.p[i:i + 1],
  520. model_input.num_samples,
  521. kv_caches,
  522. clone=True)
  523. # Retrieve the outputs to CPU.
  524. next_token_ids += output_token_ids.cpu().tolist()
  525. start_idx = end_idx
  526. else:
  527. # Execute the model.
  528. output_token_ids = _execute_model(
  529. model_input.token_ids, model_input.position_ids,
  530. model_input.attn_metadata, model_input.input_lens,
  531. model_input.t, model_input.p, model_input.num_samples,
  532. kv_caches)
  533. # Retrieve the outputs to CPU.
  534. next_token_ids = output_token_ids.cpu().tolist()
  535. # NOTE: Minimal code to construct the sampler outputs.
  536. # The TPU backend does not reuse the sampler, since the TPU backend
  537. # does not support the advanced sampling parameters such as logprobs.
  538. zero_logprob = Logprob(0.0)
  539. batch_idx = 0
  540. sampler_outputs = []
  541. for seq_group in model_input.seq_groups:
  542. seq_ids = seq_group
  543. seq_outputs = []
  544. if is_prompt:
  545. assert len(seq_ids) == 1
  546. seq_id = seq_ids[0]
  547. for i in range(model_input.best_of[batch_idx]):
  548. next_token_id = next_token_ids[batch_idx][i]
  549. seq_outputs.append(
  550. SequenceOutput(seq_id, next_token_id,
  551. {next_token_id: zero_logprob}))
  552. batch_idx += 1
  553. else:
  554. for seq_id in seq_ids:
  555. next_token_id = next_token_ids[batch_idx][0]
  556. seq_outputs.append(
  557. SequenceOutput(seq_id, next_token_id,
  558. {next_token_id: zero_logprob}))
  559. batch_idx += 1
  560. sampler_outputs.append(
  561. CompletionSequenceGroupOutput(seq_outputs, None))
  562. return [SamplerOutput(sampler_outputs)]
  563. class ModelWrapper(nn.Module):
  564. def __init__(self, model: nn.Module):
  565. super().__init__()
  566. self.model = model
  567. def forward(
  568. self,
  569. token_ids: torch.Tensor,
  570. position_ids: torch.Tensor,
  571. attn_metadata: AttentionMetadata,
  572. input_lens: torch.Tensor,
  573. t: torch.Tensor,
  574. p: torch.Tensor,
  575. num_samples: int,
  576. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  577. ) -> torch.Tensor:
  578. """Executes the forward pass of the model and samples the next token.
  579. Args:
  580. token_ids: The input token IDs of shape [batch_size, seq_len].
  581. position_ids: The input position IDs of shape [batch_size, seq_len].
  582. attn_metadata: The Pallas attention metadata.
  583. input_lens: The actual input lengths of shape [batch_size].
  584. t: The sampling temperature of shape [batch_size].
  585. p: The top-p probability of shape [batch_size].
  586. num_samples: Number of samples to draw from each logits vector.
  587. kv_caches: The key and value caches. They can be None during the
  588. memory profiling at initialization.
  589. """
  590. batch_size, seq_len = token_ids.shape
  591. # Calculate the positions to sample from.
  592. start_indicies = torch.arange(
  593. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  594. logits_indices = start_indicies + input_lens - 1
  595. # FIXME: This is a temporary hack to avoid using the existing
  596. # sampler and sampling metadata.
  597. sampling_metadata = SamplingMetadata(
  598. seq_groups=[],
  599. selected_token_indices=logits_indices,
  600. categorized_sample_indices={},
  601. num_prompts=attn_metadata.num_prefills,
  602. )
  603. # Skip this in memory profiling at initialization.
  604. if kv_caches[0][0] is not None:
  605. # index_copy_(slot_mapping) only works when the inserted dimension
  606. # is 0. However, the KV cache in the Pallas backend has the shape
  607. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  608. # work, we need to flatten the first three dimensions and modify
  609. # the slot_mapping accordingly.
  610. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  611. slot_mapping = attn_metadata.slot_mapping
  612. slot_mapping = slot_mapping.flatten()
  613. head_indicies = torch.arange(0,
  614. num_kv_heads,
  615. device=slot_mapping.device,
  616. dtype=slot_mapping.dtype)
  617. head_indicies *= block_size * num_blocks
  618. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  619. -1, num_kv_heads)
  620. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  621. slot_mapping = slot_mapping.flatten()
  622. attn_metadata.slot_mapping = slot_mapping
  623. hidden_states = self.model(
  624. token_ids,
  625. position_ids,
  626. kv_caches,
  627. attn_metadata,
  628. )
  629. hidden_states = hidden_states.flatten(0, 1)
  630. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  631. # Argmax sampling.
  632. argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
  633. argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
  634. # Zero temperature means greedy decoding. Avoid division by zero.
  635. nonzero_t = torch.where(t != 0, t, 1.0)
  636. logits = logits / nonzero_t.unsqueeze(dim=1)
  637. if _ENABLE_TOP_P:
  638. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  639. # Random sampling.
  640. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  641. sampled_token_ids = torch.multinomial(probs,
  642. num_samples,
  643. replacement=True)
  644. next_token_ids = torch.where(t != 0, sampled_token_ids,
  645. argmax_token_ids)
  646. return next_token_ids
  647. def _get_padded_prefill_len(x: int) -> int:
  648. # NOTE: The pallas FlashAttention kernel requires the sequence
  649. # length to be a multiple of 16. We pad the prompt length to the nearest
  650. # multiple of 16. This is also good for performance.
  651. if x <= 16:
  652. return 16
  653. return 1 << (x - 1).bit_length()
  654. def _get_padded_batch_size(batch_size: int) -> int:
  655. # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
  656. # To meet this requirement in the simplest way, we set the minimal batch
  657. # size to 8.
  658. if batch_size <= 8:
  659. return 8
  660. else:
  661. return ((batch_size + 15) // 16) * 16
  662. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  663. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  664. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  665. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  666. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  667. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  668. return logits