tpu_model_runner.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import time
  2. from typing import List, Optional, Tuple
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch_xla.core.xla_model as xm
  7. from loguru import logger
  8. from aphrodite.attention import AttentionMetadata, get_attn_backend
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. ModelConfig, ParallelConfig,
  11. SchedulerConfig, VisionLanguageConfig)
  12. from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
  13. SamplerOutput, SequenceGroupMetadata,
  14. SequenceOutput)
  15. from aphrodite.common.utils import make_tensor_with_pad
  16. from aphrodite.modeling.model_loader import get_model
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. _PAD_SLOT_ID = 0 # FIXME
  19. class TPUModelRunner:
  20. def __init__(
  21. self,
  22. model_config: ModelConfig,
  23. parallel_config: ParallelConfig,
  24. scheduler_config: SchedulerConfig,
  25. device_config: DeviceConfig,
  26. cache_config: CacheConfig,
  27. load_config: LoadConfig,
  28. vision_language_config: Optional[VisionLanguageConfig] = None,
  29. ):
  30. self.model_config = model_config
  31. self.parallel_config = parallel_config
  32. self.scheduler_config = scheduler_config
  33. self.device_config = device_config
  34. self.cache_config = cache_config
  35. self.load_config = load_config
  36. self.vision_language_config = vision_language_config
  37. self.block_size = self.cache_config.block_size
  38. self.max_num_blocks_per_seq = (self.model_config.max_model_len //
  39. self.block_size)
  40. self.block_tables = np.zeros(
  41. (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
  42. dtype=np.int32)
  43. self.attn_backend = get_attn_backend(
  44. self.model_config.get_num_attention_heads(self.parallel_config),
  45. self.model_config.get_head_size(),
  46. self.model_config.get_num_kv_heads(self.parallel_config),
  47. self.model_config.get_sliding_window(),
  48. self.model_config.dtype,
  49. self.cache_config.cache_dtype,
  50. self.block_size,
  51. False,
  52. )
  53. def load_model(self) -> None:
  54. self.device = self.device_config.device
  55. model = get_model(
  56. model_config=self.model_config,
  57. load_config=self.load_config,
  58. device_config=self.device_config,
  59. parallel_config=self.parallel_config,
  60. cache_config=self.cache_config,
  61. scheduler_config=self.scheduler_config,
  62. vision_language_config=self.vision_language_config,
  63. lora_config=None,
  64. )
  65. xm.wait_device_ops()
  66. model = ModelWrapper(model)
  67. self.model = torch.compile(model, backend="openxla", fullgraph=True)
  68. def _dummy_run(
  69. self,
  70. batch_size: int,
  71. seq_len: int,
  72. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  73. is_prompt: bool,
  74. ) -> None:
  75. if is_prompt:
  76. seq_len = (seq_len + 15) // 16 * 16
  77. token_ids = torch.zeros((batch_size, seq_len),
  78. dtype=torch.int32,
  79. device=self.device)
  80. position_ids = torch.zeros((batch_size, seq_len),
  81. dtype=torch.int32,
  82. device=self.device)
  83. slot_mapping = torch.zeros((batch_size, seq_len),
  84. dtype=torch.int64,
  85. device=self.device)
  86. attn_metadata = self.attn_backend.make_metadata(
  87. num_prefills=batch_size,
  88. num_prefill_tokens=batch_size * seq_len,
  89. num_decode_tokens=0,
  90. slot_mapping=slot_mapping,
  91. block_tables=None,
  92. context_lens=None,
  93. )
  94. input_lens = torch.ones((batch_size, ),
  95. dtype=torch.int32,
  96. device=self.device)
  97. else:
  98. assert seq_len == 1
  99. token_ids = torch.zeros((batch_size, seq_len),
  100. dtype=torch.int32,
  101. device=self.device)
  102. position_ids = torch.zeros((batch_size, seq_len),
  103. dtype=torch.int32,
  104. device=self.device)
  105. slot_mapping = torch.zeros((batch_size, seq_len),
  106. dtype=torch.int64,
  107. device=self.device)
  108. block_tables = torch.zeros(
  109. (batch_size, self.max_num_blocks_per_seq),
  110. dtype=torch.int32,
  111. device=self.device)
  112. context_lens = torch.ones((batch_size, ),
  113. dtype=torch.int32,
  114. device=self.device)
  115. input_lens = torch.ones((batch_size, ),
  116. dtype=torch.int32,
  117. device=self.device)
  118. attn_metadata = self.attn_backend.make_metadata(
  119. num_prefills=0,
  120. num_prefill_tokens=0,
  121. num_decode_tokens=batch_size * seq_len,
  122. slot_mapping=slot_mapping,
  123. block_tables=block_tables,
  124. context_lens=context_lens,
  125. )
  126. t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  127. p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
  128. # Dummy run.
  129. self.model(token_ids, position_ids, kv_caches, attn_metadata,
  130. input_lens, t, p)
  131. def warmup_model(
  132. self,
  133. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  134. ) -> None:
  135. # Prefill
  136. logger.info("Compiling the model with different input shapes...")
  137. start = time.time()
  138. for batch_size in [1]:
  139. seq_len = 16
  140. while True:
  141. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
  142. xm.wait_device_ops()
  143. logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
  144. if seq_len >= self.model_config.max_model_len:
  145. break
  146. num_tokens = batch_size * seq_len
  147. if num_tokens >= self.scheduler_config.max_num_batched_tokens:
  148. break
  149. seq_len = seq_len * 2
  150. end = time.time()
  151. logger.info("Compilation for prefill done in %.2f s.", end - start)
  152. # Decode
  153. start = time.time()
  154. seq_len = 1
  155. batch_size = 1
  156. while True:
  157. self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
  158. xm.wait_device_ops()
  159. logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
  160. if batch_size >= self.scheduler_config.max_num_seqs:
  161. break
  162. batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
  163. end = time.time()
  164. logger.info("Compilation for decode done in %.2f s.", end - start)
  165. def _prepare_prompt(
  166. self,
  167. seq_group_metadata_list: List[SequenceGroupMetadata],
  168. ):
  169. assert len(seq_group_metadata_list) > 0
  170. input_tokens: List[List[int]] = []
  171. input_positions: List[List[int]] = []
  172. prompt_lens: List[int] = []
  173. slot_mapping: List[List[int]] = []
  174. for seq_group_metadata in seq_group_metadata_list:
  175. assert seq_group_metadata.is_prompt
  176. seq_ids = list(seq_group_metadata.seq_data.keys())
  177. assert len(seq_ids) == 1
  178. seq_id = seq_ids[0]
  179. seq_data = seq_group_metadata.seq_data[seq_id]
  180. # Could include output tokens when a request is preempted.
  181. prompt_tokens = seq_data.get_token_ids()
  182. prompt_len = len(prompt_tokens)
  183. prompt_lens.append(prompt_len)
  184. input_tokens.append(prompt_tokens)
  185. input_positions.append(list(range(prompt_len)))
  186. assert seq_group_metadata.block_tables is not None
  187. block_table = seq_group_metadata.block_tables[seq_id]
  188. slot_mapping.append([])
  189. for i in range(prompt_len):
  190. block_number = block_table[i // self.block_size]
  191. block_offset = i % self.block_size
  192. slot = block_number * self.block_size + block_offset
  193. slot_mapping[-1].append(slot)
  194. assert len(prompt_lens) > 0
  195. num_prefills = len(prompt_lens)
  196. num_prefill_tokens = sum(prompt_lens)
  197. # Add paddings to make the shape [batch_size, max_prompt_len] where
  198. # max_prompt_len is smallest power of 2 that is greater than or equal
  199. # to the maximum prompt length.
  200. # We need the 2D input shape because the Pallas FlashAttention kernel
  201. # does not support packed 1D inputs.
  202. # We pad the seq_len to powers of 2 to reduce the compilation overhead.
  203. max_prompt_len = _get_padded_prefill_len(max(prompt_lens))
  204. input_tokens = make_tensor_with_pad(input_tokens,
  205. max_prompt_len,
  206. pad=0,
  207. dtype=torch.int32,
  208. device=self.device)
  209. input_positions = make_tensor_with_pad(input_positions,
  210. max_prompt_len,
  211. pad=0,
  212. dtype=torch.int32,
  213. device=self.device)
  214. slot_mapping = make_tensor_with_pad(slot_mapping,
  215. max_prompt_len,
  216. pad=_PAD_SLOT_ID,
  217. dtype=torch.int64,
  218. device=self.device)
  219. prompt_lens = torch.tensor(prompt_lens,
  220. dtype=torch.int32,
  221. device=self.device)
  222. attn_metadata = self.attn_backend.make_metadata(
  223. num_prefills=num_prefills,
  224. num_prefill_tokens=num_prefill_tokens, # NOTE: This is not used.
  225. num_decode_tokens=0,
  226. slot_mapping=slot_mapping,
  227. block_tables=None,
  228. context_lens=None,
  229. )
  230. return input_tokens, input_positions, attn_metadata, prompt_lens
  231. def _prepare_decode(
  232. self,
  233. seq_group_metadata_list: List[SequenceGroupMetadata],
  234. ):
  235. assert len(seq_group_metadata_list) > 0
  236. input_tokens: List[List[int]] = []
  237. input_positions: List[List[int]] = []
  238. slot_mapping: List[List[int]] = []
  239. context_lens: List[int] = []
  240. num_seq_groups = len(seq_group_metadata_list)
  241. batch_size = _get_padded_batch_size(num_seq_groups)
  242. for i, seq_group_metadata in enumerate(seq_group_metadata_list):
  243. assert not seq_group_metadata.is_prompt
  244. seq_ids = list(seq_group_metadata.seq_data.keys())
  245. for seq_id in seq_ids:
  246. seq_data = seq_group_metadata.seq_data[seq_id]
  247. generation_token = seq_data.get_last_token_id()
  248. input_tokens.append([generation_token])
  249. seq_len = seq_data.get_len()
  250. position = seq_len - 1
  251. input_positions.append([position])
  252. context_lens.append(seq_len)
  253. assert seq_group_metadata.block_tables is not None
  254. block_table = seq_group_metadata.block_tables[seq_id]
  255. self.block_tables[i, :len(block_table)] = block_table
  256. block_number = block_table[position // self.block_size]
  257. block_offset = position % self.block_size
  258. slot = block_number * self.block_size + block_offset
  259. slot_mapping.append([slot])
  260. num_paddings = batch_size - num_seq_groups
  261. input_tokens = input_tokens + [[0]] * num_paddings
  262. input_positions = input_positions + [[0]] * num_paddings
  263. slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
  264. context_lens = context_lens + [0] * num_paddings
  265. input_tokens = torch.tensor(input_tokens,
  266. dtype=torch.int32,
  267. device=self.device)
  268. input_positions = torch.tensor(input_positions,
  269. dtype=torch.int32,
  270. device=self.device)
  271. slot_mapping = torch.tensor(slot_mapping,
  272. dtype=torch.int64,
  273. device=self.device)
  274. context_lens = torch.tensor(context_lens,
  275. dtype=torch.int32,
  276. device=self.device)
  277. block_tables = torch.tensor(self.block_tables[:batch_size],
  278. dtype=torch.int32,
  279. device=self.device)
  280. input_lens = torch.tensor([1] * batch_size,
  281. dtype=torch.int32,
  282. device=self.device)
  283. attn_metadata = self.attn_backend.make_metadata(
  284. num_prefills=0,
  285. num_prefill_tokens=0,
  286. num_decode_tokens=batch_size,
  287. slot_mapping=slot_mapping,
  288. block_tables=block_tables,
  289. context_lens=context_lens,
  290. )
  291. return input_tokens, input_positions, attn_metadata, input_lens
  292. def _prepare_sample(
  293. self,
  294. seq_group_metadata_list: List[SequenceGroupMetadata],
  295. padded_batch_size: int,
  296. ) -> Tuple[torch.Tensor, torch.Tensor]:
  297. assert len(seq_group_metadata_list) > 0
  298. t = []
  299. p = []
  300. for seq_group_metadata in seq_group_metadata_list:
  301. assert seq_group_metadata.sampling_params is not None
  302. sampling_params = seq_group_metadata.sampling_params
  303. t.append(sampling_params.temperature
  304. if sampling_params.temperature >= 1e-5 else 1e-5)
  305. p.append(sampling_params.top_p)
  306. num_paddings = padded_batch_size - len(seq_group_metadata_list)
  307. t += [1.0] * num_paddings
  308. p += [1.0] * num_paddings
  309. t = torch.tensor(t, dtype=torch.float32, device=self.device)
  310. p = torch.tensor(p, dtype=torch.float32, device=self.device)
  311. return t, p
  312. def prepare_inputs(
  313. self,
  314. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  315. ):
  316. assert seq_group_metadata_list is not None
  317. assert len(seq_group_metadata_list) > 0
  318. # NOTE: We assume that all sequences in the group are all prompts or
  319. # all decodes.
  320. if seq_group_metadata_list[0].is_prompt:
  321. inputs = self._prepare_prompt(seq_group_metadata_list)
  322. else:
  323. inputs = self._prepare_decode(seq_group_metadata_list)
  324. padded_batch_size = inputs[0].shape[0]
  325. sample_inputs = self._prepare_sample(seq_group_metadata_list,
  326. padded_batch_size)
  327. return inputs + sample_inputs
  328. def _execute_model(
  329. self,
  330. seq_group_metadata_list: List[SequenceGroupMetadata],
  331. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  332. ) -> List[CompletionSequenceGroupOutput]:
  333. inputs = self.prepare_inputs(seq_group_metadata_list)
  334. next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
  335. *inputs[2:])
  336. next_token_ids = next_token_ids.cpu().tolist()
  337. i = 0
  338. sampler_outputs = []
  339. for seq_group_metadata in seq_group_metadata_list:
  340. seq_outputs = []
  341. seq_ids = list(seq_group_metadata.seq_data.keys())
  342. for seq_id in seq_ids:
  343. next_token_id = next_token_ids[i]
  344. seq_outputs.append(
  345. SequenceOutput(seq_id, next_token_id,
  346. {next_token_id: Logprob(0.0)}))
  347. i += 1
  348. sampler_outputs.append(
  349. CompletionSequenceGroupOutput(seq_outputs, None))
  350. return sampler_outputs
  351. def execute_model(
  352. self,
  353. seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
  354. kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
  355. ) -> SamplerOutput:
  356. assert seq_group_metadata_list is not None
  357. if seq_group_metadata_list[0].is_prompt:
  358. # NOTE: To reduce the compilation time, we only compile the
  359. # prefill inputs with batch size 1. Because the scheduler is not
  360. # aware of this limitation, we need to handle batch size > 1
  361. # internally by calling the model multiple times and concatenating
  362. # the outputs.
  363. # FIXME: This is a temporary hack to not change the existing
  364. # scheduler. We need to fix this in the future.
  365. sampler_outputs = []
  366. for seq_group_metadata in seq_group_metadata_list:
  367. sampler_outputs += self._execute_model([seq_group_metadata],
  368. kv_caches)
  369. else:
  370. sampler_outputs = self._execute_model(seq_group_metadata_list,
  371. kv_caches)
  372. return SamplerOutput(sampler_outputs)
  373. class ModelWrapper(nn.Module):
  374. def __init__(self, model: nn.Module):
  375. super().__init__()
  376. self.model = model.eval()
  377. def forward(
  378. self,
  379. token_ids: torch.Tensor,
  380. position_ids: torch.Tensor,
  381. kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
  382. attn_metadata: AttentionMetadata,
  383. input_lens: torch.Tensor,
  384. t: torch.Tensor,
  385. p: torch.Tensor,
  386. ) -> torch.Tensor:
  387. """Executes the forward pass of the model and samples the next token.
  388. Args:
  389. token_ids: The input token IDs of shape [batch_size, seq_len].
  390. position_ids: The input position IDs of shape [batch_size, seq_len].
  391. kv_caches: The key and value caches. They can be None during the
  392. memory profiling at initialization.
  393. attn_metadata: The Pallas attention metadata.
  394. input_lens: The actual input lengths of shape [batch_size].
  395. t: The sampling temperature of shape [batch_size].
  396. p: The top-p probability of shape [batch_size].
  397. """
  398. batch_size, seq_len = token_ids.shape
  399. # Calculate the positions to sample from.
  400. base_indicies = torch.arange(
  401. batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
  402. logits_indices = base_indicies + input_lens - 1
  403. # FIXME: This is a temporary hack to avoid using the existing
  404. # sampler and sampling metadata.
  405. sampling_metadata = SamplingMetadata(
  406. seq_groups=[],
  407. selected_token_indices=logits_indices,
  408. categorized_sample_indices={},
  409. num_prompts=attn_metadata.num_prefills,
  410. )
  411. # Skip this in memory profiling at initialization.
  412. if kv_caches[0][0] is not None:
  413. # index_copy_(slot_mapping) only works when the inserted dimension
  414. # is 0. However, the KV cache in the Pallas backend has the shape
  415. # [num_kv_heads, num_blocks, block_size, head_size]. To make it
  416. # work, we need to flatten the first three dimensions and modify
  417. # the slot_mapping accordingly.
  418. num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
  419. slot_mapping = attn_metadata.slot_mapping
  420. slot_mapping = slot_mapping.flatten()
  421. head_indicies = torch.arange(0,
  422. num_kv_heads,
  423. device=slot_mapping.device,
  424. dtype=slot_mapping.dtype)
  425. head_indicies *= block_size * num_blocks
  426. slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
  427. -1, num_kv_heads)
  428. slot_mapping = slot_mapping + head_indicies.view(1, -1)
  429. slot_mapping = slot_mapping.flatten()
  430. attn_metadata.slot_mapping = slot_mapping
  431. hidden_states = self.model(
  432. token_ids,
  433. position_ids,
  434. kv_caches,
  435. attn_metadata,
  436. )
  437. hidden_states = hidden_states.flatten(0, 1)
  438. logits = self.model.compute_logits(hidden_states, sampling_metadata)
  439. logits = logits / t.unsqueeze(dim=1)
  440. # FIXME: Disabled top-p sampling since it's too slow.
  441. # logits = _apply_top_p(logits, p.unsqueeze(dim=1))
  442. probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
  443. # FIXME: best_of > 1 is not supported.
  444. next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
  445. return next_token_ids
  446. def _get_padded_prefill_len(x: int) -> int:
  447. # NOTE: The pallas FlashAttention kernel requires the sequence
  448. # length to be a multiple of 16. We pad the prompt length to the nearest
  449. # multiple of 16. This is also good for performance.
  450. if x <= 16:
  451. return 16
  452. return 1 << (x - 1).bit_length()
  453. def _get_padded_batch_size(batch_size: int) -> int:
  454. if batch_size <= 2:
  455. return batch_size
  456. elif batch_size <= 4:
  457. return 4
  458. elif batch_size <= 8:
  459. return 8
  460. else:
  461. return ((batch_size + 15) // 16) * 16
  462. def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
  463. logits_sorted = torch.sort(logits, dim=-1, descending=True).values
  464. sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
  465. cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
  466. cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
  467. logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
  468. return logits