dbrx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # coding=utf-8
  2. from typing import Iterable, List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import Attention, AttentionMetadata
  6. from aphrodite.common.config import CacheConfig
  7. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  8. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  9. get_tensor_model_parallel_world_size,
  10. tensor_model_parallel_all_reduce)
  11. from aphrodite.modeling.layers.fused_moe import fused_moe
  12. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  13. ReplicatedLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.modeling.utils import set_weight_attrs
  23. from aphrodite.quantization.base_config import QuantizationConfig
  24. from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
  25. class DbrxRouter(nn.Module):
  26. """A Router implementation for DBRX that returns logits for each expert
  27. per token.
  28. """
  29. def __init__(
  30. self,
  31. config: DbrxConfig,
  32. params_dtype: Optional[torch.dtype] = None,
  33. ):
  34. super().__init__()
  35. self.tp_size = get_tensor_model_parallel_world_size()
  36. self.num_total_experts = config.ffn_config.moe_num_experts
  37. self.d_model = config.d_model
  38. self.layer = ReplicatedLinear(
  39. self.d_model,
  40. self.num_total_experts,
  41. bias=False,
  42. params_dtype=params_dtype,
  43. quant_config=None,
  44. )
  45. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  46. router_logits, _ = self.layer(hidden_states)
  47. return router_logits
  48. class DbrxExperts(nn.Module):
  49. """A tensor-parallel MoE implementation for DBRX.
  50. Each expert's weights are sharded across all ranks and a fused MoE
  51. kernel is used for the forward pass, and finally we reduce the outputs
  52. across ranks.
  53. """
  54. def __init__(
  55. self,
  56. config: DbrxConfig,
  57. quant_config: Optional[QuantizationConfig] = None,
  58. params_dtype: Optional[torch.dtype] = None,
  59. ):
  60. super().__init__()
  61. self.tp_size = get_tensor_model_parallel_world_size()
  62. self.num_total_experts = config.ffn_config.moe_num_experts
  63. self.top_k = config.ffn_config.moe_top_k
  64. self.d_model = config.d_model
  65. self.intermediate_size = (config.ffn_config.ffn_hidden_size //
  66. self.tp_size)
  67. if params_dtype is None:
  68. params_dtype = torch.get_default_dtype()
  69. self.params_dtype = params_dtype
  70. self.router = DbrxRouter(config, self.params_dtype)
  71. self.ws = nn.Parameter(
  72. torch.empty(
  73. self.num_total_experts,
  74. 2 * self.intermediate_size,
  75. self.d_model,
  76. device="cuda",
  77. dtype=self.params_dtype,
  78. ))
  79. self.w2s = nn.Parameter(
  80. torch.empty(
  81. self.num_total_experts,
  82. self.d_model,
  83. self.intermediate_size,
  84. device="cuda",
  85. dtype=self.params_dtype,
  86. ))
  87. set_weight_attrs(
  88. self.ws,
  89. {
  90. "weight_loader": self.weight_loader,
  91. },
  92. )
  93. set_weight_attrs(
  94. self.w2s,
  95. {
  96. "weight_loader": self.weight_loader,
  97. },
  98. )
  99. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  100. weight_name: str):
  101. tp_rank = get_tensor_model_parallel_rank()
  102. param_data = param.data
  103. shard_size = self.intermediate_size
  104. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  105. # DBRX uses GLU for each experts.
  106. # GLU has 3 linear layers: w1, v1 and w2.
  107. if weight_name.endswith("w1"):
  108. loaded_weight = torch.reshape(
  109. loaded_weight,
  110. [-1, self.intermediate_size * self.tp_size, self.d_model],
  111. )
  112. param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
  113. if weight_name.endswith("v1"):
  114. loaded_weight = torch.reshape(
  115. loaded_weight,
  116. [-1, self.intermediate_size * self.tp_size, self.d_model],
  117. )
  118. param_data[:,
  119. shard_size:2 * shard_size, :] = loaded_weight[:,
  120. shard, :]
  121. if weight_name.endswith("w2"):
  122. loaded_weight = torch.reshape(
  123. loaded_weight,
  124. [-1, self.intermediate_size * self.tp_size, self.d_model],
  125. ).transpose(1, 2)
  126. param_data[:] = loaded_weight[:, :, shard]
  127. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  128. num_tokens, hidden_size = hidden_states.shape
  129. hidden_states = hidden_states.view(-1, self.d_model)
  130. # router_logits: (num_tokens, n_experts)
  131. router_logits = self.router(hidden_states)
  132. final_hidden_states = fused_moe(
  133. hidden_states,
  134. self.ws,
  135. self.w2s,
  136. router_logits,
  137. self.top_k,
  138. renormalize=True,
  139. inplace=True,
  140. )
  141. if self.tp_size > 1:
  142. final_hidden_states = tensor_model_parallel_all_reduce(
  143. final_hidden_states)
  144. return final_hidden_states.view(num_tokens, hidden_size)
  145. class DbrxAttention(nn.Module):
  146. def __init__(
  147. self,
  148. config: DbrxConfig,
  149. cache_config: Optional[CacheConfig] = None,
  150. quant_config: Optional[QuantizationConfig] = None,
  151. ):
  152. super().__init__()
  153. self.d_model = config.d_model
  154. self.total_num_heads = config.n_heads
  155. self.head_dim = self.d_model // self.total_num_heads
  156. self.total_num_kv_heads = config.attn_config.kv_n_heads
  157. self.clip_qkv = config.attn_config.clip_qkv
  158. self.rope_theta = config.attn_config.rope_theta
  159. self.max_position = config.max_seq_len
  160. # pylint: disable=invalid-name
  161. self.Wqkv = QKVParallelLinear(
  162. self.d_model,
  163. self.head_dim,
  164. self.total_num_heads,
  165. self.total_num_kv_heads,
  166. bias=False,
  167. quant_config=quant_config,
  168. )
  169. self.out_proj = RowParallelLinear(
  170. self.d_model,
  171. self.d_model,
  172. bias=False,
  173. quant_config=quant_config,
  174. )
  175. self.rotary_emb = get_rope(
  176. self.head_dim,
  177. rotary_dim=self.head_dim,
  178. max_position=self.max_position,
  179. base=int(self.rope_theta),
  180. is_neox_style=True,
  181. )
  182. tp_world_size = get_tensor_model_parallel_world_size()
  183. self.tp_size = tp_world_size
  184. assert self.total_num_heads % tp_world_size == 0
  185. self.num_heads = self.total_num_heads // tp_world_size
  186. if self.total_num_kv_heads >= tp_world_size:
  187. # Number of KV heads is greater than TP size, so we partition
  188. # the KV heads across multiple tensor parallel GPUs.
  189. assert self.total_num_kv_heads % tp_world_size == 0
  190. else:
  191. # Number of KV heads is less than TP size, so we replicate
  192. # the KV heads across multiple tensor parallel GPUs.
  193. assert tp_world_size % self.total_num_kv_heads == 0
  194. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  195. self.q_size = self.num_heads * self.head_dim
  196. self.kv_size = self.num_kv_heads * self.head_dim
  197. self.scaling = self.head_dim**-0.5
  198. self.attn = Attention(self.num_heads,
  199. self.head_dim,
  200. self.scaling,
  201. num_kv_heads=self.num_kv_heads,
  202. cache_config=cache_config,
  203. quant_config=quant_config)
  204. def forward(
  205. self,
  206. position_ids: torch.Tensor,
  207. hidden_states: torch.Tensor,
  208. kv_cache: torch.Tensor,
  209. attn_metadata: AttentionMetadata,
  210. ) -> torch.Tensor:
  211. qkv, _ = self.Wqkv(hidden_states)
  212. if self.clip_qkv is not None:
  213. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  214. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  215. q, k = self.rotary_emb(position_ids, q, k)
  216. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  217. hidden_states, _ = self.out_proj(attn_output)
  218. return hidden_states
  219. class DbrxFusedNormAttention(nn.Module):
  220. def __init__(
  221. self,
  222. config: DbrxConfig,
  223. cache_config: Optional[CacheConfig] = None,
  224. quant_config: Optional[QuantizationConfig] = None,
  225. ):
  226. super().__init__()
  227. self.d_model = config.d_model
  228. self.attn = DbrxAttention(config, cache_config, quant_config)
  229. self.norm_1 = nn.LayerNorm(self.d_model)
  230. self.norm_2 = nn.LayerNorm(self.d_model)
  231. def forward(
  232. self,
  233. position_ids: torch.Tensor,
  234. hidden_states: torch.Tensor,
  235. kv_cache: torch.Tensor,
  236. attn_metadata: AttentionMetadata,
  237. ) -> torch.Tensor:
  238. residual = hidden_states
  239. hidden_states = self.norm_1(hidden_states)
  240. x = self.attn(
  241. position_ids=position_ids,
  242. hidden_states=hidden_states,
  243. kv_cache=kv_cache,
  244. attn_metadata=attn_metadata,
  245. )
  246. hidden_states = residual + x
  247. residual = hidden_states
  248. hidden_states = self.norm_2(hidden_states)
  249. return hidden_states, residual
  250. class DbrxBlock(nn.Module):
  251. def __init__(
  252. self,
  253. config: DbrxConfig,
  254. cache_config: Optional[CacheConfig] = None,
  255. quant_config: Optional[QuantizationConfig] = None,
  256. ):
  257. super().__init__()
  258. self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
  259. quant_config)
  260. self.ffn = DbrxExperts(config, quant_config)
  261. def forward(
  262. self,
  263. position_ids: torch.Tensor,
  264. hidden_states: torch.Tensor,
  265. kv_cache: torch.Tensor,
  266. attn_metadata: AttentionMetadata,
  267. ) -> torch.Tensor:
  268. hidden_states, residual = self.norm_attn_norm(
  269. position_ids=position_ids,
  270. hidden_states=hidden_states,
  271. kv_cache=kv_cache,
  272. attn_metadata=attn_metadata,
  273. )
  274. hidden_states = self.ffn(hidden_states)
  275. hidden_states = hidden_states + residual
  276. return hidden_states
  277. class DbrxModel(nn.Module):
  278. def __init__(
  279. self,
  280. config: DbrxConfig,
  281. cache_config: Optional[CacheConfig] = None,
  282. quant_config: Optional[QuantizationConfig] = None,
  283. ):
  284. super().__init__()
  285. self.wte = VocabParallelEmbedding(
  286. config.vocab_size,
  287. config.d_model,
  288. )
  289. self.blocks = nn.ModuleList([
  290. DbrxBlock(config, cache_config, quant_config)
  291. for _ in range(config.n_layers)
  292. ])
  293. self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
  294. for module in self.modules():
  295. if hasattr(module, "bias") and isinstance(module.bias,
  296. nn.Parameter):
  297. # Remove the bias term in Linear and LayerNorm.
  298. module.register_parameter("bias", None)
  299. def forward(
  300. self,
  301. input_ids: torch.Tensor,
  302. position_ids: torch.Tensor,
  303. kv_caches: List[torch.Tensor],
  304. attn_metadata: AttentionMetadata,
  305. ) -> torch.Tensor:
  306. hidden_states = self.wte(input_ids)
  307. for i in range(len(self.blocks)):
  308. block = self.blocks[i]
  309. hidden_states = block(
  310. position_ids,
  311. hidden_states,
  312. kv_caches[i],
  313. attn_metadata,
  314. )
  315. hidden_states = self.norm_f(hidden_states)
  316. return hidden_states
  317. class DbrxForCausalLM(nn.Module):
  318. def __init__(
  319. self,
  320. config: DbrxConfig,
  321. cache_config: Optional[CacheConfig] = None,
  322. quant_config: Optional[QuantizationConfig] = None,
  323. ):
  324. super().__init__()
  325. self.config = config
  326. self.quant_config = quant_config
  327. self.unpadded_vocab_size = config.vocab_size
  328. self.transformer = DbrxModel(config, cache_config, quant_config)
  329. self.lm_head = ParallelLMHead(
  330. config.vocab_size,
  331. config.d_model,
  332. org_num_embeddings=config.vocab_size,
  333. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  334. quant_config=quant_config,
  335. )
  336. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  337. config.vocab_size)
  338. self.sampler = Sampler()
  339. def forward(
  340. self,
  341. input_ids: torch.Tensor,
  342. positions: torch.Tensor,
  343. kv_caches: List[torch.Tensor],
  344. attn_metadata: AttentionMetadata,
  345. intermediate_tensors: Optional[IntermediateTensors] = None,
  346. ) -> torch.Tensor:
  347. hidden_states = self.transformer(input_ids, positions, kv_caches,
  348. attn_metadata)
  349. return hidden_states
  350. def compute_logits(self, hidden_states: torch.Tensor,
  351. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  352. logits = self.logits_processor(self.lm_head, hidden_states,
  353. sampling_metadata)
  354. return logits
  355. def sample(
  356. self,
  357. logits: Optional[torch.Tensor],
  358. sampling_metadata: SamplingMetadata,
  359. ) -> Optional[SamplerOutput]:
  360. next_tokens = self.sampler(logits, sampling_metadata)
  361. return next_tokens
  362. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  363. expert_params_mapping = [(
  364. "ws" if weight_name in ["w1", "v1"] else "w2s",
  365. f"experts.mlp.{weight_name}",
  366. ) for weight_name in ["w1", "v1", "w2"]]
  367. params_dict = dict(self.named_parameters(remove_duplicate=False))
  368. for name, loaded_weight in weights:
  369. for param_name, weight_name in expert_params_mapping:
  370. if weight_name not in name:
  371. continue
  372. name = name.replace(weight_name, param_name)
  373. param = params_dict[name]
  374. weight_loader = param.weight_loader
  375. weight_loader(param, loaded_weight, weight_name)
  376. break
  377. else:
  378. param = params_dict[name]
  379. weight_loader = getattr(param, "weight_loader",
  380. default_weight_loader)
  381. weight_loader(param, loaded_weight)