phi.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright (c) Microsoft Corporation.
  7. # Licensed under the MIT license.
  8. #
  9. # BSD 3-Clause License
  10. #
  11. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  12. # All rights reserved.
  13. #
  14. # Redistribution and use in source and binary forms, with or without
  15. # modification, are permitted provided that the following conditions are met:
  16. #
  17. # * Redistributions of source code must retain the above copyright notice, this
  18. # list of conditions and the following disclaimer.
  19. #
  20. # * Redistributions in binary form must reproduce the above copyright notice,
  21. # this list of conditions and the following disclaimer in the documentation
  22. # and/or other materials provided with the distribution.
  23. #
  24. # * Neither the name of the copyright holder nor the names of its
  25. # contributors may be used to endorse or promote products derived from
  26. # this software without specific prior written permission.
  27. #
  28. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  29. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  30. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  31. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  32. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  33. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  34. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  35. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  36. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  37. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  38. """Inference-only Phi model compatible with HuggingFace weights."""
  39. from typing import List, Optional
  40. import torch
  41. from torch import nn
  42. from transformers import PretrainedConfig
  43. from aphrodite.attention import Attention, AttentionMetadata
  44. from aphrodite.modeling.layers.activation import get_act_fn
  45. from aphrodite.modeling.layers.linear import (
  46. ColumnParallelLinear,
  47. LinearMethodBase,
  48. QKVParallelLinear,
  49. RowParallelLinear,
  50. )
  51. from aphrodite.modeling.layers.rotary_embedding import get_rope
  52. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  53. from aphrodite.modeling.layers.sampler import Sampler
  54. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  55. VocabParallelEmbedding,
  56. ParallelLMHead,
  57. )
  58. from aphrodite.distributed import (
  59. get_tensor_model_parallel_world_size, )
  60. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  61. from aphrodite.modeling.hf_downloader import (
  62. default_weight_loader,
  63. hf_model_weights_iterator,
  64. )
  65. from aphrodite.common.sequence import SamplerOutput
  66. class PhiAttention(nn.Module):
  67. def __init__(
  68. self,
  69. config: PretrainedConfig,
  70. linear_method: Optional[LinearMethodBase] = None,
  71. ):
  72. super().__init__()
  73. self.total_num_heads = config.num_attention_heads
  74. self.hidden_size = config.hidden_size
  75. self.head_size = self.hidden_size // self.total_num_heads
  76. tensor_model_parallel_world_size = (
  77. get_tensor_model_parallel_world_size())
  78. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  79. self.num_heads = (self.total_num_heads //
  80. tensor_model_parallel_world_size)
  81. # pylint: disable=C0103
  82. if (linear_method is not None
  83. and not linear_method.quant_config.merge_weight()):
  84. self.merge_weight = False
  85. self.q_proj = ColumnParallelLinear(
  86. self.hidden_size,
  87. self.hidden_size,
  88. bias=True,
  89. linear_method=linear_method,
  90. )
  91. self.k_proj = ColumnParallelLinear(
  92. self.hidden_size,
  93. self.hidden_size,
  94. bias=True,
  95. linear_method=linear_method,
  96. )
  97. self.v_proj = ColumnParallelLinear(
  98. self.hidden_size,
  99. self.hidden_size,
  100. bias=True,
  101. linear_method=linear_method,
  102. )
  103. else:
  104. self.merge_weight = True
  105. self.qkv_proj = QKVParallelLinear(
  106. self.hidden_size,
  107. self.head_size,
  108. self.total_num_heads,
  109. bias=True,
  110. linear_method=linear_method,
  111. )
  112. self.dense = RowParallelLinear(
  113. self.hidden_size,
  114. self.hidden_size,
  115. linear_method=linear_method,
  116. )
  117. scaling = self.head_size**-0.5
  118. rotary_dim = int(config.partial_rotary_factor *
  119. (config.hidden_size // config.num_attention_heads))
  120. assert rotary_dim % 2 == 0
  121. # pylint: disable=C0301
  122. # Refer to:
  123. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  124. rope_theta = 10000
  125. max_position_embeddings = getattr(config, "n_positions", 2048)
  126. is_neox_style = (True if linear_method is None
  127. or linear_method.quant_config.rope_style() is None
  128. else linear_method.quant_config.rope_style())
  129. self.rotary_emb = get_rope(
  130. self.head_size,
  131. rotary_dim=rotary_dim,
  132. max_position=max_position_embeddings,
  133. base=rope_theta,
  134. is_neox_style=is_neox_style,
  135. )
  136. self.attn = Attention(self.num_heads, self.head_size, scaling)
  137. def forward(
  138. self,
  139. position_ids: torch.Tensor,
  140. hidden_states: torch.Tensor,
  141. kv_cache: torch.Tensor,
  142. attn_metadata: AttentionMetadata,
  143. ) -> torch.Tensor:
  144. if self.merge_weight:
  145. qkv, _ = self.qkv_proj(hidden_states)
  146. q, k, v = qkv.chunk(chunks=3, dim=-1)
  147. else:
  148. q, _ = self.q_proj(hidden_states)
  149. k, _ = self.k_proj(hidden_states)
  150. v, _ = self.v_proj(hidden_states)
  151. q, k = self.rotary_emb(position_ids, q, k)
  152. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  153. output, _ = self.dense(attn_output)
  154. return output
  155. class PhiMLP(nn.Module):
  156. def __init__(
  157. self,
  158. config: PretrainedConfig,
  159. linear_method: Optional[LinearMethodBase] = None,
  160. ):
  161. super().__init__()
  162. n_inner = getattr(config, "n_inner", None)
  163. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  164. self.fc1 = ColumnParallelLinear(
  165. config.hidden_size,
  166. n_inner,
  167. linear_method=linear_method,
  168. )
  169. self.fc2 = RowParallelLinear(
  170. n_inner,
  171. config.hidden_size,
  172. linear_method=linear_method,
  173. )
  174. quant_config = getattr(linear_method, "quant_config", None)
  175. self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
  176. def forward(self, hidden_states):
  177. hidden_states, _ = self.fc1(hidden_states)
  178. hidden_states = self.act(hidden_states)
  179. hidden_states, _ = self.fc2(hidden_states)
  180. return hidden_states
  181. class PhiLayer(nn.Module):
  182. def __init__(
  183. self,
  184. config: PretrainedConfig,
  185. linear_method: Optional[LinearMethodBase] = None,
  186. ):
  187. super().__init__()
  188. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  189. eps=config.layer_norm_eps)
  190. self.self_attn = PhiAttention(config, linear_method)
  191. self.mlp = PhiMLP(config, linear_method)
  192. def forward(
  193. self,
  194. position_ids: torch.Tensor,
  195. hidden_states: torch.Tensor,
  196. kv_cache: torch.Tensor,
  197. attn_metadata: AttentionMetadata,
  198. ) -> torch.Tensor:
  199. residual = hidden_states
  200. hidden_states = self.input_layernorm(hidden_states)
  201. attn_outputs = self.self_attn(
  202. position_ids=position_ids,
  203. hidden_states=hidden_states,
  204. kv_cache=kv_cache,
  205. attn_metadata=attn_metadata,
  206. )
  207. feed_forward_hidden_states = self.mlp(hidden_states)
  208. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  209. return hidden_states
  210. class PhiModel(nn.Module):
  211. def __init__(
  212. self,
  213. config: PretrainedConfig,
  214. linear_method: Optional[LinearMethodBase] = None,
  215. ):
  216. super().__init__()
  217. self.config = config
  218. self.linear_method = linear_method
  219. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  220. config.hidden_size,
  221. linear_method=linear_method)
  222. self.layers = nn.ModuleList([
  223. PhiLayer(config, linear_method)
  224. for _ in range(config.num_hidden_layers)
  225. ])
  226. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  227. eps=config.layer_norm_eps)
  228. def forward(
  229. self,
  230. input_ids: torch.Tensor,
  231. positions: torch.Tensor,
  232. kv_caches: List[torch.Tensor],
  233. attn_metadata: AttentionMetadata,
  234. ) -> torch.Tensor:
  235. hidden_states = self.embed_tokens(input_ids)
  236. for i in range(self.config.num_hidden_layers):
  237. layer = self.layers[i]
  238. hidden_states = layer(
  239. positions,
  240. hidden_states,
  241. kv_caches[i],
  242. attn_metadata,
  243. )
  244. hidden_states = self.final_layernorm(hidden_states)
  245. return hidden_states
  246. class PhiForCausalLM(nn.Module):
  247. def __init__(
  248. self,
  249. config: PretrainedConfig,
  250. linear_method: Optional[LinearMethodBase] = None,
  251. ):
  252. super().__init__()
  253. self.config = config
  254. self.linear_method = linear_method
  255. self.model = PhiModel(config, linear_method)
  256. self.lm_head = ParallelLMHead(
  257. config.vocab_size,
  258. config.hidden_size,
  259. bias=True,
  260. linear_method=linear_method,
  261. )
  262. self.logits_processor = LogitsProcessor(config.vocab_size,
  263. config.tokenizer_vocab_size)
  264. self.sampler = Sampler()
  265. def forward(
  266. self,
  267. input_ids: torch.Tensor,
  268. positions: torch.Tensor,
  269. kv_caches: List[torch.Tensor],
  270. attn_metadata: AttentionMetadata,
  271. ) -> torch.Tensor:
  272. hidden_states = self.model(input_ids, positions, kv_caches,
  273. attn_metadata)
  274. return hidden_states
  275. def compute_logits(self, hidden_states: torch.Tensor,
  276. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  277. logits = self.logits_processor(self.lm_head, hidden_states,
  278. sampling_metadata, self.lm_head.bias)
  279. return logits
  280. def sample(
  281. self,
  282. logits: torch.Tensor,
  283. sampling_metadata: SamplingMetadata,
  284. ) -> Optional[SamplerOutput]:
  285. next_tokens = self.sampler(logits, sampling_metadata)
  286. return next_tokens
  287. def load_weights(
  288. self,
  289. model_name_or_path: str,
  290. cache_dir: Optional[str] = None,
  291. load_format: str = "auto",
  292. revision: Optional[str] = None,
  293. ):
  294. stacked_params_mapping = [
  295. # (param_name, shard_name, shard_id)
  296. ("qkv_proj", "q_proj", "q"),
  297. ("qkv_proj", "k_proj", "k"),
  298. ("qkv_proj", "v_proj", "v"),
  299. ]
  300. if (self.linear_method is not None
  301. and not self.linear_method.quant_config.merge_weight()):
  302. stacked_params_mapping = []
  303. params_dict = dict(self.named_parameters())
  304. for name, loaded_weight in hf_model_weights_iterator(
  305. model_name_or_path, cache_dir, load_format, revision,
  306. self.config):
  307. if "rotary_emb.inv_freq" in name:
  308. continue
  309. for param_name, weight_name, shard_id in stacked_params_mapping:
  310. if weight_name not in name:
  311. continue
  312. name = name.replace(weight_name, param_name)
  313. # Skip loading extra bias for GPTQ models.
  314. if name.endswith(".bias") and name not in params_dict:
  315. continue
  316. param = params_dict[name]
  317. weight_loader = param.weight_loader
  318. weight_loader(param, loaded_weight, shard_id)
  319. break
  320. else:
  321. # Skip loading extra bias for GPTQ models.
  322. if name.endswith(".bias") and name not in params_dict:
  323. continue
  324. # pylint: disable=E1136
  325. param = params_dict[name]
  326. weight_loader = getattr(param, "weight_loader",
  327. default_weight_loader)
  328. weight_loader(param, loaded_weight)