Forráskód Böngészése

set block size at init

AlpinDale 8 hónapja
szülő
commit
0e062e66d3

+ 8 - 12
aphrodite/task_handler/cpu_model_runner.py

@@ -4,8 +4,8 @@ import torch
 from torch import nn
 
 from aphrodite.attention import AttentionMetadata, get_attn_backend
-from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
-                                     ModelConfig, ParallelConfig,
+from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.common.utils import make_tensor_with_pad
@@ -24,6 +24,7 @@ class CPUModelRunner:
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
+        cache_config: CacheConfig,
         load_config: LoadConfig,
         lora_config: Optional[LoRAConfig],
         vision_language_config: Optional[VisionLanguageConfig],
@@ -37,27 +38,22 @@ class CPUModelRunner:
         self.scheduler_config = scheduler_config
         # Currently, CPU worker doesn't support chunked prefill.
         assert self.scheduler_config.chunked_prefill_enabled is False
+        self.device_config = device_config
+        self.cache_config = cache_config
         self.lora_config = lora_config
         self.vision_language_config = vision_language_config
         self.load_config = load_config
         self.is_driver_worker = is_driver_worker
 
-        # model_config can be None in tests/samplers/test_sampler.py.
-        # FIXME: This is a hack to make the tests work. Refactor this.
-        self.sliding_window = (model_config.get_sliding_window()
-                               if model_config is not None else None)
-        self.device_config = (device_config
-                              if device_config is not None else DeviceConfig())
         self.device = self.device_config.device
 
         self.kv_cache_dtype = kv_cache_dtype
-
-        self.attn_backend = get_attn_backend(
-            self.model_config.dtype if model_config is not None else None)
+        self.sliding_window = model_config.get_sliding_window()
+        self.block_size = cache_config.block_size
+        self.attn_backend = get_attn_backend(self.model_config.dtype)
 
         # Lazy initialization.
         self.model: nn.Module  # Set after init_Model
-        self.block_size: int  # Set after initial profiling.
 
     def load_model(self) -> None:
         self.model = get_model(

+ 1 - 0
aphrodite/task_handler/cpu_worker.py

@@ -148,6 +148,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
             parallel_config,
             scheduler_config,
             device_config,
+            cache_config,
             load_config=self.load_config,
             lora_config=self.lora_config,
             vision_language_config=self.vision_language_config,

+ 21 - 33
aphrodite/task_handler/model_runner.py

@@ -10,8 +10,8 @@ from loguru import logger
 
 from aphrodite.attention import (AttentionMetadata, AttentionMetadataPerStage,
                                  get_attn_backend)
-from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
-                                     ModelConfig, ParallelConfig,
+from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (MultiModalData, SamplerOutput,
@@ -109,6 +109,7 @@ class ModelRunner:
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         device_config: DeviceConfig,
+        cache_config: CacheConfig,
         load_config: LoadConfig,
         lora_config: Optional[LoRAConfig],
         kv_cache_dtype: Optional[str] = "auto",
@@ -118,48 +119,41 @@ class ModelRunner:
         self.model_config = model_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.cache_config = cache_config
         self.lora_config = lora_config
         self.load_config = load_config
         self.is_driver_worker = is_driver_worker
+        self.vision_language_config = vision_language_config
 
-        # model_config can be None in tests/samplers/test_sampler.py.
-        # FIXME: This is a hack to make the tests work. Refactor this.
-        self.sliding_window = (model_config.get_sliding_window()
-                               if model_config is not None else None)
-        self.device_config = (device_config
-                              if device_config is not None else DeviceConfig())
         self.device = self.device_config.device
+        self.pin_memory = is_pin_memory_available()
 
-        # Set after load_model.
-        self.lora_manager: LRUCacheWorkerLoRAManager = None
-
+        self.kv_cache_dtype = kv_cache_dtype
+        self.sliding_window = model_config.get_sliding_window()
+        self.block_size = cache_config.block_size
+        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
         self.graph_runners: Dict[int, CUDAGraphRunner] = {}
         self.graph_memory_pool: Optional[Tuple[
             int, int]] = None  # Set during graph capture.
-
-        self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
-                                       if self.model_config is not None else 0)
-
-        self.pin_memory = is_pin_memory_available()
-        self.kv_cache_dtype = kv_cache_dtype
-        self.vision_language_config = vision_language_config
-
-        self.attn_backend = get_attn_backend(
-            self.model_config.dtype if model_config is not None else None)
-
-        # Lazy initialization
-        self.model: torch.nn.Module  # Set after load_model
-        self.block_size: int  # Set after initial profiling.
         # When using CUDA graph, the input block tables must be padded to
         # max_seq_len_to_capture. However, creating the block table in
         # Python can be expensive. To optimize this, we cache the block table
         # in numpy and only copy the actual input content at every iteration.
         # The shape of the cached block table will be
         # (max batch size to capture, max context len to capture / block size).
-        self.graph_block_tables: torch.Tensor  # Set after initial profiling.
+        self.graph_block_tables = np.zeros(
+            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
+            dtype=np.int32)
+        self.attn_backend = get_attn_backend(self.model_config.dtype)
+
+        # Lazy initialization
+        self.model: torch.nn.Module  # Set after load_model
 
         # Set if the backend is flashinfer.
         self.flashinfer_workspace_buffer: torch.Tensor
+        # Set after load_model.
+        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
 
     def load_model(self) -> None:
         with CudaMemoryProfiler() as m:
@@ -217,13 +211,6 @@ class ModelRunner:
                            "but the KV cache data type is not FP8. "
                            "KV cache scaling factors will not be used.")
 
-    def set_block_size(self, block_size: int) -> None:
-        self.block_size = block_size
-
-        self.graph_block_tables = np.zeros(
-            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
-            dtype=np.int32)
-
     def get_max_block_per_batch(self) -> int:
         block_size = self.block_size
         return (self.max_seq_len_to_capture + block_size - 1) // block_size
@@ -841,6 +828,7 @@ class ModelRunner:
         dummy_lora_requests = []
         dummy_lora_requests_per_seq = []
         if self.lora_config:
+            assert self.lora_manager is not None
             with self.lora_manager.dummy_lora_cache():
                 for idx in range(self.lora_config.max_loras):
                     lora_id = idx + 1

+ 1 - 1
aphrodite/task_handler/worker.py

@@ -76,6 +76,7 @@ class Worker(WorkerBase):
             parallel_config,
             scheduler_config,
             device_config,
+            cache_config,
             load_config=load_config,
             lora_config=self.lora_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
@@ -185,7 +186,6 @@ class Worker(WorkerBase):
         self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                         self.parallel_config)
         self.gpu_cache = self.cache_engine.gpu_cache
-        self.model_runner.set_block_size(self.cache_engine.block_size)
 
     def _warm_up_model(self) -> None:
         if not self.model_config.enforce_eager: