|
@@ -14,6 +14,7 @@ from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
from aphrodite.modeling.models.gemma import GemmaModel
|
|
|
+from aphrodite.modeling.models.gemma2 import Gemma2Model
|
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
from aphrodite.multimodal import MULTIMODAL_REGISTRY
|
|
|
from aphrodite.multimodal.utils import cached_get_tokenizer
|
|
@@ -105,6 +106,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
|
|
orig_prompt_ids.remove(hf_config.image_token_index)
|
|
|
|
|
|
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
|
|
|
+ # The PaliGemma 2 tokenizer does not include a starting BOS token
|
|
|
+ if orig_prompt_ids[0] != hf_config.bos_token_id:
|
|
|
+ orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
|
|
|
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
|
|
|
|
|
|
# NOTE: Create a defensive copy of the original inputs
|
|
@@ -148,8 +152,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|
|
projection_dim=config.vision_config.projection_dim)
|
|
|
|
|
|
self.quant_config = quant_config
|
|
|
- self.language_model = GemmaModel(config.text_config, cache_config,
|
|
|
- quant_config)
|
|
|
+ if config.text_config.model_type == "gemma":
|
|
|
+ self.language_model = GemmaModel(config.text_config, cache_config,
|
|
|
+ quant_config)
|
|
|
+ else:
|
|
|
+ self.language_model = Gemma2Model(config.text_config, cache_config,
|
|
|
+ quant_config)
|
|
|
self.unpadded_vocab_size = config.text_config.vocab_size
|
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|