|
@@ -341,6 +341,104 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|
|
return cache
|
|
|
|
|
|
|
|
|
+class Phi3LongRoPERotaryEmbedding(nn.Module):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ head_size: int,
|
|
|
+ rotary_dim: int,
|
|
|
+ max_position_embeddings: int,
|
|
|
+ original_max_position_embeddings: int,
|
|
|
+ base: int,
|
|
|
+ is_neox_style: bool,
|
|
|
+ short_factor: List[float],
|
|
|
+ long_factor: List[float],
|
|
|
+ short_mscale: float = 1.1,
|
|
|
+ long_mscale: float = 1.225,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ if rotary_dim != head_size:
|
|
|
+ raise ValueError(
|
|
|
+ f"Rotary dim must be equal to head size, got {rotary_dim} "
|
|
|
+ f"and {head_size}")
|
|
|
+ if is_neox_style is False:
|
|
|
+ raise ValueError(
|
|
|
+ "Phi3SuScaledRotaryEmbedding only supports Neox style")
|
|
|
+ self.head_size = head_size
|
|
|
+ self.max_position_embeddings = max_position_embeddings
|
|
|
+ self.original_max_position_embeddings = original_max_position_embeddings
|
|
|
+ self.base = base
|
|
|
+ self.short_factor = short_factor
|
|
|
+ self.long_factor = long_factor
|
|
|
+ self.short_mscale = short_mscale
|
|
|
+ self.long_mscale = long_mscale
|
|
|
+
|
|
|
+ short_cache = self._compute_cos_sin_cache(
|
|
|
+ original_max_position_embeddings, short_factor, short_mscale)
|
|
|
+ short_cache = short_cache.to(torch.get_default_dtype())
|
|
|
+ self.register_buffer("short_cos_sin_cache",
|
|
|
+ short_cache,
|
|
|
+ persistent=False)
|
|
|
+
|
|
|
+ long_cache = self._compute_cos_sin_cache(
|
|
|
+ original_max_position_embeddings, long_factor, long_mscale)
|
|
|
+ long_cache = long_cache.to(torch.get_default_dtype())
|
|
|
+ self.register_buffer("long_cos_sin_cache",
|
|
|
+ long_cache,
|
|
|
+ persistent=False)
|
|
|
+
|
|
|
+ long_short_cache = torch.cat(
|
|
|
+ [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
|
|
|
+ self.register_buffer("long_short_cos_sin_cache",
|
|
|
+ long_short_cache,
|
|
|
+ persistent=False)
|
|
|
+
|
|
|
+ def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
|
|
+ rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
|
|
+ inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
|
|
+ 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
|
|
|
+ return inv_freq
|
|
|
+
|
|
|
+ def _compute_cos_sin_cache(self, max_position_embeddings: int,
|
|
|
+ rescale_factors: List[float],
|
|
|
+ mscale: float) -> torch.Tensor:
|
|
|
+ inv_freq = self._compute_inv_freq(rescale_factors)
|
|
|
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
|
|
|
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
|
+ cos = freqs.cos() * mscale
|
|
|
+ sin = freqs.sin() * mscale
|
|
|
+ cache = torch.cat((cos, sin), dim=-1)
|
|
|
+ return cache
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ positions: torch.Tensor,
|
|
|
+ query: torch.Tensor,
|
|
|
+ key: torch.Tensor,
|
|
|
+ offsets: Optional[torch.Tensor] = None,
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ query = query.view(*query.shape[:-1], -1, self.head_size)
|
|
|
+ key = key.view(*key.shape[:-1], -1, self.head_size)
|
|
|
+ k = self.original_max_position_embeddings
|
|
|
+ long_prompt_offset = (torch.any(positions > k).float() *
|
|
|
+ torch.full_like(positions, k)).long()
|
|
|
+ idx = (torch.add(positions, long_prompt_offset)
|
|
|
+ if long_prompt_offset is not None else positions)
|
|
|
+ self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
|
|
|
+ idx.device)
|
|
|
+ idx = torch.add(idx, offsets) if offsets is not None else idx
|
|
|
+ cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
|
|
+ cos, sin = cos_sin.chunk(2, dim=-1)
|
|
|
+ cos = cos.repeat(1, 2).unsqueeze(-2)
|
|
|
+ sin = sin.repeat(1, 2).unsqueeze(-2)
|
|
|
+
|
|
|
+ query = query * cos + _rotate_neox(query) * sin
|
|
|
+ key = key * cos + _rotate_neox(key) * sin
|
|
|
+
|
|
|
+ return query.flatten(-2), key.flatten(-2)
|
|
|
+
|
|
|
+
|
|
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
|
|
|
|
|
|
|
@@ -352,8 +450,16 @@ def get_rope(
|
|
|
is_neox_style: bool = True,
|
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
|
) -> RotaryEmbedding:
|
|
|
+ if rope_scaling is not None:
|
|
|
+ rope_scaling_tuple = {
|
|
|
+ k: tuple(v) if isinstance(v, list) else v
|
|
|
+ for k, v in rope_scaling.items()
|
|
|
+ }
|
|
|
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
|
|
|
+ else:
|
|
|
+ rope_scaling_args = None
|
|
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
|
|
- tuple(rope_scaling.items()) if rope_scaling is not None else None)
|
|
|
+ rope_scaling_args)
|
|
|
if key in _ROPE_DICT:
|
|
|
return _ROPE_DICT[key]
|
|
|
|
|
@@ -362,7 +468,8 @@ def get_rope(
|
|
|
is_neox_style)
|
|
|
else:
|
|
|
scaling_type = rope_scaling["type"]
|
|
|
- scaling_factor = rope_scaling["factor"]
|
|
|
+ if scaling_type != "su" and scaling_type != "longrope":
|
|
|
+ scaling_factor = rope_scaling["factor"]
|
|
|
if scaling_type == "linear":
|
|
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
|
|
max_position, base,
|
|
@@ -386,6 +493,19 @@ def get_rope(
|
|
|
base, is_neox_style,
|
|
|
scaling_factor,
|
|
|
**extra_kwargs)
|
|
|
+ elif scaling_type == "su" or scaling_type == "longrope":
|
|
|
+ short_factor = rope_scaling["short_factor"]
|
|
|
+ long_factor = rope_scaling["long_factor"]
|
|
|
+ original_max_position = rope_scaling[
|
|
|
+ "original_max_position_embeddings"]
|
|
|
+ extra_kwargs = {
|
|
|
+ k: v
|
|
|
+ for k, v in rope_scaling.items()
|
|
|
+ if k in ("short_mscale", "long_mscale")
|
|
|
+ }
|
|
|
+ rotary_emb = Phi3LongRoPERotaryEmbedding(
|
|
|
+ head_size, rotary_dim, max_position, original_max_position,
|
|
|
+ base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
_ROPE_DICT[key] = rotary_emb
|