فهرست منبع

model: add support for paligemma2 (#936)

AlpinDale 2 ماه پیش
والد
کامیت
c50309d386
2فایلهای تغییر یافته به همراه19 افزوده شده و 3 حذف شده
  1. 9 1
      aphrodite/modeling/models/gemma2.py
  2. 10 2
      aphrodite/modeling/models/paligemma.py

+ 9 - 1
aphrodite/modeling/models/gemma2.py

@@ -265,14 +265,22 @@ class Gemma2Model(nn.Module):
         normalizer = self.config.hidden_size**0.5
         self.register_buffer("normalizer", torch.tensor(normalizer))
 
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.embed_tokens(input_ids)
+
     def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
     ) -> torch.Tensor:
-        hidden_states = self.embed_tokens(input_ids)
+        if inputs_embeds is not None:
+            hidden_states = inputs_embeds
+        else:
+            hidden_states = self.get_input_embeddings(input_ids)
         hidden_states *= self.normalizer
 
         residual = None

+ 10 - 2
aphrodite/modeling/models/paligemma.py

@@ -14,6 +14,7 @@ from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.gemma import GemmaModel
+from aphrodite.modeling.models.gemma2 import Gemma2Model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal.utils import cached_get_tokenizer
@@ -105,6 +106,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
         orig_prompt_ids.remove(hf_config.image_token_index)
 
     new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
+    # The PaliGemma 2 tokenizer does not include a starting BOS token
+    if orig_prompt_ids[0] != hf_config.bos_token_id:
+        orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
     new_token_ids = image_token_ids_pad + orig_prompt_ids + [108]  #newline
 
     # NOTE: Create a defensive copy of the original inputs
@@ -148,8 +152,12 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
             projection_dim=config.vision_config.projection_dim)
 
         self.quant_config = quant_config
-        self.language_model = GemmaModel(config.text_config, cache_config,
-                                         quant_config)
+        if config.text_config.model_type == "gemma":
+            self.language_model = GemmaModel(config.text_config, cache_config,
+                                            quant_config)
+        else:
+            self.language_model = Gemma2Model(config.text_config, cache_config,
+                                            quant_config)
         self.unpadded_vocab_size = config.text_config.vocab_size
         logit_scale = getattr(config, "logit_scale", 1.0)
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,