dbrx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # coding=utf-8
  2. from typing import Iterable, List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import Attention, AttentionMetadata
  6. from aphrodite.common.config import CacheConfig
  7. from aphrodite.common.sequence import IntermediateTensors
  8. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  9. get_tensor_model_parallel_world_size)
  10. from aphrodite.modeling.layers.fused_moe import FusedMoE
  11. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  12. ReplicatedLinear,
  13. RowParallelLinear)
  14. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  15. from aphrodite.modeling.layers.rotary_embedding import get_rope
  16. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  17. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  18. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  19. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  20. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  21. from aphrodite.quantization.base_config import QuantizationConfig
  22. from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
  23. class DbrxRouter(nn.Module):
  24. """A Router implementation for DBRX that returns logits for each expert
  25. per token.
  26. """
  27. def __init__(
  28. self,
  29. config: DbrxConfig,
  30. params_dtype: Optional[torch.dtype] = None,
  31. ):
  32. super().__init__()
  33. self.tp_size = get_tensor_model_parallel_world_size()
  34. self.num_total_experts = config.ffn_config.moe_num_experts
  35. self.d_model = config.d_model
  36. self.layer = ReplicatedLinear(
  37. self.d_model,
  38. self.num_total_experts,
  39. bias=False,
  40. params_dtype=params_dtype,
  41. quant_config=None,
  42. )
  43. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  44. router_logits, _ = self.layer(hidden_states)
  45. return router_logits
  46. class DbrxExperts(FusedMoE):
  47. def __init__(
  48. self,
  49. config: DbrxConfig,
  50. quant_config: Optional[QuantizationConfig] = None,
  51. params_dtype: Optional[torch.dtype] = None,
  52. ):
  53. super().__init__(
  54. num_experts=config.ffn_config.moe_num_experts,
  55. top_k=config.ffn_config.moe_top_k,
  56. hidden_size=config.d_model,
  57. intermediate_size=config.ffn_config.ffn_hidden_size,
  58. params_dtype=params_dtype,
  59. reduce_results=True,
  60. renormalize=True,
  61. quant_config=quant_config,
  62. tp_size=get_tensor_model_parallel_world_size(),
  63. )
  64. self.config = config
  65. self.tp_size = get_tensor_model_parallel_world_size()
  66. self.d_model = config.d_model
  67. self.intermediate_size = (self.config.ffn_config.ffn_hidden_size //
  68. self.tp_size)
  69. # Define custom weight loader for dbrx model
  70. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  71. weight_name: str):
  72. tp_rank = get_tensor_model_parallel_rank()
  73. param_data = param.data
  74. shard_size = self.intermediate_size
  75. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  76. # DBRX uses GLU for each experts.
  77. # GLU has 3 linear layers: w1, v1 and w2.
  78. if weight_name.endswith("w1"):
  79. loaded_weight = torch.reshape(
  80. loaded_weight,
  81. [-1, self.intermediate_size * self.tp_size, self.d_model],
  82. )
  83. param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
  84. if weight_name.endswith("v1"):
  85. loaded_weight = torch.reshape(
  86. loaded_weight,
  87. [-1, self.intermediate_size * self.tp_size, self.d_model],
  88. )
  89. param_data[:,
  90. shard_size:2 * shard_size, :] = loaded_weight[:,
  91. shard, :]
  92. if weight_name.endswith("w2"):
  93. loaded_weight = torch.reshape(
  94. loaded_weight,
  95. [-1, self.intermediate_size * self.tp_size, self.d_model],
  96. ).transpose(1, 2)
  97. param_data[:] = loaded_weight[:, :, shard]
  98. class DbrxMoE(nn.Module):
  99. """A tensor-parallel MoE implementation for DBRX.
  100. Each expert's weights are sharded across all ranks and a fused MoE
  101. kernel is used for the forward pass, and finally we reduce the outputs
  102. across ranks.
  103. """
  104. def __init__(
  105. self,
  106. config: DbrxConfig,
  107. quant_config: Optional[QuantizationConfig] = None,
  108. params_dtype: Optional[torch.dtype] = None,
  109. ):
  110. super().__init__()
  111. self.d_model = config.d_model
  112. if params_dtype is None:
  113. params_dtype = torch.get_default_dtype()
  114. self.params_dtype = params_dtype
  115. self.router = DbrxRouter(config, self.params_dtype)
  116. self.experts = DbrxExperts(config=config,
  117. quant_config=quant_config,
  118. params_dtype=self.params_dtype)
  119. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  120. orig_shape = hidden_states.shape
  121. hidden_states = hidden_states.view(-1, self.d_model)
  122. # router_logits: (num_tokens, n_experts)
  123. router_logits = self.router(hidden_states)
  124. final_hidden_states = self.experts(hidden_states, router_logits)
  125. return final_hidden_states.view(orig_shape)
  126. class DbrxAttention(nn.Module):
  127. def __init__(
  128. self,
  129. config: DbrxConfig,
  130. cache_config: Optional[CacheConfig] = None,
  131. quant_config: Optional[QuantizationConfig] = None,
  132. ):
  133. super().__init__()
  134. self.d_model = config.d_model
  135. self.total_num_heads = config.n_heads
  136. self.head_dim = self.d_model // self.total_num_heads
  137. self.total_num_kv_heads = config.attn_config.kv_n_heads
  138. self.clip_qkv = config.attn_config.clip_qkv
  139. self.rope_theta = config.attn_config.rope_theta
  140. self.max_position = config.max_seq_len
  141. # pylint: disable=invalid-name
  142. self.Wqkv = QKVParallelLinear(
  143. self.d_model,
  144. self.head_dim,
  145. self.total_num_heads,
  146. self.total_num_kv_heads,
  147. bias=False,
  148. quant_config=quant_config,
  149. )
  150. self.out_proj = RowParallelLinear(
  151. self.d_model,
  152. self.d_model,
  153. bias=False,
  154. quant_config=quant_config,
  155. )
  156. self.rotary_emb = get_rope(
  157. self.head_dim,
  158. rotary_dim=self.head_dim,
  159. max_position=self.max_position,
  160. base=int(self.rope_theta),
  161. is_neox_style=True,
  162. )
  163. tp_world_size = get_tensor_model_parallel_world_size()
  164. self.tp_size = tp_world_size
  165. assert self.total_num_heads % tp_world_size == 0
  166. self.num_heads = self.total_num_heads // tp_world_size
  167. if self.total_num_kv_heads >= tp_world_size:
  168. # Number of KV heads is greater than TP size, so we partition
  169. # the KV heads across multiple tensor parallel GPUs.
  170. assert self.total_num_kv_heads % tp_world_size == 0
  171. else:
  172. # Number of KV heads is less than TP size, so we replicate
  173. # the KV heads across multiple tensor parallel GPUs.
  174. assert tp_world_size % self.total_num_kv_heads == 0
  175. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  176. self.q_size = self.num_heads * self.head_dim
  177. self.kv_size = self.num_kv_heads * self.head_dim
  178. self.scaling = self.head_dim**-0.5
  179. self.attn = Attention(self.num_heads,
  180. self.head_dim,
  181. self.scaling,
  182. num_kv_heads=self.num_kv_heads,
  183. cache_config=cache_config,
  184. quant_config=quant_config)
  185. def forward(
  186. self,
  187. position_ids: torch.Tensor,
  188. hidden_states: torch.Tensor,
  189. kv_cache: torch.Tensor,
  190. attn_metadata: AttentionMetadata,
  191. ) -> torch.Tensor:
  192. qkv, _ = self.Wqkv(hidden_states)
  193. if self.clip_qkv is not None:
  194. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  195. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  196. q, k = self.rotary_emb(position_ids, q, k)
  197. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  198. hidden_states, _ = self.out_proj(attn_output)
  199. return hidden_states
  200. class DbrxFusedNormAttention(nn.Module):
  201. def __init__(
  202. self,
  203. config: DbrxConfig,
  204. cache_config: Optional[CacheConfig] = None,
  205. quant_config: Optional[QuantizationConfig] = None,
  206. ):
  207. super().__init__()
  208. self.d_model = config.d_model
  209. self.attn = DbrxAttention(config, cache_config, quant_config)
  210. self.norm_1 = nn.LayerNorm(self.d_model)
  211. self.norm_2 = nn.LayerNorm(self.d_model)
  212. def forward(
  213. self,
  214. position_ids: torch.Tensor,
  215. hidden_states: torch.Tensor,
  216. kv_cache: torch.Tensor,
  217. attn_metadata: AttentionMetadata,
  218. ) -> torch.Tensor:
  219. residual = hidden_states
  220. hidden_states = self.norm_1(hidden_states)
  221. x = self.attn(
  222. position_ids=position_ids,
  223. hidden_states=hidden_states,
  224. kv_cache=kv_cache,
  225. attn_metadata=attn_metadata,
  226. )
  227. hidden_states = residual + x
  228. residual = hidden_states
  229. hidden_states = self.norm_2(hidden_states)
  230. return hidden_states, residual
  231. class DbrxBlock(nn.Module):
  232. def __init__(
  233. self,
  234. config: DbrxConfig,
  235. cache_config: Optional[CacheConfig] = None,
  236. quant_config: Optional[QuantizationConfig] = None,
  237. ):
  238. super().__init__()
  239. self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
  240. quant_config)
  241. self.ffn = DbrxMoE(config, quant_config)
  242. def forward(
  243. self,
  244. position_ids: torch.Tensor,
  245. hidden_states: torch.Tensor,
  246. kv_cache: torch.Tensor,
  247. attn_metadata: AttentionMetadata,
  248. ) -> torch.Tensor:
  249. hidden_states, residual = self.norm_attn_norm(
  250. position_ids=position_ids,
  251. hidden_states=hidden_states,
  252. kv_cache=kv_cache,
  253. attn_metadata=attn_metadata,
  254. )
  255. hidden_states = self.ffn(hidden_states)
  256. hidden_states = hidden_states + residual
  257. return hidden_states
  258. class DbrxModel(nn.Module):
  259. def __init__(
  260. self,
  261. config: DbrxConfig,
  262. cache_config: Optional[CacheConfig] = None,
  263. quant_config: Optional[QuantizationConfig] = None,
  264. ):
  265. super().__init__()
  266. self.wte = VocabParallelEmbedding(
  267. config.vocab_size,
  268. config.d_model,
  269. )
  270. self.blocks = nn.ModuleList([
  271. DbrxBlock(config, cache_config, quant_config)
  272. for _ in range(config.n_layers)
  273. ])
  274. self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
  275. for module in self.modules():
  276. if hasattr(module, "bias") and isinstance(module.bias,
  277. nn.Parameter):
  278. # Remove the bias term in Linear and LayerNorm.
  279. module.register_parameter("bias", None)
  280. def forward(
  281. self,
  282. input_ids: torch.Tensor,
  283. position_ids: torch.Tensor,
  284. kv_caches: List[torch.Tensor],
  285. attn_metadata: AttentionMetadata,
  286. ) -> torch.Tensor:
  287. hidden_states = self.wte(input_ids)
  288. for i in range(len(self.blocks)):
  289. block = self.blocks[i]
  290. hidden_states = block(
  291. position_ids,
  292. hidden_states,
  293. kv_caches[i],
  294. attn_metadata,
  295. )
  296. hidden_states = self.norm_f(hidden_states)
  297. return hidden_states
  298. class DbrxForCausalLM(nn.Module):
  299. def __init__(
  300. self,
  301. config: DbrxConfig,
  302. cache_config: Optional[CacheConfig] = None,
  303. quant_config: Optional[QuantizationConfig] = None,
  304. ):
  305. super().__init__()
  306. self.config = config
  307. if config.tie_word_embeddings:
  308. raise ValueError(
  309. "tie_word_embeddings is not supported for Dbrx models.")
  310. self.quant_config = quant_config
  311. self.unpadded_vocab_size = config.vocab_size
  312. self.transformer = DbrxModel(config, cache_config, quant_config)
  313. self.lm_head = ParallelLMHead(
  314. config.vocab_size,
  315. config.d_model,
  316. org_num_embeddings=config.vocab_size,
  317. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  318. quant_config=quant_config,
  319. )
  320. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  321. config.vocab_size)
  322. self.sampler = Sampler()
  323. def forward(
  324. self,
  325. input_ids: torch.Tensor,
  326. positions: torch.Tensor,
  327. kv_caches: List[torch.Tensor],
  328. attn_metadata: AttentionMetadata,
  329. intermediate_tensors: Optional[IntermediateTensors] = None,
  330. ) -> torch.Tensor:
  331. hidden_states = self.transformer(input_ids, positions, kv_caches,
  332. attn_metadata)
  333. return hidden_states
  334. def compute_logits(
  335. self,
  336. hidden_states: torch.Tensor,
  337. sampling_metadata: SamplingMetadata,
  338. ) -> Optional[torch.Tensor]:
  339. logits = self.logits_processor(self.lm_head, hidden_states,
  340. sampling_metadata)
  341. return logits
  342. def sample(
  343. self,
  344. logits: Optional[torch.Tensor],
  345. sampling_metadata: SamplingMetadata,
  346. ) -> Optional[SamplerOutput]:
  347. next_tokens = self.sampler(logits, sampling_metadata)
  348. return next_tokens
  349. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  350. expert_params_mapping = [(
  351. "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
  352. f"mlp.{weight_name}",
  353. ) for weight_name in ["w1", "v1", "w2"]]
  354. params_dict = dict(self.named_parameters(remove_duplicate=False))
  355. for name, loaded_weight in weights:
  356. for param_name, weight_name in expert_params_mapping:
  357. if weight_name not in name:
  358. continue
  359. name = name.replace(weight_name, param_name)
  360. param = params_dict[name]
  361. weight_loader = param.weight_loader
  362. weight_loader(param, loaded_weight, weight_name)
  363. break
  364. else:
  365. param = params_dict[name]
  366. weight_loader = getattr(param, "weight_loader",
  367. default_weight_loader)
  368. weight_loader(param, loaded_weight)