1
0

phi.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright (c) Microsoft Corporation.
  6. # Licensed under the MIT license.
  7. #
  8. # BSD 3-Clause License
  9. #
  10. # Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
  11. # All rights reserved.
  12. #
  13. # Redistribution and use in source and binary forms, with or without
  14. # modification, are permitted provided that the following conditions are met:
  15. #
  16. # * Redistributions of source code must retain the above copyright notice, this
  17. # list of conditions and the following disclaimer.
  18. #
  19. # * Redistributions in binary form must reproduce the above copyright notice,
  20. # this list of conditions and the following disclaimer in the documentation
  21. # and/or other materials provided with the distribution.
  22. #
  23. # * Neither the name of the copyright holder nor the names of its
  24. # contributors may be used to endorse or promote products derived from
  25. # this software without specific prior written permission.
  26. #
  27. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  28. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  29. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  30. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
  31. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  32. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
  33. # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
  34. # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
  35. # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  36. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  37. """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
  38. from typing import Iterable, List, Optional, Tuple
  39. import torch
  40. from torch import nn
  41. from transformers import PretrainedConfig
  42. from aphrodite.attention import Attention, AttentionMetadata
  43. from aphrodite.common.sequence import SamplerOutput
  44. from aphrodite.distributed import get_tensor_model_parallel_world_size
  45. from aphrodite.modeling.layers.activation import get_act_fn
  46. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  47. QKVParallelLinear,
  48. RowParallelLinear)
  49. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  50. from aphrodite.modeling.layers.rotary_embedding import get_rope
  51. from aphrodite.modeling.layers.sampler import Sampler
  52. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  53. ParallelLMHead, VocabParallelEmbedding)
  54. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  55. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  56. from aphrodite.quantization.base_config import QuantizationConfig
  57. class PhiAttention(nn.Module):
  58. def __init__(self,
  59. config: PretrainedConfig,
  60. quant_config: Optional[QuantizationConfig] = None):
  61. super().__init__()
  62. self.total_num_heads = config.num_attention_heads
  63. self.hidden_size = config.hidden_size
  64. self.head_size = self.hidden_size // self.total_num_heads
  65. tensor_model_parallel_world_size = (
  66. get_tensor_model_parallel_world_size())
  67. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  68. self.num_heads = (self.total_num_heads //
  69. tensor_model_parallel_world_size)
  70. # pylint: disable=C0103
  71. self.qkv_proj = QKVParallelLinear(
  72. self.hidden_size,
  73. self.head_size,
  74. self.total_num_heads,
  75. bias=True,
  76. quant_config=quant_config,
  77. )
  78. self.dense = RowParallelLinear(
  79. self.hidden_size,
  80. self.hidden_size,
  81. quant_config=quant_config,
  82. )
  83. scaling = self.head_size**-0.5
  84. rotary_dim = int(config.partial_rotary_factor *
  85. (config.hidden_size // config.num_attention_heads))
  86. assert rotary_dim % 2 == 0
  87. # pylint: disable=C0301
  88. # Refer to:
  89. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  90. rope_theta = 10000
  91. max_position_embeddings = getattr(config, "n_positions", 2048)
  92. self.rotary_emb = get_rope(
  93. self.head_size,
  94. rotary_dim=rotary_dim,
  95. max_position=max_position_embeddings,
  96. base=rope_theta,
  97. )
  98. self.attn = Attention(self.num_heads, self.head_size, scaling)
  99. def forward(
  100. self,
  101. position_ids: torch.Tensor,
  102. hidden_states: torch.Tensor,
  103. kv_cache: torch.Tensor,
  104. attn_metadata: AttentionMetadata,
  105. ) -> torch.Tensor:
  106. qkv, _ = self.qkv_proj(hidden_states)
  107. q, k, v = qkv.chunk(chunks=3, dim=-1)
  108. q, k = self.rotary_emb(position_ids, q, k)
  109. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  110. output, _ = self.dense(attn_output)
  111. return output
  112. class PhiMLP(nn.Module):
  113. def __init__(self,
  114. config: PretrainedConfig,
  115. quant_config: Optional[QuantizationConfig] = None):
  116. super().__init__()
  117. n_inner = getattr(config, "n_inner", None)
  118. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  119. self.fc1 = ColumnParallelLinear(
  120. config.hidden_size,
  121. n_inner,
  122. quant_config=quant_config,
  123. )
  124. self.fc2 = RowParallelLinear(
  125. n_inner,
  126. config.hidden_size,
  127. quant_config=quant_config,
  128. )
  129. quant_config = getattr(quant_config, "quant_config", None)
  130. self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
  131. def forward(self, hidden_states):
  132. hidden_states, _ = self.fc1(hidden_states)
  133. hidden_states = self.act(hidden_states)
  134. hidden_states, _ = self.fc2(hidden_states)
  135. return hidden_states
  136. class PhiLayer(nn.Module):
  137. def __init__(self,
  138. config: PretrainedConfig,
  139. quant_config: Optional[QuantizationConfig] = None):
  140. super().__init__()
  141. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  142. eps=config.layer_norm_eps)
  143. self.self_attn = PhiAttention(config, quant_config)
  144. self.mlp = PhiMLP(config, quant_config)
  145. def forward(
  146. self,
  147. position_ids: torch.Tensor,
  148. hidden_states: torch.Tensor,
  149. kv_cache: torch.Tensor,
  150. attn_metadata: AttentionMetadata,
  151. ) -> torch.Tensor:
  152. residual = hidden_states
  153. hidden_states = self.input_layernorm(hidden_states)
  154. attn_outputs = self.self_attn(
  155. position_ids=position_ids,
  156. hidden_states=hidden_states,
  157. kv_cache=kv_cache,
  158. attn_metadata=attn_metadata,
  159. )
  160. feed_forward_hidden_states = self.mlp(hidden_states)
  161. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  162. return hidden_states
  163. class PhiModel(nn.Module):
  164. def __init__(self,
  165. config: PretrainedConfig,
  166. quant_config: Optional[QuantizationConfig] = None):
  167. super().__init__()
  168. self.config = config
  169. self.quant_config = quant_config
  170. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  171. config.hidden_size)
  172. self.layers = nn.ModuleList([
  173. PhiLayer(config, quant_config)
  174. for _ in range(config.num_hidden_layers)
  175. ])
  176. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  177. eps=config.layer_norm_eps)
  178. def forward(
  179. self,
  180. input_ids: torch.Tensor,
  181. positions: torch.Tensor,
  182. kv_caches: List[torch.Tensor],
  183. attn_metadata: AttentionMetadata,
  184. ) -> torch.Tensor:
  185. hidden_states = self.embed_tokens(input_ids)
  186. for i in range(self.config.num_hidden_layers):
  187. layer = self.layers[i]
  188. hidden_states = layer(
  189. positions,
  190. hidden_states,
  191. kv_caches[i],
  192. attn_metadata,
  193. )
  194. hidden_states = self.final_layernorm(hidden_states)
  195. return hidden_states
  196. class PhiForCausalLM(nn.Module):
  197. def __init__(self,
  198. config: PretrainedConfig,
  199. quant_config: Optional[QuantizationConfig] = None):
  200. super().__init__()
  201. self.config = config
  202. self.quant_config = quant_config
  203. self.model = PhiModel(config, quant_config)
  204. self.lm_head = ParallelLMHead(config.vocab_size,
  205. config.hidden_size,
  206. bias=True)
  207. self.logits_processor = LogitsProcessor(config.vocab_size)
  208. self.sampler = Sampler()
  209. def forward(
  210. self,
  211. input_ids: torch.Tensor,
  212. positions: torch.Tensor,
  213. kv_caches: List[torch.Tensor],
  214. attn_metadata: AttentionMetadata,
  215. ) -> torch.Tensor:
  216. hidden_states = self.model(input_ids, positions, kv_caches,
  217. attn_metadata)
  218. return hidden_states
  219. def compute_logits(self, hidden_states: torch.Tensor,
  220. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  221. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  222. sampling_metadata, self.lm_head.bias)
  223. return logits
  224. def sample(
  225. self,
  226. logits: torch.Tensor,
  227. sampling_metadata: SamplingMetadata,
  228. ) -> Optional[SamplerOutput]:
  229. next_tokens = self.sampler(logits, sampling_metadata)
  230. return next_tokens
  231. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  232. stacked_params_mapping = [
  233. # (param_name, shard_name, shard_id)
  234. ("qkv_proj", "q_proj", "q"),
  235. ("qkv_proj", "k_proj", "k"),
  236. ("qkv_proj", "v_proj", "v")
  237. ]
  238. params_dict = dict(self.named_parameters())
  239. for name, loaded_weight in weights:
  240. if "rotary_emb.inv_freq" in name:
  241. continue
  242. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  243. if weight_name not in name:
  244. continue
  245. name = name.replace(weight_name, param_name)
  246. # Skip loading extra bias for GPTQ models.
  247. if name.endswith(".bias") and name not in params_dict:
  248. continue
  249. param = params_dict[name]
  250. weight_loader = param.weight_loader
  251. weight_loader(param, loaded_weight, shard_id)
  252. break
  253. else:
  254. # Skip loading extra bias for GPTQ models.
  255. if name.endswith(".bias") and name not in params_dict:
  256. continue
  257. # pylint: disable=E1136
  258. param = params_dict[name]
  259. weight_loader = getattr(param, "weight_loader",
  260. default_weight_loader)
  261. weight_loader(param, loaded_weight)