rotary_embedding.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Rotary Positional Embeddings."""
  25. import math
  26. from typing import Any, Dict, Optional, Tuple, Union
  27. import torch
  28. import torch.nn as nn
  29. from aphrodite._C import ops
  30. def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
  31. """PyTorch-native implementation."""
  32. x1 = x[..., :x.shape[-1] // 2]
  33. x2 = x[..., x.shape[-1] // 2:]
  34. return torch.cat((-x2, x1), dim=-1)
  35. def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  36. """PyTorch-native implementation."""
  37. x1 = x[..., ::2]
  38. x2 = x[..., 1::2]
  39. x = torch.stack((-x2, x1), dim=-1)
  40. return x.flatten(-2)
  41. class RotaryEmbedding(nn.Module):
  42. """Original rotary positional embedding."""
  43. def __init__(
  44. self,
  45. head_size: int,
  46. rotary_dim: int,
  47. max_position_embeddings: int,
  48. base: int,
  49. is_neox_style: bool,
  50. ) -> None:
  51. super().__init__()
  52. self.head_size = head_size
  53. self.rotary_dim = rotary_dim
  54. self.max_position_embeddings = max_position_embeddings
  55. self.base = base
  56. self.is_neox_style = is_neox_style
  57. cache = self._compute_cos_sin_cache()
  58. cache = cache.to(torch.get_default_dtype())
  59. self.register_buffer("cos_sin_cache", cache, persistent=False)
  60. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  61. """Compute the inverse frequency."""
  62. # NOTE: The HF implementation uses `torch.arange(...).float()`.
  63. # However, we use `torch.arange(..., dtype=torch.float)` instead to
  64. # avoid numerical issues with large base values (e.g., 10000000).
  65. # This may cause a slight numerical difference between the HF
  66. # implementation and ours.
  67. # NOTE: To exactly match the HF implementation, we need to
  68. # use CPU to compute the cache and then move it to GPU. However, we
  69. # create the cache on GPU for faster initialization. This may cause
  70. # a slight numerical difference between the HF implementation and ours.
  71. inv_freq = 1.0 / (base**(torch.arange(
  72. 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
  73. return inv_freq
  74. def _compute_cos_sin_cache(self) -> torch.Tensor:
  75. """Compute the cos and sin cache."""
  76. inv_freq = self._compute_inv_freq(self.base)
  77. t = torch.arange(self.max_position_embeddings, dtype=torch.float)
  78. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  79. cos = freqs.cos()
  80. sin = freqs.sin()
  81. cache = torch.cat((cos, sin), dim=-1)
  82. return cache
  83. def _forward(
  84. self,
  85. positions: torch.Tensor,
  86. query: torch.Tensor,
  87. key: torch.Tensor,
  88. ) -> Tuple[torch.Tensor, torch.Tensor]:
  89. """PyTorch-native implementation equivalent to forward()."""
  90. query = query.view(*query.shape[:-1], -1, self.head_size)
  91. key = key.view(*key.shape[:-1], -1, self.head_size)
  92. query_rot = query[..., :self.rotary_dim]
  93. key_rot = key[..., :self.rotary_dim]
  94. if self.rotary_dim < self.head_size:
  95. query_pass = query[..., self.rotary_dim:]
  96. key_pass = key[..., self.rotary_dim:]
  97. cos_sin = self.cos_sin_cache[positions]
  98. cos, sin = cos_sin.chunk(2, dim=-1)
  99. if self.is_neox_style:
  100. # NOTE: Here we assume that the positions tensor has the
  101. # shape [batch_size, seq_len].
  102. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  103. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  104. else:
  105. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  106. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  107. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  108. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  109. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  110. if self.rotary_dim < self.head_size:
  111. query = torch.cat((query_rot, query_pass), dim=-1)
  112. key = torch.cat((key_rot, key_pass), dim=-1)
  113. else:
  114. query = query_rot
  115. key = key_rot
  116. query = query.flatten(-2)
  117. key = key.flatten(-2)
  118. return query, key
  119. def forward(
  120. self,
  121. positions: torch.Tensor,
  122. query: torch.Tensor,
  123. key: torch.Tensor,
  124. ) -> Tuple[torch.Tensor, torch.Tensor]:
  125. # ops.rotary_embedding() is an in-place operation that
  126. # updates the query and key tensors.
  127. ops.rotary_embedding(positions, query, key, self.head_size,
  128. self.cos_sin_cache, self.is_neox_style)
  129. return query, key
  130. class LinearScalingRotaryEmbedding(RotaryEmbedding):
  131. """RotaryEmbedding extended with linear scaling.
  132. Credits to the Reddit user /u/kaiokendev
  133. """
  134. def __init__(
  135. self,
  136. head_size: int,
  137. rotary_dim: int,
  138. max_position_embeddings: int,
  139. base: int,
  140. is_neox_style: bool,
  141. scaling_factor: float,
  142. ) -> None:
  143. self.scaling_factor = scaling_factor
  144. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  145. is_neox_style)
  146. def _compute_cos_sin_cache(self) -> torch.Tensor:
  147. inv_freq = self._compute_inv_freq(self.base)
  148. # NOTE: self.max_position_embeddings is the original
  149. # maximum length before applying the rope scaling.
  150. # Thus, the maximum length after applying the rope scaling is
  151. # self.max_position_embeddings * self.scaling_factor.
  152. max_len = self.max_position_embeddings * self.scaling_factor
  153. t = torch.arange(max_len, dtype=torch.float)
  154. t = t / self.scaling_factor
  155. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  156. cos = freqs.cos()
  157. sin = freqs.sin()
  158. cache = torch.cat((cos, sin), dim=-1)
  159. return cache
  160. class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
  161. """RotaryEmbedding extended with Dynamic NTK scaling.
  162. Credits to the Reddit users /u/bloc97 and /u/emozilla
  163. """
  164. def __init__(
  165. self,
  166. head_size: int,
  167. rotary_dim: int,
  168. max_position_embeddings: int,
  169. base: int,
  170. is_neox_style: bool,
  171. scaling_factor: float,
  172. ) -> None:
  173. self.scaling_factor = scaling_factor
  174. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  175. is_neox_style)
  176. def _compute_cos_sin_cache(self) -> torch.Tensor:
  177. # NOTE: self.max_position_embeddings is the original
  178. # maximum length before applying the rope scaling.
  179. # Thus, the maximum length after applying the rope scaling is
  180. # self.max_position_embeddings * self.scaling_factor.
  181. max_len = self.max_position_embeddings * self.scaling_factor
  182. base = self.base * (
  183. (self.scaling_factor * max_len / self.max_position_embeddings) -
  184. (self.scaling_factor - 1))**(self.rotary_dim /
  185. (self.rotary_dim - 2))
  186. inv_freq = self._compute_inv_freq(base)
  187. t = torch.arange(max_len, dtype=torch.float)
  188. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  189. cos = freqs.cos()
  190. sin = freqs.sin()
  191. cache = torch.cat((cos, sin), dim=-1)
  192. return cache
  193. # Inverse dim formula to find dim based on number of rotations
  194. def _yarn_find_correction_dim(num_rotations: int,
  195. dim: int,
  196. base: float = 10000,
  197. max_position_embeddings: int = 2048) -> float:
  198. return (dim * math.log(max_position_embeddings /
  199. (num_rotations * 2 * math.pi))) / (2 *
  200. math.log(base))
  201. # Find dim range bounds based on rotations
  202. def _yarn_find_correction_range(low_rot: int,
  203. high_rot: int,
  204. dim: int,
  205. base: float = 10000,
  206. max_position_embeddings: int = 2048) -> int:
  207. low = math.floor(
  208. _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
  209. high = math.ceil(
  210. _yarn_find_correction_dim(high_rot, dim, base,
  211. max_position_embeddings))
  212. return max(low, 0), min(high, dim - 1) # Clamp values just in case
  213. def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
  214. dtype: torch.dtype,
  215. device: torch.device) -> torch.Tensor:
  216. if low == high:
  217. high += 0.001 # Prevent singularity
  218. linear_func = (torch.arange(dim, dtype=dtype, device=device) -
  219. low) / (high - low)
  220. ramp_func = torch.clamp(linear_func, 0, 1)
  221. return ramp_func
  222. def _yarn_get_mscale(scale: float = 1) -> float:
  223. if scale <= 1:
  224. return 1.0
  225. return 0.1 * math.log(scale) + 1.0
  226. class YaRNScalingRotaryEmbedding(RotaryEmbedding):
  227. """RotaryEmbedding extended with YaRN method.
  228. Credits to Peng et al. github.com/jquesnelle/yarn
  229. """
  230. def __init__(
  231. self,
  232. head_size: int,
  233. rotary_dim: int,
  234. max_position_embeddings: int,
  235. base: int,
  236. is_neox_style: bool,
  237. scaling_factor: float,
  238. *,
  239. extrapolation_factor: float = 1,
  240. attn_factor: float = 1,
  241. beta_fast: float = 32,
  242. beta_slow: float = 1,
  243. ) -> None:
  244. self.scaling_factor = scaling_factor
  245. self.extrapolation_factor = extrapolation_factor
  246. self.attn_factor = attn_factor
  247. self.beta_fast = beta_fast
  248. self.beta_slow = beta_slow
  249. # Get n-d magnitude scaling corrected for interpolation
  250. self.mscale = float(
  251. _yarn_get_mscale(self.scaling_factor) * attn_factor)
  252. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  253. is_neox_style)
  254. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  255. pos_freqs = self.base**(
  256. torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
  257. self.rotary_dim)
  258. inv_freq_extrapolation = 1.0 / pos_freqs
  259. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  260. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  261. self.rotary_dim, self.base,
  262. self.max_position_embeddings)
  263. # Get n-d rotational scaling corrected for extrapolation
  264. # FIXME: Add device here.
  265. # pylint: disable=no-value-for-parameter
  266. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  267. low, high, self.rotary_dim // 2,
  268. dtype=torch.float)) * self.extrapolation_factor
  269. inv_freq = inv_freq_interpolation * (
  270. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  271. return inv_freq
  272. def _compute_cos_sin_cache(self) -> torch.Tensor:
  273. inv_freq = self._compute_inv_freq(self.scaling_factor)
  274. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  275. dtype=torch.float32)
  276. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  277. cos = (freqs.cos() * self.mscale)
  278. sin = (freqs.sin() * self.mscale)
  279. cache = torch.cat((cos, sin), dim=-1)
  280. return cache
  281. _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
  282. def get_rope(
  283. head_size: int,
  284. rotary_dim: int,
  285. max_position: int,
  286. base: int,
  287. is_neox_style: bool = True,
  288. rope_scaling: Optional[Dict[str, Any]] = None,
  289. ) -> RotaryEmbedding:
  290. key = (head_size, rotary_dim, max_position, base, is_neox_style,
  291. tuple(rope_scaling.items()) if rope_scaling is not None else None)
  292. if key in _ROPE_DICT:
  293. return _ROPE_DICT[key]
  294. if rope_scaling is None:
  295. rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
  296. is_neox_style)
  297. else:
  298. scaling_type = rope_scaling["type"]
  299. scaling_factor = rope_scaling["factor"]
  300. if scaling_type == "linear":
  301. rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
  302. max_position, base,
  303. is_neox_style,
  304. scaling_factor)
  305. elif scaling_type == "dynamic":
  306. rotary_emb = DynamicNTKScalingRotaryEmbedding(
  307. head_size, rotary_dim, max_position, base, is_neox_style,
  308. scaling_factor)
  309. elif scaling_type == "yarn":
  310. original_max_position = rope_scaling[
  311. "original_max_position_embeddings"]
  312. assert max_position == original_max_position * scaling_factor
  313. extra_kwargs = {
  314. k: v
  315. for k, v in rope_scaling.items()
  316. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  317. "beta_slow")
  318. }
  319. rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
  320. original_max_position,
  321. base, is_neox_style,
  322. scaling_factor,
  323. **extra_kwargs)
  324. else:
  325. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  326. _ROPE_DICT[key] = rotary_emb
  327. return rotary_emb