|
@@ -9,9 +9,9 @@ from loguru import logger
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
from typing_extensions import assert_never
|
|
|
|
|
|
-from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
|
|
- EngineConfig, LoadConfig, LoRAConfig,
|
|
|
- ModelConfig, ParallelConfig,
|
|
|
+from aphrodite.common.config import (CacheConfig, CFGConfig, DecodingConfig,
|
|
|
+ DeviceConfig, EngineConfig, LoadConfig,
|
|
|
+ LoRAConfig, ModelConfig, ParallelConfig,
|
|
|
PromptAdapterConfig, SchedulerConfig,
|
|
|
SpeculativeConfig)
|
|
|
from aphrodite.common.logger import setup_logger
|
|
@@ -70,9 +70,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
|
|
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
|
|
|
|
|
PromptComponents = Tuple[Optional[str], List[int],
|
|
|
- Optional[MultiModalDataDict]]
|
|
|
+ Optional[MultiModalDataDict],
|
|
|
+ Optional[None], Optional[None]]
|
|
|
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
|
|
- Optional[MultiModalDataDict]]
|
|
|
+ Optional[MultiModalDataDict],
|
|
|
+ Optional[None], Optional[None]]
|
|
|
|
|
|
|
|
|
class AphroditeEngine:
|
|
@@ -171,6 +173,7 @@ class AphroditeEngine:
|
|
|
speculative_config: Optional[SpeculativeConfig],
|
|
|
decoding_config: Optional[DecodingConfig],
|
|
|
prompt_adapter_config: Optional[PromptAdapterConfig],
|
|
|
+ cfg_config: Optional[CFGConfig],
|
|
|
executor_class: Type[ExecutorBase],
|
|
|
log_stats: bool,
|
|
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
|
@@ -185,6 +188,7 @@ class AphroditeEngine:
|
|
|
config_dict = {
|
|
|
"Model": model_config.model,
|
|
|
"Speculative Config": speculative_config,
|
|
|
+ "CFG Config": cfg_config,
|
|
|
"DataType": model_config.dtype,
|
|
|
"Model Load Format": load_config.load_format,
|
|
|
"Tensor Parallel Size": parallel_config.tensor_parallel_size,
|
|
@@ -233,6 +237,7 @@ class AphroditeEngine:
|
|
|
self.load_config = load_config
|
|
|
self.decoding_config = decoding_config or DecodingConfig()
|
|
|
self.prompt_adapter_config = prompt_adapter_config
|
|
|
+ self.cfg_config = cfg_config
|
|
|
self.log_stats = log_stats
|
|
|
|
|
|
if not self.model_config.skip_tokenizer_init:
|
|
@@ -269,6 +274,7 @@ class AphroditeEngine:
|
|
|
speculative_config=speculative_config,
|
|
|
load_config=load_config,
|
|
|
prompt_adapter_config=prompt_adapter_config,
|
|
|
+ cfg_config=cfg_config,
|
|
|
)
|
|
|
|
|
|
if not self.model_config.embedding_mode:
|
|
@@ -533,6 +539,16 @@ class AphroditeEngine:
|
|
|
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
|
|
lora_request, prompt_adapter_request)
|
|
|
|
|
|
+ negative_seq = None
|
|
|
+ if 'negative_prompt_token_ids' in processed_inputs:
|
|
|
+ negative_seq = Sequence(seq_id,
|
|
|
+ processed_inputs,
|
|
|
+ block_size,
|
|
|
+ eos_token_id,
|
|
|
+ lora_request,
|
|
|
+ prompt_adapter_request,
|
|
|
+ from_negative_prompt=True)
|
|
|
+
|
|
|
encoder_seq = None
|
|
|
if 'encoder_prompt_token_ids' in processed_inputs:
|
|
|
encoder_seq = Sequence(seq_id,
|
|
@@ -553,6 +569,7 @@ class AphroditeEngine:
|
|
|
lora_request=lora_request,
|
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
|
encoder_seq=encoder_seq,
|
|
|
+ negative_seq=negative_seq,
|
|
|
)
|
|
|
elif isinstance(params, PoolingParams):
|
|
|
seq_group = self._create_sequence_group_with_pooling(
|
|
@@ -661,6 +678,8 @@ class AphroditeEngine:
|
|
|
lora_request=lora_request,
|
|
|
)
|
|
|
multi_modal_data = None
|
|
|
+ negative_prompt = None
|
|
|
+ negative_prompt_token_ids = None
|
|
|
elif isinstance(inputs, dict):
|
|
|
if "prompt_token_ids" in inputs:
|
|
|
prompt = None
|
|
@@ -674,11 +693,27 @@ class AphroditeEngine:
|
|
|
lora_request=lora_request,
|
|
|
)
|
|
|
|
|
|
+ if "negative_prompt_token_ids" in inputs:
|
|
|
+ negative_prompt = None
|
|
|
+ negative_prompt_token_ids = inputs["negative_prompt_token_ids"]
|
|
|
+ elif "negative_prompt" in inputs:
|
|
|
+ negative_prompt = parsed_negative_prompt = inputs[
|
|
|
+ "negative_prompt"]
|
|
|
+ negative_prompt_token_ids = self._tokenize_prompt(
|
|
|
+ parsed_negative_prompt,
|
|
|
+ request_id=request_id,
|
|
|
+ lora_request=lora_request,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ negative_prompt = None
|
|
|
+ negative_prompt_token_ids = None
|
|
|
+
|
|
|
multi_modal_data = inputs.get("multi_modal_data")
|
|
|
else:
|
|
|
assert_never(inputs)
|
|
|
|
|
|
- return prompt, prompt_token_ids, multi_modal_data
|
|
|
+ return (prompt, prompt_token_ids, multi_modal_data,
|
|
|
+ negative_prompt, negative_prompt_token_ids)
|
|
|
|
|
|
def _apply_prompt_adapter(
|
|
|
self,
|
|
@@ -728,8 +763,10 @@ class AphroditeEngine:
|
|
|
encoder_comps: PromptComponents,
|
|
|
decoder_comps: DecoderPromptComponents,
|
|
|
) -> EncoderDecoderLLMInputs:
|
|
|
- encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
|
|
- decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
|
|
+ encoder_prompt, encoder_prompt_ids, encoder_mm_data, \
|
|
|
+ encoder_negative_prompt, encoder_negative_prompt_ids = encoder_comps
|
|
|
+ decoder_prompt, decoder_prompt_ids, decoder_mm_data, \
|
|
|
+ decoder_negative_prompt, decoder_negative_prompt_ids= decoder_comps
|
|
|
|
|
|
if encoder_mm_data is not None or decoder_mm_data is not None:
|
|
|
raise ValueError("Multi-modal encoder-decoder models are "
|
|
@@ -737,12 +774,18 @@ class AphroditeEngine:
|
|
|
|
|
|
decoder_prompt_ids = (
|
|
|
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
|
|
|
+ decoder_negative_prompt_ids = (
|
|
|
+ self._prepare_decoder_input_ids_for_generation(decoder_negative_prompt_ids))
|
|
|
|
|
|
return EncoderDecoderLLMInputs(
|
|
|
prompt_token_ids=decoder_prompt_ids,
|
|
|
prompt=decoder_prompt,
|
|
|
+ negative_prompt_token_ids=decoder_negative_prompt_ids,
|
|
|
+ negative_prompt=decoder_negative_prompt,
|
|
|
encoder_prompt_token_ids=encoder_prompt_ids,
|
|
|
encoder_prompt=encoder_prompt,
|
|
|
+ encoder_negative_prompt_token_ids=encoder_negative_prompt_ids,
|
|
|
+ encoder_negative_prompt=encoder_negative_prompt,
|
|
|
)
|
|
|
|
|
|
def _process_encoder_decoder_prompt(
|
|
@@ -787,7 +830,7 @@ class AphroditeEngine:
|
|
|
)
|
|
|
|
|
|
if (decoder_input := inputs["decoder_prompt"]) is None:
|
|
|
- decoder_comps = None, None, None
|
|
|
+ decoder_comps = None, None, None, None, None
|
|
|
else:
|
|
|
decoder_comps = self._extract_prompt_components(
|
|
|
decoder_input,
|
|
@@ -799,7 +842,7 @@ class AphroditeEngine:
|
|
|
request_id=request_id,
|
|
|
)
|
|
|
|
|
|
- decoder_comps = None, None, None
|
|
|
+ decoder_comps = None, None, None, None, None
|
|
|
|
|
|
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
|
|
|
|
@@ -808,14 +851,17 @@ class AphroditeEngine:
|
|
|
prompt_comps: PromptComponents,
|
|
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
|
|
) -> LLMInputs:
|
|
|
- prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
|
|
+ prompt, prompt_token_ids, multi_modal_data, \
|
|
|
+ negative_prompt, negative_prompt_token_ids = prompt_comps
|
|
|
|
|
|
prompt_token_ids = self._apply_prompt_adapter(
|
|
|
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
|
|
|
|
|
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
|
|
prompt=prompt,
|
|
|
- multi_modal_data=multi_modal_data)
|
|
|
+ multi_modal_data=multi_modal_data,
|
|
|
+ negative_prompt_token_ids=negative_prompt_token_ids,
|
|
|
+ negative_prompt=negative_prompt)
|
|
|
|
|
|
def _process_decoder_only_prompt(
|
|
|
self,
|
|
@@ -960,6 +1006,7 @@ class AphroditeEngine:
|
|
|
lora_request: Optional[LoRARequest],
|
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
|
encoder_seq: Optional[Sequence] = None,
|
|
|
+ negative_seq: Optional[Sequence] = None,
|
|
|
) -> SequenceGroup:
|
|
|
"""Creates a SequenceGroup with SamplingParams."""
|
|
|
max_logprobs = self.get_model_config().max_logprobs
|
|
@@ -984,7 +1031,8 @@ class AphroditeEngine:
|
|
|
sampling_params=sampling_params,
|
|
|
lora_request=lora_request,
|
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
|
- encoder_seq=encoder_seq)
|
|
|
+ encoder_seq=encoder_seq,
|
|
|
+ negative_seq=negative_seq)
|
|
|
|
|
|
return seq_group
|
|
|
|
|
@@ -997,6 +1045,7 @@ class AphroditeEngine:
|
|
|
lora_request: Optional[LoRARequest],
|
|
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
|
|
encoder_seq: Optional[Sequence] = None,
|
|
|
+ negative_seq: Optional[Sequence] = None,
|
|
|
) -> SequenceGroup:
|
|
|
"""Creates a SequenceGroup with PoolingParams."""
|
|
|
# Defensive copy of PoolingParams, which are used by the pooler
|
|
@@ -1009,7 +1058,8 @@ class AphroditeEngine:
|
|
|
lora_request=lora_request,
|
|
|
pooling_params=pooling_params,
|
|
|
prompt_adapter_request=prompt_adapter_request,
|
|
|
- encoder_seq=encoder_seq)
|
|
|
+ encoder_seq=encoder_seq,
|
|
|
+ negative_seq=negative_seq)
|
|
|
|
|
|
return seq_group
|
|
|
|