phi.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright (c) Microsoft Corporation.
  7. # Licensed under the MIT license.
  8. #
  9. # BSD 3-Clause License
  10. #
  11. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  12. # All rights reserved.
  13. #
  14. # Redistribution and use in source and binary forms, with or without
  15. # modification, are permitted provided that the following conditions are met:
  16. #
  17. # * Redistributions of source code must retain the above copyright notice, this
  18. # list of conditions and the following disclaimer.
  19. #
  20. # * Redistributions in binary form must reproduce the above copyright notice,
  21. # this list of conditions and the following disclaimer in the documentation
  22. # and/or other materials provided with the distribution.
  23. #
  24. # * Neither the name of the copyright holder nor the names of its
  25. # contributors may be used to endorse or promote products derived from
  26. # this software without specific prior written permission.
  27. #
  28. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  29. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  30. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  31. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  32. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  33. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  34. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  35. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  36. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  37. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  38. """Inference-only Phi model compatible with HuggingFace weights."""
  39. from typing import List, Optional, Tuple
  40. import torch
  41. from torch import nn
  42. from transformers import PretrainedConfig
  43. from aphrodite.modeling.metadata import InputMetadata
  44. from aphrodite.modeling.layers.activation import get_act_fn
  45. from aphrodite.modeling.layers.attention import PagedAttention
  46. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  47. LinearMethodBase,
  48. QKVParallelLinear,
  49. RowParallelLinear)
  50. from aphrodite.modeling.layers.rotary_embedding import get_rope
  51. from aphrodite.modeling.layers.sampler import Sampler
  52. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  53. VocabParallelEmbedding, ParallelLMHead)
  54. from aphrodite.modeling.megatron.parallel_state import (
  55. get_tensor_model_parallel_world_size)
  56. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  57. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  58. hf_model_weights_iterator)
  59. from aphrodite.common.sequence import SamplerOutput
  60. KVCache = Tuple[torch.Tensor, torch.Tensor]
  61. class PhiAttention(nn.Module):
  62. def __init__(self,
  63. config: PretrainedConfig,
  64. linear_method: Optional[LinearMethodBase] = None):
  65. super().__init__()
  66. self.total_num_heads = config.num_attention_heads
  67. self.hidden_size = config.hidden_size
  68. self.head_size = self.hidden_size // self.total_num_heads
  69. tensor_model_parallel_world_size = (
  70. get_tensor_model_parallel_world_size())
  71. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  72. self.num_heads = (self.total_num_heads //
  73. tensor_model_parallel_world_size)
  74. # pylint: disable=C0103
  75. if linear_method is not None and not linear_method.quant_config.merge_weight(
  76. ):
  77. self.merge_weight = False
  78. self.q_proj = ColumnParallelLinear(self.hidden_size,
  79. self.hidden_size,
  80. bias=True,
  81. linear_method=linear_method)
  82. self.k_proj = ColumnParallelLinear(self.hidden_size,
  83. self.hidden_size,
  84. bias=True,
  85. linear_method=linear_method)
  86. self.v_proj = ColumnParallelLinear(self.hidden_size,
  87. self.hidden_size,
  88. bias=True,
  89. linear_method=linear_method)
  90. else:
  91. self.merge_weight = True
  92. self.qkv_proj = QKVParallelLinear(
  93. self.hidden_size,
  94. self.head_size,
  95. self.total_num_heads,
  96. bias=True,
  97. linear_method=linear_method,
  98. )
  99. self.dense = RowParallelLinear(
  100. self.hidden_size,
  101. self.hidden_size,
  102. linear_method=linear_method,
  103. )
  104. scaling = self.head_size**-0.5
  105. rotary_dim = int(config.partial_rotary_factor *
  106. (config.hidden_size // config.num_attention_heads))
  107. assert rotary_dim % 2 == 0
  108. # pylint: disable=C0301
  109. # Refer to:
  110. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  111. rope_theta = 10000
  112. max_position_embeddings = getattr(config, "n_positions", 2048)
  113. is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
  114. ) is None else linear_method.quant_config.rope_style()
  115. self.rotary_emb = get_rope(
  116. self.head_size,
  117. rotary_dim=rotary_dim,
  118. max_position=max_position_embeddings,
  119. base=rope_theta,
  120. is_neox_style=is_neox_style,
  121. )
  122. self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
  123. def forward(
  124. self,
  125. position_ids: torch.Tensor,
  126. hidden_states: torch.Tensor,
  127. kv_cache: KVCache,
  128. input_metadata: InputMetadata,
  129. ) -> torch.Tensor:
  130. if self.merge_weight:
  131. qkv, _ = self.qkv_proj(hidden_states)
  132. q, k, v = qkv.chunk(chunks=3, dim=-1)
  133. else:
  134. q, _ = self.q_proj(hidden_states)
  135. k, _ = self.k_proj(hidden_states)
  136. v, _ = self.v_proj(hidden_states)
  137. q, k = self.rotary_emb(position_ids, q, k)
  138. k_cache, v_cache = kv_cache
  139. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  140. output, _ = self.dense(attn_output)
  141. return output
  142. class PhiMLP(nn.Module):
  143. def __init__(self,
  144. config: PretrainedConfig,
  145. linear_method: Optional[LinearMethodBase] = None):
  146. super().__init__()
  147. n_inner = getattr(config, "n_inner", None)
  148. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  149. self.fc1 = ColumnParallelLinear(
  150. config.hidden_size,
  151. n_inner,
  152. linear_method=linear_method,
  153. )
  154. self.fc2 = RowParallelLinear(
  155. n_inner,
  156. config.hidden_size,
  157. linear_method=linear_method,
  158. )
  159. quant_config = getattr(linear_method, "quant_config", None)
  160. self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
  161. def forward(self, hidden_states):
  162. hidden_states, _ = self.fc1(hidden_states)
  163. hidden_states = self.act(hidden_states)
  164. hidden_states, _ = self.fc2(hidden_states)
  165. return hidden_states
  166. class PhiLayer(nn.Module):
  167. def __init__(self,
  168. config: PretrainedConfig,
  169. linear_method: Optional[LinearMethodBase] = None):
  170. super().__init__()
  171. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  172. eps=config.layer_norm_eps)
  173. self.self_attn = PhiAttention(config, linear_method)
  174. self.mlp = PhiMLP(config, linear_method)
  175. def forward(
  176. self,
  177. position_ids: torch.Tensor,
  178. hidden_states: torch.Tensor,
  179. kv_cache: KVCache,
  180. input_metadata: InputMetadata,
  181. ) -> torch.Tensor:
  182. residual = hidden_states
  183. hidden_states = self.input_layernorm(hidden_states)
  184. attn_outputs = self.self_attn(
  185. position_ids=position_ids,
  186. hidden_states=hidden_states,
  187. kv_cache=kv_cache,
  188. input_metadata=input_metadata,
  189. )
  190. feed_forward_hidden_states = self.mlp(hidden_states)
  191. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  192. return hidden_states
  193. class PhiModel(nn.Module):
  194. def __init__(self,
  195. config: PretrainedConfig,
  196. linear_method: Optional[LinearMethodBase] = None):
  197. super().__init__()
  198. self.config = config
  199. self.linear_method = linear_method
  200. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  201. config.hidden_size,
  202. linear_method=linear_method)
  203. self.layers = nn.ModuleList([
  204. PhiLayer(config, linear_method)
  205. for _ in range(config.num_hidden_layers)
  206. ])
  207. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  208. eps=config.layer_norm_eps)
  209. def forward(
  210. self,
  211. input_ids: torch.Tensor,
  212. positions: torch.Tensor,
  213. kv_caches: List[KVCache],
  214. input_metadata: InputMetadata,
  215. ) -> torch.Tensor:
  216. hidden_states = self.embed_tokens(input_ids)
  217. for i in range(self.config.num_hidden_layers):
  218. layer = self.layers[i]
  219. hidden_states = layer(
  220. positions,
  221. hidden_states,
  222. kv_caches[i],
  223. input_metadata,
  224. )
  225. hidden_states = self.final_layernorm(hidden_states)
  226. return hidden_states
  227. class PhiForCausalLM(nn.Module):
  228. def __init__(self,
  229. config: PretrainedConfig,
  230. linear_method: Optional[LinearMethodBase] = None):
  231. super().__init__()
  232. self.config = config
  233. self.linear_method = linear_method
  234. self.model = PhiModel(config, linear_method)
  235. self.lm_head = ParallelLMHead(config.vocab_size,
  236. config.hidden_size,
  237. bias=True,
  238. linear_method=linear_method)
  239. self.sampler = Sampler(config.vocab_size)
  240. def forward(
  241. self,
  242. input_ids: torch.Tensor,
  243. positions: torch.Tensor,
  244. kv_caches: List[KVCache],
  245. input_metadata: InputMetadata,
  246. ) -> torch.Tensor:
  247. hidden_states = self.model(input_ids, positions, kv_caches,
  248. input_metadata)
  249. return hidden_states
  250. def sample(
  251. self,
  252. hidden_states: torch.Tensor,
  253. sampling_metadata: SamplingMetadata,
  254. ) -> Optional[SamplerOutput]:
  255. head = self.lm_head # pylint: disable=unused-variable
  256. next_tokens = self.sampler(self.lm_head(hidden_states),
  257. sampling_metadata)
  258. return next_tokens
  259. def load_weights(self,
  260. model_name_or_path: str,
  261. cache_dir: Optional[str] = None,
  262. load_format: str = "auto",
  263. revision: Optional[str] = None):
  264. stacked_params_mapping = [
  265. # (param_name, shard_name, shard_id)
  266. ("qkv_proj", "q_proj", "q"),
  267. ("qkv_proj", "k_proj", "k"),
  268. ("qkv_proj", "v_proj", "v")
  269. ]
  270. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  271. ):
  272. stacked_params_mapping = []
  273. params_dict = dict(self.named_parameters())
  274. for name, loaded_weight in hf_model_weights_iterator(
  275. model_name_or_path, cache_dir, load_format, revision):
  276. if "rotary_emb.inv_freq" in name:
  277. continue
  278. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  279. if weight_name not in name:
  280. continue
  281. name = name.replace(weight_name, param_name)
  282. # Skip loading extra bias for GPTQ models.
  283. if name.endswith(".bias") and name not in params_dict:
  284. continue
  285. param = params_dict[name]
  286. weight_loader = param.weight_loader
  287. weight_loader(param, loaded_weight, shard_id)
  288. break
  289. else:
  290. # Skip loading extra bias for GPTQ models.
  291. if name.endswith(".bias") and name not in params_dict:
  292. continue
  293. # pylint: disable=E1136
  294. param = params_dict[name]
  295. weight_loader = getattr(param, "weight_loader",
  296. default_weight_loader)
  297. weight_loader(param, loaded_weight)