|
@@ -9,102 +9,98 @@ from aphrodite.lora.request import LoRARequest
|
|
from aphrodite.prompt_adapter.request import PromptAdapterRequest
|
|
from aphrodite.prompt_adapter.request import PromptAdapterRequest
|
|
from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
|
from aphrodite.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
|
|
|
|
|
-from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
|
|
|
- SingletonPromptInputs)
|
|
|
|
|
|
+from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
|
|
|
|
+ SingletonPrompt)
|
|
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
|
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
if TYPE_CHECKING:
|
|
from aphrodite.multimodal import MultiModalDataDict
|
|
from aphrodite.multimodal import MultiModalDataDict
|
|
|
|
|
|
|
|
|
|
-PromptComponents = Tuple[
|
|
|
|
- Optional[str], List[int], Optional["MultiModalDataDict"]
|
|
|
|
-]
|
|
|
|
-DecoderPromptComponents = Tuple[
|
|
|
|
- Optional[str], Optional[List[int]], Optional["MultiModalDataDict"]
|
|
|
|
-]
|
|
|
|
|
|
+PromptComponents = Tuple[Optional[str], List[int],
|
|
|
|
+ Optional["MultiModalDataDict"]]
|
|
|
|
+DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
|
|
|
+ Optional["MultiModalDataDict"]]
|
|
|
|
|
|
|
|
|
|
class InputPreprocessor:
|
|
class InputPreprocessor:
|
|
|
|
+
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
model_config: ModelConfig,
|
|
model_config: ModelConfig,
|
|
tokenizer: Optional[BaseTokenizerGroup],
|
|
tokenizer: Optional[BaseTokenizerGroup],
|
|
) -> None:
|
|
) -> None:
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
+
|
|
self.model_config = model_config
|
|
self.model_config = model_config
|
|
self.tokenizer = tokenizer
|
|
self.tokenizer = tokenizer
|
|
|
|
|
|
def get_tokenizer_group(self) -> BaseTokenizerGroup:
|
|
def get_tokenizer_group(self) -> BaseTokenizerGroup:
|
|
if self.tokenizer is None:
|
|
if self.tokenizer is None:
|
|
- raise ValueError(
|
|
|
|
- "You cannot pass text prompts when "
|
|
|
|
- "`skip_tokenizer_init` is True"
|
|
|
|
- )
|
|
|
|
|
|
+ raise ValueError("You cannot pass text prompts when "
|
|
|
|
+ "`skip_tokenizer_init` is True")
|
|
|
|
+
|
|
return self.tokenizer
|
|
return self.tokenizer
|
|
|
|
|
|
- def get_bos_token_id(
|
|
|
|
- self, lora_request: Optional[LoRARequest] = None
|
|
|
|
- ) -> Optional[int]:
|
|
|
|
|
|
+ def get_bos_token_id(self,
|
|
|
|
+ lora_request: Optional[LoRARequest] = None
|
|
|
|
+ ) -> Optional[int]:
|
|
if self.tokenizer is None:
|
|
if self.tokenizer is None:
|
|
- logger.warning(
|
|
|
|
- "Using None for BOS token id because tokenizer "
|
|
|
|
- "is not initialized"
|
|
|
|
- )
|
|
|
|
|
|
+ logger.warning("Using None for BOS token id because tokenizer "
|
|
|
|
+ "is not initialized")
|
|
return None
|
|
return None
|
|
|
|
+
|
|
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
|
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
|
|
|
|
|
|
- def get_eos_token_id(
|
|
|
|
- self, lora_request: Optional[LoRARequest] = None
|
|
|
|
- ) -> Optional[int]:
|
|
|
|
|
|
+ def get_eos_token_id(self,
|
|
|
|
+ lora_request: Optional[LoRARequest] = None
|
|
|
|
+ ) -> Optional[int]:
|
|
if self.tokenizer is None:
|
|
if self.tokenizer is None:
|
|
- logger.warning(
|
|
|
|
- "Using None for EOS token id because tokenizer "
|
|
|
|
- "is not initialized"
|
|
|
|
- )
|
|
|
|
|
|
+ logger.warning("Using None for EOS token id because tokenizer "
|
|
|
|
+ "is not initialized")
|
|
return None
|
|
return None
|
|
|
|
+
|
|
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
|
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
|
|
|
|
|
def get_decoder_start_token_id(self) -> Optional[int]:
|
|
def get_decoder_start_token_id(self) -> Optional[int]:
|
|
- """
|
|
|
|
|
|
+ '''
|
|
Obtain the decoder start token id employed by an encoder/decoder
|
|
Obtain the decoder start token id employed by an encoder/decoder
|
|
model. Returns None for non-encoder/decoder models or if the
|
|
model. Returns None for non-encoder/decoder models or if the
|
|
model config is unavailable.
|
|
model config is unavailable.
|
|
- """
|
|
|
|
|
|
+ '''
|
|
|
|
+
|
|
if not self.is_encoder_decoder_model():
|
|
if not self.is_encoder_decoder_model():
|
|
- logger.warning(
|
|
|
|
- "Using None for decoder start token id because "
|
|
|
|
- "this is not an encoder/decoder model."
|
|
|
|
- )
|
|
|
|
|
|
+ logger.warning("Using None for decoder start token id because "
|
|
|
|
+ "this is not an encoder/decoder model.")
|
|
return None
|
|
return None
|
|
- if self.model_config is None or self.model_config.hf_config is None:
|
|
|
|
- logger.warning(
|
|
|
|
- "Using None for decoder start token id because "
|
|
|
|
- "model config is not available."
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ if (self.model_config is None or self.model_config.hf_config is None):
|
|
|
|
+ logger.warning("Using None for decoder start token id because "
|
|
|
|
+ "model config is not available.")
|
|
return None
|
|
return None
|
|
- dec_start_token_id = getattr(
|
|
|
|
- self.model_config.hf_config, "decoder_start_token_id", None
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ dec_start_token_id = getattr(self.model_config.hf_config,
|
|
|
|
+ 'decoder_start_token_id', None)
|
|
if dec_start_token_id is None:
|
|
if dec_start_token_id is None:
|
|
- logger.warning(
|
|
|
|
- "Falling back on <BOS> for decoder start token id "
|
|
|
|
- "because decoder start token id is not available."
|
|
|
|
- )
|
|
|
|
|
|
+ logger.warning("Falling back on <BOS> for decoder start token id "
|
|
|
|
+ "because decoder start token id is not available.")
|
|
dec_start_token_id = self.get_bos_token_id()
|
|
dec_start_token_id = self.get_bos_token_id()
|
|
|
|
+
|
|
return dec_start_token_id
|
|
return dec_start_token_id
|
|
|
|
|
|
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
|
|
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
|
|
- """
|
|
|
|
|
|
+ '''
|
|
Specifically for encoder/decoder models:
|
|
Specifically for encoder/decoder models:
|
|
generate a default decoder prompt for when
|
|
generate a default decoder prompt for when
|
|
the user specifies only the encoder prompt.
|
|
the user specifies only the encoder prompt.
|
|
|
|
+
|
|
Encoder/decoder models utilize the decoder
|
|
Encoder/decoder models utilize the decoder
|
|
prompt in different ways; as new models are
|
|
prompt in different ways; as new models are
|
|
added, it is intended that this function
|
|
added, it is intended that this function
|
|
will be extended to produce differing
|
|
will be extended to produce differing
|
|
default decoder prompts, depending on the
|
|
default decoder prompts, depending on the
|
|
model variety.
|
|
model variety.
|
|
|
|
+
|
|
Absent a special case, the default behavior
|
|
Absent a special case, the default behavior
|
|
of this method is to mirror the behavior of
|
|
of this method is to mirror the behavior of
|
|
the HuggingFace (HF) GenerationMixin for a None
|
|
the HuggingFace (HF) GenerationMixin for a None
|
|
@@ -112,14 +108,18 @@ class InputPreprocessor:
|
|
setting to force the first decoded token to be <BOS>.
|
|
setting to force the first decoded token to be <BOS>.
|
|
Here, this behavior is approximated by having the
|
|
Here, this behavior is approximated by having the
|
|
"default" decoder prompt be <BOS>.
|
|
"default" decoder prompt be <BOS>.
|
|
|
|
+
|
|
However, it is possible that in the future
|
|
However, it is possible that in the future
|
|
- other models may have different or more
|
|
|
|
|
|
+ other models may have different or more
|
|
complex logic for the default decoder prompt.
|
|
complex logic for the default decoder prompt.
|
|
This motivates having a special helper method
|
|
This motivates having a special helper method
|
|
for default decoder prompts.
|
|
for default decoder prompts.
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
|
|
+
|
|
* prompt_token_ids
|
|
* prompt_token_ids
|
|
- """
|
|
|
|
|
|
+ '''
|
|
|
|
+
|
|
bos_token_id = self.get_bos_token_id()
|
|
bos_token_id = self.get_bos_token_id()
|
|
assert bos_token_id is not None
|
|
assert bos_token_id is not None
|
|
return [bos_token_id]
|
|
return [bos_token_id]
|
|
@@ -130,27 +130,36 @@ class InputPreprocessor:
|
|
) -> List[int]:
|
|
) -> List[int]:
|
|
"""
|
|
"""
|
|
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
|
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
|
|
|
+
|
|
Based on
|
|
Based on
|
|
|
|
+
|
|
https://github.com/huggingface/transformers/blob/
|
|
https://github.com/huggingface/transformers/blob/
|
|
4037a2b5b1278736e566aec12e169100275545ea/
|
|
4037a2b5b1278736e566aec12e169100275545ea/
|
|
src/transformers/generation/utils.py
|
|
src/transformers/generation/utils.py
|
|
|
|
+
|
|
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
|
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
|
|
|
|
+
|
|
Arguments:
|
|
Arguments:
|
|
|
|
+
|
|
* decoder_input_ids: input token ids to preprocess
|
|
* decoder_input_ids: input token ids to preprocess
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
|
|
+
|
|
* Processed token list
|
|
* Processed token list
|
|
"""
|
|
"""
|
|
|
|
+
|
|
decoder_start_token_id = self.get_decoder_start_token_id()
|
|
decoder_start_token_id = self.get_decoder_start_token_id()
|
|
assert decoder_start_token_id is not None
|
|
assert decoder_start_token_id is not None
|
|
|
|
+
|
|
if decoder_input_ids is None:
|
|
if decoder_input_ids is None:
|
|
# no decoder prompt input ->
|
|
# no decoder prompt input ->
|
|
# use decoder_start_token_id as decoder_input_ids
|
|
# use decoder_start_token_id as decoder_input_ids
|
|
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
|
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
|
- if (
|
|
|
|
- len(decoder_input_ids) == 0
|
|
|
|
- or decoder_input_ids[0] != decoder_start_token_id
|
|
|
|
- ):
|
|
|
|
|
|
+
|
|
|
|
+ if (len(decoder_input_ids) == 0
|
|
|
|
+ or decoder_input_ids[0] != decoder_start_token_id):
|
|
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
|
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
|
|
|
|
+
|
|
return decoder_input_ids
|
|
return decoder_input_ids
|
|
|
|
|
|
def _apply_prompt_adapter(
|
|
def _apply_prompt_adapter(
|
|
@@ -161,8 +170,8 @@ class InputPreprocessor:
|
|
if prompt_adapter_request:
|
|
if prompt_adapter_request:
|
|
prompt_token_ids = (
|
|
prompt_token_ids = (
|
|
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
|
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
|
- + prompt_token_ids
|
|
|
|
- )
|
|
|
|
|
|
+ + prompt_token_ids)
|
|
|
|
+
|
|
return prompt_token_ids
|
|
return prompt_token_ids
|
|
|
|
|
|
def _tokenize_prompt(
|
|
def _tokenize_prompt(
|
|
@@ -176,9 +185,10 @@ class InputPreprocessor:
|
|
corresponding token IDs.
|
|
corresponding token IDs.
|
|
"""
|
|
"""
|
|
tokenizer = self.get_tokenizer_group()
|
|
tokenizer = self.get_tokenizer_group()
|
|
- return tokenizer.encode(
|
|
|
|
- request_id=request_id, prompt=prompt, lora_request=lora_request
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ return tokenizer.encode(request_id=request_id,
|
|
|
|
+ prompt=prompt,
|
|
|
|
+ lora_request=lora_request)
|
|
|
|
|
|
async def _tokenize_prompt_async(
|
|
async def _tokenize_prompt_async(
|
|
self,
|
|
self,
|
|
@@ -188,83 +198,93 @@ class InputPreprocessor:
|
|
) -> List[int]:
|
|
) -> List[int]:
|
|
"""Async version of :meth:`_tokenize_prompt`."""
|
|
"""Async version of :meth:`_tokenize_prompt`."""
|
|
tokenizer = self.get_tokenizer_group()
|
|
tokenizer = self.get_tokenizer_group()
|
|
- return await tokenizer.encode_async(
|
|
|
|
- request_id=request_id, prompt=prompt, lora_request=lora_request
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ return await tokenizer.encode_async(request_id=request_id,
|
|
|
|
+ prompt=prompt,
|
|
|
|
+ lora_request=lora_request)
|
|
|
|
|
|
def _extract_prompt_components(
|
|
def _extract_prompt_components(
|
|
self,
|
|
self,
|
|
- inputs: SingletonPromptInputs,
|
|
|
|
|
|
+ prompt: SingletonPrompt,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
) -> PromptComponents:
|
|
) -> PromptComponents:
|
|
- """
|
|
|
|
|
|
+ '''
|
|
Extract the components of any single encoder or decoder input prompt.
|
|
Extract the components of any single encoder or decoder input prompt.
|
|
|
|
+
|
|
Arguments:
|
|
Arguments:
|
|
|
|
+
|
|
* request_id
|
|
* request_id
|
|
- * inputs: single encoder or decoder input prompt
|
|
|
|
|
|
+ * prompt: single encoder or decoder input prompt
|
|
* lora_request: this is only valid for decoder prompts
|
|
* lora_request: this is only valid for decoder prompts
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
|
|
+
|
|
* prompt
|
|
* prompt
|
|
* prompt_token_ids
|
|
* prompt_token_ids
|
|
* multi_modal_data
|
|
* multi_modal_data
|
|
- """
|
|
|
|
- parsed = parse_singleton_prompt(inputs)
|
|
|
|
|
|
+ '''
|
|
|
|
+
|
|
|
|
+ parsed = parse_singleton_prompt(prompt)
|
|
|
|
+
|
|
if parsed["type"] == "str":
|
|
if parsed["type"] == "str":
|
|
- prompt = parsed["content"]
|
|
|
|
|
|
+ prompt_text = parsed["content"]
|
|
prompt_token_ids = self._tokenize_prompt(
|
|
prompt_token_ids = self._tokenize_prompt(
|
|
- prompt,
|
|
|
|
|
|
+ prompt_text,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
)
|
|
)
|
|
multi_modal_data = None
|
|
multi_modal_data = None
|
|
elif parsed["type"] == "tokens":
|
|
elif parsed["type"] == "tokens":
|
|
- prompt = None
|
|
|
|
|
|
+ prompt_text = None
|
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
elif parsed["type"] == "text":
|
|
elif parsed["type"] == "text":
|
|
- prompt = parsed["content"]["prompt"]
|
|
|
|
|
|
+ prompt_text = parsed["content"]["prompt"]
|
|
prompt_token_ids = self._tokenize_prompt(
|
|
prompt_token_ids = self._tokenize_prompt(
|
|
- prompt,
|
|
|
|
|
|
+ prompt_text,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
)
|
|
)
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
else:
|
|
else:
|
|
assert_never(parsed)
|
|
assert_never(parsed)
|
|
- return prompt, prompt_token_ids, multi_modal_data
|
|
|
|
|
|
+
|
|
|
|
+ return prompt_text, prompt_token_ids, multi_modal_data
|
|
|
|
|
|
async def _extract_prompt_components_async(
|
|
async def _extract_prompt_components_async(
|
|
self,
|
|
self,
|
|
- inputs: SingletonPromptInputs,
|
|
|
|
|
|
+ prompt: SingletonPrompt,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
) -> PromptComponents:
|
|
) -> PromptComponents:
|
|
"""Async version of :meth:`_extract_prompt_components`."""
|
|
"""Async version of :meth:`_extract_prompt_components`."""
|
|
- parsed = parse_singleton_prompt(inputs)
|
|
|
|
|
|
+ parsed = parse_singleton_prompt(prompt)
|
|
|
|
+
|
|
if parsed["type"] == "str":
|
|
if parsed["type"] == "str":
|
|
- prompt = parsed["content"]
|
|
|
|
|
|
+ prompt_text = parsed["content"]
|
|
prompt_token_ids = await self._tokenize_prompt_async(
|
|
prompt_token_ids = await self._tokenize_prompt_async(
|
|
- prompt,
|
|
|
|
|
|
+ prompt_text,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
)
|
|
)
|
|
multi_modal_data = None
|
|
multi_modal_data = None
|
|
elif parsed["type"] == "tokens":
|
|
elif parsed["type"] == "tokens":
|
|
- prompt = None
|
|
|
|
|
|
+ prompt_text = None
|
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
elif parsed["type"] == "text":
|
|
elif parsed["type"] == "text":
|
|
- prompt = parsed["content"]["prompt"]
|
|
|
|
|
|
+ prompt_text = parsed["content"]["prompt"]
|
|
prompt_token_ids = await self._tokenize_prompt_async(
|
|
prompt_token_ids = await self._tokenize_prompt_async(
|
|
- prompt,
|
|
|
|
|
|
+ prompt_text,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
)
|
|
)
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
|
else:
|
|
else:
|
|
assert_never(parsed)
|
|
assert_never(parsed)
|
|
- return prompt, prompt_token_ids, multi_modal_data
|
|
|
|
|
|
+
|
|
|
|
+ return prompt_text, prompt_token_ids, multi_modal_data
|
|
|
|
|
|
def _build_enc_dec_llm_inputs(
|
|
def _build_enc_dec_llm_inputs(
|
|
self,
|
|
self,
|
|
@@ -273,13 +293,14 @@ class InputPreprocessor:
|
|
) -> EncoderDecoderLLMInputs:
|
|
) -> EncoderDecoderLLMInputs:
|
|
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
|
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
|
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
|
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
|
|
|
+
|
|
if encoder_mm_data is not None or decoder_mm_data is not None:
|
|
if encoder_mm_data is not None or decoder_mm_data is not None:
|
|
- raise ValueError(
|
|
|
|
- "Multi-modal encoder-decoder models are " "not supported yet"
|
|
|
|
- )
|
|
|
|
- decoder_prompt_ids = self._prepare_decoder_input_ids_for_generation(
|
|
|
|
- decoder_prompt_ids
|
|
|
|
- )
|
|
|
|
|
|
+ raise ValueError("Multi-modal encoder-decoder models are "
|
|
|
|
+ "not supported yet")
|
|
|
|
+
|
|
|
|
+ decoder_prompt_ids = (
|
|
|
|
+ self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
|
|
|
|
+
|
|
return EncoderDecoderLLMInputs(
|
|
return EncoderDecoderLLMInputs(
|
|
prompt_token_ids=decoder_prompt_ids,
|
|
prompt_token_ids=decoder_prompt_ids,
|
|
prompt=decoder_prompt,
|
|
prompt=decoder_prompt,
|
|
@@ -289,43 +310,52 @@ class InputPreprocessor:
|
|
|
|
|
|
def _process_encoder_decoder_prompt(
|
|
def _process_encoder_decoder_prompt(
|
|
self,
|
|
self,
|
|
- inputs: PromptInputs,
|
|
|
|
|
|
+ prompt: PromptType,
|
|
request_id: str,
|
|
request_id: str,
|
|
) -> EncoderDecoderLLMInputs:
|
|
) -> EncoderDecoderLLMInputs:
|
|
- """
|
|
|
|
|
|
+ '''
|
|
For encoder/decoder models only:
|
|
For encoder/decoder models only:
|
|
Process an input prompt into an
|
|
Process an input prompt into an
|
|
:class:`EncoderDecoderLLMInputs` instance.
|
|
:class:`EncoderDecoderLLMInputs` instance.
|
|
|
|
+
|
|
There are two types of input prompts:
|
|
There are two types of input prompts:
|
|
singleton prompts which carry only the
|
|
singleton prompts which carry only the
|
|
encoder prompt, and explicit encoder/decoder
|
|
encoder prompt, and explicit encoder/decoder
|
|
prompts which carry both the encoder and the
|
|
prompts which carry both the encoder and the
|
|
decoder prompts as member variables.
|
|
decoder prompts as member variables.
|
|
|
|
+
|
|
This function handles the following scenarios:
|
|
This function handles the following scenarios:
|
|
* Singleton encoder prompt: extract encoder prompt
|
|
* Singleton encoder prompt: extract encoder prompt
|
|
token ids & infer default decoder prompt token ids
|
|
token ids & infer default decoder prompt token ids
|
|
* Explicit encoder/decoder prompt: extract encoder
|
|
* Explicit encoder/decoder prompt: extract encoder
|
|
and decoder prompt token ids
|
|
and decoder prompt token ids
|
|
|
|
+
|
|
Note that for Explicit encoder/decoder prompts,
|
|
Note that for Explicit encoder/decoder prompts,
|
|
each sub-prompt (encoder or decoder prompt) can
|
|
each sub-prompt (encoder or decoder prompt) can
|
|
have any possible singleton type; thus this
|
|
have any possible singleton type; thus this
|
|
method relies on helper functions to obtain
|
|
method relies on helper functions to obtain
|
|
token ids for the sub-prompts.
|
|
token ids for the sub-prompts.
|
|
-
|
|
|
|
|
|
+
|
|
Arguments:
|
|
Arguments:
|
|
- * inputs: an input prompt
|
|
|
|
|
|
+
|
|
|
|
+ * prompt: an input prompt
|
|
* request_id
|
|
* request_id
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
|
|
+
|
|
* :class:`EncoderDecoderLLMInputs` instance
|
|
* :class:`EncoderDecoderLLMInputs` instance
|
|
- """
|
|
|
|
|
|
+ '''
|
|
|
|
+
|
|
encoder_comps: PromptComponents
|
|
encoder_comps: PromptComponents
|
|
decoder_comps: DecoderPromptComponents
|
|
decoder_comps: DecoderPromptComponents
|
|
- if is_explicit_encoder_decoder_prompt(inputs):
|
|
|
|
|
|
+
|
|
|
|
+ if is_explicit_encoder_decoder_prompt(prompt):
|
|
encoder_comps = self._extract_prompt_components(
|
|
encoder_comps = self._extract_prompt_components(
|
|
- inputs["encoder_prompt"],
|
|
|
|
|
|
+ prompt["encoder_prompt"],
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
- if (decoder_input := inputs["decoder_prompt"]) is None:
|
|
|
|
|
|
+
|
|
|
|
+ if (decoder_input := prompt["decoder_prompt"]) is None:
|
|
decoder_comps = None, None, None
|
|
decoder_comps = None, None, None
|
|
else:
|
|
else:
|
|
decoder_comps = self._extract_prompt_components(
|
|
decoder_comps = self._extract_prompt_components(
|
|
@@ -334,26 +364,30 @@ class InputPreprocessor:
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
encoder_comps = self._extract_prompt_components(
|
|
encoder_comps = self._extract_prompt_components(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
|
|
+
|
|
decoder_comps = None, None, None
|
|
decoder_comps = None, None, None
|
|
|
|
+
|
|
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
|
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
|
|
|
|
|
async def _process_encoder_decoder_prompt_async(
|
|
async def _process_encoder_decoder_prompt_async(
|
|
self,
|
|
self,
|
|
- inputs: PromptInputs,
|
|
|
|
|
|
+ prompt: PromptType,
|
|
request_id: str,
|
|
request_id: str,
|
|
) -> EncoderDecoderLLMInputs:
|
|
) -> EncoderDecoderLLMInputs:
|
|
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
|
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
|
encoder_comps: PromptComponents
|
|
encoder_comps: PromptComponents
|
|
decoder_comps: DecoderPromptComponents
|
|
decoder_comps: DecoderPromptComponents
|
|
- if is_explicit_encoder_decoder_prompt(inputs):
|
|
|
|
|
|
+
|
|
|
|
+ if is_explicit_encoder_decoder_prompt(prompt):
|
|
encoder_task = self._extract_prompt_components_async(
|
|
encoder_task = self._extract_prompt_components_async(
|
|
- inputs["encoder_prompt"],
|
|
|
|
|
|
+ prompt["encoder_prompt"],
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
- if (decoder_input := inputs["decoder_prompt"]) is None:
|
|
|
|
|
|
+
|
|
|
|
+ if (decoder_input := prompt["decoder_prompt"]) is None:
|
|
encoder_comps = await encoder_task
|
|
encoder_comps = await encoder_task
|
|
decoder_comps = None, None, None
|
|
decoder_comps = None, None, None
|
|
else:
|
|
else:
|
|
@@ -361,15 +395,17 @@ class InputPreprocessor:
|
|
decoder_input,
|
|
decoder_input,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
|
|
+
|
|
encoder_comps, decoder_comps = await asyncio.gather(
|
|
encoder_comps, decoder_comps = await asyncio.gather(
|
|
- encoder_task, decoder_task
|
|
|
|
- )
|
|
|
|
|
|
+ encoder_task, decoder_task)
|
|
else:
|
|
else:
|
|
encoder_comps = await self._extract_prompt_components_async(
|
|
encoder_comps = await self._extract_prompt_components_async(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
|
|
+
|
|
decoder_comps = None, None, None
|
|
decoder_comps = None, None, None
|
|
|
|
+
|
|
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
|
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
|
|
|
|
|
def _build_decoder_only_llm_inputs(
|
|
def _build_decoder_only_llm_inputs(
|
|
@@ -378,38 +414,43 @@ class InputPreprocessor:
|
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
|
) -> LLMInputs:
|
|
) -> LLMInputs:
|
|
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
|
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
|
|
|
+
|
|
prompt_token_ids = self._apply_prompt_adapter(
|
|
prompt_token_ids = self._apply_prompt_adapter(
|
|
- prompt_token_ids, prompt_adapter_request=prompt_adapter_request
|
|
|
|
- )
|
|
|
|
- return LLMInputs(
|
|
|
|
- prompt_token_ids=prompt_token_ids,
|
|
|
|
- prompt=prompt,
|
|
|
|
- multi_modal_data=multi_modal_data,
|
|
|
|
- )
|
|
|
|
|
|
+ prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
|
|
|
+
|
|
|
|
+ return LLMInputs(prompt_token_ids=prompt_token_ids,
|
|
|
|
+ prompt=prompt,
|
|
|
|
+ multi_modal_data=multi_modal_data)
|
|
|
|
|
|
def _process_decoder_only_prompt(
|
|
def _process_decoder_only_prompt(
|
|
self,
|
|
self,
|
|
- inputs: SingletonPromptInputs,
|
|
|
|
|
|
+ prompt: SingletonPrompt,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
) -> LLMInputs:
|
|
) -> LLMInputs:
|
|
- """
|
|
|
|
|
|
+ '''
|
|
For decoder-only models:
|
|
For decoder-only models:
|
|
Process an input prompt into an :class:`LLMInputs` instance.
|
|
Process an input prompt into an :class:`LLMInputs` instance.
|
|
|
|
+
|
|
Arguments:
|
|
Arguments:
|
|
- * inputs: input prompt
|
|
|
|
|
|
+
|
|
|
|
+ * prompt: input prompt
|
|
* request_id
|
|
* request_id
|
|
* lora_request
|
|
* lora_request
|
|
* prompt_adapter_request
|
|
* prompt_adapter_request
|
|
|
|
+
|
|
Returns:
|
|
Returns:
|
|
|
|
+
|
|
* :class:`LLMInputs` instance
|
|
* :class:`LLMInputs` instance
|
|
- """
|
|
|
|
|
|
+ '''
|
|
|
|
+
|
|
prompt_comps = self._extract_prompt_components(
|
|
prompt_comps = self._extract_prompt_components(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
)
|
|
)
|
|
|
|
+
|
|
return self._build_decoder_only_llm_inputs(
|
|
return self._build_decoder_only_llm_inputs(
|
|
prompt_comps,
|
|
prompt_comps,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
@@ -417,17 +458,18 @@ class InputPreprocessor:
|
|
|
|
|
|
async def _process_decoder_only_prompt_async(
|
|
async def _process_decoder_only_prompt_async(
|
|
self,
|
|
self,
|
|
- inputs: SingletonPromptInputs,
|
|
|
|
|
|
+ prompt: SingletonPrompt,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
) -> LLMInputs:
|
|
) -> LLMInputs:
|
|
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
|
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
|
prompt_comps = await self._extract_prompt_components_async(
|
|
prompt_comps = await self._extract_prompt_components_async(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
)
|
|
)
|
|
|
|
+
|
|
return self._build_decoder_only_llm_inputs(
|
|
return self._build_decoder_only_llm_inputs(
|
|
prompt_comps,
|
|
prompt_comps,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
@@ -435,7 +477,7 @@ class InputPreprocessor:
|
|
|
|
|
|
def preprocess(
|
|
def preprocess(
|
|
self,
|
|
self,
|
|
- inputs: PromptInputs,
|
|
|
|
|
|
+ prompt: PromptType,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
@@ -445,16 +487,17 @@ class InputPreprocessor:
|
|
# Encoder-decoder model requires special mapping of
|
|
# Encoder-decoder model requires special mapping of
|
|
# input prompts to encoder & decoder
|
|
# input prompts to encoder & decoder
|
|
return self._process_encoder_decoder_prompt(
|
|
return self._process_encoder_decoder_prompt(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
- if is_explicit_encoder_decoder_prompt(inputs):
|
|
|
|
- raise ValueError(
|
|
|
|
- "Cannot pass encoder-decoder prompt " "to decoder-only models"
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ if is_explicit_encoder_decoder_prompt(prompt):
|
|
|
|
+ raise ValueError("Cannot pass encoder-decoder prompt "
|
|
|
|
+ "to decoder-only models")
|
|
|
|
+
|
|
# Decoder-only operation
|
|
# Decoder-only operation
|
|
return self._process_decoder_only_prompt(
|
|
return self._process_decoder_only_prompt(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
@@ -462,7 +505,7 @@ class InputPreprocessor:
|
|
|
|
|
|
async def preprocess_async(
|
|
async def preprocess_async(
|
|
self,
|
|
self,
|
|
- inputs: PromptInputs,
|
|
|
|
|
|
+ prompt: PromptType,
|
|
request_id: str,
|
|
request_id: str,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
@@ -472,16 +515,17 @@ class InputPreprocessor:
|
|
# Encoder-decoder model requires special mapping of
|
|
# Encoder-decoder model requires special mapping of
|
|
# input prompts to encoder & decoder
|
|
# input prompts to encoder & decoder
|
|
return await self._process_encoder_decoder_prompt_async(
|
|
return await self._process_encoder_decoder_prompt_async(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
)
|
|
)
|
|
- if is_explicit_encoder_decoder_prompt(inputs):
|
|
|
|
- raise ValueError(
|
|
|
|
- "Cannot pass encoder-decoder prompt " "to decoder-only models"
|
|
|
|
- )
|
|
|
|
|
|
+
|
|
|
|
+ if is_explicit_encoder_decoder_prompt(prompt):
|
|
|
|
+ raise ValueError("Cannot pass encoder-decoder prompt "
|
|
|
|
+ "to decoder-only models")
|
|
|
|
+
|
|
# Decoder-only operation
|
|
# Decoder-only operation
|
|
return await self._process_decoder_only_prompt_async(
|
|
return await self._process_decoder_only_prompt_async(
|
|
- inputs,
|
|
|
|
|
|
+ prompt,
|
|
request_id=request_id,
|
|
request_id=request_id,
|
|
lora_request=lora_request,
|
|
lora_request=lora_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
prompt_adapter_request=prompt_adapter_request,
|