phi.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright (c) Microsoft Corporation.
  6. # Licensed under the MIT license.
  7. #
  8. # BSD 3-Clause License
  9. #
  10. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  11. # All rights reserved.
  12. #
  13. # Redistribution and use in source and binary forms, with or without
  14. # modification, are permitted provided that the following conditions are met:
  15. #
  16. # * Redistributions of source code must retain the above copyright notice, this
  17. # list of conditions and the following disclaimer.
  18. #
  19. # * Redistributions in binary form must reproduce the above copyright notice,
  20. # this list of conditions and the following disclaimer in the documentation
  21. # and/or other materials provided with the distribution.
  22. #
  23. # * Neither the name of the copyright holder nor the names of its
  24. # contributors may be used to endorse or promote products derived from
  25. # this software without specific prior written permission.
  26. #
  27. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  28. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  29. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  30. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  31. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  32. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  33. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  34. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  35. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  36. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  37. """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
  38. from typing import Iterable, List, Optional, Tuple
  39. import torch
  40. from torch import nn
  41. from transformers import PhiConfig
  42. from aphrodite.attention import Attention, AttentionMetadata
  43. from aphrodite.common.config import CacheConfig, LoRAConfig
  44. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  45. from aphrodite.common.utils import progress_bar
  46. from aphrodite.distributed import get_tensor_model_parallel_world_size
  47. from aphrodite.modeling.layers.activation import get_act_fn
  48. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  49. QKVParallelLinear,
  50. RowParallelLinear)
  51. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  52. from aphrodite.modeling.layers.rotary_embedding import get_rope
  53. from aphrodite.modeling.layers.sampler import Sampler
  54. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  55. ParallelLMHead, VocabParallelEmbedding)
  56. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  57. from aphrodite.modeling.models.interfaces import SupportsLoRA
  58. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  59. from aphrodite.quantization.base_config import QuantizationConfig
  60. class PhiAttention(nn.Module):
  61. def __init__(self,
  62. config: PhiConfig,
  63. cache_config: Optional[CacheConfig] = None,
  64. quant_config: Optional[QuantizationConfig] = None):
  65. super().__init__()
  66. self.total_num_heads = config.num_attention_heads
  67. self.hidden_size = config.hidden_size
  68. self.head_size = self.hidden_size // self.total_num_heads
  69. tensor_model_parallel_world_size = (
  70. get_tensor_model_parallel_world_size())
  71. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  72. self.num_heads = (self.total_num_heads //
  73. tensor_model_parallel_world_size)
  74. # pylint: disable=C0103
  75. self.qkv_proj = QKVParallelLinear(
  76. self.hidden_size,
  77. self.head_size,
  78. self.total_num_heads,
  79. bias=True,
  80. quant_config=quant_config,
  81. )
  82. self.dense = RowParallelLinear(
  83. self.hidden_size,
  84. self.hidden_size,
  85. quant_config=quant_config,
  86. )
  87. scaling = self.head_size**-0.5
  88. rotary_dim = int(config.partial_rotary_factor *
  89. (config.hidden_size // config.num_attention_heads))
  90. assert rotary_dim % 2 == 0
  91. # pylint: disable=C0301
  92. # Refer to:
  93. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  94. rope_theta = 10000
  95. max_position_embeddings = getattr(config, "n_positions", 2048)
  96. self.rotary_emb = get_rope(
  97. self.head_size,
  98. rotary_dim=rotary_dim,
  99. max_position=max_position_embeddings,
  100. base=rope_theta,
  101. )
  102. self.attn = Attention(self.num_heads,
  103. self.head_size,
  104. scaling,
  105. cache_config=cache_config,
  106. quant_config=quant_config)
  107. def forward(
  108. self,
  109. position_ids: torch.Tensor,
  110. hidden_states: torch.Tensor,
  111. kv_cache: torch.Tensor,
  112. attn_metadata: AttentionMetadata,
  113. ) -> torch.Tensor:
  114. qkv, _ = self.qkv_proj(hidden_states)
  115. q, k, v = qkv.chunk(chunks=3, dim=-1)
  116. q, k = self.rotary_emb(position_ids, q, k)
  117. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  118. output, _ = self.dense(attn_output)
  119. return output
  120. class PhiMLP(nn.Module):
  121. def __init__(self,
  122. config: PhiConfig,
  123. quant_config: Optional[QuantizationConfig] = None):
  124. super().__init__()
  125. n_inner = getattr(config, "n_inner", None)
  126. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  127. self.fc1 = ColumnParallelLinear(
  128. config.hidden_size,
  129. n_inner,
  130. quant_config=quant_config,
  131. )
  132. self.fc2 = RowParallelLinear(
  133. n_inner,
  134. config.hidden_size,
  135. quant_config=quant_config,
  136. )
  137. self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
  138. def forward(self, hidden_states):
  139. hidden_states, _ = self.fc1(hidden_states)
  140. hidden_states = self.act(hidden_states)
  141. hidden_states, _ = self.fc2(hidden_states)
  142. return hidden_states
  143. class PhiLayer(nn.Module):
  144. def __init__(self,
  145. config: PhiConfig,
  146. cache_config: Optional[CacheConfig] = None,
  147. quant_config: Optional[QuantizationConfig] = None):
  148. super().__init__()
  149. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  150. eps=config.layer_norm_eps)
  151. self.self_attn = PhiAttention(config, cache_config, quant_config)
  152. self.mlp = PhiMLP(config, quant_config)
  153. def forward(
  154. self,
  155. position_ids: torch.Tensor,
  156. hidden_states: torch.Tensor,
  157. kv_cache: torch.Tensor,
  158. attn_metadata: AttentionMetadata,
  159. ) -> torch.Tensor:
  160. residual = hidden_states
  161. hidden_states = self.input_layernorm(hidden_states)
  162. attn_outputs = self.self_attn(
  163. position_ids=position_ids,
  164. hidden_states=hidden_states,
  165. kv_cache=kv_cache,
  166. attn_metadata=attn_metadata,
  167. )
  168. feed_forward_hidden_states = self.mlp(hidden_states)
  169. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  170. return hidden_states
  171. class PhiModel(nn.Module):
  172. def __init__(self,
  173. config: PhiConfig,
  174. cache_config: Optional[CacheConfig] = None,
  175. quant_config: Optional[QuantizationConfig] = None):
  176. super().__init__()
  177. self.config = config
  178. self.quant_config = quant_config
  179. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  180. config.hidden_size)
  181. self.layers = nn.ModuleList([
  182. PhiLayer(config, cache_config, quant_config)
  183. for _ in range(config.num_hidden_layers)
  184. ])
  185. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  186. eps=config.layer_norm_eps)
  187. def forward(
  188. self,
  189. input_ids: torch.Tensor,
  190. positions: torch.Tensor,
  191. kv_caches: List[torch.Tensor],
  192. attn_metadata: AttentionMetadata,
  193. ) -> torch.Tensor:
  194. hidden_states = self.embed_tokens(input_ids)
  195. for i in range(self.config.num_hidden_layers):
  196. layer = self.layers[i]
  197. hidden_states = layer(
  198. positions,
  199. hidden_states,
  200. kv_caches[i],
  201. attn_metadata,
  202. )
  203. hidden_states = self.final_layernorm(hidden_states)
  204. return hidden_states
  205. class PhiForCausalLM(nn.Module, SupportsLoRA):
  206. packed_modules_mapping = {
  207. "qkv_proj": [
  208. "q_proj",
  209. "k_proj",
  210. "v_proj",
  211. ]
  212. }
  213. # LoRA specific attributes
  214. supported_lora_modules = [
  215. "qkv_proj",
  216. "dense",
  217. "fc1",
  218. "fc2",
  219. ]
  220. embedding_modules = {}
  221. embedding_padding_modules = []
  222. def __init__(
  223. self,
  224. config: PhiConfig,
  225. cache_config: Optional[CacheConfig] = None,
  226. quant_config: Optional[QuantizationConfig] = None,
  227. lora_config: Optional[LoRAConfig] = None,
  228. ):
  229. super().__init__()
  230. self.config = config
  231. self.lora_config = lora_config
  232. self.quant_config = quant_config
  233. self.model = PhiModel(config, cache_config, quant_config)
  234. self.lm_head = ParallelLMHead(config.vocab_size,
  235. config.hidden_size,
  236. bias=True,
  237. quant_config=quant_config)
  238. self.logits_processor = LogitsProcessor(config.vocab_size)
  239. self.sampler = Sampler()
  240. def forward(
  241. self,
  242. input_ids: torch.Tensor,
  243. positions: torch.Tensor,
  244. kv_caches: List[torch.Tensor],
  245. attn_metadata: AttentionMetadata,
  246. intermediate_tensors: Optional[IntermediateTensors] = None,
  247. ) -> torch.Tensor:
  248. hidden_states = self.model(input_ids, positions, kv_caches,
  249. attn_metadata)
  250. return hidden_states
  251. def compute_logits(
  252. self,
  253. hidden_states: torch.Tensor,
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[torch.Tensor]:
  256. logits = self.logits_processor(self.lm_head, hidden_states,
  257. sampling_metadata, self.lm_head.bias)
  258. return logits
  259. def sample(
  260. self,
  261. logits: torch.Tensor,
  262. sampling_metadata: SamplingMetadata,
  263. ) -> Optional[SamplerOutput]:
  264. next_tokens = self.sampler(logits, sampling_metadata)
  265. return next_tokens
  266. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  267. stacked_params_mapping = [
  268. # (param_name, shard_name, shard_id)
  269. ("qkv_proj", "q_proj", "q"),
  270. ("qkv_proj", "k_proj", "k"),
  271. ("qkv_proj", "v_proj", "v")
  272. ]
  273. params_dict = dict(self.named_parameters())
  274. weights_list = list(weights)
  275. for name, loaded_weight in progress_bar(weights_list,
  276. desc="Loading modules..."):
  277. if "rotary_emb.inv_freq" in name:
  278. continue
  279. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  280. if weight_name not in name:
  281. continue
  282. name = name.replace(weight_name, param_name)
  283. # Skip loading extra bias for GPTQ models.
  284. if name.endswith(".bias") and name not in params_dict:
  285. continue
  286. param = params_dict[name]
  287. weight_loader = param.weight_loader
  288. weight_loader(param, loaded_weight, shard_id)
  289. break
  290. else:
  291. # Skip loading extra bias for GPTQ models.
  292. if name.endswith(".bias") and name not in params_dict:
  293. continue
  294. # pylint: disable=E1136
  295. param = params_dict[name]
  296. weight_loader = getattr(param, "weight_loader",
  297. default_weight_loader)
  298. weight_loader(param, loaded_weight)