Browse Source

fix rope error when loading models with different dtypes

AlpinDale 7 months ago
parent
commit
14a2d6f624
1 changed files with 22 additions and 13 deletions
  1. 22 13
      aphrodite/modeling/layers/rotary_embedding.py

+ 22 - 13
aphrodite/modeling/layers/rotary_embedding.py

@@ -56,6 +56,7 @@ class RotaryEmbedding(nn.Module):
         max_position_embeddings: int,
         base: int,
         is_neox_style: bool,
+        dtype: torch.dtype,
     ) -> None:
         super().__init__()
         self.head_size = head_size
@@ -65,7 +66,7 @@ class RotaryEmbedding(nn.Module):
         self.is_neox_style = is_neox_style
 
         cache = self._compute_cos_sin_cache()
-        cache = cache.to(torch.get_default_dtype())
+        cache = cache.to(dtype)
         self.register_buffer("cos_sin_cache", cache, persistent=False)
 
     def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
@@ -181,12 +182,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
         base: int,
         is_neox_style: bool,
         scaling_factors: Union[List[float], float],
+        dtype: torch.dtype,
     ) -> None:
         if isinstance(scaling_factors, float):
             scaling_factors = [scaling_factors]
         self.scaling_factors = scaling_factors
         super().__init__(head_size, rotary_dim, max_position_embeddings, base,
-                         is_neox_style)
+                         is_neox_style, dtype)
 
     def _compute_cos_sin_cache(self) -> torch.Tensor:
         inv_freq = self._compute_inv_freq(self.base)
@@ -222,10 +224,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
         base: int,
         is_neox_style: bool,
         scaling_factor: float,
+        dtype: torch.dtype,
     ) -> None:
         self.scaling_factor = scaling_factor
         super().__init__(head_size, rotary_dim, max_position_embeddings, base,
-                         is_neox_style)
+                         is_neox_style, dtype)
 
     def _compute_cos_sin_cache(self) -> torch.Tensor:
         # NOTE: self.max_position_embeddings is the original
@@ -302,6 +305,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         base: int,
         is_neox_style: bool,
         scaling_factor: float,
+        dtype: torch.dtype,
         *,
         extrapolation_factor: float = 1,
         attn_factor: float = 1,
@@ -317,7 +321,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
         self.mscale = float(
             _yarn_get_mscale(self.scaling_factor) * attn_factor)
         super().__init__(head_size, rotary_dim, max_position_embeddings, base,
-                         is_neox_style)
+                         is_neox_style, dtype)
 
     def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
         pos_freqs = self.base**(
@@ -358,6 +362,7 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
         original_max_position_embeddings: int,
         base: int,
         is_neox_style: bool,
+        dtype: torch.dtype,
         short_factor: List[float],
         long_factor: List[float],
         short_mscale: float = 1.1,
@@ -383,14 +388,14 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
 
         short_cache = self._compute_cos_sin_cache(
             original_max_position_embeddings, short_factor, short_mscale)
-        short_cache = short_cache.to(torch.get_default_dtype())
+        short_cache = short_cache.to(dtype)
         self.register_buffer("short_cos_sin_cache",
                              short_cache,
                              persistent=False)
 
         long_cache = self._compute_cos_sin_cache(
             original_max_position_embeddings, long_factor, long_mscale)
-        long_cache = long_cache.to(torch.get_default_dtype())
+        long_cache = long_cache.to(dtype)
         self.register_buffer("long_cos_sin_cache",
                              long_cache,
                              persistent=False)
@@ -486,7 +491,10 @@ def get_rope(
     base: int,
     is_neox_style: bool = True,
     rope_scaling: Optional[Dict[str, Any]] = None,
+    dtype: Optional[torch.dtype] = None,
 ) -> RotaryEmbedding:
+    if dtype is None:
+        dtype = torch.get_default_dtype()
     if rope_scaling is not None:
         rope_scaling_tuple = {
             k: tuple(v) if isinstance(v, list) else v
@@ -496,13 +504,13 @@ def get_rope(
     else:
         rope_scaling_args = None
     key = (head_size, rotary_dim, max_position, base, is_neox_style,
-           rope_scaling_args)
+           rope_scaling_args, dtype)
     if key in _ROPE_DICT:
         return _ROPE_DICT[key]
 
     if rope_scaling is None:
         rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
-                                     is_neox_style)
+                                     is_neox_style, dtype)
     else:
         scaling_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
         if scaling_type not in {"su", "longrope", "llama3"}:
@@ -510,16 +518,16 @@ def get_rope(
         if scaling_type == "llama3":
             rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
                                                  max_position, base,
-                                                 is_neox_style)
+                                                 is_neox_style, dtype)
         elif scaling_type == "linear":
             rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
                                                       max_position, base,
                                                       is_neox_style,
-                                                      scaling_factor)
+                                                      scaling_factor, dtype)
         elif scaling_type == "dynamic":
             rotary_emb = DynamicNTKScalingRotaryEmbedding(
                 head_size, rotary_dim, max_position, base, is_neox_style,
-                scaling_factor)
+                scaling_factor, dtype)
         elif scaling_type == "yarn":
             original_max_position = rope_scaling[
                 "original_max_position_embeddings"]
@@ -532,7 +540,7 @@ def get_rope(
             rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
                                                     original_max_position,
                                                     base, is_neox_style,
-                                                    scaling_factor,
+                                                    scaling_factor, dtype,
                                                     **extra_kwargs)
         elif scaling_type == "su" or scaling_type == "longrope":
             short_factor = rope_scaling["short_factor"]
@@ -546,7 +554,8 @@ def get_rope(
             }
             rotary_emb = Phi3LongRoPERotaryEmbedding(
                 head_size, rotary_dim, max_position, original_max_position,
-                base, is_neox_style, short_factor, long_factor, **extra_kwargs)
+                base, is_neox_style, dtype, short_factor, long_factor,
+                **extra_kwargs)
         else:
             raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
     _ROPE_DICT[key] = rotary_emb