rotary_embedding.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The PygmalionAI team.
  7. # Copyright 2023 The vLLM team.
  8. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  11. # and OPT implementations in this library. It has been modified from its
  12. # original forms to accommodate minor architectural differences compared
  13. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  14. #
  15. # Licensed under the Apache License, Version 2.0 (the "License");
  16. # you may not use this file except in compliance with the License.
  17. # You may obtain a copy of the License at
  18. #
  19. # http://www.apache.org/licenses/LICENSE-2.0
  20. #
  21. # Unless required by applicable law or agreed to in writing, software
  22. # distributed under the License is distributed on an "AS IS" BASIS,
  23. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  24. # See the License for the specific language governing permissions and
  25. # limitations under the License.
  26. """Rotary Positional Embeddings."""
  27. import math
  28. from typing import Any, Dict, List, Optional, Tuple, Union
  29. import torch
  30. import torch.nn as nn
  31. from aphrodite._C import ops
  32. def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
  33. x1 = x[..., :x.shape[-1] // 2]
  34. x2 = x[..., x.shape[-1] // 2:]
  35. return torch.cat((-x2, x1), dim=-1)
  36. def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  37. x1 = x[..., ::2]
  38. x2 = x[..., 1::2]
  39. x = torch.stack((-x2, x1), dim=-1)
  40. return x.flatten(-2)
  41. class RotaryEmbedding(nn.Module):
  42. """Original rotary positional embedding."""
  43. def __init__(
  44. self,
  45. head_size: int,
  46. rotary_dim: int,
  47. max_position_embeddings: int,
  48. base: int,
  49. is_neox_style: bool,
  50. dtype: torch.dtype,
  51. ) -> None:
  52. super().__init__()
  53. self.head_size = head_size
  54. self.rotary_dim = rotary_dim
  55. self.max_position_embeddings = max_position_embeddings
  56. self.base = base
  57. self.is_neox_style = is_neox_style
  58. cache = self._compute_cos_sin_cache()
  59. cache = cache.to(dtype)
  60. self.register_buffer("cos_sin_cache", cache, persistent=False)
  61. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  62. """Compute the inverse frequency."""
  63. # NOTE: The HF implementation uses `torch.arange(...).float()`.
  64. # However, we use `torch.arange(..., dtype=torch.float)` instead to
  65. # avoid numerical issues with large base values (e.g., 10000000).
  66. # This may cause a slight numerical difference between the HF
  67. # implementation and ours.
  68. # NOTE: To exactly match the HF implementation, we need to
  69. # use CPU to compute the cache and then move it to GPU. However, we
  70. # create the cache on GPU for faster initialization. This may cause
  71. # a slight numerical difference between the HF implementation and ours.
  72. inv_freq = 1.0 / (base**(torch.arange(
  73. 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
  74. return inv_freq
  75. def _compute_cos_sin_cache(self) -> torch.Tensor:
  76. """Compute the cos and sin cache."""
  77. inv_freq = self._compute_inv_freq(self.base)
  78. t = torch.arange(self.max_position_embeddings, dtype=torch.float)
  79. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  80. cos = freqs.cos()
  81. sin = freqs.sin()
  82. cache = torch.cat((cos, sin), dim=-1)
  83. return cache
  84. def _forward(
  85. self,
  86. positions: torch.Tensor,
  87. query: torch.Tensor,
  88. key: torch.Tensor,
  89. offsets: Optional[torch.Tensor] = None,
  90. ) -> Tuple[torch.Tensor, torch.Tensor]:
  91. """PyTorch-native implementation equivalent to forward()."""
  92. query = query.view(*query.shape[:-1], -1, self.head_size)
  93. key = key.view(*key.shape[:-1], -1, self.head_size)
  94. query_rot = query[..., :self.rotary_dim]
  95. key_rot = key[..., :self.rotary_dim]
  96. if self.rotary_dim < self.head_size:
  97. query_pass = query[..., self.rotary_dim:]
  98. key_pass = key[..., self.rotary_dim:]
  99. self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
  100. positions.device, dtype=query.dtype)
  101. cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
  102. if offsets is not None else positions]
  103. cos, sin = cos_sin.chunk(2, dim=-1)
  104. if self.is_neox_style:
  105. # NOTE: Here we assume that the positions tensor has the
  106. # shape [batch_size, seq_len].
  107. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  108. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  109. else:
  110. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  111. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  112. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  113. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  114. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  115. if self.rotary_dim < self.head_size:
  116. query = torch.cat((query_rot, query_pass), dim=-1)
  117. key = torch.cat((key_rot, key_pass), dim=-1)
  118. else:
  119. query = query_rot
  120. key = key_rot
  121. query = query.flatten(-2)
  122. key = key.flatten(-2)
  123. return query, key
  124. def forward(
  125. self,
  126. positions: torch.Tensor,
  127. query: torch.Tensor,
  128. key: torch.Tensor,
  129. offsets: Optional[torch.Tensor] = None,
  130. ) -> Tuple[torch.Tensor, torch.Tensor]:
  131. self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
  132. dtype=query.dtype)
  133. # ops.rotary_embedding()/batched_rotary_embedding()
  134. # are in-place operations that update the query and key tensors.
  135. if offsets is not None:
  136. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  137. self.cos_sin_cache,
  138. self.is_neox_style, self.rotary_dim,
  139. offsets)
  140. else:
  141. ops.rotary_embedding(positions, query, key, self.head_size,
  142. self.cos_sin_cache, self.is_neox_style)
  143. return query, key
  144. def extra_repr(self) -> str:
  145. s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
  146. s += f", max_position_embeddings={self.max_position_embeddings}"
  147. s += f", base={self.base}, is_neox_style={self.is_neox_style}"
  148. return s
  149. class LinearScalingRotaryEmbedding(RotaryEmbedding):
  150. """RotaryEmbedding extended with linear scaling.
  151. It supports multiple scaling factors. Since multiple LoRA adapters may have
  152. different scaling factors, we need multiple cos/sin caches. In this way,
  153. instead of running rotary embedding kernel per lora, we can run multiple
  154. lora in a batched way.
  155. In addition to that, we also keep the cos/sin cache for the scaling factor
  156. of 1 (default) at all times.
  157. Exemplary for two scaling factors x=1, y and z with embeddings
  158. [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
  159. [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
  160. [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
  161. we construct the cos/sin cache as follows:
  162. [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
  163. ...
  164. [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
  165. We then use offsets to index into the cos/sin cache for
  166. the respective scaling factors.
  167. The offset to cache can be accessed via `scaling_factor_to_offset` API.
  168. Credits to the Reddit user /u/kaiokendev
  169. """
  170. def __init__(
  171. self,
  172. head_size: int,
  173. rotary_dim: int,
  174. max_position_embeddings: int,
  175. base: int,
  176. is_neox_style: bool,
  177. scaling_factors: Union[List[float], float],
  178. dtype: torch.dtype,
  179. ) -> None:
  180. if isinstance(scaling_factors, float):
  181. scaling_factors = [scaling_factors]
  182. self.scaling_factors: List[float] = scaling_factors # noqa
  183. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  184. is_neox_style, dtype)
  185. # Lazy initialized.
  186. self._scaling_factor_to_offset: Dict[float, int]
  187. def _compute_cos_sin_cache(self) -> torch.Tensor:
  188. inv_freq = self._compute_inv_freq(self.base)
  189. cache_list: List[torch.Tensor] = []
  190. # offsets to the next cache in a tensor.
  191. # Each offset corresponds to the same index in scaling_factors.
  192. offsets: List[int] = []
  193. for scaling_factor in self.scaling_factors:
  194. # NOTE: self.max_position_embeddings is the original
  195. # maximum length before applying the rope scaling.
  196. # Thus, the maximum length after applying the rope scaling is
  197. # self.max_position_embeddings * self.scaling_factor.
  198. max_len = self.max_position_embeddings * scaling_factor
  199. t = torch.arange(max_len, dtype=torch.float)
  200. t = t / scaling_factor
  201. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  202. cos = freqs.cos()
  203. sin = freqs.sin()
  204. cache = torch.cat((cos, sin), dim=-1)
  205. if not cache_list:
  206. offset = 0
  207. else:
  208. last_offset = offsets[-1]
  209. next_max_len = cache_list[-1].shape[0]
  210. offset = last_offset + next_max_len
  211. offsets.append(offset)
  212. cache_list.append(cache)
  213. self._scaling_factor_to_offset = {
  214. float(scaling_factor): offsets[i]
  215. for i, scaling_factor in enumerate(self.scaling_factors)
  216. }
  217. assert len(self.scaling_factors) == len(offsets)
  218. return torch.cat(cache_list, dim=0)
  219. @property
  220. def scaling_factor_to_offset(self) -> Dict[float, int]:
  221. return self._scaling_factor_to_offset
  222. class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
  223. """RotaryEmbedding extended with Dynamic NTK scaling.
  224. Credits to the Reddit users /u/bloc97 and /u/emozilla
  225. """
  226. def __init__(
  227. self,
  228. head_size: int,
  229. rotary_dim: int,
  230. max_position_embeddings: int,
  231. base: int,
  232. is_neox_style: bool,
  233. scaling_factor: float,
  234. dtype: torch.dtype,
  235. ) -> None:
  236. self.scaling_factor = scaling_factor
  237. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  238. is_neox_style, dtype)
  239. def _compute_cos_sin_cache(self) -> torch.Tensor:
  240. # NOTE: self.max_position_embeddings is the original
  241. # maximum length before applying the rope scaling.
  242. # Thus, the maximum length after applying the rope scaling is
  243. # self.max_position_embeddings * self.scaling_factor.
  244. max_len = self.max_position_embeddings * self.scaling_factor
  245. base = self.base * (
  246. (self.scaling_factor * max_len / self.max_position_embeddings) -
  247. (self.scaling_factor - 1))**(self.rotary_dim /
  248. (self.rotary_dim - 2))
  249. inv_freq = self._compute_inv_freq(base)
  250. t = torch.arange(max_len, dtype=torch.float)
  251. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  252. cos = freqs.cos()
  253. sin = freqs.sin()
  254. cache = torch.cat((cos, sin), dim=-1)
  255. return cache
  256. # Inverse dim formula to find dim based on number of rotations
  257. def _yarn_find_correction_dim(num_rotations: int,
  258. dim: int,
  259. base: float = 10000,
  260. max_position_embeddings: int = 2048) -> float:
  261. return (dim * math.log(max_position_embeddings /
  262. (num_rotations * 2 * math.pi))) / (2 *
  263. math.log(base))
  264. # Find dim range bounds based on rotations
  265. def _yarn_find_correction_range(
  266. low_rot: int,
  267. high_rot: int,
  268. dim: int,
  269. base: float = 10000,
  270. max_position_embeddings: int = 2048) -> Tuple[int, int]:
  271. low = math.floor(
  272. _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
  273. high = math.ceil(
  274. _yarn_find_correction_dim(high_rot, dim, base,
  275. max_position_embeddings))
  276. return max(low, 0), min(high, dim - 1) # Clamp values just in case
  277. def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
  278. dtype: torch.dtype) -> torch.Tensor:
  279. if low == high:
  280. high += 0.001 # Prevent singularity
  281. linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
  282. ramp_func = torch.clamp(linear_func, 0, 1)
  283. return ramp_func
  284. def _yarn_get_mscale(scale: float = 1) -> float:
  285. if scale <= 1:
  286. return 1.0
  287. return 0.1 * math.log(scale) + 1.0
  288. class YaRNScalingRotaryEmbedding(RotaryEmbedding):
  289. """RotaryEmbedding extended with YaRN method.
  290. Credits to Peng et al. github.com/jquesnelle/yarn
  291. """
  292. def __init__(
  293. self,
  294. head_size: int,
  295. rotary_dim: int,
  296. max_position_embeddings: int,
  297. base: int,
  298. is_neox_style: bool,
  299. scaling_factor: float,
  300. dtype: torch.dtype,
  301. *,
  302. extrapolation_factor: float = 1,
  303. attn_factor: float = 1,
  304. beta_fast: int = 32,
  305. beta_slow: int = 1,
  306. ) -> None:
  307. self.scaling_factor = scaling_factor
  308. self.extrapolation_factor = extrapolation_factor
  309. self.attn_factor = attn_factor
  310. self.beta_fast = beta_fast
  311. self.beta_slow = beta_slow
  312. # Get n-d magnitude scaling corrected for interpolation
  313. self.mscale = float(
  314. _yarn_get_mscale(self.scaling_factor) * attn_factor)
  315. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  316. is_neox_style, dtype)
  317. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  318. pos_freqs = self.base**(
  319. torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
  320. self.rotary_dim)
  321. inv_freq_extrapolation = 1.0 / pos_freqs
  322. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  323. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  324. self.rotary_dim, self.base,
  325. self.max_position_embeddings)
  326. # Get n-d rotational scaling corrected for extrapolation
  327. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  328. low, high, self.rotary_dim // 2,
  329. dtype=torch.float)) * self.extrapolation_factor
  330. inv_freq = inv_freq_interpolation * (
  331. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  332. return inv_freq
  333. def _compute_cos_sin_cache(self) -> torch.Tensor:
  334. inv_freq = self._compute_inv_freq(self.scaling_factor)
  335. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  336. dtype=torch.float32)
  337. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  338. cos = (freqs.cos() * self.mscale)
  339. sin = (freqs.sin() * self.mscale)
  340. cache = torch.cat((cos, sin), dim=-1)
  341. return cache
  342. class Phi3LongRoPERotaryEmbedding(nn.Module):
  343. def __init__(
  344. self,
  345. head_size: int,
  346. rotary_dim: int,
  347. max_position_embeddings: int,
  348. original_max_position_embeddings: int,
  349. base: int,
  350. is_neox_style: bool,
  351. dtype: torch.dtype,
  352. short_factor: List[float],
  353. long_factor: List[float],
  354. short_mscale: float = 1.1,
  355. long_mscale: float = 1.225,
  356. ):
  357. super().__init__()
  358. if rotary_dim != head_size:
  359. raise ValueError(
  360. f"Rotary dim must be equal to head size, got {rotary_dim} "
  361. f"and {head_size}")
  362. if is_neox_style is False:
  363. raise ValueError(
  364. "Phi3SuScaledRotaryEmbedding only supports Neox style")
  365. self.head_size = head_size
  366. self.max_position_embeddings = max_position_embeddings
  367. self.original_max_position_embeddings = original_max_position_embeddings
  368. self.base = base
  369. self.short_factor = short_factor
  370. self.long_factor = long_factor
  371. self.short_mscale = short_mscale
  372. self.long_mscale = long_mscale
  373. short_cache = self._compute_cos_sin_cache(
  374. original_max_position_embeddings, short_factor, short_mscale)
  375. short_cache = short_cache.to(dtype)
  376. self.register_buffer("short_cos_sin_cache",
  377. short_cache,
  378. persistent=False)
  379. long_cache = self._compute_cos_sin_cache(
  380. original_max_position_embeddings, long_factor, long_mscale)
  381. long_cache = long_cache.to(dtype)
  382. self.register_buffer("long_cos_sin_cache",
  383. long_cache,
  384. persistent=False)
  385. long_short_cache = torch.cat(
  386. [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
  387. self.register_buffer("long_short_cos_sin_cache",
  388. long_short_cache,
  389. persistent=False)
  390. def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
  391. rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
  392. inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
  393. 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
  394. return inv_freq
  395. def _compute_cos_sin_cache(self, max_position_embeddings: int,
  396. rescale_factors: List[float],
  397. mscale: float) -> torch.Tensor:
  398. inv_freq = self._compute_inv_freq(rescale_factors)
  399. t = torch.arange(max_position_embeddings, dtype=torch.float)
  400. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  401. cos = freqs.cos() * mscale
  402. sin = freqs.sin() * mscale
  403. cache = torch.cat((cos, sin), dim=-1)
  404. return cache
  405. def forward(
  406. self,
  407. positions: torch.Tensor,
  408. query: torch.Tensor,
  409. key: torch.Tensor,
  410. offsets: Optional[torch.Tensor] = None,
  411. ) -> Tuple[torch.Tensor, torch.Tensor]:
  412. query = query.view(*query.shape[:-1], -1, self.head_size)
  413. key = key.view(*key.shape[:-1], -1, self.head_size)
  414. k = self.original_max_position_embeddings
  415. long_prompt_offset = (torch.any(positions > k).float() *
  416. torch.full_like(positions, k)).long()
  417. idx = (torch.add(positions, long_prompt_offset)
  418. if long_prompt_offset is not None else positions)
  419. self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
  420. idx.device)
  421. idx = torch.add(idx, offsets) if offsets is not None else idx
  422. cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
  423. cos, sin = cos_sin.chunk(2, dim=-1)
  424. cos = cos.repeat(1, 2).unsqueeze(-2)
  425. sin = sin.repeat(1, 2).unsqueeze(-2)
  426. query = query * cos + _rotate_neox(query) * sin
  427. key = key * cos + _rotate_neox(key) * sin
  428. return query.flatten(-2), key.flatten(-2)
  429. class ExtendedRotaryEmbedding(RotaryEmbedding):
  430. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  431. inv_freqs = super()._compute_inv_freq(base)
  432. return self.apply_scaling(inv_freqs)
  433. def apply_scaling(self, freqs: torch.Tensor):
  434. scale_factor = 8
  435. low_freq_factor = 1
  436. high_freq_factor = 4
  437. old_context_len = 8192
  438. low_freq_wavelen = old_context_len / low_freq_factor
  439. high_freq_wavelen = old_context_len / high_freq_factor
  440. new_freqs = []
  441. for freq in freqs:
  442. wavelen = 2 * math.pi / freq
  443. if wavelen < high_freq_wavelen:
  444. new_freqs.append(freq)
  445. elif wavelen > low_freq_wavelen:
  446. new_freqs.append(freq / scale_factor)
  447. else:
  448. assert low_freq_wavelen != high_freq_wavelen
  449. smooth = (old_context_len / wavelen - low_freq_factor) / (
  450. high_freq_factor - low_freq_factor)
  451. new_freqs.append((1 - smooth) * freq / scale_factor +
  452. smooth * freq)
  453. return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
  454. _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
  455. def get_rope(
  456. head_size: int,
  457. rotary_dim: int,
  458. max_position: int,
  459. base: int,
  460. is_neox_style: bool = True,
  461. rope_scaling: Optional[Dict[str, Any]] = None,
  462. dtype: Optional[torch.dtype] = None,
  463. ) -> RotaryEmbedding:
  464. if dtype is None:
  465. dtype = torch.get_default_dtype()
  466. if rope_scaling is not None:
  467. rope_scaling_tuple = {
  468. k: tuple(v) if isinstance(v, list) else v
  469. for k, v in rope_scaling.items()
  470. }
  471. rope_scaling_args = tuple(rope_scaling_tuple.items())
  472. else:
  473. rope_scaling_args = None
  474. key = (head_size, rotary_dim, max_position, base, is_neox_style,
  475. rope_scaling_args, dtype)
  476. if key in _ROPE_DICT:
  477. return _ROPE_DICT[key]
  478. if rope_scaling is None:
  479. rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
  480. is_neox_style, dtype)
  481. else:
  482. scaling_type = rope_scaling.get("type", rope_scaling.get("rope_type"))
  483. if scaling_type not in {"su", "longrope", "llama3"}:
  484. scaling_factor = rope_scaling["factor"]
  485. if scaling_type == "llama3":
  486. rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
  487. max_position, base,
  488. is_neox_style, dtype)
  489. elif scaling_type == "linear":
  490. rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
  491. max_position, base,
  492. is_neox_style,
  493. scaling_factor, dtype)
  494. elif scaling_type == "dynamic":
  495. rotary_emb = DynamicNTKScalingRotaryEmbedding(
  496. head_size, rotary_dim, max_position, base, is_neox_style,
  497. scaling_factor, dtype)
  498. elif scaling_type == "yarn":
  499. original_max_position = rope_scaling[
  500. "original_max_position_embeddings"]
  501. extra_kwargs = {
  502. k: v
  503. for k, v in rope_scaling.items()
  504. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  505. "beta_slow")
  506. }
  507. rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
  508. original_max_position,
  509. base, is_neox_style,
  510. scaling_factor, dtype,
  511. **extra_kwargs)
  512. elif scaling_type == "su" or scaling_type == "longrope":
  513. short_factor = rope_scaling["short_factor"]
  514. long_factor = rope_scaling["long_factor"]
  515. original_max_position = rope_scaling[
  516. "original_max_position_embeddings"]
  517. extra_kwargs = {
  518. k: v
  519. for k, v in rope_scaling.items()
  520. if k in ("short_mscale", "long_mscale")
  521. }
  522. rotary_emb = Phi3LongRoPERotaryEmbedding(
  523. head_size, rotary_dim, max_position, original_max_position,
  524. base, is_neox_style, dtype, short_factor, long_factor,
  525. **extra_kwargs)
  526. else:
  527. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  528. _ROPE_DICT[key] = rotary_emb
  529. return rotary_emb