Răsfoiți Sursa

fix: llava 1.6 feature size calculation

AlpinDale 6 luni în urmă
părinte
comite
0ab35652d3
1 a modificat fișierele cu 10 adăugiri și 8 ștergeri
  1. 10 8
      aphrodite/modeling/models/llava_next.py

+ 10 - 8
aphrodite/modeling/models/llava_next.py

@@ -70,19 +70,21 @@ def _get_llava_next_num_unpadded_features(
 ) -> Tuple[int, int]:
     current_height = npatches * num_patch_height
     current_width = npatches * num_patch_width
+    current_height = torch.tensor(current_height).to("cuda")
+    current_width = torch.tensor(current_width).to("cuda")
 
     aspect_ratio: float = width / height
     current_aspect_ratio: float = current_width / current_height
     if aspect_ratio > current_aspect_ratio:
-        new_height = (height * current_width) // width
-        if new_height % 2 == 1:
-            new_height += 1
-        current_height = new_height
+        scale_factor = current_width / width
+        new_height = int(height * scale_factor)
+        padding = (current_height - new_height) // 2
+        current_height -= padding * 2
     else:
-        new_width = (width * current_height) // height
-        if new_width % 2 == 1:
-            new_width += 1
-        current_width = new_width
+        scale_factor = current_height / height
+        new_width = int(width * scale_factor)
+        padding = (current_width - new_width) // 2
+        current_width -= padding * 2
 
     unpadded_features = current_height * current_width
     newline_features = current_height