rotary_embedding.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Rotary Positional Embeddings."""
  25. import math
  26. from typing import Any, Dict, List, Optional, Tuple, Union
  27. import torch
  28. import torch.nn as nn
  29. from aphrodite._C import ops
  30. def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
  31. """PyTorch-native implementation."""
  32. x1 = x[..., :x.shape[-1] // 2]
  33. x2 = x[..., x.shape[-1] // 2:]
  34. return torch.cat((-x2, x1), dim=-1)
  35. def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  36. """PyTorch-native implementation."""
  37. x1 = x[..., ::2]
  38. x2 = x[..., 1::2]
  39. x = torch.stack((-x2, x1), dim=-1)
  40. return x.flatten(-2)
  41. class RotaryEmbedding(nn.Module):
  42. """Original rotary positional embedding."""
  43. def __init__(
  44. self,
  45. head_size: int,
  46. rotary_dim: int,
  47. max_position_embeddings: int,
  48. base: int,
  49. is_neox_style: bool,
  50. ) -> None:
  51. super().__init__()
  52. self.head_size = head_size
  53. self.rotary_dim = rotary_dim
  54. self.max_position_embeddings = max_position_embeddings
  55. self.base = base
  56. self.is_neox_style = is_neox_style
  57. cache = self._compute_cos_sin_cache()
  58. cache = cache.to(torch.get_default_dtype())
  59. self.register_buffer("cos_sin_cache", cache, persistent=False)
  60. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  61. """Compute the inverse frequency."""
  62. # NOTE: The HF implementation uses `torch.arange(...).float()`.
  63. # However, we use `torch.arange(..., dtype=torch.float)` instead to
  64. # avoid numerical issues with large base values (e.g., 10000000).
  65. # This may cause a slight numerical difference between the HF
  66. # implementation and ours.
  67. # NOTE: To exactly match the HF implementation, we need to
  68. # use CPU to compute the cache and then move it to GPU. However, we
  69. # create the cache on GPU for faster initialization. This may cause
  70. # a slight numerical difference between the HF implementation and ours.
  71. inv_freq = 1.0 / (base**(torch.arange(
  72. 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
  73. return inv_freq
  74. def _compute_cos_sin_cache(self) -> torch.Tensor:
  75. """Compute the cos and sin cache."""
  76. inv_freq = self._compute_inv_freq(self.base)
  77. t = torch.arange(self.max_position_embeddings, dtype=torch.float)
  78. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  79. cos = freqs.cos()
  80. sin = freqs.sin()
  81. cache = torch.cat((cos, sin), dim=-1)
  82. return cache
  83. def _forward(
  84. self,
  85. positions: torch.Tensor,
  86. query: torch.Tensor,
  87. key: torch.Tensor,
  88. offsets: Optional[torch.Tensor] = None,
  89. ) -> Tuple[torch.Tensor, torch.Tensor]:
  90. """PyTorch-native implementation equivalent to forward()."""
  91. query = query.view(*query.shape[:-1], -1, self.head_size)
  92. key = key.view(*key.shape[:-1], -1, self.head_size)
  93. query_rot = query[..., :self.rotary_dim]
  94. key_rot = key[..., :self.rotary_dim]
  95. if self.rotary_dim < self.head_size:
  96. query_pass = query[..., self.rotary_dim:]
  97. key_pass = key[..., self.rotary_dim:]
  98. self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
  99. cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
  100. if offsets is not None else positions]
  101. cos, sin = cos_sin.chunk(2, dim=-1)
  102. if self.is_neox_style:
  103. # NOTE: Here we assume that the positions tensor has the
  104. # shape [batch_size, seq_len].
  105. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  106. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  107. else:
  108. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  109. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  110. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  111. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  112. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  113. if self.rotary_dim < self.head_size:
  114. query = torch.cat((query_rot, query_pass), dim=-1)
  115. key = torch.cat((key_rot, key_pass), dim=-1)
  116. else:
  117. query = query_rot
  118. key = key_rot
  119. query = query.flatten(-2)
  120. key = key.flatten(-2)
  121. return query, key
  122. def forward(
  123. self,
  124. positions: torch.Tensor,
  125. query: torch.Tensor,
  126. key: torch.Tensor,
  127. offsets: Optional[torch.Tensor] = None,
  128. ) -> Tuple[torch.Tensor, torch.Tensor]:
  129. # ops.rotary_embedding() is an in-place operation that
  130. # updates the query and key tensors.
  131. self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
  132. # ops.rotary_embedding()/batched_rotary_embedding() are in-place ops
  133. # that update the qk tensors
  134. if offsets is not None:
  135. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  136. self.cos_sin_cache,
  137. self.is_neox_style, self.rotary_dim,
  138. offsets)
  139. else:
  140. ops.rotary_embedding(positions, query, key, self.head_size,
  141. self.cos_sin_cache, self.is_neox_style)
  142. return query, key
  143. class LinearScalingRotaryEmbedding(RotaryEmbedding):
  144. """RotaryEmbedding extended with linear scaling.
  145. Credits to the Reddit user /u/kaiokendev
  146. """
  147. def __init__(
  148. self,
  149. head_size: int,
  150. rotary_dim: int,
  151. max_position_embeddings: int,
  152. base: int,
  153. is_neox_style: bool,
  154. scaling_factors: Union[List[float], float],
  155. ) -> None:
  156. if isinstance(scaling_factors, float):
  157. scaling_factors = [scaling_factors]
  158. self.scaling_factors = scaling_factors
  159. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  160. is_neox_style)
  161. def _compute_cos_sin_cache(self) -> torch.Tensor:
  162. inv_freq = self._compute_inv_freq(self.base)
  163. cache_list = []
  164. for scaling_factor in self.scaling_factors:
  165. # NOTE: self.max_position_embeddings is the original
  166. # maximum length before applying the rope scaling.
  167. # Thus, the maximum length after applying the rope scaling is
  168. # self.max_position_embeddings * self.scaling_factor.
  169. max_len = self.max_position_embeddings * scaling_factor
  170. t = torch.arange(max_len, dtype=torch.float)
  171. t = t / scaling_factor
  172. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  173. cos = freqs.cos()
  174. sin = freqs.sin()
  175. cache = torch.cat((cos, sin), dim=-1)
  176. cache_list.append(cache)
  177. return torch.cat(cache_list, dim=0)
  178. class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
  179. """RotaryEmbedding extended with Dynamic NTK scaling.
  180. Credits to the Reddit users /u/bloc97 and /u/emozilla
  181. """
  182. def __init__(
  183. self,
  184. head_size: int,
  185. rotary_dim: int,
  186. max_position_embeddings: int,
  187. base: int,
  188. is_neox_style: bool,
  189. scaling_factor: float,
  190. ) -> None:
  191. self.scaling_factor = scaling_factor
  192. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  193. is_neox_style)
  194. def _compute_cos_sin_cache(self) -> torch.Tensor:
  195. # NOTE: self.max_position_embeddings is the original
  196. # maximum length before applying the rope scaling.
  197. # Thus, the maximum length after applying the rope scaling is
  198. # self.max_position_embeddings * self.scaling_factor.
  199. max_len = self.max_position_embeddings * self.scaling_factor
  200. base = self.base * (
  201. (self.scaling_factor * max_len / self.max_position_embeddings) -
  202. (self.scaling_factor - 1))**(self.rotary_dim /
  203. (self.rotary_dim - 2))
  204. inv_freq = self._compute_inv_freq(base)
  205. t = torch.arange(max_len, dtype=torch.float)
  206. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  207. cos = freqs.cos()
  208. sin = freqs.sin()
  209. cache = torch.cat((cos, sin), dim=-1)
  210. return cache
  211. # Inverse dim formula to find dim based on number of rotations
  212. def _yarn_find_correction_dim(num_rotations: int,
  213. dim: int,
  214. base: float = 10000,
  215. max_position_embeddings: int = 2048) -> float:
  216. return (dim * math.log(max_position_embeddings /
  217. (num_rotations * 2 * math.pi))) / (2 *
  218. math.log(base))
  219. # Find dim range bounds based on rotations
  220. def _yarn_find_correction_range(low_rot: int,
  221. high_rot: int,
  222. dim: int,
  223. base: float = 10000,
  224. max_position_embeddings: int = 2048) -> int:
  225. low = math.floor(
  226. _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
  227. high = math.ceil(
  228. _yarn_find_correction_dim(high_rot, dim, base,
  229. max_position_embeddings))
  230. return max(low, 0), min(high, dim - 1) # Clamp values just in case
  231. def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
  232. dtype: torch.dtype,
  233. device: torch.device) -> torch.Tensor:
  234. if low == high:
  235. high += 0.001 # Prevent singularity
  236. linear_func = (torch.arange(dim, dtype=dtype, device=device) -
  237. low) / (high - low)
  238. ramp_func = torch.clamp(linear_func, 0, 1)
  239. return ramp_func
  240. def _yarn_get_mscale(scale: float = 1) -> float:
  241. if scale <= 1:
  242. return 1.0
  243. return 0.1 * math.log(scale) + 1.0
  244. class YaRNScalingRotaryEmbedding(RotaryEmbedding):
  245. """RotaryEmbedding extended with YaRN method.
  246. Credits to Peng et al. github.com/jquesnelle/yarn
  247. """
  248. def __init__(
  249. self,
  250. head_size: int,
  251. rotary_dim: int,
  252. max_position_embeddings: int,
  253. base: int,
  254. is_neox_style: bool,
  255. scaling_factor: float,
  256. *,
  257. extrapolation_factor: float = 1,
  258. attn_factor: float = 1,
  259. beta_fast: float = 32,
  260. beta_slow: float = 1,
  261. ) -> None:
  262. self.scaling_factor = scaling_factor
  263. self.extrapolation_factor = extrapolation_factor
  264. self.attn_factor = attn_factor
  265. self.beta_fast = beta_fast
  266. self.beta_slow = beta_slow
  267. # Get n-d magnitude scaling corrected for interpolation
  268. self.mscale = float(
  269. _yarn_get_mscale(self.scaling_factor) * attn_factor)
  270. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  271. is_neox_style)
  272. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  273. pos_freqs = self.base**(
  274. torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
  275. self.rotary_dim)
  276. inv_freq_extrapolation = 1.0 / pos_freqs
  277. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  278. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  279. self.rotary_dim, self.base,
  280. self.max_position_embeddings)
  281. # Get n-d rotational scaling corrected for extrapolation
  282. # FIXME: Add device here.
  283. # pylint: disable=no-value-for-parameter
  284. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  285. low, high, self.rotary_dim // 2,
  286. dtype=torch.float)) * self.extrapolation_factor
  287. inv_freq = inv_freq_interpolation * (
  288. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  289. return inv_freq
  290. def _compute_cos_sin_cache(self) -> torch.Tensor:
  291. inv_freq = self._compute_inv_freq(self.scaling_factor)
  292. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  293. dtype=torch.float32)
  294. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  295. cos = (freqs.cos() * self.mscale)
  296. sin = (freqs.sin() * self.mscale)
  297. cache = torch.cat((cos, sin), dim=-1)
  298. return cache
  299. _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
  300. def get_rope(
  301. head_size: int,
  302. rotary_dim: int,
  303. max_position: int,
  304. base: int,
  305. is_neox_style: bool = True,
  306. rope_scaling: Optional[Dict[str, Any]] = None,
  307. ) -> RotaryEmbedding:
  308. key = (head_size, rotary_dim, max_position, base, is_neox_style,
  309. tuple(rope_scaling.items()) if rope_scaling is not None else None)
  310. if key in _ROPE_DICT:
  311. return _ROPE_DICT[key]
  312. if rope_scaling is None:
  313. rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
  314. is_neox_style)
  315. else:
  316. scaling_type = rope_scaling["type"]
  317. scaling_factor = rope_scaling["factor"]
  318. if scaling_type == "linear":
  319. rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
  320. max_position, base,
  321. is_neox_style,
  322. scaling_factor)
  323. elif scaling_type == "dynamic":
  324. rotary_emb = DynamicNTKScalingRotaryEmbedding(
  325. head_size, rotary_dim, max_position, base, is_neox_style,
  326. scaling_factor)
  327. elif scaling_type == "yarn":
  328. original_max_position = rope_scaling[
  329. "original_max_position_embeddings"]
  330. assert max_position == original_max_position * scaling_factor
  331. extra_kwargs = {
  332. k: v
  333. for k, v in rope_scaling.items()
  334. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  335. "beta_slow")
  336. }
  337. rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
  338. original_max_position,
  339. base, is_neox_style,
  340. scaling_factor,
  341. **extra_kwargs)
  342. else:
  343. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  344. _ROPE_DICT[key] = rotary_emb
  345. return rotary_emb