jamba.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. # coding=utf-8
  2. """Inference-only Jamba model."""
  3. from dataclasses import dataclass
  4. from typing import Iterable, List, Optional, Tuple
  5. import torch
  6. from torch import nn
  7. from torch.nn.parameter import Parameter
  8. from transformers import JambaConfig
  9. from aphrodite.attention.backends.abstract import AttentionMetadata
  10. from aphrodite.attention.layer import Attention
  11. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  12. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  13. # yapf: disable
  14. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  15. get_tensor_model_parallel_world_size)
  16. # yapf: enable
  17. from aphrodite.modeling.layers.fused_moe import FusedMoE
  18. from aphrodite.modeling.layers.layernorm import RMSNorm
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. ReplicatedLinear,
  23. RowParallelLinear)
  24. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  25. from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
  26. causal_conv1d_update,
  27. selective_scan_fn,
  28. selective_state_update)
  29. from aphrodite.modeling.layers.sampler import Sampler
  30. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  31. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  32. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  33. from aphrodite.modeling.models.interfaces import HasInnerState
  34. from aphrodite.modeling.models.mamba_cache import MambaCacheManager
  35. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  36. from aphrodite.modeling.utils import set_weight_attrs
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
  39. _get_graph_batch_size)
  40. KVCache = Tuple[torch.Tensor, torch.Tensor]
  41. @dataclass
  42. class MambaCacheParams:
  43. is_prompt: bool = False
  44. conv_state: torch.Tensor = torch.Tensor()
  45. ssm_state: torch.Tensor = torch.Tensor()
  46. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  47. class JambaMambaMixer(nn.Module):
  48. """
  49. Compute ∆, A, B, C, and D the state space parameters and compute
  50. the `contextualized_states`. A, D are input independent
  51. (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
  52. for why A isn't selective) ∆, B, C are input-dependent
  53. (this is a key difference between Mamba and the linear time
  54. invariant S4, and is why Mamba is called
  55. **selective** state spaces)
  56. """
  57. def __init__(self, config: JambaConfig, layer_idx):
  58. super().__init__()
  59. self.config = config
  60. self.layer_idx = layer_idx
  61. self.hidden_size = config.hidden_size
  62. self.ssm_state_size = config.mamba_d_state
  63. self.conv_kernel_size = config.mamba_d_conv
  64. self.intermediate_size = config.mamba_expand * config.hidden_size
  65. self.time_step_rank = config.mamba_dt_rank
  66. self.use_conv_bias = config.mamba_conv_bias
  67. self.use_bias = config.mamba_proj_bias
  68. self.conv1d = ColumnParallelLinear(
  69. input_size=self.conv_kernel_size,
  70. output_size=self.intermediate_size,
  71. bias=self.use_conv_bias,
  72. )
  73. # unsqueeze to fit conv1d weights shape into the linear weights shape.
  74. # Can't do this in `weight_loader` since it already exists in
  75. # `ColumnParallelLinear` and `set_weight_attrs`
  76. # doesn't allow to override it
  77. self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
  78. self.in_proj = MergedColumnParallelLinear(self.hidden_size,
  79. [self.intermediate_size] * 2,
  80. bias=self.use_bias)
  81. # selective projection used to make dt, B and C input dependent
  82. self.x_proj = RowParallelLinear(
  83. self.intermediate_size,
  84. self.time_step_rank + self.ssm_state_size * 2,
  85. bias=False,
  86. )
  87. # time step projection (discretization) -
  88. # In the forward we need to apply dt_proj without the bias,
  89. # as the bias is added in the selective scan kernel.
  90. self.dt_proj = ColumnParallelLinear(self.time_step_rank,
  91. self.intermediate_size,
  92. bias=True,
  93. skip_bias_add=True)
  94. def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  95. tp_rank = get_tensor_model_parallel_rank()
  96. tp_size = get_tensor_model_parallel_world_size()
  97. param.data.copy_(
  98. loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
  99. dim=0)[tp_rank])
  100. def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  101. weight_loader(param, -torch.exp(loaded_weight.float()))
  102. tp_size = get_tensor_model_parallel_world_size()
  103. self.A = nn.Parameter(
  104. torch.empty(
  105. self.intermediate_size // tp_size,
  106. self.ssm_state_size,
  107. dtype=torch.float32,
  108. ))
  109. self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
  110. set_weight_attrs(self.D, {"weight_loader": weight_loader})
  111. set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
  112. self.out_proj = RowParallelLinear(
  113. self.intermediate_size,
  114. self.hidden_size,
  115. bias=self.use_bias,
  116. input_is_parallel=True,
  117. )
  118. self.activation = config.hidden_act
  119. self.dt_layernorm = RMSNorm(self.time_step_rank,
  120. eps=config.rms_norm_eps)
  121. self.b_layernorm = RMSNorm(self.ssm_state_size,
  122. eps=config.rms_norm_eps)
  123. self.c_layernorm = RMSNorm(self.ssm_state_size,
  124. eps=config.rms_norm_eps)
  125. def mamba_forward(self,
  126. hidden_states: torch.Tensor,
  127. cache_params: MambaCacheParams = None):
  128. # 1. Gated MLP's linear projection
  129. projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
  130. hidden_states, gate = projected_states.chunk(2, dim=1)
  131. # 2. Convolution sequence transformation
  132. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
  133. self.conv1d.weight.size(2))
  134. if cache_params is not None and not cache_params.is_prompt:
  135. hidden_states = causal_conv1d_update(
  136. hidden_states.squeeze(-1),
  137. cache_params.conv_state,
  138. conv_weights,
  139. self.conv1d.bias,
  140. self.activation,
  141. )
  142. hidden_states = hidden_states.unsqueeze(-1)
  143. else:
  144. if cache_params is not None:
  145. conv_states = nn.functional.pad(
  146. hidden_states,
  147. (self.conv_kernel_size - hidden_states.shape[-1], 0))
  148. cache_params.conv_state.copy_(conv_states)
  149. hidden_states, _ = causal_conv1d_fn(
  150. hidden_states,
  151. conv_weights,
  152. self.conv1d.bias,
  153. activation=self.activation,
  154. )
  155. # 3. State Space Model sequence transformation
  156. # 3.a. input varying initialization of time_step, B and C
  157. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
  158. time_step, B, C = torch.split(
  159. ssm_parameters,
  160. [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
  161. dim=-1,
  162. )
  163. time_step = self.dt_layernorm(time_step.contiguous())
  164. B = self.b_layernorm(B.contiguous())
  165. C = self.c_layernorm(C.contiguous())
  166. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
  167. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  168. time_proj_bias = (self.dt_proj.bias.float() if hasattr(
  169. self.dt_proj, "bias") else None)
  170. if cache_params is not None and not cache_params.is_prompt:
  171. scan_outputs = selective_state_update(
  172. cache_params.ssm_state,
  173. hidden_states[..., 0],
  174. discrete_time_step[..., 0],
  175. self.A,
  176. B[:, 0],
  177. C[:, 0],
  178. self.D,
  179. gate[..., 0],
  180. time_proj_bias,
  181. dt_softplus=True,
  182. ).unsqueeze(-1)
  183. else:
  184. scan_outputs, ssm_state = selective_scan_fn(
  185. hidden_states,
  186. discrete_time_step,
  187. self.A,
  188. B.transpose(1, 2),
  189. C.transpose(1, 2),
  190. self.D.float(),
  191. gate,
  192. time_proj_bias,
  193. delta_softplus=True,
  194. return_last_state=True,
  195. )
  196. if ssm_state is not None and cache_params is not None:
  197. cache_params.ssm_state.copy_(ssm_state)
  198. # 4. Final linear projection
  199. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
  200. return contextualized_states
  201. def forward(
  202. self,
  203. hidden_states: torch.Tensor,
  204. attn_metadata: AttentionMetadata,
  205. conv_state: torch.Tensor,
  206. ssm_state: torch.Tensor,
  207. ):
  208. if attn_metadata.prefill_metadata is not None:
  209. offset = 0
  210. for i, prompt_len in enumerate(
  211. attn_metadata.prefill_metadata.seq_lens):
  212. cache = MambaCacheParams(True,
  213. conv_state=conv_state[i].unsqueeze(0),
  214. ssm_state=ssm_state[i].unsqueeze(0))
  215. hidden_states[offset:offset + prompt_len].copy_(
  216. self.mamba_forward(hidden_states[offset:offset +
  217. prompt_len].unsqueeze(0),
  218. cache_params=cache)[0])
  219. offset += prompt_len
  220. else:
  221. cache = MambaCacheParams(False,
  222. conv_state=conv_state,
  223. ssm_state=ssm_state)
  224. hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
  225. cache_params=cache)
  226. hidden_states = hidden_states.squeeze(1)
  227. return hidden_states
  228. class JambaMoE(nn.Module):
  229. def __init__(self,
  230. config: JambaConfig,
  231. num_experts: Optional[int] = None,
  232. top_k: Optional[int] = None,
  233. params_dtype: Optional[torch.dtype] = None,
  234. tp_size: Optional[int] = None,
  235. quant_config: Optional[QuantizationConfig] = None):
  236. super().__init__()
  237. self.num_total_experts = num_experts or config.num_experts
  238. self.top_k = top_k or config.num_experts_per_tok
  239. self.hidden_size = config.hidden_size
  240. self.intermediate_size = config.intermediate_size
  241. if self.num_total_experts > 1:
  242. self.router = ReplicatedLinear(self.hidden_size,
  243. self.num_total_experts,
  244. bias=False,
  245. quant_config=None,
  246. params_dtype=params_dtype)
  247. self.experts = FusedMoE(self.num_total_experts,
  248. self.top_k,
  249. self.hidden_size,
  250. self.intermediate_size,
  251. tp_size=tp_size,
  252. params_dtype=params_dtype,
  253. reduce_results=True,
  254. renormalize=False,
  255. use_grouped_topk=False,
  256. quant_config=quant_config)
  257. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  258. orig_shape = hidden_states.shape
  259. hidden_states = hidden_states.view(-1, self.hidden_size)
  260. # router_logits: (batch * sequence_length, n_experts)
  261. if self.num_total_experts > 1:
  262. router_logits, _ = self.router(hidden_states)
  263. else:
  264. router_logits = torch.ones((hidden_states.shape[0], 1),
  265. device=hidden_states.device,
  266. dtype=hidden_states.dtype)
  267. hidden_states = self.experts(hidden_states, router_logits)
  268. return hidden_states.view(orig_shape)
  269. class JambaMLP(JambaMoE):
  270. def __init__(self,
  271. config: JambaConfig,
  272. params_dtype: Optional[torch.dtype] = None,
  273. tp_size: Optional[int] = None,
  274. quant_config: Optional[QuantizationConfig] = None):
  275. super().__init__(config,
  276. num_experts=1,
  277. top_k=1,
  278. params_dtype=params_dtype,
  279. tp_size=tp_size,
  280. quant_config=quant_config)
  281. class JambaMambaDecoderLayer(nn.Module):
  282. def __init__(self,
  283. config: JambaConfig,
  284. layer_idx: int,
  285. cache_config: Optional[CacheConfig] = None,
  286. quant_config: Optional[QuantizationConfig] = None) -> None:
  287. super().__init__()
  288. self.layer_idx = layer_idx
  289. self.config = config
  290. self.mamba = JambaMambaMixer(config, layer_idx)
  291. num_experts = config.layers_num_experts[layer_idx]
  292. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  293. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  294. self.input_layernorm = RMSNorm(config.hidden_size,
  295. eps=config.rms_norm_eps)
  296. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  297. eps=config.rms_norm_eps)
  298. def forward(
  299. self,
  300. hidden_states: torch.Tensor,
  301. attn_metadata: AttentionMetadata,
  302. residual: Optional[torch.Tensor],
  303. conv_state: torch.Tensor,
  304. ssm_state: torch.Tensor,
  305. **kwargs,
  306. ):
  307. if residual is None:
  308. residual = hidden_states
  309. hidden_states = self.input_layernorm(hidden_states)
  310. else:
  311. hidden_states, residual = self.input_layernorm(
  312. hidden_states, residual)
  313. hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
  314. ssm_state)
  315. # Fully Connected
  316. hidden_states, residual = self.pre_ff_layernorm(
  317. hidden_states, residual)
  318. hidden_states = self.feed_forward(hidden_states)
  319. return hidden_states, residual
  320. class JambaAttentionDecoderLayer(nn.Module):
  321. def __init__(
  322. self,
  323. config: JambaConfig,
  324. layer_idx: int,
  325. cache_config: Optional[CacheConfig] = None,
  326. quant_config: Optional[QuantizationConfig] = None,
  327. ) -> None:
  328. super().__init__()
  329. self.hidden_size = config.hidden_size
  330. tp_size = get_tensor_model_parallel_world_size()
  331. self.total_num_heads = config.num_attention_heads
  332. assert self.total_num_heads % tp_size == 0
  333. self.num_heads = self.total_num_heads // tp_size
  334. self.total_num_kv_heads = config.num_key_value_heads
  335. if self.total_num_kv_heads >= tp_size:
  336. # Number of KV heads is greater than TP size, so we partition
  337. # the KV heads across multiple tensor parallel GPUs.
  338. assert self.total_num_kv_heads % tp_size == 0
  339. else:
  340. # Number of KV heads is less than TP size, so we replicate
  341. # the KV heads across multiple tensor parallel GPUs.
  342. assert tp_size % self.total_num_kv_heads == 0
  343. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  344. self.head_dim = config.hidden_size // self.total_num_heads
  345. self.q_size = self.num_heads * self.head_dim
  346. self.kv_size = self.num_kv_heads * self.head_dim
  347. self.scaling = self.head_dim**-0.5
  348. self.qkv_proj = QKVParallelLinear(
  349. config.hidden_size,
  350. self.head_dim,
  351. self.total_num_heads,
  352. self.total_num_kv_heads,
  353. bias=False,
  354. quant_config=quant_config,
  355. )
  356. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  357. config.hidden_size,
  358. bias=False,
  359. quant_config=quant_config)
  360. self.attn = Attention(
  361. self.num_heads,
  362. self.head_dim,
  363. self.scaling,
  364. num_kv_heads=self.num_kv_heads,
  365. cache_config=cache_config,
  366. )
  367. num_experts = config.layers_num_experts[layer_idx]
  368. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  369. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  370. self.input_layernorm = RMSNorm(config.hidden_size,
  371. eps=config.rms_norm_eps)
  372. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  373. eps=config.rms_norm_eps)
  374. def self_attention(
  375. self,
  376. positions: torch.Tensor,
  377. hidden_states: torch.Tensor,
  378. kv_cache: torch.Tensor,
  379. attn_metadata: AttentionMetadata,
  380. **kwargs,
  381. ) -> torch.Tensor:
  382. qkv, _ = self.qkv_proj(hidden_states)
  383. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  384. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  385. output, _ = self.o_proj(attn_output)
  386. return output
  387. def forward(
  388. self,
  389. positions: torch.Tensor,
  390. hidden_states: torch.Tensor,
  391. kv_cache: torch.Tensor,
  392. attn_metadata: AttentionMetadata,
  393. residual: Optional[torch.Tensor],
  394. **kwargs,
  395. ):
  396. if residual is None:
  397. residual = hidden_states
  398. hidden_states = self.input_layernorm(hidden_states)
  399. else:
  400. hidden_states, residual = self.input_layernorm(
  401. hidden_states, residual)
  402. hidden_states = self.self_attention(
  403. positions=positions,
  404. hidden_states=hidden_states,
  405. kv_cache=kv_cache,
  406. attn_metadata=attn_metadata,
  407. )
  408. # Fully Connected
  409. hidden_states, residual = self.pre_ff_layernorm(
  410. hidden_states, residual)
  411. hidden_states = self.feed_forward(hidden_states)
  412. return hidden_states, residual
  413. ALL_DECODER_LAYER_TYPES = {
  414. "attention": JambaAttentionDecoderLayer,
  415. "mamba": JambaMambaDecoderLayer
  416. }
  417. class JambaModel(nn.Module):
  418. def __init__(
  419. self,
  420. config: JambaConfig,
  421. quant_config: Optional[QuantizationConfig] = None,
  422. cache_config: Optional[CacheConfig] = None,
  423. lora_config: Optional[LoRAConfig] = None,
  424. ) -> None:
  425. super().__init__()
  426. self.config = config
  427. self.padding_idx = config.pad_token_id
  428. lora_vocab = ((lora_config.lora_extra_vocab_size *
  429. (lora_config.max_loras or 1)) if lora_config else 0)
  430. self.vocab_size = config.vocab_size + lora_vocab
  431. self.org_vocab_size = config.vocab_size
  432. self.embed_tokens = VocabParallelEmbedding(
  433. self.vocab_size,
  434. config.hidden_size,
  435. org_num_embeddings=config.vocab_size,
  436. )
  437. decoder_layers = []
  438. for i in range(config.num_hidden_layers):
  439. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  440. decoder_layers.append(
  441. layer_class(config,
  442. layer_idx=i,
  443. cache_config=cache_config,
  444. quant_config=quant_config))
  445. self.layers = nn.ModuleList(decoder_layers)
  446. self.final_layernorm = RMSNorm(config.hidden_size,
  447. eps=config.rms_norm_eps)
  448. def forward(
  449. self,
  450. input_ids: torch.Tensor,
  451. positions: torch.Tensor,
  452. kv_caches: List[torch.Tensor],
  453. attn_metadata: AttentionMetadata,
  454. conv_state: torch.Tensor,
  455. ssm_state: torch.Tensor,
  456. ) -> torch.Tensor:
  457. hidden_states = self.embed_tokens(input_ids)
  458. residual = None
  459. for i in range(len(self.layers)):
  460. layer = self.layers[i]
  461. kv_cache = None
  462. current_ssm_state = None
  463. current_conv_state = None
  464. if isinstance(layer, JambaAttentionDecoderLayer):
  465. kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
  466. self.config.attn_layer_period]
  467. if isinstance(layer, JambaMambaDecoderLayer):
  468. current_state_layer = i - (1 +
  469. (i - self.config.attn_layer_offset)
  470. // self.config.attn_layer_period)
  471. current_ssm_state = ssm_state[current_state_layer]
  472. current_conv_state = conv_state[current_state_layer]
  473. hidden_states, residual = layer(
  474. positions=positions,
  475. hidden_states=hidden_states,
  476. kv_cache=kv_cache,
  477. attn_metadata=attn_metadata,
  478. residual=residual,
  479. conv_state=current_conv_state,
  480. ssm_state=current_ssm_state,
  481. )
  482. hidden_states, _ = self.final_layernorm(hidden_states, residual)
  483. return hidden_states
  484. class JambaForCausalLM(nn.Module, HasInnerState):
  485. packed_modules_mapping = {
  486. "qkv_proj": [
  487. "q_proj",
  488. "k_proj",
  489. "v_proj",
  490. ],
  491. }
  492. # LoRA specific attributes
  493. supported_lora_modules = [
  494. "qkv_proj",
  495. "o_proj",
  496. "embed_tokens",
  497. "lm_head",
  498. ]
  499. embedding_modules = {
  500. "embed_tokens": "input_embeddings",
  501. "lm_head": "output_embeddings",
  502. }
  503. embedding_padding_modules = ["lm_head"]
  504. def __init__(
  505. self,
  506. config: JambaConfig,
  507. cache_config: Optional[CacheConfig] = None,
  508. quant_config: Optional[QuantizationConfig] = None,
  509. lora_config: Optional[LoRAConfig] = None,
  510. scheduler_config: Optional[SchedulerConfig] = None,
  511. ) -> None:
  512. assert not scheduler_config.chunked_prefill_enabled, \
  513. "Jamba currently does not support chunked prefill"
  514. assert not cache_config.enable_prefix_caching, \
  515. "Jamba currently does not support prefix caching"
  516. super().__init__()
  517. self.config = config
  518. self.scheduler_config = scheduler_config
  519. self.model = JambaModel(config,
  520. cache_config=cache_config,
  521. quant_config=quant_config,
  522. lora_config=lora_config)
  523. self.unpadded_vocab_size = config.vocab_size
  524. if lora_config:
  525. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  526. self.lm_head = ParallelLMHead(
  527. self.unpadded_vocab_size,
  528. config.hidden_size,
  529. org_num_embeddings=config.vocab_size,
  530. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  531. # We need bigger padding if using lora for kernel
  532. # compatibility
  533. if not lora_config else lora_config.lora_vocab_padding_size,
  534. )
  535. # Used to track and store by the Mamba cache between steps.
  536. self.mamba_cache: Optional[MambaCacheManager] = None
  537. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  538. config.vocab_size)
  539. self.sampler = Sampler()
  540. def forward(self,
  541. input_ids: torch.Tensor,
  542. positions: torch.Tensor,
  543. kv_caches: List[KVCache],
  544. attn_metadata: AttentionMetadata,
  545. intermediate_tensors: Optional[IntermediateTensors] = None,
  546. **kwargs):
  547. if self.mamba_cache is None:
  548. max_batch_size = (_get_graph_batch_size(
  549. self.scheduler_config.max_num_seqs) if self.scheduler_config
  550. else max(_BATCH_SIZES_TO_CAPTURE) + 2)
  551. layers_type = self.config.layers_block_type
  552. num_mamba_layers = sum(
  553. [layer_type == "mamba" for layer_type in layers_type])
  554. self.mamba_cache = MambaCacheManager(
  555. self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
  556. *self._get_mamba_cache_shape())
  557. if "seqlen_agnostic_capture_inputs" not in kwargs:
  558. # We get here only on Prefill/Eager mode runs
  559. assert all(
  560. key in kwargs
  561. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  562. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  563. finished_requests_ids = kwargs["finished_requests_ids"]
  564. self.mamba_cache.release_finished_requests(finished_requests_ids)
  565. batch_size = input_ids.shape[0]
  566. if attn_metadata.prefill_metadata:
  567. batch_size = len(request_ids_to_seq_ids)
  568. mamba_cache_tensors = self.mamba_cache.prepare_current_run_state(
  569. request_ids_to_seq_ids, batch_size, finished_requests_ids)
  570. else:
  571. # CUDA graph capturing runs
  572. mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
  573. hidden_states = self.model(input_ids, positions, kv_caches,
  574. attn_metadata, mamba_cache_tensors[0],
  575. mamba_cache_tensors[1])
  576. return hidden_states
  577. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  578. return self.mamba_cache.copy_inputs_before_cuda_graphs(
  579. input_buffers, **kwargs)
  580. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  581. return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
  582. def _get_mamba_cache_shape(
  583. self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
  584. world_size = get_tensor_model_parallel_world_size()
  585. hidden_size = self.config.hidden_size
  586. conv_state_shape = (
  587. self.config.mamba_expand * hidden_size // world_size,
  588. self.config.mamba_d_conv,
  589. )
  590. temporal_state_shape = (
  591. self.config.mamba_expand * hidden_size // world_size,
  592. self.config.mamba_d_state,
  593. )
  594. return conv_state_shape, temporal_state_shape
  595. def compute_logits(
  596. self,
  597. hidden_states: torch.Tensor,
  598. sampling_metadata: SamplingMetadata,
  599. ) -> Optional[torch.Tensor]:
  600. logits = self.logits_processor(self.lm_head, hidden_states,
  601. sampling_metadata)
  602. return logits
  603. def sample(
  604. self,
  605. logits: Optional[torch.Tensor],
  606. sampling_metadata: SamplingMetadata,
  607. ) -> Optional[SamplerOutput]:
  608. next_tokens = self.sampler(logits, sampling_metadata)
  609. return next_tokens
  610. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  611. stacked_params_mapping = [
  612. # (param_name, shard_name, shard_id)
  613. ("qkv_proj", "q_proj", "q"),
  614. ("qkv_proj", "k_proj", "k"),
  615. ("qkv_proj", "v_proj", "v"),
  616. ]
  617. # Params for weights, fp8 weight scales, fp8 activation scales
  618. # (param_name, weight_name, expert_id, shard_id)
  619. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  620. ckpt_gate_proj_name="gate_proj",
  621. ckpt_down_proj_name="down_proj",
  622. ckpt_up_proj_name="up_proj",
  623. num_experts=self.config.num_experts)
  624. params_dict = dict(self.named_parameters())
  625. for name, loaded_weight in weights:
  626. if "rotary_emb.inv_freq" in name:
  627. continue
  628. if "A_log" in name:
  629. name = name.replace("A_log", "A")
  630. if ".self_attn." in name:
  631. name = name.replace(".self_attn", "")
  632. if "feed_forward" in name and not _is_moe_layer(name):
  633. ## map MLP layers to expert with ID=0
  634. name = name.replace("feed_forward", "feed_forward.experts.0")
  635. for param_name, weight_name, shard_id in stacked_params_mapping:
  636. if weight_name not in name:
  637. continue
  638. if 'experts' in name:
  639. continue
  640. name = name.replace(weight_name, param_name)
  641. # Skip loading extra bias for GPTQ models.
  642. if name.endswith(".bias") and name not in params_dict:
  643. continue
  644. param = params_dict[name]
  645. weight_loader = param.weight_loader
  646. weight_loader(param, loaded_weight, shard_id)
  647. break
  648. else:
  649. for (
  650. param_name,
  651. weight_name,
  652. expert_id,
  653. shard_id,
  654. ) in expert_params_mapping:
  655. if weight_name not in name:
  656. continue
  657. name = name.replace(weight_name, param_name)
  658. param = params_dict[name]
  659. weight_loader = param.weight_loader
  660. weight_loader(param,
  661. loaded_weight,
  662. weight_name,
  663. shard_id=shard_id,
  664. expert_id=expert_id)
  665. break
  666. else:
  667. # Skip loading extra bias for GPTQ models.
  668. if name.endswith(".bias") and name not in params_dict:
  669. continue
  670. param = params_dict[name]
  671. weight_loader = getattr(param, "weight_loader",
  672. default_weight_loader)
  673. weight_loader(param, loaded_weight)
  674. def _is_moe_layer(name: str):
  675. return any(
  676. [experts_name in name for experts_name in [
  677. "experts",
  678. "router",
  679. ]])