|
@@ -1,19 +1,5 @@
|
|
|
-import copy
|
|
|
-from collections import defaultdict
|
|
|
-import os
|
|
|
import time
|
|
|
-import pickle
|
|
|
-import importlib
|
|
|
-from typing import (
|
|
|
- TYPE_CHECKING,
|
|
|
- Any,
|
|
|
- Dict,
|
|
|
- Iterable,
|
|
|
- List,
|
|
|
- Optional,
|
|
|
- Tuple,
|
|
|
- Union,
|
|
|
-)
|
|
|
+from typing import Dict, Iterable, List, Optional, Tuple, Type, Union
|
|
|
from loguru import logger
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
@@ -29,12 +15,9 @@ from aphrodite.common.config import (
|
|
|
)
|
|
|
from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
|
|
|
from aphrodite.engine.args_tools import EngineArgs
|
|
|
+from aphrodite.executor.executor_base import ExecutorBase
|
|
|
from aphrodite.engine.metrics import StatLogger, Stats
|
|
|
-from aphrodite.engine.ray_tools import (
|
|
|
- RayWorkerAphrodite,
|
|
|
- initialize_cluster,
|
|
|
- ray,
|
|
|
-)
|
|
|
+from aphrodite.engine.ray_tools import (initialize_ray_cluster)
|
|
|
from aphrodite.common.outputs import RequestOutput
|
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
|
from aphrodite.common.sequence import (
|
|
@@ -51,33 +34,11 @@ from aphrodite.transformers_utils.tokenizer import (
|
|
|
TokenizerGroup,
|
|
|
)
|
|
|
from aphrodite.common.utils import (
|
|
|
- Counter,
|
|
|
- set_cuda_visible_devices,
|
|
|
- get_ip,
|
|
|
- get_open_port,
|
|
|
- get_distributed_init_method,
|
|
|
-)
|
|
|
+ Counter, )
|
|
|
from aphrodite.common.logger import setup_logger
|
|
|
|
|
|
-if ray:
|
|
|
- from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
|
-
|
|
|
-if TYPE_CHECKING:
|
|
|
- from ray.util.placement_group import PlacementGroup
|
|
|
-
|
|
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
|
|
|
|
|
-# A map between the device type (in device config) to its worker module.
|
|
|
-DEVICE_TO_WORKER_MODULE_MAP = {
|
|
|
- "cuda": "aphrodite.task_handler.worker",
|
|
|
- "neuron": "aphrodite.task_handler.neuron_worker",
|
|
|
-}
|
|
|
-
|
|
|
-# If the env var is set, it uses the Ray's compiled DAG API
|
|
|
-# which optimizes the control plane overhead.
|
|
|
-# Run APHRODITE with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
|
|
|
-USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
|
|
|
-
|
|
|
|
|
|
class AphroditeEngine:
|
|
|
"""An LLM engine that receives requests and generates texts.
|
|
@@ -103,8 +64,8 @@ class AphroditeEngine:
|
|
|
scheduler_config: The configuration related to the request scheduler.
|
|
|
device_config: The configuration related to the device.
|
|
|
lora_config: The configuration related to LoRA.
|
|
|
- placement_group: Ray placement group for distributed execution.
|
|
|
- Required for distributed execution.
|
|
|
+ executor_class: The model executor class for managing distributed
|
|
|
+ execution.
|
|
|
log_stats: Whether to log statistics.
|
|
|
"""
|
|
|
|
|
@@ -116,7 +77,7 @@ class AphroditeEngine:
|
|
|
scheduler_config: SchedulerConfig,
|
|
|
device_config: DeviceConfig,
|
|
|
lora_config: Optional[LoRAConfig],
|
|
|
- placement_group: Optional["PlacementGroup"],
|
|
|
+ executor_class: Type[ExecutorBase],
|
|
|
log_stats: bool,
|
|
|
) -> None:
|
|
|
logger.info(
|
|
@@ -148,33 +109,13 @@ class AphroditeEngine:
|
|
|
self._init_tokenizer()
|
|
|
self.seq_counter = Counter()
|
|
|
|
|
|
- # Create the parallel GPU workers.
|
|
|
- if self.parallel_config.worker_use_ray:
|
|
|
- # Disable Ray usage stats collection.
|
|
|
- ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
|
|
- if ray_usage != "1":
|
|
|
- os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
|
|
- # Pass additional arguments to initialize the worker
|
|
|
- additional_ray_args = {}
|
|
|
- if self.parallel_config.ray_workers_use_nsight:
|
|
|
- logger.info("Configuring Ray workers to use nsight.")
|
|
|
- additional_ray_args = {
|
|
|
- "runtime_env": {
|
|
|
- "nsight": {
|
|
|
- "t": "cuda,cudnn,cublas",
|
|
|
- "o": "'worker_process_%p'",
|
|
|
- "cuda-graph-trace": "node",
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- self._init_workers_ray(placement_group, **additional_ray_args)
|
|
|
- else:
|
|
|
- self._init_workers()
|
|
|
-
|
|
|
- # Profile the memory usage and initialize the cache.
|
|
|
- self._init_cache()
|
|
|
+ self.model_executor = executor_class(model_config, cache_config,
|
|
|
+ parallel_config, scheduler_config,
|
|
|
+ device_config, lora_config)
|
|
|
|
|
|
# Create the scheduler.
|
|
|
+ # NOTE: the cache_config here have been updated with the numbers of
|
|
|
+ # GPU and CPU blocks, which are profiled in the distributed executor.
|
|
|
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
|
|
|
|
|
# Metric Logging.
|
|
@@ -185,9 +126,29 @@ class AphroditeEngine:
|
|
|
)
|
|
|
self.stat_logger.info("cache_config", self.cache_config)
|
|
|
|
|
|
- self.forward_dag = None
|
|
|
- if USE_RAY_COMPILED_DAG:
|
|
|
- self.forward_dag = self._compiled_ray_dag()
|
|
|
+ @classmethod
|
|
|
+ def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
|
|
|
+ """Creates an LLM engine from the engine arguments."""
|
|
|
+ # Create the engine configs.
|
|
|
+ engine_configs = engine_args.create_engine_configs()
|
|
|
+ parallel_config = engine_configs[2]
|
|
|
+
|
|
|
+ # Initialize the cluster and specify the executor class.
|
|
|
+ if parallel_config.worker_use_ray:
|
|
|
+ initialize_ray_cluster(parallel_config)
|
|
|
+ from aphrodite.executor.ray_gpu_executor import RayGPUExecutor
|
|
|
+ executor_class = RayGPUExecutor
|
|
|
+ else:
|
|
|
+ assert parallel_config.world_size == 1, (
|
|
|
+ "Ray is required if parallel_config.world_size > 1.")
|
|
|
+ from aphrodite.executor.gpu_executor import GPUExecutor
|
|
|
+ executor_class = GPUExecutor
|
|
|
+
|
|
|
+ # Create the LLM engine.
|
|
|
+ engine = cls(*engine_configs,
|
|
|
+ executor_class=executor_class,
|
|
|
+ log_stats=not engine_args.disable_log_stats)
|
|
|
+ return engine
|
|
|
|
|
|
def __reduce__(self):
|
|
|
# This is to ensure that the AphroditeEngine is not referenced in
|
|
@@ -201,40 +162,6 @@ class AphroditeEngine:
|
|
|
sequence: Sequence) -> "PreTrainedTokenizer":
|
|
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
|
|
|
|
|
- def _dispatch_worker(self):
|
|
|
- worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
|
|
- self.device_config.device_type]
|
|
|
- imported_worker = importlib.import_module(worker_module)
|
|
|
- Worker = imported_worker.Worker
|
|
|
- return Worker
|
|
|
-
|
|
|
- def _init_workers(self):
|
|
|
- # Lazy import the Worker to avoid importing torch.cuda/xformers
|
|
|
- # before CUDA_VISIBLE_DEVICES is set in the Worker
|
|
|
- Worker = self._dispatch_worker()
|
|
|
-
|
|
|
- assert (self.parallel_config.world_size == 1
|
|
|
- ), "Ray is required if parallel_config.world_size > 1."
|
|
|
-
|
|
|
- self.workers: List[Worker] = []
|
|
|
- distributed_init_method = get_distributed_init_method(
|
|
|
- get_ip(), get_open_port())
|
|
|
- self.driver_worker = Worker(
|
|
|
- self.model_config,
|
|
|
- self.parallel_config,
|
|
|
- self.scheduler_config,
|
|
|
- self.device_config,
|
|
|
- local_rank=0,
|
|
|
- rank=0,
|
|
|
- distributed_init_method=distributed_init_method,
|
|
|
- lora_config=self.lora_config,
|
|
|
- kv_cache_dtype=self.cache_config.cache_dtype,
|
|
|
- # kv_quant_params_path=(self.cache_config.cache_quant_params_path),
|
|
|
- is_driver_worker=True,
|
|
|
- )
|
|
|
- self._run_workers("init_model")
|
|
|
- self._run_workers("load_model")
|
|
|
-
|
|
|
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
|
|
init_kwargs = dict(
|
|
|
enable_lora=bool(self.lora_config),
|
|
@@ -248,131 +175,6 @@ class AphroditeEngine:
|
|
|
self.tokenizer: TokenizerGroup = TokenizerGroup(
|
|
|
self.model_config.tokenizer, **init_kwargs)
|
|
|
|
|
|
- def _init_workers_ray(self, placement_group: "PlacementGroup",
|
|
|
- **ray_remote_kwargs):
|
|
|
- if self.parallel_config.tensor_parallel_size == 1:
|
|
|
- num_gpus = self.cache_config.gpu_memory_utilization
|
|
|
- else:
|
|
|
- num_gpus = 1
|
|
|
-
|
|
|
- self.driver_dummy_worker: RayWorkerAphrodite = None
|
|
|
- self.workers: List[RayWorkerAphrodite] = []
|
|
|
-
|
|
|
- driver_ip = get_ip()
|
|
|
- for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
|
|
- if not bundle.get("GPU", 0):
|
|
|
- continue
|
|
|
- scheduling_strategy = PlacementGroupSchedulingStrategy(
|
|
|
- placement_group=placement_group,
|
|
|
- placement_group_capture_child_tasks=True,
|
|
|
- placement_group_bundle_index=bundle_id,
|
|
|
- )
|
|
|
- worker = ray.remote(
|
|
|
- num_cpus=0,
|
|
|
- num_gpus=num_gpus,
|
|
|
- scheduling_strategy=scheduling_strategy,
|
|
|
- **ray_remote_kwargs,
|
|
|
- )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
|
|
|
-
|
|
|
- worker_ip = ray.get(worker.get_node_ip.remote())
|
|
|
- if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
|
|
- # If the worker is on the same node as the driver, we use it
|
|
|
- # as the resource holder for the driver process.
|
|
|
- self.driver_dummy_worker = worker
|
|
|
- else:
|
|
|
- self.workers.append(worker)
|
|
|
-
|
|
|
- if self.driver_dummy_worker is None:
|
|
|
- raise ValueError(
|
|
|
- "Ray does not allocate any GPUs on the driver node. Consider "
|
|
|
- "adjusting the Ray placement group or running the driver on a "
|
|
|
- "GPU node.")
|
|
|
-
|
|
|
- driver_node_id, driver_gpu_ids = ray.get(
|
|
|
- self.driver_dummy_worker.get_node_and_gpu_ids.remote())
|
|
|
- worker_node_and_gpu_ids = ray.get(
|
|
|
- [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
|
|
|
-
|
|
|
- node_workers = defaultdict(list)
|
|
|
- node_gpus = defaultdict(list)
|
|
|
-
|
|
|
- node_workers[driver_node_id].append(0)
|
|
|
- node_gpus[driver_node_id].extend(driver_gpu_ids)
|
|
|
- for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
|
|
|
- start=1):
|
|
|
- node_workers[node_id].append(i)
|
|
|
- node_gpus[node_id].extend(gpu_ids)
|
|
|
- for node_id, gpu_ids in node_gpus.items():
|
|
|
- node_gpus[node_id] = sorted(gpu_ids)
|
|
|
-
|
|
|
- # Set CUDA_VISIBLE_DEVICES for the driver.
|
|
|
- set_cuda_visible_devices(node_gpus[driver_node_id])
|
|
|
- for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
|
|
- worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
|
|
-
|
|
|
- distributed_init_method = get_distributed_init_method(
|
|
|
- driver_ip, get_open_port())
|
|
|
-
|
|
|
- # Lazy import the Worker to avoid importing torch.cuda/xformers
|
|
|
- # before CUDA_VISIBLE_DEVICES is set in the Worker
|
|
|
- Worker = self._dispatch_worker()
|
|
|
-
|
|
|
- # Initialize torch distributed process group for the workers.
|
|
|
- model_config = copy.deepcopy(self.model_config)
|
|
|
- parallel_config = copy.deepcopy(self.parallel_config)
|
|
|
- scheduler_config = copy.deepcopy(self.scheduler_config)
|
|
|
- device_config = copy.deepcopy(self.device_config)
|
|
|
- lora_config = copy.deepcopy(self.lora_config)
|
|
|
- kv_cache_dtype = self.cache_config.cache_dtype
|
|
|
- # kv_quant_params_path = self.cache_config.cache_quant_params_path
|
|
|
-
|
|
|
- for rank, (worker, (node_id,
|
|
|
- _)) in enumerate(zip(self.workers,
|
|
|
- worker_node_and_gpu_ids),
|
|
|
- start=1):
|
|
|
- local_rank = node_workers[node_id].index(rank)
|
|
|
- worker.init_worker.remote(
|
|
|
- lambda rank=rank, local_rank=local_rank: Worker(
|
|
|
- model_config,
|
|
|
- parallel_config,
|
|
|
- scheduler_config,
|
|
|
- device_config,
|
|
|
- local_rank,
|
|
|
- rank,
|
|
|
- distributed_init_method,
|
|
|
- lora_config=lora_config,
|
|
|
- kv_cache_dtype=kv_cache_dtype,
|
|
|
- # kv_quant_params_path=kv_quant_params_path,
|
|
|
- ))
|
|
|
-
|
|
|
- driver_rank = 0
|
|
|
- driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
|
|
- self.driver_worker = Worker(
|
|
|
- self.model_config,
|
|
|
- self.parallel_config,
|
|
|
- self.scheduler_config,
|
|
|
- self.device_config,
|
|
|
- driver_local_rank,
|
|
|
- driver_rank,
|
|
|
- distributed_init_method,
|
|
|
- lora_config=self.lora_config,
|
|
|
- kv_cache_dtype=kv_cache_dtype,
|
|
|
- # kv_quant_params_path=kv_quant_params_path,
|
|
|
- is_driver_worker=True,
|
|
|
- )
|
|
|
-
|
|
|
- # don't use cupy for eager mode
|
|
|
- self._run_workers(
|
|
|
- "init_model",
|
|
|
- cupy_port=get_open_port()
|
|
|
- if not model_config.enforce_eager else None,
|
|
|
- )
|
|
|
- self._run_workers(
|
|
|
- "load_model",
|
|
|
- max_concurrent_workers=self.parallel_config.
|
|
|
- max_parallel_loading_workers,
|
|
|
- )
|
|
|
-
|
|
|
def _verify_args(self) -> None:
|
|
|
self.model_config.verify_with_parallel_config(self.parallel_config)
|
|
|
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
|
@@ -381,90 +183,6 @@ class AphroditeEngine:
|
|
|
self.lora_config.verify_with_scheduler_config(
|
|
|
self.scheduler_config)
|
|
|
|
|
|
- def _init_cache(self) -> None:
|
|
|
- # ruff: noqa: E501
|
|
|
- """Profiles the memory usage and initializes the KV cache.
|
|
|
-
|
|
|
- The engine will first conduct a profiling of the existing memory usage.
|
|
|
- Then, it calculate the maximum possible number of GPU and CPU blocks
|
|
|
- that can be allocated with the remaining free memory.
|
|
|
- More details can be found in the
|
|
|
- :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
|
|
|
- from class :class:`~aphrodite.task_handler.Worker`.
|
|
|
-
|
|
|
- Afterwards, as there may be multiple workers,
|
|
|
- we take the minimum number of blocks across all workers
|
|
|
- to ensure this can be applied to all of them.
|
|
|
-
|
|
|
- Finally, the engine will initialize the KV cache
|
|
|
- with the calculated number of blocks.
|
|
|
-
|
|
|
- .. tip::
|
|
|
- You may limit the usage of GPU memory
|
|
|
- by adjusting the `gpu_memory_utilization` parameters.
|
|
|
- """
|
|
|
- # Get the maximum number of blocks that can be allocated on GPU and CPU.
|
|
|
- num_blocks = self._run_workers(
|
|
|
- "profile_num_available_blocks",
|
|
|
- block_size=self.cache_config.block_size,
|
|
|
- gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
|
|
- cpu_swap_space=self.cache_config.swap_space_bytes,
|
|
|
- cache_dtype=self.cache_config.cache_dtype,
|
|
|
- )
|
|
|
-
|
|
|
- # Since we use a shared centralized controller, we take the minimum
|
|
|
- # number of blocks across all workers to make sure all the memory
|
|
|
- # operators can be applied to all workers.
|
|
|
- num_gpu_blocks = min(b[0] for b in num_blocks)
|
|
|
- num_cpu_blocks = min(b[1] for b in num_blocks)
|
|
|
- # FIXME: Change to debug log.
|
|
|
- logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
|
|
- f"# CPU blocks: {num_cpu_blocks}")
|
|
|
-
|
|
|
- logger.info(
|
|
|
- f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"
|
|
|
- )
|
|
|
-
|
|
|
- if num_gpu_blocks <= 0:
|
|
|
- raise ValueError("No available memory for the cache blocks. "
|
|
|
- "Try increasing `gpu_memory_utilization` when "
|
|
|
- "initializing the engine.")
|
|
|
- max_seq_len = self.cache_config.block_size * num_gpu_blocks
|
|
|
- logger.info(f"Maximum sequence length allowed in the cache: "
|
|
|
- f"{max_seq_len}")
|
|
|
- if self.model_config.max_model_len > max_seq_len:
|
|
|
- raise ValueError(
|
|
|
- f"The model's max seq len ({self.model_config.max_model_len}) "
|
|
|
- "is larger than the maximum number of tokens that can be "
|
|
|
- f"stored in KV cache ({max_seq_len}). Try increasing "
|
|
|
- "`gpu_memory_utilization` or decreasing `max_model_len` when "
|
|
|
- "initializing the engine.")
|
|
|
-
|
|
|
- self.cache_config.num_gpu_blocks = num_gpu_blocks
|
|
|
- self.cache_config.num_cpu_blocks = num_cpu_blocks
|
|
|
-
|
|
|
- # Initialize the cache.
|
|
|
- self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
|
|
- # Warm up the model. This includes capturing the model into CUDA graph
|
|
|
- # if enforce_eager is False.
|
|
|
- self._run_workers("warm_up_model")
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def from_engine_args(cls, engine_args: EngineArgs) -> "AphroditeEngine":
|
|
|
- """Creates an LLM engine from the engine arguments."""
|
|
|
- # Create the engine configs.
|
|
|
- engine_configs = engine_args.create_engine_configs()
|
|
|
- parallel_config = engine_configs[2]
|
|
|
- # Initialize the cluster.
|
|
|
- placement_group = initialize_cluster(parallel_config)
|
|
|
- # Create the LLM engine.
|
|
|
- engine = cls(
|
|
|
- *engine_configs,
|
|
|
- placement_group,
|
|
|
- log_stats=not engine_args.disable_log_stats,
|
|
|
- )
|
|
|
- return engine
|
|
|
-
|
|
|
def encode_request(
|
|
|
self,
|
|
|
request_id: str,
|
|
@@ -894,7 +612,7 @@ class AphroditeEngine:
|
|
|
- A Sequence Group (SG) refer to a group of sequences
|
|
|
that are generated from the same prompt.
|
|
|
|
|
|
- - Step 2: Calls the workers to execute the model.
|
|
|
+ - Step 2: Calls the distributed executor to execute the model.
|
|
|
- Step 3: Processes the model output. This mainly includes:
|
|
|
|
|
|
- Decodes the relevant outputs.
|
|
@@ -930,20 +648,10 @@ class AphroditeEngine:
|
|
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
|
|
|
|
|
if not scheduler_outputs.is_empty():
|
|
|
- # Execute the model.
|
|
|
- all_outputs = self._run_workers(
|
|
|
- "execute_model",
|
|
|
- driver_kwargs={
|
|
|
- "seq_group_metadata_list": seq_group_metadata_list,
|
|
|
- "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
|
|
- "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
|
|
- "blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
|
|
- },
|
|
|
- use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
|
|
|
- )
|
|
|
-
|
|
|
- # Only the driver worker returns the sampling results.
|
|
|
- output = all_outputs[0]
|
|
|
+ output = self.model_executor.execute_model(
|
|
|
+ seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
|
|
+ scheduler_outputs.blocks_to_swap_out,
|
|
|
+ scheduler_outputs.blocks_to_copy)
|
|
|
else:
|
|
|
output = []
|
|
|
|
|
@@ -999,7 +707,8 @@ class AphroditeEngine:
|
|
|
# Latency Timings.
|
|
|
time_last_iters = []
|
|
|
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
|
|
- # Time since last token. (n.b. updates seq_group.metrics.last_token_time)
|
|
|
+ # Time since last token.
|
|
|
+ # (n.b. updates seq_group.metrics.last_token_time)
|
|
|
time_last_iters.append(seq_group.get_last_latency(now))
|
|
|
# Time since arrival for all finished requests.
|
|
|
if seq_group.is_finished():
|
|
@@ -1123,119 +832,16 @@ class AphroditeEngine:
|
|
|
seq.output_text = seq.output_text[:-len(stop_string)]
|
|
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
|
- assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
|
|
- return self._run_workers(
|
|
|
- "add_lora",
|
|
|
- lora_request=lora_request,
|
|
|
- )
|
|
|
+ return self.model_executor.add_lora(lora_request)
|
|
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
|
- assert lora_id > 0, "lora_id must be greater than 0."
|
|
|
- return self._run_workers(
|
|
|
- "remove_lora",
|
|
|
- lora_id=lora_id,
|
|
|
- )
|
|
|
+ return self.model_executor.remove_lora(lora_id)
|
|
|
|
|
|
def list_loras(self) -> List[int]:
|
|
|
- return self._run_workers("list_loras")
|
|
|
-
|
|
|
- def _run_workers(
|
|
|
- self,
|
|
|
- method: str,
|
|
|
- *args,
|
|
|
- driver_args: Optional[List[Any]] = None,
|
|
|
- driver_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
- max_concurrent_workers: Optional[int] = None,
|
|
|
- use_ray_compiled_dag: bool = False,
|
|
|
- **kwargs,
|
|
|
- ) -> Any:
|
|
|
- """Runs the given method on all workers."""
|
|
|
-
|
|
|
- if max_concurrent_workers:
|
|
|
- raise NotImplementedError(
|
|
|
- "max_concurrent_workers is not supported yet.")
|
|
|
-
|
|
|
- if use_ray_compiled_dag:
|
|
|
- # Right now, compiled DAG can only accept a single
|
|
|
- # input.
|
|
|
- # TODO: Fix it.
|
|
|
- output_channels = self.forward_dag.execute(1)
|
|
|
- else:
|
|
|
- # Start the ray workers first.
|
|
|
- ray_worker_outputs = [
|
|
|
- worker.execute_method.remote(method, *args, **kwargs)
|
|
|
- for worker in self.workers
|
|
|
- ]
|
|
|
-
|
|
|
- if driver_args is None:
|
|
|
- driver_args = args
|
|
|
- if driver_kwargs is None:
|
|
|
- driver_kwargs = kwargs
|
|
|
-
|
|
|
- # Start the driver worker after all the ray workers.
|
|
|
- driver_worker_output = getattr(self.driver_worker,
|
|
|
- method)(*driver_args, **driver_kwargs)
|
|
|
-
|
|
|
- # Get the results of the ray workers.
|
|
|
- if self.workers:
|
|
|
- if use_ray_compiled_dag:
|
|
|
- try:
|
|
|
- ray_worker_outputs = [
|
|
|
- pickle.loads(chan.begin_read())
|
|
|
- for chan in output_channels
|
|
|
- ]
|
|
|
- finally:
|
|
|
- # Has to call end_read in order to reuse the DAG.
|
|
|
- for chan in output_channels:
|
|
|
- chan.end_read()
|
|
|
- else:
|
|
|
- ray_worker_outputs = ray.get(ray_worker_outputs)
|
|
|
-
|
|
|
- return [driver_worker_output] + ray_worker_outputs
|
|
|
-
|
|
|
- def _compiled_ray_dag(self):
|
|
|
- from packaging import version
|
|
|
- import pkg_resources
|
|
|
-
|
|
|
- required_version = "2.9"
|
|
|
- current_version = pkg_resources.get_distribution("ray").version
|
|
|
-
|
|
|
- if version.parse(current_version) < version.parse(required_version):
|
|
|
- raise ValueError(f"Ray version {required_version} or greater is "
|
|
|
- f"required, but found {current_version}")
|
|
|
-
|
|
|
- from ray.dag import MultiOutputNode, InputNode
|
|
|
-
|
|
|
- assert self.parallel_config.worker_use_ray
|
|
|
-
|
|
|
- # Right now, compiled DAG requires at least 1 arg. We send
|
|
|
- # a dummy value for now. It will be fixed soon.
|
|
|
- with InputNode() as input_data:
|
|
|
- forward_dag = MultiOutputNode([
|
|
|
- worker.execute_model_compiled_dag_remote.bind(input_data)
|
|
|
- for worker in self.workers
|
|
|
- ])
|
|
|
- return forward_dag.experimental_compile()
|
|
|
+ return self.model_executor.list_loras()
|
|
|
|
|
|
def check_health(self) -> None:
|
|
|
- """Raises an error if engine is unhealthy."""
|
|
|
- self._check_if_any_actor_is_dead()
|
|
|
-
|
|
|
- def _check_if_any_actor_is_dead(self):
|
|
|
- if not self.parallel_config.worker_use_ray:
|
|
|
- return
|
|
|
-
|
|
|
- if not self.workers:
|
|
|
- return
|
|
|
-
|
|
|
- dead_actors = []
|
|
|
- for actor in self.workers:
|
|
|
- actor_state = ray.state.actors(actor._ray_actor_id.hex())
|
|
|
- if actor_state["State"] == "DEAD":
|
|
|
- dead_actors.append(actor)
|
|
|
- if dead_actors:
|
|
|
- raise RuntimeError("At least one Worker is dead. "
|
|
|
- f"Dead Workers: {dead_actors}. ")
|
|
|
+ self.model_executor.check_health()
|
|
|
|
|
|
|
|
|
setup_logger()
|