dbrx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. # coding=utf-8
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn as nn
  5. from aphrodite.attention import Attention, AttentionMetadata
  6. from aphrodite.modeling.layers.fused_moe import fused_moe
  7. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  8. QKVParallelLinear,
  9. ReplicatedLinear,
  10. RowParallelLinear)
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.rotary_embedding import get_rope
  13. from aphrodite.modeling.layers.sampler import Sampler
  14. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  15. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  16. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  17. get_tensor_model_parallel_world_size,
  18. tensor_model_parallel_all_reduce)
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.modeling.utils import set_weight_attrs
  21. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  22. hf_model_weights_iterator)
  23. from aphrodite.common.sequence import SamplerOutput
  24. from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
  25. class DbrxRouter(nn.Module):
  26. """A Router implementation for DBRX that returns logits for each expert
  27. per token.
  28. """
  29. def __init__(
  30. self,
  31. config: DbrxConfig,
  32. params_dtype: Optional[torch.dtype] = None,
  33. ):
  34. super().__init__()
  35. self.tp_size = get_tensor_model_parallel_world_size()
  36. self.num_total_experts = config.ffn_config.moe_num_experts
  37. self.d_model = config.d_model
  38. self.layer = ReplicatedLinear(
  39. self.d_model,
  40. self.num_total_experts,
  41. bias=False,
  42. params_dtype=params_dtype,
  43. linear_method=None,
  44. )
  45. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  46. router_logits, _ = self.layer(hidden_states)
  47. return router_logits
  48. class DbrxExperts(nn.Module):
  49. """A tensor-parallel MoE implementation for DBRX.
  50. Each expert's weights are sharded across all ranks and a fused MoE
  51. kernel is used for the forward pass, and finally we reduce the outputs
  52. across ranks.
  53. """
  54. def __init__(
  55. self,
  56. config: DbrxConfig,
  57. linear_method: Optional[LinearMethodBase] = None,
  58. params_dtype: Optional[torch.dtype] = None,
  59. ):
  60. super().__init__()
  61. self.tp_size = get_tensor_model_parallel_world_size()
  62. self.num_total_experts = config.ffn_config.moe_num_experts
  63. self.top_k = config.ffn_config.moe_top_k
  64. self.d_model = config.d_model
  65. self.intermediate_size = (config.ffn_config.ffn_hidden_size //
  66. self.tp_size)
  67. if params_dtype is None:
  68. params_dtype = torch.get_default_dtype()
  69. self.params_dtype = params_dtype
  70. self.router = DbrxRouter(config, self.params_dtype)
  71. self.ws = nn.Parameter(
  72. torch.empty(
  73. self.num_total_experts,
  74. 2 * self.intermediate_size,
  75. self.d_model,
  76. device="cuda",
  77. dtype=self.params_dtype,
  78. ))
  79. self.w2s = nn.Parameter(
  80. torch.empty(
  81. self.num_total_experts,
  82. self.d_model,
  83. self.intermediate_size,
  84. device="cuda",
  85. dtype=self.params_dtype,
  86. ))
  87. set_weight_attrs(
  88. self.ws,
  89. {
  90. "weight_loader": self.weight_loader,
  91. },
  92. )
  93. set_weight_attrs(
  94. self.w2s,
  95. {
  96. "weight_loader": self.weight_loader,
  97. },
  98. )
  99. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  100. weight_name: str):
  101. tp_rank = get_tensor_model_parallel_rank()
  102. param_data = param.data
  103. shard_size = self.intermediate_size
  104. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  105. # DBRX uses GLU for each experts.
  106. # GLU has 3 linear layers: w1, v1 and w2.
  107. if weight_name.endswith("w1"):
  108. loaded_weight = torch.reshape(
  109. loaded_weight,
  110. [-1, self.intermediate_size * self.tp_size, self.d_model],
  111. )
  112. param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
  113. if weight_name.endswith("v1"):
  114. loaded_weight = torch.reshape(
  115. loaded_weight,
  116. [-1, self.intermediate_size * self.tp_size, self.d_model],
  117. )
  118. param_data[:,
  119. shard_size:2 * shard_size, :] = loaded_weight[:,
  120. shard, :]
  121. if weight_name.endswith("w2"):
  122. loaded_weight = torch.reshape(
  123. loaded_weight,
  124. [-1, self.intermediate_size * self.tp_size, self.d_model],
  125. ).transpose(1, 2)
  126. param_data[:] = loaded_weight[:, :, shard]
  127. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  128. num_tokens, hidden_size = hidden_states.shape
  129. hidden_states = hidden_states.view(-1, self.d_model)
  130. # router_logits: (num_tokens, n_experts)
  131. router_logits = self.router(hidden_states)
  132. final_hidden_states = fused_moe(
  133. hidden_states,
  134. self.ws,
  135. self.w2s,
  136. router_logits,
  137. self.top_k,
  138. renormalize=True,
  139. inplace=True,
  140. )
  141. if self.tp_size > 1:
  142. final_hidden_states = tensor_model_parallel_all_reduce(
  143. final_hidden_states)
  144. return final_hidden_states.view(num_tokens, hidden_size)
  145. class DbrxAttention(nn.Module):
  146. def __init__(
  147. self,
  148. config: DbrxConfig,
  149. linear_method: Optional[LinearMethodBase] = None,
  150. ):
  151. super().__init__()
  152. self.d_model = config.d_model
  153. self.total_num_heads = config.n_heads
  154. self.head_dim = self.d_model // self.total_num_heads
  155. self.total_num_kv_heads = config.attn_config.kv_n_heads
  156. self.clip_qkv = config.attn_config.clip_qkv
  157. self.rope_theta = config.attn_config.rope_theta
  158. self.max_position = config.max_seq_len
  159. # pylint: disable=invalid-name
  160. self.Wqkv = QKVParallelLinear(
  161. self.d_model,
  162. self.head_dim,
  163. self.total_num_heads,
  164. self.total_num_kv_heads,
  165. bias=False,
  166. linear_method=linear_method,
  167. )
  168. self.out_proj = RowParallelLinear(
  169. self.d_model,
  170. self.d_model,
  171. bias=False,
  172. linear_method=linear_method,
  173. )
  174. is_neox_style = (True if linear_method is None
  175. or linear_method.quant_config.rope_style() is None
  176. else linear_method.quant_config.rope_style())
  177. self.rotary_emb = get_rope(
  178. self.head_dim,
  179. rotary_dim=self.head_dim,
  180. max_position=self.max_position,
  181. base=int(self.rope_theta),
  182. is_neox_style=is_neox_style,
  183. )
  184. tp_world_size = get_tensor_model_parallel_world_size()
  185. self.tp_size = tp_world_size
  186. assert self.total_num_heads % tp_world_size == 0
  187. self.num_heads = self.total_num_heads // tp_world_size
  188. if self.total_num_kv_heads >= tp_world_size:
  189. # Number of KV heads is greater than TP size, so we partition
  190. # the KV heads across multiple tensor parallel GPUs.
  191. assert self.total_num_kv_heads % tp_world_size == 0
  192. else:
  193. # Number of KV heads is less than TP size, so we replicate
  194. # the KV heads across multiple tensor parallel GPUs.
  195. assert tp_world_size % self.total_num_kv_heads == 0
  196. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
  197. self.q_size = self.num_heads * self.head_dim
  198. self.kv_size = self.num_kv_heads * self.head_dim
  199. self.scaling = self.head_dim**-0.5
  200. self.attn = Attention(
  201. self.num_heads,
  202. self.head_dim,
  203. self.scaling,
  204. num_kv_heads=self.num_kv_heads,
  205. )
  206. def forward(
  207. self,
  208. position_ids: torch.Tensor,
  209. hidden_states: torch.Tensor,
  210. kv_cache: torch.Tensor,
  211. attn_metadata: AttentionMetadata,
  212. ) -> torch.Tensor:
  213. qkv, _ = self.Wqkv(hidden_states)
  214. if self.clip_qkv is not None:
  215. qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
  216. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  217. q, k = self.rotary_emb(position_ids, q, k)
  218. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  219. hidden_states, _ = self.out_proj(attn_output)
  220. return hidden_states
  221. class DbrxFusedNormAttention(nn.Module):
  222. def __init__(
  223. self,
  224. config: DbrxConfig,
  225. linear_method: Optional[LinearMethodBase] = None,
  226. ):
  227. super().__init__()
  228. self.d_model = config.d_model
  229. self.attn = DbrxAttention(config, linear_method)
  230. self.norm_1 = nn.LayerNorm(self.d_model)
  231. self.norm_2 = nn.LayerNorm(self.d_model)
  232. def forward(
  233. self,
  234. position_ids: torch.Tensor,
  235. hidden_states: torch.Tensor,
  236. kv_cache: torch.Tensor,
  237. attn_metadata: AttentionMetadata,
  238. ) -> torch.Tensor:
  239. residual = hidden_states
  240. hidden_states = self.norm_1(hidden_states)
  241. x = self.attn(
  242. position_ids=position_ids,
  243. hidden_states=hidden_states,
  244. kv_cache=kv_cache,
  245. attn_metadata=attn_metadata,
  246. )
  247. hidden_states = residual + x
  248. residual = hidden_states
  249. hidden_states = self.norm_2(hidden_states)
  250. return hidden_states, residual
  251. class DbrxBlock(nn.Module):
  252. def __init__(
  253. self,
  254. config: DbrxConfig,
  255. linear_method: Optional[LinearMethodBase] = None,
  256. ):
  257. super().__init__()
  258. self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
  259. self.ffn = DbrxExperts(config, linear_method)
  260. def forward(
  261. self,
  262. position_ids: torch.Tensor,
  263. hidden_states: torch.Tensor,
  264. kv_cache: torch.Tensor,
  265. attn_metadata: AttentionMetadata,
  266. ) -> torch.Tensor:
  267. hidden_states, residual = self.norm_attn_norm(
  268. position_ids=position_ids,
  269. hidden_states=hidden_states,
  270. kv_cache=kv_cache,
  271. attn_metadata=attn_metadata,
  272. )
  273. hidden_states = self.ffn(hidden_states)
  274. hidden_states = hidden_states + residual
  275. return hidden_states
  276. class DbrxModel(nn.Module):
  277. def __init__(
  278. self,
  279. config: DbrxConfig,
  280. linear_method: Optional[LinearMethodBase] = None,
  281. ):
  282. super().__init__()
  283. self.wte = VocabParallelEmbedding(config.vocab_size,
  284. config.d_model,
  285. linear_method=linear_method)
  286. self.blocks = nn.ModuleList(
  287. [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
  288. self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
  289. for module in self.modules():
  290. if hasattr(module, "bias") and isinstance(module.bias,
  291. nn.Parameter):
  292. # Remove the bias term in Linear and LayerNorm.
  293. module.register_parameter("bias", None)
  294. def forward(
  295. self,
  296. input_ids: torch.Tensor,
  297. position_ids: torch.Tensor,
  298. kv_caches: List[torch.Tensor],
  299. attn_metadata: AttentionMetadata,
  300. ) -> torch.Tensor:
  301. hidden_states = self.wte(input_ids)
  302. for i in range(len(self.blocks)):
  303. block = self.blocks[i]
  304. hidden_states = block(
  305. position_ids,
  306. hidden_states,
  307. kv_caches[i],
  308. attn_metadata,
  309. )
  310. hidden_states = self.norm_f(hidden_states)
  311. return hidden_states
  312. class DbrxForCausalLM(nn.Module):
  313. def __init__(
  314. self,
  315. config: DbrxConfig,
  316. linear_method: Optional[LinearMethodBase] = None,
  317. ):
  318. super().__init__()
  319. self.config = config
  320. self.linear_method = linear_method
  321. self.unpadded_vocab_size = config.vocab_size
  322. self.transformer = DbrxModel(config, linear_method)
  323. self.lm_head = ParallelLMHead(config.vocab_size,
  324. config.d_model,
  325. org_num_embeddings=config.vocab_size,
  326. padding_size=DEFAULT_VOCAB_PADDING_SIZE,
  327. linear_method=linear_method)
  328. self.logits_processor = LogitsProcessor(
  329. self.unpadded_vocab_size,
  330. min(config.vocab_size, config.tokenizer_vocab_size))
  331. self.sampler = Sampler()
  332. def forward(
  333. self,
  334. input_ids: torch.Tensor,
  335. positions: torch.Tensor,
  336. kv_caches: List[torch.Tensor],
  337. attn_metadata: AttentionMetadata,
  338. ) -> torch.Tensor:
  339. hidden_states = self.transformer(input_ids, positions, kv_caches,
  340. attn_metadata)
  341. return hidden_states
  342. def compute_logits(self, hidden_states: torch.Tensor,
  343. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  344. logits = self.logits_processor(self.lm_head, hidden_states,
  345. sampling_metadata)
  346. return logits
  347. def sample(
  348. self,
  349. logits: Optional[torch.Tensor],
  350. sampling_metadata: SamplingMetadata,
  351. ) -> Optional[SamplerOutput]:
  352. next_tokens = self.sampler(logits, sampling_metadata)
  353. return next_tokens
  354. def load_weights(
  355. self,
  356. model_name_or_path: str,
  357. cache_dir: Optional[str] = None,
  358. load_format: str = "auto",
  359. revision: Optional[str] = None,
  360. ):
  361. expert_params_mapping = [(
  362. "ws" if weight_name in ["w1", "v1"] else "w2s",
  363. f"experts.mlp.{weight_name}",
  364. ) for weight_name in ["w1", "v1", "w2"]]
  365. params_dict = dict(self.named_parameters(remove_duplicate=False))
  366. for name, loaded_weight in hf_model_weights_iterator(
  367. model_name_or_path, cache_dir, load_format, revision,
  368. self.config):
  369. for param_name, weight_name in expert_params_mapping:
  370. if weight_name not in name:
  371. continue
  372. name = name.replace(weight_name, param_name)
  373. param = params_dict[name]
  374. weight_loader = param.weight_loader
  375. weight_loader(param, loaded_weight, weight_name)
  376. break
  377. else:
  378. param = params_dict[name]
  379. weight_loader = getattr(param, "weight_loader",
  380. default_weight_loader)
  381. weight_loader(param, loaded_weight)