Преглед на файлове

fix: correct auto ntk scaling_factor for 4k ctx case (#101)

* fix: correct auto ntk scaling_factor for 4k ctx case

* shorten line length

* format code
sandwichdoge преди 1 година
родител
ревизия
99293aaff0
променени са 1 файла, в които са добавени 8 реда и са изтрити 2 реда
  1. 8 2
      aphrodite/common/config.py

+ 8 - 2
aphrodite/common/config.py

@@ -8,6 +8,8 @@ from aphrodite.common.logger import init_logger
 from aphrodite.transformers_utils.config import get_config
 from aphrodite.common.utils import get_cpu_memory
 
+from math import exp, log
+
 logger = init_logger(__name__)
 
 _GB = 1 << 30
@@ -371,8 +373,12 @@ def _get_and_verify_max_len(
     if max_model_len is None:
         max_model_len = derived_max_model_len
     elif max_model_len > derived_max_model_len:
-        # hope this works
-        scaling_factor = max_model_len / derived_max_model_len
+        if derived_max_model_len == 4096:
+            scaling_factor = exp(
+                log((max_model_len - 1150.29) / 2982.33) / .884113)
+        else:
+            scaling_factor = max_model_len / derived_max_model_len
+
         hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
         logger.warning(
             f"User-specified max_model_len {max_model_len} is higher than "