phi.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright (c) Microsoft Corporation.
  7. # Licensed under the MIT license.
  8. #
  9. # BSD 3-Clause License
  10. #
  11. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  12. # All rights reserved.
  13. #
  14. # Redistribution and use in source and binary forms, with or without
  15. # modification, are permitted provided that the following conditions are met:
  16. #
  17. # * Redistributions of source code must retain the above copyright notice, this
  18. # list of conditions and the following disclaimer.
  19. #
  20. # * Redistributions in binary form must reproduce the above copyright notice,
  21. # this list of conditions and the following disclaimer in the documentation
  22. # and/or other materials provided with the distribution.
  23. #
  24. # * Neither the name of the copyright holder nor the names of its
  25. # contributors may be used to endorse or promote products derived from
  26. # this software without specific prior written permission.
  27. #
  28. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  29. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  30. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  31. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  32. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  33. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  34. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  35. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  36. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  37. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  38. """Inference-only Phi model compatible with HuggingFace weights."""
  39. from typing import List, Optional, Tuple
  40. import torch
  41. from torch import nn
  42. from transformers import PretrainedConfig
  43. from aphrodite.modeling.metadata import InputMetadata
  44. from aphrodite.modeling.layers.activation import get_act_fn
  45. from aphrodite.modeling.layers.attention import PagedAttention
  46. from aphrodite.modeling.layers.linear import (
  47. ColumnParallelLinear,
  48. LinearMethodBase,
  49. QKVParallelLinear,
  50. RowParallelLinear,
  51. )
  52. from aphrodite.modeling.layers.rotary_embedding import get_rope
  53. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  54. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  55. VocabParallelEmbedding,
  56. ParallelLMHead,
  57. )
  58. from aphrodite.modeling.megatron.parallel_state import (
  59. get_tensor_model_parallel_world_size, )
  60. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  61. from aphrodite.modeling.hf_downloader import (
  62. default_weight_loader,
  63. hf_model_weights_iterator,
  64. )
  65. from aphrodite.common.sequence import SamplerOutput
  66. KVCache = Tuple[torch.Tensor, torch.Tensor]
  67. class PhiAttention(nn.Module):
  68. def __init__(
  69. self,
  70. config: PretrainedConfig,
  71. linear_method: Optional[LinearMethodBase] = None,
  72. ):
  73. super().__init__()
  74. self.total_num_heads = config.num_attention_heads
  75. self.hidden_size = config.hidden_size
  76. self.head_size = self.hidden_size // self.total_num_heads
  77. tensor_model_parallel_world_size = (
  78. get_tensor_model_parallel_world_size())
  79. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  80. self.num_heads = (self.total_num_heads //
  81. tensor_model_parallel_world_size)
  82. # pylint: disable=C0103
  83. if (linear_method is not None
  84. and not linear_method.quant_config.merge_weight()):
  85. self.merge_weight = False
  86. self.q_proj = ColumnParallelLinear(
  87. self.hidden_size,
  88. self.hidden_size,
  89. bias=True,
  90. linear_method=linear_method,
  91. )
  92. self.k_proj = ColumnParallelLinear(
  93. self.hidden_size,
  94. self.hidden_size,
  95. bias=True,
  96. linear_method=linear_method,
  97. )
  98. self.v_proj = ColumnParallelLinear(
  99. self.hidden_size,
  100. self.hidden_size,
  101. bias=True,
  102. linear_method=linear_method,
  103. )
  104. else:
  105. self.merge_weight = True
  106. self.qkv_proj = QKVParallelLinear(
  107. self.hidden_size,
  108. self.head_size,
  109. self.total_num_heads,
  110. bias=True,
  111. linear_method=linear_method,
  112. )
  113. self.dense = RowParallelLinear(
  114. self.hidden_size,
  115. self.hidden_size,
  116. linear_method=linear_method,
  117. )
  118. scaling = self.head_size**-0.5
  119. rotary_dim = int(config.partial_rotary_factor *
  120. (config.hidden_size // config.num_attention_heads))
  121. assert rotary_dim % 2 == 0
  122. # pylint: disable=C0301
  123. # Refer to:
  124. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  125. rope_theta = 10000
  126. max_position_embeddings = getattr(config, "n_positions", 2048)
  127. is_neox_style = (True if linear_method is None
  128. or linear_method.quant_config.rope_style() is None
  129. else linear_method.quant_config.rope_style())
  130. self.rotary_emb = get_rope(
  131. self.head_size,
  132. rotary_dim=rotary_dim,
  133. max_position=max_position_embeddings,
  134. base=rope_theta,
  135. is_neox_style=is_neox_style,
  136. )
  137. self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
  138. def forward(
  139. self,
  140. position_ids: torch.Tensor,
  141. hidden_states: torch.Tensor,
  142. kv_cache: KVCache,
  143. input_metadata: InputMetadata,
  144. ) -> torch.Tensor:
  145. if self.merge_weight:
  146. qkv, _ = self.qkv_proj(hidden_states)
  147. q, k, v = qkv.chunk(chunks=3, dim=-1)
  148. else:
  149. q, _ = self.q_proj(hidden_states)
  150. k, _ = self.k_proj(hidden_states)
  151. v, _ = self.v_proj(hidden_states)
  152. q, k = self.rotary_emb(position_ids, q, k)
  153. k_cache, v_cache = kv_cache
  154. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  155. output, _ = self.dense(attn_output)
  156. return output
  157. class PhiMLP(nn.Module):
  158. def __init__(
  159. self,
  160. config: PretrainedConfig,
  161. linear_method: Optional[LinearMethodBase] = None,
  162. ):
  163. super().__init__()
  164. n_inner = getattr(config, "n_inner", None)
  165. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  166. self.fc1 = ColumnParallelLinear(
  167. config.hidden_size,
  168. n_inner,
  169. linear_method=linear_method,
  170. )
  171. self.fc2 = RowParallelLinear(
  172. n_inner,
  173. config.hidden_size,
  174. linear_method=linear_method,
  175. )
  176. quant_config = getattr(linear_method, "quant_config", None)
  177. self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
  178. def forward(self, hidden_states):
  179. hidden_states, _ = self.fc1(hidden_states)
  180. hidden_states = self.act(hidden_states)
  181. hidden_states, _ = self.fc2(hidden_states)
  182. return hidden_states
  183. class PhiLayer(nn.Module):
  184. def __init__(
  185. self,
  186. config: PretrainedConfig,
  187. linear_method: Optional[LinearMethodBase] = None,
  188. ):
  189. super().__init__()
  190. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  191. eps=config.layer_norm_eps)
  192. self.self_attn = PhiAttention(config, linear_method)
  193. self.mlp = PhiMLP(config, linear_method)
  194. def forward(
  195. self,
  196. position_ids: torch.Tensor,
  197. hidden_states: torch.Tensor,
  198. kv_cache: KVCache,
  199. input_metadata: InputMetadata,
  200. ) -> torch.Tensor:
  201. residual = hidden_states
  202. hidden_states = self.input_layernorm(hidden_states)
  203. attn_outputs = self.self_attn(
  204. position_ids=position_ids,
  205. hidden_states=hidden_states,
  206. kv_cache=kv_cache,
  207. input_metadata=input_metadata,
  208. )
  209. feed_forward_hidden_states = self.mlp(hidden_states)
  210. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  211. return hidden_states
  212. class PhiModel(nn.Module):
  213. def __init__(
  214. self,
  215. config: PretrainedConfig,
  216. linear_method: Optional[LinearMethodBase] = None,
  217. ):
  218. super().__init__()
  219. self.config = config
  220. self.linear_method = linear_method
  221. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  222. config.hidden_size,
  223. linear_method=linear_method)
  224. self.layers = nn.ModuleList([
  225. PhiLayer(config, linear_method)
  226. for _ in range(config.num_hidden_layers)
  227. ])
  228. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  229. eps=config.layer_norm_eps)
  230. def forward(
  231. self,
  232. input_ids: torch.Tensor,
  233. positions: torch.Tensor,
  234. kv_caches: List[KVCache],
  235. input_metadata: InputMetadata,
  236. ) -> torch.Tensor:
  237. hidden_states = self.embed_tokens(input_ids)
  238. for i in range(self.config.num_hidden_layers):
  239. layer = self.layers[i]
  240. hidden_states = layer(
  241. positions,
  242. hidden_states,
  243. kv_caches[i],
  244. input_metadata,
  245. )
  246. hidden_states = self.final_layernorm(hidden_states)
  247. return hidden_states
  248. class PhiForCausalLM(nn.Module):
  249. def __init__(
  250. self,
  251. config: PretrainedConfig,
  252. linear_method: Optional[LinearMethodBase] = None,
  253. ):
  254. super().__init__()
  255. self.config = config
  256. self.linear_method = linear_method
  257. self.model = PhiModel(config, linear_method)
  258. self.lm_head = ParallelLMHead(
  259. config.vocab_size,
  260. config.hidden_size,
  261. bias=True,
  262. linear_method=linear_method,
  263. )
  264. self.sampler = Sampler(config.vocab_size)
  265. self.quant_sampler = QuantSampler(config.vocab_size)
  266. def forward(
  267. self,
  268. input_ids: torch.Tensor,
  269. positions: torch.Tensor,
  270. kv_caches: List[KVCache],
  271. input_metadata: InputMetadata,
  272. ) -> torch.Tensor:
  273. hidden_states = self.model(input_ids, positions, kv_caches,
  274. input_metadata)
  275. return hidden_states
  276. def sample(
  277. self,
  278. hidden_states: torch.Tensor,
  279. sampling_metadata: SamplingMetadata,
  280. ) -> Optional[SamplerOutput]:
  281. if (self.linear_method is not None
  282. and not self.linear_method.quant_config.merge_weight()):
  283. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  284. sampling_metadata)
  285. else:
  286. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  287. sampling_metadata)
  288. return next_tokens
  289. def load_weights(
  290. self,
  291. model_name_or_path: str,
  292. cache_dir: Optional[str] = None,
  293. load_format: str = "auto",
  294. revision: Optional[str] = None,
  295. ):
  296. stacked_params_mapping = [
  297. # (param_name, shard_name, shard_id)
  298. ("qkv_proj", "q_proj", "q"),
  299. ("qkv_proj", "k_proj", "k"),
  300. ("qkv_proj", "v_proj", "v"),
  301. ]
  302. if (self.linear_method is not None
  303. and not self.linear_method.quant_config.merge_weight()):
  304. stacked_params_mapping = []
  305. params_dict = dict(self.named_parameters())
  306. for name, loaded_weight in hf_model_weights_iterator(
  307. model_name_or_path, cache_dir, load_format, revision,
  308. self.config):
  309. if "rotary_emb.inv_freq" in name:
  310. continue
  311. for param_name, weight_name, shard_id in stacked_params_mapping:
  312. if weight_name not in name:
  313. continue
  314. name = name.replace(weight_name, param_name)
  315. # Skip loading extra bias for GPTQ models.
  316. if name.endswith(".bias") and name not in params_dict:
  317. continue
  318. param = params_dict[name]
  319. weight_loader = param.weight_loader
  320. weight_loader(param, loaded_weight, shard_id)
  321. break
  322. else:
  323. # Skip loading extra bias for GPTQ models.
  324. if name.endswith(".bias") and name not in params_dict:
  325. continue
  326. # pylint: disable=E1136
  327. param = params_dict[name]
  328. weight_loader = getattr(param, "weight_loader",
  329. default_weight_loader)
  330. weight_loader(param, loaded_weight)