ray_tokenizer_group.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import asyncio
  2. import os
  3. from typing import List, Optional
  4. from transformers import PreTrainedTokenizer
  5. from aphrodite.common.config import TokenizerPoolConfig
  6. from aphrodite.lora.request import LoRARequest
  7. from aphrodite.engine.ray_tools import ray
  8. from aphrodite.transformers_utils.tokenizer_group.base_tokenizer_group import (
  9. BaseTokenizerGroup)
  10. from aphrodite.transformers_utils.tokenizer_group.tokenizer_group import (
  11. TokenizerGroup)
  12. from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
  13. class RayTokenizerGroupPool(BaseTokenizerGroup):
  14. """A Ray-based pool of TokenizerGroups for async tokenization."""
  15. # Class to use for workers making up the pool.
  16. _worker_cls = TokenizerGroup
  17. @classmethod
  18. def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
  19. **init_kwargs) -> "RayTokenizerGroupPool":
  20. ray_actor_options = (tokenizer_pool_config.extra_config or {
  21. "num_cpus": 0
  22. })
  23. ray_actor_options.setdefault(
  24. "scheduling_strategy",
  25. NodeAffinitySchedulingStrategy(
  26. node_id=ray.get_runtime_context().get_node_id(), soft=True))
  27. # Carry over the env vars to the actors.
  28. # This is necessary for API keys and such.
  29. ray_actor_options.setdefault("runtime_env", {})
  30. _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
  31. init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
  32. init_kwargs["ray_actor_options"] = ray_actor_options
  33. return cls(**init_kwargs)
  34. def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
  35. max_input_length: Optional[int], num_actors: int,
  36. ray_actor_options: dict, **tokenizer_config):
  37. # Store a local copy of the TokenizerGroup for quick access
  38. # to underlying HF tokenizers.
  39. self._local_tokenizer_group = self._worker_cls(
  40. tokenizer_id=tokenizer_id,
  41. enable_lora=enable_lora,
  42. max_num_seqs=max_num_seqs,
  43. max_input_length=max_input_length,
  44. **tokenizer_config,
  45. )
  46. ray_tokenizer_group_cls = ray.remote(
  47. self._worker_cls).options(**ray_actor_options)
  48. self.tokenizer_actors = [
  49. ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
  50. max_num_seqs, max_input_length,
  51. **tokenizer_config)
  52. for _ in range(num_actors)
  53. ]
  54. self._idle_actors: Optional[asyncio.Queue] = None
  55. @property
  56. def pool_size(self) -> int:
  57. return len(self.tokenizer_actors)
  58. def ping(self):
  59. return ray.get(
  60. [actor.ping.remote() for actor in self.tokenizer_actors])
  61. def _ensure_queue_initialized(self):
  62. if self._idle_actors is None:
  63. self._idle_actors = asyncio.Queue()
  64. for actor in self.tokenizer_actors:
  65. self._idle_actors.put_nowait(actor)
  66. def encode(self,
  67. prompt: str,
  68. request_id: Optional[str] = None,
  69. lora_request: Optional[LoRARequest] = None) -> List[int]:
  70. """Encode a prompt using the tokenizer group.
  71. We pick an idle actor and use it to encode the prompt.
  72. The actor is then put back in the queue for future use.
  73. This is blocking.
  74. """
  75. self._ensure_queue_initialized()
  76. if self._idle_actors.empty():
  77. raise RuntimeError("No idle actors available.")
  78. actor = self._idle_actors.get_nowait()
  79. try:
  80. ret = ray.get(
  81. actor.encode.remote(request_id=request_id,
  82. prompt=prompt,
  83. lora_request=lora_request))
  84. finally:
  85. # Put the actor back in the queue.
  86. # This is done in a finally block to ensure that the actor is
  87. # always put back in the queue, even if an exception/cancellation
  88. # is raised.
  89. self._idle_actors.put_nowait(actor)
  90. return ret
  91. async def encode_async(
  92. self,
  93. prompt: str,
  94. request_id: Optional[str] = None,
  95. lora_request: Optional[LoRARequest] = None) -> List[int]:
  96. """Encode a prompt using the tokenizer group.
  97. We pick an idle actor and use it to encode the prompt.
  98. If there are no idle actors, we wait until one becomes
  99. available.
  100. The actor is then put back in the queue for future use.
  101. This is non-blocking.
  102. """
  103. self._ensure_queue_initialized()
  104. actor = await self._idle_actors.get()
  105. try:
  106. ret = await actor.encode.remote(request_id=request_id,
  107. prompt=prompt,
  108. lora_request=lora_request)
  109. finally:
  110. # Put the actor back in the queue.
  111. # This is done in a finally block to ensure that the actor is
  112. # always put back in the queue, even if an exception/cancellation
  113. # is raised.
  114. self._idle_actors.put_nowait(actor)
  115. return ret
  116. def get_max_input_len(self,
  117. lora_request: Optional[LoRARequest] = None
  118. ) -> Optional[int]:
  119. """Get the maximum input length for the LoRA request."""
  120. return self._local_tokenizer_group.get_max_input_len(lora_request)
  121. def get_lora_tokenizer(
  122. self,
  123. lora_request: Optional[LoRARequest] = None
  124. ) -> "PreTrainedTokenizer":
  125. return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
  126. async def get_lora_tokenizer_async(
  127. self,
  128. lora_request: Optional[LoRARequest] = None
  129. ) -> "PreTrainedTokenizer":
  130. return await self._local_tokenizer_group.get_lora_tokenizer_async(
  131. lora_request)
  132. def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
  133. """Copy over all current process environment variables to the runtime_env.
  134. The variables in runtime_env will take precedence over the current process
  135. environment variables.
  136. runtime_env will be modified in place."""
  137. env_vars = os.environ.copy()
  138. runtime_env.setdefault("env_vars", {})
  139. env_vars.update(runtime_env["env_vars"])
  140. runtime_env["env_vars"] = env_vars