Browse Source

chore: add fault tolerance for RayTokenizerGroupPool

AlpinDale 7 months ago
parent
commit
4ed1bb9958

+ 2 - 0
aphrodite/engine/aphrodite_engine.py

@@ -932,6 +932,8 @@ class AphroditeEngine:
         return self.model_executor.pin_lora(lora_id)
         return self.model_executor.pin_lora(lora_id)
 
 
     def check_health(self) -> None:
     def check_health(self) -> None:
+        if self.tokenizer:
+            self.tokenizer.check_health()
         self.model_executor.check_health()
         self.model_executor.check_health()
 
 
 
 

+ 2 - 0
aphrodite/engine/async_aphrodite.py

@@ -303,6 +303,8 @@ class _AsyncAphrodite(AphroditeEngine):
         )
         )
 
 
     async def check_health_async(self) -> None:
     async def check_health_async(self) -> None:
+        if self.tokenizer:
+            self.tokenizer.check_health()
         self.model_executor.check_health()
         self.model_executor.check_health()
 
 
 
 

+ 1 - 1
aphrodite/task_handler/cpu_worker.py

@@ -269,7 +269,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
         if execute_model_req is None:
         if execute_model_req is None:
             seq_group_metadata_list = None
             seq_group_metadata_list = None
         else:
         else:
-            seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+            seq_group_metadata_list = execute_model_req.seq_group_metadata_listv
         if self.is_driver_worker:
         if self.is_driver_worker:
             assert seq_group_metadata_list is not None
             assert seq_group_metadata_list is not None
             num_seq_groups: int = len(seq_group_metadata_list)
             num_seq_groups: int = len(seq_group_metadata_list)

+ 4 - 0
aphrodite/transformers_utils/tokenizer_group/base_tokenizer_group.py

@@ -53,3 +53,7 @@ class BaseTokenizerGroup(ABC):
     ) -> "PreTrainedTokenizer":
     ) -> "PreTrainedTokenizer":
         """Get a tokenizer for a LoRA request."""
         """Get a tokenizer for a LoRA request."""
         pass
         pass
+
+    def check_health(self):
+        """Raise exception if the tokenizer group is unhealthy."""
+        return

+ 91 - 31
aphrodite/transformers_utils/tokenizer_group/ray_tokenizer_group.py

@@ -2,16 +2,18 @@ import asyncio
 import os
 import os
 from typing import List, Optional
 from typing import List, Optional
 
 
+from ray.exceptions import ActorDiedError
+from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
 from transformers import PreTrainedTokenizer
 from transformers import PreTrainedTokenizer
+from loguru import logger
 
 
 from aphrodite.common.config import TokenizerPoolConfig
 from aphrodite.common.config import TokenizerPoolConfig
-from aphrodite.lora.request import LoRARequest
 from aphrodite.executor.ray_utils import ray
 from aphrodite.executor.ray_utils import ray
-from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
-    BaseTokenizerGroup)
-from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import (
-    TokenizerGroup)
-from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
+from aphrodite.lora.request import LoRARequest
+from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import \
+    BaseTokenizerGroup
+from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import \
+    TokenizerGroup
 
 
 
 
 class RayTokenizerGroupPool(BaseTokenizerGroup):
 class RayTokenizerGroupPool(BaseTokenizerGroup):
@@ -46,24 +48,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
                  ray_actor_options: dict, **tokenizer_config):
                  ray_actor_options: dict, **tokenizer_config):
         # Store a local copy of the TokenizerGroup for quick access
         # Store a local copy of the TokenizerGroup for quick access
         # to underlying HF tokenizers.
         # to underlying HF tokenizers.
+        self._tokenizer_config = {
+            "tokenizer_id": tokenizer_id,
+            "enable_lora": enable_lora,
+            "max_num_seqs": max_num_seqs,
+            "max_input_length": max_input_length,
+            **tokenizer_config
+        }
         self._local_tokenizer_group = self._worker_cls(
         self._local_tokenizer_group = self._worker_cls(
-            tokenizer_id=tokenizer_id,
-            enable_lora=enable_lora,
-            max_num_seqs=max_num_seqs,
-            max_input_length=max_input_length,
-            **tokenizer_config,
-        )
-
-        ray_tokenizer_group_cls = ray.remote(
+            **self._tokenizer_config, )
+
+        self._ray_tokenizer_group_cls = ray.remote(
             self._worker_cls).options(**ray_actor_options)
             self._worker_cls).options(**ray_actor_options)
-        self.tokenizer_actors = [
-            ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
-                                           max_num_seqs, max_input_length,
-                                           **tokenizer_config)
-            for _ in range(num_actors)
-        ]
+        self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
         self._idle_actors: Optional[asyncio.Queue] = None
         self._idle_actors: Optional[asyncio.Queue] = None
 
 
+        # If set, actor is unhealthy. Will reraise on the next
+        # check_health call.
+        self._exception: Optional[ActorDiedError] = None
+
+    def _init_actor(self) -> ray.ObjectRef:
+        return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
+
     @property
     @property
     def pool_size(self) -> int:
     def pool_size(self) -> int:
         return len(self.tokenizer_actors)
         return len(self.tokenizer_actors)
@@ -78,6 +84,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
             for actor in self.tokenizer_actors:
             for actor in self.tokenizer_actors:
                 self._idle_actors.put_nowait(actor)
                 self._idle_actors.put_nowait(actor)
 
 
+    def _finalize_encode(self, actor: ray.ObjectRef,
+                         original_actor: ray.ObjectRef, actor_is_alive: bool):
+        assert self._idle_actors is not None
+        # Cleanup the dead actor.
+        if not actor_is_alive or original_actor is not actor:
+            self.tokenizer_actors.remove(original_actor)
+        if actor_is_alive:
+            # Put the actor back in the queue.
+            # This is done in a finally block to ensure that the actor is
+            # always put back in the queue, even if an exception/cancellation
+            # is raised.
+            self._idle_actors.put_nowait(actor)
+            # Add back the new actor.
+            if original_actor is not actor:
+                self.tokenizer_actors.append(actor)
+
     def encode(self,
     def encode(self,
                prompt: str,
                prompt: str,
                request_id: Optional[str] = None,
                request_id: Optional[str] = None,
@@ -88,22 +110,39 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
         The actor is then put back in the queue for future use.
         The actor is then put back in the queue for future use.
         This is blocking.
         This is blocking.
         """
         """
+        self.check_health()
         self._ensure_queue_initialized()
         self._ensure_queue_initialized()
 
 
         if self._idle_actors.empty():
         if self._idle_actors.empty():
             raise RuntimeError("No idle actors available.")
             raise RuntimeError("No idle actors available.")
         actor = self._idle_actors.get_nowait()
         actor = self._idle_actors.get_nowait()
+        actor_is_alive = True
+        original_actor = actor
         try:
         try:
             ret = ray.get(
             ret = ray.get(
                 actor.encode.remote(request_id=request_id,
                 actor.encode.remote(request_id=request_id,
                                     prompt=prompt,
                                     prompt=prompt,
                                     lora_request=lora_request))
                                     lora_request=lora_request))
+        except ActorDiedError as e:
+            # If the actor is dead, we first try to reinitialize it.
+            logger.warning(
+                f"{actor} died with ActorDiedError, reinitializing.",
+                exc_info=e)
+            actor = self._init_actor()
+            try:
+                ret = ray.get(
+                    actor.encode.remote(request_id=request_id,
+                                        prompt=prompt,
+                                        lora_request=lora_request))
+            except ActorDiedError as e:
+                logger.error(f"{actor} died for second time in a row, marking "
+                             "RayTokenizerGroupPool as unhealthy.")
+                actor_is_alive = False
+                if not self._exception:
+                    self._exception = e
+                self.check_health()
         finally:
         finally:
-            # Put the actor back in the queue.
-            # This is done in a finally block to ensure that the actor is
-            # always put back in the queue, even if an exception/cancellation
-            # is raised.
-            self._idle_actors.put_nowait(actor)
+            self._finalize_encode(actor, original_actor, actor_is_alive)
         return ret
         return ret
 
 
     async def encode_async(
     async def encode_async(
@@ -112,26 +151,42 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
             request_id: Optional[str] = None,
             request_id: Optional[str] = None,
             lora_request: Optional[LoRARequest] = None) -> List[int]:
             lora_request: Optional[LoRARequest] = None) -> List[int]:
         """Encode a prompt using the tokenizer group.
         """Encode a prompt using the tokenizer group.
-
         We pick an idle actor and use it to encode the prompt.
         We pick an idle actor and use it to encode the prompt.
         If there are no idle actors, we wait until one becomes
         If there are no idle actors, we wait until one becomes
         available.
         available.
         The actor is then put back in the queue for future use.
         The actor is then put back in the queue for future use.
         This is non-blocking.
         This is non-blocking.
         """
         """
+        self.check_health()
         self._ensure_queue_initialized()
         self._ensure_queue_initialized()
+        assert self._idle_actors is not None
 
 
         actor = await self._idle_actors.get()
         actor = await self._idle_actors.get()
+        actor_is_alive = True
+        original_actor = actor
         try:
         try:
             ret = await actor.encode.remote(request_id=request_id,
             ret = await actor.encode.remote(request_id=request_id,
                                             prompt=prompt,
                                             prompt=prompt,
                                             lora_request=lora_request)
                                             lora_request=lora_request)
+        except ActorDiedError as e:
+            # If the actor is dead, we first try to reinitialize it.
+            logger.warning(
+                f"{actor} died with ActorDiedError, reinitializing.",
+                exc_info=e)
+            actor = self._init_actor()
+            try:
+                ret = await actor.encode.remote(request_id=request_id,
+                                                prompt=prompt,
+                                                lora_request=lora_request)
+            except ActorDiedError as e:
+                logger.error(f"{actor} died for second time in a row, marking "
+                             "RayTokenizerGroupPool as unhealthy.")
+                actor_is_alive = False
+                if not self._exception:
+                    self._exception = e
+                self.check_health()
         finally:
         finally:
-            # Put the actor back in the queue.
-            # This is done in a finally block to ensure that the actor is
-            # always put back in the queue, even if an exception/cancellation
-            # is raised.
-            self._idle_actors.put_nowait(actor)
+            self._finalize_encode(actor, original_actor, actor_is_alive)
         return ret
         return ret
 
 
     def get_max_input_len(self,
     def get_max_input_len(self,
@@ -153,6 +208,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
         return await self._local_tokenizer_group.get_lora_tokenizer_async(
         return await self._local_tokenizer_group.get_lora_tokenizer_async(
             lora_request)
             lora_request)
 
 
+    def check_health(self):
+        if self._exception:
+            raise RuntimeError(
+                "TokenizerGroupPool is unhealthy.") from self._exception
+
 
 
 def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
 def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
     """Copy over all current process environment variables to the runtime_env.
     """Copy over all current process environment variables to the runtime_env.