|
@@ -2,16 +2,18 @@ import asyncio
|
|
import os
|
|
import os
|
|
from typing import List, Optional
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
+from ray.exceptions import ActorDiedError
|
|
|
|
+from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
from transformers import PreTrainedTokenizer
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
+from loguru import logger
|
|
|
|
|
|
from aphrodite.common.config import TokenizerPoolConfig
|
|
from aphrodite.common.config import TokenizerPoolConfig
|
|
-from aphrodite.lora.request import LoRARequest
|
|
|
|
from aphrodite.executor.ray_utils import ray
|
|
from aphrodite.executor.ray_utils import ray
|
|
-from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
|
|
|
- BaseTokenizerGroup)
|
|
|
|
-from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import (
|
|
|
|
- TokenizerGroup)
|
|
|
|
-from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
|
|
|
|
|
+from aphrodite.lora.request import LoRARequest
|
|
|
|
+from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import \
|
|
|
|
+ BaseTokenizerGroup
|
|
|
|
+from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import \
|
|
|
|
+ TokenizerGroup
|
|
|
|
|
|
|
|
|
|
class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
@@ -46,24 +48,28 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
ray_actor_options: dict, **tokenizer_config):
|
|
ray_actor_options: dict, **tokenizer_config):
|
|
# Store a local copy of the TokenizerGroup for quick access
|
|
# Store a local copy of the TokenizerGroup for quick access
|
|
# to underlying HF tokenizers.
|
|
# to underlying HF tokenizers.
|
|
|
|
+ self._tokenizer_config = {
|
|
|
|
+ "tokenizer_id": tokenizer_id,
|
|
|
|
+ "enable_lora": enable_lora,
|
|
|
|
+ "max_num_seqs": max_num_seqs,
|
|
|
|
+ "max_input_length": max_input_length,
|
|
|
|
+ **tokenizer_config
|
|
|
|
+ }
|
|
self._local_tokenizer_group = self._worker_cls(
|
|
self._local_tokenizer_group = self._worker_cls(
|
|
- tokenizer_id=tokenizer_id,
|
|
|
|
- enable_lora=enable_lora,
|
|
|
|
- max_num_seqs=max_num_seqs,
|
|
|
|
- max_input_length=max_input_length,
|
|
|
|
- **tokenizer_config,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- ray_tokenizer_group_cls = ray.remote(
|
|
|
|
|
|
+ **self._tokenizer_config, )
|
|
|
|
+
|
|
|
|
+ self._ray_tokenizer_group_cls = ray.remote(
|
|
self._worker_cls).options(**ray_actor_options)
|
|
self._worker_cls).options(**ray_actor_options)
|
|
- self.tokenizer_actors = [
|
|
|
|
- ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
|
|
|
|
- max_num_seqs, max_input_length,
|
|
|
|
- **tokenizer_config)
|
|
|
|
- for _ in range(num_actors)
|
|
|
|
- ]
|
|
|
|
|
|
+ self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
|
|
self._idle_actors: Optional[asyncio.Queue] = None
|
|
self._idle_actors: Optional[asyncio.Queue] = None
|
|
|
|
|
|
|
|
+ # If set, actor is unhealthy. Will reraise on the next
|
|
|
|
+ # check_health call.
|
|
|
|
+ self._exception: Optional[ActorDiedError] = None
|
|
|
|
+
|
|
|
|
+ def _init_actor(self) -> ray.ObjectRef:
|
|
|
|
+ return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
|
|
|
|
+
|
|
@property
|
|
@property
|
|
def pool_size(self) -> int:
|
|
def pool_size(self) -> int:
|
|
return len(self.tokenizer_actors)
|
|
return len(self.tokenizer_actors)
|
|
@@ -78,6 +84,22 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
for actor in self.tokenizer_actors:
|
|
for actor in self.tokenizer_actors:
|
|
self._idle_actors.put_nowait(actor)
|
|
self._idle_actors.put_nowait(actor)
|
|
|
|
|
|
|
|
+ def _finalize_encode(self, actor: ray.ObjectRef,
|
|
|
|
+ original_actor: ray.ObjectRef, actor_is_alive: bool):
|
|
|
|
+ assert self._idle_actors is not None
|
|
|
|
+ # Cleanup the dead actor.
|
|
|
|
+ if not actor_is_alive or original_actor is not actor:
|
|
|
|
+ self.tokenizer_actors.remove(original_actor)
|
|
|
|
+ if actor_is_alive:
|
|
|
|
+ # Put the actor back in the queue.
|
|
|
|
+ # This is done in a finally block to ensure that the actor is
|
|
|
|
+ # always put back in the queue, even if an exception/cancellation
|
|
|
|
+ # is raised.
|
|
|
|
+ self._idle_actors.put_nowait(actor)
|
|
|
|
+ # Add back the new actor.
|
|
|
|
+ if original_actor is not actor:
|
|
|
|
+ self.tokenizer_actors.append(actor)
|
|
|
|
+
|
|
def encode(self,
|
|
def encode(self,
|
|
prompt: str,
|
|
prompt: str,
|
|
request_id: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
@@ -88,22 +110,39 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
The actor is then put back in the queue for future use.
|
|
The actor is then put back in the queue for future use.
|
|
This is blocking.
|
|
This is blocking.
|
|
"""
|
|
"""
|
|
|
|
+ self.check_health()
|
|
self._ensure_queue_initialized()
|
|
self._ensure_queue_initialized()
|
|
|
|
|
|
if self._idle_actors.empty():
|
|
if self._idle_actors.empty():
|
|
raise RuntimeError("No idle actors available.")
|
|
raise RuntimeError("No idle actors available.")
|
|
actor = self._idle_actors.get_nowait()
|
|
actor = self._idle_actors.get_nowait()
|
|
|
|
+ actor_is_alive = True
|
|
|
|
+ original_actor = actor
|
|
try:
|
|
try:
|
|
ret = ray.get(
|
|
ret = ray.get(
|
|
actor.encode.remote(request_id=request_id,
|
|
actor.encode.remote(request_id=request_id,
|
|
prompt=prompt,
|
|
prompt=prompt,
|
|
lora_request=lora_request))
|
|
lora_request=lora_request))
|
|
|
|
+ except ActorDiedError as e:
|
|
|
|
+ # If the actor is dead, we first try to reinitialize it.
|
|
|
|
+ logger.warning(
|
|
|
|
+ f"{actor} died with ActorDiedError, reinitializing.",
|
|
|
|
+ exc_info=e)
|
|
|
|
+ actor = self._init_actor()
|
|
|
|
+ try:
|
|
|
|
+ ret = ray.get(
|
|
|
|
+ actor.encode.remote(request_id=request_id,
|
|
|
|
+ prompt=prompt,
|
|
|
|
+ lora_request=lora_request))
|
|
|
|
+ except ActorDiedError as e:
|
|
|
|
+ logger.error(f"{actor} died for second time in a row, marking "
|
|
|
|
+ "RayTokenizerGroupPool as unhealthy.")
|
|
|
|
+ actor_is_alive = False
|
|
|
|
+ if not self._exception:
|
|
|
|
+ self._exception = e
|
|
|
|
+ self.check_health()
|
|
finally:
|
|
finally:
|
|
- # Put the actor back in the queue.
|
|
|
|
- # This is done in a finally block to ensure that the actor is
|
|
|
|
- # always put back in the queue, even if an exception/cancellation
|
|
|
|
- # is raised.
|
|
|
|
- self._idle_actors.put_nowait(actor)
|
|
|
|
|
|
+ self._finalize_encode(actor, original_actor, actor_is_alive)
|
|
return ret
|
|
return ret
|
|
|
|
|
|
async def encode_async(
|
|
async def encode_async(
|
|
@@ -112,26 +151,42 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
request_id: Optional[str] = None,
|
|
request_id: Optional[str] = None,
|
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
|
"""Encode a prompt using the tokenizer group.
|
|
"""Encode a prompt using the tokenizer group.
|
|
-
|
|
|
|
We pick an idle actor and use it to encode the prompt.
|
|
We pick an idle actor and use it to encode the prompt.
|
|
If there are no idle actors, we wait until one becomes
|
|
If there are no idle actors, we wait until one becomes
|
|
available.
|
|
available.
|
|
The actor is then put back in the queue for future use.
|
|
The actor is then put back in the queue for future use.
|
|
This is non-blocking.
|
|
This is non-blocking.
|
|
"""
|
|
"""
|
|
|
|
+ self.check_health()
|
|
self._ensure_queue_initialized()
|
|
self._ensure_queue_initialized()
|
|
|
|
+ assert self._idle_actors is not None
|
|
|
|
|
|
actor = await self._idle_actors.get()
|
|
actor = await self._idle_actors.get()
|
|
|
|
+ actor_is_alive = True
|
|
|
|
+ original_actor = actor
|
|
try:
|
|
try:
|
|
ret = await actor.encode.remote(request_id=request_id,
|
|
ret = await actor.encode.remote(request_id=request_id,
|
|
prompt=prompt,
|
|
prompt=prompt,
|
|
lora_request=lora_request)
|
|
lora_request=lora_request)
|
|
|
|
+ except ActorDiedError as e:
|
|
|
|
+ # If the actor is dead, we first try to reinitialize it.
|
|
|
|
+ logger.warning(
|
|
|
|
+ f"{actor} died with ActorDiedError, reinitializing.",
|
|
|
|
+ exc_info=e)
|
|
|
|
+ actor = self._init_actor()
|
|
|
|
+ try:
|
|
|
|
+ ret = await actor.encode.remote(request_id=request_id,
|
|
|
|
+ prompt=prompt,
|
|
|
|
+ lora_request=lora_request)
|
|
|
|
+ except ActorDiedError as e:
|
|
|
|
+ logger.error(f"{actor} died for second time in a row, marking "
|
|
|
|
+ "RayTokenizerGroupPool as unhealthy.")
|
|
|
|
+ actor_is_alive = False
|
|
|
|
+ if not self._exception:
|
|
|
|
+ self._exception = e
|
|
|
|
+ self.check_health()
|
|
finally:
|
|
finally:
|
|
- # Put the actor back in the queue.
|
|
|
|
- # This is done in a finally block to ensure that the actor is
|
|
|
|
- # always put back in the queue, even if an exception/cancellation
|
|
|
|
- # is raised.
|
|
|
|
- self._idle_actors.put_nowait(actor)
|
|
|
|
|
|
+ self._finalize_encode(actor, original_actor, actor_is_alive)
|
|
return ret
|
|
return ret
|
|
|
|
|
|
def get_max_input_len(self,
|
|
def get_max_input_len(self,
|
|
@@ -153,6 +208,11 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
|
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
|
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
|
lora_request)
|
|
lora_request)
|
|
|
|
|
|
|
|
+ def check_health(self):
|
|
|
|
+ if self._exception:
|
|
|
|
+ raise RuntimeError(
|
|
|
|
+ "TokenizerGroupPool is unhealthy.") from self._exception
|
|
|
|
+
|
|
|
|
|
|
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
|
|
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
|
|
"""Copy over all current process environment variables to the runtime_env.
|
|
"""Copy over all current process environment variables to the runtime_env.
|