ray_tokenizer_group.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import asyncio
  2. import os
  3. from typing import List, Optional
  4. from ray.exceptions import ActorDiedError
  5. from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
  6. from transformers import PreTrainedTokenizer
  7. from loguru import logger
  8. from aphrodite.common.config import TokenizerPoolConfig
  9. from aphrodite.executor.ray_utils import ray
  10. from aphrodite.lora.request import LoRARequest
  11. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import \
  12. BaseTokenizerGroup
  13. from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import \
  14. TokenizerGroup
  15. class RayTokenizerGroupPool(BaseTokenizerGroup):
  16. """A Ray-based pool of TokenizerGroups for async tokenization."""
  17. # Class to use for workers making up the pool.
  18. _worker_cls = TokenizerGroup
  19. @classmethod
  20. def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
  21. **init_kwargs) -> "RayTokenizerGroupPool":
  22. ray_actor_options = (tokenizer_pool_config.extra_config or {
  23. "num_cpus": 0
  24. })
  25. ray_actor_options.setdefault(
  26. "scheduling_strategy",
  27. NodeAffinitySchedulingStrategy(
  28. node_id=ray.get_runtime_context().get_node_id(), soft=True))
  29. # Carry over the env vars to the actors.
  30. # This is necessary for API keys and such.
  31. ray_actor_options.setdefault("runtime_env", {})
  32. _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
  33. init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
  34. init_kwargs["ray_actor_options"] = ray_actor_options
  35. return cls(**init_kwargs)
  36. def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
  37. max_input_length: Optional[int], num_actors: int,
  38. ray_actor_options: dict, **tokenizer_config):
  39. # Store a local copy of the TokenizerGroup for quick access
  40. # to underlying HF tokenizers.
  41. self._tokenizer_config = {
  42. "tokenizer_id": tokenizer_id,
  43. "enable_lora": enable_lora,
  44. "max_num_seqs": max_num_seqs,
  45. "max_input_length": max_input_length,
  46. **tokenizer_config
  47. }
  48. self._local_tokenizer_group = self._worker_cls(
  49. **self._tokenizer_config, )
  50. self._ray_tokenizer_group_cls = ray.remote(
  51. self._worker_cls).options(**ray_actor_options)
  52. self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
  53. self._idle_actors: Optional[asyncio.Queue] = None
  54. # If set, actor is unhealthy. Will reraise on the next
  55. # check_health call.
  56. self._exception: Optional[ActorDiedError] = None
  57. def _init_actor(self) -> ray.ObjectRef:
  58. return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)
  59. @property
  60. def pool_size(self) -> int:
  61. return len(self.tokenizer_actors)
  62. def ping(self):
  63. return ray.get(
  64. [actor.ping.remote() for actor in self.tokenizer_actors])
  65. def _ensure_queue_initialized(self):
  66. if self._idle_actors is None:
  67. self._idle_actors = asyncio.Queue()
  68. for actor in self.tokenizer_actors:
  69. self._idle_actors.put_nowait(actor)
  70. def _finalize_encode(self, actor: ray.ObjectRef,
  71. original_actor: ray.ObjectRef, actor_is_alive: bool):
  72. assert self._idle_actors is not None
  73. # Cleanup the dead actor.
  74. if not actor_is_alive or original_actor is not actor:
  75. self.tokenizer_actors.remove(original_actor)
  76. if actor_is_alive:
  77. # Put the actor back in the queue.
  78. # This is done in a finally block to ensure that the actor is
  79. # always put back in the queue, even if an exception/cancellation
  80. # is raised.
  81. self._idle_actors.put_nowait(actor)
  82. # Add back the new actor.
  83. if original_actor is not actor:
  84. self.tokenizer_actors.append(actor)
  85. def encode(self,
  86. prompt: str,
  87. request_id: Optional[str] = None,
  88. lora_request: Optional[LoRARequest] = None) -> List[int]:
  89. """Encode a prompt using the tokenizer group.
  90. We pick an idle actor and use it to encode the prompt.
  91. The actor is then put back in the queue for future use.
  92. This is blocking.
  93. """
  94. self.check_health()
  95. self._ensure_queue_initialized()
  96. if self._idle_actors.empty():
  97. raise RuntimeError("No idle actors available.")
  98. actor = self._idle_actors.get_nowait()
  99. actor_is_alive = True
  100. original_actor = actor
  101. try:
  102. ret = ray.get(
  103. actor.encode.remote(request_id=request_id,
  104. prompt=prompt,
  105. lora_request=lora_request))
  106. except ActorDiedError as e:
  107. # If the actor is dead, we first try to reinitialize it.
  108. logger.warning(
  109. f"{actor} died with ActorDiedError, reinitializing.",
  110. exc_info=e)
  111. actor = self._init_actor()
  112. try:
  113. ret = ray.get(
  114. actor.encode.remote(request_id=request_id,
  115. prompt=prompt,
  116. lora_request=lora_request))
  117. except ActorDiedError as e:
  118. logger.error(f"{actor} died for second time in a row, marking "
  119. "RayTokenizerGroupPool as unhealthy.")
  120. actor_is_alive = False
  121. if not self._exception:
  122. self._exception = e
  123. self.check_health()
  124. finally:
  125. self._finalize_encode(actor, original_actor, actor_is_alive)
  126. return ret
  127. async def encode_async(
  128. self,
  129. prompt: str,
  130. request_id: Optional[str] = None,
  131. lora_request: Optional[LoRARequest] = None) -> List[int]:
  132. """Encode a prompt using the tokenizer group.
  133. We pick an idle actor and use it to encode the prompt.
  134. If there are no idle actors, we wait until one becomes
  135. available.
  136. The actor is then put back in the queue for future use.
  137. This is non-blocking.
  138. """
  139. self.check_health()
  140. self._ensure_queue_initialized()
  141. assert self._idle_actors is not None
  142. actor = await self._idle_actors.get()
  143. actor_is_alive = True
  144. original_actor = actor
  145. try:
  146. ret = await actor.encode.remote(request_id=request_id,
  147. prompt=prompt,
  148. lora_request=lora_request)
  149. except ActorDiedError as e:
  150. # If the actor is dead, we first try to reinitialize it.
  151. logger.warning(
  152. f"{actor} died with ActorDiedError, reinitializing.",
  153. exc_info=e)
  154. actor = self._init_actor()
  155. try:
  156. ret = await actor.encode.remote(request_id=request_id,
  157. prompt=prompt,
  158. lora_request=lora_request)
  159. except ActorDiedError as e:
  160. logger.error(f"{actor} died for second time in a row, marking "
  161. "RayTokenizerGroupPool as unhealthy.")
  162. actor_is_alive = False
  163. if not self._exception:
  164. self._exception = e
  165. self.check_health()
  166. finally:
  167. self._finalize_encode(actor, original_actor, actor_is_alive)
  168. return ret
  169. def get_max_input_len(self,
  170. lora_request: Optional[LoRARequest] = None
  171. ) -> Optional[int]:
  172. """Get the maximum input length for the LoRA request."""
  173. return self._local_tokenizer_group.get_max_input_len(lora_request)
  174. def get_lora_tokenizer(
  175. self,
  176. lora_request: Optional[LoRARequest] = None
  177. ) -> "PreTrainedTokenizer":
  178. return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
  179. async def get_lora_tokenizer_async(
  180. self,
  181. lora_request: Optional[LoRARequest] = None
  182. ) -> "PreTrainedTokenizer":
  183. return await self._local_tokenizer_group.get_lora_tokenizer_async(
  184. lora_request)
  185. def check_health(self):
  186. if self._exception:
  187. raise RuntimeError(
  188. "TokenizerGroupPool is unhealthy.") from self._exception
  189. def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
  190. """Copy over all current process environment variables to the runtime_env.
  191. The variables in runtime_env will take precedence over the current process
  192. environment variables.
  193. runtime_env will be modified in place."""
  194. env_vars = os.environ.copy()
  195. runtime_env.setdefault("env_vars", {})
  196. env_vars.update(runtime_env["env_vars"])
  197. runtime_env["env_vars"] = env_vars