stablelm.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # coding=utf-8
  2. # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # This code is based off the following work:
  18. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
  19. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
  20. """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
  21. model compatible with HuggingFace weights."""
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.config import CacheConfig
  28. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  29. from aphrodite.common.utils import progress_bar
  30. from aphrodite.distributed import get_tensor_model_parallel_world_size
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  33. QKVParallelLinear,
  34. RowParallelLinear)
  35. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. ParallelLMHead, VocabParallelEmbedding)
  40. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  41. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  42. from aphrodite.quantization.base_config import QuantizationConfig
  43. class StablelmMLP(nn.Module):
  44. def __init__(self,
  45. config: PretrainedConfig,
  46. quant_config: Optional[QuantizationConfig] = None) -> None:
  47. super().__init__()
  48. self.config = config
  49. self.hidden_size = config.hidden_size
  50. self.intermediate_size = config.intermediate_size
  51. self.gate_up_proj = MergedColumnParallelLinear(
  52. config.hidden_size, [config.intermediate_size] * 2,
  53. bias=False,
  54. quant_config=quant_config)
  55. self.down_proj = RowParallelLinear(config.intermediate_size,
  56. config.hidden_size,
  57. bias=False)
  58. self.act_fn = SiluAndMul()
  59. def forward(self, x: torch.Tensor) -> torch.Tensor:
  60. gate_up, _ = self.gate_up_proj(x)
  61. x = self.act_fn(gate_up)
  62. x, _ = self.down_proj(x)
  63. return x
  64. class StablelmAttention(nn.Module):
  65. def __init__(self,
  66. config: PretrainedConfig,
  67. cache_config: Optional[CacheConfig] = None,
  68. quant_config: Optional[QuantizationConfig] = None) -> None:
  69. super().__init__()
  70. self.config = config
  71. self.hidden_size = config.hidden_size
  72. tp_size = get_tensor_model_parallel_world_size()
  73. self.total_num_heads = config.num_attention_heads
  74. self.num_heads = self.total_num_heads // tp_size
  75. self.total_num_key_value_heads = config.num_key_value_heads
  76. if self.total_num_key_value_heads >= tp_size:
  77. # Number of KV heads is greater than TP size, so we partition
  78. # the KV heads across multiple tensor parallel GPUs.
  79. assert self.total_num_key_value_heads % tp_size == 0
  80. else:
  81. # Number of KV heads is less than TP size, so we replicate
  82. # the KV heads across multiple tensor parallel GPUs.
  83. assert tp_size % self.total_num_key_value_heads == 0
  84. self.num_key_value_heads = max(
  85. 1, self.total_num_key_value_heads // tp_size)
  86. self.head_dim = self.hidden_size // self.total_num_heads
  87. self.max_position_embeddings = config.max_position_embeddings
  88. rope_pct = getattr(config, "rope_pct",
  89. getattr(config, "partial_rotary_factor", 1))
  90. self.rotary_ndims = int(self.head_dim * rope_pct)
  91. self.scaling = self.head_dim**-0.5
  92. self.q_size = self.num_heads * self.head_dim
  93. self.kv_size = self.num_key_value_heads * self.head_dim
  94. self.qkv_bias = getattr(config, "use_qkv_bias", False)
  95. if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
  96. raise ValueError(f"hidden_size must be divisible by num_heads "
  97. f"(got `hidden_size`: {self.hidden_size}"
  98. f" and `num_heads`: {self.num_heads}).")
  99. self.qkv_proj = QKVParallelLinear(self.hidden_size,
  100. self.head_dim,
  101. self.total_num_heads,
  102. self.total_num_key_value_heads,
  103. self.qkv_bias,
  104. quant_config=quant_config)
  105. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  106. self.hidden_size,
  107. bias=False,
  108. quant_config=quant_config)
  109. self.rotary_emb = get_rope(
  110. self.head_dim,
  111. rotary_dim=self.rotary_ndims,
  112. max_position=self.config.max_position_embeddings,
  113. base=self.config.rope_theta,
  114. )
  115. self.attn = Attention(self.num_heads,
  116. self.head_dim,
  117. self.scaling,
  118. num_kv_heads=self.num_key_value_heads,
  119. cache_config=cache_config,
  120. quant_config=quant_config)
  121. def forward(
  122. self,
  123. positions: torch.Tensor,
  124. hidden_states: torch.Tensor,
  125. kv_cache: torch.Tensor,
  126. attn_metadata: AttentionMetadata,
  127. ) -> torch.Tensor:
  128. qkv, _ = self.qkv_proj(hidden_states)
  129. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  130. q, k = self.rotary_emb(positions, q, k)
  131. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  132. output, _ = self.o_proj(attn_output)
  133. return output
  134. class StablelmDecoderLayer(nn.Module):
  135. def __init__(
  136. self,
  137. config: PretrainedConfig,
  138. cache_config: Optional[CacheConfig] = None,
  139. quant_config: Optional[QuantizationConfig] = None,
  140. ) -> None:
  141. super().__init__()
  142. self.self_attn = StablelmAttention(config, cache_config, quant_config)
  143. self.mlp = StablelmMLP(config, quant_config)
  144. norm_eps = getattr(config, "norm_eps",
  145. getattr(config, "layer_norm_eps", 1e-05))
  146. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
  147. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  148. eps=norm_eps)
  149. def forward(
  150. self,
  151. positions: torch.Tensor,
  152. hidden_states: torch.Tensor,
  153. kv_cache: torch.Tensor,
  154. attn_metadata: AttentionMetadata,
  155. ) -> Tuple[torch.Tensor, torch.Tensor]:
  156. # Self Attention
  157. residual = hidden_states
  158. hidden_states = self.input_layernorm(hidden_states)
  159. hidden_states = self.self_attn(
  160. positions=positions,
  161. hidden_states=hidden_states,
  162. kv_cache=kv_cache,
  163. attn_metadata=attn_metadata,
  164. )
  165. hidden_states = residual + hidden_states
  166. # Fully Connected
  167. residual = hidden_states
  168. hidden_states = self.post_attention_layernorm(hidden_states)
  169. hidden_states = self.mlp(hidden_states)
  170. hidden_states = residual + hidden_states
  171. return hidden_states, residual
  172. class StableLMEpochModel(nn.Module):
  173. def __init__(self,
  174. config: PretrainedConfig,
  175. cache_config: Optional[CacheConfig] = None,
  176. quant_config: Optional[QuantizationConfig] = None) -> None:
  177. super().__init__()
  178. self.embed_tokens = VocabParallelEmbedding(
  179. config.vocab_size,
  180. config.hidden_size,
  181. )
  182. self.layers = nn.ModuleList([
  183. StablelmDecoderLayer(config, cache_config, quant_config)
  184. for _ in range(config.num_hidden_layers)
  185. ])
  186. norm_eps = getattr(config, "norm_eps",
  187. getattr(config, "layer_norm_eps", 1e-05))
  188. self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
  189. def forward(
  190. self,
  191. input_ids: torch.Tensor,
  192. positions: torch.Tensor,
  193. kv_caches: List[torch.Tensor],
  194. attn_metadata: AttentionMetadata,
  195. ) -> torch.Tensor:
  196. hidden_states = self.embed_tokens(input_ids)
  197. for i in range(len(self.layers)):
  198. layer = self.layers[i]
  199. hidden_states, residual = layer(
  200. positions,
  201. hidden_states,
  202. kv_caches[i],
  203. attn_metadata,
  204. )
  205. hidden_states = self.norm(hidden_states)
  206. return hidden_states
  207. class StablelmForCausalLM(nn.Module):
  208. def __init__(
  209. self,
  210. config: PretrainedConfig,
  211. cache_config: Optional[CacheConfig] = None,
  212. quant_config: Optional[QuantizationConfig] = None,
  213. ) -> None:
  214. super().__init__()
  215. self.config = config
  216. self.quant_config = quant_config
  217. self.model = StableLMEpochModel(config, cache_config, quant_config)
  218. self.lm_head = ParallelLMHead(config.vocab_size,
  219. config.hidden_size,
  220. quant_config=quant_config)
  221. self.logits_processor = LogitsProcessor(config.vocab_size)
  222. self.sampler = Sampler()
  223. def forward(
  224. self,
  225. input_ids: torch.Tensor,
  226. positions: torch.Tensor,
  227. kv_caches: List[torch.Tensor],
  228. attn_metadata: AttentionMetadata,
  229. intermediate_tensors: Optional[IntermediateTensors] = None,
  230. ) -> torch.Tensor:
  231. hidden_states = self.model(input_ids, positions, kv_caches,
  232. attn_metadata)
  233. return hidden_states
  234. def compute_logits(
  235. self,
  236. hidden_states: torch.Tensor,
  237. sampling_metadata: SamplingMetadata,
  238. ) -> Optional[torch.Tensor]:
  239. logits = self.logits_processor(self.lm_head, hidden_states,
  240. sampling_metadata)
  241. return logits
  242. def sample(
  243. self,
  244. logits: torch.Tensor,
  245. sampling_metadata: SamplingMetadata,
  246. ) -> Optional[SamplerOutput]:
  247. next_tokens = self.sampler(logits, sampling_metadata)
  248. return next_tokens
  249. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  250. stacked_params_mapping = [
  251. # (param_name, shard_name, shard_id)
  252. ("qkv_proj", "q_proj", "q"),
  253. ("qkv_proj", "k_proj", "k"),
  254. ("qkv_proj", "v_proj", "v"),
  255. ("gate_up_proj", "gate_proj", 0),
  256. ("gate_up_proj", "up_proj", 1),
  257. ]
  258. params_dict = dict(self.named_parameters())
  259. weights_list = list(weights)
  260. for name, loaded_weight in progress_bar(weights_list,
  261. desc="Loading modules..."):
  262. if "rotary_emb.inv_freq" in name:
  263. continue
  264. if ("rotary_emb.cos_cached" in name
  265. or "rotary_emb.sin_cached" in name):
  266. # Models trained using ColossalAI may include these tensors in
  267. # the checkpoint. Skip them.
  268. continue
  269. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  270. if weight_name not in name:
  271. continue
  272. name = name.replace(weight_name, param_name)
  273. # Skip loading extra bias for GPTQ models.
  274. if name.endswith(".bias") and name not in params_dict:
  275. continue
  276. param = params_dict[name]
  277. weight_loader = param.weight_loader
  278. weight_loader(param, loaded_weight, shard_id)
  279. break
  280. else:
  281. # Skip loading extra bias for GPTQ models.
  282. if name.endswith(".bias") and name not in params_dict:
  283. continue
  284. param = params_dict[name]
  285. weight_loader = getattr(param, "weight_loader",
  286. default_weight_loader)
  287. weight_loader(param, loaded_weight)