phi.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright (c) Microsoft Corporation.
  6. # Licensed under the MIT license.
  7. #
  8. # BSD 3-Clause License
  9. #
  10. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  11. # All rights reserved.
  12. #
  13. # Redistribution and use in source and binary forms, with or without
  14. # modification, are permitted provided that the following conditions are met:
  15. #
  16. # * Redistributions of source code must retain the above copyright notice, this
  17. # list of conditions and the following disclaimer.
  18. #
  19. # * Redistributions in binary form must reproduce the above copyright notice,
  20. # this list of conditions and the following disclaimer in the documentation
  21. # and/or other materials provided with the distribution.
  22. #
  23. # * Neither the name of the copyright holder nor the names of its
  24. # contributors may be used to endorse or promote products derived from
  25. # this software without specific prior written permission.
  26. #
  27. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  28. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  29. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  30. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  31. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  32. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  33. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  34. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  35. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  36. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  37. """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
  38. from typing import Iterable, List, Optional, Tuple
  39. import torch
  40. from torch import nn
  41. from transformers import PhiConfig
  42. from aphrodite.attention import Attention, AttentionMetadata
  43. from aphrodite.common.config import CacheConfig, LoRAConfig
  44. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  45. from aphrodite.distributed import get_tensor_model_parallel_world_size
  46. from aphrodite.modeling.layers.activation import get_act_fn
  47. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  48. QKVParallelLinear,
  49. RowParallelLinear)
  50. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  51. from aphrodite.modeling.layers.rotary_embedding import get_rope
  52. from aphrodite.modeling.layers.sampler import Sampler
  53. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  54. ParallelLMHead, VocabParallelEmbedding)
  55. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  56. from aphrodite.modeling.models.interfaces import SupportsLoRA
  57. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  58. from aphrodite.quantization.base_config import QuantizationConfig
  59. class PhiAttention(nn.Module):
  60. def __init__(self,
  61. config: PhiConfig,
  62. cache_config: Optional[CacheConfig] = None,
  63. quant_config: Optional[QuantizationConfig] = None):
  64. super().__init__()
  65. self.total_num_heads = config.num_attention_heads
  66. self.hidden_size = config.hidden_size
  67. self.head_size = self.hidden_size // self.total_num_heads
  68. tensor_model_parallel_world_size = (
  69. get_tensor_model_parallel_world_size())
  70. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  71. self.num_heads = (self.total_num_heads //
  72. tensor_model_parallel_world_size)
  73. # pylint: disable=C0103
  74. self.qkv_proj = QKVParallelLinear(
  75. self.hidden_size,
  76. self.head_size,
  77. self.total_num_heads,
  78. bias=True,
  79. quant_config=quant_config,
  80. )
  81. self.dense = RowParallelLinear(
  82. self.hidden_size,
  83. self.hidden_size,
  84. quant_config=quant_config,
  85. )
  86. scaling = self.head_size**-0.5
  87. rotary_dim = int(config.partial_rotary_factor *
  88. (config.hidden_size // config.num_attention_heads))
  89. assert rotary_dim % 2 == 0
  90. # pylint: disable=C0301
  91. # Refer to:
  92. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  93. rope_theta = 10000
  94. max_position_embeddings = getattr(config, "n_positions", 2048)
  95. self.rotary_emb = get_rope(
  96. self.head_size,
  97. rotary_dim=rotary_dim,
  98. max_position=max_position_embeddings,
  99. base=rope_theta,
  100. )
  101. self.attn = Attention(self.num_heads,
  102. self.head_size,
  103. scaling,
  104. cache_config=cache_config,
  105. quant_config=quant_config)
  106. def forward(
  107. self,
  108. position_ids: torch.Tensor,
  109. hidden_states: torch.Tensor,
  110. kv_cache: torch.Tensor,
  111. attn_metadata: AttentionMetadata,
  112. ) -> torch.Tensor:
  113. qkv, _ = self.qkv_proj(hidden_states)
  114. q, k, v = qkv.chunk(chunks=3, dim=-1)
  115. q, k = self.rotary_emb(position_ids, q, k)
  116. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  117. output, _ = self.dense(attn_output)
  118. return output
  119. class PhiMLP(nn.Module):
  120. def __init__(self,
  121. config: PhiConfig,
  122. quant_config: Optional[QuantizationConfig] = None):
  123. super().__init__()
  124. n_inner = getattr(config, "n_inner", None)
  125. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  126. self.fc1 = ColumnParallelLinear(
  127. config.hidden_size,
  128. n_inner,
  129. quant_config=quant_config,
  130. )
  131. self.fc2 = RowParallelLinear(
  132. n_inner,
  133. config.hidden_size,
  134. quant_config=quant_config,
  135. )
  136. self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
  137. def forward(self, hidden_states):
  138. hidden_states, _ = self.fc1(hidden_states)
  139. hidden_states = self.act(hidden_states)
  140. hidden_states, _ = self.fc2(hidden_states)
  141. return hidden_states
  142. class PhiLayer(nn.Module):
  143. def __init__(self,
  144. config: PhiConfig,
  145. cache_config: Optional[CacheConfig] = None,
  146. quant_config: Optional[QuantizationConfig] = None):
  147. super().__init__()
  148. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  149. eps=config.layer_norm_eps)
  150. self.self_attn = PhiAttention(config, cache_config, quant_config)
  151. self.mlp = PhiMLP(config, quant_config)
  152. def forward(
  153. self,
  154. position_ids: torch.Tensor,
  155. hidden_states: torch.Tensor,
  156. kv_cache: torch.Tensor,
  157. attn_metadata: AttentionMetadata,
  158. ) -> torch.Tensor:
  159. residual = hidden_states
  160. hidden_states = self.input_layernorm(hidden_states)
  161. attn_outputs = self.self_attn(
  162. position_ids=position_ids,
  163. hidden_states=hidden_states,
  164. kv_cache=kv_cache,
  165. attn_metadata=attn_metadata,
  166. )
  167. feed_forward_hidden_states = self.mlp(hidden_states)
  168. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  169. return hidden_states
  170. class PhiModel(nn.Module):
  171. def __init__(self,
  172. config: PhiConfig,
  173. cache_config: Optional[CacheConfig] = None,
  174. quant_config: Optional[QuantizationConfig] = None):
  175. super().__init__()
  176. self.config = config
  177. self.quant_config = quant_config
  178. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  179. config.hidden_size)
  180. self.layers = nn.ModuleList([
  181. PhiLayer(config, cache_config, quant_config)
  182. for _ in range(config.num_hidden_layers)
  183. ])
  184. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  185. eps=config.layer_norm_eps)
  186. def forward(
  187. self,
  188. input_ids: torch.Tensor,
  189. positions: torch.Tensor,
  190. kv_caches: List[torch.Tensor],
  191. attn_metadata: AttentionMetadata,
  192. ) -> torch.Tensor:
  193. hidden_states = self.embed_tokens(input_ids)
  194. for i in range(self.config.num_hidden_layers):
  195. layer = self.layers[i]
  196. hidden_states = layer(
  197. positions,
  198. hidden_states,
  199. kv_caches[i],
  200. attn_metadata,
  201. )
  202. hidden_states = self.final_layernorm(hidden_states)
  203. return hidden_states
  204. class PhiForCausalLM(nn.Module, SupportsLoRA):
  205. packed_modules_mapping = {
  206. "qkv_proj": [
  207. "q_proj",
  208. "k_proj",
  209. "v_proj",
  210. ]
  211. }
  212. # LoRA specific attributes
  213. supported_lora_modules = [
  214. "qkv_proj",
  215. "dense",
  216. "fc1",
  217. "fc2",
  218. ]
  219. embedding_modules = {}
  220. embedding_padding_modules = []
  221. def __init__(
  222. self,
  223. config: PhiConfig,
  224. cache_config: Optional[CacheConfig] = None,
  225. quant_config: Optional[QuantizationConfig] = None,
  226. lora_config: Optional[LoRAConfig] = None,
  227. ):
  228. super().__init__()
  229. self.config = config
  230. self.lora_config = lora_config
  231. self.quant_config = quant_config
  232. self.model = PhiModel(config, cache_config, quant_config)
  233. self.lm_head = ParallelLMHead(config.vocab_size,
  234. config.hidden_size,
  235. bias=True,
  236. quant_config=quant_config)
  237. self.logits_processor = LogitsProcessor(config.vocab_size)
  238. self.sampler = Sampler()
  239. def forward(
  240. self,
  241. input_ids: torch.Tensor,
  242. positions: torch.Tensor,
  243. kv_caches: List[torch.Tensor],
  244. attn_metadata: AttentionMetadata,
  245. intermediate_tensors: Optional[IntermediateTensors] = None,
  246. ) -> torch.Tensor:
  247. hidden_states = self.model(input_ids, positions, kv_caches,
  248. attn_metadata)
  249. return hidden_states
  250. def compute_logits(
  251. self,
  252. hidden_states: torch.Tensor,
  253. sampling_metadata: SamplingMetadata,
  254. ) -> Optional[torch.Tensor]:
  255. logits = self.logits_processor(self.lm_head, hidden_states,
  256. sampling_metadata, self.lm_head.bias)
  257. return logits
  258. def sample(
  259. self,
  260. logits: torch.Tensor,
  261. sampling_metadata: SamplingMetadata,
  262. ) -> Optional[SamplerOutput]:
  263. next_tokens = self.sampler(logits, sampling_metadata)
  264. return next_tokens
  265. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  266. stacked_params_mapping = [
  267. # (param_name, shard_name, shard_id)
  268. ("qkv_proj", "q_proj", "q"),
  269. ("qkv_proj", "k_proj", "k"),
  270. ("qkv_proj", "v_proj", "v")
  271. ]
  272. params_dict = dict(self.named_parameters())
  273. for name, loaded_weight in weights:
  274. if "rotary_emb.inv_freq" in name:
  275. continue
  276. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  277. if weight_name not in name:
  278. continue
  279. name = name.replace(weight_name, param_name)
  280. # Skip loading extra bias for GPTQ models.
  281. if name.endswith(".bias") and name not in params_dict:
  282. continue
  283. param = params_dict[name]
  284. weight_loader = param.weight_loader
  285. weight_loader(param, loaded_weight, shard_id)
  286. break
  287. else:
  288. # Skip loading extra bias for GPTQ models.
  289. if name.endswith(".bias") and name not in params_dict:
  290. continue
  291. # pylint: disable=E1136
  292. param = params_dict[name]
  293. weight_loader = getattr(param, "weight_loader",
  294. default_weight_loader)
  295. weight_loader(param, loaded_weight)