1
0

phi3_small.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import math
  2. from typing import Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers.configuration_utils import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, LoRAConfig
  8. from aphrodite.common.sequence import IntermediateTensors
  9. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  10. get_tensor_model_parallel_world_size)
  11. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  12. QKVParallelLinear,
  13. RowParallelLinear)
  14. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  15. from aphrodite.modeling.layers.rotary_embedding import get_rope
  16. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  17. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  18. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  19. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.quantization.base_config import QuantizationConfig
  22. def load_column_parallel_weight(param: torch.nn.Parameter,
  23. loaded_weight: torch.Tensor):
  24. tp = get_tensor_model_parallel_world_size()
  25. rk = get_tensor_model_parallel_rank()
  26. assert param.size(0) * tp == loaded_weight.size(0)
  27. s = rk * param.size(0)
  28. e = (rk + 1) * param.size(0)
  29. loaded_weight = loaded_weight[s:e]
  30. assert param.shape == loaded_weight.shape
  31. param.data.copy_(loaded_weight)
  32. class HeadMajorQKVParallelLinear(QKVParallelLinear):
  33. def weight_loader(self, param: torch.nn.Parameter,
  34. loaded_weight: torch.Tensor):
  35. return load_column_parallel_weight(param, loaded_weight)
  36. class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
  37. def weight_loader(self, param: torch.nn.Parameter,
  38. loaded_weight: torch.Tensor):
  39. return load_column_parallel_weight(param, loaded_weight)
  40. @torch.jit.script
  41. def quick_gelu(x):
  42. return x * torch.sigmoid(1.702 * x)
  43. @torch.jit.script
  44. def gegelu(input, limit: Optional[float] = None):
  45. a_gelu, a_linear = input[..., ::2], input[..., 1::2]
  46. if limit is not None:
  47. a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
  48. a_gelu.clamp(min=None, max=limit))
  49. a_linear = torch.where(
  50. torch.isinf(a_linear),
  51. a_linear,
  52. a_linear.clamp(min=-limit, max=limit),
  53. )
  54. out_gelu = quick_gelu(a_gelu)
  55. return out_gelu * (a_linear + 1)
  56. class Phi3SmallMLP(nn.Module):
  57. def __init__(
  58. self,
  59. config: PretrainedConfig,
  60. quant_config: Optional[QuantizationConfig] = None,
  61. ) -> None:
  62. super().__init__()
  63. self.config = config
  64. assert (self.config.hidden_act == "gegelu"
  65. ), "Only `gegelu` is supported for the 4.7 series of models .."
  66. self.hidden_size = config.hidden_size
  67. self.gegelu_limit = config.gegelu_limit
  68. self.intermediate_size = config.intermediate_size
  69. self.up_proj = HeadMajorColumnParallelLinear(
  70. self.hidden_size,
  71. 2 * [self.intermediate_size],
  72. bias=True,
  73. quant_config=quant_config,
  74. )
  75. self.down_proj = RowParallelLinear(
  76. self.intermediate_size,
  77. self.hidden_size,
  78. bias=True,
  79. quant_config=quant_config,
  80. )
  81. def forward(self, x):
  82. gate_up, _ = self.up_proj(x)
  83. x = gegelu(gate_up)
  84. x, _ = self.down_proj(x)
  85. return x
  86. class Phi3SmallSelfAttention(nn.Module):
  87. def __init__(
  88. self,
  89. config: PretrainedConfig,
  90. layer_idx: int,
  91. cache_config: Optional[CacheConfig] = None,
  92. quant_config: Optional[QuantizationConfig] = None,
  93. ) -> None:
  94. super().__init__()
  95. self.layer_idx = layer_idx
  96. self.config = config
  97. self.sparse_block_size = config.blocksparse_block_size
  98. self.homo_heads = config.blocksparse_homo_head_pattern
  99. self.local_blocks = config.blocksparse_num_local_blocks
  100. self.vert_stride = config.blocksparse_vert_stride
  101. assert (config.blocksparse_block_size ==
  102. config.blocksparse_triton_kernel_block_size)
  103. self.hidden_size = config.hidden_size
  104. # Number of Query Heads
  105. self.num_heads = config.num_attention_heads
  106. self.head_dim = self.hidden_size // self.num_heads
  107. self.tp_size = get_tensor_model_parallel_world_size()
  108. # Number of total Key Value Heads before tensor parallel
  109. self.num_key_value_heads = config.num_key_value_heads
  110. self.num_q_per_kv = self.num_heads // self.num_key_value_heads
  111. if self.tp_size > 1:
  112. assert self.num_key_value_heads % self.tp_size == 0
  113. self.num_kv_heads_per_partion = max(
  114. 1, self.num_key_value_heads // self.tp_size)
  115. self.num_heads_per_partition = self.num_heads // self.tp_size
  116. self.max_position_embeddings = config.max_position_embeddings
  117. self.rope_embedding_base = config.rope_embedding_base
  118. self.rope_position_scale = config.rope_position_scale
  119. self.is_causal = True
  120. norm_factor = None
  121. if config.mup_use_scaling:
  122. norm_factor = self.head_dim / config.mup_attn_multiplier
  123. else:
  124. norm_factor = math.sqrt(self.head_dim)
  125. self.scale = 1 / norm_factor
  126. self.query_key_value = HeadMajorQKVParallelLinear(
  127. self.hidden_size,
  128. self.head_dim,
  129. self.num_heads,
  130. self.num_key_value_heads,
  131. bias=True,
  132. quant_config=quant_config,
  133. )
  134. self.dense = RowParallelLinear(self.hidden_size,
  135. self.hidden_size,
  136. bias=True,
  137. quant_config=quant_config)
  138. if getattr(self.config, "rope_scaling", None) is not None:
  139. rope_scaling = self.config.rope_scaling
  140. for key in rope_scaling:
  141. if isinstance(rope_scaling[key], list):
  142. rope_scaling[key] = tuple(rope_scaling[key])
  143. if "factor" not in rope_scaling:
  144. rope_scaling["factor"] = self.rope_position_scale
  145. else:
  146. rope_scaling = {
  147. "type": "linear",
  148. "factor": self.rope_position_scale,
  149. }
  150. self.rotary_emb = get_rope(
  151. self.head_dim,
  152. rotary_dim=self.head_dim,
  153. max_position=self.max_position_embeddings,
  154. base=self.rope_embedding_base,
  155. rope_scaling=rope_scaling,
  156. )
  157. # blocksparse params
  158. self.blocksparse_block_size = config.blocksparse_block_size
  159. self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
  160. self.blocksparse_vert_stride = config.blocksparse_vert_stride
  161. use_dense_attn = (getattr(self.config,
  162. "dense_attention_every_n_layers", None)
  163. and (self.layer_idx + 1) %
  164. self.config.dense_attention_every_n_layers == 0)
  165. bs_params = None
  166. if not use_dense_attn:
  167. bs_params = {
  168. 'max_seqlen': self.max_position_embeddings,
  169. 'num_heads': self.num_heads_per_partition,
  170. "num_kv_heads": self.num_kv_heads_per_partion,
  171. "block_size": self.sparse_block_size,
  172. "local_blocks": self.local_blocks,
  173. "vert_stride": self.vert_stride,
  174. "homo_head": self.homo_heads
  175. }
  176. self.attn = Attention(
  177. self.num_heads_per_partition,
  178. self.head_dim,
  179. self.scale,
  180. num_kv_heads=self.num_kv_heads_per_partion,
  181. cache_config=cache_config,
  182. quant_config=quant_config,
  183. blocksparse_params=bs_params,
  184. )
  185. def forward(
  186. self,
  187. positions: torch.Tensor,
  188. hidden_states: torch.Tensor,
  189. kv_cache: torch.Tensor,
  190. attn_metadata: AttentionMetadata,
  191. ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
  192. Optional[Tuple[torch.Tensor]]]:
  193. qkv, _ = self.query_key_value(hidden_states)
  194. qkv = qkv.view(qkv.shape[:-1] +
  195. (-1, (self.num_q_per_kv + 2), self.head_dim))
  196. q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
  197. # NOTE: this is required by RotaryEmbed, which indeed does not have to
  198. # TODO: allow 3D QK for rotary forward
  199. q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
  200. k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
  201. v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
  202. q, k = self.rotary_emb(positions, q, k)
  203. attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
  204. output, _ = self.dense(attn_output)
  205. return output
  206. class Phi3SmallDecoderLayer(nn.Module):
  207. def __init__(
  208. self,
  209. config: PretrainedConfig,
  210. layer_idx: int,
  211. cache_config: Optional[CacheConfig] = None,
  212. quant_config: Optional[QuantizationConfig] = None,
  213. ):
  214. super().__init__()
  215. self.hidden_size = config.hidden_size
  216. self.self_attn = Phi3SmallSelfAttention(config,
  217. layer_idx,
  218. cache_config=cache_config,
  219. quant_config=quant_config)
  220. self.mlp = Phi3SmallMLP(config, quant_config)
  221. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  222. eps=config.layer_norm_epsilon)
  223. self.post_attention_layernorm = nn.LayerNorm(
  224. config.hidden_size, eps=config.layer_norm_epsilon)
  225. def forward(
  226. self,
  227. positions: torch.Tensor,
  228. hidden_states: torch.Tensor,
  229. kv_cache: torch.Tensor,
  230. attn_metadata: AttentionMetadata,
  231. ) -> torch.Tensor:
  232. residual = hidden_states
  233. hidden_states = self.input_layernorm(hidden_states)
  234. hidden_states = self.self_attn(
  235. positions=positions,
  236. hidden_states=hidden_states,
  237. kv_cache=kv_cache,
  238. attn_metadata=attn_metadata,
  239. )
  240. hidden_states = residual + hidden_states
  241. residual = hidden_states
  242. hidden_states = self.post_attention_layernorm(hidden_states)
  243. hidden_states = self.mlp(hidden_states)
  244. hidden_states = residual + hidden_states
  245. return hidden_states
  246. class Phi3SmallModel(nn.Module):
  247. def __init__(
  248. self,
  249. config: PretrainedConfig,
  250. cache_config: Optional[CacheConfig] = None,
  251. quant_config: Optional[QuantizationConfig] = None,
  252. ):
  253. super().__init__()
  254. self.config = config
  255. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  256. config.hidden_size)
  257. self.mup_embedding_multiplier = config.mup_embedding_multiplier
  258. self.layers = nn.ModuleList([
  259. Phi3SmallDecoderLayer(config, layer_idx, cache_config,
  260. quant_config)
  261. for layer_idx in range(config.num_hidden_layers)
  262. ])
  263. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  264. eps=config.layer_norm_epsilon)
  265. def get_input_embeddings(self):
  266. return self.embed_tokens
  267. def set_input_embeddings(self, value):
  268. self.embed_tokens = value
  269. def forward(
  270. self,
  271. input_ids: torch.LongTensor,
  272. positions: Optional[torch.LongTensor],
  273. kv_caches: List[torch.Tensor],
  274. attn_metadata: AttentionMetadata = None,
  275. ):
  276. hidden_states = self.embed_tokens(input_ids)
  277. if (self.mup_embedding_multiplier is not None
  278. and self.mup_embedding_multiplier > 0.0):
  279. hidden_states = hidden_states * self.mup_embedding_multiplier
  280. for i in range(len(self.layers)):
  281. layer = self.layers[i]
  282. hidden_states = layer(
  283. positions,
  284. hidden_states,
  285. kv_caches[i],
  286. attn_metadata,
  287. )
  288. hidden_states = self.final_layernorm(hidden_states)
  289. return hidden_states
  290. class Phi3SmallForCausalLM(nn.Module):
  291. _tied_weights_keys = ["lm_head.weight"]
  292. def __init__(
  293. self,
  294. config,
  295. cache_config: Optional[CacheConfig] = None,
  296. quant_config: Optional[QuantizationConfig] = None,
  297. lora_config: Optional[LoRAConfig] = None,
  298. ):
  299. super().__init__()
  300. self.config = config
  301. self.quant_config = quant_config
  302. self.model = Phi3SmallModel(config, cache_config, quant_config)
  303. self.vocab_size = config.vocab_size
  304. self.mup_width_multiplier = config.mup_width_multiplier
  305. self.lm_head = ParallelLMHead(
  306. self.vocab_size,
  307. config.hidden_size,
  308. org_num_embeddings=config.vocab_size,
  309. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  310. quant_config=quant_config,
  311. )
  312. self.logits_processor = LogitsProcessor(config.vocab_size)
  313. self.sampler = Sampler()
  314. # tokens in tiktoken but not used
  315. if hasattr(config, 'dummy_token_indices'):
  316. device = self.lm_head.weight.device
  317. self.register_buffer('dummy_token_indices',
  318. torch.LongTensor(
  319. config.dummy_token_indices).to(device),
  320. persistent=False)
  321. else:
  322. self.dummy_token_indices = None
  323. def get_input_embeddings(self):
  324. return self.model.embed_tokens
  325. def set_input_embeddings(self, value):
  326. self.model.embed_tokens = value
  327. def get_output_embeddings(self):
  328. return self.lm_head
  329. def set_output_embeddings(self, value):
  330. self.lm_head = value
  331. def set_decoder(self, decoder):
  332. self.model = decoder
  333. def get_decoder(self):
  334. return self.model
  335. def compute_logits(
  336. self,
  337. hidden_states: torch.Tensor,
  338. sampling_metadata: SamplingMetadata,
  339. ) -> Optional[torch.Tensor]:
  340. logits = self.logits_processor(self.lm_head, hidden_states,
  341. sampling_metadata)
  342. if self.dummy_token_indices is not None and logits is not None:
  343. logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
  344. return logits
  345. def forward(
  346. self,
  347. input_ids: torch.LongTensor,
  348. positions: Optional[torch.LongTensor],
  349. kv_caches: List[torch.Tensor],
  350. attn_metadata: AttentionMetadata,
  351. intermediate_tensors: Optional[IntermediateTensors] = None,
  352. ) -> torch.Tensor:
  353. output_hidden_states = self.model(
  354. input_ids=input_ids,
  355. positions=positions,
  356. kv_caches=kv_caches,
  357. attn_metadata=attn_metadata,
  358. )
  359. output_hidden_states = output_hidden_states
  360. return output_hidden_states
  361. def sample(
  362. self,
  363. logits: torch.Tensor,
  364. sampling_metadata: SamplingMetadata,
  365. ) -> Optional[SamplerOutput]:
  366. next_tokens = self.sampler(logits / self.mup_width_multiplier,
  367. sampling_metadata)
  368. return next_tokens
  369. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  370. params_dict = dict(self.named_parameters())
  371. for name, loaded_weight in weights:
  372. if "rotary_emb.inv_freq" in name:
  373. continue
  374. if name.endswith(".bias") and name not in params_dict:
  375. continue
  376. param = params_dict[name]
  377. weight_loader = getattr(param, "weight_loader",
  378. default_weight_loader)
  379. weight_loader(param, loaded_weight)
  380. self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)