123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836 |
- import time
- from dataclasses import dataclass
- from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
- Type, Union)
- from unittest.mock import patch
- import numpy as np
- import torch
- import torch.nn as nn
- import torch_xla.core.xla_model as xm
- import torch_xla.runtime as xr
- from loguru import logger
- from aphrodite.attention import AttentionMetadata, get_attn_backend
- from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
- ModelConfig, ParallelConfig,
- SchedulerConfig)
- from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
- IntermediateTensors, Logprob,
- SequenceGroupMetadata, SequenceOutput)
- from aphrodite.compilation.wrapper import (
- TorchCompileWrapperWithCustomDispacther)
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.modeling.model_loader import get_model
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.worker.model_runner_base import (
- ModelRunnerBase, ModelRunnerInputBase,
- _add_attn_metadata_broadcastable_dict,
- _init_attn_metadata_from_tensor_dict)
- if TYPE_CHECKING:
- from aphrodite.attention.backends.abstract import AttentionBackend
- # Here we utilize the behavior that out-of-bound index is ignored.
- # FIXME: Find a more reliable way to prevent possible bugs.
- _PAD_SLOT_ID = 1_000_000_000
- # FIXME: Temporarily disabled top-p sampling since it's too slow.
- _ENABLE_TOP_P = False
- # FIXME: A temporary hack to support `n > 1`.
- # This can significantly affect the performance if too large.
- _MAX_NUM_SAMPLES = 128
- @dataclass(frozen=True)
- class ModelInputForTPU(ModelRunnerInputBase):
- token_ids: torch.Tensor
- position_ids: torch.Tensor
- attn_metadata: AttentionMetadata
- input_lens: torch.Tensor
- t: torch.Tensor
- p: torch.Tensor
- num_samples: int
- best_of: List[int]
- seq_groups: List[List[int]]
- is_first_multi_step: bool = True
- is_last_step: bool = True
- virtual_engine: int = 0
- async_callback: Optional[Callable] = None
- def as_broadcastable_tensor_dict(
- self) -> Dict[str, Union[int, torch.Tensor]]:
- tensor_dict = {
- "token_ids": self.token_ids,
- "position_ids": self.position_ids,
- "input_lens": self.input_lens,
- "t": self.t,
- "p": self.p,
- "num_samples": self.num_samples,
- "best_of": self.best_of,
- "seq_groups": self.seq_groups,
- "is_first_multi_step": self.is_first_multi_step,
- "is_last_step": self.is_last_step,
- "virtual_engine": self.virtual_engine,
- }
- _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
- return tensor_dict
- @classmethod
- def from_broadcasted_tensor_dict(
- cls: Type["ModelInputForTPU"],
- tensor_dict: Dict[str, Any],
- attn_backend: Optional["AttentionBackend"] = None,
- ) -> "ModelInputForTPU":
- if attn_backend is not None:
- tensor_dict = _init_attn_metadata_from_tensor_dict(
- attn_backend, tensor_dict)
- return cls(**tensor_dict)
- class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
- def __init__(
- self,
- model_config: ModelConfig,
- parallel_config: ParallelConfig,
- scheduler_config: SchedulerConfig,
- device_config: DeviceConfig,
- cache_config: CacheConfig,
- load_config: LoadConfig,
- is_driver_worker: bool = False,
- **kwargs,
- ):
- self.model_config = model_config
- self.parallel_config = parallel_config
- self.scheduler_config = scheduler_config
- self.device_config = device_config
- self.cache_config = cache_config
- self.load_config = load_config
- self.is_driver_worker = is_driver_worker
- self.block_size = self.cache_config.block_size
- self.max_num_blocks_per_seq = (self.model_config.max_model_len //
- self.block_size)
- self.block_tables = np.zeros(
- (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
- dtype=np.int32)
- self.attn_backend = get_attn_backend(
- self.model_config.get_head_size(),
- self.model_config.get_sliding_window(),
- self.model_config.dtype,
- self.cache_config.cache_dtype,
- self.block_size,
- self.model_config.is_attention_free(),
- False,
- )
- self.cached_step_outputs: List[torch.Tensor] = []
- def load_model(self) -> None:
- self.device = self.device_config.device
- # NOTE: While the executor assigns the TP ranks to the worker
- # process, the ranks can be different from the ranks internally assigned
- # by the xm runtime. Therefore, there is a mismatch in the rank
- # assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
- # This is not a problem in linear layers because all-reduce is
- # rank-agnostic. However, it matters for all-gather as the ranks
- # determine the order of concatenating the output tensors.
- # As a workaround, we use the xm's rank assignment only when loading
- # the embedding weights.
- xm_tp_rank = xr.global_ordinal()
- with patch(
- "aphrodite.modeling.layers.vocab_parallel_embedding."
- "get_tensor_model_parallel_rank",
- return_value=xm_tp_rank):
- model = get_model(
- model_config=self.model_config,
- load_config=self.load_config,
- device_config=self.device_config,
- parallel_config=self.parallel_config,
- cache_config=self.cache_config,
- scheduler_config=self.scheduler_config,
- lora_config=None,
- )
- model = model.eval()
- xm.wait_device_ops()
- self.model = ModelWrapper(model)
- def _dummy_run(
- self,
- batch_size: int,
- seq_len: int,
- kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
- is_prompt: bool,
- ) -> None:
- if is_prompt:
- seq_len = (seq_len + 15) // 16 * 16
- token_ids = torch.zeros((batch_size, seq_len),
- dtype=torch.int32,
- device=self.device)
- position_ids = torch.zeros((batch_size, seq_len),
- dtype=torch.int32,
- device=self.device)
- slot_mapping = torch.zeros((batch_size, seq_len),
- dtype=torch.int64,
- device=self.device)
- attn_metadata = self.attn_backend.make_metadata(
- num_prefills=batch_size,
- num_prefill_tokens=batch_size * seq_len,
- num_decode_tokens=0,
- slot_mapping=slot_mapping,
- block_tables=None,
- context_lens=None,
- )
- input_lens = torch.ones((batch_size, ),
- dtype=torch.int32,
- device=self.device)
- else:
- assert seq_len == 1
- token_ids = torch.zeros((batch_size, seq_len),
- dtype=torch.int32,
- device=self.device)
- position_ids = torch.zeros((batch_size, seq_len),
- dtype=torch.int32,
- device=self.device)
- slot_mapping = torch.zeros((batch_size, seq_len),
- dtype=torch.int64,
- device=self.device)
- block_tables = torch.zeros(
- (batch_size, self.max_num_blocks_per_seq),
- dtype=torch.int32,
- device=self.device)
- context_lens = torch.ones((batch_size, ),
- dtype=torch.int32,
- device=self.device)
- input_lens = torch.ones((batch_size, ),
- dtype=torch.int32,
- device=self.device)
- attn_metadata = self.attn_backend.make_metadata(
- num_prefills=0,
- num_prefill_tokens=0,
- num_decode_tokens=batch_size * seq_len,
- slot_mapping=slot_mapping,
- block_tables=block_tables,
- context_lens=context_lens,
- )
- t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
- p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
- num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
- # NOTE: There are two stages of compilation: torch.compile and
- # XLA compilation. Using `mark_dynamic` can reduce the torch.compile
- # overhead by reusing the FX graph for different shapes.
- # However, the XLA graph will still require static shapes and needs to
- # be re-compiled for every different shapes. This overhead is inevitable
- # in the first run, but can be skipped afterwards as we cache the XLA
- # graphs in the disk (APHRODITE_XLA_CACHE_PATH).
- if is_prompt:
- # Prefll
- torch._dynamo.mark_dynamic(token_ids, 1)
- torch._dynamo.mark_dynamic(position_ids, 1)
- torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
- else:
- # Decode
- torch._dynamo.mark_dynamic(token_ids, 0)
- torch._dynamo.mark_dynamic(position_ids, 0)
- torch._dynamo.mark_dynamic(input_lens, 0)
- torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
- torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
- torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
- torch._dynamo.mark_dynamic(t, 0)
- torch._dynamo.mark_dynamic(p, 0)
- # Dummy run.
- self.model(token_ids,
- position_ids,
- attn_metadata,
- input_lens,
- t,
- p,
- num_samples,
- kv_caches,
- is_prompt=is_prompt)
- def warmup_model(
- self,
- kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
- ) -> None:
- # Prefill
- logger.info("Compiling the model with different input shapes...")
- start = time.time()
- for batch_size in [1]:
- seq_len = 16
- while True:
- self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True)
- xm.wait_device_ops()
- logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
- if seq_len >= self.model_config.max_model_len:
- break
- num_tokens = batch_size * seq_len
- if num_tokens >= self.scheduler_config.max_num_batched_tokens:
- break
- seq_len = seq_len * 2
- end = time.time()
- logger.info(f"Compilation for prefill done in {end - start:.2f} s.")
- # Decode
- start = time.time()
- seq_len = 1
- batch_size = 8 # Must be in sync with _get_padded_batch_size()
- while True:
- self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False)
- xm.wait_device_ops()
- logger.info(f"batch_size: {batch_size}, seq_len: {seq_len}")
- if batch_size >= self.scheduler_config.max_num_seqs:
- break
- batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
- end = time.time()
- logger.info(f"Compilation for decode done in {end - start:.2f} s.")
- def _prepare_prompt(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[int] = []
- input_positions: List[int] = []
- prompt_lens: List[int] = []
- slot_mapping: List[int] = []
- for seq_group_metadata in seq_group_metadata_list:
- assert seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- seq_data = seq_group_metadata.seq_data[seq_id]
- # Could include output tokens when a request is preempted.
- prompt_tokens = seq_data.get_token_ids()
- prompt_len = len(prompt_tokens)
- prompt_lens.append(prompt_len)
- input_tokens.extend(prompt_tokens)
- input_positions.extend(list(range(prompt_len)))
- assert seq_group_metadata.block_tables is not None
- block_table = seq_group_metadata.block_tables[seq_id]
- for i in range(prompt_len):
- block_number = block_table[i // self.block_size]
- block_offset = i % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append(slot)
- # Add paddings to EACH prompt to the smallest power of 2 that is
- # greater than or equal to the prompt length.
- # We pad the seq_len to reduce the compilation overhead.
- # We execute each prompt individually (i.e., with batch_size 1)
- # because the FlashAttention kernel does not support ragged inputs.
- # TODO(woosuk): Use SplashAttention to support ragged inputs.
- padded_prompt_len = _get_padded_prefill_len(prompt_len)
- num_paddings = padded_prompt_len - prompt_len
- input_tokens += [0] * num_paddings
- input_positions += [0] * num_paddings
- slot_mapping += [_PAD_SLOT_ID] * num_paddings
- assert len(prompt_lens) > 0
- num_prefills = len(prompt_lens)
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.int32,
- device="cpu")
- input_positions = torch.tensor(input_positions,
- dtype=torch.int32,
- device="cpu")
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.int64,
- device="cpu")
- prompt_lens = torch.tensor(prompt_lens,
- dtype=torch.int32,
- device="cpu")
- attn_metadata = self.attn_backend.make_metadata(
- num_prefills=num_prefills,
- num_prefill_tokens=0, # NOTE: This is not used.
- num_decode_tokens=0,
- slot_mapping=slot_mapping,
- block_tables=None,
- context_lens=None,
- )
- return input_tokens, input_positions, attn_metadata, prompt_lens
- def _prepare_decode(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
- assert len(seq_group_metadata_list) > 0
- input_tokens: List[List[int]] = []
- input_positions: List[List[int]] = []
- slot_mapping: List[List[int]] = []
- context_lens: List[int] = []
- batch_idx = 0
- for seq_group_metadata in seq_group_metadata_list:
- assert not seq_group_metadata.is_prompt
- seq_ids = list(seq_group_metadata.seq_data.keys())
- for seq_id in seq_ids:
- seq_data = seq_group_metadata.seq_data[seq_id]
- generation_token = seq_data.get_last_token_id()
- input_tokens.append([generation_token])
- seq_len = seq_data.get_len()
- position = seq_len - 1
- input_positions.append([position])
- context_lens.append(seq_len)
- assert seq_group_metadata.block_tables is not None
- block_table = seq_group_metadata.block_tables[seq_id]
- self.block_tables[batch_idx, :len(block_table)] = block_table
- batch_idx += 1
- block_number = block_table[position // self.block_size]
- block_offset = position % self.block_size
- slot = block_number * self.block_size + block_offset
- slot_mapping.append([slot])
- batch_size = _get_padded_batch_size(batch_idx)
- num_paddings = batch_size - batch_idx
- input_tokens = input_tokens + [[0]] * num_paddings
- input_positions = input_positions + [[0]] * num_paddings
- slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
- context_lens = context_lens + [0] * num_paddings
- input_tokens = torch.tensor(input_tokens,
- dtype=torch.int32,
- device="cpu")
- input_positions = torch.tensor(input_positions,
- dtype=torch.int32,
- device="cpu")
- slot_mapping = torch.tensor(slot_mapping,
- dtype=torch.int64,
- device="cpu")
- context_lens = torch.tensor(context_lens,
- dtype=torch.int32,
- device="cpu")
- block_tables = torch.tensor(self.block_tables[:batch_size],
- dtype=torch.int32,
- device="cpu")
- input_lens = torch.tensor([1] * batch_size,
- dtype=torch.int32,
- device="cpu")
- attn_metadata = self.attn_backend.make_metadata(
- num_prefills=0,
- num_prefill_tokens=0,
- num_decode_tokens=batch_size,
- slot_mapping=slot_mapping,
- block_tables=block_tables,
- context_lens=context_lens,
- )
- return input_tokens, input_positions, attn_metadata, input_lens
- def _prepare_sample(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- padded_batch_size: int,
- ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
- assert len(seq_group_metadata_list) > 0
- t = []
- p = []
- best_of = []
- for seq_group_metadata in seq_group_metadata_list:
- sampling_params = seq_group_metadata.sampling_params
- t.append(sampling_params.temperature)
- if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
- raise NotImplementedError(
- "Top-p sampling is currently disabled for the TPU backend "
- "due to performance issues.")
- p.append(sampling_params.top_p)
- if sampling_params.top_k != -1:
- raise NotImplementedError(
- "Top-k sampling is currently disabled for the TPU backend "
- "due to performance issues.")
- if sampling_params.best_of > _MAX_NUM_SAMPLES:
- raise NotImplementedError(
- f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
- "backend.")
- best_of.append(sampling_params.best_of)
- if sampling_params.use_beam_search:
- raise NotImplementedError(
- "Beam search is not supported by the TPU backend.")
- if sampling_params.logprobs is not None:
- raise NotImplementedError(
- "logprobs is not currently supported by the TPU backend.")
- if sampling_params.prompt_logprobs is not None:
- raise NotImplementedError(
- "prompt_logprobs is not currently supported by the TPU "
- "backend.")
- # Repeat the sampling params if the seq group has multiple seqs.
- num_seqs = len(seq_group_metadata.seq_data)
- t += [t[-1]] * (num_seqs - 1)
- p += [p[-1]] * (num_seqs - 1)
- best_of += [best_of[-1]] * (num_seqs - 1)
- num_paddings = padded_batch_size - len(t)
- t += [1.0] * num_paddings
- p += [1.0] * num_paddings
- t = torch.tensor(t, dtype=torch.float32, device="cpu")
- p = torch.tensor(p, dtype=torch.float32, device="cpu")
- return t, p, best_of
- def prepare_model_input(
- self,
- seq_group_metadata_list: List[SequenceGroupMetadata],
- virtual_engine: int = 0,
- finished_requests_ids: Optional[List[str]] = None,
- ) -> ModelInputForTPU:
- del finished_requests_ids # Unused.
- assert virtual_engine == 0
- assert len(seq_group_metadata_list) > 0
- # NOTE: We assume that all sequences in the group are all prompts or
- # all decodes.
- is_prompt = seq_group_metadata_list[0].is_prompt
- if is_prompt:
- inputs = self._prepare_prompt(seq_group_metadata_list)
- else:
- inputs = self._prepare_decode(seq_group_metadata_list)
- input_tokens, input_positions, attn_metadata, input_lens = inputs
- padded_batch_size = input_tokens.shape[0]
- t, p, best_of = self._prepare_sample(seq_group_metadata_list,
- padded_batch_size)
- num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
- seq_groups = [
- list(metadata.seq_data.keys())
- for metadata in seq_group_metadata_list
- ]
- return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
- input_lens, t, p, num_samples, best_of,
- seq_groups)
- def make_model_input_from_broadcasted_tensor_dict(
- self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
- model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
- tensor_dict, attn_backend=self.attn_backend)
- return model_input
- @torch.no_grad()
- def execute_model(
- self,
- model_input: ModelInputForTPU,
- kv_caches: Optional[List[Any]],
- intermediate_tensors: Optional[IntermediateTensors] = None,
- num_steps: int = 1,
- ) -> List[SamplerOutput]:
- assert intermediate_tensors is None
- if not model_input.is_first_multi_step:
- if not model_input.is_last_step:
- return []
- use_async_out_proc = model_input.async_callback is not None
- sampler_outputs = []
- num_outputs = len(self.cached_step_outputs)
- for i in range(num_outputs):
- next_token_ids = self.cached_step_outputs.pop(0)
- next_token_ids = next_token_ids.cpu().tolist()
- sampler_output = _make_decode_output(next_token_ids,
- model_input.seq_groups)
- sampler_outputs.append(sampler_output)
- if i < num_outputs - 1 and use_async_out_proc:
- assert model_input.async_callback is not None
- ctx = model_input.async_callback.keywords[ # type: ignore
- "ctx"]
- ctx.append_output(
- outputs=[sampler_output],
- seq_group_metadata_list=ctx.seq_group_metadata_list,
- scheduler_outputs=ctx.scheduler_outputs,
- is_async=False,
- is_last_step=False)
- model_input.async_callback()
- if use_async_out_proc:
- return [sampler_outputs[-1]]
- else:
- return sampler_outputs
- is_prompt = model_input.attn_metadata.num_prefills > 0
- if is_prompt:
- assert num_steps == 1
- # NOTE: Since the FlashAttention kernel does not support
- # ragged inputs, we split the prompts into different batches and
- # process them separately. This is a temporary hack that should be
- # optimized by using SplashAttention.
- orig_slot_mapping = model_input.attn_metadata.slot_mapping
- batch_size = model_input.input_lens.shape[0]
- start_idx = 0
- next_token_ids = []
- for i in range(batch_size):
- # Get the actual prefill_len.
- prefill_len = model_input.input_lens[i:i + 1].item()
- prefill_len = _get_padded_prefill_len(prefill_len)
- end_idx = start_idx + prefill_len
- token_ids = model_input.token_ids[None, start_idx:end_idx].to(
- self.device)
- position_ids = model_input.position_ids[None,
- start_idx:end_idx].to(
- self.device)
- attn_metadata = model_input.attn_metadata
- attn_metadata.num_prefills = 1
- attn_metadata.slot_mapping = orig_slot_mapping[
- None, start_idx:end_idx].to(self.device)
- input_lens = model_input.input_lens[i:i + 1].to(self.device)
- t = model_input.t[i:i + 1].to(self.device)
- p = model_input.p[i:i + 1].to(self.device)
- output_token_ids = self.model(token_ids,
- position_ids,
- attn_metadata,
- input_lens,
- t,
- p,
- model_input.num_samples,
- kv_caches,
- is_prompt=True)
- next_token_ids.append(output_token_ids[0])
- start_idx = end_idx
- if model_input.async_callback is not None:
- model_input.async_callback()
- # Retrieve the outputs to CPU.
- next_token_ids = [
- output_token_ids.cpu().tolist()
- for output_token_ids in next_token_ids
- ]
- # NOTE: Minimal code to construct the sampler outputs.
- # The TPU backend does not reuse the sampler, since the TPU backend
- # does not support advanced sampling parameters such as logprobs.
- zero_logprob = Logprob(0.0)
- sampler_outputs = []
- for i, seq_group in enumerate(model_input.seq_groups):
- seq_ids = seq_group
- assert len(seq_ids) == 1
- seq_id = seq_ids[0]
- seq_outputs = []
- for j in range(model_input.best_of[i]):
- next_token_id = next_token_ids[i][j]
- seq_outputs.append(
- SequenceOutput(seq_id, next_token_id,
- {next_token_id: zero_logprob}))
- sampler_outputs.append(
- CompletionSequenceGroupOutput(seq_outputs, None))
- return [SamplerOutput(sampler_outputs)]
- else:
- token_ids = model_input.token_ids.to(self.device)
- position_ids = model_input.position_ids.to(self.device)
- attn_metadata = model_input.attn_metadata
- attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
- self.device)
- attn_metadata.block_tables = attn_metadata.block_tables.to(
- self.device)
- attn_metadata.context_lens = attn_metadata.context_lens.to(
- self.device)
- t = model_input.t.to(self.device)
- p = model_input.p.to(self.device)
- input_lens = model_input.input_lens.to(self.device)
- for i in range(num_steps):
- slot_mapping = attn_metadata.slot_mapping
- output_token_ids = self.model(token_ids,
- position_ids,
- attn_metadata,
- input_lens,
- t,
- p,
- model_input.num_samples,
- kv_caches,
- is_prompt=False)
- self.cached_step_outputs.append(output_token_ids)
- if i < num_steps - 1:
- # Prepare the inputs for the next step.
- token_ids = output_token_ids.unsqueeze(dim=1).int()
- position_ids = position_ids + 1
- attn_metadata.context_lens = attn_metadata.context_lens + 1
- block_tables = attn_metadata.block_tables
- block_number = block_tables.gather(
- 1,
- position_ids.long() // self.block_size)
- block_offset = position_ids % self.block_size
- is_padding = slot_mapping == _PAD_SLOT_ID
- slot_mapping = block_number * self.block_size + block_offset
- slot_mapping = slot_mapping.long()
- slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
- slot_mapping)
- attn_metadata.slot_mapping = slot_mapping
- if model_input.async_callback is not None:
- model_input.async_callback()
- if num_steps > 1:
- return []
- # Retrieve the outputs to CPU.
- next_token_ids = self.cached_step_outputs.pop(0)
- next_token_ids = next_token_ids.cpu().tolist()
- sampler_output = _make_decode_output(next_token_ids,
- model_input.seq_groups)
- return [sampler_output]
- class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
- def __init__(self, model: nn.Module):
- self.model = model
- compiled_callable = torch.compile(self.forward,
- backend="openxla",
- fullgraph=True,
- dynamic=False)
- super().__init__(compiled_callable)
- def __call__(self, *args, is_prompt: bool, **kwargs):
- if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher:
- # not fully compiled yet, or not using the custom dispatcher,
- # let PyTorch handle it
- return self.compiled_callable(*args, **kwargs)
- # the 3 compiled codes are:
- # 0: for profiling
- # 1: for prompt
- # 2: for decode
- # dispatch to the compiled code directly, skip PyTorch
- if is_prompt:
- with self.dispatch_to_code(1):
- return self.forward(*args, **kwargs)
- else:
- with self.dispatch_to_code(2):
- return self.forward(*args, **kwargs)
- def forward(
- self,
- token_ids: torch.Tensor,
- position_ids: torch.Tensor,
- attn_metadata: AttentionMetadata,
- input_lens: torch.Tensor,
- t: torch.Tensor,
- p: torch.Tensor,
- num_samples: int,
- kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
- ) -> torch.Tensor:
- """Executes the forward pass of the model and samples the next token.
- Args:
- token_ids: The input token IDs of shape [batch_size, seq_len].
- position_ids: The input position IDs of shape [batch_size, seq_len].
- attn_metadata: The Pallas attention metadata.
- input_lens: The actual input lengths of shape [batch_size].
- t: The sampling temperature of shape [batch_size].
- p: The top-p probability of shape [batch_size].
- num_samples: Number of samples to draw from each logits vector.
- kv_caches: The key and value caches. They can be None during the
- memory profiling at initialization.
- """
- batch_size, seq_len = token_ids.shape
- # Calculate the positions to sample from.
- start_indicies = torch.arange(
- batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
- logits_indices = start_indicies + input_lens - 1
- # FIXME: This is a temporary hack to avoid using the existing
- # sampler and sampling metadata.
- sampling_metadata = SamplingMetadata(
- seq_groups=[],
- selected_token_indices=logits_indices,
- categorized_sample_indices={},
- num_prompts=attn_metadata.num_prefills,
- )
- # Skip this in memory profiling at initialization.
- if kv_caches[0][0] is not None:
- # index_copy_(slot_mapping) only works when the inserted dimension
- # is 0. However, the KV cache in the Pallas backend has the shape
- # [num_kv_heads, num_blocks, block_size, head_size]. To make it
- # work, we need to flatten the first three dimensions and modify
- # the slot_mapping accordingly.
- num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
- slot_mapping = attn_metadata.slot_mapping
- slot_mapping = slot_mapping.flatten()
- head_indicies = torch.arange(0,
- num_kv_heads,
- device=slot_mapping.device,
- dtype=slot_mapping.dtype)
- head_indicies *= block_size * num_blocks
- slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
- -1, num_kv_heads)
- slot_mapping = slot_mapping + head_indicies.view(1, -1)
- slot_mapping = slot_mapping.flatten()
- attn_metadata.slot_mapping = slot_mapping
- hidden_states = self.model(
- token_ids,
- position_ids,
- kv_caches,
- attn_metadata,
- )
- hidden_states = hidden_states.flatten(0, 1)
- logits = self.model.compute_logits(hidden_states, sampling_metadata)
- # Argmax sampling.
- argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
- argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
- # Zero temperature means greedy decoding. Avoid division by zero.
- nonzero_t = torch.where(t != 0, t, 1.0)
- logits = logits / nonzero_t.unsqueeze(dim=1)
- if _ENABLE_TOP_P:
- logits = _apply_top_p(logits, p.unsqueeze(dim=1))
- # Random sampling.
- probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
- sampled_token_ids = torch.multinomial(probs,
- num_samples,
- replacement=True)
- if num_samples == 1:
- argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
- sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
- next_token_ids = torch.where(t != 0, sampled_token_ids,
- argmax_token_ids)
- return next_token_ids
- def _get_padded_prefill_len(x: int) -> int:
- # NOTE: The pallas FlashAttention kernel requires the sequence
- # length to be a multiple of 16. We pad the prompt length to the nearest
- # multiple of 16. This is also good for performance.
- if x <= 16:
- return 16
- return 1 << (x - 1).bit_length()
- def _get_padded_batch_size(batch_size: int) -> int:
- # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
- # To meet this requirement in the simplest way, we set the minimal batch
- # size to 8.
- if batch_size <= 8:
- return 8
- else:
- return ((batch_size + 15) // 16) * 16
- def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
- logits_sorted = torch.sort(logits, dim=-1, descending=True).values
- sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
- cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
- cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
- logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
- return logits
- def _make_decode_output(
- next_token_ids: List[int],
- seq_groups: List[List[int]],
- ) -> SamplerOutput:
- zero_logprob = Logprob(0.0)
- sampler_outputs = []
- batch_idx = 0
- for seq_group in seq_groups:
- seq_ids = seq_group
- seq_outputs = []
- for seq_id in seq_ids:
- next_token_id = next_token_ids[batch_idx]
- seq_outputs.append(
- SequenceOutput(seq_id, next_token_id,
- {next_token_id: zero_logprob}))
- batch_idx += 1
- sampler_outputs.append(CompletionSequenceGroupOutput(
- seq_outputs, None))
- return SamplerOutput(sampler_outputs)
|