# coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py # Copyright 2023 The PygmalionAI team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Rotary Positional Embeddings.""" import math from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn from aphrodite._C import ops def _rotate_neox(x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation.""" x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation.""" x1 = x[..., ::2] x2 = x[..., 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) class RotaryEmbedding(nn.Module): """Original rotary positional embedding.""" def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, ) -> None: super().__init__() self.head_size = head_size self.rotary_dim = rotary_dim self.max_position_embeddings = max_position_embeddings self.base = base self.is_neox_style = is_neox_style cache = self._compute_cos_sin_cache() cache = cache.to(torch.get_default_dtype()) self.register_buffer("cos_sin_cache", cache, persistent=False) def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE: The HF implementation uses `torch.arange(...).float()`. # However, we use `torch.arange(..., dtype=torch.float)` instead to # avoid numerical issues with large base values (e.g., 10000000). # This may cause a slight numerical difference between the HF # implementation and ours. # NOTE: To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / (base**(torch.arange( 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache def _forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) query_rot = query[..., :self.rotary_dim] key_rot = key[..., :self.rotary_dim] if self.rotary_dim < self.head_size: query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE: Here we assume that the positions tensor has the # shape [batch_size, seq_len]. cos = cos.repeat(1, 1, 2).unsqueeze(-2) sin = sin.repeat(1, 1, 2).unsqueeze(-2) else: cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj query_rot = query_rot * cos + rotate_fn(query_rot) * sin key_rot = key_rot * cos + rotate_fn(key_rot) * sin if self.rotary_dim < self.head_size: query = torch.cat((query_rot, query_pass), dim=-1) key = torch.cat((key_rot, key_pass), dim=-1) else: query = query_rot key = key_rot query = query.flatten(-2) key = key.flatten(-2) return query, key def forward( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: # ops.rotary_embedding() is an in-place operation that # updates the query and key tensors. ops.rotary_embedding(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style) return query, key class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, scaling_factor: float, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) # NOTE: self.max_position_embeddings is the original # maximum length before applying the rope scaling. # Thus, the maximum length after applying the rope scaling is # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor t = torch.arange(max_len, dtype=torch.float) t = t / self.scaling_factor freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, scaling_factor: float, ) -> None: self.scaling_factor = scaling_factor super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_cos_sin_cache(self) -> torch.Tensor: # NOTE: self.max_position_embeddings is the original # maximum length before applying the rope scaling. # Thus, the maximum length after applying the rope scaling is # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor base = self.base * ( (self.scaling_factor * max_len / self.max_position_embeddings) - (self.scaling_factor - 1))**(self.rotary_dim / (self.rotary_dim - 2)) inv_freq = self._compute_inv_freq(base) t = torch.arange(max_len, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache # Inverse dim formula to find dim based on number of rotations def _yarn_find_correction_dim(num_rotations: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048) -> float: return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) # Find dim range bounds based on rotations def _yarn_find_correction_range(low_rot: int, high_rot: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048) -> int: low = math.floor( _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil( _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: if low == high: high += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func def _yarn_get_mscale(scale: float = 1) -> float: if scale <= 1: return 1.0 return 0.1 * math.log(scale) + 1.0 class YaRNScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with YaRN method. Credits to Peng et al. github.com/jquesnelle/yarn """ def __init__( self, head_size: int, rotary_dim: int, max_position_embeddings: int, base: int, is_neox_style: bool, scaling_factor: float, *, extrapolation_factor: float = 1, attn_factor: float = 1, beta_fast: float = 32, beta_slow: float = 1, ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor self.attn_factor = attn_factor self.beta_fast = beta_fast self.beta_slow = beta_slow # Get n-d magnitude scaling corrected for interpolation self.mscale = float( _yarn_get_mscale(self.scaling_factor) * attn_factor) super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**( torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation # FIXME: Add device here. # pylint: disable=no-value-for-parameter inv_freq_mask = (1 - _yarn_linear_ramp_mask( low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) sin = (freqs.sin() * self.mscale) cache = torch.cat((cos, sin), dim=-1) return cache _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} def get_rope( head_size: int, rotary_dim: int, max_position: int, base: int, is_neox_style: bool = True, rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: key = (head_size, rotary_dim, max_position, base, is_neox_style, tuple(rope_scaling.items()) if rope_scaling is not None else None) if key in _ROPE_DICT: return _ROPE_DICT[key] if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style) else: scaling_type = rope_scaling["type"] scaling_factor = rope_scaling["factor"] if scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor) elif scaling_type == "dynamic": rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor) elif scaling_type == "yarn": original_max_position = rope_scaling[ "original_max_position_embeddings"] assert max_position == original_max_position * scaling_factor extra_kwargs = { k: v for k, v in rope_scaling.items() if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, original_max_position, base, is_neox_style, scaling_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb