1
0

phi3_small.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. import math
  2. from typing import Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers.configuration_utils import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, LoRAConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.common.utils import progress_bar
  10. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size)
  12. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  13. QKVParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. def load_column_parallel_weight(param: torch.nn.Parameter,
  24. loaded_weight: torch.Tensor):
  25. tp = get_tensor_model_parallel_world_size()
  26. rk = get_tensor_model_parallel_rank()
  27. assert param.size(0) * tp == loaded_weight.size(0)
  28. s = rk * param.size(0)
  29. e = (rk + 1) * param.size(0)
  30. loaded_weight = loaded_weight[s:e]
  31. assert param.shape == loaded_weight.shape
  32. param.data.copy_(loaded_weight)
  33. class HeadMajorQKVParallelLinear(QKVParallelLinear):
  34. def weight_loader(self, param: torch.nn.Parameter,
  35. loaded_weight: torch.Tensor):
  36. return load_column_parallel_weight(param, loaded_weight)
  37. class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
  38. def weight_loader(self, param: torch.nn.Parameter,
  39. loaded_weight: torch.Tensor):
  40. return load_column_parallel_weight(param, loaded_weight)
  41. @torch.jit.script
  42. def quick_gelu(x):
  43. return x * torch.sigmoid(1.702 * x)
  44. @torch.jit.script
  45. def gegelu(input, limit: Optional[float] = None):
  46. a_gelu, a_linear = input[..., ::2], input[..., 1::2]
  47. if limit is not None:
  48. a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
  49. a_gelu.clamp(min=None, max=limit))
  50. a_linear = torch.where(
  51. torch.isinf(a_linear),
  52. a_linear,
  53. a_linear.clamp(min=-limit, max=limit),
  54. )
  55. out_gelu = quick_gelu(a_gelu)
  56. return out_gelu * (a_linear + 1)
  57. class Phi3SmallMLP(nn.Module):
  58. def __init__(
  59. self,
  60. config: PretrainedConfig,
  61. quant_config: Optional[QuantizationConfig] = None,
  62. ) -> None:
  63. super().__init__()
  64. self.config = config
  65. assert (self.config.hidden_act == "gegelu"
  66. ), "Only `gegelu` is supported for the 4.7 series of models .."
  67. self.hidden_size = config.hidden_size
  68. self.gegelu_limit = config.gegelu_limit
  69. self.intermediate_size = config.intermediate_size
  70. self.up_proj = HeadMajorColumnParallelLinear(
  71. self.hidden_size,
  72. 2 * [self.intermediate_size],
  73. bias=True,
  74. quant_config=quant_config,
  75. )
  76. self.down_proj = RowParallelLinear(
  77. self.intermediate_size,
  78. self.hidden_size,
  79. bias=True,
  80. quant_config=quant_config,
  81. )
  82. def forward(self, x):
  83. gate_up, _ = self.up_proj(x)
  84. x = gegelu(gate_up)
  85. x, _ = self.down_proj(x)
  86. return x
  87. class Phi3SmallSelfAttention(nn.Module):
  88. def __init__(
  89. self,
  90. config: PretrainedConfig,
  91. layer_idx: int,
  92. cache_config: Optional[CacheConfig] = None,
  93. quant_config: Optional[QuantizationConfig] = None,
  94. ) -> None:
  95. super().__init__()
  96. self.layer_idx = layer_idx
  97. self.config = config
  98. self.sparse_block_size = config.blocksparse_block_size
  99. self.homo_heads = config.blocksparse_homo_head_pattern
  100. self.local_blocks = config.blocksparse_num_local_blocks
  101. self.vert_stride = config.blocksparse_vert_stride
  102. assert (config.blocksparse_block_size ==
  103. config.blocksparse_triton_kernel_block_size)
  104. self.hidden_size = config.hidden_size
  105. # Number of Query Heads
  106. self.num_heads = config.num_attention_heads
  107. self.head_dim = self.hidden_size // self.num_heads
  108. self.tp_size = get_tensor_model_parallel_world_size()
  109. # Number of total Key Value Heads before tensor parallel
  110. self.num_key_value_heads = config.num_key_value_heads
  111. self.num_q_per_kv = self.num_heads // self.num_key_value_heads
  112. if self.tp_size > 1:
  113. assert self.num_key_value_heads % self.tp_size == 0
  114. self.num_kv_heads_per_partion = max(
  115. 1, self.num_key_value_heads // self.tp_size)
  116. self.num_heads_per_partition = self.num_heads // self.tp_size
  117. self.max_position_embeddings = config.max_position_embeddings
  118. self.rope_embedding_base = config.rope_embedding_base
  119. self.rope_position_scale = config.rope_position_scale
  120. self.is_causal = True
  121. norm_factor = None
  122. if config.mup_use_scaling:
  123. norm_factor = self.head_dim / config.mup_attn_multiplier
  124. else:
  125. norm_factor = math.sqrt(self.head_dim)
  126. self.scale = 1 / norm_factor
  127. self.query_key_value = HeadMajorQKVParallelLinear(
  128. self.hidden_size,
  129. self.head_dim,
  130. self.num_heads,
  131. self.num_key_value_heads,
  132. bias=True,
  133. quant_config=quant_config,
  134. )
  135. self.dense = RowParallelLinear(self.hidden_size,
  136. self.hidden_size,
  137. bias=True,
  138. quant_config=quant_config)
  139. if getattr(self.config, "rope_scaling", None) is not None:
  140. rope_scaling = self.config.rope_scaling
  141. for key in rope_scaling:
  142. if isinstance(rope_scaling[key], list):
  143. rope_scaling[key] = tuple(rope_scaling[key])
  144. if "factor" not in rope_scaling:
  145. rope_scaling["factor"] = self.rope_position_scale
  146. else:
  147. rope_scaling = {
  148. "type": "linear",
  149. "factor": self.rope_position_scale,
  150. }
  151. self.rotary_emb = get_rope(
  152. self.head_dim,
  153. rotary_dim=self.head_dim,
  154. max_position=self.max_position_embeddings,
  155. base=self.rope_embedding_base,
  156. rope_scaling=rope_scaling,
  157. )
  158. # blocksparse params
  159. self.blocksparse_block_size = config.blocksparse_block_size
  160. self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
  161. self.blocksparse_vert_stride = config.blocksparse_vert_stride
  162. use_dense_attn = (getattr(self.config,
  163. "dense_attention_every_n_layers", None)
  164. and (self.layer_idx + 1) %
  165. self.config.dense_attention_every_n_layers == 0)
  166. bs_params = None
  167. if not use_dense_attn:
  168. bs_params = {
  169. 'max_seqlen': self.max_position_embeddings,
  170. 'num_heads': self.num_heads_per_partition,
  171. "num_kv_heads": self.num_kv_heads_per_partion,
  172. "block_size": self.sparse_block_size,
  173. "local_blocks": self.local_blocks,
  174. "vert_stride": self.vert_stride,
  175. "homo_head": self.homo_heads
  176. }
  177. self.attn = Attention(
  178. self.num_heads_per_partition,
  179. self.head_dim,
  180. self.scale,
  181. num_kv_heads=self.num_kv_heads_per_partion,
  182. cache_config=cache_config,
  183. quant_config=quant_config,
  184. blocksparse_params=bs_params,
  185. )
  186. def forward(
  187. self,
  188. positions: torch.Tensor,
  189. hidden_states: torch.Tensor,
  190. kv_cache: torch.Tensor,
  191. attn_metadata: AttentionMetadata,
  192. ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
  193. Optional[Tuple[torch.Tensor]]]:
  194. qkv, _ = self.query_key_value(hidden_states)
  195. qkv = qkv.view(qkv.shape[:-1] +
  196. (-1, (self.num_q_per_kv + 2), self.head_dim))
  197. q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
  198. # NOTE: this is required by RotaryEmbed, which indeed does not have to
  199. # TODO: allow 3D QK for rotary forward
  200. q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
  201. k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
  202. v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
  203. q, k = self.rotary_emb(positions, q, k)
  204. attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
  205. output, _ = self.dense(attn_output)
  206. return output
  207. class Phi3SmallDecoderLayer(nn.Module):
  208. def __init__(
  209. self,
  210. config: PretrainedConfig,
  211. layer_idx: int,
  212. cache_config: Optional[CacheConfig] = None,
  213. quant_config: Optional[QuantizationConfig] = None,
  214. ):
  215. super().__init__()
  216. self.hidden_size = config.hidden_size
  217. self.self_attn = Phi3SmallSelfAttention(config,
  218. layer_idx,
  219. cache_config=cache_config,
  220. quant_config=quant_config)
  221. self.mlp = Phi3SmallMLP(config, quant_config)
  222. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  223. eps=config.layer_norm_epsilon)
  224. self.post_attention_layernorm = nn.LayerNorm(
  225. config.hidden_size, eps=config.layer_norm_epsilon)
  226. def forward(
  227. self,
  228. positions: torch.Tensor,
  229. hidden_states: torch.Tensor,
  230. kv_cache: torch.Tensor,
  231. attn_metadata: AttentionMetadata,
  232. ) -> torch.Tensor:
  233. residual = hidden_states
  234. hidden_states = self.input_layernorm(hidden_states)
  235. hidden_states = self.self_attn(
  236. positions=positions,
  237. hidden_states=hidden_states,
  238. kv_cache=kv_cache,
  239. attn_metadata=attn_metadata,
  240. )
  241. hidden_states = residual + hidden_states
  242. residual = hidden_states
  243. hidden_states = self.post_attention_layernorm(hidden_states)
  244. hidden_states = self.mlp(hidden_states)
  245. hidden_states = residual + hidden_states
  246. return hidden_states
  247. class Phi3SmallModel(nn.Module):
  248. def __init__(
  249. self,
  250. config: PretrainedConfig,
  251. cache_config: Optional[CacheConfig] = None,
  252. quant_config: Optional[QuantizationConfig] = None,
  253. ):
  254. super().__init__()
  255. self.config = config
  256. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  257. config.hidden_size)
  258. self.mup_embedding_multiplier = config.mup_embedding_multiplier
  259. self.layers = nn.ModuleList([
  260. Phi3SmallDecoderLayer(config, layer_idx, cache_config,
  261. quant_config)
  262. for layer_idx in range(config.num_hidden_layers)
  263. ])
  264. self.final_layernorm = nn.LayerNorm(config.hidden_size,
  265. eps=config.layer_norm_epsilon)
  266. def get_input_embeddings(self):
  267. return self.embed_tokens
  268. def set_input_embeddings(self, value):
  269. self.embed_tokens = value
  270. def forward(
  271. self,
  272. input_ids: torch.LongTensor,
  273. positions: Optional[torch.LongTensor],
  274. kv_caches: List[torch.Tensor],
  275. attn_metadata: AttentionMetadata = None,
  276. ):
  277. hidden_states = self.embed_tokens(input_ids)
  278. if (self.mup_embedding_multiplier is not None
  279. and self.mup_embedding_multiplier > 0.0):
  280. hidden_states = hidden_states * self.mup_embedding_multiplier
  281. for i in range(len(self.layers)):
  282. layer = self.layers[i]
  283. hidden_states = layer(
  284. positions,
  285. hidden_states,
  286. kv_caches[i],
  287. attn_metadata,
  288. )
  289. hidden_states = self.final_layernorm(hidden_states)
  290. return hidden_states
  291. class Phi3SmallForCausalLM(nn.Module):
  292. _tied_weights_keys = ["lm_head.weight"]
  293. def __init__(
  294. self,
  295. config,
  296. cache_config: Optional[CacheConfig] = None,
  297. quant_config: Optional[QuantizationConfig] = None,
  298. lora_config: Optional[LoRAConfig] = None,
  299. ):
  300. super().__init__()
  301. self.config = config
  302. self.quant_config = quant_config
  303. self.model = Phi3SmallModel(config, cache_config, quant_config)
  304. self.vocab_size = config.vocab_size
  305. self.mup_width_multiplier = config.mup_width_multiplier
  306. self.lm_head = ParallelLMHead(
  307. self.vocab_size,
  308. config.hidden_size,
  309. org_num_embeddings=config.vocab_size,
  310. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  311. quant_config=quant_config,
  312. )
  313. self.logits_processor = LogitsProcessor(config.vocab_size)
  314. self.sampler = Sampler()
  315. # tokens in tiktoken but not used
  316. if hasattr(config, 'dummy_token_indices'):
  317. device = self.lm_head.weight.device
  318. self.register_buffer('dummy_token_indices',
  319. torch.LongTensor(
  320. config.dummy_token_indices).to(device),
  321. persistent=False)
  322. else:
  323. self.dummy_token_indices = None
  324. def get_input_embeddings(self):
  325. return self.model.embed_tokens
  326. def set_input_embeddings(self, value):
  327. self.model.embed_tokens = value
  328. def get_output_embeddings(self):
  329. return self.lm_head
  330. def set_output_embeddings(self, value):
  331. self.lm_head = value
  332. def set_decoder(self, decoder):
  333. self.model = decoder
  334. def get_decoder(self):
  335. return self.model
  336. def compute_logits(
  337. self,
  338. hidden_states: torch.Tensor,
  339. sampling_metadata: SamplingMetadata,
  340. ) -> Optional[torch.Tensor]:
  341. logits = self.logits_processor(self.lm_head, hidden_states,
  342. sampling_metadata)
  343. if self.dummy_token_indices is not None and logits is not None:
  344. logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
  345. return logits
  346. def forward(
  347. self,
  348. input_ids: torch.LongTensor,
  349. positions: Optional[torch.LongTensor],
  350. kv_caches: List[torch.Tensor],
  351. attn_metadata: AttentionMetadata,
  352. intermediate_tensors: Optional[IntermediateTensors] = None,
  353. ) -> torch.Tensor:
  354. output_hidden_states = self.model(
  355. input_ids=input_ids,
  356. positions=positions,
  357. kv_caches=kv_caches,
  358. attn_metadata=attn_metadata,
  359. )
  360. output_hidden_states = output_hidden_states
  361. return output_hidden_states
  362. def sample(
  363. self,
  364. logits: torch.Tensor,
  365. sampling_metadata: SamplingMetadata,
  366. ) -> Optional[SamplerOutput]:
  367. next_tokens = self.sampler(logits / self.mup_width_multiplier,
  368. sampling_metadata)
  369. return next_tokens
  370. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  371. params_dict = dict(self.named_parameters())
  372. weights_list = list(weights)
  373. for name, loaded_weight in progress_bar(weights_list,
  374. desc="Loading modules..."):
  375. if "rotary_emb.inv_freq" in name:
  376. continue
  377. if name.endswith(".bias") and name not in params_dict:
  378. continue
  379. param = params_dict[name]
  380. weight_loader = getattr(param, "weight_loader",
  381. default_weight_loader)
  382. weight_loader(param, loaded_weight)
  383. self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)