Kaynağa Gözat

fix: dtype mismatch for paligemma

AlpinDale 6 ay önce
ebeveyn
işleme
05e45aeb53

+ 1 - 0
aphrodite/modeling/models/gemma.py

@@ -274,6 +274,7 @@ class GemmaModel(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
         if inputs_embeds is not None:

+ 11 - 4
aphrodite/modeling/models/paligemma.py

@@ -8,7 +8,8 @@ from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
-from aphrodite.common.sequence import SamplerOutput, SequenceData
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceData)
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -107,7 +108,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
     orig_prompt = llm_inputs.get("prompt")
     orig_prompt_ids = llm_inputs.get("prompt_token_ids")
 
-    if image_token_str in orig_prompt:
+    if orig_prompt is not None and image_token_str in orig_prompt:
         logger.warning(
             f"The image token '{image_token_str}' was detected in the prompt "
             "and will be removed. Please follow the proper prompt format"
@@ -210,7 +211,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
     def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
                                   pixel_values: torch.Tensor) -> torch.Tensor:
 
-        image_outputs = vision_tower(pixel_values, output_hidden_states=True)
+        target_dtype = vision_tower.get_input_embeddings().weight.dtype
+        image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
+                                     output_hidden_states=True)
 
         selected_image_features = image_outputs.last_hidden_state
 
@@ -232,9 +235,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
 
         return self.multi_modal_projector(image_features)
 
-    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+    def forward(self,
+                input_ids: torch.Tensor,
+                positions: torch.Tensor,
                 kv_caches: List[torch.Tensor],
                 attn_metadata: AttentionMetadata,
+                intermediate_tensors: Optional[IntermediateTensors] = None,
                 **kwargs: object) -> SamplerOutput:
 
         parsed_image_input = self._parse_and_validate_image_input(**kwargs)
@@ -259,6 +265,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
                                             positions,
                                             kv_caches,
                                             attn_metadata,
+                                            None,
                                             inputs_embeds=inputs_embeds)
 
         return hidden_states