stablelm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # coding=utf-8
  2. # Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. # This code is based off the following work:
  18. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
  19. # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
  20. """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
  21. model compatible with HuggingFace weights."""
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.sequence import SamplerOutput
  28. from aphrodite.distributed import get_tensor_model_parallel_world_size
  29. from aphrodite.modeling.layers.activation import SiluAndMul
  30. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.rotary_embedding import get_rope
  35. from aphrodite.modeling.layers.sampler import Sampler
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. ParallelLMHead, VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. class StablelmMLP(nn.Module):
  42. def __init__(self,
  43. config: PretrainedConfig,
  44. quant_config: Optional[QuantizationConfig] = None) -> None:
  45. super().__init__()
  46. self.config = config
  47. self.hidden_size = config.hidden_size
  48. self.intermediate_size = config.intermediate_size
  49. self.gate_up_proj = MergedColumnParallelLinear(
  50. config.hidden_size, [config.intermediate_size] * 2,
  51. bias=False,
  52. quant_config=quant_config)
  53. self.down_proj = RowParallelLinear(config.intermediate_size,
  54. config.hidden_size,
  55. bias=False)
  56. self.act_fn = SiluAndMul()
  57. def forward(self, x: torch.Tensor) -> torch.Tensor:
  58. gate_up, _ = self.gate_up_proj(x)
  59. x = self.act_fn(gate_up)
  60. x, _ = self.down_proj(x)
  61. return x
  62. class StablelmAttention(nn.Module):
  63. def __init__(self,
  64. config: PretrainedConfig,
  65. quant_config: Optional[QuantizationConfig] = None) -> None:
  66. super().__init__()
  67. self.config = config
  68. self.hidden_size = config.hidden_size
  69. tp_size = get_tensor_model_parallel_world_size()
  70. self.total_num_heads = config.num_attention_heads
  71. self.num_heads = self.total_num_heads // tp_size
  72. self.total_num_key_value_heads = config.num_key_value_heads
  73. if self.total_num_key_value_heads >= tp_size:
  74. # Number of KV heads is greater than TP size, so we partition
  75. # the KV heads across multiple tensor parallel GPUs.
  76. assert self.total_num_key_value_heads % tp_size == 0
  77. else:
  78. # Number of KV heads is less than TP size, so we replicate
  79. # the KV heads across multiple tensor parallel GPUs.
  80. assert tp_size % self.total_num_key_value_heads == 0
  81. self.num_key_value_heads = max(
  82. 1, self.total_num_key_value_heads // tp_size)
  83. self.head_dim = self.hidden_size // self.total_num_heads
  84. self.max_position_embeddings = config.max_position_embeddings
  85. rope_pct = getattr(config, "rope_pct",
  86. getattr(config, "partial_rotary_factor", 1))
  87. self.rotary_ndims = int(self.head_dim * rope_pct)
  88. self.scaling = self.head_dim**-0.5
  89. self.q_size = self.num_heads * self.head_dim
  90. self.kv_size = self.num_key_value_heads * self.head_dim
  91. self.qkv_bias = getattr(config, "use_qkv_bias", False)
  92. if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
  93. raise ValueError(f"hidden_size must be divisible by num_heads "
  94. f"(got `hidden_size`: {self.hidden_size}"
  95. f" and `num_heads`: {self.num_heads}).")
  96. self.qkv_proj = QKVParallelLinear(self.hidden_size,
  97. self.head_dim,
  98. self.total_num_heads,
  99. self.total_num_key_value_heads,
  100. self.qkv_bias,
  101. quant_config=quant_config)
  102. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  103. self.hidden_size,
  104. bias=False,
  105. quant_config=quant_config)
  106. self.rotary_emb = get_rope(
  107. self.head_dim,
  108. rotary_dim=self.rotary_ndims,
  109. max_position=self.config.max_position_embeddings,
  110. base=self.config.rope_theta,
  111. )
  112. self.attn = Attention(self.num_heads,
  113. self.head_dim,
  114. self.scaling,
  115. num_kv_heads=self.num_key_value_heads)
  116. def forward(
  117. self,
  118. positions: torch.Tensor,
  119. hidden_states: torch.Tensor,
  120. kv_cache: torch.Tensor,
  121. attn_metadata: AttentionMetadata,
  122. ) -> torch.Tensor:
  123. qkv, _ = self.qkv_proj(hidden_states)
  124. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  125. q, k = self.rotary_emb(positions, q, k)
  126. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  127. output, _ = self.o_proj(attn_output)
  128. return output
  129. class StablelmDecoderLayer(nn.Module):
  130. def __init__(
  131. self,
  132. config: PretrainedConfig,
  133. quant_config: Optional[QuantizationConfig] = None,
  134. ) -> None:
  135. super().__init__()
  136. self.self_attn = StablelmAttention(config)
  137. self.mlp = StablelmMLP(config, quant_config)
  138. norm_eps = getattr(config, "norm_eps",
  139. getattr(config, "layer_norm_eps", 1e-05))
  140. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
  141. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  142. eps=norm_eps)
  143. def forward(
  144. self,
  145. positions: torch.Tensor,
  146. hidden_states: torch.Tensor,
  147. kv_cache: torch.Tensor,
  148. attn_metadata: AttentionMetadata,
  149. ) -> Tuple[torch.Tensor, torch.Tensor]:
  150. # Self Attention
  151. residual = hidden_states
  152. hidden_states = self.input_layernorm(hidden_states)
  153. hidden_states = self.self_attn(
  154. positions=positions,
  155. hidden_states=hidden_states,
  156. kv_cache=kv_cache,
  157. attn_metadata=attn_metadata,
  158. )
  159. hidden_states = residual + hidden_states
  160. # Fully Connected
  161. residual = hidden_states
  162. hidden_states = self.post_attention_layernorm(hidden_states)
  163. hidden_states = self.mlp(hidden_states)
  164. hidden_states = residual + hidden_states
  165. return hidden_states, residual
  166. class StableLMEpochModel(nn.Module):
  167. def __init__(self,
  168. config: PretrainedConfig,
  169. quant_config: Optional[QuantizationConfig] = None) -> None:
  170. super().__init__()
  171. self.embed_tokens = VocabParallelEmbedding(
  172. config.vocab_size,
  173. config.hidden_size,
  174. )
  175. self.layers = nn.ModuleList([
  176. StablelmDecoderLayer(config, quant_config)
  177. for _ in range(config.num_hidden_layers)
  178. ])
  179. norm_eps = getattr(config, "norm_eps",
  180. getattr(config, "layer_norm_eps", 1e-05))
  181. self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
  182. def forward(
  183. self,
  184. input_ids: torch.Tensor,
  185. positions: torch.Tensor,
  186. kv_caches: List[torch.Tensor],
  187. attn_metadata: AttentionMetadata,
  188. ) -> torch.Tensor:
  189. hidden_states = self.embed_tokens(input_ids)
  190. for i in range(len(self.layers)):
  191. layer = self.layers[i]
  192. hidden_states, residual = layer(
  193. positions,
  194. hidden_states,
  195. kv_caches[i],
  196. attn_metadata,
  197. )
  198. hidden_states = self.norm(hidden_states)
  199. return hidden_states
  200. class StablelmForCausalLM(nn.Module):
  201. def __init__(
  202. self,
  203. config: PretrainedConfig,
  204. quant_config: Optional[QuantizationConfig] = None,
  205. ) -> None:
  206. super().__init__()
  207. self.config = config
  208. self.quant_config = quant_config
  209. self.model = StableLMEpochModel(config, quant_config)
  210. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  211. self.logits_processor = LogitsProcessor(config.vocab_size)
  212. self.sampler = Sampler()
  213. def forward(
  214. self,
  215. input_ids: torch.Tensor,
  216. positions: torch.Tensor,
  217. kv_caches: List[torch.Tensor],
  218. attn_metadata: AttentionMetadata,
  219. ) -> torch.Tensor:
  220. hidden_states = self.model(input_ids, positions, kv_caches,
  221. attn_metadata)
  222. return hidden_states
  223. def compute_logits(self, hidden_states: torch.Tensor,
  224. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  225. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  226. sampling_metadata)
  227. return logits
  228. def sample(
  229. self,
  230. logits: torch.Tensor,
  231. sampling_metadata: SamplingMetadata,
  232. ) -> Optional[SamplerOutput]:
  233. next_tokens = self.sampler(logits, sampling_metadata)
  234. return next_tokens
  235. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  236. stacked_params_mapping = [
  237. # (param_name, shard_name, shard_id)
  238. ("qkv_proj", "q_proj", "q"),
  239. ("qkv_proj", "k_proj", "k"),
  240. ("qkv_proj", "v_proj", "v"),
  241. ("gate_up_proj", "gate_proj", 0),
  242. ("gate_up_proj", "up_proj", 1),
  243. ]
  244. params_dict = dict(self.named_parameters())
  245. for name, loaded_weight in weights:
  246. if "rotary_emb.inv_freq" in name:
  247. continue
  248. if ("rotary_emb.cos_cached" in name
  249. or "rotary_emb.sin_cached" in name):
  250. # Models trained using ColossalAI may include these tensors in
  251. # the checkpoint. Skip them.
  252. continue
  253. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  254. if weight_name not in name:
  255. continue
  256. name = name.replace(weight_name, param_name)
  257. # Skip loading extra bias for GPTQ models.
  258. if name.endswith(".bias") and name not in params_dict:
  259. continue
  260. param = params_dict[name]
  261. weight_loader = param.weight_loader
  262. weight_loader(param, loaded_weight, shard_id)
  263. break
  264. else:
  265. # Skip loading extra bias for GPTQ models.
  266. if name.endswith(".bias") and name not in params_dict:
  267. continue
  268. param = params_dict[name]
  269. weight_loader = getattr(param, "weight_loader",
  270. default_weight_loader)
  271. weight_loader(param, loaded_weight)