mpt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # coding=utf-8
  2. # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
  3. import math
  4. from typing import List, Optional
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.attention import Attention, AttentionMetadata
  8. from aphrodite.modeling.layers.activation import get_act_fn
  9. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  10. LinearMethodBase,
  11. QKVParallelLinear,
  12. RowParallelLinear)
  13. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  16. VocabParallelEmbedding, ParallelLMHead)
  17. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  18. get_tensor_model_parallel_world_size)
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  21. hf_model_weights_iterator)
  22. from aphrodite.common.sequence import SamplerOutput
  23. from aphrodite.transformers_utils.configs.mpt import MPTConfig
  24. def _get_alibi_slopes(
  25. total_num_heads: int,
  26. alibi_bias_max: int,
  27. ) -> torch.Tensor:
  28. next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
  29. m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
  30. m = m.mul(alibi_bias_max / next_power_of_2)
  31. slopes = 1.0 / torch.pow(2, m)
  32. if next_power_of_2 != total_num_heads:
  33. slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
  34. return slopes
  35. class MPTAttention(nn.Module):
  36. def __init__(
  37. self,
  38. config: MPTConfig,
  39. linear_method: Optional[LinearMethodBase] = None,
  40. ):
  41. super().__init__()
  42. self.d_model = config.d_model
  43. self.total_num_heads = config.n_heads
  44. self.head_dim = self.d_model // self.total_num_heads
  45. self.clip_qkv = config.attn_config["clip_qkv"]
  46. self.qk_ln = config.attn_config["qk_ln"]
  47. self.alibi_bias_max = config.attn_config["alibi_bias_max"]
  48. if "kv_n_heads" in config.attn_config:
  49. self.total_num_kv_heads = config.attn_config['kv_n_heads']
  50. else:
  51. self.total_num_kv_heads = self.total_num_heads
  52. assert not config.attn_config["prefix_lm"]
  53. assert config.attn_config["alibi"]
  54. # pylint: disable=invalid-name
  55. self.Wqkv = QKVParallelLinear(
  56. self.d_model,
  57. self.d_model // self.total_num_heads,
  58. self.total_num_heads,
  59. self.total_num_kv_heads,
  60. bias=not config.no_bias,
  61. linear_method=linear_method,
  62. )
  63. if self.qk_ln:
  64. self.q_ln = nn.LayerNorm(self.d_model)
  65. self.k_ln = nn.LayerNorm(self.d_model)
  66. self.out_proj = RowParallelLinear(
  67. self.d_model,
  68. self.d_model,
  69. bias=not config.no_bias,
  70. linear_method=linear_method,
  71. )
  72. tp_world_size = get_tensor_model_parallel_world_size()
  73. assert self.total_num_heads % tp_world_size == 0
  74. self.num_heads = self.total_num_heads // tp_world_size
  75. if self.total_num_kv_heads >= tp_world_size:
  76. # Number of KV heads is greater than TP size, so we partition
  77. # the KV heads across multiple tensor parallel GPUs.
  78. assert self.total_num_kv_heads % tp_world_size == 0
  79. else:
  80. # Number of KV heads is less than TP size, so we replicate
  81. # the KV heads across multiple tensor parallel GPUs.
  82. assert tp_world_size % self.total_num_kv_heads == 0
  83. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  84. self.q_size = self.num_heads * self.head_dim
  85. self.kv_size = self.num_kv_heads * self.head_dim
  86. # Create the alibi slopes and slice them.
  87. tp_rank = get_tensor_model_parallel_rank()
  88. head_start = tp_rank * self.num_heads
  89. head_end = (tp_rank + 1) * self.num_heads
  90. alibi_slopes = _get_alibi_slopes(self.total_num_heads,
  91. self.alibi_bias_max)
  92. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  93. self.head_dim = self.d_model // self.total_num_heads
  94. scaling = self.head_dim**-0.5
  95. self.attn = Attention(self.num_heads,
  96. self.head_dim,
  97. scaling,
  98. alibi_slopes=alibi_slopes,
  99. num_kv_heads=self.num_kv_heads)
  100. def forward(
  101. self,
  102. position_ids: torch.Tensor,
  103. hidden_states: torch.Tensor,
  104. kv_cache: torch.Tensor,
  105. attn_metadata: AttentionMetadata,
  106. ) -> torch.Tensor:
  107. del position_ids # unused.
  108. qkv, _ = self.Wqkv(hidden_states)
  109. if self.clip_qkv is not None:
  110. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  111. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  112. if self.qk_ln:
  113. q = self.q_ln(q)
  114. k = self.k_ln(k)
  115. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  116. output, _ = self.out_proj(attn_output)
  117. return output
  118. class MPTMLP(nn.Module):
  119. def __init__(
  120. self,
  121. config: MPTConfig,
  122. linear_method: Optional[LinearMethodBase] = None,
  123. ):
  124. super().__init__()
  125. hidden_size = config.d_model
  126. expansion_ratio = config.expansion_ratio
  127. intermediate_size = expansion_ratio * hidden_size
  128. self.up_proj = ColumnParallelLinear(
  129. hidden_size,
  130. intermediate_size,
  131. bias=not config.no_bias,
  132. linear_method=linear_method,
  133. )
  134. quant_config = getattr(linear_method, "quant_config", None)
  135. self.act = get_act_fn("gelu", quant_config, intermediate_size)
  136. self.down_proj = RowParallelLinear(
  137. intermediate_size,
  138. hidden_size,
  139. bias=not config.no_bias,
  140. linear_method=linear_method,
  141. )
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. x, _ = self.up_proj(x)
  144. x = self.act(x)
  145. x, _ = self.down_proj(x)
  146. return x
  147. class MPTBlock(nn.Module):
  148. def __init__(
  149. self,
  150. config: MPTConfig,
  151. linear_method: Optional[LinearMethodBase] = None,
  152. ):
  153. super().__init__()
  154. hidden_size = config.d_model
  155. self.norm_1 = nn.LayerNorm(hidden_size)
  156. self.attn = MPTAttention(config, linear_method)
  157. self.norm_2 = nn.LayerNorm(hidden_size)
  158. self.ffn = MPTMLP(config, linear_method)
  159. def forward(
  160. self,
  161. position_ids: torch.Tensor,
  162. hidden_states: torch.Tensor,
  163. kv_cache: torch.Tensor,
  164. attn_metadata: AttentionMetadata,
  165. ) -> torch.Tensor:
  166. x = self.norm_1(hidden_states)
  167. x = self.attn(
  168. position_ids=position_ids,
  169. hidden_states=x,
  170. kv_cache=kv_cache,
  171. attn_metadata=attn_metadata,
  172. )
  173. hidden_states = hidden_states + x
  174. x = self.norm_2(hidden_states)
  175. x = self.ffn(x)
  176. hidden_states = hidden_states + x
  177. return hidden_states
  178. class MPTModel(nn.Module):
  179. def __init__(
  180. self,
  181. config: MPTConfig,
  182. linear_method: Optional[LinearMethodBase] = None,
  183. ):
  184. super().__init__()
  185. assert config.embedding_fraction == 1.0
  186. assert config.norm_type == "low_precision_layernorm"
  187. self.wte = VocabParallelEmbedding(config.vocab_size,
  188. config.d_model,
  189. linear_method=linear_method)
  190. self.blocks = nn.ModuleList(
  191. [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
  192. self.norm_f = nn.LayerNorm(config.d_model)
  193. if config.no_bias:
  194. for module in self.modules():
  195. if hasattr(module, "bias") and isinstance(
  196. module.bias, nn.Parameter):
  197. # Remove the bias term in Linear and LayerNorm.
  198. module.register_parameter("bias", None)
  199. def forward(
  200. self,
  201. input_ids: torch.Tensor,
  202. position_ids: torch.Tensor,
  203. kv_caches: List[torch.Tensor],
  204. attn_metadata: AttentionMetadata,
  205. ) -> torch.Tensor:
  206. hidden_states = self.wte(input_ids)
  207. for i in range(len(self.blocks)):
  208. block = self.blocks[i]
  209. hidden_states = block(
  210. position_ids,
  211. hidden_states,
  212. kv_caches[i],
  213. attn_metadata,
  214. )
  215. hidden_states = self.norm_f(hidden_states)
  216. return hidden_states
  217. class MPTForCausalLM(nn.Module):
  218. def __init__(
  219. self,
  220. config: MPTConfig,
  221. linear_method: Optional[LinearMethodBase] = None,
  222. ):
  223. super().__init__()
  224. self.config = config
  225. assert config.tie_word_embeddings
  226. self.linear_method = linear_method
  227. self.transformer = MPTModel(config, linear_method)
  228. # self.lm_head_weight = self.transformer.wte.weight
  229. self.lm_head = ParallelLMHead(config.vocab_size,
  230. config.hidden_size,
  231. linear_method=linear_method)
  232. self.logits_processor = LogitsProcessor(config.vocab_size)
  233. self.sampler = Sampler()
  234. def forward(
  235. self,
  236. input_ids: torch.Tensor,
  237. positions: torch.Tensor,
  238. kv_caches: List[torch.Tensor],
  239. attn_metadata: AttentionMetadata,
  240. ) -> torch.Tensor:
  241. hidden_states = self.transformer(input_ids, positions, kv_caches,
  242. attn_metadata)
  243. return hidden_states
  244. def compute_logits(self, hidden_states: torch.Tensor,
  245. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  246. logits = self.logits_processor(self.lm_head, hidden_states,
  247. sampling_metadata)
  248. return logits
  249. def sample(
  250. self,
  251. logits: torch.Tensor,
  252. sampling_metadata: SamplingMetadata,
  253. ) -> Optional[SamplerOutput]:
  254. next_tokens = self.sampler(logits, sampling_metadata)
  255. return next_tokens
  256. def load_weights(self,
  257. model_name_or_path: str,
  258. cache_dir: Optional[str] = None,
  259. load_format: str = "auto",
  260. revision: Optional[str] = None):
  261. params_dict = dict(self.named_parameters(remove_duplicate=False))
  262. for name, loaded_weight in hf_model_weights_iterator(
  263. model_name_or_path, cache_dir, load_format, revision,
  264. self.config):
  265. # Skip loading extra bias for GPTQ models.
  266. if name.endswith(".bias") and name not in params_dict:
  267. continue
  268. if "wte" in name:
  269. # Copy word embedding to lm_head
  270. head_name = name.replace("transformer.wte", "lm_head")
  271. if head_name in params_dict:
  272. lm_head_param = params_dict[head_name]
  273. weight_loader = getattr(lm_head_param, "weight_loader",
  274. default_weight_loader)
  275. weight_loader(lm_head_param, loaded_weight)
  276. param = params_dict[name]
  277. weight_loader = getattr(param, "weight_loader",
  278. default_weight_loader)
  279. weight_loader(param, loaded_weight)