starcoder2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # coding=utf-8
  2. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """ PyTorch Starcoder2 model."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import Starcoder2Config
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.config import CacheConfig
  27. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  28. from aphrodite.common.utils import progress_bar
  29. from aphrodite.distributed import get_tensor_model_parallel_world_size
  30. from aphrodite.modeling.layers.activation import get_act_fn
  31. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear)
  34. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  35. from aphrodite.modeling.layers.rotary_embedding import get_rope
  36. from aphrodite.modeling.layers.sampler import Sampler
  37. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  38. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  39. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  40. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  41. from aphrodite.quantization.base_config import QuantizationConfig
  42. class Starcoder2Attention(nn.Module):
  43. def __init__(self,
  44. config: Starcoder2Config,
  45. cache_config: Optional[CacheConfig] = None,
  46. quant_config: Optional[QuantizationConfig] = None):
  47. super().__init__()
  48. self.config = config
  49. self.hidden_size = config.hidden_size
  50. tp_size = get_tensor_model_parallel_world_size()
  51. self.total_num_heads = config.num_attention_heads
  52. assert self.total_num_heads % tp_size == 0
  53. self.num_heads = self.total_num_heads // tp_size
  54. self.total_num_kv_heads = config.num_key_value_heads
  55. if self.total_num_kv_heads >= tp_size:
  56. # Number of KV heads is greater than TP size, so we partition
  57. # the KV heads across multiple tensor parallel GPUs.
  58. assert self.total_num_kv_heads % tp_size == 0
  59. else:
  60. # Number of KV heads is less than TP size, so we replicate
  61. # the KV heads across multiple tensor parallel GPUs.
  62. assert tp_size % self.total_num_kv_heads == 0
  63. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  64. self.head_dim = self.hidden_size // self.total_num_heads
  65. self.q_size = self.num_heads * self.head_dim
  66. self.kv_size = self.num_kv_heads * self.head_dim
  67. self.scaling = self.head_dim**-0.5
  68. self.rope_theta = config.rope_theta
  69. self.max_position_embeddings = config.max_position_embeddings
  70. self.use_bias = config.use_bias
  71. self.qkv_proj = QKVParallelLinear(
  72. self.hidden_size,
  73. self.head_dim,
  74. self.total_num_heads,
  75. self.total_num_kv_heads,
  76. bias=self.use_bias,
  77. quant_config=quant_config,
  78. )
  79. self.o_proj = RowParallelLinear(
  80. self.total_num_heads * self.head_dim,
  81. self.hidden_size,
  82. bias=self.use_bias,
  83. quant_config=quant_config,
  84. )
  85. self.rotary_emb = get_rope(
  86. self.head_dim,
  87. rotary_dim=self.head_dim,
  88. max_position=self.max_position_embeddings,
  89. base=int(self.rope_theta),
  90. is_neox_style=True,
  91. )
  92. self.attn = Attention(self.num_heads,
  93. self.head_dim,
  94. self.scaling,
  95. num_kv_heads=self.num_kv_heads,
  96. cache_config=cache_config,
  97. quant_config=quant_config)
  98. def forward(
  99. self,
  100. positions: torch.Tensor,
  101. hidden_states: torch.Tensor,
  102. kv_cache: torch.Tensor,
  103. attn_metadata: AttentionMetadata,
  104. ) -> torch.Tensor:
  105. qkv, _ = self.qkv_proj(hidden_states)
  106. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  107. q, k = self.rotary_emb(positions, q, k)
  108. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  109. output, _ = self.o_proj(attn_output)
  110. return output
  111. class Starcoder2MLP(nn.Module):
  112. def __init__(self,
  113. config: Starcoder2Config,
  114. quant_config: Optional[QuantizationConfig] = None):
  115. super().__init__()
  116. self.c_fc = ColumnParallelLinear(
  117. config.hidden_size,
  118. config.intermediate_size,
  119. bias=config.use_bias,
  120. quant_config=quant_config,
  121. )
  122. self.c_proj = RowParallelLinear(
  123. config.intermediate_size,
  124. config.hidden_size,
  125. bias=config.use_bias,
  126. quant_config=quant_config,
  127. )
  128. self.act = get_act_fn(config.hidden_act, quant_config,
  129. config.intermediate_size)
  130. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  131. hidden_states, _ = self.c_fc(hidden_states)
  132. hidden_states = self.act(hidden_states)
  133. hidden_states, _ = self.c_proj(hidden_states)
  134. return hidden_states
  135. class Starcoder2DecoderLayer(nn.Module):
  136. def __init__(self,
  137. config: Starcoder2Config,
  138. cache_config: Optional[CacheConfig] = None,
  139. quant_config: Optional[QuantizationConfig] = None):
  140. super().__init__()
  141. self.hidden_size = config.hidden_size
  142. self.self_attn = Starcoder2Attention(config,
  143. cache_config,
  144. quant_config=quant_config)
  145. self.mlp = Starcoder2MLP(config, quant_config=quant_config)
  146. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  147. eps=config.norm_epsilon)
  148. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  149. eps=config.norm_epsilon)
  150. def forward(
  151. self,
  152. positions: torch.Tensor,
  153. hidden_states: torch.Tensor,
  154. kv_cache: torch.Tensor,
  155. attn_metadata: AttentionMetadata,
  156. ) -> torch.Tensor:
  157. # Self Attention
  158. residual = hidden_states
  159. hidden_states = self.input_layernorm(hidden_states)
  160. hidden_states = self.self_attn(
  161. positions=positions,
  162. hidden_states=hidden_states,
  163. kv_cache=kv_cache,
  164. attn_metadata=attn_metadata,
  165. )
  166. hidden_states = residual + hidden_states
  167. # Fully Connected
  168. residual = hidden_states
  169. hidden_states = self.post_attention_layernorm(hidden_states)
  170. hidden_states = self.mlp(hidden_states)
  171. hidden_states = residual + hidden_states
  172. return hidden_states
  173. class Starcoder2Model(nn.Module):
  174. def __init__(self,
  175. config: Starcoder2Config,
  176. cache_config: Optional[CacheConfig] = None,
  177. quant_config: Optional[QuantizationConfig] = None):
  178. super().__init__()
  179. self.config = config
  180. self.padding_idx = config.pad_token_id
  181. self.vocab_size = config.vocab_size
  182. # TODO: consider padding_idx (currently removed)
  183. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  184. config.hidden_size)
  185. self.layers = nn.ModuleList([
  186. Starcoder2DecoderLayer(config,
  187. cache_config,
  188. quant_config=quant_config)
  189. for _ in range(config.num_hidden_layers)
  190. ])
  191. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  192. def forward(
  193. self,
  194. input_ids: torch.Tensor,
  195. positions: torch.Tensor,
  196. kv_caches: List[torch.Tensor],
  197. attn_metadata: AttentionMetadata,
  198. ) -> torch.Tensor:
  199. hidden_states = self.embed_tokens(input_ids)
  200. for i in range(len(self.layers)):
  201. layer = self.layers[i]
  202. hidden_states = layer(positions, hidden_states, kv_caches[i],
  203. attn_metadata)
  204. hidden_states = self.norm(hidden_states)
  205. return hidden_states
  206. class Starcoder2ForCausalLM(nn.Module):
  207. def __init__(self,
  208. config: Starcoder2Config,
  209. cache_config: Optional[CacheConfig] = None,
  210. quant_config: Optional[QuantizationConfig] = None):
  211. super().__init__()
  212. self.config = config
  213. self.model = Starcoder2Model(config,
  214. cache_config,
  215. quant_config=quant_config)
  216. self.vocab_size = config.vocab_size
  217. self.unpadded_vocab_size = config.vocab_size
  218. if config.tie_word_embeddings:
  219. self.lm_head = self.model.embed_tokens
  220. else:
  221. self.unpadded_vocab_size = config.vocab_size
  222. self.lm_head = ParallelLMHead(
  223. self.unpadded_vocab_size,
  224. config.hidden_size,
  225. org_num_embeddings=config.vocab_size,
  226. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  227. quant_config=quant_config,
  228. )
  229. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  230. config.vocab_size)
  231. self.sampler = Sampler()
  232. def forward(
  233. self,
  234. input_ids: torch.Tensor,
  235. positions: torch.Tensor,
  236. kv_caches: List[torch.Tensor],
  237. attn_metadata: AttentionMetadata,
  238. intermediate_tensors: Optional[IntermediateTensors] = None,
  239. ) -> torch.Tensor:
  240. hidden_states = self.model(input_ids, positions, kv_caches,
  241. attn_metadata)
  242. return hidden_states
  243. def compute_logits(
  244. self,
  245. hidden_states: torch.Tensor,
  246. sampling_metadata: SamplingMetadata,
  247. ) -> Optional[torch.Tensor]:
  248. logits = self.logits_processor(self.lm_head, hidden_states,
  249. sampling_metadata)
  250. return logits
  251. def sample(
  252. self,
  253. logits: Optional[torch.Tensor],
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[SamplerOutput]:
  256. next_tokens = self.sampler(logits, sampling_metadata)
  257. return next_tokens
  258. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  259. stacked_params_mapping = [
  260. # (param_name, shard_name, shard_id)
  261. ("qkv_proj", "q_proj", "q"),
  262. ("qkv_proj", "k_proj", "k"),
  263. ("qkv_proj", "v_proj", "v"),
  264. ]
  265. params_dict = dict(self.named_parameters(remove_duplicate=False))
  266. weights_list = list(weights)
  267. for name, loaded_weight in progress_bar(weights_list,
  268. desc="Loading modules..."):
  269. if "rotary_emb.inv_freq" in name:
  270. continue
  271. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  272. if weight_name not in name:
  273. continue
  274. name = name.replace(weight_name, param_name)
  275. param = params_dict[name]
  276. weight_loader = param.weight_loader
  277. weight_loader(param, loaded_weight, shard_id)
  278. break
  279. else:
  280. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  281. continue
  282. param = params_dict[name]
  283. weight_loader = getattr(param, "weight_loader",
  284. default_weight_loader)
  285. weight_loader(param, loaded_weight)