# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The PygmalionAI team. # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Rotary Positional Embeddings.""" import math from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from aphrodite._C import ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: x1 = x[..., ::2] x2 = x[..., 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) class RotaryEmbedding(nn.Module): """Original rotary positional embedding.""" def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, ) -> None: super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style cache = self._compute_cos_sin_cache() cache = cache.to(torch.get_default_dtype()) self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE: The HF implementation uses `torch.arange(...).float()`. # However, we use `torch.arange(..., dtype=torch.float)` instead to # avoid numerical issues with large base values (e.g., 10000000). # This may cause a slight numerical difference between the HF # implementation and ours. # NOTE: To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / (base**(torch.arange( 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache def _forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] if self.rotary_dim < self.head_size: query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( positions.device) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE: Here we assume that the positions tensor has the # shape [batch_size, seq_len]. cos = cos.repeat(1, 1, 2).unsqueeze(-2) sin = sin.repeat(1, 1, 2).unsqueeze(-2) else: cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj query_rot = query_rot * cos + rotate_fn(query_rot) * sin key_rot = key_rot * cos + rotate_fn(key_rot) * sin if self.rotary_dim < self.head_size: query = torch.cat((query_rot, query_pass), dim=-1) key = torch.cat((key_rot, key_pass), dim=-1) else: query = query_rot key = key_rot query = query.flatten(-2) key = key.flatten(-2) return query, key def forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: self.cos_sin_cache = self.cos_sin_cache.to(positions.device) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: ops.batched_rotary_embedding(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style, self.rotary_dim, offsets) else: ops.rotary_embedding(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style) return query, key class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, scaling_factors: Union[List[float], float], ) -> None: if isinstance(scaling_factors, float): scaling_factors = [scaling_factors] self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) cache_list = [] for scaling_factor in self.scaling_factors: # NOTE: self.max_position_embeddings is the original # maximum length before applying the rope scaling. # Thus, the maximum length after applying the rope scaling is # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * scaling_factor t = torch.arange(max_len, dtype=torch.float) t = t / scaling_factor freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) cache_list.append(cache) return torch.cat(cache_list, dim=0) class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, scaling_factor: float, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE: self.max_position_embeddings is the original # maximum length before applying the rope scaling. # Thus, the maximum length after applying the rope scaling is # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor base = self.base * ( (self.scaling_factor * max_len / self.max_position_embeddings) - (self.scaling_factor - 1))**(self.rotary_dim / (self.rotary_dim - 2)) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache # Inverse dim formula to find dim based on number of rotations def _yarn_find_correction_dim(num_rotations: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048) -> float: return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) # Find dim range bounds based on rotations def _yarn_find_correction_range( low_rot: int, high_rot: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048) -> Tuple[int, int]: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func def _yarn_get_mscale(scale: float = 1) -> float: if scale <= 1: return 1.0 return 0.1 * math.log(scale) + 1.0 class YaRNScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with YaRN method. Credits to Peng et al. github.com/jquesnelle/yarn """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, scaling_factor: float, *, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: int = 32, beta_slow: int = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation self.mscale = float( _yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**( torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = (1 - _yarn_linear_ramp_mask( low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) sin = (freqs.sin() * self.mscale) cache = torch.cat((cos, sin), dim=-1) return cache class Phi3LongRoPERotaryEmbedding(nn.Module): def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, original_max_position_embeddings: int, base: int, is_neox_style: bool, short_factor: List[float], long_factor: List[float], short_mscale: float = 1.1, long_mscale: float = 1.225, ): super().__init__() if rotary_dim != head_size: raise ValueError( f"Rotary dim must be equal to head size, got {rotary_dim} " f"and {head_size}") if is_neox_style is False: raise ValueError( "Phi3SuScaledRotaryEmbedding only supports Neox style") self.head_size = head_size self.max_position_embeddings = max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings self.base = base self.short_factor = short_factor self.long_factor = long_factor self.short_mscale = short_mscale self.long_mscale = long_mscale short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) short_cache = short_cache.to(torch.get_default_dtype()) self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) long_cache = self._compute_cos_sin_cache( original_max_position_embeddings, long_factor, long_mscale) long_cache = long_cache.to(torch.get_default_dtype()) self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) long_short_cache = torch.cat( [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) self.register_buffer("long_short_cos_sin_cache", long_short_cache, persistent=False) def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( 0, self.head_size, 2, dtype=torch.float) / self.head_size))) return inv_freq def _compute_cos_sin_cache(self, max_position_embeddings: int, rescale_factors: List[float], mscale: float) -> torch.Tensor: inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() * mscale sin = freqs.sin() * mscale cache = torch.cat((cos, sin), dim=-1) return cache def forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) k = self.original_max_position_embeddings long_prompt_offset = (torch.any(positions > k).float() * torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( idx.device) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) cos, sin = cos_sin.chunk(2, dim=-1) cos = cos.repeat(1, 2).unsqueeze(-2) sin = sin.repeat(1, 2).unsqueeze(-2) query = query * cos + _rotate_neox(query) * sin key = key * cos + _rotate_neox(key) * sin return query.flatten(-2), key.flatten(-2) class ExtendedRotaryEmbedding(RotaryEmbedding): def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: inv_freqs = super()._compute_inv_freq(base) return self.apply_scaling(inv_freqs) def apply_scaling(self, freqs: torch.Tensor): scale_factor = 8 low_freq_factor = 1 high_freq_factor = 4 old_context_len = 8192 low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq if wavelen < high_freq_wavelen: new_freqs.append(freq) elif wavelen > low_freq_wavelen: new_freqs.append(freq / scale_factor) else: assert low_freq_wavelen != high_freq_wavelen smooth = (old_context_len / wavelen - low_freq_factor) / ( high_freq_factor - low_freq_factor) new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} def get_rope( head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if rope_scaling is not None: rope_scaling_tuple = { k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() } rope_scaling_args = tuple(rope_scaling_tuple.items()) else: rope_scaling_args = None key = (head_size, rotary_dim, max_position, base, is_neox_style, rope_scaling_args) if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) else: scaling_type = rope_scaling["type"] if scaling_type not in {"su", "longrope", "extended"}: scaling_factor = rope_scaling["factor"] if scaling_type == "extended": rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor) elif scaling_type == "yarn": original_max_position = rope_scaling[ "original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, scaling_factor, **extra_kwargs) elif scaling_type == "su" or scaling_type == "longrope": short_factor = rope_scaling["short_factor"] long_factor = rope_scaling["long_factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("short_mscale", "long_mscale") } rotary_emb = Phi3LongRoPERotaryEmbedding( head_size, rotary_dim, max_position, original_max_position, base, is_neox_style, short_factor, long_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb