rotary_embedding.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Rotary Positional Embeddings."""
  25. import math
  26. from typing import Any, Dict, List, Optional, Tuple, Union
  27. import torch
  28. import torch.nn as nn
  29. from aphrodite.modeling._custom_op import CustomOp
  30. from aphrodite.platforms import current_platform
  31. def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
  32. x1 = x[..., :x.shape[-1] // 2]
  33. x2 = x[..., x.shape[-1] // 2:]
  34. return torch.cat((-x2, x1), dim=-1)
  35. def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  36. x1 = x[..., ::2]
  37. x2 = x[..., 1::2]
  38. x = torch.stack((-x2, x1), dim=-1)
  39. return x.flatten(-2)
  40. # for TPUs
  41. def _apply_rotary_emb(
  42. x: torch.Tensor,
  43. cos: torch.Tensor,
  44. sin: torch.Tensor,
  45. ) -> torch.Tensor:
  46. """
  47. Args:
  48. x: [num_tokens, num_heads, head_size]
  49. cos: [num_tokens, head_size // 2]
  50. sin: [num_tokens, head_size // 2]
  51. """
  52. orig_dtype = x.dtype
  53. x = x.float()
  54. x1, x2 = torch.chunk(x, 2, dim=-1)
  55. cos = cos.unsqueeze(-2)
  56. sin = sin.unsqueeze(-2)
  57. o1 = x1 * cos - x2 * sin
  58. o2 = x2 * cos + x1 * sin
  59. return torch.cat((o1, o2), dim=-1).to(orig_dtype)
  60. class RotaryEmbedding(CustomOp):
  61. """Original rotary positional embedding."""
  62. def __init__(
  63. self,
  64. head_size: int,
  65. rotary_dim: int,
  66. max_position_embeddings: int,
  67. base: int,
  68. is_neox_style: bool,
  69. dtype: torch.dtype,
  70. ) -> None:
  71. super().__init__()
  72. self.head_size = head_size
  73. self.rotary_dim = rotary_dim
  74. self.max_position_embeddings = max_position_embeddings
  75. self.base = base
  76. self.is_neox_style = is_neox_style
  77. self.dtype = dtype
  78. cache = self._compute_cos_sin_cache()
  79. cache = cache.to(dtype)
  80. self.register_buffer("cos_sin_cache", cache, persistent=False)
  81. self.use_native2 = current_platform.is_tpu() and is_neox_style
  82. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  83. """Compute the inverse frequency."""
  84. # NOTE: The HF implementation uses `torch.arange(...).float()`.
  85. # However, we use `torch.arange(..., dtype=torch.float)` instead to
  86. # avoid numerical issues with large base values (e.g., 10000000).
  87. # This may cause a slight numerical difference between the HF
  88. # implementation and ours.
  89. # NOTE: To exactly match the HF implementation, we need to
  90. # use CPU to compute the cache and then move it to GPU. However, we
  91. # create the cache on GPU for faster initialization. This may cause
  92. # a slight numerical difference between the HF implementation and ours.
  93. inv_freq = 1.0 / (base**(torch.arange(
  94. 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
  95. return inv_freq
  96. def _compute_cos_sin_cache(self) -> torch.Tensor:
  97. """Compute the cos and sin cache."""
  98. inv_freq = self._compute_inv_freq(self.base)
  99. t = torch.arange(self.max_position_embeddings, dtype=torch.float)
  100. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  101. cos = freqs.cos()
  102. sin = freqs.sin()
  103. cache = torch.cat((cos, sin), dim=-1)
  104. return cache
  105. def forward_native(
  106. self,
  107. positions: torch.Tensor,
  108. query: torch.Tensor,
  109. key: torch.Tensor,
  110. offsets: Optional[torch.Tensor] = None,
  111. ) -> Tuple[torch.Tensor, torch.Tensor]:
  112. """A PyTorch-native implementation equivalent to forward().
  113. This method mimics the implementation of the custom CUDA kernel
  114. used in `forward_cuda()`.
  115. """
  116. query = query.view(*query.shape[:-1], -1, self.head_size)
  117. key = key.view(*key.shape[:-1], -1, self.head_size)
  118. query_rot = query[..., :self.rotary_dim]
  119. key_rot = key[..., :self.rotary_dim]
  120. if self.rotary_dim < self.head_size:
  121. query_pass = query[..., self.rotary_dim:]
  122. key_pass = key[..., self.rotary_dim:]
  123. self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
  124. positions.device, dtype=query.dtype)
  125. cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
  126. if offsets is not None else positions]
  127. cos, sin = cos_sin.chunk(2, dim=-1)
  128. if self.is_neox_style:
  129. # NOTE: Here we assume that the positions tensor has the
  130. # shape [batch_size, seq_len].
  131. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  132. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  133. else:
  134. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  135. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  136. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  137. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  138. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  139. if self.rotary_dim < self.head_size:
  140. query = torch.cat((query_rot, query_pass), dim=-1)
  141. key = torch.cat((key_rot, key_pass), dim=-1)
  142. else:
  143. query = query_rot
  144. key = key_rot
  145. query = query.flatten(-2)
  146. key = key.flatten(-2)
  147. return query, key
  148. def forward_native2(
  149. self,
  150. positions: torch.Tensor,
  151. query: torch.Tensor,
  152. key: torch.Tensor,
  153. offsets: Optional[torch.Tensor] = None,
  154. ) -> Tuple[torch.Tensor, torch.Tensor]:
  155. """Another PyTorch-native implementation of forward().
  156. This method might perform better than `forward_native()` when compiled.
  157. """
  158. if offsets is not None:
  159. positions = positions + offsets
  160. positions = positions.flatten()
  161. num_tokens = positions.shape[0]
  162. cos_sin = self.cos_sin_cache.index_select(0, positions)
  163. cos, sin = cos_sin.chunk(2, dim=-1)
  164. query_shape = query.shape
  165. query = query.view(num_tokens, -1, self.head_size)
  166. query_rot = query[..., :self.rotary_dim]
  167. query_pass = query[..., self.rotary_dim:]
  168. query_rot = _apply_rotary_emb(query_rot, cos, sin)
  169. query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
  170. key_shape = key.shape
  171. key = key.view(num_tokens, -1, self.head_size)
  172. key_rot = key[..., :self.rotary_dim]
  173. key_pass = key[..., self.rotary_dim:]
  174. key_rot = _apply_rotary_emb(key_rot, cos, sin)
  175. key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
  176. return query, key
  177. def forward_cuda(
  178. self,
  179. positions: torch.Tensor,
  180. query: torch.Tensor,
  181. key: torch.Tensor,
  182. offsets: Optional[torch.Tensor] = None,
  183. ) -> Tuple[torch.Tensor, torch.Tensor]:
  184. from aphrodite import _custom_ops as ops
  185. self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
  186. dtype=query.dtype)
  187. # ops.rotary_embedding()/batched_rotary_embedding()
  188. # are in-place operations that update the query and key tensors.
  189. if offsets is not None:
  190. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  191. self.cos_sin_cache,
  192. self.is_neox_style, self.rotary_dim,
  193. offsets)
  194. else:
  195. ops.rotary_embedding(positions, query, key, self.head_size,
  196. self.cos_sin_cache, self.is_neox_style)
  197. return query, key
  198. def forward_xpu(
  199. self,
  200. positions: torch.Tensor,
  201. query: torch.Tensor,
  202. key: torch.Tensor,
  203. offsets: Optional[torch.Tensor] = None,
  204. ) -> Tuple[torch.Tensor, torch.Tensor]:
  205. from aphrodite._ipex_ops import ipex_ops as ops
  206. self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
  207. dtype=query.dtype)
  208. # ops.rotary_embedding()/batched_rotary_embedding()
  209. # are in-place operations that update the query and key tensors.
  210. if offsets is not None:
  211. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  212. self.cos_sin_cache,
  213. self.is_neox_style, self.rotary_dim,
  214. offsets)
  215. else:
  216. ops.rotary_embedding(positions, query, key, self.head_size,
  217. self.cos_sin_cache, self.is_neox_style)
  218. return query, key
  219. def forward_tpu(
  220. self,
  221. positions: torch.Tensor,
  222. query: torch.Tensor,
  223. key: torch.Tensor,
  224. offsets: Optional[torch.Tensor] = None,
  225. ) -> Tuple[torch.Tensor, torch.Tensor]:
  226. forward_fn = (self.forward_native2
  227. if self.use_native2 else self.forward_native)
  228. return forward_fn(positions, query, key, offsets)
  229. def extra_repr(self) -> str:
  230. s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
  231. s += f", max_position_embeddings={self.max_position_embeddings}"
  232. s += f", base={self.base}, is_neox_style={self.is_neox_style}"
  233. return s
  234. class LinearScalingRotaryEmbedding(RotaryEmbedding):
  235. """RotaryEmbedding extended with linear scaling.
  236. It supports multiple scaling factors. Since multiple LoRA adapters may have
  237. different scaling factors, we need multiple cos/sin caches. In this way,
  238. instead of running rotary embedding kernel per lora, we can run multiple
  239. lora in a batched way.
  240. In addition to that, we also keep the cos/sin cache for the scaling factor
  241. of 1 (default) at all times.
  242. Exemplary for two scaling factors x=1, y and z with embeddings
  243. [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
  244. [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
  245. [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
  246. we construct the cos/sin cache as follows:
  247. [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
  248. ...
  249. [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
  250. We then use offsets to index into the cos/sin cache for
  251. the respective scaling factors.
  252. The offset to cache can be accessed via `scaling_factor_to_offset` API.
  253. Credits to the Reddit user /u/kaiokendev
  254. """
  255. def __init__(
  256. self,
  257. head_size: int,
  258. rotary_dim: int,
  259. max_position_embeddings: int,
  260. base: int,
  261. is_neox_style: bool,
  262. scaling_factors: Union[List[float], float],
  263. dtype: torch.dtype,
  264. ) -> None:
  265. if isinstance(scaling_factors, float):
  266. scaling_factors = [scaling_factors]
  267. self.scaling_factors: List[float] = scaling_factors # noqa
  268. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  269. is_neox_style, dtype)
  270. # Lazy initialized.
  271. self._scaling_factor_to_offset: Dict[float, int]
  272. def _compute_cos_sin_cache(self) -> torch.Tensor:
  273. inv_freq = self._compute_inv_freq(self.base)
  274. cache_list: List[torch.Tensor] = []
  275. # offsets to the next cache in a tensor.
  276. # Each offset corresponds to the same index in scaling_factors.
  277. offsets: List[int] = []
  278. for scaling_factor in self.scaling_factors:
  279. # NOTE: self.max_position_embeddings is the original
  280. # maximum length before applying the rope scaling.
  281. # Thus, the maximum length after applying the rope scaling is
  282. # self.max_position_embeddings * self.scaling_factor.
  283. max_len = self.max_position_embeddings * scaling_factor
  284. t = torch.arange(max_len, dtype=torch.float)
  285. t = t / scaling_factor
  286. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  287. cos = freqs.cos()
  288. sin = freqs.sin()
  289. cache = torch.cat((cos, sin), dim=-1)
  290. if not cache_list:
  291. offset = 0
  292. else:
  293. last_offset = offsets[-1]
  294. next_max_len = cache_list[-1].shape[0]
  295. offset = last_offset + next_max_len
  296. offsets.append(offset)
  297. cache_list.append(cache)
  298. self._scaling_factor_to_offset = {
  299. float(scaling_factor): offsets[i]
  300. for i, scaling_factor in enumerate(self.scaling_factors)
  301. }
  302. assert len(self.scaling_factors) == len(offsets)
  303. return torch.cat(cache_list, dim=0)
  304. @property
  305. def scaling_factor_to_offset(self) -> Dict[float, int]:
  306. return self._scaling_factor_to_offset
  307. class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
  308. """RotaryEmbedding extended with Dynamic NTK scaling.
  309. Credits to the Reddit users /u/bloc97 and /u/emozilla
  310. """
  311. def __init__(
  312. self,
  313. head_size: int,
  314. rotary_dim: int,
  315. max_position_embeddings: int,
  316. base: int,
  317. is_neox_style: bool,
  318. scaling_factor: float,
  319. dtype: torch.dtype,
  320. ) -> None:
  321. self.scaling_factor = scaling_factor
  322. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  323. is_neox_style, dtype)
  324. def _compute_cos_sin_cache(self) -> torch.Tensor:
  325. # NOTE: self.max_position_embeddings is the original
  326. # maximum length before applying the rope scaling.
  327. # Thus, the maximum length after applying the rope scaling is
  328. # self.max_position_embeddings * self.scaling_factor.
  329. max_len = self.max_position_embeddings * self.scaling_factor
  330. base = self.base * (
  331. (self.scaling_factor * max_len / self.max_position_embeddings) -
  332. (self.scaling_factor - 1))**(self.rotary_dim /
  333. (self.rotary_dim - 2))
  334. inv_freq = self._compute_inv_freq(base)
  335. t = torch.arange(max_len, dtype=torch.float)
  336. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  337. cos = freqs.cos()
  338. sin = freqs.sin()
  339. cache = torch.cat((cos, sin), dim=-1)
  340. return cache
  341. # Inverse dim formula to find dim based on number of rotations
  342. def _yarn_find_correction_dim(num_rotations: int,
  343. dim: int,
  344. base: float = 10000,
  345. max_position_embeddings: int = 2048) -> float:
  346. return (dim * math.log(max_position_embeddings /
  347. (num_rotations * 2 * math.pi))) / (2 *
  348. math.log(base))
  349. # Find dim range bounds based on rotations
  350. def _yarn_find_correction_range(
  351. low_rot: int,
  352. high_rot: int,
  353. dim: int,
  354. base: float = 10000,
  355. max_position_embeddings: int = 2048) -> Tuple[int, int]:
  356. low = math.floor(
  357. _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
  358. high = math.ceil(
  359. _yarn_find_correction_dim(high_rot, dim, base,
  360. max_position_embeddings))
  361. return max(low, 0), min(high, dim - 1) # Clamp values just in case
  362. def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
  363. dtype: torch.dtype) -> torch.Tensor:
  364. if low == high:
  365. high += 0.001 # Prevent singularity
  366. linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
  367. ramp_func = torch.clamp(linear_func, 0, 1)
  368. return ramp_func
  369. def _yarn_get_mscale(scale: float = 1) -> float:
  370. if scale <= 1:
  371. return 1.0
  372. return 0.1 * math.log(scale) + 1.0
  373. class YaRNScalingRotaryEmbedding(RotaryEmbedding):
  374. """RotaryEmbedding extended with YaRN method.
  375. Credits to Peng et al. github.com/jquesnelle/yarn
  376. """
  377. def __init__(
  378. self,
  379. head_size: int,
  380. rotary_dim: int,
  381. max_position_embeddings: int,
  382. base: int,
  383. is_neox_style: bool,
  384. scaling_factor: float,
  385. dtype: torch.dtype,
  386. *,
  387. extrapolation_factor: float = 1,
  388. attn_factor: float = 1,
  389. beta_fast: int = 32,
  390. beta_slow: int = 1,
  391. ) -> None:
  392. self.scaling_factor = scaling_factor
  393. self.extrapolation_factor = extrapolation_factor
  394. self.attn_factor = attn_factor
  395. self.beta_fast = beta_fast
  396. self.beta_slow = beta_slow
  397. # Get n-d magnitude scaling corrected for interpolation
  398. self.mscale = float(
  399. _yarn_get_mscale(self.scaling_factor) * attn_factor)
  400. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  401. is_neox_style, dtype)
  402. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  403. pos_freqs = self.base**(
  404. torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
  405. self.rotary_dim)
  406. inv_freq_extrapolation = 1.0 / pos_freqs
  407. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  408. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  409. self.rotary_dim, self.base,
  410. self.max_position_embeddings)
  411. # Get n-d rotational scaling corrected for extrapolation
  412. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  413. low, high, self.rotary_dim // 2,
  414. dtype=torch.float)) * self.extrapolation_factor
  415. inv_freq = inv_freq_interpolation * (
  416. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  417. return inv_freq
  418. def _compute_cos_sin_cache(self) -> torch.Tensor:
  419. inv_freq = self._compute_inv_freq(self.scaling_factor)
  420. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  421. dtype=torch.float32)
  422. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  423. cos = (freqs.cos() * self.mscale)
  424. sin = (freqs.sin() * self.mscale)
  425. cache = torch.cat((cos, sin), dim=-1)
  426. return cache
  427. class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
  428. """Phi3 family of models scaled rotary embedding.
  429. Based on the original RotaryEmbedding implementation.
  430. """
  431. def __init__(
  432. self,
  433. head_size: int,
  434. rotary_dim: int,
  435. max_position_embeddings: int,
  436. original_max_position_embeddings: int,
  437. base: int,
  438. is_neox_style: bool,
  439. dtype: torch.dtype,
  440. short_factor: List[float],
  441. long_factor: List[float],
  442. short_mscale: Optional[float] = None,
  443. long_mscale: Optional[float] = None,
  444. ):
  445. super().__init__()
  446. if rotary_dim != head_size:
  447. raise ValueError(
  448. f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
  449. rotary_dim != head_size ({rotary_dim}!={head_size}).")
  450. if is_neox_style is False:
  451. raise ValueError(
  452. "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
  453. )
  454. self.head_size = head_size
  455. self.max_position_embeddings = max_position_embeddings
  456. self.original_max_position_embeddings = original_max_position_embeddings
  457. self.base = base
  458. self.short_factor = short_factor
  459. self.long_factor = long_factor
  460. scale = self.max_position_embeddings / \
  461. self.original_max_position_embeddings
  462. if scale <= 1.0:
  463. scaling_factor = 1.0
  464. else:
  465. scaling_factor = math.sqrt(
  466. 1 + math.log(scale) /
  467. math.log(self.original_max_position_embeddings))
  468. if short_mscale is None:
  469. short_mscale = scaling_factor
  470. if long_mscale is None:
  471. long_mscale = scaling_factor
  472. self.short_mscale = short_mscale
  473. self.long_mscale = long_mscale
  474. short_cache = self._compute_cos_sin_cache(
  475. original_max_position_embeddings, short_factor, short_mscale)
  476. short_cache = short_cache.to(dtype)
  477. self.register_buffer("short_cos_sin_cache",
  478. short_cache,
  479. persistent=False)
  480. long_cache = self._compute_cos_sin_cache(max_position_embeddings,
  481. long_factor, long_mscale)
  482. long_cache = long_cache.to(dtype)
  483. self.register_buffer("long_cos_sin_cache",
  484. long_cache,
  485. persistent=False)
  486. long_short_cache = torch.cat(
  487. [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
  488. self.register_buffer("long_short_cos_sin_cache",
  489. long_short_cache,
  490. persistent=False)
  491. def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
  492. rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
  493. inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
  494. 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
  495. return inv_freq
  496. def _compute_cos_sin_cache(
  497. self,
  498. max_position_embeddings: int,
  499. rescale_factors: List[float],
  500. mscale: float,
  501. ) -> torch.Tensor:
  502. inv_freq = self._compute_inv_freq(rescale_factors)
  503. t = torch.arange(max_position_embeddings, dtype=torch.float)
  504. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  505. cos = freqs.cos() * mscale
  506. sin = freqs.sin() * mscale
  507. cache = torch.cat((cos, sin), dim=-1)
  508. return cache
  509. def forward(
  510. self,
  511. positions: torch.Tensor,
  512. query: torch.Tensor,
  513. key: torch.Tensor,
  514. offsets: Optional[torch.Tensor] = None,
  515. ) -> Tuple[torch.Tensor, torch.Tensor]:
  516. query = query.view(*query.shape[:-1], -1, self.head_size)
  517. key = key.view(*key.shape[:-1], -1, self.head_size)
  518. k = self.original_max_position_embeddings
  519. long_prompt_offset = (torch.any(positions > k).float() *
  520. torch.full_like(positions, k)).long()
  521. idx = (torch.add(positions, long_prompt_offset)
  522. if long_prompt_offset is not None else positions)
  523. self.long_short_cos_sin_cache: torch.Tensor = (
  524. self.long_short_cos_sin_cache.to(idx.device))
  525. idx = torch.add(idx, offsets) if offsets is not None else idx
  526. cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
  527. cos, sin = cos_sin.chunk(2, dim=-1)
  528. cos = cos.repeat(1, 2).unsqueeze(-2)
  529. sin = sin.repeat(1, 2).unsqueeze(-2)
  530. query = query * cos + _rotate_neox(query) * sin
  531. key = key * cos + _rotate_neox(key) * sin
  532. return query.flatten(-2), key.flatten(-2)
  533. def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
  534. if scale <= 1:
  535. return 1.0
  536. return 0.1 * mscale * math.log(scale) + 1.0
  537. class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
  538. """RotaryEmbedding extended with YaRN method.
  539. Credits to Peng et al. github.com/jquesnelle/yarn
  540. """
  541. def __init__(
  542. self,
  543. head_size: int,
  544. rotary_dim: int,
  545. max_position_embeddings: int,
  546. base: int,
  547. is_neox_style: bool,
  548. scaling_factor: float,
  549. dtype: torch.dtype,
  550. *,
  551. extrapolation_factor: float = 1,
  552. attn_factor: float = 1,
  553. beta_fast: int = 32,
  554. beta_slow: int = 1,
  555. mscale: float = 1,
  556. mscale_all_dim: float = 0,
  557. ) -> None:
  558. self.scaling_factor = scaling_factor
  559. self.extrapolation_factor = extrapolation_factor
  560. self.attn_factor = attn_factor
  561. self.beta_fast = beta_fast
  562. self.beta_slow = beta_slow
  563. # Get n-d magnitude scaling corrected for interpolation.
  564. self.mscale = float(
  565. yarn_get_mscale(self.scaling_factor, float(mscale)) /
  566. yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
  567. attn_factor)
  568. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  569. is_neox_style, dtype)
  570. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  571. pos_freqs = self.base**(torch.arange(
  572. 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
  573. self.rotary_dim)
  574. inv_freq_extrapolation = 1.0 / pos_freqs
  575. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  576. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  577. self.rotary_dim, self.base,
  578. self.max_position_embeddings)
  579. # Get n-d rotational scaling corrected for extrapolation
  580. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  581. low, high, self.rotary_dim // 2,
  582. dtype=torch.float)) * self.extrapolation_factor
  583. inv_freq = inv_freq_interpolation * (
  584. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  585. return inv_freq
  586. def _compute_cos_sin_cache(self) -> torch.Tensor:
  587. inv_freq = self._compute_inv_freq(self.scaling_factor)
  588. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  589. device="cuda",
  590. dtype=torch.float32)
  591. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  592. cos = (freqs.cos() * self.mscale)
  593. sin = (freqs.sin() * self.mscale)
  594. cache = torch.cat((cos, sin), dim=-1)
  595. print("Cache shape", cache.shape)
  596. return cache
  597. def forward(
  598. self,
  599. positions: torch.Tensor,
  600. query: torch.Tensor,
  601. key: torch.Tensor,
  602. offsets: Optional[torch.Tensor] = None,
  603. ) -> Tuple[torch.Tensor, torch.Tensor]:
  604. """PyTorch-native implementation equivalent to forward()."""
  605. query_rot = query[..., :self.rotary_dim]
  606. key_rot = key[..., :self.rotary_dim]
  607. if self.rotary_dim < self.head_size:
  608. query_pass = query[..., self.rotary_dim:]
  609. key_pass = key[..., self.rotary_dim:]
  610. self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
  611. positions.device)
  612. cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
  613. if offsets is not None else positions]
  614. cos, sin = cos_sin.chunk(2, dim=-1)
  615. if self.is_neox_style:
  616. # NOTE: Here we assume that the positions tensor has the
  617. # shape [batch_size, seq_len].
  618. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  619. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  620. else:
  621. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  622. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  623. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  624. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  625. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  626. if self.rotary_dim < self.head_size:
  627. query = torch.cat((query_rot, query_pass), dim=-1)
  628. key = torch.cat((key_rot, key_pass), dim=-1)
  629. else:
  630. query = query_rot
  631. key = key_rot
  632. return query, key
  633. class GemmaRotaryEmbedding(RotaryEmbedding):
  634. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  635. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
  636. inv_freq = 1.0 / (base**(
  637. torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
  638. self.rotary_dim))
  639. return inv_freq
  640. class Llama3RotaryEmbedding(RotaryEmbedding):
  641. def __init__(
  642. self,
  643. head_size: int,
  644. rotary_dim: int,
  645. max_position_embeddings: int,
  646. base: int,
  647. is_neox_style: bool,
  648. dtype: torch.dtype,
  649. scaling_factor: float,
  650. low_freq_factor: float,
  651. high_freq_factor: float,
  652. orig_max_position: int,
  653. ) -> None:
  654. self.scaling_factor = scaling_factor
  655. self.low_freq_factor = low_freq_factor
  656. self.high_freq_factor = high_freq_factor
  657. self.orig_max_position = orig_max_position
  658. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  659. is_neox_style, dtype)
  660. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  661. inv_freqs = super()._compute_inv_freq(base)
  662. low_freq_wavelen = self.orig_max_position / self.low_freq_factor
  663. high_freq_wavelen = self.orig_max_position / self.high_freq_factor
  664. wave_len = 2 * math.pi / inv_freqs
  665. if self.low_freq_factor != self.high_freq_factor:
  666. smooth = (self.orig_max_position / wave_len - self.low_freq_factor
  667. ) / (self.high_freq_factor - self.low_freq_factor)
  668. else:
  669. smooth = 0
  670. new_freqs = torch.where(
  671. wave_len < high_freq_wavelen,
  672. inv_freqs,
  673. torch.where(
  674. wave_len > low_freq_wavelen,
  675. inv_freqs / self.scaling_factor,
  676. (1 - smooth) * inv_freqs / self.scaling_factor +
  677. smooth * inv_freqs,
  678. ),
  679. )
  680. return new_freqs
  681. _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
  682. def get_rope(
  683. head_size: int,
  684. rotary_dim: int,
  685. max_position: int,
  686. base: int,
  687. is_neox_style: bool = True,
  688. rope_scaling: Optional[Dict[str, Any]] = None,
  689. dtype: Optional[torch.dtype] = None,
  690. rotary_percent: float = 1.0,
  691. ) -> RotaryEmbedding:
  692. if dtype is None:
  693. dtype = torch.get_default_dtype()
  694. if rope_scaling is not None:
  695. # Transforms every value that is a list into a tuple for caching calls
  696. rope_scaling_tuple = {
  697. k: tuple(v) if isinstance(v, list) else v
  698. for k, v in rope_scaling.items()
  699. }
  700. rope_scaling_args = tuple(rope_scaling_tuple.items())
  701. else:
  702. rope_scaling_args = None
  703. if rotary_percent < 1.0:
  704. rotary_dim = int(rotary_dim * rotary_percent)
  705. key = (head_size, rotary_dim, max_position, base, is_neox_style,
  706. rope_scaling_args, dtype)
  707. if key in _ROPE_DICT:
  708. return _ROPE_DICT[key]
  709. if rope_scaling is None:
  710. rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
  711. is_neox_style, dtype)
  712. else:
  713. scaling_type = rope_scaling[
  714. "type"] if "type" in rope_scaling else rope_scaling["rope_type"]
  715. # The correct one should be "longrope" but keep "su" here
  716. # for backward compatible
  717. if scaling_type not in {"su", "longrope"}:
  718. scaling_factor = rope_scaling["factor"]
  719. if scaling_type == "llama3":
  720. low_freq_factor = rope_scaling["low_freq_factor"]
  721. high_freq_factor = rope_scaling["high_freq_factor"]
  722. original_max_position = rope_scaling[
  723. "original_max_position_embeddings"]
  724. rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
  725. max_position, base,
  726. is_neox_style, dtype,
  727. scaling_factor, low_freq_factor,
  728. high_freq_factor,
  729. original_max_position)
  730. elif scaling_type == "linear":
  731. rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
  732. max_position, base,
  733. is_neox_style,
  734. scaling_factor, dtype)
  735. elif scaling_type == "dynamic":
  736. rotary_emb = DynamicNTKScalingRotaryEmbedding(
  737. head_size, rotary_dim, max_position, base, is_neox_style,
  738. scaling_factor, dtype)
  739. elif scaling_type == "yarn":
  740. original_max_position = rope_scaling[
  741. "original_max_position_embeddings"]
  742. extra_kwargs = {
  743. k: v
  744. for k, v in rope_scaling.items()
  745. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  746. "beta_slow")
  747. }
  748. rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
  749. original_max_position,
  750. base, is_neox_style,
  751. scaling_factor, dtype,
  752. **extra_kwargs)
  753. elif scaling_type == "deepseek_yarn":
  754. original_max_position = rope_scaling[
  755. "original_max_position_embeddings"]
  756. # assert max_position == original_max_position * scaling_factor
  757. extra_kwargs = {
  758. k: v
  759. for k, v in rope_scaling.items()
  760. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  761. "beta_slow", "mscale", "mscale_all_dim")
  762. }
  763. rotary_emb = DeepseekScalingRotaryEmbedding(
  764. head_size, rotary_dim, original_max_position, base,
  765. is_neox_style, scaling_factor, dtype, **extra_kwargs)
  766. # The correct one should be "longrope" but keep "su" here
  767. # for backward compatible
  768. elif scaling_type == "su" or scaling_type == "longrope":
  769. short_factor = rope_scaling["short_factor"]
  770. long_factor = rope_scaling["long_factor"]
  771. original_max_position = rope_scaling[
  772. "original_max_position_embeddings"]
  773. extra_kwargs = {
  774. k: v
  775. for k, v in rope_scaling.items()
  776. if k in ("short_mscale", "long_mscale")
  777. }
  778. rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
  779. head_size, rotary_dim, max_position, original_max_position,
  780. base, is_neox_style, dtype, short_factor, long_factor,
  781. **extra_kwargs)
  782. else:
  783. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  784. _ROPE_DICT[key] = rotary_emb
  785. return rotary_emb