Parcourir la source

fix: improve consistency between feature size calc and dummy data for profiling

AlpinDale il y a 6 mois
Parent
commit
526163003d
2 fichiers modifiés avec 17 ajouts et 26 suppressions
  1. 8 13
      aphrodite/modeling/models/llava_next.py
  2. 9 13
      aphrodite/modeling/models/phi3v.py

+ 8 - 13
aphrodite/modeling/models/llava_next.py

@@ -33,6 +33,9 @@ _KEYS_TO_MODIFY_MAPPING = {
     "language_model.model": "language_model",
 }
 
+# Result in the max possible feature size (2x2 grid of 336x336px tiles)
+MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
+
 
 class LlavaNextImagePixelInputs(TypedDict):
     type: Literal["pixel_values"]
@@ -124,13 +127,11 @@ def get_llava_next_image_feature_size(
 
 
 def get_max_llava_next_image_tokens(ctx: InputContext):
-    # Result in the max possible feature size (2x2 grid of 336x336px tiles)
-    dummy_height = dummy_width = 448
 
     return get_llava_next_image_feature_size(
         ctx.get_hf_config(LlavaNextConfig),
-        input_height=dummy_height,
-        input_width=dummy_width,
+        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
     )
 
 
@@ -138,13 +139,7 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
     hf_config = ctx.get_hf_config(LlavaNextConfig)
     vision_config = hf_config.vision_config
 
-    # Result in the max possible feature size (2x2 grid of 336x336px tiles)
-    dummy_height = dummy_width = 448
-    image_feature_size = get_llava_next_image_feature_size(
-        hf_config,
-        input_height=dummy_height,
-        input_width=dummy_width,
-    )
+    image_feature_size = get_max_llava_next_image_tokens(ctx)
 
     if isinstance(vision_config, CLIPVisionConfig):
         seq_data = dummy_seq_data_for_clip(
@@ -156,8 +151,8 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
 
         mm_data = dummy_image_for_clip(
             vision_config,
-            image_width_override=dummy_width,
-            image_height_override=dummy_height,
+            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
+            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
         )
 
         return seq_data, mm_data

+ 9 - 13
aphrodite/modeling/models/phi3v.py

@@ -51,6 +51,10 @@ _KEYS_TO_MODIFY_MAPPING = {
 # Cannot find the following 2 numbers from hf config.
 _IMAGE_TOKEN_ID = 32044
 
+# Result in the max possible feature size (h:w = 16:1)
+MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
+MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
+
 CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                      hidden_act="quick_gelu",
                                                      hidden_size=1024,
@@ -320,24 +324,16 @@ def get_phi3v_image_feature_size(
 
 
 def get_max_phi3v_image_tokens(ctx: InputContext):
-    # Result in the max possible feature size (h:w = 16:1)
-    dummy_height, dummy_width = 8000, 50
 
     return get_phi3v_image_feature_size(
         ctx.get_hf_config(PretrainedConfig),
-        input_height=dummy_height,
-        input_width=dummy_width,
+        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
     )
 
 
 def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
-    # Result in the max possible feature size (h:w = 16:1)
-    dummy_height, dummy_width = 8000, 50
-    image_feature_size = get_phi3v_image_feature_size(
-        ctx.get_hf_config(PretrainedConfig),
-        input_height=dummy_height,
-        input_width=dummy_width,
-    )
+    image_feature_size = get_max_phi3v_image_tokens(ctx)
 
     seq_data = dummy_seq_data_for_clip(
         CLIP_VIT_LARGE_PATCH14_336_CONFIG,
@@ -347,8 +343,8 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
     )
     mm_data = dummy_image_for_clip(
         CLIP_VIT_LARGE_PATCH14_336_CONFIG,
-        image_width_override=dummy_width,
-        image_height_override=dummy_height,
+        image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
+        image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
     )
 
     return seq_data, mm_data