mpt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # coding=utf-8
  2. # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
  3. import math
  4. from typing import Iterable, List, Optional, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from aphrodite.attention import Attention, AttentionMetadata
  8. from aphrodite.common.config import CacheConfig
  9. from aphrodite.common.sequence import IntermediateTensors
  10. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  11. get_tensor_model_parallel_world_size)
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  14. QKVParallelLinear,
  15. RowParallelLinear)
  16. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  17. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. from aphrodite.transformers_utils.configs.mpt import MPTConfig
  24. def _get_alibi_slopes(
  25. total_num_heads: int,
  26. alibi_bias_max: int,
  27. ) -> torch.Tensor:
  28. next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
  29. m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
  30. m = m.mul(alibi_bias_max / next_power_of_2)
  31. slopes = 1.0 / torch.pow(2, m)
  32. if next_power_of_2 != total_num_heads:
  33. slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
  34. return slopes
  35. class MPTAttention(nn.Module):
  36. def __init__(
  37. self,
  38. config: MPTConfig,
  39. cache_config: Optional[CacheConfig] = None,
  40. quant_config: Optional[QuantizationConfig] = None,
  41. ):
  42. super().__init__()
  43. self.d_model = config.d_model
  44. self.total_num_heads = config.n_heads
  45. self.head_dim = self.d_model // self.total_num_heads
  46. self.clip_qkv = config.attn_config["clip_qkv"]
  47. self.qk_ln = config.attn_config["qk_ln"]
  48. self.alibi_bias_max = config.attn_config["alibi_bias_max"]
  49. if "kv_n_heads" in config.attn_config:
  50. self.total_num_kv_heads = config.attn_config['kv_n_heads']
  51. else:
  52. self.total_num_kv_heads = self.total_num_heads
  53. assert not config.attn_config["prefix_lm"]
  54. assert config.attn_config["alibi"]
  55. # pylint: disable=invalid-name
  56. self.Wqkv = QKVParallelLinear(
  57. self.d_model,
  58. self.d_model // self.total_num_heads,
  59. self.total_num_heads,
  60. self.total_num_kv_heads,
  61. bias=not config.no_bias,
  62. quant_config=quant_config,
  63. )
  64. if self.qk_ln:
  65. self.q_ln = nn.LayerNorm(self.d_model)
  66. self.k_ln = nn.LayerNorm(self.d_model)
  67. self.out_proj = RowParallelLinear(
  68. self.d_model,
  69. self.d_model,
  70. bias=not config.no_bias,
  71. quant_config=quant_config,
  72. )
  73. tp_world_size = get_tensor_model_parallel_world_size()
  74. assert self.total_num_heads % tp_world_size == 0
  75. self.num_heads = self.total_num_heads // tp_world_size
  76. if self.total_num_kv_heads >= tp_world_size:
  77. # Number of KV heads is greater than TP size, so we partition
  78. # the KV heads across multiple tensor parallel GPUs.
  79. assert self.total_num_kv_heads % tp_world_size == 0
  80. else:
  81. # Number of KV heads is less than TP size, so we replicate
  82. # the KV heads across multiple tensor parallel GPUs.
  83. assert tp_world_size % self.total_num_kv_heads == 0
  84. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  85. self.q_size = self.num_heads * self.head_dim
  86. self.kv_size = self.num_kv_heads * self.head_dim
  87. # Create the alibi slopes and slice them.
  88. tp_rank = get_tensor_model_parallel_rank()
  89. head_start = tp_rank * self.num_heads
  90. head_end = (tp_rank + 1) * self.num_heads
  91. alibi_slopes = _get_alibi_slopes(self.total_num_heads,
  92. self.alibi_bias_max)
  93. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  94. self.head_dim = self.d_model // self.total_num_heads
  95. scaling = self.head_dim**-0.5
  96. self.attn = Attention(self.num_heads,
  97. self.head_dim,
  98. scaling,
  99. alibi_slopes=alibi_slopes,
  100. num_kv_heads=self.num_kv_heads,
  101. cache_config=cache_config,
  102. quant_config=quant_config)
  103. def forward(
  104. self,
  105. position_ids: torch.Tensor,
  106. hidden_states: torch.Tensor,
  107. kv_cache: torch.Tensor,
  108. attn_metadata: AttentionMetadata,
  109. ) -> torch.Tensor:
  110. del position_ids # unused.
  111. qkv, _ = self.Wqkv(hidden_states)
  112. if self.clip_qkv is not None:
  113. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  114. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  115. if self.qk_ln:
  116. q = self.q_ln(q)
  117. k = self.k_ln(k)
  118. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  119. output, _ = self.out_proj(attn_output)
  120. return output
  121. class MPTMLP(nn.Module):
  122. def __init__(
  123. self,
  124. config: MPTConfig,
  125. quant_config: Optional[QuantizationConfig] = None,
  126. ):
  127. super().__init__()
  128. hidden_size = config.d_model
  129. expansion_ratio = config.expansion_ratio
  130. intermediate_size = expansion_ratio * hidden_size
  131. self.up_proj = ColumnParallelLinear(
  132. hidden_size,
  133. intermediate_size,
  134. bias=not config.no_bias,
  135. quant_config=quant_config,
  136. )
  137. self.act = get_act_fn("gelu", quant_config, intermediate_size)
  138. self.down_proj = RowParallelLinear(
  139. intermediate_size,
  140. hidden_size,
  141. bias=not config.no_bias,
  142. quant_config=quant_config,
  143. )
  144. def forward(self, x: torch.Tensor) -> torch.Tensor:
  145. x, _ = self.up_proj(x)
  146. x = self.act(x)
  147. x, _ = self.down_proj(x)
  148. return x
  149. class MPTBlock(nn.Module):
  150. def __init__(
  151. self,
  152. config: MPTConfig,
  153. cache_config: Optional[CacheConfig] = None,
  154. quant_config: Optional[QuantizationConfig] = None,
  155. ):
  156. super().__init__()
  157. hidden_size = config.d_model
  158. self.norm_1 = nn.LayerNorm(hidden_size)
  159. self.attn = MPTAttention(config, cache_config, quant_config)
  160. self.norm_2 = nn.LayerNorm(hidden_size)
  161. self.ffn = MPTMLP(config, quant_config)
  162. def forward(
  163. self,
  164. position_ids: torch.Tensor,
  165. hidden_states: torch.Tensor,
  166. kv_cache: torch.Tensor,
  167. attn_metadata: AttentionMetadata,
  168. ) -> torch.Tensor:
  169. x = self.norm_1(hidden_states)
  170. x = self.attn(
  171. position_ids=position_ids,
  172. hidden_states=x,
  173. kv_cache=kv_cache,
  174. attn_metadata=attn_metadata,
  175. )
  176. hidden_states = hidden_states + x
  177. x = self.norm_2(hidden_states)
  178. x = self.ffn(x)
  179. hidden_states = hidden_states + x
  180. return hidden_states
  181. class MPTModel(nn.Module):
  182. def __init__(
  183. self,
  184. config: MPTConfig,
  185. cache_config: Optional[CacheConfig] = None,
  186. quant_config: Optional[QuantizationConfig] = None,
  187. ):
  188. super().__init__()
  189. assert config.embedding_fraction == 1.0
  190. assert config.norm_type == "low_precision_layernorm"
  191. self.wte = VocabParallelEmbedding(
  192. config.vocab_size,
  193. config.d_model,
  194. )
  195. self.blocks = nn.ModuleList([
  196. MPTBlock(config, cache_config, quant_config)
  197. for _ in range(config.n_layers)
  198. ])
  199. self.norm_f = nn.LayerNorm(config.d_model)
  200. if config.no_bias:
  201. for module in self.modules():
  202. if hasattr(module, "bias") and isinstance(
  203. module.bias, nn.Parameter):
  204. # Remove the bias term in Linear and LayerNorm.
  205. module.register_parameter("bias", None)
  206. def forward(
  207. self,
  208. input_ids: torch.Tensor,
  209. position_ids: torch.Tensor,
  210. kv_caches: List[torch.Tensor],
  211. attn_metadata: AttentionMetadata,
  212. ) -> torch.Tensor:
  213. hidden_states = self.wte(input_ids)
  214. for i in range(len(self.blocks)):
  215. block = self.blocks[i]
  216. hidden_states = block(
  217. position_ids,
  218. hidden_states,
  219. kv_caches[i],
  220. attn_metadata,
  221. )
  222. hidden_states = self.norm_f(hidden_states)
  223. return hidden_states
  224. class MPTForCausalLM(nn.Module):
  225. def __init__(
  226. self,
  227. config: MPTConfig,
  228. cache_config: Optional[CacheConfig] = None,
  229. quant_config: Optional[QuantizationConfig] = None,
  230. ):
  231. super().__init__()
  232. self.config = config
  233. assert config.tie_word_embeddings
  234. self.quant_config = quant_config
  235. self.transformer = MPTModel(config, cache_config, quant_config)
  236. self.lm_head = self.transformer.wte
  237. self.logits_processor = LogitsProcessor(config.vocab_size)
  238. self.sampler = Sampler()
  239. def forward(
  240. self,
  241. input_ids: torch.Tensor,
  242. positions: torch.Tensor,
  243. kv_caches: List[torch.Tensor],
  244. attn_metadata: AttentionMetadata,
  245. intermediate_tensors: Optional[IntermediateTensors] = None,
  246. ) -> torch.Tensor:
  247. hidden_states = self.transformer(input_ids, positions, kv_caches,
  248. attn_metadata)
  249. return hidden_states
  250. def compute_logits(
  251. self,
  252. hidden_states: torch.Tensor,
  253. sampling_metadata: SamplingMetadata,
  254. ) -> Optional[torch.Tensor]:
  255. logits = self.logits_processor(self.lm_head, hidden_states,
  256. sampling_metadata)
  257. return logits
  258. def sample(
  259. self,
  260. logits: torch.Tensor,
  261. sampling_metadata: SamplingMetadata,
  262. ) -> Optional[SamplerOutput]:
  263. next_tokens = self.sampler(logits, sampling_metadata)
  264. return next_tokens
  265. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  266. params_dict = dict(self.named_parameters(remove_duplicate=False))
  267. for name, loaded_weight in weights:
  268. # Skip loading extra bias for GPTQ models.
  269. if name.endswith(".bias") and name not in params_dict:
  270. continue
  271. param = params_dict[name]
  272. weight_loader = getattr(param, "weight_loader",
  273. default_weight_loader)
  274. weight_loader(param, loaded_weight)