|
@@ -0,0 +1,271 @@
|
|
|
+import torch
|
|
|
+from typing import Dict, List, Tuple, Any
|
|
|
+
|
|
|
+from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig
|
|
|
+from aphrodite.modeling import get_model, InputMetadata, set_random_seed
|
|
|
+from aphrodite.modeling.megatron.parallel_state import initialize_model_parallel, initialize_all_reduce_launcher
|
|
|
+from aphrodite.common.sampling_params import SamplingParams
|
|
|
+from aphrodite.common.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
|
|
|
+from aphrodite.task_handler.cache_engine import CacheEngine
|
|
|
+from aphrodite.common.utils import get_gpu_memory
|
|
|
+
|
|
|
+
|
|
|
+class Worker:
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ model_config: ModelConfig,
|
|
|
+ parallel_config: ParallelConfig,
|
|
|
+ scheduler_config: SchedulerConfig,
|
|
|
+ rank: int,
|
|
|
+ distributed_init_method: str,
|
|
|
+ ) -> None:
|
|
|
+ self.model_config = model_config
|
|
|
+ self.parallel_config = parallel_config
|
|
|
+ self.scheduler_config = scheduler_config
|
|
|
+ self.rank = rank
|
|
|
+ self.distributed_init_method = distributed_init_method
|
|
|
+
|
|
|
+ _init_distributed_environment(parallel_config, rank, distributed_init_method)
|
|
|
+
|
|
|
+ set_random_seed(self.model_config.seed)
|
|
|
+ self.model = get_model(model_config)
|
|
|
+ initialize_all_reduce_launcher(
|
|
|
+ self.scheduler_config.max_num_batched_tokens,
|
|
|
+ self.model_config.get_hidden_size(),
|
|
|
+ self.model_config.dtype,
|
|
|
+ )
|
|
|
+
|
|
|
+ # These will be initialize by self.init_cache_engine()
|
|
|
+ self.cache_config = None
|
|
|
+ self.block_size = None
|
|
|
+ self.cache_engine = None
|
|
|
+ self.cache_events = None
|
|
|
+ self.gpu_cache = None
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
+ def profile_num_available_blocks(
|
|
|
+ self,
|
|
|
+ block_size: int,
|
|
|
+ gpu_memory_utilization: float,
|
|
|
+ cpu_swap_space: int,
|
|
|
+ ) -> Tuple[int, int]:
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ torch.cuda.reset_peak_memory_stats()
|
|
|
+
|
|
|
+ sampling_params = SamplingParams(top_p=0.99,
|
|
|
+ top_k=self.model.config.vocab_size - 1)
|
|
|
+ max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
|
|
+ max_num_seqs = self.scheduler_config.max_num_seqs
|
|
|
+ seqs = []
|
|
|
+ for group_id in range(max_num_seqs):
|
|
|
+ seq_len = (max_num_batched_tokens // max_num_seqs +
|
|
|
+ (group_id < max_num_batched_tokens % max_num_seqs))
|
|
|
+ seq_data = SequenceData([0] * seq_len)
|
|
|
+ seq = SequenceGroupMetadata(
|
|
|
+ request_id=str(group_id),
|
|
|
+ is_prompt=True,
|
|
|
+ seq_data={group_id: seq_data},
|
|
|
+ sampling_params=sampling_params,
|
|
|
+ block_tables=None,
|
|
|
+ )
|
|
|
+ seqs.append(seq)
|
|
|
+
|
|
|
+ input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
|
|
|
+
|
|
|
+ num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
|
+ self.model(
|
|
|
+ input_ids=input_tokens,
|
|
|
+ positions=input_positions,
|
|
|
+ kv_caches=[(None, None)] * num_layers,
|
|
|
+ input_metadata=input_metadata,
|
|
|
+ cache_events=None,
|
|
|
+ )
|
|
|
+
|
|
|
+ torch.cuda.synchronize()
|
|
|
+ peak_memory = torch.cuda.max_memory_allocated()
|
|
|
+ total_gpu_memory = get_gpu_memory()
|
|
|
+ cache_block_size = CacheEngine.get_cache_block_size(
|
|
|
+ block_size, self.model_config, self.parallel_config)
|
|
|
+ num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size)
|
|
|
+ num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
|
|
+ num_gpu_blocks = max(num_gpu_blocks, 0)
|
|
|
+ num_cpu_blocks = max(num_cpu_blocks, 0)
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ set_random_seed(self.model_config.seed)
|
|
|
+ return num_gpu_blocks, num_cpu_blocks
|
|
|
+
|
|
|
+
|
|
|
+ def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
|
|
+ self.cache_config = cache_config
|
|
|
+ self.block_size = cache_config.block_size
|
|
|
+ self.cache_engine = CacheEngine(
|
|
|
+ self.cache_config, self.model_config, self.parallel_config)
|
|
|
+ self.cache_events = self.cache_engine.events
|
|
|
+ self.gpu_cache = self.cache_engine.gpu_cache
|
|
|
+
|
|
|
+ def _prepare_inputs(
|
|
|
+ self,
|
|
|
+ seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
|
|
+ seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
|
|
+ input_tokens: List[int] = []
|
|
|
+ input_positions: List[int] = []
|
|
|
+ slot_mapping: List[int] = []
|
|
|
+
|
|
|
+ prompt_lens: List[int] = []
|
|
|
+ for seq_group_metadata in seq_group_metadata_list:
|
|
|
+ if not seq_group_metadata.is_prompt:
|
|
|
+ continue
|
|
|
+
|
|
|
+ seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
|
+ sampling_params = seq_group_metadata.sampling_params
|
|
|
+ seq_groups.append((seq_ids, sampling_params))
|
|
|
+
|
|
|
+ seq_id = seq_ids[0] # Use any sequence in the group
|
|
|
+
|
|
|
+ seq_data = seq_group_metadata.seq_data[seq_id]
|
|
|
+ prompt_tokens = seq_data.get_token_ids()
|
|
|
+ prompt_len = len(prompt_tokens)
|
|
|
+ prompt_lens.append(prompt_len)
|
|
|
+
|
|
|
+ input_tokens.extend(prompt_tokens)
|
|
|
+ input_positions.extend(range(len(prompt_tokens))) # assuming the first token in the prompt is always the first token in the sequence
|
|
|
+
|
|
|
+ if seq_group_metadata.block_tables is None:
|
|
|
+ slot_mapping.extend([0] * prompt_len)
|
|
|
+ continue
|
|
|
+
|
|
|
+ block_table = seq_group_metadata.block_tables[seq_id]
|
|
|
+ for i in range(prompt_len):
|
|
|
+ block_number = block_table[i // self.block_size]
|
|
|
+ block_offset = i % self.block_size
|
|
|
+ slot = block_number * self.block_size + block_offset
|
|
|
+ slot_mapping.append(slot)
|
|
|
+
|
|
|
+ max_context_len = 0
|
|
|
+ max_num_blocks_per_seq = 0
|
|
|
+ context_lens: List[int] = []
|
|
|
+ generation_block_tables: List[List[int]] = []
|
|
|
+ for seq_group_metadata in seq_group_metadata_list:
|
|
|
+ if seq_group_metadata.is_prompt:
|
|
|
+ continue
|
|
|
+
|
|
|
+ seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
|
+ sampling_params = seq_group_metadata.sampling_params
|
|
|
+ seq_groups.append((seq_ids, sampling_params))
|
|
|
+
|
|
|
+ for seq_id in seq_ids:
|
|
|
+ seq_data = seq_group_metadata.seq_data[seq_id]
|
|
|
+ generation_token = seq_data.get_last_token_id()
|
|
|
+ input_tokens.append(generation_token)
|
|
|
+
|
|
|
+ context_len = seq_data.get_len()
|
|
|
+ position = context_len - 1
|
|
|
+ input_positions.append(position)
|
|
|
+
|
|
|
+ block_table = seq_group_metadata.block_tables[seq_id]
|
|
|
+ generation_block_tables.append(block_table)
|
|
|
+
|
|
|
+ max_context_len = max(max_context_len, context_len)
|
|
|
+ max_num_blocks_per_seq = max(
|
|
|
+ max_num_blocks_per_seq, len(block_table))
|
|
|
+ context_lens.append(context_len)
|
|
|
+
|
|
|
+ block_number = block_table[position // self.block_size]
|
|
|
+ block_offset = position % self.block_size
|
|
|
+ slot = block_number * self.block_size + block_offset
|
|
|
+ slot_mapping.append(slot)
|
|
|
+
|
|
|
+
|
|
|
+ # Using Tensor Cores in NVIDIA requires the input length to be a multiple of 8
|
|
|
+ input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
|
|
+ input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
|
|
+
|
|
|
+ tokens_tensor = torch.cuda.LongTensor(input_tokens)
|
|
|
+ positions_tensor = torch.cuda.LongTensor(input_positions)
|
|
|
+ slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
|
|
|
+ context_lens_tensor = torch.cuda.IntTensor(context_lens)
|
|
|
+ padded_block_tables = [
|
|
|
+ _pad_to_max(block_table, max_num_blocks_per_seq)
|
|
|
+ for block_table in generation_block_tables]
|
|
|
+ block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
|
|
|
+
|
|
|
+ seq_data: Dict[int, SequenceData] = {}
|
|
|
+ for seq_group_metadata in seq_group_metadata_list:
|
|
|
+ seq_data.update(seq_group_metadata.seq_data)
|
|
|
+
|
|
|
+ input_metadata = InputMetadata(
|
|
|
+ seq_groups=seq_groups,
|
|
|
+ seq_data=seq_data,
|
|
|
+ prompt_lens=prompt_lens,
|
|
|
+ slot_mapping=slot_mapping_tensor,
|
|
|
+ context_lens=context_lens_tensor,
|
|
|
+ max_context_len=max_context_len,
|
|
|
+ block_tables=block_tables_tensor,
|
|
|
+ )
|
|
|
+ return tokens_tensor, positions_tensor, input_metadata
|
|
|
+
|
|
|
+ @torch.inference_mode()
|
|
|
+ def execute_model(
|
|
|
+ self,
|
|
|
+ seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
+ blocks_to_swap_in: Dict[int, int],
|
|
|
+ blocks_to_swap_out: Dict[int, int],
|
|
|
+ blocks_to_copy: Dict[int, List[int]],
|
|
|
+ ) -> Dict[int, SequenceOutputs]:
|
|
|
+ issued_cached_op = False
|
|
|
+ if blocks_to_swap_in:
|
|
|
+ self.cache_engine.swap_in(blocks_to_swap_in)
|
|
|
+ issued_cached_op = True
|
|
|
+ if blocks_to_swap_out:
|
|
|
+ self.cache_engine.swap_out(blocks_to_swap_out)
|
|
|
+ issued_cached_op = True
|
|
|
+ if blocks_to_copy:
|
|
|
+ self.cache_engine.copy(blocks_to_copy)
|
|
|
+ issued_cached_op = True
|
|
|
+
|
|
|
+ if issued_cached_op:
|
|
|
+ cache_events = self.cache_events
|
|
|
+ else:
|
|
|
+ cache_events = None
|
|
|
+
|
|
|
+ if not seq_group_metadata_list:
|
|
|
+ if cache_events is not None:
|
|
|
+ for event in cache_events:
|
|
|
+ event.wait()
|
|
|
+ return {}
|
|
|
+
|
|
|
+ input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
|
|
+ seq_group_metadata_list)
|
|
|
+
|
|
|
+ output = self.model(
|
|
|
+ input_ids=input_tokens,
|
|
|
+ positions=input_positions,
|
|
|
+ kv_caches=self.gpu_cache,
|
|
|
+ input_metadata=input_metadata,
|
|
|
+ cache_events=cache_events,
|
|
|
+ )
|
|
|
+ return output
|
|
|
+
|
|
|
+
|
|
|
+def _init_distributed_environment(
|
|
|
+ parallel_config: ParallelConfig,
|
|
|
+ rank: int,
|
|
|
+ distributed_init_method: str,
|
|
|
+) -> None:
|
|
|
+ torch.distributed.init_process_group(
|
|
|
+ backend="nccl",
|
|
|
+ world_size=parallel_config.world_size,
|
|
|
+ rank=rank,
|
|
|
+ init_method=distributed_init_method,
|
|
|
+ )
|
|
|
+ torch.distributed.all_reduce(torch.zeros(1).cuda())
|
|
|
+ initialize_model_parallel(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)
|
|
|
+
|
|
|
+def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
|
|
+ return x + [0] * ((-len(x)) % multiple_of)
|
|
|
+
|
|
|
+def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
|
|
+ return x + [0] * (max_len - len(x))
|