jais.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
  6. # reserved.
  7. # Copyright 2023 Cerebras Systems.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only Jais model compatible with HuggingFace weights."""
  21. import math
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.config import CacheConfig
  27. from aphrodite.common.sequence import IntermediateTensors
  28. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  29. get_tensor_model_parallel_world_size)
  30. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. from aphrodite.transformers_utils.configs import JAISConfig
  41. class SwiGLUActivation(nn.Module):
  42. def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
  43. return x1 * nn.functional.silu(x2)
  44. def _get_alibi_slopes(n):
  45. def get_slopes_power_of_2(n):
  46. start = 2**(-(2**-(math.log2(n) - 3)))
  47. ratio = start
  48. return [start * ratio**i for i in range(n)]
  49. if math.log2(n).is_integer():
  50. return get_slopes_power_of_2(n)
  51. else:
  52. closest_power_of_2 = 2**math.floor(math.log2(n))
  53. return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
  54. 2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
  55. class JAISAttention(nn.Module):
  56. def __init__(
  57. self,
  58. config: JAISConfig,
  59. cache_config: Optional[CacheConfig] = None,
  60. quant_config: Optional[QuantizationConfig] = None,
  61. ):
  62. super().__init__()
  63. self.hidden_size = config.hidden_size
  64. total_num_heads = config.num_attention_heads
  65. tensor_model_parallel_world_size = (
  66. get_tensor_model_parallel_world_size())
  67. assert total_num_heads % tensor_model_parallel_world_size == 0
  68. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  69. self.head_dim = self.hidden_size // total_num_heads
  70. if hasattr(config, "scale_qk_dot_by_d"):
  71. config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
  72. self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
  73. self.scale = self.head_dim**-self.attn_scale_power
  74. self.c_attn = QKVParallelLinear(
  75. self.hidden_size,
  76. self.head_dim,
  77. total_num_heads,
  78. bias=True,
  79. quant_config=quant_config,
  80. )
  81. self.c_proj = RowParallelLinear(
  82. self.hidden_size,
  83. self.hidden_size,
  84. bias=True,
  85. quant_config=quant_config,
  86. )
  87. tp_rank = get_tensor_model_parallel_rank()
  88. head_start = tp_rank * self.num_heads
  89. head_end = (tp_rank + 1) * self.num_heads
  90. alibi_slopes = _get_alibi_slopes(total_num_heads)
  91. alibi_slopes = alibi_slopes[head_start:head_end]
  92. self.attn = Attention(self.num_heads,
  93. self.head_dim,
  94. scale=self.scale,
  95. alibi_slopes=alibi_slopes,
  96. cache_config=cache_config,
  97. quant_config=quant_config)
  98. def forward(
  99. self,
  100. hidden_states: torch.Tensor,
  101. kv_cache: torch.Tensor,
  102. attn_metadata: AttentionMetadata,
  103. ) -> torch.Tensor:
  104. qkv, _ = self.c_attn(hidden_states)
  105. q, k, v = qkv.chunk(chunks=3, dim=-1)
  106. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  107. attn_output, _ = self.c_proj(attn_output)
  108. return attn_output
  109. class JAISMLP(nn.Module):
  110. def __init__(
  111. self,
  112. intermediate_size: int,
  113. config: JAISConfig,
  114. quant_config: Optional[QuantizationConfig] = None,
  115. ):
  116. super().__init__()
  117. hidden_size = config.hidden_size
  118. self.swiglu = config.activation_function == "swiglu"
  119. self.c_fc = ColumnParallelLinear(
  120. hidden_size,
  121. intermediate_size,
  122. bias=True,
  123. quant_config=quant_config,
  124. )
  125. self.c_fc2 = (ColumnParallelLinear(
  126. hidden_size,
  127. intermediate_size,
  128. bias=True,
  129. quant_config=quant_config,
  130. ) if self.swiglu else None)
  131. self.c_proj = RowParallelLinear(
  132. intermediate_size,
  133. hidden_size,
  134. bias=True,
  135. quant_config=quant_config,
  136. )
  137. self.act = SwiGLUActivation()
  138. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  139. if self.swiglu:
  140. hidden_states2, _ = self.c_fc2(hidden_states)
  141. hidden_states, _ = self.c_fc(hidden_states)
  142. hidden_states = (self.act(hidden_states, hidden_states2)
  143. if self.swiglu else self.act(hidden_states))
  144. hidden_states, _ = self.c_proj(hidden_states)
  145. return hidden_states
  146. class JAISBlock(nn.Module):
  147. def __init__(
  148. self,
  149. config: JAISConfig,
  150. cache_config: Optional[CacheConfig] = None,
  151. quant_config: Optional[QuantizationConfig] = None,
  152. ):
  153. super().__init__()
  154. hidden_size = config.hidden_size
  155. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  156. hidden_size)
  157. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  158. self.attn = JAISAttention(config, cache_config, quant_config)
  159. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  160. self.mlp = JAISMLP(inner_dim, config, quant_config)
  161. def forward(
  162. self,
  163. hidden_states: torch.Tensor,
  164. kv_cache: torch.Tensor,
  165. attn_metadata: AttentionMetadata,
  166. ) -> torch.Tensor:
  167. residual = hidden_states
  168. hidden_states = self.ln_1(hidden_states)
  169. attn_output = self.attn(
  170. hidden_states=hidden_states,
  171. kv_cache=kv_cache,
  172. attn_metadata=attn_metadata,
  173. )
  174. # residual connection
  175. hidden_states = attn_output + residual
  176. residual = hidden_states
  177. hidden_states = self.ln_2(hidden_states)
  178. feed_forward_hidden_states = self.mlp(hidden_states)
  179. # residual connection
  180. hidden_states = residual + feed_forward_hidden_states
  181. return hidden_states
  182. class JAISModel(nn.Module):
  183. def __init__(
  184. self,
  185. config: JAISConfig,
  186. cache_config: Optional[CacheConfig] = None,
  187. quant_config: Optional[QuantizationConfig] = None,
  188. ):
  189. super().__init__()
  190. self.config = config
  191. assert not config.add_cross_attention
  192. assert not config.scale_attn_by_inverse_layer_idx
  193. assert not config.reorder_and_upcast_attn
  194. self.embed_dim = config.hidden_size
  195. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  196. self.wpe = (nn.Embedding(config.max_position_embeddings,
  197. self.embed_dim)
  198. if config.position_embedding_type != "alibi" else None)
  199. if hasattr(config, "embeddings_scale"):
  200. self.embeddings_scale = config.embeddings_scale
  201. else:
  202. self.embeddings_scale = config.mup_embeddings_scale
  203. self.h = nn.ModuleList([
  204. JAISBlock(config, cache_config, quant_config)
  205. for _ in range(config.num_hidden_layers)
  206. ])
  207. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  208. def forward(
  209. self,
  210. input_ids: torch.Tensor,
  211. position_ids: torch.Tensor,
  212. kv_caches: List[torch.Tensor],
  213. attn_metadata: AttentionMetadata,
  214. ) -> torch.Tensor:
  215. inputs_embeds = self.wte(input_ids)
  216. if self.wpe is not None:
  217. position_embeds = self.wpe(position_ids)
  218. hidden_states = inputs_embeds + position_embeds
  219. else:
  220. hidden_states = inputs_embeds
  221. hidden_states *= torch.tensor(float(self.embeddings_scale),
  222. dtype=hidden_states.dtype)
  223. for i in range(len(self.h)):
  224. layer = self.h[i]
  225. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  226. hidden_states = self.ln_f(hidden_states)
  227. return hidden_states
  228. class JAISLMHeadModel(nn.Module):
  229. def __init__(
  230. self,
  231. config: JAISConfig,
  232. cache_config: Optional[CacheConfig] = None,
  233. quant_config: Optional[QuantizationConfig] = None,
  234. ):
  235. super().__init__()
  236. self.config = config
  237. self.quant_config = quant_config
  238. self.transformer = JAISModel(config, cache_config, quant_config)
  239. self.lm_head = self.transformer.wte
  240. if hasattr(config, "width_scale"):
  241. self.output_logits_scale = config.width_scale
  242. else:
  243. self.output_logits_scale = (config.mup_output_alpha *
  244. config.mup_width_scale)
  245. self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
  246. scale=self.output_logits_scale)
  247. self.sampler = Sampler()
  248. def forward(
  249. self,
  250. input_ids: torch.Tensor,
  251. positions: torch.Tensor,
  252. kv_caches: List[torch.Tensor],
  253. attn_metadata: AttentionMetadata,
  254. intermediate_tensors: Optional[IntermediateTensors] = None,
  255. ) -> torch.Tensor:
  256. hidden_states = self.transformer(input_ids, positions, kv_caches,
  257. attn_metadata)
  258. return hidden_states
  259. def compute_logits(
  260. self,
  261. hidden_states: torch.Tensor,
  262. sampling_metadata: SamplingMetadata,
  263. ) -> Optional[torch.Tensor]:
  264. logits = self.logits_processor(self.lm_head, hidden_states,
  265. sampling_metadata)
  266. return logits
  267. def sample(
  268. self,
  269. logits: torch.Tensor,
  270. sampling_metadata: SamplingMetadata,
  271. ) -> Optional[SamplerOutput]:
  272. next_tokens = self.sampler(logits, sampling_metadata)
  273. return next_tokens
  274. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  275. params_dict = dict(self.named_parameters(remove_duplicate=False))
  276. for name, loaded_weight in weights:
  277. if "lm_head.weight" in name:
  278. # GPT-2 ties the weights of the embedding layer and the final
  279. # linear layer.
  280. continue
  281. if ".attn.bias" in name or ".attn.masked_bias" in name:
  282. # Skip attention mask.
  283. # NOTE: "c_attn.bias" should not be skipped.
  284. continue
  285. if "relative_pe" in name:
  286. continue
  287. if not name.startswith("transformer."):
  288. name = "transformer." + name
  289. param = params_dict[name]
  290. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  291. # Because of this, we need to transpose the weights.
  292. # Note(zhuohan): the logic below might break quantized models.
  293. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  294. if conv1d_weight_name not in name:
  295. continue
  296. if not name.endswith(".weight"):
  297. continue
  298. loaded_weight = loaded_weight.t()
  299. weight_loader = getattr(param, "weight_loader",
  300. default_weight_loader)
  301. weight_loader(param, loaded_weight)