123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- # coding=utf-8
- # Adapted from
- # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The PygmalionAI team.
- # Copyright 2023 The vLLM team.
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Rotary Positional Embeddings."""
- import math
- from typing import Any, Dict, List, Optional, Tuple, Union
- import torch
- import torch.nn as nn
- from aphrodite._C import ops
- def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
- def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
- x1 = x[..., ::2]
- x2 = x[..., 1::2]
- x = torch.stack((-x2, x1), dim=-1)
- return x.flatten(-2)
- class RotaryEmbedding(nn.Module):
- """Original rotary positional embedding."""
- def __init__(
- self,
- head_size: int,
- rotary_dim: int,
- max_position_embeddings: int,
- base: int,
- is_neox_style: bool,
- ) -> None:
- super().__init__()
- self.head_size = head_size
- self.rotary_dim = rotary_dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- self.is_neox_style = is_neox_style
- cache = self._compute_cos_sin_cache()
- cache = cache.to(torch.get_default_dtype())
- self.register_buffer("cos_sin_cache", cache, persistent=False)
- def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
- """Compute the inverse frequency."""
- # NOTE: The HF implementation uses `torch.arange(...).float()`.
- # However, we use `torch.arange(..., dtype=torch.float)` instead to
- # avoid numerical issues with large base values (e.g., 10000000).
- # This may cause a slight numerical difference between the HF
- # implementation and ours.
- # NOTE: To exactly match the HF implementation, we need to
- # use CPU to compute the cache and then move it to GPU. However, we
- # create the cache on GPU for faster initialization. This may cause
- # a slight numerical difference between the HF implementation and ours.
- inv_freq = 1.0 / (base**(torch.arange(
- 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
- return inv_freq
- def _compute_cos_sin_cache(self) -> torch.Tensor:
- """Compute the cos and sin cache."""
- inv_freq = self._compute_inv_freq(self.base)
- t = torch.arange(self.max_position_embeddings, dtype=torch.float)
- freqs = torch.einsum("i,j -> ij", t, inv_freq)
- cos = freqs.cos()
- sin = freqs.sin()
- cache = torch.cat((cos, sin), dim=-1)
- return cache
- def _forward(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- offsets: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """PyTorch-native implementation equivalent to forward()."""
- query = query.view(*query.shape[:-1], -1, self.head_size)
- key = key.view(*key.shape[:-1], -1, self.head_size)
- query_rot = query[..., :self.rotary_dim]
- key_rot = key[..., :self.rotary_dim]
- if self.rotary_dim < self.head_size:
- query_pass = query[..., self.rotary_dim:]
- key_pass = key[..., self.rotary_dim:]
- self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
- positions.device)
- cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
- if offsets is not None else positions]
- cos, sin = cos_sin.chunk(2, dim=-1)
- if self.is_neox_style:
- # NOTE: Here we assume that the positions tensor has the
- # shape [batch_size, seq_len].
- cos = cos.repeat(1, 1, 2).unsqueeze(-2)
- sin = sin.repeat(1, 1, 2).unsqueeze(-2)
- else:
- cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
- sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
- rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
- query_rot = query_rot * cos + rotate_fn(query_rot) * sin
- key_rot = key_rot * cos + rotate_fn(key_rot) * sin
- if self.rotary_dim < self.head_size:
- query = torch.cat((query_rot, query_pass), dim=-1)
- key = torch.cat((key_rot, key_pass), dim=-1)
- else:
- query = query_rot
- key = key_rot
- query = query.flatten(-2)
- key = key.flatten(-2)
- return query, key
- def forward(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- offsets: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
- # ops.rotary_embedding()/batched_rotary_embedding()
- # are in-place operations that update the query and key tensors.
- if offsets is not None:
- ops.batched_rotary_embedding(positions, query, key, self.head_size,
- self.cos_sin_cache,
- self.is_neox_style, self.rotary_dim,
- offsets)
- else:
- ops.rotary_embedding(positions, query, key, self.head_size,
- self.cos_sin_cache, self.is_neox_style)
- return query, key
- class LinearScalingRotaryEmbedding(RotaryEmbedding):
- """RotaryEmbedding extended with linear scaling.
- Credits to the Reddit user /u/kaiokendev
- """
- def __init__(
- self,
- head_size: int,
- rotary_dim: int,
- max_position_embeddings: int,
- base: int,
- is_neox_style: bool,
- scaling_factors: Union[List[float], float],
- ) -> None:
- if isinstance(scaling_factors, float):
- scaling_factors = [scaling_factors]
- self.scaling_factors = scaling_factors
- super().__init__(head_size, rotary_dim, max_position_embeddings, base,
- is_neox_style)
- def _compute_cos_sin_cache(self) -> torch.Tensor:
- inv_freq = self._compute_inv_freq(self.base)
- cache_list = []
- for scaling_factor in self.scaling_factors:
- # NOTE: self.max_position_embeddings is the original
- # maximum length before applying the rope scaling.
- # Thus, the maximum length after applying the rope scaling is
- # self.max_position_embeddings * self.scaling_factor.
- max_len = self.max_position_embeddings * scaling_factor
- t = torch.arange(max_len, dtype=torch.float)
- t = t / scaling_factor
- freqs = torch.einsum("i,j -> ij", t, inv_freq)
- cos = freqs.cos()
- sin = freqs.sin()
- cache = torch.cat((cos, sin), dim=-1)
- cache_list.append(cache)
- return torch.cat(cache_list, dim=0)
- class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
- """RotaryEmbedding extended with Dynamic NTK scaling.
- Credits to the Reddit users /u/bloc97 and /u/emozilla
- """
- def __init__(
- self,
- head_size: int,
- rotary_dim: int,
- max_position_embeddings: int,
- base: int,
- is_neox_style: bool,
- scaling_factor: float,
- ) -> None:
- self.scaling_factor = scaling_factor
- super().__init__(head_size, rotary_dim, max_position_embeddings, base,
- is_neox_style)
- def _compute_cos_sin_cache(self) -> torch.Tensor:
- # NOTE: self.max_position_embeddings is the original
- # maximum length before applying the rope scaling.
- # Thus, the maximum length after applying the rope scaling is
- # self.max_position_embeddings * self.scaling_factor.
- max_len = self.max_position_embeddings * self.scaling_factor
- base = self.base * (
- (self.scaling_factor * max_len / self.max_position_embeddings) -
- (self.scaling_factor - 1))**(self.rotary_dim /
- (self.rotary_dim - 2))
- inv_freq = self._compute_inv_freq(base)
- t = torch.arange(max_len, dtype=torch.float)
- freqs = torch.einsum("i,j -> ij", t, inv_freq)
- cos = freqs.cos()
- sin = freqs.sin()
- cache = torch.cat((cos, sin), dim=-1)
- return cache
- # Inverse dim formula to find dim based on number of rotations
- def _yarn_find_correction_dim(num_rotations: int,
- dim: int,
- base: float = 10000,
- max_position_embeddings: int = 2048) -> float:
- return (dim * math.log(max_position_embeddings /
- (num_rotations * 2 * math.pi))) / (2 *
- math.log(base))
- # Find dim range bounds based on rotations
- def _yarn_find_correction_range(
- low_rot: int,
- high_rot: int,
- dim: int,
- base: float = 10000,
- max_position_embeddings: int = 2048) -> Tuple[int, int]:
- low = math.floor(
- _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
- high = math.ceil(
- _yarn_find_correction_dim(high_rot, dim, base,
- max_position_embeddings))
- return max(low, 0), min(high, dim - 1) # Clamp values just in case
- def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
- dtype: torch.dtype) -> torch.Tensor:
- if low == high:
- high += 0.001 # Prevent singularity
- linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
- ramp_func = torch.clamp(linear_func, 0, 1)
- return ramp_func
- def _yarn_get_mscale(scale: float = 1) -> float:
- if scale <= 1:
- return 1.0
- return 0.1 * math.log(scale) + 1.0
- class YaRNScalingRotaryEmbedding(RotaryEmbedding):
- """RotaryEmbedding extended with YaRN method.
- Credits to Peng et al. github.com/jquesnelle/yarn
- """
- def __init__(
- self,
- head_size: int,
- rotary_dim: int,
- max_position_embeddings: int,
- base: int,
- is_neox_style: bool,
- scaling_factor: float,
- *,
- extrapolation_factor: float = 1,
- attn_factor: float = 1,
- beta_fast: int = 32,
- beta_slow: int = 1,
- ) -> None:
- self.scaling_factor = scaling_factor
- self.extrapolation_factor = extrapolation_factor
- self.attn_factor = attn_factor
- self.beta_fast = beta_fast
- self.beta_slow = beta_slow
- # Get n-d magnitude scaling corrected for interpolation
- self.mscale = float(
- _yarn_get_mscale(self.scaling_factor) * attn_factor)
- super().__init__(head_size, rotary_dim, max_position_embeddings, base,
- is_neox_style)
- def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
- pos_freqs = self.base**(
- torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
- self.rotary_dim)
- inv_freq_extrapolation = 1.0 / pos_freqs
- inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
- low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
- self.rotary_dim, self.base,
- self.max_position_embeddings)
- # Get n-d rotational scaling corrected for extrapolation
- inv_freq_mask = (1 - _yarn_linear_ramp_mask(
- low, high, self.rotary_dim // 2,
- dtype=torch.float)) * self.extrapolation_factor
- inv_freq = inv_freq_interpolation * (
- 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
- return inv_freq
- def _compute_cos_sin_cache(self) -> torch.Tensor:
- inv_freq = self._compute_inv_freq(self.scaling_factor)
- t = torch.arange(self.max_position_embeddings * self.scaling_factor,
- dtype=torch.float32)
- freqs = torch.einsum("i,j -> ij", t, inv_freq)
- cos = (freqs.cos() * self.mscale)
- sin = (freqs.sin() * self.mscale)
- cache = torch.cat((cos, sin), dim=-1)
- return cache
- class Phi3LongRoPERotaryEmbedding(nn.Module):
- def __init__(
- self,
- head_size: int,
- rotary_dim: int,
- max_position_embeddings: int,
- original_max_position_embeddings: int,
- base: int,
- is_neox_style: bool,
- short_factor: List[float],
- long_factor: List[float],
- short_mscale: float = 1.1,
- long_mscale: float = 1.225,
- ):
- super().__init__()
- if rotary_dim != head_size:
- raise ValueError(
- f"Rotary dim must be equal to head size, got {rotary_dim} "
- f"and {head_size}")
- if is_neox_style is False:
- raise ValueError(
- "Phi3SuScaledRotaryEmbedding only supports Neox style")
- self.head_size = head_size
- self.max_position_embeddings = max_position_embeddings
- self.original_max_position_embeddings = original_max_position_embeddings
- self.base = base
- self.short_factor = short_factor
- self.long_factor = long_factor
- self.short_mscale = short_mscale
- self.long_mscale = long_mscale
- short_cache = self._compute_cos_sin_cache(
- original_max_position_embeddings, short_factor, short_mscale)
- short_cache = short_cache.to(torch.get_default_dtype())
- self.register_buffer("short_cos_sin_cache",
- short_cache,
- persistent=False)
- long_cache = self._compute_cos_sin_cache(
- original_max_position_embeddings, long_factor, long_mscale)
- long_cache = long_cache.to(torch.get_default_dtype())
- self.register_buffer("long_cos_sin_cache",
- long_cache,
- persistent=False)
- long_short_cache = torch.cat(
- [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
- self.register_buffer("long_short_cos_sin_cache",
- long_short_cache,
- persistent=False)
- def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
- rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
- inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
- 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
- return inv_freq
- def _compute_cos_sin_cache(self, max_position_embeddings: int,
- rescale_factors: List[float],
- mscale: float) -> torch.Tensor:
- inv_freq = self._compute_inv_freq(rescale_factors)
- t = torch.arange(max_position_embeddings, dtype=torch.float)
- freqs = torch.einsum("i,j -> ij", t, inv_freq)
- cos = freqs.cos() * mscale
- sin = freqs.sin() * mscale
- cache = torch.cat((cos, sin), dim=-1)
- return cache
- def forward(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor,
- offsets: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- query = query.view(*query.shape[:-1], -1, self.head_size)
- key = key.view(*key.shape[:-1], -1, self.head_size)
- k = self.original_max_position_embeddings
- long_prompt_offset = (torch.any(positions > k).float() *
- torch.full_like(positions, k)).long()
- idx = (torch.add(positions, long_prompt_offset)
- if long_prompt_offset is not None else positions)
- self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
- idx.device)
- idx = torch.add(idx, offsets) if offsets is not None else idx
- cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
- cos, sin = cos_sin.chunk(2, dim=-1)
- cos = cos.repeat(1, 2).unsqueeze(-2)
- sin = sin.repeat(1, 2).unsqueeze(-2)
- query = query * cos + _rotate_neox(query) * sin
- key = key * cos + _rotate_neox(key) * sin
- return query.flatten(-2), key.flatten(-2)
- class ExtendedRotaryEmbedding(RotaryEmbedding):
- def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
- inv_freqs = super()._compute_inv_freq(base)
- return self.apply_scaling(inv_freqs)
- def apply_scaling(self, freqs: torch.Tensor):
- scale_factor = 8
- low_freq_factor = 1
- high_freq_factor = 4
- old_context_len = 8192
- low_freq_wavelen = old_context_len / low_freq_factor
- high_freq_wavelen = old_context_len / high_freq_factor
- new_freqs = []
- for freq in freqs:
- wavelen = 2 * math.pi / freq
- if wavelen < high_freq_wavelen:
- new_freqs.append(freq)
- elif wavelen > low_freq_wavelen:
- new_freqs.append(freq / scale_factor)
- else:
- assert low_freq_wavelen != high_freq_wavelen
- smooth = (old_context_len / wavelen - low_freq_factor) / (
- high_freq_factor - low_freq_factor)
- new_freqs.append((1 - smooth) * freq / scale_factor +
- smooth * freq)
- return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
- _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
- def get_rope(
- head_size: int,
- rotary_dim: int,
- max_position: int,
- base: int,
- is_neox_style: bool = True,
- rope_scaling: Optional[Dict[str, Any]] = None,
- ) -> RotaryEmbedding:
- if rope_scaling is not None:
- rope_scaling_tuple = {
- k: tuple(v) if isinstance(v, list) else v
- for k, v in rope_scaling.items()
- }
- rope_scaling_args = tuple(rope_scaling_tuple.items())
- else:
- rope_scaling_args = None
- key = (head_size, rotary_dim, max_position, base, is_neox_style,
- rope_scaling_args)
- if key in _ROPE_DICT:
- return _ROPE_DICT[key]
- if rope_scaling is None:
- rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
- is_neox_style)
- else:
- scaling_type = rope_scaling["type"]
- if scaling_type not in {"su", "longrope", "extended"}:
- scaling_factor = rope_scaling["factor"]
- if scaling_type == "extended":
- rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
- max_position, base,
- is_neox_style)
- elif scaling_type == "linear":
- rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
- max_position, base,
- is_neox_style,
- scaling_factor)
- elif scaling_type == "dynamic":
- rotary_emb = DynamicNTKScalingRotaryEmbedding(
- head_size, rotary_dim, max_position, base, is_neox_style,
- scaling_factor)
- elif scaling_type == "yarn":
- original_max_position = rope_scaling[
- "original_max_position_embeddings"]
- extra_kwargs = {
- k: v
- for k, v in rope_scaling.items()
- if k in ("extrapolation_factor", "attn_factor", "beta_fast",
- "beta_slow")
- }
- rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
- original_max_position,
- base, is_neox_style,
- scaling_factor,
- **extra_kwargs)
- elif scaling_type == "su" or scaling_type == "longrope":
- short_factor = rope_scaling["short_factor"]
- long_factor = rope_scaling["long_factor"]
- original_max_position = rope_scaling[
- "original_max_position_embeddings"]
- extra_kwargs = {
- k: v
- for k, v in rope_scaling.items()
- if k in ("short_mscale", "long_mscale")
- }
- rotary_emb = Phi3LongRoPERotaryEmbedding(
- head_size, rotary_dim, max_position, original_max_position,
- base, is_neox_style, short_factor, long_factor, **extra_kwargs)
- else:
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
- _ROPE_DICT[key] = rotary_emb
- return rotary_emb
|