Bläddra i källkod

fix: tokenizer when using ray (#366)

AlpinDale 11 månader sedan
förälder
incheckning
b361096463

+ 1 - 1
aphrodite/endpoints/openai/serving_chat.py

@@ -66,7 +66,7 @@ class OpenAIServingChat(OpenAIServing):
             lora_request = self._maybe_get_lora(request)
             guided_decode_logits_processor = (
                 await get_guided_decoding_logits_processor(
-                    request, self.engine.get_tokenizer()))
+                    request, await self.engine.get_tokenizer()))
             if guided_decode_logits_processor:
                 if sampling_params.logits_processors is None:
                     sampling_params.logits_processors = []

+ 1 - 1
aphrodite/endpoints/openai/serving_completions.py

@@ -126,7 +126,7 @@ class OpenAIServingCompletion(OpenAIServing):
             lora_request = self._maybe_get_lora(request)
             guided_decode_logit_processor = (
                 await get_guided_decoding_logits_processor(
-                    request, self.engine.get_tokenizer()))
+                    request, await self.engine.get_tokenizer()))
             if guided_decode_logit_processor is not None:
                 if sampling_params.logits_processors is None:
                     sampling_params.logits_processors = []

+ 6 - 1
aphrodite/engine/aphrodite_engine.py

@@ -15,6 +15,7 @@ from typing import (
     Union,
 )
 from loguru import logger
+from transformers import PreTrainedTokenizer
 
 import aphrodite
 from aphrodite.lora.request import LoRARequest
@@ -193,7 +194,11 @@ class AphroditeEngine:
         # the closure used to initialize Ray worker actors
         raise RuntimeError("AphroditeEngine should not be pickled!")
 
-    def get_tokenizer_for_seq(self, sequence: Sequence):
+    def get_tokenizer(self) -> "PreTrainedTokenizer":
+        return self.tokenizer.get_lora_tokenizer()
+
+    def get_tokenizer_for_seq(self,
+                              sequence: Sequence) -> "PreTrainedTokenizer":
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
 
     def _dispatch_worker(self):

+ 6 - 2
aphrodite/engine/async_aphrodite.py

@@ -5,6 +5,7 @@ from functools import partial
 from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
                     Union, AsyncIterator, Callable)
 from loguru import logger
+from transformers import PreTrainedTokenizer
 
 from aphrodite.lora.request import LoRARequest
 from aphrodite.common.config import ModelConfig
@@ -377,8 +378,11 @@ class AsyncAphrodite:
         self.set_errored(exc)
         self._request_tracker.propagate_exception(exc)
 
-    def get_tokenizer(self):
-        return self.engine.tokenizer.tokenizer
+    async def get_tokenizer(self) -> "PreTrainedTokenizer":
+        if self.engine_use_ray:
+            return await self.engine.get_tokenizer.remote()
+        else:
+            return self.engine.get_tokenizer()
 
     def start_background_loop(self) -> None:
         """Start the background loop."""

+ 4 - 2
aphrodite/transformers_utils/tokenizer.py

@@ -177,7 +177,8 @@ class TokenizerGroup:
 
     def get_lora_tokenizer(
             self,
-            lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
+            lora_request: Optional[LoRARequest] = None
+    ) -> "PreTrainedTokenizer":
         if not lora_request or not self.enable_lora:
             return self.tokenizer
         if lora_request.lora_int_id not in self.lora_tokenizers:
@@ -190,7 +191,8 @@ class TokenizerGroup:
 
     async def get_lora_tokenizer_async(
             self,
-            lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
+            lora_request: Optional[LoRARequest] = None
+    ) -> "PreTrainedTokenizer":
         if not lora_request or not self.enable_lora:
             return self.tokenizer
         if lora_request.lora_int_id not in self.lora_tokenizers: