dbrx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # coding=utf-8
  2. from typing import Iterable, List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import Attention, AttentionMetadata
  6. from aphrodite.common.config import CacheConfig
  7. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  8. from aphrodite.common.utils import progress_bar
  9. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  10. get_tensor_model_parallel_world_size,
  11. tensor_model_parallel_all_reduce)
  12. from aphrodite.modeling.layers.fused_moe import fused_moe
  13. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  14. ReplicatedLinear,
  15. RowParallelLinear)
  16. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  17. from aphrodite.modeling.layers.rotary_embedding import get_rope
  18. from aphrodite.modeling.layers.sampler import Sampler
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  21. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  22. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  23. from aphrodite.modeling.utils import set_weight_attrs
  24. from aphrodite.quantization.base_config import QuantizationConfig
  25. from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
  26. class DbrxRouter(nn.Module):
  27. """A Router implementation for DBRX that returns logits for each expert
  28. per token.
  29. """
  30. def __init__(
  31. self,
  32. config: DbrxConfig,
  33. params_dtype: Optional[torch.dtype] = None,
  34. ):
  35. super().__init__()
  36. self.tp_size = get_tensor_model_parallel_world_size()
  37. self.num_total_experts = config.ffn_config.moe_num_experts
  38. self.d_model = config.d_model
  39. self.layer = ReplicatedLinear(
  40. self.d_model,
  41. self.num_total_experts,
  42. bias=False,
  43. params_dtype=params_dtype,
  44. quant_config=None,
  45. )
  46. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  47. router_logits, _ = self.layer(hidden_states)
  48. return router_logits
  49. class DbrxExperts(nn.Module):
  50. """A tensor-parallel MoE implementation for DBRX.
  51. Each expert's weights are sharded across all ranks and a fused MoE
  52. kernel is used for the forward pass, and finally we reduce the outputs
  53. across ranks.
  54. """
  55. def __init__(
  56. self,
  57. config: DbrxConfig,
  58. quant_config: Optional[QuantizationConfig] = None,
  59. params_dtype: Optional[torch.dtype] = None,
  60. ):
  61. super().__init__()
  62. self.tp_size = get_tensor_model_parallel_world_size()
  63. self.num_total_experts = config.ffn_config.moe_num_experts
  64. self.top_k = config.ffn_config.moe_top_k
  65. self.d_model = config.d_model
  66. self.intermediate_size = (config.ffn_config.ffn_hidden_size //
  67. self.tp_size)
  68. if params_dtype is None:
  69. params_dtype = torch.get_default_dtype()
  70. self.params_dtype = params_dtype
  71. self.router = DbrxRouter(config, self.params_dtype)
  72. self.ws = nn.Parameter(
  73. torch.empty(
  74. self.num_total_experts,
  75. 2 * self.intermediate_size,
  76. self.d_model,
  77. device="cuda",
  78. dtype=self.params_dtype,
  79. ))
  80. self.w2s = nn.Parameter(
  81. torch.empty(
  82. self.num_total_experts,
  83. self.d_model,
  84. self.intermediate_size,
  85. device="cuda",
  86. dtype=self.params_dtype,
  87. ))
  88. set_weight_attrs(
  89. self.ws,
  90. {
  91. "weight_loader": self.weight_loader,
  92. },
  93. )
  94. set_weight_attrs(
  95. self.w2s,
  96. {
  97. "weight_loader": self.weight_loader,
  98. },
  99. )
  100. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  101. weight_name: str):
  102. tp_rank = get_tensor_model_parallel_rank()
  103. param_data = param.data
  104. shard_size = self.intermediate_size
  105. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  106. # DBRX uses GLU for each experts.
  107. # GLU has 3 linear layers: w1, v1 and w2.
  108. if weight_name.endswith("w1"):
  109. loaded_weight = torch.reshape(
  110. loaded_weight,
  111. [-1, self.intermediate_size * self.tp_size, self.d_model],
  112. )
  113. param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
  114. if weight_name.endswith("v1"):
  115. loaded_weight = torch.reshape(
  116. loaded_weight,
  117. [-1, self.intermediate_size * self.tp_size, self.d_model],
  118. )
  119. param_data[:,
  120. shard_size:2 * shard_size, :] = loaded_weight[:,
  121. shard, :]
  122. if weight_name.endswith("w2"):
  123. loaded_weight = torch.reshape(
  124. loaded_weight,
  125. [-1, self.intermediate_size * self.tp_size, self.d_model],
  126. ).transpose(1, 2)
  127. param_data[:] = loaded_weight[:, :, shard]
  128. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  129. num_tokens, hidden_size = hidden_states.shape
  130. hidden_states = hidden_states.view(-1, self.d_model)
  131. # router_logits: (num_tokens, n_experts)
  132. router_logits = self.router(hidden_states)
  133. final_hidden_states = fused_moe(
  134. hidden_states,
  135. self.ws,
  136. self.w2s,
  137. router_logits,
  138. self.top_k,
  139. renormalize=True,
  140. inplace=True,
  141. )
  142. if self.tp_size > 1:
  143. final_hidden_states = tensor_model_parallel_all_reduce(
  144. final_hidden_states)
  145. return final_hidden_states.view(num_tokens, hidden_size)
  146. class DbrxAttention(nn.Module):
  147. def __init__(
  148. self,
  149. config: DbrxConfig,
  150. cache_config: Optional[CacheConfig] = None,
  151. quant_config: Optional[QuantizationConfig] = None,
  152. ):
  153. super().__init__()
  154. self.d_model = config.d_model
  155. self.total_num_heads = config.n_heads
  156. self.head_dim = self.d_model // self.total_num_heads
  157. self.total_num_kv_heads = config.attn_config.kv_n_heads
  158. self.clip_qkv = config.attn_config.clip_qkv
  159. self.rope_theta = config.attn_config.rope_theta
  160. self.max_position = config.max_seq_len
  161. # pylint: disable=invalid-name
  162. self.Wqkv = QKVParallelLinear(
  163. self.d_model,
  164. self.head_dim,
  165. self.total_num_heads,
  166. self.total_num_kv_heads,
  167. bias=False,
  168. quant_config=quant_config,
  169. )
  170. self.out_proj = RowParallelLinear(
  171. self.d_model,
  172. self.d_model,
  173. bias=False,
  174. quant_config=quant_config,
  175. )
  176. self.rotary_emb = get_rope(
  177. self.head_dim,
  178. rotary_dim=self.head_dim,
  179. max_position=self.max_position,
  180. base=int(self.rope_theta),
  181. is_neox_style=True,
  182. )
  183. tp_world_size = get_tensor_model_parallel_world_size()
  184. self.tp_size = tp_world_size
  185. assert self.total_num_heads % tp_world_size == 0
  186. self.num_heads = self.total_num_heads // tp_world_size
  187. if self.total_num_kv_heads >= tp_world_size:
  188. # Number of KV heads is greater than TP size, so we partition
  189. # the KV heads across multiple tensor parallel GPUs.
  190. assert self.total_num_kv_heads % tp_world_size == 0
  191. else:
  192. # Number of KV heads is less than TP size, so we replicate
  193. # the KV heads across multiple tensor parallel GPUs.
  194. assert tp_world_size % self.total_num_kv_heads == 0
  195. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  196. self.q_size = self.num_heads * self.head_dim
  197. self.kv_size = self.num_kv_heads * self.head_dim
  198. self.scaling = self.head_dim**-0.5
  199. self.attn = Attention(self.num_heads,
  200. self.head_dim,
  201. self.scaling,
  202. num_kv_heads=self.num_kv_heads,
  203. cache_config=cache_config,
  204. quant_config=quant_config)
  205. def forward(
  206. self,
  207. position_ids: torch.Tensor,
  208. hidden_states: torch.Tensor,
  209. kv_cache: torch.Tensor,
  210. attn_metadata: AttentionMetadata,
  211. ) -> torch.Tensor:
  212. qkv, _ = self.Wqkv(hidden_states)
  213. if self.clip_qkv is not None:
  214. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  215. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  216. q, k = self.rotary_emb(position_ids, q, k)
  217. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  218. hidden_states, _ = self.out_proj(attn_output)
  219. return hidden_states
  220. class DbrxFusedNormAttention(nn.Module):
  221. def __init__(
  222. self,
  223. config: DbrxConfig,
  224. cache_config: Optional[CacheConfig] = None,
  225. quant_config: Optional[QuantizationConfig] = None,
  226. ):
  227. super().__init__()
  228. self.d_model = config.d_model
  229. self.attn = DbrxAttention(config, cache_config, quant_config)
  230. self.norm_1 = nn.LayerNorm(self.d_model)
  231. self.norm_2 = nn.LayerNorm(self.d_model)
  232. def forward(
  233. self,
  234. position_ids: torch.Tensor,
  235. hidden_states: torch.Tensor,
  236. kv_cache: torch.Tensor,
  237. attn_metadata: AttentionMetadata,
  238. ) -> torch.Tensor:
  239. residual = hidden_states
  240. hidden_states = self.norm_1(hidden_states)
  241. x = self.attn(
  242. position_ids=position_ids,
  243. hidden_states=hidden_states,
  244. kv_cache=kv_cache,
  245. attn_metadata=attn_metadata,
  246. )
  247. hidden_states = residual + x
  248. residual = hidden_states
  249. hidden_states = self.norm_2(hidden_states)
  250. return hidden_states, residual
  251. class DbrxBlock(nn.Module):
  252. def __init__(
  253. self,
  254. config: DbrxConfig,
  255. cache_config: Optional[CacheConfig] = None,
  256. quant_config: Optional[QuantizationConfig] = None,
  257. ):
  258. super().__init__()
  259. self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
  260. quant_config)
  261. self.ffn = DbrxExperts(config, quant_config)
  262. def forward(
  263. self,
  264. position_ids: torch.Tensor,
  265. hidden_states: torch.Tensor,
  266. kv_cache: torch.Tensor,
  267. attn_metadata: AttentionMetadata,
  268. ) -> torch.Tensor:
  269. hidden_states, residual = self.norm_attn_norm(
  270. position_ids=position_ids,
  271. hidden_states=hidden_states,
  272. kv_cache=kv_cache,
  273. attn_metadata=attn_metadata,
  274. )
  275. hidden_states = self.ffn(hidden_states)
  276. hidden_states = hidden_states + residual
  277. return hidden_states
  278. class DbrxModel(nn.Module):
  279. def __init__(
  280. self,
  281. config: DbrxConfig,
  282. cache_config: Optional[CacheConfig] = None,
  283. quant_config: Optional[QuantizationConfig] = None,
  284. ):
  285. super().__init__()
  286. self.wte = VocabParallelEmbedding(
  287. config.vocab_size,
  288. config.d_model,
  289. )
  290. self.blocks = nn.ModuleList([
  291. DbrxBlock(config, cache_config, quant_config)
  292. for _ in range(config.n_layers)
  293. ])
  294. self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
  295. for module in self.modules():
  296. if hasattr(module, "bias") and isinstance(module.bias,
  297. nn.Parameter):
  298. # Remove the bias term in Linear and LayerNorm.
  299. module.register_parameter("bias", None)
  300. def forward(
  301. self,
  302. input_ids: torch.Tensor,
  303. position_ids: torch.Tensor,
  304. kv_caches: List[torch.Tensor],
  305. attn_metadata: AttentionMetadata,
  306. ) -> torch.Tensor:
  307. hidden_states = self.wte(input_ids)
  308. for i in range(len(self.blocks)):
  309. block = self.blocks[i]
  310. hidden_states = block(
  311. position_ids,
  312. hidden_states,
  313. kv_caches[i],
  314. attn_metadata,
  315. )
  316. hidden_states = self.norm_f(hidden_states)
  317. return hidden_states
  318. class DbrxForCausalLM(nn.Module):
  319. def __init__(
  320. self,
  321. config: DbrxConfig,
  322. cache_config: Optional[CacheConfig] = None,
  323. quant_config: Optional[QuantizationConfig] = None,
  324. ):
  325. super().__init__()
  326. self.config = config
  327. self.quant_config = quant_config
  328. self.unpadded_vocab_size = config.vocab_size
  329. self.transformer = DbrxModel(config, cache_config, quant_config)
  330. self.lm_head = ParallelLMHead(
  331. config.vocab_size,
  332. config.d_model,
  333. org_num_embeddings=config.vocab_size,
  334. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  335. quant_config=quant_config,
  336. )
  337. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  338. config.vocab_size)
  339. self.sampler = Sampler()
  340. def forward(
  341. self,
  342. input_ids: torch.Tensor,
  343. positions: torch.Tensor,
  344. kv_caches: List[torch.Tensor],
  345. attn_metadata: AttentionMetadata,
  346. intermediate_tensors: Optional[IntermediateTensors] = None,
  347. ) -> torch.Tensor:
  348. hidden_states = self.transformer(input_ids, positions, kv_caches,
  349. attn_metadata)
  350. return hidden_states
  351. def compute_logits(
  352. self,
  353. hidden_states: torch.Tensor,
  354. sampling_metadata: SamplingMetadata,
  355. ) -> Optional[torch.Tensor]:
  356. logits = self.logits_processor(self.lm_head, hidden_states,
  357. sampling_metadata)
  358. return logits
  359. def sample(
  360. self,
  361. logits: Optional[torch.Tensor],
  362. sampling_metadata: SamplingMetadata,
  363. ) -> Optional[SamplerOutput]:
  364. next_tokens = self.sampler(logits, sampling_metadata)
  365. return next_tokens
  366. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  367. expert_params_mapping = [(
  368. "ws" if weight_name in ["w1", "v1"] else "w2s",
  369. f"experts.mlp.{weight_name}",
  370. ) for weight_name in ["w1", "v1", "w2"]]
  371. params_dict = dict(self.named_parameters(remove_duplicate=False))
  372. weights_list = list(weights)
  373. for name, loaded_weight in progress_bar(weights_list,
  374. desc="Loading modules..."):
  375. for param_name, weight_name in expert_params_mapping:
  376. if weight_name not in name:
  377. continue
  378. name = name.replace(weight_name, param_name)
  379. param = params_dict[name]
  380. weight_loader = param.weight_loader
  381. weight_loader(param, loaded_weight, weight_name)
  382. break
  383. else:
  384. param = params_dict[name]
  385. weight_loader = getattr(param, "weight_loader",
  386. default_weight_loader)
  387. weight_loader(param, loaded_weight)