starcoder2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # coding=utf-8
  2. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """ PyTorch Starcoder2 model."""
  21. from typing import Iterable, List, Optional, Tuple
  22. import torch
  23. from torch import nn
  24. from transformers import Starcoder2Config
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.common.config import CacheConfig
  27. from aphrodite.common.sequence import IntermediateTensors
  28. from aphrodite.distributed import get_tensor_model_parallel_world_size
  29. from aphrodite.modeling.layers.activation import get_act_fn
  30. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear)
  33. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  34. from aphrodite.modeling.layers.rotary_embedding import get_rope
  35. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  36. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  37. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  38. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  39. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. class Starcoder2Attention(nn.Module):
  42. def __init__(self,
  43. config: Starcoder2Config,
  44. cache_config: Optional[CacheConfig] = None,
  45. quant_config: Optional[QuantizationConfig] = None):
  46. super().__init__()
  47. self.config = config
  48. self.hidden_size = config.hidden_size
  49. tp_size = get_tensor_model_parallel_world_size()
  50. self.total_num_heads = config.num_attention_heads
  51. assert self.total_num_heads % tp_size == 0
  52. self.num_heads = self.total_num_heads // tp_size
  53. self.total_num_kv_heads = config.num_key_value_heads
  54. if self.total_num_kv_heads >= tp_size:
  55. # Number of KV heads is greater than TP size, so we partition
  56. # the KV heads across multiple tensor parallel GPUs.
  57. assert self.total_num_kv_heads % tp_size == 0
  58. else:
  59. # Number of KV heads is less than TP size, so we replicate
  60. # the KV heads across multiple tensor parallel GPUs.
  61. assert tp_size % self.total_num_kv_heads == 0
  62. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  63. self.head_dim = self.hidden_size // self.total_num_heads
  64. self.q_size = self.num_heads * self.head_dim
  65. self.kv_size = self.num_kv_heads * self.head_dim
  66. self.scaling = self.head_dim**-0.5
  67. self.rope_theta = config.rope_theta
  68. self.max_position_embeddings = config.max_position_embeddings
  69. self.use_bias = config.use_bias
  70. self.qkv_proj = QKVParallelLinear(
  71. self.hidden_size,
  72. self.head_dim,
  73. self.total_num_heads,
  74. self.total_num_kv_heads,
  75. bias=self.use_bias,
  76. quant_config=quant_config,
  77. )
  78. self.o_proj = RowParallelLinear(
  79. self.total_num_heads * self.head_dim,
  80. self.hidden_size,
  81. bias=self.use_bias,
  82. quant_config=quant_config,
  83. )
  84. self.rotary_emb = get_rope(
  85. self.head_dim,
  86. rotary_dim=self.head_dim,
  87. max_position=self.max_position_embeddings,
  88. base=int(self.rope_theta),
  89. is_neox_style=True,
  90. )
  91. self.attn = Attention(self.num_heads,
  92. self.head_dim,
  93. self.scaling,
  94. num_kv_heads=self.num_kv_heads,
  95. cache_config=cache_config,
  96. quant_config=quant_config)
  97. def forward(
  98. self,
  99. positions: torch.Tensor,
  100. hidden_states: torch.Tensor,
  101. kv_cache: torch.Tensor,
  102. attn_metadata: AttentionMetadata,
  103. ) -> torch.Tensor:
  104. qkv, _ = self.qkv_proj(hidden_states)
  105. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  106. q, k = self.rotary_emb(positions, q, k)
  107. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  108. output, _ = self.o_proj(attn_output)
  109. return output
  110. class Starcoder2MLP(nn.Module):
  111. def __init__(self,
  112. config: Starcoder2Config,
  113. quant_config: Optional[QuantizationConfig] = None):
  114. super().__init__()
  115. self.c_fc = ColumnParallelLinear(
  116. config.hidden_size,
  117. config.intermediate_size,
  118. bias=config.use_bias,
  119. quant_config=quant_config,
  120. )
  121. self.c_proj = RowParallelLinear(
  122. config.intermediate_size,
  123. config.hidden_size,
  124. bias=config.use_bias,
  125. quant_config=quant_config,
  126. )
  127. self.act = get_act_fn(config.hidden_act, quant_config,
  128. config.intermediate_size)
  129. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  130. hidden_states, _ = self.c_fc(hidden_states)
  131. hidden_states = self.act(hidden_states)
  132. hidden_states, _ = self.c_proj(hidden_states)
  133. return hidden_states
  134. class Starcoder2DecoderLayer(nn.Module):
  135. def __init__(self,
  136. config: Starcoder2Config,
  137. cache_config: Optional[CacheConfig] = None,
  138. quant_config: Optional[QuantizationConfig] = None):
  139. super().__init__()
  140. self.hidden_size = config.hidden_size
  141. self.self_attn = Starcoder2Attention(config,
  142. cache_config,
  143. quant_config=quant_config)
  144. self.mlp = Starcoder2MLP(config, quant_config=quant_config)
  145. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  146. eps=config.norm_epsilon)
  147. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  148. eps=config.norm_epsilon)
  149. def forward(
  150. self,
  151. positions: torch.Tensor,
  152. hidden_states: torch.Tensor,
  153. kv_cache: torch.Tensor,
  154. attn_metadata: AttentionMetadata,
  155. ) -> torch.Tensor:
  156. # Self Attention
  157. residual = hidden_states
  158. hidden_states = self.input_layernorm(hidden_states)
  159. hidden_states = self.self_attn(
  160. positions=positions,
  161. hidden_states=hidden_states,
  162. kv_cache=kv_cache,
  163. attn_metadata=attn_metadata,
  164. )
  165. hidden_states = residual + hidden_states
  166. # Fully Connected
  167. residual = hidden_states
  168. hidden_states = self.post_attention_layernorm(hidden_states)
  169. hidden_states = self.mlp(hidden_states)
  170. hidden_states = residual + hidden_states
  171. return hidden_states
  172. class Starcoder2Model(nn.Module):
  173. def __init__(self,
  174. config: Starcoder2Config,
  175. cache_config: Optional[CacheConfig] = None,
  176. quant_config: Optional[QuantizationConfig] = None):
  177. super().__init__()
  178. self.config = config
  179. self.padding_idx = config.pad_token_id
  180. self.vocab_size = config.vocab_size
  181. # TODO: consider padding_idx (currently removed)
  182. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  183. config.hidden_size)
  184. self.layers = nn.ModuleList([
  185. Starcoder2DecoderLayer(config,
  186. cache_config,
  187. quant_config=quant_config)
  188. for _ in range(config.num_hidden_layers)
  189. ])
  190. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  191. def forward(
  192. self,
  193. input_ids: torch.Tensor,
  194. positions: torch.Tensor,
  195. kv_caches: List[torch.Tensor],
  196. attn_metadata: AttentionMetadata,
  197. ) -> torch.Tensor:
  198. hidden_states = self.embed_tokens(input_ids)
  199. for i in range(len(self.layers)):
  200. layer = self.layers[i]
  201. hidden_states = layer(positions, hidden_states, kv_caches[i],
  202. attn_metadata)
  203. hidden_states = self.norm(hidden_states)
  204. return hidden_states
  205. class Starcoder2ForCausalLM(nn.Module):
  206. def __init__(self,
  207. config: Starcoder2Config,
  208. cache_config: Optional[CacheConfig] = None,
  209. quant_config: Optional[QuantizationConfig] = None):
  210. super().__init__()
  211. self.config = config
  212. self.model = Starcoder2Model(config,
  213. cache_config,
  214. quant_config=quant_config)
  215. self.vocab_size = config.vocab_size
  216. self.unpadded_vocab_size = config.vocab_size
  217. if config.tie_word_embeddings:
  218. self.lm_head = self.model.embed_tokens
  219. else:
  220. self.unpadded_vocab_size = config.vocab_size
  221. self.lm_head = ParallelLMHead(
  222. self.unpadded_vocab_size,
  223. config.hidden_size,
  224. org_num_embeddings=config.vocab_size,
  225. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  226. quant_config=quant_config,
  227. )
  228. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  229. config.vocab_size)
  230. self.sampler = Sampler()
  231. def forward(
  232. self,
  233. input_ids: torch.Tensor,
  234. positions: torch.Tensor,
  235. kv_caches: List[torch.Tensor],
  236. attn_metadata: AttentionMetadata,
  237. intermediate_tensors: Optional[IntermediateTensors] = None,
  238. ) -> torch.Tensor:
  239. hidden_states = self.model(input_ids, positions, kv_caches,
  240. attn_metadata)
  241. return hidden_states
  242. def compute_logits(
  243. self,
  244. hidden_states: torch.Tensor,
  245. sampling_metadata: SamplingMetadata,
  246. ) -> Optional[torch.Tensor]:
  247. logits = self.logits_processor(self.lm_head, hidden_states,
  248. sampling_metadata)
  249. return logits
  250. def sample(
  251. self,
  252. logits: Optional[torch.Tensor],
  253. sampling_metadata: SamplingMetadata,
  254. ) -> Optional[SamplerOutput]:
  255. next_tokens = self.sampler(logits, sampling_metadata)
  256. return next_tokens
  257. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  258. stacked_params_mapping = [
  259. # (param_name, shard_name, shard_id)
  260. ("qkv_proj", "q_proj", "q"),
  261. ("qkv_proj", "k_proj", "k"),
  262. ("qkv_proj", "v_proj", "v"),
  263. ]
  264. params_dict = dict(self.named_parameters(remove_duplicate=False))
  265. for name, loaded_weight in weights:
  266. if "rotary_emb.inv_freq" in name:
  267. continue
  268. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  269. if weight_name not in name:
  270. continue
  271. name = name.replace(weight_name, param_name)
  272. param = params_dict[name]
  273. weight_loader = param.weight_loader
  274. weight_loader(param, loaded_weight, shard_id)
  275. break
  276. else:
  277. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  278. continue
  279. param = params_dict[name]
  280. weight_loader = getattr(param, "weight_loader",
  281. default_weight_loader)
  282. weight_loader(param, loaded_weight)