1
0

dbrx.py 15 KB


  1. # coding=utf-8
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import Attention, AttentionMetadata
  6. from aphrodite.modeling.layers.fused_moe import fused_moe
  7. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  8. QKVParallelLinear,
  9. ReplicatedLinear,
  10. RowParallelLinear)
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.rotary_embedding import get_rope
  13. from aphrodite.modeling.layers.sampler import Sampler
  14. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  15. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  16. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  17. get_tensor_model_parallel_world_size,
  18. tensor_model_parallel_all_reduce)
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.modeling.utils import set_weight_attrs
  21. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  22. hf_model_weights_iterator)
  23. from aphrodite.common.sequence import SamplerOutput
  24. from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
  25. class DbrxRouter(nn.Module):
  26. """A Router implementation for DBRX that returns logits for each expert
  27. per token.
  28. """
  29. def __init__(
  30. self,
  31. config: DbrxConfig,
  32. params_dtype: Optional[torch.dtype] = None,
  33. ):
  34. super().__init__()
  35. self.tp_size = get_tensor_model_parallel_world_size()
  36. self.num_total_experts = config.ffn_config.moe_num_experts
  37. self.d_model = config.d_model
  38. self.layer = ReplicatedLinear(
  39. self.d_model,
  40. self.num_total_experts,
  41. bias=False,
  42. params_dtype=params_dtype,
  43. linear_method=None,
  44. )
  45. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  46. router_logits, _ = self.layer(hidden_states)
  47. return router_logits
  48. class DbrxExperts(nn.Module):
  49. """A tensor-parallel MoE implementation for DBRX.
  50. Each expert's weights are sharded across all ranks and a fused MoE
  51. kernel is used for the forward pass, and finally we reduce the outputs
  52. across ranks.
  53. """
  54. def __init__(
  55. self,
  56. config: DbrxConfig,
  57. linear_method: Optional[LinearMethodBase] = None,
  58. params_dtype: Optional[torch.dtype] = None,
  59. ):
  60. super().__init__()
  61. self.tp_size = get_tensor_model_parallel_world_size()
  62. self.num_total_experts = config.ffn_config.moe_num_experts
  63. self.top_k = config.ffn_config.moe_top_k
  64. self.d_model = config.d_model
  65. self.intermediate_size = (config.ffn_config.ffn_hidden_size //
  66. self.tp_size)
  67. if params_dtype is None:
  68. params_dtype = torch.get_default_dtype()
  69. self.params_dtype = params_dtype
  70. self.router = DbrxRouter(config, self.params_dtype)
  71. self.ws = nn.Parameter(
  72. torch.empty(
  73. self.num_total_experts,
  74. 2 * self.intermediate_size,
  75. self.d_model,
  76. device="cuda",
  77. dtype=self.params_dtype,
  78. ))
  79. self.w2s = nn.Parameter(
  80. torch.empty(
  81. self.num_total_experts,
  82. self.d_model,
  83. self.intermediate_size,
  84. device="cuda",
  85. dtype=self.params_dtype,
  86. ))
  87. set_weight_attrs(
  88. self.ws,
  89. {
  90. "weight_loader": self.weight_loader,
  91. },
  92. )
  93. set_weight_attrs(
  94. self.w2s,
  95. {
  96. "weight_loader": self.weight_loader,
  97. },
  98. )
  99. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  100. weight_name: str):
  101. tp_rank = get_tensor_model_parallel_rank()
  102. param_data = param.data
  103. shard_size = self.intermediate_size
  104. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  105. # DBRX uses GLU for each experts.
  106. # GLU has 3 linear layers: w1, v1 and w2.
  107. if weight_name.endswith("w1"):
  108. loaded_weight = torch.reshape(
  109. loaded_weight,
  110. [-1, self.intermediate_size * self.tp_size, self.d_model],
  111. )
  112. param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
  113. if weight_name.endswith("v1"):
  114. loaded_weight = torch.reshape(
  115. loaded_weight,
  116. [-1, self.intermediate_size * self.tp_size, self.d_model],
  117. )
  118. param_data[:,
  119. shard_size:2 * shard_size, :] = loaded_weight[:,
  120. shard, :]
  121. if weight_name.endswith("w2"):
  122. loaded_weight = torch.reshape(
  123. loaded_weight,
  124. [-1, self.intermediate_size * self.tp_size, self.d_model],
  125. ).transpose(1, 2)
  126. param_data[:] = loaded_weight[:, :, shard]
  127. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  128. num_tokens, hidden_size = hidden_states.shape
  129. hidden_states = hidden_states.view(-1, self.d_model)
  130. # router_logits: (num_tokens, n_experts)
  131. router_logits = self.router(hidden_states)
  132. final_hidden_states = fused_moe(
  133. hidden_states,
  134. self.ws,
  135. self.w2s,
  136. router_logits,
  137. self.top_k,
  138. renormalize=True,
  139. inplace=True,
  140. )
  141. if self.tp_size > 1:
  142. final_hidden_states = tensor_model_parallel_all_reduce(
  143. final_hidden_states)
  144. return final_hidden_states.view(num_tokens, hidden_size)
  145. class DbrxAttention(nn.Module):
  146. def __init__(
  147. self,
  148. config: DbrxConfig,
  149. linear_method: Optional[LinearMethodBase] = None,
  150. ):
  151. super().__init__()
  152. self.d_model = config.d_model
  153. self.total_num_heads = config.n_heads
  154. self.head_dim = self.d_model // self.total_num_heads
  155. self.total_num_kv_heads = config.attn_config.kv_n_heads
  156. self.clip_qkv = config.attn_config.clip_qkv
  157. self.rope_theta = config.attn_config.rope_theta
  158. self.max_position = config.max_seq_len
  159. # pylint: disable=invalid-name
  160. self.Wqkv = QKVParallelLinear(
  161. self.d_model,
  162. self.head_dim,
  163. self.total_num_heads,
  164. self.total_num_kv_heads,
  165. bias=False,
  166. linear_method=linear_method,
  167. )
  168. self.out_proj = RowParallelLinear(
  169. self.d_model,
  170. self.d_model,
  171. bias=False,
  172. linear_method=linear_method,
  173. )
  174. self.rotary_emb = get_rope(
  175. self.head_dim,
  176. rotary_dim=self.head_dim,
  177. max_position=self.max_position,
  178. base=int(self.rope_theta),
  179. is_neox_style=True,
  180. )
  181. tp_world_size = get_tensor_model_parallel_world_size()
  182. self.tp_size = tp_world_size
  183. assert self.total_num_heads % tp_world_size == 0
  184. self.num_heads = self.total_num_heads // tp_world_size
  185. if self.total_num_kv_heads >= tp_world_size:
  186. # Number of KV heads is greater than TP size, so we partition
  187. # the KV heads across multiple tensor parallel GPUs.
  188. assert self.total_num_kv_heads % tp_world_size == 0
  189. else:
  190. # Number of KV heads is less than TP size, so we replicate
  191. # the KV heads across multiple tensor parallel GPUs.
  192. assert tp_world_size % self.total_num_kv_heads == 0
  193. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  194. self.q_size = self.num_heads * self.head_dim
  195. self.kv_size = self.num_kv_heads * self.head_dim
  196. self.scaling = self.head_dim**-0.5
  197. self.attn = Attention(
  198. self.num_heads,
  199. self.head_dim,
  200. self.scaling,
  201. num_kv_heads=self.num_kv_heads,
  202. )
  203. def forward(
  204. self,
  205. position_ids: torch.Tensor,
  206. hidden_states: torch.Tensor,
  207. kv_cache: torch.Tensor,
  208. attn_metadata: AttentionMetadata,
  209. ) -> torch.Tensor:
  210. qkv, _ = self.Wqkv(hidden_states)
  211. if self.clip_qkv is not None:
  212. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  213. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  214. q, k = self.rotary_emb(position_ids, q, k)
  215. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  216. hidden_states, _ = self.out_proj(attn_output)
  217. return hidden_states
  218. class DbrxFusedNormAttention(nn.Module):
  219. def __init__(
  220. self,
  221. config: DbrxConfig,
  222. linear_method: Optional[LinearMethodBase] = None,
  223. ):
  224. super().__init__()
  225. self.d_model = config.d_model
  226. self.attn = DbrxAttention(config, linear_method)
  227. self.norm_1 = nn.LayerNorm(self.d_model)
  228. self.norm_2 = nn.LayerNorm(self.d_model)
  229. def forward(
  230. self,
  231. position_ids: torch.Tensor,
  232. hidden_states: torch.Tensor,
  233. kv_cache: torch.Tensor,
  234. attn_metadata: AttentionMetadata,
  235. ) -> torch.Tensor:
  236. residual = hidden_states
  237. hidden_states = self.norm_1(hidden_states)
  238. x = self.attn(
  239. position_ids=position_ids,
  240. hidden_states=hidden_states,
  241. kv_cache=kv_cache,
  242. attn_metadata=attn_metadata,
  243. )
  244. hidden_states = residual + x
  245. residual = hidden_states
  246. hidden_states = self.norm_2(hidden_states)
  247. return hidden_states, residual
  248. class DbrxBlock(nn.Module):
  249. def __init__(
  250. self,
  251. config: DbrxConfig,
  252. linear_method: Optional[LinearMethodBase] = None,
  253. ):
  254. super().__init__()
  255. self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
  256. self.ffn = DbrxExperts(config, linear_method)
  257. def forward(
  258. self,
  259. position_ids: torch.Tensor,
  260. hidden_states: torch.Tensor,
  261. kv_cache: torch.Tensor,
  262. attn_metadata: AttentionMetadata,
  263. ) -> torch.Tensor:
  264. hidden_states, residual = self.norm_attn_norm(
  265. position_ids=position_ids,
  266. hidden_states=hidden_states,
  267. kv_cache=kv_cache,
  268. attn_metadata=attn_metadata,
  269. )
  270. hidden_states = self.ffn(hidden_states)
  271. hidden_states = hidden_states + residual
  272. return hidden_states
  273. class DbrxModel(nn.Module):
  274. def __init__(
  275. self,
  276. config: DbrxConfig,
  277. linear_method: Optional[LinearMethodBase] = None,
  278. ):
  279. super().__init__()
  280. self.wte = VocabParallelEmbedding(config.vocab_size,
  281. config.d_model,
  282. linear_method=linear_method)
  283. self.blocks = nn.ModuleList(
  284. [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
  285. self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
  286. for module in self.modules():
  287. if hasattr(module, "bias") and isinstance(module.bias,
  288. nn.Parameter):
  289. # Remove the bias term in Linear and LayerNorm.
  290. module.register_parameter("bias", None)
  291. def forward(
  292. self,
  293. input_ids: torch.Tensor,
  294. position_ids: torch.Tensor,
  295. kv_caches: List[torch.Tensor],
  296. attn_metadata: AttentionMetadata,
  297. ) -> torch.Tensor:
  298. hidden_states = self.wte(input_ids)
  299. for i in range(len(self.blocks)):
  300. block = self.blocks[i]
  301. hidden_states = block(
  302. position_ids,
  303. hidden_states,
  304. kv_caches[i],
  305. attn_metadata,
  306. )
  307. hidden_states = self.norm_f(hidden_states)
  308. return hidden_states
  309. class DbrxForCausalLM(nn.Module):
  310. def __init__(
  311. self,
  312. config: DbrxConfig,
  313. linear_method: Optional[LinearMethodBase] = None,
  314. ):
  315. super().__init__()
  316. self.config = config
  317. self.linear_method = linear_method
  318. self.unpadded_vocab_size = config.vocab_size
  319. self.transformer = DbrxModel(config, linear_method)
  320. self.lm_head = ParallelLMHead(config.vocab_size,
  321. config.d_model,
  322. org_num_embeddings=config.vocab_size,
  323. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  324. linear_method=linear_method)
  325. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  326. config.vocab_size)
  327. self.sampler = Sampler()
  328. def forward(
  329. self,
  330. input_ids: torch.Tensor,
  331. positions: torch.Tensor,
  332. kv_caches: List[torch.Tensor],
  333. attn_metadata: AttentionMetadata,
  334. ) -> torch.Tensor:
  335. hidden_states = self.transformer(input_ids, positions, kv_caches,
  336. attn_metadata)
  337. return hidden_states
  338. def compute_logits(self, hidden_states: torch.Tensor,
  339. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  340. logits = self.logits_processor(self.lm_head, hidden_states,
  341. sampling_metadata)
  342. return logits
  343. def sample(
  344. self,
  345. logits: Optional[torch.Tensor],
  346. sampling_metadata: SamplingMetadata,
  347. ) -> Optional[SamplerOutput]:
  348. next_tokens = self.sampler(logits, sampling_metadata)
  349. return next_tokens
  350. def load_weights(
  351. self,
  352. model_name_or_path: str,
  353. cache_dir: Optional[str] = None,
  354. load_format: str = "auto",
  355. revision: Optional[str] = None,
  356. ):
  357. expert_params_mapping = [(
  358. "ws" if weight_name in ["w1", "v1"] else "w2s",
  359. f"experts.mlp.{weight_name}",
  360. ) for weight_name in ["w1", "v1", "w2"]]
  361. params_dict = dict(self.named_parameters(remove_duplicate=False))
  362. for name, loaded_weight in hf_model_weights_iterator(
  363. model_name_or_path, cache_dir, load_format, revision,
  364. self.config):
  365. for param_name, weight_name in expert_params_mapping:
  366. if weight_name not in name:
  367. continue
  368. name = name.replace(weight_name, param_name)
  369. param = params_dict[name]
  370. weight_loader = param.weight_loader
  371. weight_loader(param, loaded_weight, weight_name)
  372. break
  373. else:
  374. param = params_dict[name]
  375. weight_loader = getattr(param, "weight_loader",
  376. default_weight_loader)
  377. weight_loader(param, loaded_weight)