stablelm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # coding=utf-8
  2. # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # This code is based off the following work:
  18. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
  19. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
  20. """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model
  21. compatible with HuggingFace weights."""
  22. from typing import List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.modeling.metadata import InputMetadata
  27. from aphrodite.modeling.layers.activation import SiluAndMul
  28. from aphrodite.modeling.layers.attention import PagedAttention
  29. from aphrodite.modeling.layers.linear import (
  30. LinearMethodBase,
  31. MergedColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear,
  34. ColumnParallelLinear,
  35. )
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. VocabParallelEmbedding,
  40. ParallelLMHead,
  41. )
  42. from aphrodite.modeling.megatron.parallel_state import (
  43. get_tensor_model_parallel_world_size, )
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.modeling.hf_downloader import (
  46. default_weight_loader,
  47. hf_model_weights_iterator,
  48. )
  49. from aphrodite.common.sequence import SamplerOutput
  50. KVCache = Tuple[torch.Tensor, torch.Tensor]
  51. class StablelmMLP(nn.Module):
  52. def __init__(
  53. self,
  54. config: PretrainedConfig,
  55. linear_method: Optional[LinearMethodBase] = None,
  56. ) -> None:
  57. super().__init__()
  58. self.config = config
  59. self.hidden_size = config.hidden_size
  60. self.intermediate_size = config.intermediate_size
  61. if (linear_method is not None
  62. and not linear_method.quant_config.merge_weight()):
  63. self.merge_weight = False
  64. self.gate_proj = ColumnParallelLinear(
  65. config.hidden_size,
  66. config.intermediate_size,
  67. bias=False,
  68. linear_method=linear_method,
  69. )
  70. self.up_proj = ColumnParallelLinear(
  71. config.hidden_size,
  72. config.intermediate_size,
  73. bias=False,
  74. linear_method=linear_method,
  75. )
  76. else:
  77. self.merge_weight = True
  78. self.gate_up_proj = MergedColumnParallelLinear(
  79. config.hidden_size,
  80. [config.intermediate_size] * 2,
  81. bias=False,
  82. linear_method=linear_method,
  83. )
  84. self.down_proj = RowParallelLinear(config.intermediate_size,
  85. config.hidden_size,
  86. bias=False)
  87. self.act_fn = SiluAndMul()
  88. def forward(self, x: torch.Tensor) -> torch.Tensor:
  89. if self.merge_weight:
  90. gate_up, _ = self.gate_up_proj(x)
  91. else:
  92. up, _ = self.up_proj(x)
  93. gate, _ = self.gate_proj(x)
  94. gate_up = torch.cat([gate, up], dim=-1)
  95. x = self.act_fn(gate_up)
  96. x, _ = self.down_proj(x)
  97. return x
  98. class StablelmAttention(nn.Module):
  99. def __init__(
  100. self,
  101. config: PretrainedConfig,
  102. linear_method: Optional[LinearMethodBase] = None,
  103. ) -> None:
  104. super().__init__()
  105. self.config = config
  106. self.hidden_size = config.hidden_size
  107. tp_size = get_tensor_model_parallel_world_size()
  108. self.total_num_heads = config.num_attention_heads
  109. self.num_heads = self.total_num_heads // tp_size
  110. self.total_num_key_value_heads = config.num_key_value_heads
  111. if self.total_num_key_value_heads >= tp_size:
  112. # Number of KV heads is greater than TP size, so we partition
  113. # the KV heads across multiple tensor parallel GPUs.
  114. assert self.total_num_key_value_heads % tp_size == 0
  115. else:
  116. # Number of KV heads is less than TP size, so we replicate
  117. # the KV heads across multiple tensor parallel GPUs.
  118. assert tp_size % self.total_num_key_value_heads == 0
  119. self.num_key_value_heads = max(
  120. 1, self.total_num_key_value_heads // tp_size)
  121. self.head_dim = self.hidden_size // self.total_num_heads
  122. self.max_position_embeddings = config.max_position_embeddings
  123. rope_pct = self.config.partial_rotary_factor
  124. self.rotary_ndims = int(self.head_dim * rope_pct)
  125. self.scaling = self.head_dim**-0.5
  126. self.q_size = self.num_heads * self.head_dim
  127. self.kv_size = self.num_key_value_heads * self.head_dim
  128. self.qkv_bias = getattr(config, "use_qkv_bias", False)
  129. if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
  130. raise ValueError("hidden_size must be divisible by num_heads (got "
  131. f"`hidden_size`: {self.hidden_size}"
  132. f" and `num_heads`: {self.num_heads}).")
  133. if (linear_method is not None
  134. and not linear_method.quant_config.merge_weight()):
  135. self.merge_weight = False
  136. self.q_proj = ColumnParallelLinear(
  137. self.hidden_size,
  138. self.q_size,
  139. bias=self.qkv_bias,
  140. linear_method=linear_method,
  141. )
  142. self.k_proj = ColumnParallelLinear(
  143. self.hidden_size,
  144. self.kv_size,
  145. bias=self.qkv_bias,
  146. linear_method=linear_method,
  147. )
  148. self.v_proj = ColumnParallelLinear(
  149. self.hidden_size,
  150. self.kv_size,
  151. bias=self.qkv_bias,
  152. linear_method=linear_method,
  153. )
  154. else:
  155. self.merge_weight = True
  156. self.qkv_proj = QKVParallelLinear(
  157. self.hidden_size,
  158. self.head_dim,
  159. self.total_num_heads,
  160. self.total_num_key_value_heads,
  161. self.qkv_bias,
  162. linear_method=linear_method,
  163. )
  164. self.o_proj = RowParallelLinear(
  165. self.total_num_heads * self.head_dim,
  166. self.hidden_size,
  167. bias=False,
  168. linear_method=linear_method,
  169. )
  170. self.rotary_ndims = int(self.head_dim *
  171. self.config.partial_rotary_factor)
  172. self.rotary_emb = get_rope(
  173. self.head_dim,
  174. rotary_dim=self.rotary_ndims,
  175. max_position=self.config.max_position_embeddings,
  176. base=self.config.rope_theta,
  177. )
  178. self.attn = PagedAttention(
  179. self.num_heads,
  180. self.head_dim,
  181. self.scaling,
  182. num_kv_heads=self.num_key_value_heads,
  183. )
  184. def forward(
  185. self,
  186. positions: torch.Tensor,
  187. hidden_states: torch.Tensor,
  188. kv_cache: KVCache,
  189. input_metadata: InputMetadata,
  190. ) -> torch.Tensor:
  191. if self.merge_weight:
  192. qkv, _ = self.qkv_proj(hidden_states)
  193. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  194. dim=-1)
  195. else:
  196. q, _ = self.q_proj(hidden_states)
  197. k, _ = self.k_proj(hidden_states)
  198. v, _ = self.v_proj(hidden_states)
  199. q, k = self.rotary_emb(positions, q, k)
  200. k_cache, v_cache = kv_cache
  201. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  202. output, _ = self.o_proj(attn_output)
  203. return output
  204. class StablelmDecoderLayer(nn.Module):
  205. def __init__(
  206. self,
  207. config: PretrainedConfig,
  208. linear_method: Optional[LinearMethodBase] = None,
  209. ) -> None:
  210. super().__init__()
  211. self.self_attn = StablelmAttention(config)
  212. self.mlp = StablelmMLP(config, linear_method)
  213. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  214. eps=config.layer_norm_eps)
  215. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  216. eps=config.layer_norm_eps)
  217. def forward(
  218. self,
  219. positions: torch.Tensor,
  220. hidden_states: torch.Tensor,
  221. kv_cache: KVCache,
  222. input_metadata: InputMetadata,
  223. ) -> Tuple[torch.Tensor, torch.Tensor]:
  224. # Self Attention
  225. residual = hidden_states
  226. hidden_states = self.input_layernorm(hidden_states)
  227. hidden_states = self.self_attn(
  228. positions=positions,
  229. hidden_states=hidden_states,
  230. kv_cache=kv_cache,
  231. input_metadata=input_metadata,
  232. )
  233. hidden_states = residual + hidden_states
  234. # Fully Connected
  235. residual = hidden_states
  236. hidden_states = self.post_attention_layernorm(hidden_states)
  237. hidden_states = self.mlp(hidden_states)
  238. hidden_states = residual + hidden_states
  239. return hidden_states, residual
  240. class StableLMEpochModel(nn.Module):
  241. def __init__(
  242. self,
  243. config: PretrainedConfig,
  244. linear_method: Optional[LinearMethodBase] = None,
  245. ) -> None:
  246. super().__init__()
  247. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  248. config.hidden_size,
  249. linear_method=linear_method)
  250. self.layers = nn.ModuleList([
  251. StablelmDecoderLayer(config, linear_method)
  252. for _ in range(config.num_hidden_layers)
  253. ])
  254. self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  255. def forward(
  256. self,
  257. input_ids: torch.Tensor,
  258. positions: torch.Tensor,
  259. kv_caches: List[KVCache],
  260. input_metadata: InputMetadata,
  261. ) -> torch.Tensor:
  262. hidden_states = self.embed_tokens(input_ids)
  263. for i in range(len(self.layers)):
  264. layer = self.layers[i]
  265. # pylint: disable=unused-variable
  266. hidden_states, residual = layer(
  267. positions,
  268. hidden_states,
  269. kv_caches[i],
  270. input_metadata,
  271. )
  272. hidden_states = self.norm(hidden_states)
  273. return hidden_states
  274. class StablelmForCausalLM(nn.Module):
  275. def __init__(
  276. self,
  277. config: PretrainedConfig,
  278. linear_method: Optional[LinearMethodBase] = None,
  279. ) -> None:
  280. super().__init__()
  281. self.config = config
  282. self.linear_method = linear_method
  283. self.model = StableLMEpochModel(config, linear_method)
  284. self.lm_head = ParallelLMHead(config.vocab_size,
  285. config.hidden_size,
  286. linear_method=linear_method)
  287. self.sampler = Sampler(config.vocab_size)
  288. self.quant_sampler = QuantSampler(config.vocab_size)
  289. def forward(
  290. self,
  291. input_ids: torch.Tensor,
  292. positions: torch.Tensor,
  293. kv_caches: List[KVCache],
  294. input_metadata: InputMetadata,
  295. ) -> torch.Tensor:
  296. hidden_states = self.model(input_ids, positions, kv_caches,
  297. input_metadata)
  298. return hidden_states
  299. def sample(
  300. self,
  301. hidden_states: torch.Tensor,
  302. sampling_metadata: SamplingMetadata,
  303. ) -> Optional[SamplerOutput]:
  304. if (self.linear_method is not None
  305. and not self.linear_method.quant_config.merge_weight()):
  306. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  307. sampling_metadata)
  308. else:
  309. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  310. sampling_metadata)
  311. return next_tokens
  312. def load_weights(
  313. self,
  314. model_name_or_path: str,
  315. cache_dir: Optional[str] = None,
  316. load_format: str = "auto",
  317. revision: Optional[str] = None,
  318. ):
  319. stacked_params_mapping = [
  320. # (param_name, shard_name, shard_id)
  321. ("qkv_proj", "q_proj", "q"),
  322. ("qkv_proj", "k_proj", "k"),
  323. ("qkv_proj", "v_proj", "v"),
  324. ("gate_up_proj", "gate_proj", 0),
  325. ("gate_up_proj", "up_proj", 1),
  326. ]
  327. if (self.linear_method is not None
  328. and not self.linear_method.quant_config.merge_weight()):
  329. stacked_params_mapping = []
  330. params_dict = dict(self.named_parameters())
  331. for name, loaded_weight in hf_model_weights_iterator(
  332. model_name_or_path, cache_dir, load_format, revision,
  333. self.config):
  334. if "rotary_emb.inv_freq" in name:
  335. continue
  336. if ("rotary_emb.cos_cached" in name
  337. or "rotary_emb.sin_cached" in name):
  338. # Models trained using ColossalAI may include these tensors in
  339. # the checkpoint. Skip them.
  340. continue
  341. for param_name, weight_name, shard_id in stacked_params_mapping:
  342. if weight_name not in name:
  343. continue
  344. name = name.replace(weight_name, param_name)
  345. # Skip loading extra bias for GPTQ models.
  346. if name.endswith(".bias") and name not in params_dict:
  347. continue
  348. param = params_dict[name]
  349. weight_loader = param.weight_loader
  350. weight_loader(param, loaded_weight, shard_id)
  351. break
  352. else:
  353. # Skip loading extra bias for GPTQ models.
  354. if name.endswith(".bias") and name not in params_dict:
  355. continue
  356. param = params_dict[name]
  357. weight_loader = getattr(param, "weight_loader",
  358. default_weight_loader)
  359. weight_loader(param, loaded_weight)