|
@@ -5,7 +5,8 @@ import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
|
|
|
-from aphrodite.common.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
|
|
|
+from aphrodite.common.config import (DeviceConfig, ModelConfig, LoRAConfig,
|
|
|
+ ParallelConfig, SchedulerConfig)
|
|
|
from aphrodite.common.logger import init_logger
|
|
|
from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
|
|
|
from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
|
|
@@ -37,6 +38,7 @@ class ModelRunner:
|
|
|
model_config: ModelConfig,
|
|
|
parallel_config: ParallelConfig,
|
|
|
scheduler_config: SchedulerConfig,
|
|
|
+ device_config: DeviceConfig,
|
|
|
lora_config: Optional[LoRAConfig],
|
|
|
kv_cache_dtype: Optional[str] = "auto",
|
|
|
is_driver_worker: bool = False,
|
|
@@ -51,7 +53,9 @@ class ModelRunner:
|
|
|
# FIXME: This is a hack to make the tests work. Refactor this.
|
|
|
self.sliding_window = (model_config.get_sliding_window()
|
|
|
if model_config is not None else None)
|
|
|
- self.device = torch.device(torch.cuda.current_device())
|
|
|
+ self.device_config = (device_config
|
|
|
+ if device_config is not None else DeviceConfig())
|
|
|
+ self.device = self.device_config.device
|
|
|
self.model = None
|
|
|
self.block_size = None # Set after initial profiling.
|
|
|
self.lora_manager = None
|
|
@@ -74,7 +78,8 @@ class ModelRunner:
|
|
|
self.kv_cache_dtype = kv_cache_dtype
|
|
|
|
|
|
def load_model(self) -> None:
|
|
|
- self.model = get_model(self.model_config, self.lora_config)
|
|
|
+ self.model = get_model(self.model_config, self.device_config,
|
|
|
+ self.lora_config)
|
|
|
|
|
|
vocab_size = self.model.config.vocab_size
|
|
|
|
|
@@ -184,22 +189,25 @@ class ModelRunner:
|
|
|
input_tokens = _make_tensor_with_pad(input_tokens,
|
|
|
max_prompt_len,
|
|
|
pad=0,
|
|
|
- dtype=torch.long)
|
|
|
+ dtype=torch.long,
|
|
|
+ device=self.device)
|
|
|
input_positions = _make_tensor_with_pad(input_positions,
|
|
|
max_prompt_len,
|
|
|
pad=0,
|
|
|
- dtype=torch.long)
|
|
|
+ dtype=torch.long,
|
|
|
+ device=self.device)
|
|
|
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
|
|
max_prompt_len,
|
|
|
pad=_PAD_SLOT_ID,
|
|
|
- dtype=torch.long)
|
|
|
+ dtype=torch.long,
|
|
|
+ device=self.device)
|
|
|
lora_index_mapping = [
|
|
|
_pad_to_max(mapping, max_prompt_len, pad=0)
|
|
|
for mapping in lora_index_mapping
|
|
|
]
|
|
|
context_lens_tensor = torch.tensor(context_lens,
|
|
|
dtype=torch.int,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
# Prepare prefix block tables
|
|
|
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
|
|
block_tables = _make_tensor_with_pad(
|
|
@@ -207,15 +215,16 @@ class ModelRunner:
|
|
|
max_len=max_prompt_block_table_len,
|
|
|
pad=0,
|
|
|
dtype=torch.int,
|
|
|
+ device=self.device,
|
|
|
)
|
|
|
start_loc_tensor = torch.arange(0,
|
|
|
len(prompt_lens) * max_prompt_len,
|
|
|
max_prompt_len,
|
|
|
dtype=torch.long,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
prompt_lens_tensor = torch.tensor(prompt_lens,
|
|
|
dtype=torch.long,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
|
|
|
input_metadata = InputMetadata(
|
|
|
is_prompt=True,
|
|
@@ -307,20 +316,20 @@ class ModelRunner:
|
|
|
max_len=1,
|
|
|
pad=0,
|
|
|
dtype=torch.long,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
input_positions = _make_tensor_with_pad(input_positions,
|
|
|
max_len=1,
|
|
|
pad=0,
|
|
|
dtype=torch.long,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
|
|
max_len=1,
|
|
|
pad=_PAD_SLOT_ID,
|
|
|
dtype=torch.long,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
context_lens = torch.tensor(context_lens,
|
|
|
dtype=torch.int,
|
|
|
- device="cuda")
|
|
|
+ device=self.device)
|
|
|
|
|
|
if use_captured_graph:
|
|
|
# The shape of graph_block_tables is
|
|
@@ -329,7 +338,7 @@ class ModelRunner:
|
|
|
for i, block_table in enumerate(block_tables):
|
|
|
if block_table:
|
|
|
input_block_tables[i, :len(block_table)] = block_table
|
|
|
- block_tables = torch.tensor(input_block_tables, device="cuda")
|
|
|
+ block_tables = torch.tensor(input_block_tables, device=self.device)
|
|
|
else:
|
|
|
max_block_table_len = max(
|
|
|
len(block_table) for block_table in block_tables)
|
|
@@ -338,7 +347,7 @@ class ModelRunner:
|
|
|
max_len=max_block_table_len,
|
|
|
pad=0,
|
|
|
dtype=torch.int,
|
|
|
- device="cuda",
|
|
|
+ device=self.device,
|
|
|
)
|
|
|
|
|
|
lora_index_mapping = [
|
|
@@ -413,9 +422,13 @@ class ModelRunner:
|
|
|
|
|
|
selected_token_indices = _async_h2d(selected_token_indices,
|
|
|
dtype=torch.long,
|
|
|
+ target_device=self.device,
|
|
|
pin_memory=not self.in_wsl)
|
|
|
categorized_sample_indices = {
|
|
|
- t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
|
|
|
+ t: _async_h2d(seq_ids,
|
|
|
+ dtype=torch.int,
|
|
|
+ target_device=self.device,
|
|
|
+ pin_memory=not self.in_wsl)
|
|
|
for t, seq_ids in categorized_sample_indices.items()
|
|
|
}
|
|
|
|
|
@@ -801,14 +814,10 @@ def _make_tensor_with_pad(
|
|
|
max_len: int,
|
|
|
pad: int,
|
|
|
dtype: torch.dtype,
|
|
|
- device: Union[str, torch.device] = "cuda",
|
|
|
- pin_memory: bool = False,
|
|
|
+ device: Optional[Union[str, torch.device]],
|
|
|
) -> torch.Tensor:
|
|
|
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
|
|
- return torch.tensor(padded_x,
|
|
|
- dtype=dtype,
|
|
|
- device=device,
|
|
|
- pin_memory=pin_memory and str(device) == "cpu")
|
|
|
+ return torch.tensor(padded_x, dtype=dtype, device=device)
|
|
|
|
|
|
|
|
|
def _get_graph_batch_size(batch_size: int) -> int:
|
|
@@ -820,6 +829,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
|
|
return (batch_size + 7) // 8 * 8
|
|
|
|
|
|
|
|
|
-def _async_h2d(data: list, dtype, pin_memory):
|
|
|
- t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
|
|
|
- return t.to(device="cuda", non_blocking=True)
|
|
|
+def _async_h2d(
|
|
|
+ data: list,
|
|
|
+ dtype: torch.dtype,
|
|
|
+ target_device: Union[str, torch.device],
|
|
|
+ pin_memory: bool,
|
|
|
+) -> torch.Tensor:
|
|
|
+ t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
|
|
+ return t.to(device=target_device, non_blocking=True)
|