1
0

stablelm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # coding=utf-8
  2. # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # This code is based off the following work:
  18. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
  19. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
  20. """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model
  21. compatible with HuggingFace weights."""
  22. from typing import List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.modeling.layers.activation import SiluAndMul
  28. from aphrodite.modeling.layers.linear import (
  29. LinearMethodBase,
  30. MergedColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear,
  33. ColumnParallelLinear,
  34. )
  35. from aphrodite.modeling.layers.rotary_embedding import get_rope
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.sampler import Sampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. VocabParallelEmbedding,
  40. ParallelLMHead,
  41. )
  42. from aphrodite.distributed import (
  43. get_tensor_model_parallel_world_size, )
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.modeling.hf_downloader import (
  46. default_weight_loader,
  47. hf_model_weights_iterator,
  48. )
  49. from aphrodite.common.sequence import SamplerOutput
  50. class StablelmMLP(nn.Module):
  51. def __init__(
  52. self,
  53. config: PretrainedConfig,
  54. linear_method: Optional[LinearMethodBase] = None,
  55. ) -> None:
  56. super().__init__()
  57. self.config = config
  58. self.hidden_size = config.hidden_size
  59. self.intermediate_size = config.intermediate_size
  60. if (linear_method is not None
  61. and not linear_method.quant_config.merge_weight()):
  62. self.merge_weight = False
  63. self.gate_proj = ColumnParallelLinear(
  64. config.hidden_size,
  65. config.intermediate_size,
  66. bias=False,
  67. linear_method=linear_method,
  68. )
  69. self.up_proj = ColumnParallelLinear(
  70. config.hidden_size,
  71. config.intermediate_size,
  72. bias=False,
  73. linear_method=linear_method,
  74. )
  75. else:
  76. self.merge_weight = True
  77. self.gate_up_proj = MergedColumnParallelLinear(
  78. config.hidden_size,
  79. [config.intermediate_size] * 2,
  80. bias=False,
  81. linear_method=linear_method,
  82. )
  83. self.down_proj = RowParallelLinear(config.intermediate_size,
  84. config.hidden_size,
  85. bias=False)
  86. self.act_fn = SiluAndMul()
  87. def forward(self, x: torch.Tensor) -> torch.Tensor:
  88. if self.merge_weight:
  89. gate_up, _ = self.gate_up_proj(x)
  90. else:
  91. up, _ = self.up_proj(x)
  92. gate, _ = self.gate_proj(x)
  93. gate_up = torch.cat([gate, up], dim=-1)
  94. x = self.act_fn(gate_up)
  95. x, _ = self.down_proj(x)
  96. return x
  97. class StablelmAttention(nn.Module):
  98. def __init__(
  99. self,
  100. config: PretrainedConfig,
  101. linear_method: Optional[LinearMethodBase] = None,
  102. ) -> None:
  103. super().__init__()
  104. self.config = config
  105. self.hidden_size = config.hidden_size
  106. tp_size = get_tensor_model_parallel_world_size()
  107. self.total_num_heads = config.num_attention_heads
  108. self.num_heads = self.total_num_heads // tp_size
  109. self.total_num_key_value_heads = config.num_key_value_heads
  110. if self.total_num_key_value_heads >= tp_size:
  111. # Number of KV heads is greater than TP size, so we partition
  112. # the KV heads across multiple tensor parallel GPUs.
  113. assert self.total_num_key_value_heads % tp_size == 0
  114. else:
  115. # Number of KV heads is less than TP size, so we replicate
  116. # the KV heads across multiple tensor parallel GPUs.
  117. assert tp_size % self.total_num_key_value_heads == 0
  118. self.num_key_value_heads = max(
  119. 1, self.total_num_key_value_heads // tp_size)
  120. self.head_dim = self.hidden_size // self.total_num_heads
  121. self.max_position_embeddings = config.max_position_embeddings
  122. rope_pct = self.config.partial_rotary_factor
  123. self.rotary_ndims = int(self.head_dim * rope_pct)
  124. self.scaling = self.head_dim**-0.5
  125. self.q_size = self.num_heads * self.head_dim
  126. self.kv_size = self.num_key_value_heads * self.head_dim
  127. self.qkv_bias = getattr(config, "use_qkv_bias", False)
  128. if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
  129. raise ValueError("hidden_size must be divisible by num_heads (got "
  130. f"`hidden_size`: {self.hidden_size}"
  131. f" and `num_heads`: {self.num_heads}).")
  132. if (linear_method is not None
  133. and not linear_method.quant_config.merge_weight()):
  134. self.merge_weight = False
  135. self.q_proj = ColumnParallelLinear(
  136. self.hidden_size,
  137. self.q_size,
  138. bias=self.qkv_bias,
  139. linear_method=linear_method,
  140. )
  141. self.k_proj = ColumnParallelLinear(
  142. self.hidden_size,
  143. self.kv_size,
  144. bias=self.qkv_bias,
  145. linear_method=linear_method,
  146. )
  147. self.v_proj = ColumnParallelLinear(
  148. self.hidden_size,
  149. self.kv_size,
  150. bias=self.qkv_bias,
  151. linear_method=linear_method,
  152. )
  153. else:
  154. self.merge_weight = True
  155. self.qkv_proj = QKVParallelLinear(
  156. self.hidden_size,
  157. self.head_dim,
  158. self.total_num_heads,
  159. self.total_num_key_value_heads,
  160. self.qkv_bias,
  161. linear_method=linear_method,
  162. )
  163. self.o_proj = RowParallelLinear(
  164. self.total_num_heads * self.head_dim,
  165. self.hidden_size,
  166. bias=False,
  167. linear_method=linear_method,
  168. )
  169. self.rotary_ndims = int(self.head_dim *
  170. self.config.partial_rotary_factor)
  171. self.rotary_emb = get_rope(
  172. self.head_dim,
  173. rotary_dim=self.rotary_ndims,
  174. max_position=self.config.max_position_embeddings,
  175. base=self.config.rope_theta,
  176. )
  177. self.attn = Attention(
  178. self.num_heads,
  179. self.head_dim,
  180. self.scaling,
  181. num_kv_heads=self.num_key_value_heads,
  182. )
  183. def forward(
  184. self,
  185. positions: torch.Tensor,
  186. hidden_states: torch.Tensor,
  187. kv_cache: torch.Tensor,
  188. attn_metadata: AttentionMetadata,
  189. ) -> torch.Tensor:
  190. if self.merge_weight:
  191. qkv, _ = self.qkv_proj(hidden_states)
  192. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  193. dim=-1)
  194. else:
  195. q, _ = self.q_proj(hidden_states)
  196. k, _ = self.k_proj(hidden_states)
  197. v, _ = self.v_proj(hidden_states)
  198. q, k = self.rotary_emb(positions, q, k)
  199. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  200. output, _ = self.o_proj(attn_output)
  201. return output
  202. class StablelmDecoderLayer(nn.Module):
  203. def __init__(
  204. self,
  205. config: PretrainedConfig,
  206. linear_method: Optional[LinearMethodBase] = None,
  207. ) -> None:
  208. super().__init__()
  209. self.self_attn = StablelmAttention(config)
  210. self.mlp = StablelmMLP(config, linear_method)
  211. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  212. eps=config.layer_norm_eps)
  213. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  214. eps=config.layer_norm_eps)
  215. def forward(
  216. self,
  217. positions: torch.Tensor,
  218. hidden_states: torch.Tensor,
  219. kv_cache: torch.Tensor,
  220. attn_metadata: AttentionMetadata,
  221. ) -> Tuple[torch.Tensor, torch.Tensor]:
  222. # Self Attention
  223. residual = hidden_states
  224. hidden_states = self.input_layernorm(hidden_states)
  225. hidden_states = self.self_attn(
  226. positions=positions,
  227. hidden_states=hidden_states,
  228. kv_cache=kv_cache,
  229. attn_metadata=attn_metadata,
  230. )
  231. hidden_states = residual + hidden_states
  232. # Fully Connected
  233. residual = hidden_states
  234. hidden_states = self.post_attention_layernorm(hidden_states)
  235. hidden_states = self.mlp(hidden_states)
  236. hidden_states = residual + hidden_states
  237. return hidden_states, residual
  238. class StableLMEpochModel(nn.Module):
  239. def __init__(
  240. self,
  241. config: PretrainedConfig,
  242. linear_method: Optional[LinearMethodBase] = None,
  243. ) -> None:
  244. super().__init__()
  245. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  246. config.hidden_size,
  247. linear_method=linear_method)
  248. self.layers = nn.ModuleList([
  249. StablelmDecoderLayer(config, linear_method)
  250. for _ in range(config.num_hidden_layers)
  251. ])
  252. self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  253. def forward(
  254. self,
  255. input_ids: torch.Tensor,
  256. positions: torch.Tensor,
  257. kv_caches: List[torch.Tensor],
  258. attn_metadata: AttentionMetadata,
  259. ) -> torch.Tensor:
  260. hidden_states = self.embed_tokens(input_ids)
  261. for i in range(len(self.layers)):
  262. layer = self.layers[i]
  263. # pylint: disable=unused-variable
  264. hidden_states, residual = layer(
  265. positions,
  266. hidden_states,
  267. kv_caches[i],
  268. attn_metadata,
  269. )
  270. hidden_states = self.norm(hidden_states)
  271. return hidden_states
  272. class StablelmForCausalLM(nn.Module):
  273. def __init__(
  274. self,
  275. config: PretrainedConfig,
  276. linear_method: Optional[LinearMethodBase] = None,
  277. ) -> None:
  278. super().__init__()
  279. self.config = config
  280. self.linear_method = linear_method
  281. self.model = StableLMEpochModel(config, linear_method)
  282. self.lm_head = ParallelLMHead(config.vocab_size,
  283. config.hidden_size,
  284. linear_method=linear_method)
  285. self.logits_processor = LogitsProcessor(config.vocab_size)
  286. self.sampler = Sampler()
  287. def forward(
  288. self,
  289. input_ids: torch.Tensor,
  290. positions: torch.Tensor,
  291. kv_caches: List[torch.Tensor],
  292. attn_metadata: AttentionMetadata,
  293. ) -> torch.Tensor:
  294. hidden_states = self.model(input_ids, positions, kv_caches,
  295. attn_metadata)
  296. return hidden_states
  297. def compute_logits(self, hidden_states: torch.Tensor,
  298. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  299. logits = self.logits_processor(self.lm_head, hidden_states,
  300. sampling_metadata)
  301. return logits
  302. def sample(
  303. self,
  304. logits: torch.Tensor,
  305. sampling_metadata: SamplingMetadata,
  306. ) -> Optional[SamplerOutput]:
  307. next_tokens = self.sampler(logits, sampling_metadata)
  308. return next_tokens
  309. def load_weights(
  310. self,
  311. model_name_or_path: str,
  312. cache_dir: Optional[str] = None,
  313. load_format: str = "auto",
  314. revision: Optional[str] = None,
  315. ):
  316. stacked_params_mapping = [
  317. # (param_name, shard_name, shard_id)
  318. ("qkv_proj", "q_proj", "q"),
  319. ("qkv_proj", "k_proj", "k"),
  320. ("qkv_proj", "v_proj", "v"),
  321. ("gate_up_proj", "gate_proj", 0),
  322. ("gate_up_proj", "up_proj", 1),
  323. ]
  324. if (self.linear_method is not None
  325. and not self.linear_method.quant_config.merge_weight()):
  326. stacked_params_mapping = []
  327. params_dict = dict(self.named_parameters())
  328. for name, loaded_weight in hf_model_weights_iterator(
  329. model_name_or_path, cache_dir, load_format, revision,
  330. self.config):
  331. if "rotary_emb.inv_freq" in name:
  332. continue
  333. if ("rotary_emb.cos_cached" in name
  334. or "rotary_emb.sin_cached" in name):
  335. # Models trained using ColossalAI may include these tensors in
  336. # the checkpoint. Skip them.
  337. continue
  338. for param_name, weight_name, shard_id in stacked_params_mapping:
  339. if weight_name not in name:
  340. continue
  341. name = name.replace(weight_name, param_name)
  342. # Skip loading extra bias for GPTQ models.
  343. if name.endswith(".bias") and name not in params_dict:
  344. continue
  345. param = params_dict[name]
  346. weight_loader = param.weight_loader
  347. weight_loader(param, loaded_weight, shard_id)
  348. break
  349. else:
  350. # Skip loading extra bias for GPTQ models.
  351. if name.endswith(".bias") and name not in params_dict:
  352. continue
  353. param = params_dict[name]
  354. weight_loader = getattr(param, "weight_loader",
  355. default_weight_loader)
  356. weight_loader(param, loaded_weight)