1
0

starcoder2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # coding=utf-8
  2. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """ PyTorch Starcoder2 model."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import Starcoder2Config
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.sequence import SamplerOutput
  27. from aphrodite.distributed import get_tensor_model_parallel_world_size
  28. from aphrodite.modeling.layers.activation import get_act_fn
  29. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  30. QKVParallelLinear,
  31. RowParallelLinear)
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.rotary_embedding import get_rope
  34. from aphrodite.modeling.layers.sampler import Sampler
  35. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  36. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  37. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. class Starcoder2Attention(nn.Module):
  41. def __init__(self,
  42. config: Starcoder2Config,
  43. quant_config: Optional[QuantizationConfig] = None):
  44. super().__init__()
  45. self.config = config
  46. self.hidden_size = config.hidden_size
  47. tp_size = get_tensor_model_parallel_world_size()
  48. self.total_num_heads = config.num_attention_heads
  49. assert self.total_num_heads % tp_size == 0
  50. self.num_heads = self.total_num_heads // tp_size
  51. self.total_num_kv_heads = config.num_key_value_heads
  52. if self.total_num_kv_heads >= tp_size:
  53. # Number of KV heads is greater than TP size, so we partition
  54. # the KV heads across multiple tensor parallel GPUs.
  55. assert self.total_num_kv_heads % tp_size == 0
  56. else:
  57. # Number of KV heads is less than TP size, so we replicate
  58. # the KV heads across multiple tensor parallel GPUs.
  59. assert tp_size % self.total_num_kv_heads == 0
  60. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  61. self.head_dim = self.hidden_size // self.total_num_heads
  62. self.q_size = self.num_heads * self.head_dim
  63. self.kv_size = self.num_kv_heads * self.head_dim
  64. self.scaling = self.head_dim**-0.5
  65. self.rope_theta = config.rope_theta
  66. self.max_position_embeddings = config.max_position_embeddings
  67. self.use_bias = config.use_bias
  68. self.sliding_window = config.sliding_window
  69. self.qkv_proj = QKVParallelLinear(
  70. self.hidden_size,
  71. self.head_dim,
  72. self.total_num_heads,
  73. self.total_num_kv_heads,
  74. bias=self.use_bias,
  75. quant_config=quant_config,
  76. )
  77. self.o_proj = RowParallelLinear(
  78. self.total_num_heads * self.head_dim,
  79. self.hidden_size,
  80. bias=self.use_bias,
  81. quant_config=quant_config,
  82. )
  83. self.rotary_emb = get_rope(
  84. self.head_dim,
  85. rotary_dim=self.head_dim,
  86. max_position=self.max_position_embeddings,
  87. base=int(self.rope_theta),
  88. is_neox_style=True,
  89. )
  90. self.attn = Attention(
  91. self.num_heads,
  92. self.head_dim,
  93. self.scaling,
  94. num_kv_heads=self.num_kv_heads,
  95. sliding_window=self.sliding_window,
  96. )
  97. def forward(
  98. self,
  99. positions: torch.Tensor,
  100. hidden_states: torch.Tensor,
  101. kv_cache: torch.Tensor,
  102. attn_metadata: AttentionMetadata,
  103. ) -> torch.Tensor:
  104. qkv, _ = self.qkv_proj(hidden_states)
  105. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  106. q, k = self.rotary_emb(positions, q, k)
  107. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  108. output, _ = self.o_proj(attn_output)
  109. return output
  110. class Starcoder2MLP(nn.Module):
  111. def __init__(self,
  112. config: Starcoder2Config,
  113. quant_config: Optional[QuantizationConfig] = None):
  114. super().__init__()
  115. self.c_fc = ColumnParallelLinear(
  116. config.hidden_size,
  117. config.intermediate_size,
  118. bias=config.use_bias,
  119. quant_config=quant_config,
  120. )
  121. self.c_proj = RowParallelLinear(
  122. config.intermediate_size,
  123. config.hidden_size,
  124. bias=config.use_bias,
  125. quant_config=quant_config,
  126. )
  127. quant_config = getattr(quant_config, "quant_config", None)
  128. self.act = get_act_fn(config.hidden_act, quant_config,
  129. config.intermediate_size)
  130. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  131. hidden_states, _ = self.c_fc(hidden_states)
  132. hidden_states = self.act(hidden_states)
  133. hidden_states, _ = self.c_proj(hidden_states)
  134. return hidden_states
  135. class Starcoder2DecoderLayer(nn.Module):
  136. def __init__(self,
  137. config: Starcoder2Config,
  138. quant_config: Optional[QuantizationConfig] = None):
  139. super().__init__()
  140. self.hidden_size = config.hidden_size
  141. self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
  142. self.mlp = Starcoder2MLP(config, quant_config=quant_config)
  143. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  144. eps=config.norm_epsilon)
  145. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  146. eps=config.norm_epsilon)
  147. def forward(
  148. self,
  149. positions: torch.Tensor,
  150. hidden_states: torch.Tensor,
  151. kv_cache: torch.Tensor,
  152. attn_metadata: AttentionMetadata,
  153. ) -> torch.Tensor:
  154. # Self Attention
  155. residual = hidden_states
  156. hidden_states = self.input_layernorm(hidden_states)
  157. hidden_states = self.self_attn(
  158. positions=positions,
  159. hidden_states=hidden_states,
  160. kv_cache=kv_cache,
  161. attn_metadata=attn_metadata,
  162. )
  163. hidden_states = residual + hidden_states
  164. # Fully Connected
  165. residual = hidden_states
  166. hidden_states = self.post_attention_layernorm(hidden_states)
  167. hidden_states = self.mlp(hidden_states)
  168. hidden_states = residual + hidden_states
  169. return hidden_states
  170. class Starcoder2Model(nn.Module):
  171. def __init__(self,
  172. config: Starcoder2Config,
  173. quant_config: Optional[QuantizationConfig] = None):
  174. super().__init__()
  175. self.config = config
  176. self.padding_idx = config.pad_token_id
  177. self.vocab_size = config.vocab_size
  178. # TODO: consider padding_idx (currently removed)
  179. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  180. config.hidden_size)
  181. self.layers = nn.ModuleList([
  182. Starcoder2DecoderLayer(config, quant_config=quant_config)
  183. for _ in range(config.num_hidden_layers)
  184. ])
  185. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  186. def forward(
  187. self,
  188. input_ids: torch.Tensor,
  189. positions: torch.Tensor,
  190. kv_caches: List[torch.Tensor],
  191. attn_metadata: AttentionMetadata,
  192. ) -> torch.Tensor:
  193. hidden_states = self.embed_tokens(input_ids)
  194. for i in range(len(self.layers)):
  195. layer = self.layers[i]
  196. hidden_states = layer(positions, hidden_states, kv_caches[i],
  197. attn_metadata)
  198. hidden_states = self.norm(hidden_states)
  199. return hidden_states
  200. class Starcoder2ForCausalLM(nn.Module):
  201. def __init__(self,
  202. config: Starcoder2Config,
  203. quant_config: Optional[QuantizationConfig] = None):
  204. super().__init__()
  205. self.config = config
  206. self.model = Starcoder2Model(config, quant_config=quant_config)
  207. self.vocab_size = config.vocab_size
  208. self.unpadded_vocab_size = config.vocab_size
  209. if config.tie_word_embeddings:
  210. self.lm_head_weight = self.model.embed_tokens.weight
  211. else:
  212. self.unpadded_vocab_size = config.vocab_size
  213. self.lm_head = ParallelLMHead(
  214. self.unpadded_vocab_size,
  215. config.hidden_size,
  216. org_num_embeddings=config.vocab_size,
  217. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  218. )
  219. self.lm_head_weight = self.lm_head.weight
  220. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  221. config.vocab_size)
  222. self.sampler = Sampler()
  223. def forward(
  224. self,
  225. input_ids: torch.Tensor,
  226. positions: torch.Tensor,
  227. kv_caches: List[torch.Tensor],
  228. attn_metadata: AttentionMetadata,
  229. ) -> torch.Tensor:
  230. hidden_states = self.model(input_ids, positions, kv_caches,
  231. attn_metadata)
  232. return hidden_states
  233. def compute_logits(self, hidden_states: torch.Tensor,
  234. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  235. logits = self.logits_processor(self.lm_head_weight, hidden_states,
  236. sampling_metadata)
  237. return logits
  238. def sample(
  239. self,
  240. logits: Optional[torch.Tensor],
  241. sampling_metadata: SamplingMetadata,
  242. ) -> Optional[SamplerOutput]:
  243. next_tokens = self.sampler(logits, sampling_metadata)
  244. return next_tokens
  245. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  246. stacked_params_mapping = [
  247. # (param_name, shard_name, shard_id)
  248. ("qkv_proj", "q_proj", "q"),
  249. ("qkv_proj", "k_proj", "k"),
  250. ("qkv_proj", "v_proj", "v"),
  251. ]
  252. params_dict = dict(self.named_parameters(remove_duplicate=False))
  253. for name, loaded_weight in weights:
  254. if "rotary_emb.inv_freq" in name:
  255. continue
  256. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  257. if weight_name not in name:
  258. continue
  259. name = name.replace(weight_name, param_name)
  260. param = params_dict[name]
  261. weight_loader = param.weight_loader
  262. weight_loader(param, loaded_weight, shard_id)
  263. break
  264. else:
  265. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  266. continue
  267. param = params_dict[name]
  268. weight_loader = getattr(param, "weight_loader",
  269. default_weight_loader)
  270. weight_loader(param, loaded_weight)