jais.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
  6. # reserved.
  7. # Copyright 2023 Cerebras Systems.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only Jais model compatible with HuggingFace weights."""
  21. import math
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.config import CacheConfig
  27. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  28. from aphrodite.common.utils import progress_bar
  29. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  30. get_tensor_model_parallel_world_size)
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. from aphrodite.transformers_utils.configs import JAISConfig
  42. class SwiGLUActivation(nn.Module):
  43. def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
  44. return x1 * nn.functional.silu(x2)
  45. def _get_alibi_slopes(n):
  46. def get_slopes_power_of_2(n):
  47. start = 2**(-(2**-(math.log2(n) - 3)))
  48. ratio = start
  49. return [start * ratio**i for i in range(n)]
  50. if math.log2(n).is_integer():
  51. return get_slopes_power_of_2(n)
  52. else:
  53. closest_power_of_2 = 2**math.floor(math.log2(n))
  54. return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
  55. 2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
  56. class JAISAttention(nn.Module):
  57. def __init__(
  58. self,
  59. config: JAISConfig,
  60. cache_config: Optional[CacheConfig] = None,
  61. quant_config: Optional[QuantizationConfig] = None,
  62. ):
  63. super().__init__()
  64. self.hidden_size = config.hidden_size
  65. total_num_heads = config.num_attention_heads
  66. tensor_model_parallel_world_size = (
  67. get_tensor_model_parallel_world_size())
  68. assert total_num_heads % tensor_model_parallel_world_size == 0
  69. self.num_heads = total_num_heads // tensor_model_parallel_world_size
  70. self.head_dim = self.hidden_size // total_num_heads
  71. if hasattr(config, "scale_qk_dot_by_d"):
  72. config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
  73. self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
  74. self.scale = self.head_dim**-self.attn_scale_power
  75. self.c_attn = QKVParallelLinear(
  76. self.hidden_size,
  77. self.head_dim,
  78. total_num_heads,
  79. bias=True,
  80. quant_config=quant_config,
  81. )
  82. self.c_proj = RowParallelLinear(
  83. self.hidden_size,
  84. self.hidden_size,
  85. bias=True,
  86. quant_config=quant_config,
  87. )
  88. tp_rank = get_tensor_model_parallel_rank()
  89. head_start = tp_rank * self.num_heads
  90. head_end = (tp_rank + 1) * self.num_heads
  91. alibi_slopes = _get_alibi_slopes(total_num_heads)
  92. alibi_slopes = alibi_slopes[head_start:head_end]
  93. self.attn = Attention(self.num_heads,
  94. self.head_dim,
  95. scale=self.scale,
  96. alibi_slopes=alibi_slopes,
  97. cache_config=cache_config,
  98. quant_config=quant_config)
  99. def forward(
  100. self,
  101. hidden_states: torch.Tensor,
  102. kv_cache: torch.Tensor,
  103. attn_metadata: AttentionMetadata,
  104. ) -> torch.Tensor:
  105. qkv, _ = self.c_attn(hidden_states)
  106. q, k, v = qkv.chunk(chunks=3, dim=-1)
  107. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  108. attn_output, _ = self.c_proj(attn_output)
  109. return attn_output
  110. class JAISMLP(nn.Module):
  111. def __init__(
  112. self,
  113. intermediate_size: int,
  114. config: JAISConfig,
  115. quant_config: Optional[QuantizationConfig] = None,
  116. ):
  117. super().__init__()
  118. hidden_size = config.hidden_size
  119. self.swiglu = config.activation_function == "swiglu"
  120. self.c_fc = ColumnParallelLinear(
  121. hidden_size,
  122. intermediate_size,
  123. bias=True,
  124. quant_config=quant_config,
  125. )
  126. self.c_fc2 = (ColumnParallelLinear(
  127. hidden_size,
  128. intermediate_size,
  129. bias=True,
  130. quant_config=quant_config,
  131. ) if self.swiglu else None)
  132. self.c_proj = RowParallelLinear(
  133. intermediate_size,
  134. hidden_size,
  135. bias=True,
  136. quant_config=quant_config,
  137. )
  138. self.act = SwiGLUActivation()
  139. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  140. if self.swiglu:
  141. hidden_states2, _ = self.c_fc2(hidden_states)
  142. hidden_states, _ = self.c_fc(hidden_states)
  143. hidden_states = (self.act(hidden_states, hidden_states2)
  144. if self.swiglu else self.act(hidden_states))
  145. hidden_states, _ = self.c_proj(hidden_states)
  146. return hidden_states
  147. class JAISBlock(nn.Module):
  148. def __init__(
  149. self,
  150. config: JAISConfig,
  151. cache_config: Optional[CacheConfig] = None,
  152. quant_config: Optional[QuantizationConfig] = None,
  153. ):
  154. super().__init__()
  155. hidden_size = config.hidden_size
  156. inner_dim = (config.n_inner if config.n_inner is not None else 4 *
  157. hidden_size)
  158. self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  159. self.attn = JAISAttention(config, cache_config, quant_config)
  160. self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  161. self.mlp = JAISMLP(inner_dim, config, quant_config)
  162. def forward(
  163. self,
  164. hidden_states: torch.Tensor,
  165. kv_cache: torch.Tensor,
  166. attn_metadata: AttentionMetadata,
  167. ) -> torch.Tensor:
  168. residual = hidden_states
  169. hidden_states = self.ln_1(hidden_states)
  170. attn_output = self.attn(
  171. hidden_states=hidden_states,
  172. kv_cache=kv_cache,
  173. attn_metadata=attn_metadata,
  174. )
  175. # residual connection
  176. hidden_states = attn_output + residual
  177. residual = hidden_states
  178. hidden_states = self.ln_2(hidden_states)
  179. feed_forward_hidden_states = self.mlp(hidden_states)
  180. # residual connection
  181. hidden_states = residual + feed_forward_hidden_states
  182. return hidden_states
  183. class JAISModel(nn.Module):
  184. def __init__(
  185. self,
  186. config: JAISConfig,
  187. cache_config: Optional[CacheConfig] = None,
  188. quant_config: Optional[QuantizationConfig] = None,
  189. ):
  190. super().__init__()
  191. self.config = config
  192. assert not config.add_cross_attention
  193. assert not config.scale_attn_by_inverse_layer_idx
  194. assert not config.reorder_and_upcast_attn
  195. self.embed_dim = config.hidden_size
  196. self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
  197. self.wpe = (nn.Embedding(config.max_position_embeddings,
  198. self.embed_dim)
  199. if config.position_embedding_type != "alibi" else None)
  200. if hasattr(config, "embeddings_scale"):
  201. self.embeddings_scale = config.embeddings_scale
  202. else:
  203. self.embeddings_scale = config.mup_embeddings_scale
  204. self.h = nn.ModuleList([
  205. JAISBlock(config, cache_config, quant_config)
  206. for _ in range(config.num_hidden_layers)
  207. ])
  208. self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  209. def forward(
  210. self,
  211. input_ids: torch.Tensor,
  212. position_ids: torch.Tensor,
  213. kv_caches: List[torch.Tensor],
  214. attn_metadata: AttentionMetadata,
  215. ) -> torch.Tensor:
  216. inputs_embeds = self.wte(input_ids)
  217. if self.wpe is not None:
  218. position_embeds = self.wpe(position_ids)
  219. hidden_states = inputs_embeds + position_embeds
  220. else:
  221. hidden_states = inputs_embeds
  222. hidden_states *= torch.tensor(float(self.embeddings_scale),
  223. dtype=hidden_states.dtype)
  224. for i in range(len(self.h)):
  225. layer = self.h[i]
  226. hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
  227. hidden_states = self.ln_f(hidden_states)
  228. return hidden_states
  229. class JAISLMHeadModel(nn.Module):
  230. def __init__(
  231. self,
  232. config: JAISConfig,
  233. cache_config: Optional[CacheConfig] = None,
  234. quant_config: Optional[QuantizationConfig] = None,
  235. ):
  236. super().__init__()
  237. self.config = config
  238. self.quant_config = quant_config
  239. self.transformer = JAISModel(config, cache_config, quant_config)
  240. self.lm_head = self.transformer.wte
  241. if hasattr(config, "width_scale"):
  242. self.output_logits_scale = config.width_scale
  243. else:
  244. self.output_logits_scale = (config.mup_output_alpha *
  245. config.mup_width_scale)
  246. self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
  247. scale=self.output_logits_scale)
  248. self.sampler = Sampler()
  249. def forward(
  250. self,
  251. input_ids: torch.Tensor,
  252. positions: torch.Tensor,
  253. kv_caches: List[torch.Tensor],
  254. attn_metadata: AttentionMetadata,
  255. intermediate_tensors: Optional[IntermediateTensors] = None,
  256. ) -> torch.Tensor:
  257. hidden_states = self.transformer(input_ids, positions, kv_caches,
  258. attn_metadata)
  259. return hidden_states
  260. def compute_logits(
  261. self,
  262. hidden_states: torch.Tensor,
  263. sampling_metadata: SamplingMetadata,
  264. ) -> Optional[torch.Tensor]:
  265. logits = self.logits_processor(self.lm_head, hidden_states,
  266. sampling_metadata)
  267. return logits
  268. def sample(
  269. self,
  270. logits: torch.Tensor,
  271. sampling_metadata: SamplingMetadata,
  272. ) -> Optional[SamplerOutput]:
  273. next_tokens = self.sampler(logits, sampling_metadata)
  274. return next_tokens
  275. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  276. params_dict = dict(self.named_parameters(remove_duplicate=False))
  277. weights_list = list(weights)
  278. for name, loaded_weight in progress_bar(weights_list,
  279. desc="Loading modules..."):
  280. if "lm_head.weight" in name:
  281. # GPT-2 ties the weights of the embedding layer and the final
  282. # linear layer.
  283. continue
  284. if ".attn.bias" in name or ".attn.masked_bias" in name:
  285. # Skip attention mask.
  286. # NOTE: "c_attn.bias" should not be skipped.
  287. continue
  288. if "relative_pe" in name:
  289. continue
  290. if not name.startswith("transformer."):
  291. name = "transformer." + name
  292. param = params_dict[name]
  293. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
  294. # Because of this, we need to transpose the weights.
  295. # Note(zhuohan): the logic below might break quantized models.
  296. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
  297. if conv1d_weight_name not in name:
  298. continue
  299. if not name.endswith(".weight"):
  300. continue
  301. loaded_weight = loaded_weight.t()
  302. weight_loader = getattr(param, "weight_loader",
  303. default_weight_loader)
  304. weight_loader(param, loaded_weight)