rotary_embedding.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Rotary Positional Embeddings."""
  25. import math
  26. from typing import Any, Dict, List, Optional, Tuple, Union
  27. import torch
  28. import torch.nn as nn
  29. from aphrodite.modeling._custom_op import CustomOp
  30. def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
  31. x1 = x[..., :x.shape[-1] // 2]
  32. x2 = x[..., x.shape[-1] // 2:]
  33. return torch.cat((-x2, x1), dim=-1)
  34. def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  35. x1 = x[..., ::2]
  36. x2 = x[..., 1::2]
  37. x = torch.stack((-x2, x1), dim=-1)
  38. return x.flatten(-2)
  39. def _apply_rotary_emb(
  40. x: torch.Tensor,
  41. cos: torch.Tensor,
  42. sin: torch.Tensor,
  43. is_neox_style: bool,
  44. ) -> torch.Tensor:
  45. """
  46. Args:
  47. x: [num_tokens, num_heads, head_size]
  48. cos: [num_tokens, head_size // 2]
  49. sin: [num_tokens, head_size // 2]
  50. is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
  51. positional embeddings.
  52. """
  53. cos = cos.unsqueeze(-2).to(x.dtype)
  54. sin = sin.unsqueeze(-2).to(x.dtype)
  55. if is_neox_style:
  56. x1, x2 = torch.chunk(x, 2, dim=-1)
  57. else:
  58. x1 = x[..., ::2]
  59. x2 = x[..., 1::2]
  60. o1 = x1 * cos - x2 * sin
  61. o2 = x2 * cos + x1 * sin
  62. if is_neox_style:
  63. return torch.cat((o1, o2), dim=-1)
  64. else:
  65. return torch.stack((o1, o2), dim=-1).flatten(-2)
  66. class RotaryEmbedding(CustomOp):
  67. """Original rotary positional embedding."""
  68. def __init__(
  69. self,
  70. head_size: int,
  71. rotary_dim: int,
  72. max_position_embeddings: int,
  73. base: int,
  74. is_neox_style: bool,
  75. dtype: torch.dtype,
  76. ) -> None:
  77. super().__init__()
  78. self.head_size = head_size
  79. self.rotary_dim = rotary_dim
  80. self.max_position_embeddings = max_position_embeddings
  81. self.base = base
  82. self.is_neox_style = is_neox_style
  83. self.dtype = dtype
  84. cache = self._compute_cos_sin_cache()
  85. cache = cache.to(dtype)
  86. self.cos_sin_cache: torch.Tensor
  87. self.register_buffer("cos_sin_cache", cache, persistent=False)
  88. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  89. """Compute the inverse frequency."""
  90. # NOTE: To exactly match the HF implementation, we need to
  91. # use CPU to compute the cache and then move it to GPU. However, we
  92. # create the cache on GPU for faster initialization. This may cause
  93. # a slight numerical difference between the HF implementation and ours.
  94. inv_freq = 1.0 / (base**(torch.arange(
  95. 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
  96. return inv_freq
  97. def _compute_cos_sin_cache(self) -> torch.Tensor:
  98. """Compute the cos and sin cache."""
  99. inv_freq = self._compute_inv_freq(self.base)
  100. t = torch.arange(self.max_position_embeddings, dtype=torch.float)
  101. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  102. cos = freqs.cos()
  103. sin = freqs.sin()
  104. cache = torch.cat((cos, sin), dim=-1)
  105. return cache
  106. def forward_native(
  107. self,
  108. positions: torch.Tensor,
  109. query: torch.Tensor,
  110. key: torch.Tensor,
  111. offsets: Optional[torch.Tensor] = None,
  112. ) -> Tuple[torch.Tensor, torch.Tensor]:
  113. """A PyTorch-native implementation of forward()."""
  114. if offsets is not None:
  115. positions = positions + offsets
  116. positions = positions.flatten()
  117. num_tokens = positions.shape[0]
  118. cos_sin = self.cos_sin_cache.index_select(0, positions)
  119. cos, sin = cos_sin.chunk(2, dim=-1)
  120. query_shape = query.shape
  121. query = query.view(num_tokens, -1, self.head_size)
  122. query_rot = query[..., :self.rotary_dim]
  123. query_pass = query[..., self.rotary_dim:]
  124. query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
  125. query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
  126. key_shape = key.shape
  127. key = key.view(num_tokens, -1, self.head_size)
  128. key_rot = key[..., :self.rotary_dim]
  129. key_pass = key[..., self.rotary_dim:]
  130. key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
  131. key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
  132. return query, key
  133. def forward_cuda(
  134. self,
  135. positions: torch.Tensor,
  136. query: torch.Tensor,
  137. key: torch.Tensor,
  138. offsets: Optional[torch.Tensor] = None,
  139. ) -> Tuple[torch.Tensor, torch.Tensor]:
  140. from aphrodite import _custom_ops as ops
  141. self.cos_sin_cache = self.cos_sin_cache.to(query.device,
  142. dtype=query.dtype)
  143. # ops.rotary_embedding()/batched_rotary_embedding()
  144. # are in-place operations that update the query and key tensors.
  145. if offsets is not None:
  146. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  147. self.cos_sin_cache,
  148. self.is_neox_style, self.rotary_dim,
  149. offsets)
  150. else:
  151. ops.rotary_embedding(positions, query, key, self.head_size,
  152. self.cos_sin_cache, self.is_neox_style)
  153. return query, key
  154. def forward_xpu(
  155. self,
  156. positions: torch.Tensor,
  157. query: torch.Tensor,
  158. key: torch.Tensor,
  159. offsets: Optional[torch.Tensor] = None,
  160. ) -> Tuple[torch.Tensor, torch.Tensor]:
  161. from aphrodite._ipex_ops import ipex_ops as ops
  162. self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
  163. dtype=query.dtype)
  164. # ops.rotary_embedding()/batched_rotary_embedding()
  165. # are in-place operations that update the query and key tensors.
  166. if offsets is not None:
  167. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  168. self.cos_sin_cache,
  169. self.is_neox_style, self.rotary_dim,
  170. offsets)
  171. else:
  172. ops.rotary_embedding(positions, query, key, self.head_size,
  173. self.cos_sin_cache, self.is_neox_style)
  174. return query, key
  175. def extra_repr(self) -> str:
  176. s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
  177. s += f", max_position_embeddings={self.max_position_embeddings}"
  178. s += f", base={self.base}, is_neox_style={self.is_neox_style}"
  179. return s
  180. class LinearScalingRotaryEmbedding(RotaryEmbedding):
  181. """RotaryEmbedding extended with linear scaling.
  182. It supports multiple scaling factors. Since multiple LoRA adapters may have
  183. different scaling factors, we need multiple cos/sin caches. In this way,
  184. instead of running rotary embedding kernel per lora, we can run multiple
  185. lora in a batched way.
  186. In addition to that, we also keep the cos/sin cache for the scaling factor
  187. of 1 (default) at all times.
  188. Exemplary for two scaling factors x=1, y and z with embeddings
  189. [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
  190. [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
  191. [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
  192. we construct the cos/sin cache as follows:
  193. [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
  194. ...
  195. [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
  196. We then use offsets to index into the cos/sin cache for
  197. the respective scaling factors.
  198. The offset to cache can be accessed via `scaling_factor_to_offset` API.
  199. Credits to the Reddit user /u/kaiokendev
  200. """
  201. def __init__(
  202. self,
  203. head_size: int,
  204. rotary_dim: int,
  205. max_position_embeddings: int,
  206. base: int,
  207. is_neox_style: bool,
  208. scaling_factors: Union[List[float], float],
  209. dtype: torch.dtype,
  210. ) -> None:
  211. if isinstance(scaling_factors, float):
  212. scaling_factors = [scaling_factors]
  213. self.scaling_factors: List[float] = scaling_factors # noqa
  214. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  215. is_neox_style, dtype)
  216. # Lazy initialized.
  217. self._scaling_factor_to_offset: Dict[float, int]
  218. def _compute_cos_sin_cache(self) -> torch.Tensor:
  219. inv_freq = self._compute_inv_freq(self.base)
  220. cache_list: List[torch.Tensor] = []
  221. # offsets to the next cache in a tensor.
  222. # Each offset corresponds to the same index in scaling_factors.
  223. offsets: List[int] = []
  224. for scaling_factor in self.scaling_factors:
  225. # NOTE: self.max_position_embeddings is the original
  226. # maximum length before applying the rope scaling.
  227. # Thus, the maximum length after applying the rope scaling is
  228. # self.max_position_embeddings * self.scaling_factor.
  229. max_len = self.max_position_embeddings * scaling_factor
  230. t = torch.arange(max_len, dtype=torch.float)
  231. t = t / scaling_factor
  232. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  233. cos = freqs.cos()
  234. sin = freqs.sin()
  235. cache = torch.cat((cos, sin), dim=-1)
  236. if not cache_list:
  237. offset = 0
  238. else:
  239. last_offset = offsets[-1]
  240. next_max_len = cache_list[-1].shape[0]
  241. offset = last_offset + next_max_len
  242. offsets.append(offset)
  243. cache_list.append(cache)
  244. self._scaling_factor_to_offset = {
  245. float(scaling_factor): offsets[i]
  246. for i, scaling_factor in enumerate(self.scaling_factors)
  247. }
  248. assert len(self.scaling_factors) == len(offsets)
  249. return torch.cat(cache_list, dim=0)
  250. @property
  251. def scaling_factor_to_offset(self) -> Dict[float, int]:
  252. return self._scaling_factor_to_offset
  253. class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
  254. """RotaryEmbedding extended with Dynamic NTK scaling.
  255. Credits to the Reddit users /u/bloc97 and /u/emozilla
  256. """
  257. def __init__(
  258. self,
  259. head_size: int,
  260. rotary_dim: int,
  261. max_position_embeddings: int,
  262. base: int,
  263. is_neox_style: bool,
  264. scaling_factor: float,
  265. dtype: torch.dtype,
  266. ) -> None:
  267. self.scaling_factor = scaling_factor
  268. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  269. is_neox_style, dtype)
  270. def _compute_cos_sin_cache(self) -> torch.Tensor:
  271. # NOTE: self.max_position_embeddings is the original
  272. # maximum length before applying the rope scaling.
  273. # Thus, the maximum length after applying the rope scaling is
  274. # self.max_position_embeddings * self.scaling_factor.
  275. max_len = self.max_position_embeddings * self.scaling_factor
  276. base = self.base * (
  277. (self.scaling_factor * max_len / self.max_position_embeddings) -
  278. (self.scaling_factor - 1))**(self.rotary_dim /
  279. (self.rotary_dim - 2))
  280. inv_freq = self._compute_inv_freq(base)
  281. t = torch.arange(max_len, dtype=torch.float)
  282. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  283. cos = freqs.cos()
  284. sin = freqs.sin()
  285. cache = torch.cat((cos, sin), dim=-1)
  286. return cache
  287. # Inverse dim formula to find dim based on number of rotations
  288. def _yarn_find_correction_dim(num_rotations: int,
  289. dim: int,
  290. base: float = 10000,
  291. max_position_embeddings: int = 2048) -> float:
  292. return (dim * math.log(max_position_embeddings /
  293. (num_rotations * 2 * math.pi))) / (2 *
  294. math.log(base))
  295. # Find dim range bounds based on rotations
  296. def _yarn_find_correction_range(
  297. low_rot: int,
  298. high_rot: int,
  299. dim: int,
  300. base: float = 10000,
  301. max_position_embeddings: int = 2048) -> Tuple[int, int]:
  302. low = math.floor(
  303. _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
  304. high = math.ceil(
  305. _yarn_find_correction_dim(high_rot, dim, base,
  306. max_position_embeddings))
  307. return max(low, 0), min(high, dim - 1) # Clamp values just in case
  308. def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
  309. dtype: torch.dtype) -> torch.Tensor:
  310. if low == high:
  311. high += 0.001 # Prevent singularity
  312. linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
  313. ramp_func = torch.clamp(linear_func, 0, 1)
  314. return ramp_func
  315. def _yarn_get_mscale(scale: float = 1) -> float:
  316. if scale <= 1:
  317. return 1.0
  318. return 0.1 * math.log(scale) + 1.0
  319. class YaRNScalingRotaryEmbedding(RotaryEmbedding):
  320. """RotaryEmbedding extended with YaRN method.
  321. Credits to Peng et al. github.com/jquesnelle/yarn
  322. """
  323. def __init__(
  324. self,
  325. head_size: int,
  326. rotary_dim: int,
  327. max_position_embeddings: int,
  328. base: int,
  329. is_neox_style: bool,
  330. scaling_factor: float,
  331. dtype: torch.dtype,
  332. *,
  333. extrapolation_factor: float = 1,
  334. attn_factor: float = 1,
  335. beta_fast: int = 32,
  336. beta_slow: int = 1,
  337. ) -> None:
  338. self.scaling_factor = scaling_factor
  339. self.extrapolation_factor = extrapolation_factor
  340. self.attn_factor = attn_factor
  341. self.beta_fast = beta_fast
  342. self.beta_slow = beta_slow
  343. # Get n-d magnitude scaling corrected for interpolation
  344. self.mscale = float(
  345. _yarn_get_mscale(self.scaling_factor) * attn_factor)
  346. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  347. is_neox_style, dtype)
  348. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  349. pos_freqs = self.base**(
  350. torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
  351. self.rotary_dim)
  352. inv_freq_extrapolation = 1.0 / pos_freqs
  353. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  354. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  355. self.rotary_dim, self.base,
  356. self.max_position_embeddings)
  357. # Get n-d rotational scaling corrected for extrapolation
  358. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  359. low, high, self.rotary_dim // 2,
  360. dtype=torch.float)) * self.extrapolation_factor
  361. inv_freq = inv_freq_interpolation * (
  362. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  363. return inv_freq
  364. def _compute_cos_sin_cache(self) -> torch.Tensor:
  365. inv_freq = self._compute_inv_freq(self.scaling_factor)
  366. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  367. dtype=torch.float32)
  368. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  369. cos = (freqs.cos() * self.mscale)
  370. sin = (freqs.sin() * self.mscale)
  371. cache = torch.cat((cos, sin), dim=-1)
  372. return cache
  373. class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
  374. """Phi3 family of models scaled rotary embedding.
  375. Based on the original RotaryEmbedding implementation.
  376. """
  377. def __init__(
  378. self,
  379. head_size: int,
  380. rotary_dim: int,
  381. max_position_embeddings: int,
  382. original_max_position_embeddings: int,
  383. base: int,
  384. is_neox_style: bool,
  385. dtype: torch.dtype,
  386. short_factor: List[float],
  387. long_factor: List[float],
  388. short_mscale: Optional[float] = None,
  389. long_mscale: Optional[float] = None,
  390. ):
  391. super().__init__()
  392. if rotary_dim != head_size:
  393. raise ValueError(
  394. f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
  395. rotary_dim != head_size ({rotary_dim}!={head_size}).")
  396. if is_neox_style is False:
  397. raise ValueError(
  398. "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
  399. )
  400. self.head_size = head_size
  401. self.max_position_embeddings = max_position_embeddings
  402. self.original_max_position_embeddings = original_max_position_embeddings
  403. self.base = base
  404. self.short_factor = short_factor
  405. self.long_factor = long_factor
  406. scale = self.max_position_embeddings / \
  407. self.original_max_position_embeddings
  408. if scale <= 1.0:
  409. scaling_factor = 1.0
  410. else:
  411. scaling_factor = math.sqrt(
  412. 1 + math.log(scale) /
  413. math.log(self.original_max_position_embeddings))
  414. if short_mscale is None:
  415. short_mscale = scaling_factor
  416. if long_mscale is None:
  417. long_mscale = scaling_factor
  418. self.short_mscale = short_mscale
  419. self.long_mscale = long_mscale
  420. short_cache = self._compute_cos_sin_cache(
  421. original_max_position_embeddings, short_factor, short_mscale)
  422. short_cache = short_cache.to(dtype)
  423. self.register_buffer("short_cos_sin_cache",
  424. short_cache,
  425. persistent=False)
  426. long_cache = self._compute_cos_sin_cache(max_position_embeddings,
  427. long_factor, long_mscale)
  428. long_cache = long_cache.to(dtype)
  429. self.register_buffer("long_cos_sin_cache",
  430. long_cache,
  431. persistent=False)
  432. long_short_cache = torch.cat(
  433. [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
  434. self.register_buffer("long_short_cos_sin_cache",
  435. long_short_cache,
  436. persistent=False)
  437. def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
  438. rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
  439. inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
  440. 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
  441. return inv_freq
  442. def _compute_cos_sin_cache(
  443. self,
  444. max_position_embeddings: int,
  445. rescale_factors: List[float],
  446. mscale: float,
  447. ) -> torch.Tensor:
  448. inv_freq = self._compute_inv_freq(rescale_factors)
  449. t = torch.arange(max_position_embeddings, dtype=torch.float)
  450. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  451. cos = freqs.cos() * mscale
  452. sin = freqs.sin() * mscale
  453. cache = torch.cat((cos, sin), dim=-1)
  454. return cache
  455. def forward(
  456. self,
  457. positions: torch.Tensor,
  458. query: torch.Tensor,
  459. key: torch.Tensor,
  460. offsets: Optional[torch.Tensor] = None,
  461. ) -> Tuple[torch.Tensor, torch.Tensor]:
  462. query = query.view(*query.shape[:-1], -1, self.head_size)
  463. key = key.view(*key.shape[:-1], -1, self.head_size)
  464. k = self.original_max_position_embeddings
  465. long_prompt_offset = (torch.any(positions > k).float() *
  466. torch.full_like(positions, k)).long()
  467. idx = (torch.add(positions, long_prompt_offset)
  468. if long_prompt_offset is not None else positions)
  469. self.long_short_cos_sin_cache: torch.Tensor = (
  470. self.long_short_cos_sin_cache.to(idx.device))
  471. idx = torch.add(idx, offsets) if offsets is not None else idx
  472. cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
  473. cos, sin = cos_sin.chunk(2, dim=-1)
  474. cos = cos.repeat(1, 2).unsqueeze(-2)
  475. sin = sin.repeat(1, 2).unsqueeze(-2)
  476. query = query * cos + _rotate_neox(query) * sin
  477. key = key * cos + _rotate_neox(key) * sin
  478. return query.flatten(-2), key.flatten(-2)
  479. def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
  480. if scale <= 1:
  481. return 1.0
  482. return 0.1 * mscale * math.log(scale) + 1.0
  483. class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
  484. """RotaryEmbedding extended with YaRN method.
  485. Credits to Peng et al. github.com/jquesnelle/yarn
  486. """
  487. def __init__(
  488. self,
  489. head_size: int,
  490. rotary_dim: int,
  491. max_position_embeddings: int,
  492. base: int,
  493. is_neox_style: bool,
  494. scaling_factor: float,
  495. dtype: torch.dtype,
  496. *,
  497. extrapolation_factor: float = 1,
  498. attn_factor: float = 1,
  499. beta_fast: int = 32,
  500. beta_slow: int = 1,
  501. mscale: float = 1,
  502. mscale_all_dim: float = 0,
  503. ) -> None:
  504. self.scaling_factor = scaling_factor
  505. self.extrapolation_factor = extrapolation_factor
  506. self.attn_factor = attn_factor
  507. self.beta_fast = beta_fast
  508. self.beta_slow = beta_slow
  509. # Get n-d magnitude scaling corrected for interpolation.
  510. self.mscale = float(
  511. yarn_get_mscale(self.scaling_factor, float(mscale)) /
  512. yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
  513. attn_factor)
  514. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  515. is_neox_style, dtype)
  516. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  517. pos_freqs = self.base**(torch.arange(
  518. 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
  519. self.rotary_dim)
  520. inv_freq_extrapolation = 1.0 / pos_freqs
  521. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  522. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  523. self.rotary_dim, self.base,
  524. self.max_position_embeddings)
  525. # Get n-d rotational scaling corrected for extrapolation
  526. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  527. low, high, self.rotary_dim // 2,
  528. dtype=torch.float)) * self.extrapolation_factor
  529. inv_freq = inv_freq_interpolation * (
  530. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  531. return inv_freq
  532. def _compute_cos_sin_cache(self) -> torch.Tensor:
  533. inv_freq = self._compute_inv_freq(self.scaling_factor)
  534. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  535. device="cuda",
  536. dtype=torch.float32)
  537. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  538. cos = (freqs.cos() * self.mscale)
  539. sin = (freqs.sin() * self.mscale)
  540. cache = torch.cat((cos, sin), dim=-1)
  541. print("Cache shape", cache.shape)
  542. return cache
  543. def forward(
  544. self,
  545. positions: torch.Tensor,
  546. query: torch.Tensor,
  547. key: torch.Tensor,
  548. offsets: Optional[torch.Tensor] = None,
  549. ) -> Tuple[torch.Tensor, torch.Tensor]:
  550. """PyTorch-native implementation equivalent to forward()."""
  551. query_rot = query[..., :self.rotary_dim]
  552. key_rot = key[..., :self.rotary_dim]
  553. if self.rotary_dim < self.head_size:
  554. query_pass = query[..., self.rotary_dim:]
  555. key_pass = key[..., self.rotary_dim:]
  556. self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
  557. positions.device)
  558. cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
  559. if offsets is not None else positions]
  560. cos, sin = cos_sin.chunk(2, dim=-1)
  561. if self.is_neox_style:
  562. # NOTE: Here we assume that the positions tensor has the
  563. # shape [batch_size, seq_len].
  564. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  565. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  566. else:
  567. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  568. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  569. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  570. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  571. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  572. if self.rotary_dim < self.head_size:
  573. query = torch.cat((query_rot, query_pass), dim=-1)
  574. key = torch.cat((key_rot, key_pass), dim=-1)
  575. else:
  576. query = query_rot
  577. key = key_rot
  578. return query, key
  579. class GemmaRotaryEmbedding(RotaryEmbedding):
  580. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  581. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
  582. inv_freq = 1.0 / (base**(
  583. torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
  584. self.rotary_dim))
  585. return inv_freq
  586. class Llama3RotaryEmbedding(RotaryEmbedding):
  587. def __init__(
  588. self,
  589. head_size: int,
  590. rotary_dim: int,
  591. max_position_embeddings: int,
  592. base: int,
  593. is_neox_style: bool,
  594. dtype: torch.dtype,
  595. scaling_factor: float,
  596. low_freq_factor: float,
  597. high_freq_factor: float,
  598. orig_max_position: int,
  599. ) -> None:
  600. self.scaling_factor = scaling_factor
  601. self.low_freq_factor = low_freq_factor
  602. self.high_freq_factor = high_freq_factor
  603. self.orig_max_position = orig_max_position
  604. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  605. is_neox_style, dtype)
  606. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  607. inv_freqs = super()._compute_inv_freq(base)
  608. low_freq_wavelen = self.orig_max_position / self.low_freq_factor
  609. high_freq_wavelen = self.orig_max_position / self.high_freq_factor
  610. wave_len = 2 * math.pi / inv_freqs
  611. if self.low_freq_factor != self.high_freq_factor:
  612. smooth = (self.orig_max_position / wave_len - self.low_freq_factor
  613. ) / (self.high_freq_factor - self.low_freq_factor)
  614. else:
  615. smooth = 0
  616. new_freqs = torch.where(
  617. wave_len < high_freq_wavelen,
  618. inv_freqs,
  619. torch.where(
  620. wave_len > low_freq_wavelen,
  621. inv_freqs / self.scaling_factor,
  622. (1 - smooth) * inv_freqs / self.scaling_factor +
  623. smooth * inv_freqs,
  624. ),
  625. )
  626. return new_freqs
  627. class MRotaryEmbedding(RotaryEmbedding):
  628. """Rotary Embedding with Multimodal Sections."""
  629. def __init__(
  630. self,
  631. head_size: int,
  632. rotary_dim: int,
  633. max_position_embeddings: int,
  634. base: int,
  635. is_neox_style: bool,
  636. dtype: torch.dtype,
  637. mrope_section: Optional[List[int]] = None,
  638. ) -> None:
  639. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  640. is_neox_style, dtype)
  641. self.mrope_section = mrope_section
  642. if self.mrope_section:
  643. assert sum(self.mrope_section) == rotary_dim // 2
  644. def forward(
  645. self,
  646. positions: torch.Tensor,
  647. query: torch.Tensor,
  648. key: torch.Tensor,
  649. ) -> Tuple[torch.Tensor, torch.Tensor]:
  650. """PyTorch-native implementation equivalent to forward().
  651. Args:
  652. positions:
  653. [num_tokens,] (text only) or
  654. [3, num_tokens] (T/H/W positions with multimodal inputs)
  655. query: [num_tokens, num_heads * head_size]
  656. key: [num_tokens, num_kv_heads * head_size]
  657. """
  658. assert positions.ndim == 1 or positions.ndim == 2
  659. num_tokens = positions.shape[-1]
  660. cos_sin = self.cos_sin_cache[positions]
  661. cos, sin = cos_sin.chunk(2, dim=-1)
  662. if positions.ndim == 2:
  663. assert self.mrope_section
  664. cos = torch.cat([
  665. m[i]
  666. for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
  667. ],
  668. dim=-1)
  669. sin = torch.cat([
  670. m[i]
  671. for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
  672. ],
  673. dim=-1)
  674. query_shape = query.shape
  675. query = query.view(num_tokens, -1, self.head_size)
  676. query_rot = query[..., :self.rotary_dim]
  677. query_pass = query[..., self.rotary_dim:]
  678. query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
  679. query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
  680. key_shape = key.shape
  681. key = key.view(num_tokens, -1, self.head_size)
  682. key_rot = key[..., :self.rotary_dim]
  683. key_pass = key[..., self.rotary_dim:]
  684. key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
  685. key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
  686. return query, key
  687. @staticmethod
  688. def get_input_positions(
  689. input_tokens: List[int],
  690. image_grid_thw: Union[List[List[int]], torch.Tensor],
  691. video_grid_thw: Union[List[List[int]], torch.Tensor],
  692. image_token_id: int,
  693. video_token_id: int,
  694. vision_start_token_id: int,
  695. vision_end_token_id: int,
  696. spatial_merge_size: int,
  697. context_len: int = 0,
  698. ) -> Tuple[List[List[int]], int]:
  699. """Get mrope input positions and delta value."""
  700. if isinstance(image_grid_thw, torch.Tensor):
  701. image_grid_thw = image_grid_thw.tolist()
  702. if isinstance(video_grid_thw, torch.Tensor):
  703. video_grid_thw = video_grid_thw.tolist()
  704. input_tokens_tensor = torch.tensor(input_tokens)
  705. vision_start_indices = torch.argwhere(
  706. input_tokens_tensor == vision_start_token_id).squeeze(1)
  707. vision_tokens = input_tokens_tensor[vision_start_indices + 1]
  708. image_nums = (vision_tokens == image_token_id).sum()
  709. video_nums = (vision_tokens == video_token_id).sum()
  710. llm_pos_ids_list: list = []
  711. st = 0
  712. remain_images, remain_videos = image_nums, video_nums
  713. image_index, video_index = 0, 0
  714. for _ in range(image_nums + video_nums):
  715. if image_token_id in input_tokens and remain_images > 0:
  716. ed_image = input_tokens.index(image_token_id, st)
  717. else:
  718. ed_image = len(input_tokens) + 1
  719. if video_token_id in input_tokens and remain_videos > 0:
  720. ed_video = input_tokens.index(video_token_id, st)
  721. else:
  722. ed_video = len(input_tokens) + 1
  723. if ed_image < ed_video:
  724. t, h, w = (
  725. image_grid_thw[image_index][0],
  726. image_grid_thw[image_index][1],
  727. image_grid_thw[image_index][2],
  728. )
  729. image_index += 1
  730. remain_images -= 1
  731. ed = ed_image
  732. else:
  733. t, h, w = (
  734. video_grid_thw[video_index][0],
  735. video_grid_thw[video_index][1],
  736. video_grid_thw[video_index][2],
  737. )
  738. video_index += 1
  739. remain_videos -= 1
  740. ed = ed_video
  741. llm_grid_t, llm_grid_h, llm_grid_w = \
  742. t, h // spatial_merge_size, w // spatial_merge_size
  743. text_len = ed - st
  744. st_idx = llm_pos_ids_list[-1].max() + 1 if len(
  745. llm_pos_ids_list) > 0 else 0
  746. llm_pos_ids_list.append(
  747. torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
  748. t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
  749. -1, llm_grid_h * llm_grid_w).flatten()
  750. h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
  751. llm_grid_t, -1, llm_grid_w).flatten()
  752. w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
  753. llm_grid_t, llm_grid_h, -1).flatten()
  754. llm_pos_ids_list.append(
  755. torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
  756. st = ed + llm_grid_t * llm_grid_h * llm_grid_w
  757. if st < len(input_tokens):
  758. st_idx = llm_pos_ids_list[-1].max() + 1 if len(
  759. llm_pos_ids_list) > 0 else 0
  760. text_len = len(input_tokens) - st
  761. llm_pos_ids_list.append(
  762. torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
  763. llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
  764. llm_positions = llm_positions[:, context_len:]
  765. mrope_position_delta = (llm_positions.max() + 1 -
  766. len(input_tokens)).item()
  767. return llm_positions.tolist(), mrope_position_delta
  768. @staticmethod
  769. def get_next_input_positions(
  770. mrope_position_delta: int,
  771. context_len: int,
  772. seq_len: int,
  773. ) -> List[List[int]]:
  774. return [
  775. list(
  776. range(context_len + mrope_position_delta,
  777. seq_len + mrope_position_delta)) for _ in range(3)
  778. ]
  779. _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
  780. def get_rope(
  781. head_size: int,
  782. rotary_dim: int,
  783. max_position: int,
  784. base: int,
  785. is_neox_style: bool = True,
  786. rope_scaling: Optional[Dict[str, Any]] = None,
  787. dtype: Optional[torch.dtype] = None,
  788. partial_rotary_factor: float = 1.0,
  789. ) -> RotaryEmbedding:
  790. if dtype is None:
  791. dtype = torch.get_default_dtype()
  792. if rope_scaling is not None:
  793. # Transforms every value that is a list into a tuple for caching calls
  794. rope_scaling_tuple = {
  795. k: tuple(v) if isinstance(v, list) else v
  796. for k, v in rope_scaling.items()
  797. }
  798. rope_scaling_args = tuple(rope_scaling_tuple.items())
  799. else:
  800. rope_scaling_args = None
  801. if partial_rotary_factor < 1.0:
  802. rotary_dim = int(rotary_dim * partial_rotary_factor)
  803. key = (head_size, rotary_dim, max_position, base, is_neox_style,
  804. rope_scaling_args, dtype)
  805. if key in _ROPE_DICT:
  806. return _ROPE_DICT[key]
  807. if rope_scaling is None:
  808. rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
  809. is_neox_style, dtype)
  810. else:
  811. scaling_type = rope_scaling[
  812. "type"] if "type" in rope_scaling else rope_scaling["rope_type"]
  813. # The correct one should be "longrope" but keep "su" here
  814. # for backward compatible
  815. if scaling_type not in {"su", "longrope"}:
  816. scaling_factor = rope_scaling.get("factor", 1.0)
  817. if scaling_type == "llama3":
  818. low_freq_factor = rope_scaling["low_freq_factor"]
  819. high_freq_factor = rope_scaling["high_freq_factor"]
  820. original_max_position = rope_scaling[
  821. "original_max_position_embeddings"]
  822. rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
  823. max_position, base,
  824. is_neox_style, dtype,
  825. scaling_factor, low_freq_factor,
  826. high_freq_factor,
  827. original_max_position)
  828. elif scaling_type == "linear":
  829. rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
  830. max_position, base,
  831. is_neox_style,
  832. scaling_factor, dtype)
  833. elif scaling_type == "dynamic":
  834. rotary_emb = DynamicNTKScalingRotaryEmbedding(
  835. head_size, rotary_dim, max_position, base, is_neox_style,
  836. scaling_factor, dtype)
  837. elif scaling_type == "yarn":
  838. original_max_position = rope_scaling[
  839. "original_max_position_embeddings"]
  840. extra_kwargs = {
  841. k: v
  842. for k, v in rope_scaling.items()
  843. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  844. "beta_slow")
  845. }
  846. rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
  847. original_max_position,
  848. base, is_neox_style,
  849. scaling_factor, dtype,
  850. **extra_kwargs)
  851. elif scaling_type == "deepseek_yarn":
  852. original_max_position = rope_scaling[
  853. "original_max_position_embeddings"]
  854. # assert max_position == original_max_position * scaling_factor
  855. extra_kwargs = {
  856. k: v
  857. for k, v in rope_scaling.items()
  858. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  859. "beta_slow", "mscale", "mscale_all_dim")
  860. }
  861. rotary_emb = DeepseekScalingRotaryEmbedding(
  862. head_size, rotary_dim, original_max_position, base,
  863. is_neox_style, scaling_factor, dtype, **extra_kwargs)
  864. # The correct one should be "longrope" but keep "su" here
  865. # for backward compatible
  866. elif scaling_type == "su" or scaling_type == "longrope":
  867. short_factor = rope_scaling["short_factor"]
  868. long_factor = rope_scaling["long_factor"]
  869. original_max_position = rope_scaling[
  870. "original_max_position_embeddings"]
  871. extra_kwargs = {
  872. k: v
  873. for k, v in rope_scaling.items()
  874. if k in ("short_mscale", "long_mscale")
  875. }
  876. rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
  877. head_size, rotary_dim, max_position, original_max_position,
  878. base, is_neox_style, dtype, short_factor, long_factor,
  879. **extra_kwargs)
  880. elif scaling_type == "mrope":
  881. rotary_emb = MRotaryEmbedding(
  882. head_size,
  883. rotary_dim,
  884. max_position,
  885. base,
  886. is_neox_style,
  887. dtype,
  888. mrope_section=rope_scaling["mrope_section"],
  889. )
  890. else:
  891. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  892. _ROPE_DICT[key] = rotary_emb
  893. return rotary_emb