from typing import Dict, List, Optional, Tuple import torch from loguru import logger from torch import nn from aphrodite.common.sequence import (SamplerOutput, SequenceData, SequenceGroupMetadata) from aphrodite.common.utils import (async_tensor_h2d, is_pin_memory_available, make_tensor_with_pad, maybe_expand_dim) from aphrodite.common.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) from aphrodite.modeling import SamplingMetadata from aphrodite.modeling.model_loader.neuron import get_neuron_model from aphrodite.common.sampling_params import SamplingParams, SamplingType class NeuronModelRunner: def __init__( self, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, device_config: DeviceConfig, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config if model_config is not None and model_config.get_sliding_window(): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() # Lazy initialization. self.model: nn.Module # initialize after load_model. def load_model(self) -> None: self.model = get_neuron_model(self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config) def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] input_block_ids: List[int] = [] prompt_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) input_tokens.append(prompt_tokens) input_positions.append(list(range(prompt_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) max_prompt_len = max(prompt_lens) assert max_prompt_len > 0 input_tokens = make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, max_prompt_len, pad=0, dtype=torch.long, device=self.device) input_block_ids = torch.tensor(input_block_ids, dtype=torch.long, device=self.device) return input_tokens, input_positions, input_block_ids, prompt_lens def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] input_block_ids: List[int] = [] context_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append([position]) context_lens.append(seq_len) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) input_tokens = make_tensor_with_pad(input_tokens, max_len=1, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, device=self.device) input_block_ids = torch.tensor(input_block_ids, dtype=torch.long, device=self.device) return input_tokens, input_positions, input_block_ids def _prepare_sample( self, seq_group_metadata_list: List[SequenceGroupMetadata], prompt_lens: List[int], ) -> SamplingMetadata: seq_groups: List[Tuple[List[int], SamplingParams]] = [] selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { t: [] for t in SamplingType } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) if seq_group_metadata.is_prompt: assert len(seq_ids) == 1 assert prompt_lens is not None prompt_len = prompt_lens[i] if sampling_params.prompt_logprobs is not None: # NOTE: prompt token positions do not need sample, skip categorized_sample_indices_start_idx += prompt_len - 1 categorized_sample_indices[ sampling_params.sampling_type].append( (categorized_sample_indices_start_idx, categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(selected_token_start_idx, selected_token_start_idx + prompt_len - 1)) selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += prompt_len if sampling_params.seed is not None: seq_group_metadata.state.generator = torch.Generator( device=self.device).manual_seed(sampling_params.seed) else: num_seqs = len(seq_ids) selected_token_indices.extend( range(selected_token_start_idx, selected_token_start_idx + num_seqs)) selected_token_start_idx += num_seqs categorized_sample_indices[ sampling_params.sampling_type].extend( zip( range( categorized_sample_indices_start_idx, categorized_sample_indices_start_idx + num_seqs), range( categorized_sampled_token_indices_start_idx, categorized_sampled_token_indices_start_idx + num_seqs))) categorized_sample_indices_start_idx += num_seqs categorized_sampled_token_indices_start_idx += num_seqs if sampling_params.seed is not None: generators.append(seq_group_metadata.state.generator) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=self.device, pin_memory=self.pin_memory) categorized_sample_indices = { t: maybe_expand_dim( async_tensor_h2d(seq_ids, dtype=torch.int, target_device=self.device, pin_memory=self.pin_memory), 2, 2) for t, seq_ids in categorized_sample_indices.items() } seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) sampling_metadata = SamplingMetadata( seq_groups=seq_groups, seq_data=seq_data, prompt_lens=prompt_lens, selected_token_indices=selected_token_indices, categorized_sample_indices=categorized_sample_indices, generators=generators, ) return sampling_metadata def prepare_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, prompt_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens) return (input_tokens, input_positions, input_block_ids, sampling_metadata) @torch.inference_mode() def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_block_ids, sampling_metadata ) = self.prepare_input_tensors(seq_group_metadata_list) hidden_states = self.model( input_ids=input_tokens, positions=input_positions, input_block_ids=input_block_ids, ) # Compute the logits. logits = self.model.compute_logits(hidden_states, sampling_metadata) # Sample the next token. output = self.model.sample( logits=logits, sampling_metadata=sampling_metadata, ) return output @property def vocab_size(self) -> int: return self.model_config.get_vocab_size()