|
@@ -56,6 +56,7 @@ class RotaryEmbedding(nn.Module):
|
|
|
max_position_embeddings: int,
|
|
|
base: int,
|
|
|
is_neox_style: bool,
|
|
|
+ dtype: torch.dtype,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
self.head_size = head_size
|
|
@@ -65,7 +66,7 @@ class RotaryEmbedding(nn.Module):
|
|
|
self.is_neox_style = is_neox_style
|
|
|
|
|
|
cache = self._compute_cos_sin_cache()
|
|
|
- cache = cache.to(torch.get_default_dtype())
|
|
|
+ cache = cache.to(dtype)
|
|
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
|
|
|
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
|
@@ -181,12 +182,13 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|
|
base: int,
|
|
|
is_neox_style: bool,
|
|
|
scaling_factors: Union[List[float], float],
|
|
|
+ dtype: torch.dtype,
|
|
|
) -> None:
|
|
|
if isinstance(scaling_factors, float):
|
|
|
scaling_factors = [scaling_factors]
|
|
|
self.scaling_factors = scaling_factors
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
|
|
- is_neox_style)
|
|
|
+ is_neox_style, dtype)
|
|
|
|
|
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
|
inv_freq = self._compute_inv_freq(self.base)
|
|
@@ -222,10 +224,11 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|
|
base: int,
|
|
|
is_neox_style: bool,
|
|
|
scaling_factor: float,
|
|
|
+ dtype: torch.dtype,
|
|
|
) -> None:
|
|
|
self.scaling_factor = scaling_factor
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
|
|
- is_neox_style)
|
|
|
+ is_neox_style, dtype)
|
|
|
|
|
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
|
# NOTE: self.max_position_embeddings is the original
|
|
@@ -302,6 +305,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|
|
base: int,
|
|
|
is_neox_style: bool,
|
|
|
scaling_factor: float,
|
|
|
+ dtype: torch.dtype,
|
|
|
*,
|
|
|
extrapolation_factor: float = 1,
|
|
|
attn_factor: float = 1,
|
|
@@ -317,7 +321,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|
|
self.mscale = float(
|
|
|
_yarn_get_mscale(self.scaling_factor) * attn_factor)
|
|
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
|
|
- is_neox_style)
|
|
|
+ is_neox_style, dtype)
|
|
|
|
|
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
|
|
pos_freqs = self.base**(
|
|
@@ -358,6 +362,7 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
|
|
|
original_max_position_embeddings: int,
|
|
|
base: int,
|
|
|
is_neox_style: bool,
|
|
|
+ dtype: torch.dtype,
|
|
|
short_factor: List[float],
|
|
|
long_factor: List[float],
|
|
|
short_mscale: float = 1.1,
|
|
@@ -383,14 +388,14 @@ class Phi3LongRoPERotaryEmbedding(nn.Module):
|
|
|
|
|
|
short_cache = self._compute_cos_sin_cache(
|
|
|
original_max_position_embeddings, short_factor, short_mscale)
|
|
|
- short_cache = short_cache.to(torch.get_default_dtype())
|
|
|
+ short_cache = short_cache.to(dtype)
|
|
|
self.register_buffer("short_cos_sin_cache",
|
|
|
short_cache,
|
|
|
persistent=False)
|
|
|
|
|
|
long_cache = self._compute_cos_sin_cache(
|
|
|
original_max_position_embeddings, long_factor, long_mscale)
|
|
|
- long_cache = long_cache.to(torch.get_default_dtype())
|
|
|
+ long_cache = long_cache.to(dtype)
|
|
|
self.register_buffer("long_cos_sin_cache",
|
|
|
long_cache,
|
|
|
persistent=False)
|
|
@@ -486,7 +491,10 @@ def get_rope(
|
|
|
base: int,
|
|
|
is_neox_style: bool = True,
|
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
|
+ dtype: Optional[torch.dtype] = None,
|
|
|
) -> RotaryEmbedding:
|
|
|
+ if dtype is None:
|
|
|
+ dtype = torch.get_default_dtype()
|
|
|
if rope_scaling is not None:
|
|
|
rope_scaling_tuple = {
|
|
|
k: tuple(v) if isinstance(v, list) else v
|
|
@@ -496,13 +504,13 @@ def get_rope(
|
|
|
else:
|
|
|
rope_scaling_args = None
|
|
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
|
|
- rope_scaling_args)
|
|
|
+ rope_scaling_args, dtype)
|
|
|
if key in _ROPE_DICT:
|
|
|
return _ROPE_DICT[key]
|
|
|
|
|
|
if rope_scaling is None:
|
|
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
|
|
- is_neox_style)
|
|
|
+ is_neox_style, dtype)
|
|
|
else:
|
|
|
scaling_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
|
|
|
if scaling_type not in {"su", "longrope", "llama3"}:
|
|
@@ -510,16 +518,16 @@ def get_rope(
|
|
|
if scaling_type == "llama3":
|
|
|
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
|
|
|
max_position, base,
|
|
|
- is_neox_style)
|
|
|
+ is_neox_style, dtype)
|
|
|
elif scaling_type == "linear":
|
|
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
|
|
max_position, base,
|
|
|
is_neox_style,
|
|
|
- scaling_factor)
|
|
|
+ scaling_factor, dtype)
|
|
|
elif scaling_type == "dynamic":
|
|
|
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
|
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
|
|
- scaling_factor)
|
|
|
+ scaling_factor, dtype)
|
|
|
elif scaling_type == "yarn":
|
|
|
original_max_position = rope_scaling[
|
|
|
"original_max_position_embeddings"]
|
|
@@ -532,7 +540,7 @@ def get_rope(
|
|
|
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
|
|
original_max_position,
|
|
|
base, is_neox_style,
|
|
|
- scaling_factor,
|
|
|
+ scaling_factor, dtype,
|
|
|
**extra_kwargs)
|
|
|
elif scaling_type == "su" or scaling_type == "longrope":
|
|
|
short_factor = rope_scaling["short_factor"]
|
|
@@ -546,7 +554,8 @@ def get_rope(
|
|
|
}
|
|
|
rotary_emb = Phi3LongRoPERotaryEmbedding(
|
|
|
head_size, rotary_dim, max_position, original_max_position,
|
|
|
- base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
|
|
+ base, is_neox_style, dtype, short_factor, long_factor,
|
|
|
+ **extra_kwargs)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
_ROPE_DICT[key] = rotary_emb
|