mpt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # coding=utf-8
  2. # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
  3. import math
  4. from typing import Iterable, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.attention import Attention, AttentionMetadata
  8. from aphrodite.common.config import CacheConfig
  9. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  10. from aphrodite.common.utils import progress_bar
  11. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size)
  13. from aphrodite.modeling.layers.activation import get_act_fn
  14. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  15. QKVParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  18. from aphrodite.modeling.layers.sampler import Sampler
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. VocabParallelEmbedding)
  21. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  22. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  23. from aphrodite.quantization.base_config import QuantizationConfig
  24. from aphrodite.transformers_utils.configs.mpt import MPTConfig
  25. def _get_alibi_slopes(
  26. total_num_heads: int,
  27. alibi_bias_max: int,
  28. ) -> torch.Tensor:
  29. next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
  30. m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
  31. m = m.mul(alibi_bias_max / next_power_of_2)
  32. slopes = 1.0 / torch.pow(2, m)
  33. if next_power_of_2 != total_num_heads:
  34. slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
  35. return slopes
  36. class MPTAttention(nn.Module):
  37. def __init__(
  38. self,
  39. config: MPTConfig,
  40. cache_config: Optional[CacheConfig] = None,
  41. quant_config: Optional[QuantizationConfig] = None,
  42. ):
  43. super().__init__()
  44. self.d_model = config.d_model
  45. self.total_num_heads = config.n_heads
  46. self.head_dim = self.d_model // self.total_num_heads
  47. self.clip_qkv = config.attn_config["clip_qkv"]
  48. self.qk_ln = config.attn_config["qk_ln"]
  49. self.alibi_bias_max = config.attn_config["alibi_bias_max"]
  50. if "kv_n_heads" in config.attn_config:
  51. self.total_num_kv_heads = config.attn_config['kv_n_heads']
  52. else:
  53. self.total_num_kv_heads = self.total_num_heads
  54. assert not config.attn_config["prefix_lm"]
  55. assert config.attn_config["alibi"]
  56. # pylint: disable=invalid-name
  57. self.Wqkv = QKVParallelLinear(
  58. self.d_model,
  59. self.d_model // self.total_num_heads,
  60. self.total_num_heads,
  61. self.total_num_kv_heads,
  62. bias=not config.no_bias,
  63. quant_config=quant_config,
  64. )
  65. if self.qk_ln:
  66. self.q_ln = nn.LayerNorm(self.d_model)
  67. self.k_ln = nn.LayerNorm(self.d_model)
  68. self.out_proj = RowParallelLinear(
  69. self.d_model,
  70. self.d_model,
  71. bias=not config.no_bias,
  72. quant_config=quant_config,
  73. )
  74. tp_world_size = get_tensor_model_parallel_world_size()
  75. assert self.total_num_heads % tp_world_size == 0
  76. self.num_heads = self.total_num_heads // tp_world_size
  77. if self.total_num_kv_heads >= tp_world_size:
  78. # Number of KV heads is greater than TP size, so we partition
  79. # the KV heads across multiple tensor parallel GPUs.
  80. assert self.total_num_kv_heads % tp_world_size == 0
  81. else:
  82. # Number of KV heads is less than TP size, so we replicate
  83. # the KV heads across multiple tensor parallel GPUs.
  84. assert tp_world_size % self.total_num_kv_heads == 0
  85. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  86. self.q_size = self.num_heads * self.head_dim
  87. self.kv_size = self.num_kv_heads * self.head_dim
  88. # Create the alibi slopes and slice them.
  89. tp_rank = get_tensor_model_parallel_rank()
  90. head_start = tp_rank * self.num_heads
  91. head_end = (tp_rank + 1) * self.num_heads
  92. alibi_slopes = _get_alibi_slopes(self.total_num_heads,
  93. self.alibi_bias_max)
  94. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  95. self.head_dim = self.d_model // self.total_num_heads
  96. scaling = self.head_dim**-0.5
  97. self.attn = Attention(self.num_heads,
  98. self.head_dim,
  99. scaling,
  100. alibi_slopes=alibi_slopes,
  101. num_kv_heads=self.num_kv_heads,
  102. cache_config=cache_config,
  103. quant_config=quant_config)
  104. def forward(
  105. self,
  106. position_ids: torch.Tensor,
  107. hidden_states: torch.Tensor,
  108. kv_cache: torch.Tensor,
  109. attn_metadata: AttentionMetadata,
  110. ) -> torch.Tensor:
  111. del position_ids # unused.
  112. qkv, _ = self.Wqkv(hidden_states)
  113. if self.clip_qkv is not None:
  114. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  115. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  116. if self.qk_ln:
  117. q = self.q_ln(q)
  118. k = self.k_ln(k)
  119. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  120. output, _ = self.out_proj(attn_output)
  121. return output
  122. class MPTMLP(nn.Module):
  123. def __init__(
  124. self,
  125. config: MPTConfig,
  126. quant_config: Optional[QuantizationConfig] = None,
  127. ):
  128. super().__init__()
  129. hidden_size = config.d_model
  130. expansion_ratio = config.expansion_ratio
  131. intermediate_size = expansion_ratio * hidden_size
  132. self.up_proj = ColumnParallelLinear(
  133. hidden_size,
  134. intermediate_size,
  135. bias=not config.no_bias,
  136. quant_config=quant_config,
  137. )
  138. self.act = get_act_fn("gelu", quant_config, intermediate_size)
  139. self.down_proj = RowParallelLinear(
  140. intermediate_size,
  141. hidden_size,
  142. bias=not config.no_bias,
  143. quant_config=quant_config,
  144. )
  145. def forward(self, x: torch.Tensor) -> torch.Tensor:
  146. x, _ = self.up_proj(x)
  147. x = self.act(x)
  148. x, _ = self.down_proj(x)
  149. return x
  150. class MPTBlock(nn.Module):
  151. def __init__(
  152. self,
  153. config: MPTConfig,
  154. cache_config: Optional[CacheConfig] = None,
  155. quant_config: Optional[QuantizationConfig] = None,
  156. ):
  157. super().__init__()
  158. hidden_size = config.d_model
  159. self.norm_1 = nn.LayerNorm(hidden_size)
  160. self.attn = MPTAttention(config, cache_config, quant_config)
  161. self.norm_2 = nn.LayerNorm(hidden_size)
  162. self.ffn = MPTMLP(config, quant_config)
  163. def forward(
  164. self,
  165. position_ids: torch.Tensor,
  166. hidden_states: torch.Tensor,
  167. kv_cache: torch.Tensor,
  168. attn_metadata: AttentionMetadata,
  169. ) -> torch.Tensor:
  170. x = self.norm_1(hidden_states)
  171. x = self.attn(
  172. position_ids=position_ids,
  173. hidden_states=x,
  174. kv_cache=kv_cache,
  175. attn_metadata=attn_metadata,
  176. )
  177. hidden_states = hidden_states + x
  178. x = self.norm_2(hidden_states)
  179. x = self.ffn(x)
  180. hidden_states = hidden_states + x
  181. return hidden_states
  182. class MPTModel(nn.Module):
  183. def __init__(
  184. self,
  185. config: MPTConfig,
  186. cache_config: Optional[CacheConfig] = None,
  187. quant_config: Optional[QuantizationConfig] = None,
  188. ):
  189. super().__init__()
  190. assert config.embedding_fraction == 1.0
  191. assert config.norm_type == "low_precision_layernorm"
  192. self.wte = VocabParallelEmbedding(
  193. config.vocab_size,
  194. config.d_model,
  195. )
  196. self.blocks = nn.ModuleList([
  197. MPTBlock(config, cache_config, quant_config)
  198. for _ in range(config.n_layers)
  199. ])
  200. self.norm_f = nn.LayerNorm(config.d_model)
  201. if config.no_bias:
  202. for module in self.modules():
  203. if hasattr(module, "bias") and isinstance(
  204. module.bias, nn.Parameter):
  205. # Remove the bias term in Linear and LayerNorm.
  206. module.register_parameter("bias", None)
  207. def forward(
  208. self,
  209. input_ids: torch.Tensor,
  210. position_ids: torch.Tensor,
  211. kv_caches: List[torch.Tensor],
  212. attn_metadata: AttentionMetadata,
  213. ) -> torch.Tensor:
  214. hidden_states = self.wte(input_ids)
  215. for i in range(len(self.blocks)):
  216. block = self.blocks[i]
  217. hidden_states = block(
  218. position_ids,
  219. hidden_states,
  220. kv_caches[i],
  221. attn_metadata,
  222. )
  223. hidden_states = self.norm_f(hidden_states)
  224. return hidden_states
  225. class MPTForCausalLM(nn.Module):
  226. def __init__(
  227. self,
  228. config: MPTConfig,
  229. cache_config: Optional[CacheConfig] = None,
  230. quant_config: Optional[QuantizationConfig] = None,
  231. ):
  232. super().__init__()
  233. self.config = config
  234. assert config.tie_word_embeddings
  235. self.quant_config = quant_config
  236. self.transformer = MPTModel(config, cache_config, quant_config)
  237. self.lm_head = self.transformer.wte
  238. self.logits_processor = LogitsProcessor(config.vocab_size)
  239. self.sampler = Sampler()
  240. def forward(
  241. self,
  242. input_ids: torch.Tensor,
  243. positions: torch.Tensor,
  244. kv_caches: List[torch.Tensor],
  245. attn_metadata: AttentionMetadata,
  246. intermediate_tensors: Optional[IntermediateTensors] = None,
  247. ) -> torch.Tensor:
  248. hidden_states = self.transformer(input_ids, positions, kv_caches,
  249. attn_metadata)
  250. return hidden_states
  251. def compute_logits(
  252. self,
  253. hidden_states: torch.Tensor,
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[torch.Tensor]:
  256. logits = self.logits_processor(self.lm_head, hidden_states,
  257. sampling_metadata)
  258. return logits
  259. def sample(
  260. self,
  261. logits: torch.Tensor,
  262. sampling_metadata: SamplingMetadata,
  263. ) -> Optional[SamplerOutput]:
  264. next_tokens = self.sampler(logits, sampling_metadata)
  265. return next_tokens
  266. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  267. params_dict = dict(self.named_parameters(remove_duplicate=False))
  268. weights_list = list(weights)
  269. for name, loaded_weight in progress_bar(weights_list,
  270. desc="Loading modules..."):
  271. # Skip loading extra bias for GPTQ models.
  272. if name.endswith(".bias") and name not in params_dict:
  273. continue
  274. param = params_dict[name]
  275. weight_loader = getattr(param, "weight_loader",
  276. default_weight_loader)
  277. weight_loader(param, loaded_weight)