瀏覽代碼

chore: add fault tolerance for RayTokenizerGroupPool

AlpinDale 7 月之前
父節點
當前提交
4ed1bb9958

+ 2 - 0
aphrodite/engine/aphrodite_engine.py

@@ -932,6 +932,8 @@ class AphroditeEngine:
         return self.model_executor.pin_lora(lora_id)
 
     def check_health(self) -> None:
+        if self.tokenizer:
+            self.tokenizer.check_health()
         self.model_executor.check_health()
 
 

+ 2 - 0
aphrodite/engine/async_aphrodite.py

@@ -303,6 +303,8 @@ class _AsyncAphrodite(AphroditeEngine):
         )
 
     async def check_health_async(self) -> None:
+        if self.tokenizer:
+            self.tokenizer.check_health()
         self.model_executor.check_health()
 
 

+ 1 - 1
aphrodite/task_handler/cpu_worker.py

@@ -269,7 +269,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
         if execute_model_req is None:
             seq_group_metadata_list = None
         else:
-            seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+            seq_group_metadata_list = execute_model_req.seq_group_metadata_listv
         if self.is_driver_worker:
             assert seq_group_metadata_list is not None
             num_seq_groups: int = len(seq_group_metadata_list)

+ 4 - 0
aphrodite/transformers_utils/tokenizer_group/base_tokenizer_group.py

@@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
     ) -> "PreTrainedTokenizer":
         """Get a tokenizer for a LoRA request."""
         pass
+
+    def check_health(self):
+        """Raise exception if the tokenizer group is unhealthy."""
+        return

+ 91 - 31
aphrodite/transformers_utils/tokenizer_group/ray_tokenizer_group.py

@@ -2,16 +2,18 @@ import asyncio
 import os
 from typing import List, Optional
 
+from ray.exceptions import ActorDiedError
+from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
 from transformers import PreTrainedTokenizer
+from loguru import logger
 
 from aphrodite.common.config import TokenizerPoolConfig
-from aphrodite.lora.request import LoRARequest
 from aphrodite.executor.ray_utils import ray
-from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
-    BaseTokenizerGroup)
-from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import (
-    TokenizerGroup)
-from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
+from aphrodite.lora.request import LoRARequest
+from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import \
+    BaseTokenizerGroup
+from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import \
+    TokenizerGroup
 
 
 class RayTokenizerGroupPool(BaseTokenizerGroup):
@@ -46,24 +48,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
                  ray_actor_options: dict, **tokenizer_config):
         # Store a local copy of the TokenizerGroup for quick access
         # to underlying HF tokenizers.
+        self._tokenizer_config = {
+            "tokenizer_id": tokenizer_id,
+            "enable_lora": enable_lora,
+            "max_num_seqs": max_num_seqs,
+            "max_input_length": max_input_length,
+            **tokenizer_config
+        }
         self._local_tokenizer_group = self._worker_cls(
-            tokenizer_id=tokenizer_id,
-            enable_lora=enable_lora,
-            max_num_seqs=max_num_seqs,
-            max_input_length=max_input_length,
-            **tokenizer_config,
-        )
-
-        ray_tokenizer_group_cls = ray.remote(
+            **self._tokenizer_config, )
+
+        self._ray_tokenizer_group_cls = ray.remote(
             self._worker_cls).options(**ray_actor_options)
-        self.tokenizer_actors = [
-            ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
-                                           max_num_seqs, max_input_length,
-                                           **tokenizer_config)
-            for _ in range(num_actors)
-        ]
+        self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
         self._idle_actors: Optional[asyncio.Queue] = None
 
+        # If set, actor is unhealthy. Will reraise on the next
+        # check_health call.
+        self._exception: Optional[ActorDiedError] = None
+
+    def _init_actor(self) -> ray.ObjectRef:
+        return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
+
     @property
     def pool_size(self) -> int:
         return len(self.tokenizer_actors)
@@ -78,6 +84,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
             for actor in self.tokenizer_actors:
                 self._idle_actors.put_nowait(actor)
 
+    def _finalize_encode(self, actor: ray.ObjectRef,
+                         original_actor: ray.ObjectRef, actor_is_alive: bool):
+        assert self._idle_actors is not None
+        # Cleanup the dead actor.
+        if not actor_is_alive or original_actor is not actor:
+            self.tokenizer_actors.remove(original_actor)
+        if actor_is_alive:
+            # Put the actor back in the queue.
+            # This is done in a finally block to ensure that the actor is
+            # always put back in the queue, even if an exception/cancellation
+            # is raised.
+            self._idle_actors.put_nowait(actor)
+            # Add back the new actor.
+            if original_actor is not actor:
+                self.tokenizer_actors.append(actor)
+
     def encode(self,
                prompt: str,
                request_id: Optional[str] = None,
@@ -88,22 +110,39 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
         The actor is then put back in the queue for future use.
         This is blocking.
         """
+        self.check_health()
         self._ensure_queue_initialized()
 
         if self._idle_actors.empty():
             raise RuntimeError("No idle actors available.")
         actor = self._idle_actors.get_nowait()
+        actor_is_alive = True
+        original_actor = actor
         try:
             ret = ray.get(
                 actor.encode.remote(request_id=request_id,
                                     prompt=prompt,
                                     lora_request=lora_request))
+        except ActorDiedError as e:
+            # If the actor is dead, we first try to reinitialize it.
+            logger.warning(
+                f"{actor} died with ActorDiedError, reinitializing.",
+                exc_info=e)
+            actor = self._init_actor()
+            try:
+                ret = ray.get(
+                    actor.encode.remote(request_id=request_id,
+                                        prompt=prompt,
+                                        lora_request=lora_request))
+            except ActorDiedError as e:
+                logger.error(f"{actor} died for second time in a row, marking "
+                             "RayTokenizerGroupPool as unhealthy.")
+                actor_is_alive = False
+                if not self._exception:
+                    self._exception = e
+                self.check_health()
         finally:
-            # Put the actor back in the queue.
-            # This is done in a finally block to ensure that the actor is
-            # always put back in the queue, even if an exception/cancellation
-            # is raised.
-            self._idle_actors.put_nowait(actor)
+            self._finalize_encode(actor, original_actor, actor_is_alive)
         return ret
 
     async def encode_async(
@@ -112,26 +151,42 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
             request_id: Optional[str] = None,
             lora_request: Optional[LoRARequest] = None) -> List[int]:
         """Encode a prompt using the tokenizer group.
-
         We pick an idle actor and use it to encode the prompt.
         If there are no idle actors, we wait until one becomes
         available.
         The actor is then put back in the queue for future use.
         This is non-blocking.
         """
+        self.check_health()
         self._ensure_queue_initialized()
+        assert self._idle_actors is not None
 
         actor = await self._idle_actors.get()
+        actor_is_alive = True
+        original_actor = actor
         try:
             ret = await actor.encode.remote(request_id=request_id,
                                             prompt=prompt,
                                             lora_request=lora_request)
+        except ActorDiedError as e:
+            # If the actor is dead, we first try to reinitialize it.
+            logger.warning(
+                f"{actor} died with ActorDiedError, reinitializing.",
+                exc_info=e)
+            actor = self._init_actor()
+            try:
+                ret = await actor.encode.remote(request_id=request_id,
+                                                prompt=prompt,
+                                                lora_request=lora_request)
+            except ActorDiedError as e:
+                logger.error(f"{actor} died for second time in a row, marking "
+                             "RayTokenizerGroupPool as unhealthy.")
+                actor_is_alive = False
+                if not self._exception:
+                    self._exception = e
+                self.check_health()
         finally:
-            # Put the actor back in the queue.
-            # This is done in a finally block to ensure that the actor is
-            # always put back in the queue, even if an exception/cancellation
-            # is raised.
-            self._idle_actors.put_nowait(actor)
+            self._finalize_encode(actor, original_actor, actor_is_alive)
         return ret
 
     def get_max_input_len(self,
@@ -153,6 +208,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
         return await self._local_tokenizer_group.get_lora_tokenizer_async(
             lora_request)
 
+    def check_health(self):
+        if self._exception:
+            raise RuntimeError(
+                "TokenizerGroupPool is unhealthy.") from self._exception
+
 
 def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
     """Copy over all current process environment variables to the runtime_env.