|
@@ -8,7 +8,8 @@ from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
|
|
|
|
|
|
from aphrodite.attention import AttentionMetadata
|
|
from aphrodite.attention import AttentionMetadata
|
|
from aphrodite.common.config import CacheConfig, MultiModalConfig
|
|
from aphrodite.common.config import CacheConfig, MultiModalConfig
|
|
-from aphrodite.common.sequence import SamplerOutput, SequenceData
|
|
|
|
|
|
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
|
|
|
|
+ SequenceData)
|
|
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
|
from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
|
from aphrodite.modeling.layers.linear import ColumnParallelLinear
|
|
from aphrodite.modeling.layers.linear import ColumnParallelLinear
|
|
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
@@ -107,7 +108,7 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
|
|
orig_prompt = llm_inputs.get("prompt")
|
|
orig_prompt = llm_inputs.get("prompt")
|
|
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
|
orig_prompt_ids = llm_inputs.get("prompt_token_ids")
|
|
|
|
|
|
- if image_token_str in orig_prompt:
|
|
|
|
|
|
+ if orig_prompt is not None and image_token_str in orig_prompt:
|
|
logger.warning(
|
|
logger.warning(
|
|
f"The image token '{image_token_str}' was detected in the prompt "
|
|
f"The image token '{image_token_str}' was detected in the prompt "
|
|
"and will be removed. Please follow the proper prompt format"
|
|
"and will be removed. Please follow the proper prompt format"
|
|
@@ -210,7 +211,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
|
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
|
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
|
pixel_values: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
- image_outputs = vision_tower(pixel_values, output_hidden_states=True)
|
|
|
|
|
|
+ target_dtype = vision_tower.get_input_embeddings().weight.dtype
|
|
|
|
+ image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
|
|
|
|
+ output_hidden_states=True)
|
|
|
|
|
|
selected_image_features = image_outputs.last_hidden_state
|
|
selected_image_features = image_outputs.last_hidden_state
|
|
|
|
|
|
@@ -232,9 +235,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|
|
|
|
|
return self.multi_modal_projector(image_features)
|
|
return self.multi_modal_projector(image_features)
|
|
|
|
|
|
- def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
|
|
|
|
|
+ def forward(self,
|
|
|
|
+ input_ids: torch.Tensor,
|
|
|
|
+ positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
attn_metadata: AttentionMetadata,
|
|
|
|
+ intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
**kwargs: object) -> SamplerOutput:
|
|
**kwargs: object) -> SamplerOutput:
|
|
|
|
|
|
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
|
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
|
|
@@ -259,6 +265,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
|
|
positions,
|
|
positions,
|
|
kv_caches,
|
|
kv_caches,
|
|
attn_metadata,
|
|
attn_metadata,
|
|
|
|
+ None,
|
|
inputs_embeds=inputs_embeds)
|
|
inputs_embeds=inputs_embeds)
|
|
|
|
|
|
return hidden_states
|
|
return hidden_states
|