rotary_embedding.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Rotary Positional Embeddings."""
  25. import math
  26. from typing import Any, Dict, List, Optional, Tuple, Union
  27. import torch
  28. import torch.nn as nn
  29. from aphrodite.modeling._custom_op import CustomOp
  30. def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
  31. x1 = x[..., :x.shape[-1] // 2]
  32. x2 = x[..., x.shape[-1] // 2:]
  33. return torch.cat((-x2, x1), dim=-1)
  34. def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
  35. x1 = x[..., ::2]
  36. x2 = x[..., 1::2]
  37. x = torch.stack((-x2, x1), dim=-1)
  38. return x.flatten(-2)
  39. def _apply_rotary_emb(
  40. x: torch.Tensor,
  41. cos: torch.Tensor,
  42. sin: torch.Tensor,
  43. is_neox_style: bool,
  44. ) -> torch.Tensor:
  45. """
  46. Args:
  47. x: [num_tokens, num_heads, head_size]
  48. cos: [num_tokens, head_size // 2]
  49. sin: [num_tokens, head_size // 2]
  50. is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
  51. positional embeddings.
  52. """
  53. cos = cos.unsqueeze(-2).to(x.dtype)
  54. sin = sin.unsqueeze(-2).to(x.dtype)
  55. if is_neox_style:
  56. x1, x2 = torch.chunk(x, 2, dim=-1)
  57. else:
  58. x1 = x[..., ::2]
  59. x2 = x[..., 1::2]
  60. o1 = x1 * cos - x2 * sin
  61. o2 = x2 * cos + x1 * sin
  62. if is_neox_style:
  63. return torch.cat((o1, o2), dim=-1)
  64. else:
  65. return torch.stack((o1, o2), dim=-1).flatten(-2)
  66. class RotaryEmbedding(CustomOp):
  67. """Original rotary positional embedding."""
  68. def __init__(
  69. self,
  70. head_size: int,
  71. rotary_dim: int,
  72. max_position_embeddings: int,
  73. base: int,
  74. is_neox_style: bool,
  75. dtype: torch.dtype,
  76. ) -> None:
  77. super().__init__()
  78. self.head_size = head_size
  79. self.rotary_dim = rotary_dim
  80. self.max_position_embeddings = max_position_embeddings
  81. self.base = base
  82. self.is_neox_style = is_neox_style
  83. self.dtype = dtype
  84. cache = self._compute_cos_sin_cache()
  85. cache = cache.to(dtype)
  86. self.cos_sin_cache: torch.Tensor
  87. self.register_buffer("cos_sin_cache", cache, persistent=False)
  88. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  89. """Compute the inverse frequency."""
  90. # NOTE: To exactly match the HF implementation, we need to
  91. # use CPU to compute the cache and then move it to GPU. However, we
  92. # create the cache on GPU for faster initialization. This may cause
  93. # a slight numerical difference between the HF implementation and ours.
  94. inv_freq = 1.0 / (base**(torch.arange(
  95. 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
  96. return inv_freq
  97. def _compute_cos_sin_cache(self) -> torch.Tensor:
  98. """Compute the cos and sin cache."""
  99. inv_freq = self._compute_inv_freq(self.base)
  100. t = torch.arange(self.max_position_embeddings, dtype=torch.float)
  101. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  102. cos = freqs.cos()
  103. sin = freqs.sin()
  104. cache = torch.cat((cos, sin), dim=-1)
  105. return cache
  106. def forward_native(
  107. self,
  108. positions: torch.Tensor,
  109. query: torch.Tensor,
  110. key: torch.Tensor,
  111. offsets: Optional[torch.Tensor] = None,
  112. ) -> Tuple[torch.Tensor, torch.Tensor]:
  113. """A PyTorch-native implementation of forward()."""
  114. if offsets is not None:
  115. positions = positions + offsets
  116. positions = positions.flatten()
  117. num_tokens = positions.shape[0]
  118. cos_sin = self.cos_sin_cache.index_select(0, positions)
  119. cos, sin = cos_sin.chunk(2, dim=-1)
  120. query_shape = query.shape
  121. query = query.view(num_tokens, -1, self.head_size)
  122. query_rot = query[..., :self.rotary_dim]
  123. query_pass = query[..., self.rotary_dim:]
  124. query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
  125. query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
  126. key_shape = key.shape
  127. key = key.view(num_tokens, -1, self.head_size)
  128. key_rot = key[..., :self.rotary_dim]
  129. key_pass = key[..., self.rotary_dim:]
  130. key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
  131. key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
  132. return query, key
  133. def forward_cuda(
  134. self,
  135. positions: torch.Tensor,
  136. query: torch.Tensor,
  137. key: torch.Tensor,
  138. offsets: Optional[torch.Tensor] = None,
  139. ) -> Tuple[torch.Tensor, torch.Tensor]:
  140. from aphrodite import _custom_ops as ops
  141. self.cos_sin_cache = self.cos_sin_cache.to(query.device,
  142. dtype=query.dtype)
  143. # ops.rotary_embedding()/batched_rotary_embedding()
  144. # are in-place operations that update the query and key tensors.
  145. if offsets is not None:
  146. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  147. self.cos_sin_cache,
  148. self.is_neox_style, self.rotary_dim,
  149. offsets)
  150. else:
  151. ops.rotary_embedding(positions, query, key, self.head_size,
  152. self.cos_sin_cache, self.is_neox_style)
  153. return query, key
  154. def forward_xpu(
  155. self,
  156. positions: torch.Tensor,
  157. query: torch.Tensor,
  158. key: torch.Tensor,
  159. offsets: Optional[torch.Tensor] = None,
  160. ) -> Tuple[torch.Tensor, torch.Tensor]:
  161. from aphrodite._ipex_ops import ipex_ops as ops
  162. self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
  163. dtype=query.dtype)
  164. # ops.rotary_embedding()/batched_rotary_embedding()
  165. # are in-place operations that update the query and key tensors.
  166. if offsets is not None:
  167. ops.batched_rotary_embedding(positions, query, key, self.head_size,
  168. self.cos_sin_cache,
  169. self.is_neox_style, self.rotary_dim,
  170. offsets)
  171. else:
  172. ops.rotary_embedding(positions, query, key, self.head_size,
  173. self.cos_sin_cache, self.is_neox_style)
  174. return query, key
  175. def forward_triton(
  176. self,
  177. positions: torch.Tensor,
  178. query: torch.Tensor,
  179. key: torch.Tensor,
  180. offsets: Optional[torch.Tensor] = None,
  181. ) -> Tuple[torch.Tensor, torch.Tensor]:
  182. return self.forward_cuda(positions, query, key, offsets)
  183. def extra_repr(self) -> str:
  184. s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
  185. s += f", max_position_embeddings={self.max_position_embeddings}"
  186. s += f", base={self.base}, is_neox_style={self.is_neox_style}"
  187. return s
  188. class LinearScalingRotaryEmbedding(RotaryEmbedding):
  189. """RotaryEmbedding extended with linear scaling.
  190. It supports multiple scaling factors. Since multiple LoRA adapters may have
  191. different scaling factors, we need multiple cos/sin caches. In this way,
  192. instead of running rotary embedding kernel per lora, we can run multiple
  193. lora in a batched way.
  194. In addition to that, we also keep the cos/sin cache for the scaling factor
  195. of 1 (default) at all times.
  196. Exemplary for two scaling factors x=1, y and z with embeddings
  197. [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
  198. [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
  199. [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
  200. we construct the cos/sin cache as follows:
  201. [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
  202. ...
  203. [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
  204. We then use offsets to index into the cos/sin cache for
  205. the respective scaling factors.
  206. The offset to cache can be accessed via `scaling_factor_to_offset` API.
  207. Credits to the Reddit user /u/kaiokendev
  208. """
  209. def __init__(
  210. self,
  211. head_size: int,
  212. rotary_dim: int,
  213. max_position_embeddings: int,
  214. base: int,
  215. is_neox_style: bool,
  216. scaling_factors: Union[List[float], float],
  217. dtype: torch.dtype,
  218. ) -> None:
  219. if isinstance(scaling_factors, float):
  220. scaling_factors = [scaling_factors]
  221. self.scaling_factors: List[float] = scaling_factors # noqa
  222. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  223. is_neox_style, dtype)
  224. # Lazy initialized.
  225. self._scaling_factor_to_offset: Dict[float, int]
  226. def _compute_cos_sin_cache(self) -> torch.Tensor:
  227. inv_freq = self._compute_inv_freq(self.base)
  228. cache_list: List[torch.Tensor] = []
  229. # offsets to the next cache in a tensor.
  230. # Each offset corresponds to the same index in scaling_factors.
  231. offsets: List[int] = []
  232. for scaling_factor in self.scaling_factors:
  233. # NOTE: self.max_position_embeddings is the original
  234. # maximum length before applying the rope scaling.
  235. # Thus, the maximum length after applying the rope scaling is
  236. # self.max_position_embeddings * self.scaling_factor.
  237. max_len = self.max_position_embeddings * scaling_factor
  238. t = torch.arange(max_len, dtype=torch.float)
  239. t = t / scaling_factor
  240. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  241. cos = freqs.cos()
  242. sin = freqs.sin()
  243. cache = torch.cat((cos, sin), dim=-1)
  244. if not cache_list:
  245. offset = 0
  246. else:
  247. last_offset = offsets[-1]
  248. next_max_len = cache_list[-1].shape[0]
  249. offset = last_offset + next_max_len
  250. offsets.append(offset)
  251. cache_list.append(cache)
  252. self._scaling_factor_to_offset = {
  253. float(scaling_factor): offsets[i]
  254. for i, scaling_factor in enumerate(self.scaling_factors)
  255. }
  256. assert len(self.scaling_factors) == len(offsets)
  257. return torch.cat(cache_list, dim=0)
  258. @property
  259. def scaling_factor_to_offset(self) -> Dict[float, int]:
  260. return self._scaling_factor_to_offset
  261. class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
  262. """RotaryEmbedding extended with Dynamic NTK scaling.
  263. Credits to the Reddit users /u/bloc97 and /u/emozilla
  264. """
  265. def __init__(
  266. self,
  267. head_size: int,
  268. rotary_dim: int,
  269. max_position_embeddings: int,
  270. base: int,
  271. is_neox_style: bool,
  272. scaling_factor: float,
  273. dtype: torch.dtype,
  274. ) -> None:
  275. self.scaling_factor = scaling_factor
  276. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  277. is_neox_style, dtype)
  278. def _compute_cos_sin_cache(self) -> torch.Tensor:
  279. # NOTE: self.max_position_embeddings is the original
  280. # maximum length before applying the rope scaling.
  281. # Thus, the maximum length after applying the rope scaling is
  282. # self.max_position_embeddings * self.scaling_factor.
  283. max_len = self.max_position_embeddings * self.scaling_factor
  284. base = self.base * (
  285. (self.scaling_factor * max_len / self.max_position_embeddings) -
  286. (self.scaling_factor - 1))**(self.rotary_dim /
  287. (self.rotary_dim - 2))
  288. inv_freq = self._compute_inv_freq(base)
  289. t = torch.arange(max_len, dtype=torch.float)
  290. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  291. cos = freqs.cos()
  292. sin = freqs.sin()
  293. cache = torch.cat((cos, sin), dim=-1)
  294. return cache
  295. # Inverse dim formula to find dim based on number of rotations
  296. def _yarn_find_correction_dim(num_rotations: int,
  297. dim: int,
  298. base: float = 10000,
  299. max_position_embeddings: int = 2048) -> float:
  300. return (dim * math.log(max_position_embeddings /
  301. (num_rotations * 2 * math.pi))) / (2 *
  302. math.log(base))
  303. # Find dim range bounds based on rotations
  304. def _yarn_find_correction_range(
  305. low_rot: int,
  306. high_rot: int,
  307. dim: int,
  308. base: float = 10000,
  309. max_position_embeddings: int = 2048) -> Tuple[int, int]:
  310. low = math.floor(
  311. _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
  312. high = math.ceil(
  313. _yarn_find_correction_dim(high_rot, dim, base,
  314. max_position_embeddings))
  315. return max(low, 0), min(high, dim - 1) # Clamp values just in case
  316. def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
  317. dtype: torch.dtype) -> torch.Tensor:
  318. if low == high:
  319. high += 0.001 # Prevent singularity
  320. linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
  321. ramp_func = torch.clamp(linear_func, 0, 1)
  322. return ramp_func
  323. def _yarn_get_mscale(scale: float = 1) -> float:
  324. if scale <= 1:
  325. return 1.0
  326. return 0.1 * math.log(scale) + 1.0
  327. class YaRNScalingRotaryEmbedding(RotaryEmbedding):
  328. """RotaryEmbedding extended with YaRN method.
  329. Credits to Peng et al. github.com/jquesnelle/yarn
  330. """
  331. def __init__(
  332. self,
  333. head_size: int,
  334. rotary_dim: int,
  335. max_position_embeddings: int,
  336. base: int,
  337. is_neox_style: bool,
  338. scaling_factor: float,
  339. dtype: torch.dtype,
  340. *,
  341. extrapolation_factor: float = 1,
  342. attn_factor: float = 1,
  343. beta_fast: int = 32,
  344. beta_slow: int = 1,
  345. ) -> None:
  346. self.scaling_factor = scaling_factor
  347. self.extrapolation_factor = extrapolation_factor
  348. self.attn_factor = attn_factor
  349. self.beta_fast = beta_fast
  350. self.beta_slow = beta_slow
  351. # Get n-d magnitude scaling corrected for interpolation
  352. self.mscale = float(
  353. _yarn_get_mscale(self.scaling_factor) * attn_factor)
  354. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  355. is_neox_style, dtype)
  356. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  357. pos_freqs = self.base**(
  358. torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
  359. self.rotary_dim)
  360. inv_freq_extrapolation = 1.0 / pos_freqs
  361. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  362. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  363. self.rotary_dim, self.base,
  364. self.max_position_embeddings)
  365. # Get n-d rotational scaling corrected for extrapolation
  366. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  367. low, high, self.rotary_dim // 2,
  368. dtype=torch.float)) * self.extrapolation_factor
  369. inv_freq = inv_freq_interpolation * (
  370. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  371. return inv_freq
  372. def _compute_cos_sin_cache(self) -> torch.Tensor:
  373. inv_freq = self._compute_inv_freq(self.scaling_factor)
  374. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  375. dtype=torch.float32)
  376. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  377. cos = (freqs.cos() * self.mscale)
  378. sin = (freqs.sin() * self.mscale)
  379. cache = torch.cat((cos, sin), dim=-1)
  380. return cache
  381. class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
  382. """Phi3 family of models scaled rotary embedding.
  383. Based on the original RotaryEmbedding implementation.
  384. """
  385. def __init__(
  386. self,
  387. head_size: int,
  388. rotary_dim: int,
  389. max_position_embeddings: int,
  390. original_max_position_embeddings: int,
  391. base: int,
  392. is_neox_style: bool,
  393. dtype: torch.dtype,
  394. short_factor: List[float],
  395. long_factor: List[float],
  396. short_mscale: Optional[float] = None,
  397. long_mscale: Optional[float] = None,
  398. ):
  399. super().__init__()
  400. if rotary_dim != head_size:
  401. raise ValueError(
  402. f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
  403. rotary_dim != head_size ({rotary_dim}!={head_size}).")
  404. if is_neox_style is False:
  405. raise ValueError(
  406. "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
  407. )
  408. self.head_size = head_size
  409. self.max_position_embeddings = max_position_embeddings
  410. self.original_max_position_embeddings = original_max_position_embeddings
  411. self.base = base
  412. self.short_factor = short_factor
  413. self.long_factor = long_factor
  414. scale = self.max_position_embeddings / \
  415. self.original_max_position_embeddings
  416. if scale <= 1.0:
  417. scaling_factor = 1.0
  418. else:
  419. scaling_factor = math.sqrt(
  420. 1 + math.log(scale) /
  421. math.log(self.original_max_position_embeddings))
  422. if short_mscale is None:
  423. short_mscale = scaling_factor
  424. if long_mscale is None:
  425. long_mscale = scaling_factor
  426. self.short_mscale = short_mscale
  427. self.long_mscale = long_mscale
  428. short_cache = self._compute_cos_sin_cache(
  429. original_max_position_embeddings, short_factor, short_mscale)
  430. short_cache = short_cache.to(dtype)
  431. self.register_buffer("short_cos_sin_cache",
  432. short_cache,
  433. persistent=False)
  434. long_cache = self._compute_cos_sin_cache(max_position_embeddings,
  435. long_factor, long_mscale)
  436. long_cache = long_cache.to(dtype)
  437. self.register_buffer("long_cos_sin_cache",
  438. long_cache,
  439. persistent=False)
  440. long_short_cache = torch.cat(
  441. [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
  442. self.register_buffer("long_short_cos_sin_cache",
  443. long_short_cache,
  444. persistent=False)
  445. def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
  446. rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
  447. inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
  448. 0, self.head_size, 2, dtype=torch.float) / self.head_size)))
  449. return inv_freq
  450. def _compute_cos_sin_cache(
  451. self,
  452. max_position_embeddings: int,
  453. rescale_factors: List[float],
  454. mscale: float,
  455. ) -> torch.Tensor:
  456. inv_freq = self._compute_inv_freq(rescale_factors)
  457. t = torch.arange(max_position_embeddings, dtype=torch.float)
  458. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  459. cos = freqs.cos() * mscale
  460. sin = freqs.sin() * mscale
  461. cache = torch.cat((cos, sin), dim=-1)
  462. return cache
  463. def forward(
  464. self,
  465. positions: torch.Tensor,
  466. query: torch.Tensor,
  467. key: torch.Tensor,
  468. offsets: Optional[torch.Tensor] = None,
  469. ) -> Tuple[torch.Tensor, torch.Tensor]:
  470. query = query.view(*query.shape[:-1], -1, self.head_size)
  471. key = key.view(*key.shape[:-1], -1, self.head_size)
  472. k = self.original_max_position_embeddings
  473. long_prompt_offset = (torch.any(positions > k).float() *
  474. torch.full_like(positions, k)).long()
  475. idx = (torch.add(positions, long_prompt_offset)
  476. if long_prompt_offset is not None else positions)
  477. self.long_short_cos_sin_cache: torch.Tensor = (
  478. self.long_short_cos_sin_cache.to(idx.device))
  479. idx = torch.add(idx, offsets) if offsets is not None else idx
  480. cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
  481. cos, sin = cos_sin.chunk(2, dim=-1)
  482. cos = cos.repeat(1, 2).unsqueeze(-2)
  483. sin = sin.repeat(1, 2).unsqueeze(-2)
  484. query = query * cos + _rotate_neox(query) * sin
  485. key = key * cos + _rotate_neox(key) * sin
  486. return query.flatten(-2), key.flatten(-2)
  487. def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
  488. if scale <= 1:
  489. return 1.0
  490. return 0.1 * mscale * math.log(scale) + 1.0
  491. class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
  492. """RotaryEmbedding extended with YaRN method.
  493. Credits to Peng et al. github.com/jquesnelle/yarn
  494. """
  495. def __init__(
  496. self,
  497. head_size: int,
  498. rotary_dim: int,
  499. max_position_embeddings: int,
  500. base: int,
  501. is_neox_style: bool,
  502. scaling_factor: float,
  503. dtype: torch.dtype,
  504. *,
  505. extrapolation_factor: float = 1,
  506. attn_factor: float = 1,
  507. beta_fast: int = 32,
  508. beta_slow: int = 1,
  509. mscale: float = 1,
  510. mscale_all_dim: float = 0,
  511. ) -> None:
  512. self.scaling_factor = scaling_factor
  513. self.extrapolation_factor = extrapolation_factor
  514. self.attn_factor = attn_factor
  515. self.beta_fast = beta_fast
  516. self.beta_slow = beta_slow
  517. # Get n-d magnitude scaling corrected for interpolation.
  518. self.mscale = float(
  519. yarn_get_mscale(self.scaling_factor, float(mscale)) /
  520. yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
  521. attn_factor)
  522. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  523. is_neox_style, dtype)
  524. def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
  525. pos_freqs = self.base**(torch.arange(
  526. 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
  527. self.rotary_dim)
  528. inv_freq_extrapolation = 1.0 / pos_freqs
  529. inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
  530. low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
  531. self.rotary_dim, self.base,
  532. self.max_position_embeddings)
  533. # Get n-d rotational scaling corrected for extrapolation
  534. inv_freq_mask = (1 - _yarn_linear_ramp_mask(
  535. low, high, self.rotary_dim // 2,
  536. dtype=torch.float)) * self.extrapolation_factor
  537. inv_freq = inv_freq_interpolation * (
  538. 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
  539. return inv_freq
  540. def _compute_cos_sin_cache(self) -> torch.Tensor:
  541. inv_freq = self._compute_inv_freq(self.scaling_factor)
  542. t = torch.arange(self.max_position_embeddings * self.scaling_factor,
  543. device="cuda",
  544. dtype=torch.float32)
  545. freqs = torch.einsum("i,j -> ij", t, inv_freq)
  546. cos = (freqs.cos() * self.mscale)
  547. sin = (freqs.sin() * self.mscale)
  548. cache = torch.cat((cos, sin), dim=-1)
  549. print("Cache shape", cache.shape)
  550. return cache
  551. def forward(
  552. self,
  553. positions: torch.Tensor,
  554. query: torch.Tensor,
  555. key: torch.Tensor,
  556. offsets: Optional[torch.Tensor] = None,
  557. ) -> Tuple[torch.Tensor, torch.Tensor]:
  558. """PyTorch-native implementation equivalent to forward()."""
  559. query_rot = query[..., :self.rotary_dim]
  560. key_rot = key[..., :self.rotary_dim]
  561. if self.rotary_dim < self.head_size:
  562. query_pass = query[..., self.rotary_dim:]
  563. key_pass = key[..., self.rotary_dim:]
  564. self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
  565. positions.device)
  566. cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
  567. if offsets is not None else positions]
  568. cos, sin = cos_sin.chunk(2, dim=-1)
  569. if self.is_neox_style:
  570. # NOTE: Here we assume that the positions tensor has the
  571. # shape [batch_size, seq_len].
  572. cos = cos.repeat(1, 1, 2).unsqueeze(-2)
  573. sin = sin.repeat(1, 1, 2).unsqueeze(-2)
  574. else:
  575. cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
  576. sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
  577. rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
  578. query_rot = query_rot * cos + rotate_fn(query_rot) * sin
  579. key_rot = key_rot * cos + rotate_fn(key_rot) * sin
  580. if self.rotary_dim < self.head_size:
  581. query = torch.cat((query_rot, query_pass), dim=-1)
  582. key = torch.cat((key_rot, key_pass), dim=-1)
  583. else:
  584. query = query_rot
  585. key = key_rot
  586. return query, key
  587. class GemmaRotaryEmbedding(RotaryEmbedding):
  588. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  589. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
  590. inv_freq = 1.0 / (base**(
  591. torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
  592. self.rotary_dim))
  593. return inv_freq
  594. class Llama3RotaryEmbedding(RotaryEmbedding):
  595. def __init__(
  596. self,
  597. head_size: int,
  598. rotary_dim: int,
  599. max_position_embeddings: int,
  600. base: int,
  601. is_neox_style: bool,
  602. dtype: torch.dtype,
  603. scaling_factor: float,
  604. low_freq_factor: float,
  605. high_freq_factor: float,
  606. orig_max_position: int,
  607. ) -> None:
  608. self.scaling_factor = scaling_factor
  609. self.low_freq_factor = low_freq_factor
  610. self.high_freq_factor = high_freq_factor
  611. self.orig_max_position = orig_max_position
  612. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  613. is_neox_style, dtype)
  614. def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
  615. inv_freqs = super()._compute_inv_freq(base)
  616. low_freq_wavelen = self.orig_max_position / self.low_freq_factor
  617. high_freq_wavelen = self.orig_max_position / self.high_freq_factor
  618. wave_len = 2 * math.pi / inv_freqs
  619. if self.low_freq_factor != self.high_freq_factor:
  620. smooth = (self.orig_max_position / wave_len - self.low_freq_factor
  621. ) / (self.high_freq_factor - self.low_freq_factor)
  622. else:
  623. smooth = 0
  624. new_freqs = torch.where(
  625. wave_len < high_freq_wavelen,
  626. inv_freqs,
  627. torch.where(
  628. wave_len > low_freq_wavelen,
  629. inv_freqs / self.scaling_factor,
  630. (1 - smooth) * inv_freqs / self.scaling_factor +
  631. smooth * inv_freqs,
  632. ),
  633. )
  634. return new_freqs
  635. class MRotaryEmbedding(RotaryEmbedding):
  636. """Rotary Embedding with Multimodal Sections."""
  637. def __init__(
  638. self,
  639. head_size: int,
  640. rotary_dim: int,
  641. max_position_embeddings: int,
  642. base: int,
  643. is_neox_style: bool,
  644. dtype: torch.dtype,
  645. mrope_section: Optional[List[int]] = None,
  646. ) -> None:
  647. super().__init__(head_size, rotary_dim, max_position_embeddings, base,
  648. is_neox_style, dtype)
  649. self.mrope_section = mrope_section
  650. if self.mrope_section:
  651. assert sum(self.mrope_section) == rotary_dim // 2
  652. def forward(
  653. self,
  654. positions: torch.Tensor,
  655. query: torch.Tensor,
  656. key: torch.Tensor,
  657. ) -> Tuple[torch.Tensor, torch.Tensor]:
  658. """PyTorch-native implementation equivalent to forward().
  659. Args:
  660. positions:
  661. [num_tokens,] (text only) or
  662. [3, num_tokens] (T/H/W positions with multimodal inputs)
  663. query: [num_tokens, num_heads * head_size]
  664. key: [num_tokens, num_kv_heads * head_size]
  665. """
  666. assert positions.ndim == 1 or positions.ndim == 2
  667. num_tokens = positions.shape[-1]
  668. cos_sin = self.cos_sin_cache[positions]
  669. cos, sin = cos_sin.chunk(2, dim=-1)
  670. if positions.ndim == 2:
  671. assert self.mrope_section
  672. cos = torch.cat([
  673. m[i]
  674. for i, m in enumerate(cos.split(self.mrope_section, dim=-1))
  675. ],
  676. dim=-1)
  677. sin = torch.cat([
  678. m[i]
  679. for i, m in enumerate(sin.split(self.mrope_section, dim=-1))
  680. ],
  681. dim=-1)
  682. query_shape = query.shape
  683. query = query.view(num_tokens, -1, self.head_size)
  684. query_rot = query[..., :self.rotary_dim]
  685. query_pass = query[..., self.rotary_dim:]
  686. query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
  687. query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
  688. key_shape = key.shape
  689. key = key.view(num_tokens, -1, self.head_size)
  690. key_rot = key[..., :self.rotary_dim]
  691. key_pass = key[..., self.rotary_dim:]
  692. key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
  693. key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
  694. return query, key
  695. @staticmethod
  696. def get_input_positions(
  697. input_tokens: List[int],
  698. image_grid_thw: Union[List[List[int]], torch.Tensor],
  699. video_grid_thw: Union[List[List[int]], torch.Tensor],
  700. image_token_id: int,
  701. video_token_id: int,
  702. vision_start_token_id: int,
  703. vision_end_token_id: int,
  704. spatial_merge_size: int,
  705. context_len: int = 0,
  706. ) -> Tuple[List[List[int]], int]:
  707. """Get mrope input positions and delta value."""
  708. if isinstance(image_grid_thw, torch.Tensor):
  709. image_grid_thw = image_grid_thw.tolist()
  710. if isinstance(video_grid_thw, torch.Tensor):
  711. video_grid_thw = video_grid_thw.tolist()
  712. input_tokens_tensor = torch.tensor(input_tokens)
  713. vision_start_indices = torch.argwhere(
  714. input_tokens_tensor == vision_start_token_id).squeeze(1)
  715. vision_tokens = input_tokens_tensor[vision_start_indices + 1]
  716. image_nums = (vision_tokens == image_token_id).sum()
  717. video_nums = (vision_tokens == video_token_id).sum()
  718. llm_pos_ids_list: list = []
  719. st = 0
  720. remain_images, remain_videos = image_nums, video_nums
  721. image_index, video_index = 0, 0
  722. for _ in range(image_nums + video_nums):
  723. if image_token_id in input_tokens and remain_images > 0:
  724. ed_image = input_tokens.index(image_token_id, st)
  725. else:
  726. ed_image = len(input_tokens) + 1
  727. if video_token_id in input_tokens and remain_videos > 0:
  728. ed_video = input_tokens.index(video_token_id, st)
  729. else:
  730. ed_video = len(input_tokens) + 1
  731. if ed_image < ed_video:
  732. t, h, w = (
  733. image_grid_thw[image_index][0],
  734. image_grid_thw[image_index][1],
  735. image_grid_thw[image_index][2],
  736. )
  737. image_index += 1
  738. remain_images -= 1
  739. ed = ed_image
  740. else:
  741. t, h, w = (
  742. video_grid_thw[video_index][0],
  743. video_grid_thw[video_index][1],
  744. video_grid_thw[video_index][2],
  745. )
  746. video_index += 1
  747. remain_videos -= 1
  748. ed = ed_video
  749. llm_grid_t, llm_grid_h, llm_grid_w = \
  750. t, h // spatial_merge_size, w // spatial_merge_size
  751. text_len = ed - st
  752. st_idx = llm_pos_ids_list[-1].max() + 1 if len(
  753. llm_pos_ids_list) > 0 else 0
  754. llm_pos_ids_list.append(
  755. torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
  756. t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
  757. -1, llm_grid_h * llm_grid_w).flatten()
  758. h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
  759. llm_grid_t, -1, llm_grid_w).flatten()
  760. w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
  761. llm_grid_t, llm_grid_h, -1).flatten()
  762. llm_pos_ids_list.append(
  763. torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
  764. st = ed + llm_grid_t * llm_grid_h * llm_grid_w
  765. if st < len(input_tokens):
  766. st_idx = llm_pos_ids_list[-1].max() + 1 if len(
  767. llm_pos_ids_list) > 0 else 0
  768. text_len = len(input_tokens) - st
  769. llm_pos_ids_list.append(
  770. torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
  771. llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
  772. llm_positions = llm_positions[:, context_len:]
  773. mrope_position_delta = (llm_positions.max() + 1 -
  774. len(input_tokens)).item()
  775. return llm_positions.tolist(), mrope_position_delta
  776. @staticmethod
  777. def get_next_input_positions(
  778. mrope_position_delta: int,
  779. context_len: int,
  780. seq_len: int,
  781. ) -> List[List[int]]:
  782. return [
  783. list(
  784. range(context_len + mrope_position_delta,
  785. seq_len + mrope_position_delta)) for _ in range(3)
  786. ]
  787. _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
  788. def get_rope(
  789. head_size: int,
  790. rotary_dim: int,
  791. max_position: int,
  792. base: int,
  793. is_neox_style: bool = True,
  794. rope_scaling: Optional[Dict[str, Any]] = None,
  795. dtype: Optional[torch.dtype] = None,
  796. partial_rotary_factor: float = 1.0,
  797. ) -> RotaryEmbedding:
  798. if dtype is None:
  799. dtype = torch.get_default_dtype()
  800. if rope_scaling is not None:
  801. # Transforms every value that is a list into a tuple for caching calls
  802. rope_scaling_tuple = {
  803. k: tuple(v) if isinstance(v, list) else v
  804. for k, v in rope_scaling.items()
  805. }
  806. rope_scaling_args = tuple(rope_scaling_tuple.items())
  807. else:
  808. rope_scaling_args = None
  809. if partial_rotary_factor < 1.0:
  810. rotary_dim = int(rotary_dim * partial_rotary_factor)
  811. key = (head_size, rotary_dim, max_position, base, is_neox_style,
  812. rope_scaling_args, dtype)
  813. if key in _ROPE_DICT:
  814. return _ROPE_DICT[key]
  815. if rope_scaling is None:
  816. rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
  817. is_neox_style, dtype)
  818. else:
  819. scaling_type = rope_scaling[
  820. "type"] if "type" in rope_scaling else rope_scaling["rope_type"]
  821. # The correct one should be "longrope" but keep "su" here
  822. # for backward compatible
  823. if scaling_type not in {"su", "longrope"}:
  824. scaling_factor = rope_scaling.get("factor", 1.0)
  825. if scaling_type == "llama3":
  826. low_freq_factor = rope_scaling["low_freq_factor"]
  827. high_freq_factor = rope_scaling["high_freq_factor"]
  828. original_max_position = rope_scaling[
  829. "original_max_position_embeddings"]
  830. rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
  831. max_position, base,
  832. is_neox_style, dtype,
  833. scaling_factor, low_freq_factor,
  834. high_freq_factor,
  835. original_max_position)
  836. elif scaling_type == "linear":
  837. rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
  838. max_position, base,
  839. is_neox_style,
  840. scaling_factor, dtype)
  841. elif scaling_type == "dynamic":
  842. rotary_emb = DynamicNTKScalingRotaryEmbedding(
  843. head_size, rotary_dim, max_position, base, is_neox_style,
  844. scaling_factor, dtype)
  845. elif scaling_type == "yarn":
  846. original_max_position = rope_scaling[
  847. "original_max_position_embeddings"]
  848. extra_kwargs = {
  849. k: v
  850. for k, v in rope_scaling.items()
  851. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  852. "beta_slow")
  853. }
  854. rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
  855. original_max_position,
  856. base, is_neox_style,
  857. scaling_factor, dtype,
  858. **extra_kwargs)
  859. elif scaling_type == "deepseek_yarn":
  860. original_max_position = rope_scaling[
  861. "original_max_position_embeddings"]
  862. # assert max_position == original_max_position * scaling_factor
  863. extra_kwargs = {
  864. k: v
  865. for k, v in rope_scaling.items()
  866. if k in ("extrapolation_factor", "attn_factor", "beta_fast",
  867. "beta_slow", "mscale", "mscale_all_dim")
  868. }
  869. rotary_emb = DeepseekScalingRotaryEmbedding(
  870. head_size, rotary_dim, original_max_position, base,
  871. is_neox_style, scaling_factor, dtype, **extra_kwargs)
  872. # The correct one should be "longrope" but keep "su" here
  873. # for backward compatible
  874. elif scaling_type == "su" or scaling_type == "longrope":
  875. short_factor = rope_scaling["short_factor"]
  876. long_factor = rope_scaling["long_factor"]
  877. original_max_position = rope_scaling[
  878. "original_max_position_embeddings"]
  879. extra_kwargs = {
  880. k: v
  881. for k, v in rope_scaling.items()
  882. if k in ("short_mscale", "long_mscale")
  883. }
  884. rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
  885. head_size, rotary_dim, max_position, original_max_position,
  886. base, is_neox_style, dtype, short_factor, long_factor,
  887. **extra_kwargs)
  888. elif scaling_type == "mrope":
  889. rotary_emb = MRotaryEmbedding(
  890. head_size,
  891. rotary_dim,
  892. max_position,
  893. base,
  894. is_neox_style,
  895. dtype,
  896. mrope_section=rope_scaling["mrope_section"],
  897. )
  898. else:
  899. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  900. _ROPE_DICT[key] = rotary_emb
  901. return rotary_emb