Browse Source

fix: remove event and stream, add typing (#382)

* fix: remove cache and stream, add typing

* remove cache stream and event from cache engine

* remove cache events from worker

* formatting
AlpinDale 11 months ago
parent
commit
e3252edd07

+ 11 - 9
aphrodite/common/utils.py

@@ -5,7 +5,7 @@ import subprocess
 import uuid
 import gc
 from platform import uname
-from typing import List, Tuple, Union
+from typing import List, Tuple, Union, Generic
 from packaging.version import parse, Version
 
 import psutil
@@ -51,10 +51,10 @@ class Counter:
         self.counter = 0
 
 
-class LRUCache:
+class LRUCache(Generic[T]):
 
     def __init__(self, capacity: int):
-        self.cache = OrderedDict()
+        self.cache = OrderedDict[Hashable, T]()
         self.capacity = capacity
 
     def __contains__(self, key: Hashable) -> bool:
@@ -63,10 +63,10 @@ class LRUCache:
     def __len__(self) -> int:
         return len(self.cache)
 
-    def __getitem__(self, key: Hashable) -> Any:
+    def __getitem__(self, key: Hashable) -> T:
         return self.get(key)
 
-    def __setitem__(self, key: Hashable, value: Any) -> None:
+    def __setitem__(self, key: Hashable, value: T) -> None:
         self.put(key, value)
 
     def __delitem__(self, key: Hashable) -> None:
@@ -75,7 +75,9 @@ class LRUCache:
     def touch(self, key: Hashable) -> None:
         self.cache.move_to_end(key)
 
-    def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
+    def get(self,
+            key: Hashable,
+            default_value: Optional[T] = None) -> Optional[T]:
         if key in self.cache:
             value = self.cache[key]
             self.cache.move_to_end(key)
@@ -83,12 +85,12 @@ class LRUCache:
             value = default_value
         return value
 
-    def put(self, key: Hashable, value: Any) -> None:
+    def put(self, key: Hashable, value: T) -> None:
         self.cache[key] = value
         self.cache.move_to_end(key)
         self._remove_old_if_needed()
 
-    def _on_remove(self, key: Hashable, value: Any):
+    def _on_remove(self, key: Hashable, value: T):
         pass
 
     def remove_oldest(self):
@@ -101,7 +103,7 @@ class LRUCache:
         while len(self.cache) > self.capacity:
             self.remove_oldest()
 
-    def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
+    def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
         run_on_remove = key in self.cache
         value = self.cache.pop(key, default_value)
         if run_on_remove:

+ 3 - 3
aphrodite/lora/models.py

@@ -4,7 +4,7 @@ import logging
 import math
 import os
 import re
-from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
+from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
 
 import safetensors.torch
 import torch
@@ -537,14 +537,14 @@ class LoRAModelManager:
                 replacement_loras)
 
 
-class LoRALRUCache(LRUCache):
+class LoRALRUCache(LRUCache[LoRAModel]):
 
     def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
                                                                    None]):
         super().__init__(capacity)
         self.deactivate_lora_fn = deactivate_lora_fn
 
-    def _on_remove(self, key: Hashable, value: Any):
+    def _on_remove(self, key: Hashable, value: LoRAModel):
         logger.debug(f"Removing LoRA. int id: {key}")
         self.deactivate_lora_fn(key)
         return super()._on_remove(key, value)

+ 8 - 18
aphrodite/task_handler/cache_engine.py

@@ -36,7 +36,7 @@ class CacheEngine:
         self.num_gpu_blocks = cache_config.num_gpu_blocks
         self.num_cpu_blocks = cache_config.num_cpu_blocks
 
-        # Skip initializing CUDA stream and buffer for Neuron backend.
+        # Skip initializing KV Cache for Neuron backend.
         if is_neuron():
             return
 
@@ -49,12 +49,6 @@ class CacheEngine:
         self.gpu_cache = self.allocate_gpu_cache()
         self.cpu_cache = self.allocate_cpu_cache()
 
-        # Initialize the stream for caching operations.
-        self.cache_stream = torch.cuda.Stream()
-        assert self.cache_stream != torch.cuda.current_stream()
-        # Initialize the events for stream synchronization.
-        self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
-
     def get_key_block_shape(self) -> Tuple[int, int, int, int]:
         element_size = torch.tensor([], dtype=self.dtype).element_size()
         x = 16 // element_size
@@ -124,17 +118,13 @@ class CacheEngine:
     ) -> None:
         from aphrodite._C import cache_ops
 
-        with torch.cuda.stream(self.cache_stream):
-            for i in range(self.num_layers):
-                src_key_cache, src_value_cache = src[i]
-                dst_key_cache, dst_value_cache = dst[i]
-                # Copy the key blocks.
-                cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
-                # Copy the value blocks.
-                cache_ops.swap_blocks(src_value_cache, dst_value_cache,
-                                      src_to_dst)
-                event = self.events[i]
-                event.record(stream=self.cache_stream)
+        for i in range(self.num_layers):
+            src_key_cache, src_value_cache = src[i]
+            dst_key_cache, dst_value_cache = dst[i]
+            # Copy the key blocks.
+            cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+            # Copy the value blocks.
+            cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
 
     def swap_in(self, src_to_dst: Dict[int, int]) -> None:
         self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)

+ 1 - 14
aphrodite/task_handler/worker.py

@@ -70,7 +70,6 @@ class Worker:
         # self.init_cache_engine().
         self.cache_config = None
         self.cache_engine = None
-        self.cache_events = None
         self.gpu_cache = None
 
     def init_model(self, cupy_port: Optional[int] = None) -> None:
@@ -158,7 +157,6 @@ class Worker:
         self.cache_config = cache_config
         self.cache_engine = CacheEngine(self.cache_config, self.model_config,
                                         self.parallel_config)
-        self.cache_events = self.cache_engine.events
         self.gpu_cache = self.cache_engine.gpu_cache
         self.model_runner.set_block_size(self.cache_engine.block_size)
 
@@ -176,24 +174,13 @@ class Worker:
         blocks_to_copy: Dict[int, List[int]],
     ) -> None:
         # Issue cache operations.
-        issued_cache_op = False
+        # TODO: Profile the overhead of swapping operations and optimize
         if blocks_to_swap_in:
             self.cache_engine.swap_in(blocks_to_swap_in)
-            issued_cache_op = True
         if blocks_to_swap_out:
             self.cache_engine.swap_out(blocks_to_swap_out)
-            issued_cache_op = True
         if blocks_to_copy:
             self.cache_engine.copy(blocks_to_copy)
-            issued_cache_op = True
-
-        cache_events = self.cache_events if issued_cache_op else None
-
-        # Wait for cache operations to finish.
-        # TODO: Profile swapping overhead and optimize if needed.
-        if cache_events is not None:
-            for event in cache_events:
-                event.wait()
 
     @torch.inference_mode()
     def execute_model(

+ 13 - 6
aphrodite/transformers_utils/tokenizer_group/base_tokenizer_group.py

@@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
         pass
 
     @abstractmethod
-    def encode(self, prompt: str, request_id: Optional[str],
-               lora_request: Optional[LoRARequest]) -> List[int]:
+    def encode(self,
+               prompt: str,
+               request_id: Optional[str] = None,
+               lora_request: Optional[LoRARequest] = None) -> List[int]:
         """Encode a prompt using the tokenizer group."""
         pass
 
     @abstractmethod
-    async def encode_async(self, prompt: str, request_id: Optional[str],
-                           lora_request: Optional[LoRARequest]) -> List[int]:
+    async def encode_async(
+            self,
+            prompt: str,
+            request_id: Optional[str] = None,
+            lora_request: Optional[LoRARequest] = None) -> List[int]:
         """Encode a prompt using the tokenizer group."""
         pass
 
     @abstractmethod
     def get_lora_tokenizer(
             self,
-            lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
+            lora_request: Optional[LoRARequest] = None
+    ) -> "PreTrainedTokenizer":
         """Get a tokenizer for a LoRA request."""
         pass
 
     @abstractmethod
     async def get_lora_tokenizer_async(
             self,
-            lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
+            lora_request: Optional[LoRARequest] = None
+    ) -> "PreTrainedTokenizer":
         """Get a tokenizer for a LoRA request."""
         pass

+ 2 - 4
aphrodite/transformers_utils/tokenizer_group/tokenizer_group.py

@@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
         self.enable_lora = enable_lora
         self.max_input_length = max_input_length
         self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
-        if enable_lora:
-            self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
-        else:
-            self.lora_tokenizers = None
+        self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
+            capacity=max_num_seqs) if enable_lora else None
 
     def ping(self) -> bool:
         """Check if the tokenizer group is alive."""