phi1_5.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from transformers import PretrainedConfig
  5. from aphrodite.modeling.metadata import InputMetadata
  6. from aphrodite.modeling.layers.activation import get_act_fn
  7. from aphrodite.modeling.layers.attention import PagedAttention
  8. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  9. RowParallelLinear,
  10. QKVParallelLinear,
  11. LinearMethodBase)
  12. from aphrodite.modeling.layers.rotary_embedding import get_rope
  13. from aphrodite.modeling.layers.sampler import Sampler
  14. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  15. VocabParallelEmbedding, ParallelLMHead)
  16. from aphrodite.modeling.megatron.parallel_state import (
  17. get_tensor_model_parallel_world_size)
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  20. hf_model_weights_iterator)
  21. from aphrodite.common.sequence import SamplerOutput
  22. KVCache = Tuple[torch.Tensor, torch.Tensor]
  23. class PhiEmbedding(nn.Module):
  24. def __init__(self, config: PretrainedConfig):
  25. super().__init__()
  26. self.wte = VocabParallelEmbedding(
  27. config.vocab_size,
  28. config.hidden_size,
  29. )
  30. def forward(self, input_ids: torch.LongTensor):
  31. return self.wte(input_ids)
  32. class PhiAttention(nn.Module):
  33. def __init__(self,
  34. config: PretrainedConfig,
  35. linear_method: Optional[LinearMethodBase] = None):
  36. super().__init__()
  37. self.total_num_heads = config.num_attention_heads
  38. self.hidden_size = config.hidden_size
  39. self.head_size = self.hidden_size // self.total_num_heads
  40. tensor_model_parallel_world_size = (
  41. get_tensor_model_parallel_world_size())
  42. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  43. self.num_heads = (self.total_num_heads //
  44. tensor_model_parallel_world_size)
  45. # pylint: disable=C0103
  46. self.Wqkv = QKVParallelLinear(
  47. self.hidden_size,
  48. self.head_size,
  49. self.total_num_heads,
  50. linear_method=linear_method,
  51. )
  52. self.qkv_proj = QKVParallelLinear(
  53. config.hidden_size,
  54. self.head_size,
  55. self.total_num_heads,
  56. bias=False,
  57. linear_method=linear_method,
  58. )
  59. self.out_proj = RowParallelLinear(
  60. self.hidden_size,
  61. self.hidden_size,
  62. linear_method=linear_method,
  63. )
  64. scaling = self.head_size**-0.5
  65. rotary_dim = config.rotary_dim
  66. assert rotary_dim % 2 == 0
  67. # pylint: disable=C0301
  68. # Refer to:
  69. # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
  70. rope_theta = 10000
  71. max_position_embeddings = getattr(config, "n_positions", 2048)
  72. self.rotary_emb = get_rope(
  73. self.head_size,
  74. rotary_dim=rotary_dim,
  75. max_position=max_position_embeddings,
  76. base=rope_theta,
  77. )
  78. self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
  79. def forward(
  80. self,
  81. position_ids: torch.Tensor,
  82. hidden_states: torch.Tensor,
  83. kv_cache: KVCache,
  84. input_metadata: InputMetadata,
  85. cache_event: Optional[torch.cuda.Event],
  86. ) -> torch.Tensor:
  87. qkv, _ = self.Wqkv(hidden_states)
  88. q, k, v = qkv.chunk(chunks=3, dim=-1)
  89. q, k = self.rotary_emb(position_ids, q, k)
  90. k_cache, v_cache = kv_cache
  91. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
  92. cache_event)
  93. output, _ = self.out_proj(attn_output)
  94. return output
  95. class PhiMLP(nn.Module):
  96. def __init__(self,
  97. config: PretrainedConfig,
  98. linear_method: Optional[LinearMethodBase] = None):
  99. super().__init__()
  100. n_inner = getattr(config, "n_inner", None)
  101. n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
  102. self.fc1 = ColumnParallelLinear(
  103. config.hidden_size,
  104. n_inner,
  105. linear_method=linear_method,
  106. )
  107. self.fc2 = RowParallelLinear(
  108. n_inner,
  109. config.hidden_size,
  110. linear_method=linear_method,
  111. )
  112. quant_config = getattr(linear_method, "quant_config", None)
  113. self.act = get_act_fn(config.activation_function, quant_config,
  114. n_inner)
  115. def forward(self, hidden_states):
  116. hidden_states, _ = self.fc1(hidden_states)
  117. hidden_states = self.act(hidden_states)
  118. hidden_states, _ = self.fc2(hidden_states)
  119. return hidden_states
  120. class PhiLayer(nn.Module):
  121. def __init__(self,
  122. config: PretrainedConfig,
  123. linear_method: Optional[LinearMethodBase] = None):
  124. super().__init__()
  125. self.ln = nn.LayerNorm(config.hidden_size,
  126. eps=config.layer_norm_epsilon)
  127. self.mixer = PhiAttention(config, linear_method)
  128. self.mlp = PhiMLP(config, linear_method)
  129. def forward(
  130. self,
  131. position_ids: torch.Tensor,
  132. hidden_states: torch.Tensor,
  133. kv_cache: KVCache,
  134. input_metadata: InputMetadata,
  135. cache_event: Optional[torch.cuda.Event],
  136. ) -> torch.Tensor:
  137. residual = hidden_states
  138. hidden_states = self.ln(hidden_states)
  139. attn_outputs = self.mixer(
  140. position_ids=position_ids,
  141. hidden_states=hidden_states,
  142. kv_cache=kv_cache,
  143. input_metadata=input_metadata,
  144. cache_event=cache_event,
  145. )
  146. feed_forward_hidden_states = self.mlp(hidden_states)
  147. hidden_states = attn_outputs + feed_forward_hidden_states + residual
  148. return hidden_states
  149. class PhiModel(nn.Module):
  150. def __init__(self,
  151. config: PretrainedConfig,
  152. linear_method: Optional[LinearMethodBase] = None):
  153. super().__init__()
  154. self.config = config
  155. self.linear_method = linear_method
  156. self.embd = PhiEmbedding(config)
  157. self.h = nn.ModuleList([
  158. PhiLayer(config, linear_method)
  159. for _ in range(config.num_hidden_layers)
  160. ])
  161. def forward(
  162. self,
  163. input_ids: torch.Tensor,
  164. positions: torch.Tensor,
  165. kv_caches: List[KVCache],
  166. input_metadata: InputMetadata,
  167. cache_events: Optional[List[torch.cuda.Event]],
  168. ) -> torch.Tensor:
  169. hidden_states = self.embd(input_ids)
  170. for i in range(self.config.num_hidden_layers):
  171. if cache_events is None:
  172. cache_event = None
  173. else:
  174. cache_event = cache_events[i]
  175. layer = self.h[i]
  176. hidden_states = layer(
  177. positions,
  178. hidden_states,
  179. kv_caches[i],
  180. input_metadata,
  181. cache_event,
  182. )
  183. return hidden_states
  184. class PhiCausalLMHead(nn.Module):
  185. def __init__(self, config: PretrainedConfig):
  186. super().__init__()
  187. self.ln = nn.LayerNorm(config.hidden_size,
  188. eps=config.layer_norm_epsilon)
  189. self.linear = ParallelLMHead(config.vocab_size,
  190. config.hidden_size,
  191. bias=True)
  192. class PhiForCausalLM(nn.Module):
  193. def __init__(self,
  194. config: PretrainedConfig,
  195. linear_method: Optional[LinearMethodBase] = None):
  196. super().__init__()
  197. self.config = config
  198. self.linear_method = linear_method
  199. self.transformer = PhiModel(config, linear_method)
  200. self.lm_head = PhiCausalLMHead(config)
  201. self.sampler = Sampler(config.vocab_size)
  202. def forward(
  203. self,
  204. input_ids: torch.Tensor,
  205. positions: torch.Tensor,
  206. kv_caches: List[KVCache],
  207. input_metadata: InputMetadata,
  208. cache_events: Optional[List[torch.cuda.Event]],
  209. ) -> torch.Tensor:
  210. hidden_states = self.transformer(input_ids, positions, kv_caches,
  211. input_metadata, cache_events)
  212. hidden_states = self.lm_head.ln(hidden_states)
  213. return hidden_states
  214. def sample(
  215. self,
  216. hidden_states: torch.Tensor,
  217. sampling_metadata: SamplingMetadata,
  218. ) -> SamplerOutput:
  219. head = self.lm_head.linear
  220. next_tokens = self.sampler(head.weight, hidden_states,
  221. sampling_metadata, head.bias)
  222. return next_tokens
  223. def load_weights(self,
  224. model_name_or_path: str,
  225. cache_dir: Optional[str] = None,
  226. load_format: str = "auto",
  227. revision: Optional[str] = None):
  228. params_dict = dict(self.named_parameters())
  229. for name, loaded_weight in hf_model_weights_iterator(
  230. model_name_or_path, cache_dir, load_format, revision):
  231. if "rotary_emb.inv_freq" in name:
  232. continue
  233. # skip loading extra bias for GPTQ models
  234. if name.endswith("bias") and name not in params_dict:
  235. continue
  236. param = params_dict[name]
  237. weight_loader = getattr(param, "weight_loader",
  238. default_weight_loader)
  239. weight_loader(param, loaded_weight)