jamba.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. # coding=utf-8
  2. """Inference-only Jamba model."""
  3. from dataclasses import dataclass
  4. from typing import Iterable, List, Optional, Tuple
  5. import torch
  6. from torch import nn
  7. from torch.nn.parameter import Parameter
  8. from transformers import JambaConfig
  9. from aphrodite.attention.backends.abstract import AttentionMetadata
  10. from aphrodite.attention.layer import Attention
  11. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  12. from aphrodite.common.sequence import IntermediateTensors
  13. # yapf: disable
  14. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  15. get_tensor_model_parallel_world_size)
  16. # yapf: enable
  17. from aphrodite.modeling.layers.fused_moe import FusedMoE
  18. from aphrodite.modeling.layers.layernorm import RMSNorm
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. ReplicatedLinear,
  23. RowParallelLinear)
  24. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  25. from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
  26. causal_conv1d_update,
  27. selective_scan_fn,
  28. selective_state_update)
  29. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  30. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  31. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  32. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  33. from aphrodite.modeling.models.interfaces import HasInnerState
  34. from aphrodite.modeling.models.mamba_cache import MambaCacheManager
  35. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  36. from aphrodite.modeling.utils import set_weight_attrs
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. from aphrodite.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
  39. _get_graph_batch_size)
  40. from .interfaces import SupportsLoRA
  41. KVCache = Tuple[torch.Tensor, torch.Tensor]
  42. @dataclass
  43. class MambaCacheParams:
  44. is_prompt: bool = False
  45. conv_state: torch.Tensor = torch.Tensor()
  46. ssm_state: torch.Tensor = torch.Tensor()
  47. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  48. class JambaMambaMixer(nn.Module):
  49. """
  50. Compute ∆, A, B, C, and D the state space parameters and compute
  51. the `contextualized_states`. A, D are input independent
  52. (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
  53. for why A isn't selective) ∆, B, C are input-dependent
  54. (this is a key difference between Mamba and the linear time
  55. invariant S4, and is why Mamba is called
  56. **selective** state spaces)
  57. """
  58. def __init__(self, config: JambaConfig, layer_idx):
  59. super().__init__()
  60. self.config = config
  61. self.layer_idx = layer_idx
  62. self.hidden_size = config.hidden_size
  63. self.ssm_state_size = config.mamba_d_state
  64. self.conv_kernel_size = config.mamba_d_conv
  65. self.intermediate_size = config.mamba_expand * config.hidden_size
  66. self.time_step_rank = config.mamba_dt_rank
  67. self.use_conv_bias = config.mamba_conv_bias
  68. self.use_bias = config.mamba_proj_bias
  69. self.conv1d = ColumnParallelLinear(
  70. input_size=self.conv_kernel_size,
  71. output_size=self.intermediate_size,
  72. bias=self.use_conv_bias,
  73. )
  74. # unsqueeze to fit conv1d weights shape into the linear weights shape.
  75. # Can't do this in `weight_loader` since it already exists in
  76. # `ColumnParallelLinear` and `set_weight_attrs`
  77. # doesn't allow to override it
  78. self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
  79. self.in_proj = MergedColumnParallelLinear(self.hidden_size,
  80. [self.intermediate_size] * 2,
  81. bias=self.use_bias)
  82. # selective projection used to make dt, B and C input dependent
  83. self.x_proj = RowParallelLinear(
  84. self.intermediate_size,
  85. self.time_step_rank + self.ssm_state_size * 2,
  86. bias=False,
  87. )
  88. # time step projection (discretization) -
  89. # In the forward we need to apply dt_proj without the bias,
  90. # as the bias is added in the selective scan kernel.
  91. self.dt_proj = ColumnParallelLinear(self.time_step_rank,
  92. self.intermediate_size,
  93. bias=True,
  94. skip_bias_add=True)
  95. def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  96. tp_rank = get_tensor_model_parallel_rank()
  97. tp_size = get_tensor_model_parallel_world_size()
  98. param.data.copy_(
  99. loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
  100. dim=0)[tp_rank])
  101. def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  102. weight_loader(param, -torch.exp(loaded_weight.float()))
  103. tp_size = get_tensor_model_parallel_world_size()
  104. self.A = nn.Parameter(
  105. torch.empty(
  106. self.intermediate_size // tp_size,
  107. self.ssm_state_size,
  108. dtype=torch.float32,
  109. ))
  110. self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
  111. set_weight_attrs(self.D, {"weight_loader": weight_loader})
  112. set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
  113. self.out_proj = RowParallelLinear(
  114. self.intermediate_size,
  115. self.hidden_size,
  116. bias=self.use_bias,
  117. input_is_parallel=True,
  118. )
  119. self.activation = config.hidden_act
  120. self.dt_layernorm = RMSNorm(self.time_step_rank,
  121. eps=config.rms_norm_eps)
  122. self.b_layernorm = RMSNorm(self.ssm_state_size,
  123. eps=config.rms_norm_eps)
  124. self.c_layernorm = RMSNorm(self.ssm_state_size,
  125. eps=config.rms_norm_eps)
  126. def mamba_forward(self,
  127. hidden_states: torch.Tensor,
  128. cache_params: MambaCacheParams = None):
  129. # 1. Gated MLP's linear projection
  130. projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
  131. hidden_states, gate = projected_states.chunk(2, dim=1)
  132. # 2. Convolution sequence transformation
  133. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
  134. self.conv1d.weight.size(2))
  135. if cache_params is not None and not cache_params.is_prompt:
  136. hidden_states = causal_conv1d_update(
  137. hidden_states.squeeze(-1),
  138. cache_params.conv_state,
  139. conv_weights,
  140. self.conv1d.bias,
  141. self.activation,
  142. )
  143. hidden_states = hidden_states.unsqueeze(-1)
  144. else:
  145. if cache_params is not None:
  146. conv_states = nn.functional.pad(
  147. hidden_states,
  148. (self.conv_kernel_size - hidden_states.shape[-1], 0))
  149. cache_params.conv_state.copy_(conv_states)
  150. hidden_states, _ = causal_conv1d_fn(
  151. hidden_states,
  152. conv_weights,
  153. self.conv1d.bias,
  154. activation=self.activation,
  155. )
  156. # 3. State Space Model sequence transformation
  157. # 3.a. input varying initialization of time_step, B and C
  158. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
  159. time_step, B, C = torch.split(
  160. ssm_parameters,
  161. [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
  162. dim=-1,
  163. )
  164. time_step = self.dt_layernorm(time_step.contiguous())
  165. B = self.b_layernorm(B.contiguous())
  166. C = self.c_layernorm(C.contiguous())
  167. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
  168. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  169. time_proj_bias = (self.dt_proj.bias.float() if hasattr(
  170. self.dt_proj, "bias") else None)
  171. if cache_params is not None and not cache_params.is_prompt:
  172. scan_outputs = selective_state_update(
  173. cache_params.ssm_state,
  174. hidden_states[..., 0],
  175. discrete_time_step[..., 0],
  176. self.A,
  177. B[:, 0],
  178. C[:, 0],
  179. self.D,
  180. gate[..., 0],
  181. time_proj_bias,
  182. dt_softplus=True,
  183. ).unsqueeze(-1)
  184. else:
  185. scan_outputs, ssm_state = selective_scan_fn(
  186. hidden_states,
  187. discrete_time_step,
  188. self.A,
  189. B.transpose(1, 2),
  190. C.transpose(1, 2),
  191. self.D.float(),
  192. gate,
  193. time_proj_bias,
  194. delta_softplus=True,
  195. return_last_state=True,
  196. )
  197. if ssm_state is not None and cache_params is not None:
  198. cache_params.ssm_state.copy_(ssm_state)
  199. # 4. Final linear projection
  200. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
  201. return contextualized_states
  202. def forward(
  203. self,
  204. hidden_states: torch.Tensor,
  205. attn_metadata: AttentionMetadata,
  206. conv_state: torch.Tensor,
  207. ssm_state: torch.Tensor,
  208. ):
  209. if attn_metadata.prefill_metadata is not None:
  210. offset = 0
  211. for i, prompt_len in enumerate(
  212. attn_metadata.prefill_metadata.seq_lens):
  213. cache = MambaCacheParams(True,
  214. conv_state=conv_state[i].unsqueeze(0),
  215. ssm_state=ssm_state[i].unsqueeze(0))
  216. hidden_states[offset:offset + prompt_len].copy_(
  217. self.mamba_forward(hidden_states[offset:offset +
  218. prompt_len].unsqueeze(0),
  219. cache_params=cache)[0])
  220. offset += prompt_len
  221. else:
  222. cache = MambaCacheParams(False,
  223. conv_state=conv_state,
  224. ssm_state=ssm_state)
  225. hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
  226. cache_params=cache)
  227. hidden_states = hidden_states.squeeze(1)
  228. return hidden_states
  229. class JambaMoE(nn.Module):
  230. def __init__(self,
  231. config: JambaConfig,
  232. num_experts: Optional[int] = None,
  233. top_k: Optional[int] = None,
  234. params_dtype: Optional[torch.dtype] = None,
  235. tp_size: Optional[int] = None,
  236. quant_config: Optional[QuantizationConfig] = None):
  237. super().__init__()
  238. self.num_total_experts = num_experts or config.num_experts
  239. self.top_k = top_k or config.num_experts_per_tok
  240. self.hidden_size = config.hidden_size
  241. self.intermediate_size = config.intermediate_size
  242. if self.num_total_experts > 1:
  243. self.router = ReplicatedLinear(self.hidden_size,
  244. self.num_total_experts,
  245. bias=False,
  246. quant_config=None,
  247. params_dtype=params_dtype)
  248. self.experts = FusedMoE(self.num_total_experts,
  249. self.top_k,
  250. self.hidden_size,
  251. self.intermediate_size,
  252. tp_size=tp_size,
  253. params_dtype=params_dtype,
  254. reduce_results=True,
  255. renormalize=False,
  256. use_grouped_topk=False,
  257. quant_config=quant_config)
  258. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  259. orig_shape = hidden_states.shape
  260. hidden_states = hidden_states.view(-1, self.hidden_size)
  261. # router_logits: (batch * sequence_length, n_experts)
  262. if self.num_total_experts > 1:
  263. router_logits, _ = self.router(hidden_states)
  264. else:
  265. router_logits = torch.ones((hidden_states.shape[0], 1),
  266. device=hidden_states.device,
  267. dtype=hidden_states.dtype)
  268. hidden_states = self.experts(hidden_states, router_logits)
  269. return hidden_states.view(orig_shape)
  270. class JambaMLP(JambaMoE):
  271. def __init__(self,
  272. config: JambaConfig,
  273. params_dtype: Optional[torch.dtype] = None,
  274. tp_size: Optional[int] = None,
  275. quant_config: Optional[QuantizationConfig] = None):
  276. super().__init__(config,
  277. num_experts=1,
  278. top_k=1,
  279. params_dtype=params_dtype,
  280. tp_size=tp_size,
  281. quant_config=quant_config)
  282. class JambaMambaDecoderLayer(nn.Module):
  283. def __init__(self,
  284. config: JambaConfig,
  285. layer_idx: int,
  286. cache_config: Optional[CacheConfig] = None,
  287. quant_config: Optional[QuantizationConfig] = None) -> None:
  288. super().__init__()
  289. self.layer_idx = layer_idx
  290. self.config = config
  291. self.mamba = JambaMambaMixer(config, layer_idx)
  292. num_experts = config.layers_num_experts[layer_idx]
  293. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  294. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  295. self.input_layernorm = RMSNorm(config.hidden_size,
  296. eps=config.rms_norm_eps)
  297. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  298. eps=config.rms_norm_eps)
  299. def forward(
  300. self,
  301. hidden_states: torch.Tensor,
  302. attn_metadata: AttentionMetadata,
  303. residual: Optional[torch.Tensor],
  304. conv_state: torch.Tensor,
  305. ssm_state: torch.Tensor,
  306. **kwargs,
  307. ):
  308. if residual is None:
  309. residual = hidden_states
  310. hidden_states = self.input_layernorm(hidden_states)
  311. else:
  312. hidden_states, residual = self.input_layernorm(
  313. hidden_states, residual)
  314. hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
  315. ssm_state)
  316. # Fully Connected
  317. hidden_states, residual = self.pre_ff_layernorm(
  318. hidden_states, residual)
  319. hidden_states = self.feed_forward(hidden_states)
  320. return hidden_states, residual
  321. class JambaAttentionDecoderLayer(nn.Module):
  322. def __init__(
  323. self,
  324. config: JambaConfig,
  325. layer_idx: int,
  326. cache_config: Optional[CacheConfig] = None,
  327. quant_config: Optional[QuantizationConfig] = None,
  328. ) -> None:
  329. super().__init__()
  330. self.hidden_size = config.hidden_size
  331. tp_size = get_tensor_model_parallel_world_size()
  332. self.total_num_heads = config.num_attention_heads
  333. assert self.total_num_heads % tp_size == 0
  334. self.num_heads = self.total_num_heads // tp_size
  335. self.total_num_kv_heads = config.num_key_value_heads
  336. if self.total_num_kv_heads >= tp_size:
  337. # Number of KV heads is greater than TP size, so we partition
  338. # the KV heads across multiple tensor parallel GPUs.
  339. assert self.total_num_kv_heads % tp_size == 0
  340. else:
  341. # Number of KV heads is less than TP size, so we replicate
  342. # the KV heads across multiple tensor parallel GPUs.
  343. assert tp_size % self.total_num_kv_heads == 0
  344. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  345. self.head_dim = config.hidden_size // self.total_num_heads
  346. self.q_size = self.num_heads * self.head_dim
  347. self.kv_size = self.num_kv_heads * self.head_dim
  348. self.scaling = self.head_dim**-0.5
  349. self.qkv_proj = QKVParallelLinear(
  350. config.hidden_size,
  351. self.head_dim,
  352. self.total_num_heads,
  353. self.total_num_kv_heads,
  354. bias=False,
  355. quant_config=quant_config,
  356. )
  357. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  358. config.hidden_size,
  359. bias=False,
  360. quant_config=quant_config)
  361. self.attn = Attention(
  362. self.num_heads,
  363. self.head_dim,
  364. self.scaling,
  365. num_kv_heads=self.num_kv_heads,
  366. cache_config=cache_config,
  367. )
  368. num_experts = config.layers_num_experts[layer_idx]
  369. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  370. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  371. self.input_layernorm = RMSNorm(config.hidden_size,
  372. eps=config.rms_norm_eps)
  373. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  374. eps=config.rms_norm_eps)
  375. def self_attention(
  376. self,
  377. positions: torch.Tensor,
  378. hidden_states: torch.Tensor,
  379. kv_cache: torch.Tensor,
  380. attn_metadata: AttentionMetadata,
  381. **kwargs,
  382. ) -> torch.Tensor:
  383. qkv, _ = self.qkv_proj(hidden_states)
  384. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  385. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  386. output, _ = self.o_proj(attn_output)
  387. return output
  388. def forward(
  389. self,
  390. positions: torch.Tensor,
  391. hidden_states: torch.Tensor,
  392. kv_cache: torch.Tensor,
  393. attn_metadata: AttentionMetadata,
  394. residual: Optional[torch.Tensor],
  395. **kwargs,
  396. ):
  397. if residual is None:
  398. residual = hidden_states
  399. hidden_states = self.input_layernorm(hidden_states)
  400. else:
  401. hidden_states, residual = self.input_layernorm(
  402. hidden_states, residual)
  403. hidden_states = self.self_attention(
  404. positions=positions,
  405. hidden_states=hidden_states,
  406. kv_cache=kv_cache,
  407. attn_metadata=attn_metadata,
  408. )
  409. # Fully Connected
  410. hidden_states, residual = self.pre_ff_layernorm(
  411. hidden_states, residual)
  412. hidden_states = self.feed_forward(hidden_states)
  413. return hidden_states, residual
  414. ALL_DECODER_LAYER_TYPES = {
  415. "attention": JambaAttentionDecoderLayer,
  416. "mamba": JambaMambaDecoderLayer
  417. }
  418. class JambaModel(nn.Module):
  419. def __init__(
  420. self,
  421. config: JambaConfig,
  422. quant_config: Optional[QuantizationConfig] = None,
  423. cache_config: Optional[CacheConfig] = None,
  424. lora_config: Optional[LoRAConfig] = None,
  425. ) -> None:
  426. super().__init__()
  427. self.config = config
  428. self.padding_idx = config.pad_token_id
  429. lora_vocab = ((lora_config.lora_extra_vocab_size *
  430. (lora_config.max_loras or 1)) if lora_config else 0)
  431. self.vocab_size = config.vocab_size + lora_vocab
  432. self.org_vocab_size = config.vocab_size
  433. self.embed_tokens = VocabParallelEmbedding(
  434. self.vocab_size,
  435. config.hidden_size,
  436. org_num_embeddings=config.vocab_size,
  437. )
  438. decoder_layers = []
  439. for i in range(config.num_hidden_layers):
  440. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  441. decoder_layers.append(
  442. layer_class(config,
  443. layer_idx=i,
  444. cache_config=cache_config,
  445. quant_config=quant_config))
  446. self.layers = nn.ModuleList(decoder_layers)
  447. self.final_layernorm = RMSNorm(config.hidden_size,
  448. eps=config.rms_norm_eps)
  449. def forward(
  450. self,
  451. input_ids: torch.Tensor,
  452. positions: torch.Tensor,
  453. kv_caches: List[torch.Tensor],
  454. attn_metadata: AttentionMetadata,
  455. conv_state: torch.Tensor,
  456. ssm_state: torch.Tensor,
  457. ) -> torch.Tensor:
  458. hidden_states = self.embed_tokens(input_ids)
  459. residual = None
  460. for i in range(len(self.layers)):
  461. layer = self.layers[i]
  462. kv_cache = None
  463. current_ssm_state = None
  464. current_conv_state = None
  465. if isinstance(layer, JambaAttentionDecoderLayer):
  466. kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
  467. self.config.attn_layer_period]
  468. if isinstance(layer, JambaMambaDecoderLayer):
  469. current_state_layer = i - (1 +
  470. (i - self.config.attn_layer_offset)
  471. // self.config.attn_layer_period)
  472. current_ssm_state = ssm_state[current_state_layer]
  473. current_conv_state = conv_state[current_state_layer]
  474. hidden_states, residual = layer(
  475. positions=positions,
  476. hidden_states=hidden_states,
  477. kv_cache=kv_cache,
  478. attn_metadata=attn_metadata,
  479. residual=residual,
  480. conv_state=current_conv_state,
  481. ssm_state=current_ssm_state,
  482. )
  483. hidden_states, _ = self.final_layernorm(hidden_states, residual)
  484. return hidden_states
  485. class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
  486. packed_modules_mapping = {
  487. "qkv_proj": [
  488. "q_proj",
  489. "k_proj",
  490. "v_proj",
  491. ],
  492. }
  493. # LoRA specific attributes
  494. supported_lora_modules = [
  495. "qkv_proj",
  496. "o_proj",
  497. "embed_tokens",
  498. "lm_head",
  499. ]
  500. embedding_modules = {
  501. "embed_tokens": "input_embeddings",
  502. "lm_head": "output_embeddings",
  503. }
  504. embedding_padding_modules = ["lm_head"]
  505. def __init__(
  506. self,
  507. config: JambaConfig,
  508. cache_config: Optional[CacheConfig] = None,
  509. quant_config: Optional[QuantizationConfig] = None,
  510. lora_config: Optional[LoRAConfig] = None,
  511. scheduler_config: Optional[SchedulerConfig] = None,
  512. ) -> None:
  513. assert not scheduler_config.chunked_prefill_enabled, \
  514. "Jamba currently does not support chunked prefill"
  515. assert not cache_config.enable_prefix_caching, \
  516. "Jamba currently does not support prefix caching"
  517. super().__init__()
  518. self.config = config
  519. self.scheduler_config = scheduler_config
  520. self.model = JambaModel(config,
  521. cache_config=cache_config,
  522. quant_config=quant_config,
  523. lora_config=lora_config)
  524. self.unpadded_vocab_size = config.vocab_size
  525. if lora_config:
  526. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  527. self.lm_head = ParallelLMHead(
  528. self.unpadded_vocab_size,
  529. config.hidden_size,
  530. org_num_embeddings=config.vocab_size,
  531. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  532. # We need bigger padding if using lora for kernel
  533. # compatibility
  534. if not lora_config else lora_config.lora_vocab_padding_size,
  535. )
  536. # Used to track and store by the Mamba cache between steps.
  537. self.mamba_cache: Optional[MambaCacheManager] = None
  538. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  539. config.vocab_size)
  540. self.sampler = Sampler()
  541. def forward(self,
  542. input_ids: torch.Tensor,
  543. positions: torch.Tensor,
  544. kv_caches: List[KVCache],
  545. attn_metadata: AttentionMetadata,
  546. intermediate_tensors: Optional[IntermediateTensors] = None,
  547. **kwargs):
  548. if self.mamba_cache is None:
  549. max_batch_size = (_get_graph_batch_size(
  550. self.scheduler_config.max_num_seqs) if self.scheduler_config
  551. else max(_BATCH_SIZES_TO_CAPTURE) + 2)
  552. layers_type = self.config.layers_block_type
  553. num_mamba_layers = sum(
  554. [layer_type == "mamba" for layer_type in layers_type])
  555. self.mamba_cache = MambaCacheManager(
  556. self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
  557. *self._get_mamba_cache_shape())
  558. if "seqlen_agnostic_capture_inputs" not in kwargs:
  559. # We get here only on Prefill/Eager mode runs
  560. assert all(
  561. key in kwargs
  562. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  563. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  564. finished_requests_ids = kwargs["finished_requests_ids"]
  565. self.mamba_cache.release_finished_requests(finished_requests_ids)
  566. batch_size = input_ids.shape[0]
  567. if attn_metadata.prefill_metadata:
  568. batch_size = len(request_ids_to_seq_ids)
  569. mamba_cache_tensors = self.mamba_cache.prepare_current_run_state(
  570. request_ids_to_seq_ids, batch_size, finished_requests_ids)
  571. else:
  572. # CUDA graph capturing runs
  573. mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
  574. hidden_states = self.model(input_ids, positions, kv_caches,
  575. attn_metadata, mamba_cache_tensors[0],
  576. mamba_cache_tensors[1])
  577. return hidden_states
  578. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  579. return self.mamba_cache.copy_inputs_before_cuda_graphs(
  580. input_buffers, **kwargs)
  581. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  582. return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
  583. def _get_mamba_cache_shape(
  584. self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
  585. world_size = get_tensor_model_parallel_world_size()
  586. hidden_size = self.config.hidden_size
  587. conv_state_shape = (
  588. self.config.mamba_expand * hidden_size // world_size,
  589. self.config.mamba_d_conv,
  590. )
  591. temporal_state_shape = (
  592. self.config.mamba_expand * hidden_size // world_size,
  593. self.config.mamba_d_state,
  594. )
  595. return conv_state_shape, temporal_state_shape
  596. def compute_logits(
  597. self,
  598. hidden_states: torch.Tensor,
  599. sampling_metadata: SamplingMetadata,
  600. ) -> Optional[torch.Tensor]:
  601. logits = self.logits_processor(self.lm_head, hidden_states,
  602. sampling_metadata)
  603. return logits
  604. def sample(
  605. self,
  606. logits: Optional[torch.Tensor],
  607. sampling_metadata: SamplingMetadata,
  608. ) -> Optional[SamplerOutput]:
  609. next_tokens = self.sampler(logits, sampling_metadata)
  610. return next_tokens
  611. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  612. stacked_params_mapping = [
  613. # (param_name, shard_name, shard_id)
  614. ("qkv_proj", "q_proj", "q"),
  615. ("qkv_proj", "k_proj", "k"),
  616. ("qkv_proj", "v_proj", "v"),
  617. ]
  618. # Params for weights, fp8 weight scales, fp8 activation scales
  619. # (param_name, weight_name, expert_id, shard_id)
  620. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  621. ckpt_gate_proj_name="gate_proj",
  622. ckpt_down_proj_name="down_proj",
  623. ckpt_up_proj_name="up_proj",
  624. num_experts=self.config.num_experts)
  625. params_dict = dict(self.named_parameters())
  626. for name, loaded_weight in weights:
  627. if "rotary_emb.inv_freq" in name:
  628. continue
  629. if "A_log" in name:
  630. name = name.replace("A_log", "A")
  631. if ".self_attn." in name:
  632. name = name.replace(".self_attn", "")
  633. if "feed_forward" in name and not _is_moe_layer(name):
  634. ## map MLP layers to expert with ID=0
  635. name = name.replace("feed_forward", "feed_forward.experts.0")
  636. for param_name, weight_name, shard_id in stacked_params_mapping:
  637. if weight_name not in name:
  638. continue
  639. if 'experts' in name:
  640. continue
  641. name = name.replace(weight_name, param_name)
  642. # Skip loading extra bias for GPTQ models.
  643. if name.endswith(".bias") and name not in params_dict:
  644. continue
  645. param = params_dict[name]
  646. weight_loader = param.weight_loader
  647. weight_loader(param, loaded_weight, shard_id)
  648. break
  649. else:
  650. for (
  651. param_name,
  652. weight_name,
  653. expert_id,
  654. shard_id,
  655. ) in expert_params_mapping:
  656. if weight_name not in name:
  657. continue
  658. name = name.replace(weight_name, param_name)
  659. param = params_dict[name]
  660. weight_loader = param.weight_loader
  661. weight_loader(param,
  662. loaded_weight,
  663. name,
  664. shard_id=shard_id,
  665. expert_id=expert_id)
  666. break
  667. else:
  668. # Skip loading extra bias for GPTQ models.
  669. if name.endswith(".bias") and name not in params_dict:
  670. continue
  671. param = params_dict[name]
  672. weight_loader = getattr(param, "weight_loader",
  673. default_weight_loader)
  674. weight_loader(param, loaded_weight)
  675. def _is_moe_layer(name: str):
  676. return any(
  677. [experts_name in name for experts_name in [
  678. "experts",
  679. "router",
  680. ]])