dbrx.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # coding=utf-8
  2. from typing import Iterable, List, Optional, Tuple
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import Attention, AttentionMetadata
  6. from aphrodite.common.sequence import SamplerOutput
  7. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size,
  9. tensor_model_parallel_all_reduce)
  10. from aphrodite.modeling.layers.fused_moe import fused_moe
  11. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  12. QKVParallelLinear,
  13. ReplicatedLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.modeling.utils import set_weight_attrs
  23. from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
  24. class DbrxRouter(nn.Module):
  25. """A Router implementation for DBRX that returns logits for each expert
  26. per token.
  27. """
  28. def __init__(
  29. self,
  30. config: DbrxConfig,
  31. params_dtype: Optional[torch.dtype] = None,
  32. ):
  33. super().__init__()
  34. self.tp_size = get_tensor_model_parallel_world_size()
  35. self.num_total_experts = config.ffn_config.moe_num_experts
  36. self.d_model = config.d_model
  37. self.layer = ReplicatedLinear(
  38. self.d_model,
  39. self.num_total_experts,
  40. bias=False,
  41. params_dtype=params_dtype,
  42. linear_method=None,
  43. )
  44. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  45. router_logits, _ = self.layer(hidden_states)
  46. return router_logits
  47. class DbrxExperts(nn.Module):
  48. """A tensor-parallel MoE implementation for DBRX.
  49. Each expert's weights are sharded across all ranks and a fused MoE
  50. kernel is used for the forward pass, and finally we reduce the outputs
  51. across ranks.
  52. """
  53. def __init__(
  54. self,
  55. config: DbrxConfig,
  56. linear_method: Optional[LinearMethodBase] = None,
  57. params_dtype: Optional[torch.dtype] = None,
  58. ):
  59. super().__init__()
  60. self.tp_size = get_tensor_model_parallel_world_size()
  61. self.num_total_experts = config.ffn_config.moe_num_experts
  62. self.top_k = config.ffn_config.moe_top_k
  63. self.d_model = config.d_model
  64. self.intermediate_size = (config.ffn_config.ffn_hidden_size //
  65. self.tp_size)
  66. if params_dtype is None:
  67. params_dtype = torch.get_default_dtype()
  68. self.params_dtype = params_dtype
  69. self.router = DbrxRouter(config, self.params_dtype)
  70. self.ws = nn.Parameter(
  71. torch.empty(
  72. self.num_total_experts,
  73. 2 * self.intermediate_size,
  74. self.d_model,
  75. device="cuda",
  76. dtype=self.params_dtype,
  77. ))
  78. self.w2s = nn.Parameter(
  79. torch.empty(
  80. self.num_total_experts,
  81. self.d_model,
  82. self.intermediate_size,
  83. device="cuda",
  84. dtype=self.params_dtype,
  85. ))
  86. set_weight_attrs(
  87. self.ws,
  88. {
  89. "weight_loader": self.weight_loader,
  90. },
  91. )
  92. set_weight_attrs(
  93. self.w2s,
  94. {
  95. "weight_loader": self.weight_loader,
  96. },
  97. )
  98. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  99. weight_name: str):
  100. tp_rank = get_tensor_model_parallel_rank()
  101. param_data = param.data
  102. shard_size = self.intermediate_size
  103. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  104. # DBRX uses GLU for each experts.
  105. # GLU has 3 linear layers: w1, v1 and w2.
  106. if weight_name.endswith("w1"):
  107. loaded_weight = torch.reshape(
  108. loaded_weight,
  109. [-1, self.intermediate_size * self.tp_size, self.d_model],
  110. )
  111. param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
  112. if weight_name.endswith("v1"):
  113. loaded_weight = torch.reshape(
  114. loaded_weight,
  115. [-1, self.intermediate_size * self.tp_size, self.d_model],
  116. )
  117. param_data[:,
  118. shard_size:2 * shard_size, :] = loaded_weight[:,
  119. shard, :]
  120. if weight_name.endswith("w2"):
  121. loaded_weight = torch.reshape(
  122. loaded_weight,
  123. [-1, self.intermediate_size * self.tp_size, self.d_model],
  124. ).transpose(1, 2)
  125. param_data[:] = loaded_weight[:, :, shard]
  126. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  127. num_tokens, hidden_size = hidden_states.shape
  128. hidden_states = hidden_states.view(-1, self.d_model)
  129. # router_logits: (num_tokens, n_experts)
  130. router_logits = self.router(hidden_states)
  131. final_hidden_states = fused_moe(
  132. hidden_states,
  133. self.ws,
  134. self.w2s,
  135. router_logits,
  136. self.top_k,
  137. renormalize=True,
  138. inplace=True,
  139. )
  140. if self.tp_size > 1:
  141. final_hidden_states = tensor_model_parallel_all_reduce(
  142. final_hidden_states)
  143. return final_hidden_states.view(num_tokens, hidden_size)
  144. class DbrxAttention(nn.Module):
  145. def __init__(
  146. self,
  147. config: DbrxConfig,
  148. linear_method: Optional[LinearMethodBase] = None,
  149. ):
  150. super().__init__()
  151. self.d_model = config.d_model
  152. self.total_num_heads = config.n_heads
  153. self.head_dim = self.d_model // self.total_num_heads
  154. self.total_num_kv_heads = config.attn_config.kv_n_heads
  155. self.clip_qkv = config.attn_config.clip_qkv
  156. self.rope_theta = config.attn_config.rope_theta
  157. self.max_position = config.max_seq_len
  158. # pylint: disable=invalid-name
  159. self.Wqkv = QKVParallelLinear(
  160. self.d_model,
  161. self.head_dim,
  162. self.total_num_heads,
  163. self.total_num_kv_heads,
  164. bias=False,
  165. linear_method=linear_method,
  166. )
  167. self.out_proj = RowParallelLinear(
  168. self.d_model,
  169. self.d_model,
  170. bias=False,
  171. linear_method=linear_method,
  172. )
  173. self.rotary_emb = get_rope(
  174. self.head_dim,
  175. rotary_dim=self.head_dim,
  176. max_position=self.max_position,
  177. base=int(self.rope_theta),
  178. is_neox_style=True,
  179. )
  180. tp_world_size = get_tensor_model_parallel_world_size()
  181. self.tp_size = tp_world_size
  182. assert self.total_num_heads % tp_world_size == 0
  183. self.num_heads = self.total_num_heads // tp_world_size
  184. if self.total_num_kv_heads >= tp_world_size:
  185. # Number of KV heads is greater than TP size, so we partition
  186. # the KV heads across multiple tensor parallel GPUs.
  187. assert self.total_num_kv_heads % tp_world_size == 0
  188. else:
  189. # Number of KV heads is less than TP size, so we replicate
  190. # the KV heads across multiple tensor parallel GPUs.
  191. assert tp_world_size % self.total_num_kv_heads == 0
  192. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  193. self.q_size = self.num_heads * self.head_dim
  194. self.kv_size = self.num_kv_heads * self.head_dim
  195. self.scaling = self.head_dim**-0.5
  196. self.attn = Attention(
  197. self.num_heads,
  198. self.head_dim,
  199. self.scaling,
  200. num_kv_heads=self.num_kv_heads,
  201. )
  202. def forward(
  203. self,
  204. position_ids: torch.Tensor,
  205. hidden_states: torch.Tensor,
  206. kv_cache: torch.Tensor,
  207. attn_metadata: AttentionMetadata,
  208. ) -> torch.Tensor:
  209. qkv, _ = self.Wqkv(hidden_states)
  210. if self.clip_qkv is not None:
  211. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  212. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  213. q, k = self.rotary_emb(position_ids, q, k)
  214. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  215. hidden_states, _ = self.out_proj(attn_output)
  216. return hidden_states
  217. class DbrxFusedNormAttention(nn.Module):
  218. def __init__(
  219. self,
  220. config: DbrxConfig,
  221. linear_method: Optional[LinearMethodBase] = None,
  222. ):
  223. super().__init__()
  224. self.d_model = config.d_model
  225. self.attn = DbrxAttention(config, linear_method)
  226. self.norm_1 = nn.LayerNorm(self.d_model)
  227. self.norm_2 = nn.LayerNorm(self.d_model)
  228. def forward(
  229. self,
  230. position_ids: torch.Tensor,
  231. hidden_states: torch.Tensor,
  232. kv_cache: torch.Tensor,
  233. attn_metadata: AttentionMetadata,
  234. ) -> torch.Tensor:
  235. residual = hidden_states
  236. hidden_states = self.norm_1(hidden_states)
  237. x = self.attn(
  238. position_ids=position_ids,
  239. hidden_states=hidden_states,
  240. kv_cache=kv_cache,
  241. attn_metadata=attn_metadata,
  242. )
  243. hidden_states = residual + x
  244. residual = hidden_states
  245. hidden_states = self.norm_2(hidden_states)
  246. return hidden_states, residual
  247. class DbrxBlock(nn.Module):
  248. def __init__(
  249. self,
  250. config: DbrxConfig,
  251. linear_method: Optional[LinearMethodBase] = None,
  252. ):
  253. super().__init__()
  254. self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
  255. self.ffn = DbrxExperts(config, linear_method)
  256. def forward(
  257. self,
  258. position_ids: torch.Tensor,
  259. hidden_states: torch.Tensor,
  260. kv_cache: torch.Tensor,
  261. attn_metadata: AttentionMetadata,
  262. ) -> torch.Tensor:
  263. hidden_states, residual = self.norm_attn_norm(
  264. position_ids=position_ids,
  265. hidden_states=hidden_states,
  266. kv_cache=kv_cache,
  267. attn_metadata=attn_metadata,
  268. )
  269. hidden_states = self.ffn(hidden_states)
  270. hidden_states = hidden_states + residual
  271. return hidden_states
  272. class DbrxModel(nn.Module):
  273. def __init__(
  274. self,
  275. config: DbrxConfig,
  276. linear_method: Optional[LinearMethodBase] = None,
  277. ):
  278. super().__init__()
  279. self.wte = VocabParallelEmbedding(
  280. config.vocab_size,
  281. config.d_model,
  282. )
  283. self.blocks = nn.ModuleList(
  284. [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
  285. self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
  286. for module in self.modules():
  287. if hasattr(module, "bias") and isinstance(module.bias,
  288. nn.Parameter):
  289. # Remove the bias term in Linear and LayerNorm.
  290. module.register_parameter("bias", None)
  291. def forward(
  292. self,
  293. input_ids: torch.Tensor,
  294. position_ids: torch.Tensor,
  295. kv_caches: List[torch.Tensor],
  296. attn_metadata: AttentionMetadata,
  297. ) -> torch.Tensor:
  298. hidden_states = self.wte(input_ids)
  299. for i in range(len(self.blocks)):
  300. block = self.blocks[i]
  301. hidden_states = block(
  302. position_ids,
  303. hidden_states,
  304. kv_caches[i],
  305. attn_metadata,
  306. )
  307. hidden_states = self.norm_f(hidden_states)
  308. return hidden_states
  309. class DbrxForCausalLM(nn.Module):
  310. def __init__(
  311. self,
  312. config: DbrxConfig,
  313. linear_method: Optional[LinearMethodBase] = None,
  314. ):
  315. super().__init__()
  316. self.config = config
  317. self.linear_method = linear_method
  318. self.unpadded_vocab_size = config.vocab_size
  319. self.transformer = DbrxModel(config, linear_method)
  320. self.lm_head = ParallelLMHead(
  321. config.vocab_size,
  322. config.d_model,
  323. org_num_embeddings=config.vocab_size,
  324. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  325. )
  326. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  327. config.vocab_size)
  328. self.sampler = Sampler()
  329. def forward(
  330. self,
  331. input_ids: torch.Tensor,
  332. positions: torch.Tensor,
  333. kv_caches: List[torch.Tensor],
  334. attn_metadata: AttentionMetadata,
  335. ) -> torch.Tensor:
  336. hidden_states = self.transformer(input_ids, positions, kv_caches,
  337. attn_metadata)
  338. return hidden_states
  339. def compute_logits(self, hidden_states: torch.Tensor,
  340. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  341. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  342. sampling_metadata)
  343. return logits
  344. def sample(
  345. self,
  346. logits: Optional[torch.Tensor],
  347. sampling_metadata: SamplingMetadata,
  348. ) -> Optional[SamplerOutput]:
  349. next_tokens = self.sampler(logits, sampling_metadata)
  350. return next_tokens
  351. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  352. expert_params_mapping = [(
  353. "ws" if weight_name in ["w1", "v1"] else "w2s",
  354. f"experts.mlp.{weight_name}",
  355. ) for weight_name in ["w1", "v1", "w2"]]
  356. params_dict = dict(self.named_parameters(remove_duplicate=False))
  357. for name, loaded_weight in weights:
  358. for param_name, weight_name in expert_params_mapping:
  359. if weight_name not in name:
  360. continue
  361. name = name.replace(weight_name, param_name)
  362. param = params_dict[name]
  363. weight_loader = param.weight_loader
  364. weight_loader(param, loaded_weight, weight_name)
  365. break
  366. else:
  367. param = params_dict[name]
  368. weight_loader = getattr(param, "weight_loader",
  369. default_weight_loader)
  370. weight_loader(param, loaded_weight)