1
0

tpu_model_runner.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. import time
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. import torch_xla.core.xla_model as xm
  8. from loguru import logger
  9. from aphrodite.attention import AttentionMetadata, get_attn_backend
  10. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  11. ModelConfig, MultiModalConfig,
  12. ParallelConfig, SchedulerConfig)
  13. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  14. IntermediateTensors, Logprob,
  15. SamplerOutput, SequenceGroupMetadata,
  16. SequenceOutput)
  17. from aphrodite.modeling.model_loader import get_model
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. from aphrodite.task_handler.model_runner_base import (
  20. ModelRunnerBase, ModelRunnerInputBase,
  21. _add_attn_metadata_broadcastable_dict,
  22. _init_attn_metadata_from_tensor_dict)
  23. if TYPE_CHECKING:
  24. from aphrodite.attention.backends.abstract import AttentionBackend
  25. _PAD_SLOT_ID = -1 # NOTE: In PyTorch XLA, index -1 is ignored.
  26. # FIXME: Temporarily disabled top-p sampling since it's too slow.
  27. _ENABLE_TOP_P = False
  28. # FIXME: A temporary hack to support `n > 1`.
  29. # This can significantly affect the performance if too large.
  30. _MAX_NUM_SAMPLES = 128
  31. @dataclass(frozen=True)
  32. class ModelInputForTPU(ModelRunnerInputBase):
  33. token_ids: torch.Tensor
  34. position_ids: torch.Tensor
  35. attn_metadata: AttentionMetadata
  36. input_lens: torch.Tensor
  37. t: torch.Tensor
  38. p: torch.Tensor
  39. num_samples: int
  40. best_of: List[int]
  41. seq_groups: List[List[int]]
  42. def as_broadcastable_tensor_dict(
  43. self) -> Dict[str, Union[int, torch.Tensor]]:
  44. tensor_dict = {
  45. "token_ids": self.token_ids,
  46. "position_ids": self.position_ids,
  47. "input_lens": self.input_lens,
  48. "t": self.t,
  49. "p": self.p,
  50. "num_samples": self.num_samples,
  51. }
  52. _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
  53. return tensor_dict
  54. @classmethod
  55. def from_broadcasted_tensor_dict(
  56. cls: Type["ModelInputForTPU"],
  57. tensor_dict: Dict[str, Any],
  58. attn_backend: Optional["AttentionBackend"] = None,
  59. ) -> "ModelInputForTPU":
  60. if attn_backend is not None:
  61. tensor_dict = _init_attn_metadata_from_tensor_dict(
  62. attn_backend, tensor_dict)
  63. return cls(**tensor_dict)
  64. class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
  65. def __init__(
  66. self,
  67. model_config: ModelConfig,
  68. parallel_config: ParallelConfig,
  69. scheduler_config: SchedulerConfig,
  70. device_config: DeviceConfig,
  71. cache_config: CacheConfig,
  72. load_config: LoadConfig,
  73. multimodal_config: Optional[MultiModalConfig] = None,
  74. is_driver_worker: bool = False,
  75. ):
  76. self.model_config = model_config
  77. self.parallel_config = parallel_config
  78. self.scheduler_config = scheduler_config
  79. self.device_config = device_config
  80. self.cache_config = cache_config
  81. self.load_config = load_config
  82. self.multimodal_config = multimodal_config
  83. self.is_driver_worker = is_driver_worker
  84. self.block_size = self.cache_config.block_size
  85. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  86. self.block_size)
  87. self.block_tables = np.zeros(
  88. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  89. dtype=np.int32)
  90. self.attn_backend = get_attn_backend(
  91. self.model_config.get_num_attention_heads(self.parallel_config),
  92. self.model_config.get_head_size(),
  93. self.model_config.get_num_kv_heads(self.parallel_config),
  94. self.model_config.get_sliding_window(),
  95. self.model_config.dtype,
  96. self.cache_config.cache_dtype,
  97. self.block_size,
  98. False,
  99. )
  100. def load_model(self) -> None:
  101. self.device = self.device_config.device
  102. model = get_model(
  103. model_config=self.model_config,
  104. load_config=self.load_config,
  105. device_config=self.device_config,
  106. parallel_config=self.parallel_config,
  107. cache_config=self.cache_config,
  108. scheduler_config=self.scheduler_config,
  109. multimodal_config=self.multimodal_config,
  110. lora_config=None,
  111. )
  112. model = model.eval()
  113. xm.wait_device_ops()
  114. model = ModelWrapper(model)
  115. self.model = torch.compile(model, backend="openxla", fullgraph=True)
  116. def _dummy_run(
  117. self,
  118. batch_size: int,
  119. seq_len: int,
  120. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  121. is_prompt: bool,
  122. ) -> None:
  123. if is_prompt:
  124. seq_len = (seq_len + 15) // 16 * 16
  125. token_ids = torch.zeros((batch_size, seq_len),
  126. dtype=torch.int32,
  127. device=self.device)
  128. position_ids = torch.zeros((batch_size, seq_len),
  129. dtype=torch.int32,
  130. device=self.device)
  131. slot_mapping = torch.zeros((batch_size, seq_len),
  132. dtype=torch.int64,
  133. device=self.device)
  134. attn_metadata = self.attn_backend.make_metadata(
  135. num_prefills=batch_size,
  136. num_prefill_tokens=batch_size * seq_len,
  137. num_decode_tokens=0,
  138. slot_mapping=slot_mapping,
  139. block_tables=None,
  140. context_lens=None,
  141. )
  142. input_lens = torch.ones((batch_size, ),
  143. dtype=torch.int32,
  144. device=self.device)
  145. else:
  146. assert seq_len == 1
  147. token_ids = torch.zeros((batch_size, seq_len),
  148. dtype=torch.int32,
  149. device=self.device)
  150. position_ids = torch.zeros((batch_size, seq_len),
  151. dtype=torch.int32,
  152. device=self.device)
  153. slot_mapping = torch.zeros((batch_size, seq_len),
  154. dtype=torch.int64,
  155. device=self.device)
  156. block_tables = torch.zeros(
  157. (batch_size, self.max_num_blocks_per_seq),
  158. dtype=torch.int32,
  159. device=self.device)
  160. context_lens = torch.ones((batch_size, ),
  161. dtype=torch.int32,
  162. device=self.device)
  163. input_lens = torch.ones((batch_size, ),
  164. dtype=torch.int32,
  165. device=self.device)
  166. attn_metadata = self.attn_backend.make_metadata(
  167. num_prefills=0,
  168. num_prefill_tokens=0,
  169. num_decode_tokens=batch_size * seq_len,
  170. slot_mapping=slot_mapping,
  171. block_tables=block_tables,
  172. context_lens=context_lens,
  173. )
  174. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  175. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  176. # Dummy run.
  177. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  178. self.model(token_ids, position_ids, attn_metadata, input_lens, t, p,
  179. num_samples, kv_caches)
  180. def warmup_model(
  181. self,
  182. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  183. ) -> None:
  184. # Prefill
  185. logger.info("Compiling the model with different input shapes...")
  186. start = time.time()
  187. for batch_size in [1]:
  188. seq_len = 16
  189. while True:
  190. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  191. xm.wait_device_ops()
  192. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  193. if seq_len >= self.model_config.max_model_len:
  194. break
  195. num_tokens = batch_size * seq_len
  196. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  197. break
  198. seq_len = seq_len * 2
  199. end = time.time()
  200. logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
  201. # Decode
  202. start = time.time()
  203. seq_len = 1
  204. batch_size = 8 # Must be in sync with _get_padded_batch_size()
  205. while True:
  206. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  207. xm.wait_device_ops()
  208. logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
  209. if batch_size >= self.scheduler_config.max_num_seqs:
  210. break
  211. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  212. end = time.time()
  213. logger.info(f"Compilation for decode done in {end - start:.2f} s.")
  214. def _prepare_prompt(
  215. self,
  216. seq_group_metadata_list: List[SequenceGroupMetadata],
  217. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  218. assert len(seq_group_metadata_list) > 0
  219. input_tokens: List[int] = []
  220. input_positions: List[int] = []
  221. prompt_lens: List[int] = []
  222. slot_mapping: List[int] = []
  223. for seq_group_metadata in seq_group_metadata_list:
  224. assert seq_group_metadata.is_prompt
  225. seq_ids = list(seq_group_metadata.seq_data.keys())
  226. assert len(seq_ids) == 1
  227. seq_id = seq_ids[0]
  228. seq_data = seq_group_metadata.seq_data[seq_id]
  229. # Could include output tokens when a request is preempted.
  230. prompt_tokens = seq_data.get_token_ids()
  231. prompt_len = len(prompt_tokens)
  232. prompt_lens.append(prompt_len)
  233. input_tokens.extend(prompt_tokens)
  234. input_positions.extend(list(range(prompt_len)))
  235. assert seq_group_metadata.block_tables is not None
  236. block_table = seq_group_metadata.block_tables[seq_id]
  237. for i in range(prompt_len):
  238. block_number = block_table[i // self.block_size]
  239. block_offset = i % self.block_size
  240. slot = block_number * self.block_size + block_offset
  241. slot_mapping.append(slot)
  242. # Add paddings to EACH prompt to the smallest power of 2 that is
  243. # greater than or equal to the prompt length.
  244. # We pad the seq_len to reduce the compilation overhead.
  245. # We execute each prompt individually (i.e., with batch_size 1)
  246. # because the FlashAttention kernel does not support ragged inputs.
  247. # TODO(woosuk): Use SplashAttention to support ragged inputs.
  248. padded_prompt_len = _get_padded_prefill_len(prompt_len)
  249. num_paddings = padded_prompt_len - prompt_len
  250. input_tokens += [0] * num_paddings
  251. input_positions += [0] * num_paddings
  252. slot_mapping += [_PAD_SLOT_ID] * num_paddings
  253. assert len(prompt_lens) > 0
  254. num_prefills = len(prompt_lens)
  255. input_tokens = torch.tensor(input_tokens,
  256. dtype=torch.int32,
  257. device="cpu")
  258. input_positions = torch.tensor(input_positions,
  259. dtype=torch.int32,
  260. device="cpu")
  261. slot_mapping = torch.tensor(slot_mapping,
  262. dtype=torch.int64,
  263. device="cpu")
  264. prompt_lens = torch.tensor(prompt_lens,
  265. dtype=torch.int32,
  266. device="cpu")
  267. attn_metadata = self.attn_backend.make_metadata(
  268. num_prefills=num_prefills,
  269. num_prefill_tokens=0, # NOTE: This is not used.
  270. num_decode_tokens=0,
  271. slot_mapping=slot_mapping,
  272. block_tables=None,
  273. context_lens=None,
  274. )
  275. return input_tokens, input_positions, attn_metadata, prompt_lens
  276. def _prepare_decode(
  277. self,
  278. seq_group_metadata_list: List[SequenceGroupMetadata],
  279. ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
  280. assert len(seq_group_metadata_list) > 0
  281. input_tokens: List[List[int]] = []
  282. input_positions: List[List[int]] = []
  283. slot_mapping: List[List[int]] = []
  284. context_lens: List[int] = []
  285. batch_idx = 0
  286. for seq_group_metadata in seq_group_metadata_list:
  287. assert not seq_group_metadata.is_prompt
  288. seq_ids = list(seq_group_metadata.seq_data.keys())
  289. for seq_id in seq_ids:
  290. seq_data = seq_group_metadata.seq_data[seq_id]
  291. generation_token = seq_data.get_last_token_id()
  292. input_tokens.append([generation_token])
  293. seq_len = seq_data.get_len()
  294. position = seq_len - 1
  295. input_positions.append([position])
  296. context_lens.append(seq_len)
  297. assert seq_group_metadata.block_tables is not None
  298. block_table = seq_group_metadata.block_tables[seq_id]
  299. self.block_tables[batch_idx, :len(block_table)] = block_table
  300. batch_idx += 1
  301. block_number = block_table[position // self.block_size]
  302. block_offset = position % self.block_size
  303. slot = block_number * self.block_size + block_offset
  304. slot_mapping.append([slot])
  305. batch_size = _get_padded_batch_size(batch_idx)
  306. num_paddings = batch_size - batch_idx
  307. input_tokens = input_tokens + [[0]] * num_paddings
  308. input_positions = input_positions + [[0]] * num_paddings
  309. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  310. context_lens = context_lens + [0] * num_paddings
  311. input_tokens = torch.tensor(input_tokens,
  312. dtype=torch.int32,
  313. device="cpu")
  314. input_positions = torch.tensor(input_positions,
  315. dtype=torch.int32,
  316. device="cpu")
  317. slot_mapping = torch.tensor(slot_mapping,
  318. dtype=torch.int64,
  319. device="cpu")
  320. context_lens = torch.tensor(context_lens,
  321. dtype=torch.int32,
  322. device="cpu")
  323. block_tables = torch.tensor(self.block_tables[:batch_size],
  324. dtype=torch.int32,
  325. device="cpu")
  326. input_lens = torch.tensor([1] * batch_size,
  327. dtype=torch.int32,
  328. device="cpu")
  329. attn_metadata = self.attn_backend.make_metadata(
  330. num_prefills=0,
  331. num_prefill_tokens=0,
  332. num_decode_tokens=batch_size,
  333. slot_mapping=slot_mapping,
  334. block_tables=block_tables,
  335. context_lens=context_lens,
  336. )
  337. return input_tokens, input_positions, attn_metadata, input_lens
  338. def _prepare_sample(
  339. self,
  340. seq_group_metadata_list: List[SequenceGroupMetadata],
  341. padded_batch_size: int,
  342. ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
  343. assert len(seq_group_metadata_list) > 0
  344. t = []
  345. p = []
  346. best_of = []
  347. for seq_group_metadata in seq_group_metadata_list:
  348. sampling_params = seq_group_metadata.sampling_params
  349. # NOTE: Here we mimic argmax sampling by applying a very
  350. # low temperature. This is not accurate.
  351. t.append(sampling_params.temperature
  352. if sampling_params.temperature >= 1e-5 else 1e-5)
  353. if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
  354. raise NotImplementedError(
  355. "Top-p sampling is currently disabled for the TPU backend "
  356. "due to performance issues.")
  357. p.append(sampling_params.top_p)
  358. if sampling_params.top_k != -1:
  359. raise NotImplementedError(
  360. "Top-k sampling is currently disabled for the TPU backend "
  361. "due to performance issues.")
  362. if sampling_params.best_of > _MAX_NUM_SAMPLES:
  363. raise NotImplementedError(
  364. f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
  365. "backend.")
  366. best_of.append(sampling_params.best_of)
  367. if sampling_params.use_beam_search:
  368. raise NotImplementedError(
  369. "Beam search is not supported by the TPU backend.")
  370. if sampling_params.logprobs is not None:
  371. raise NotImplementedError(
  372. "logprobs is not currently supported by the TPU backend.")
  373. if sampling_params.prompt_logprobs is not None:
  374. raise NotImplementedError(
  375. "prompt_logprobs is not currently supported by the TPU "
  376. "backend.")
  377. # Repeat the sampling params if the seq group has multiple seqs.
  378. num_seqs = len(seq_group_metadata.seq_data)
  379. t += [t[-1]] * (num_seqs - 1)
  380. p += [p[-1]] * (num_seqs - 1)
  381. best_of += [best_of[-1]] * (num_seqs - 1)
  382. num_paddings = padded_batch_size - len(t)
  383. t += [1.0] * num_paddings
  384. p += [1.0] * num_paddings
  385. t = torch.tensor(t, dtype=torch.float32, device="cpu")
  386. p = torch.tensor(p, dtype=torch.float32, device="cpu")
  387. return t, p, best_of
  388. def prepare_model_input(
  389. self,
  390. seq_group_metadata_list: List[SequenceGroupMetadata],
  391. virtual_engine: int = 0,
  392. finished_requests_ids: Optional[List[str]] = None,
  393. ) -> ModelInputForTPU:
  394. del finished_requests_ids # Unused.
  395. assert virtual_engine == 0
  396. assert len(seq_group_metadata_list) > 0
  397. # NOTE: We assume that all sequences in the group are all prompts or
  398. # all decodes.
  399. is_prompt = seq_group_metadata_list[0].is_prompt
  400. if is_prompt:
  401. inputs = self._prepare_prompt(seq_group_metadata_list)
  402. else:
  403. inputs = self._prepare_decode(seq_group_metadata_list)
  404. input_tokens, input_positions, attn_metadata, input_lens = inputs
  405. padded_batch_size = input_tokens.shape[0]
  406. t, p, best_of = self._prepare_sample(seq_group_metadata_list,
  407. padded_batch_size)
  408. num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
  409. seq_groups = [
  410. list(metadata.seq_data.keys())
  411. for metadata in seq_group_metadata_list
  412. ]
  413. return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
  414. input_lens, t, p, num_samples, best_of,
  415. seq_groups)
  416. def make_model_input_from_broadcasted_tensor_dict(
  417. self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
  418. model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
  419. tensor_dict, attn_backend=self.attn_backend)
  420. return model_input
  421. def execute_model(
  422. self,
  423. model_input: ModelInputForTPU,
  424. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  425. intermediate_tensors: Optional[IntermediateTensors] = None,
  426. num_steps: int = 1,
  427. ) -> List[SamplerOutput]:
  428. assert intermediate_tensors is None
  429. if num_steps > 1:
  430. raise ValueError(
  431. "TPUModelRunner does not support multi-step execution.")
  432. def _execute_model(*args, clone: bool = False) -> torch.Tensor:
  433. """Move input args from CPU to device and execute the model."""
  434. def _copy_to_device(x: torch.Tensor) -> torch.Tensor:
  435. if clone:
  436. # When x is a slice of a CPU tensor, XLA may copy the whole
  437. # original tensor to TPU instead of only copying x.
  438. # To avoid this, we copy x after cloning.
  439. x = x.clone()
  440. return x.to(self.device)
  441. new_args = []
  442. for arg in args:
  443. if isinstance(arg, torch.Tensor):
  444. arg = _copy_to_device(arg)
  445. elif isinstance(arg, AttentionMetadata):
  446. arg.slot_mapping = _copy_to_device(arg.slot_mapping)
  447. if getattr(arg, "block_tables", None) is not None:
  448. arg.block_tables = _copy_to_device(arg.block_tables)
  449. if getattr(arg, "context_lens", None) is not None:
  450. arg.context_lens = _copy_to_device(arg.context_lens)
  451. new_args.append(arg)
  452. return self.model(*new_args)
  453. num_prefills = model_input.attn_metadata.num_prefills
  454. is_prompt = num_prefills > 0
  455. if is_prompt:
  456. # NOTE: Since the FlashAttention kernel does not support
  457. # ragged inputs, we split the prompts into different batches and
  458. # process them separately. This is a temporary hack that should be
  459. # optimized by using SplashAttention.
  460. next_token_ids = []
  461. orig_slot_mapping = model_input.attn_metadata.slot_mapping
  462. batch_size = model_input.input_lens.shape[0]
  463. start_idx = 0
  464. for i in range(batch_size):
  465. # Get the actual prefill_len.
  466. prefill_len = model_input.input_lens[i:i + 1].item()
  467. prefill_len = _get_padded_prefill_len(prefill_len)
  468. end_idx = start_idx + prefill_len
  469. model_input.attn_metadata.slot_mapping = orig_slot_mapping[
  470. None, start_idx:end_idx]
  471. model_input.attn_metadata.num_prefills = 1
  472. output_token_ids = _execute_model(
  473. model_input.token_ids[None, start_idx:end_idx],
  474. model_input.position_ids[None, start_idx:end_idx],
  475. model_input.attn_metadata,
  476. model_input.input_lens[i:i + 1],
  477. model_input.t[i:i + 1],
  478. model_input.p[i:i + 1],
  479. model_input.num_samples,
  480. kv_caches,
  481. clone=True)
  482. # Retrieve the outputs to CPU.
  483. next_token_ids += output_token_ids.cpu().tolist()
  484. start_idx = end_idx
  485. else:
  486. # Execute the model.
  487. output_token_ids = _execute_model(
  488. model_input.token_ids, model_input.position_ids,
  489. model_input.attn_metadata, model_input.input_lens,
  490. model_input.t, model_input.p, model_input.num_samples,
  491. kv_caches)
  492. # Retrieve the outputs to CPU.
  493. next_token_ids = output_token_ids.cpu().tolist()
  494. # NOTE: Minimal code to construct the sampler outputs.
  495. # The TPU backend does not reuse the sampler, since the TPU backend
  496. # does not support the advanced sampling parameters such as logprobs.
  497. zero_logprob = Logprob(0.0)
  498. batch_idx = 0
  499. sampler_outputs = []
  500. for seq_group in model_input.seq_groups:
  501. seq_ids = seq_group
  502. seq_outputs = []
  503. if is_prompt:
  504. assert len(seq_ids) == 1
  505. seq_id = seq_ids[0]
  506. for i in range(model_input.best_of[batch_idx]):
  507. next_token_id = next_token_ids[batch_idx][i]
  508. seq_outputs.append(
  509. SequenceOutput(seq_id, next_token_id,
  510. {next_token_id: zero_logprob}))
  511. batch_idx += 1
  512. else:
  513. for seq_id in seq_ids:
  514. next_token_id = next_token_ids[batch_idx][0]
  515. seq_outputs.append(
  516. SequenceOutput(seq_id, next_token_id,
  517. {next_token_id: zero_logprob}))
  518. batch_idx += 1
  519. sampler_outputs.append(
  520. CompletionSequenceGroupOutput(seq_outputs, None))
  521. return [SamplerOutput(sampler_outputs)]
  522. class ModelWrapper(nn.Module):
  523. def __init__(self, model: nn.Module):
  524. super().__init__()
  525. self.model = model
  526. def forward(
  527. self,
  528. token_ids: torch.Tensor,
  529. position_ids: torch.Tensor,
  530. attn_metadata: AttentionMetadata,
  531. input_lens: torch.Tensor,
  532. t: torch.Tensor,
  533. p: torch.Tensor,
  534. num_samples: int,
  535. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  536. ) -> torch.Tensor:
  537. """Executes the forward pass of the model and samples the next token.
  538. Args:
  539. token_ids: The input token IDs of shape [batch_size, seq_len].
  540. position_ids: The input position IDs of shape [batch_size, seq_len].
  541. attn_metadata: The Pallas attention metadata.
  542. input_lens: The actual input lengths of shape [batch_size].
  543. t: The sampling temperature of shape [batch_size].
  544. p: The top-p probability of shape [batch_size].
  545. num_samples: Number of samples to draw from each logits vector.
  546. kv_caches: The key and value caches. They can be None during the
  547. memory profiling at initialization.
  548. """
  549. batch_size, seq_len = token_ids.shape
  550. # Calculate the positions to sample from.
  551. start_indicies = torch.arange(
  552. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  553. logits_indices = start_indicies + input_lens - 1
  554. # FIXME: This is a temporary hack to avoid using the existing
  555. # sampler and sampling metadata.
  556. sampling_metadata = SamplingMetadata(
  557. seq_groups=[],
  558. selected_token_indices=logits_indices,
  559. categorized_sample_indices={},
  560. num_prompts=attn_metadata.num_prefills,
  561. )
  562. # Skip this in memory profiling at initialization.
  563. if kv_caches[0][0] is not None:
  564. # index_copy_(slot_mapping) only works when the inserted dimension
  565. # is 0. However, the KV cache in the Pallas backend has the shape
  566. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  567. # work, we need to flatten the first three dimensions and modify
  568. # the slot_mapping accordingly.
  569. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  570. slot_mapping = attn_metadata.slot_mapping
  571. slot_mapping = slot_mapping.flatten()
  572. head_indicies = torch.arange(0,
  573. num_kv_heads,
  574. device=slot_mapping.device,
  575. dtype=slot_mapping.dtype)
  576. head_indicies *= block_size * num_blocks
  577. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  578. -1, num_kv_heads)
  579. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  580. slot_mapping = slot_mapping.flatten()
  581. attn_metadata.slot_mapping = slot_mapping
  582. hidden_states = self.model(
  583. token_ids,
  584. position_ids,
  585. kv_caches,
  586. attn_metadata,
  587. )
  588. hidden_states = hidden_states.flatten(0, 1)
  589. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  590. logits = logits / t.unsqueeze(dim=1)
  591. if _ENABLE_TOP_P:
  592. logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  593. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  594. next_token_ids = torch.multinomial(probs,
  595. num_samples,
  596. replacement=True)
  597. return next_token_ids
  598. def _get_padded_prefill_len(x: int) -> int:
  599. # NOTE: The pallas FlashAttention kernel requires the sequence
  600. # length to be a multiple of 16. We pad the prompt length to the nearest
  601. # multiple of 16. This is also good for performance.
  602. if x <= 16:
  603. return 16
  604. return 1 << (x - 1).bit_length()
  605. def _get_padded_batch_size(batch_size: int) -> int:
  606. # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
  607. # To meet this requirement in the simplest way, we set the minimal batch
  608. # size to 8.
  609. if batch_size <= 8:
  610. return 8
  611. else:
  612. return ((batch_size + 15) // 16) * 16
  613. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  614. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  615. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  616. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  617. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  618. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  619. return logits