|
@@ -8,6 +8,8 @@ from aphrodite.common.logger import init_logger
|
|
|
from aphrodite.transformers_utils.config import get_config
|
|
|
from aphrodite.common.utils import get_cpu_memory
|
|
|
|
|
|
+from math import exp, log
|
|
|
+
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
_GB = 1 << 30
|
|
@@ -371,8 +373,12 @@ def _get_and_verify_max_len(
|
|
|
if max_model_len is None:
|
|
|
max_model_len = derived_max_model_len
|
|
|
elif max_model_len > derived_max_model_len:
|
|
|
- # hope this works
|
|
|
- scaling_factor = max_model_len / derived_max_model_len
|
|
|
+ if derived_max_model_len == 4096:
|
|
|
+ scaling_factor = exp(
|
|
|
+ log((max_model_len - 1150.29) / 2982.33) / .884113)
|
|
|
+ else:
|
|
|
+ scaling_factor = max_model_len / derived_max_model_len
|
|
|
+
|
|
|
hf_config.rope_scaling = {"factor": scaling_factor, "type": "dynamic"}
|
|
|
logger.warning(
|
|
|
f"User-specified max_model_len {max_model_len} is higher than "
|