jamba.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917
  1. # coding=utf-8
  2. """Inference-only Jamba model."""
  3. from dataclasses import dataclass
  4. from typing import Dict, Iterable, List, Optional, Tuple
  5. import torch
  6. from torch import nn
  7. from torch.nn.parameter import Parameter
  8. from transformers import JambaConfig
  9. from aphrodite.attention.backends.abstract import AttentionMetadata
  10. from aphrodite.attention.layer import Attention
  11. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  12. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  13. # yapf: disable
  14. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  15. get_tensor_model_parallel_world_size)
  16. # yapf: enable
  17. from aphrodite.modeling.layers.activation import SiluAndMul
  18. from aphrodite.modeling.layers.fused_moe import FusedMoE
  19. from aphrodite.modeling.layers.layernorm import RMSNorm
  20. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  21. MergedColumnParallelLinear,
  22. QKVParallelLinear,
  23. ReplicatedLinear,
  24. RowParallelLinear)
  25. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  26. from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
  27. causal_conv1d_update,
  28. selective_scan_fn,
  29. selective_state_update)
  30. from aphrodite.modeling.layers.sampler import Sampler
  31. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  32. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  33. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  34. from aphrodite.modeling.models.interfaces import HasInnerState
  35. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  36. from aphrodite.modeling.utils import set_weight_attrs
  37. from aphrodite.quantization.base_config import QuantizationConfig
  38. from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
  39. _get_graph_batch_size)
  40. KVCache = Tuple[torch.Tensor, torch.Tensor]
  41. @dataclass
  42. class MambaCacheParams:
  43. is_prompt: bool = False
  44. conv_state: torch.Tensor = torch.Tensor()
  45. ssm_state: torch.Tensor = torch.Tensor()
  46. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  47. class JambaMambaMixer(nn.Module):
  48. """
  49. Compute ∆, A, B, C, and D the state space parameters and compute
  50. the `contextualized_states`. A, D are input independent
  51. (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
  52. for why A isn't selective) ∆, B, C are input-dependent
  53. (this is a key difference between Mamba and the linear time
  54. invariant S4, and is why Mamba is called
  55. **selective** state spaces)
  56. """
  57. def __init__(self, config: JambaConfig, layer_idx):
  58. super().__init__()
  59. self.config = config
  60. self.layer_idx = layer_idx
  61. self.hidden_size = config.hidden_size
  62. self.ssm_state_size = config.mamba_d_state
  63. self.conv_kernel_size = config.mamba_d_conv
  64. self.intermediate_size = config.mamba_expand * config.hidden_size
  65. self.time_step_rank = config.mamba_dt_rank
  66. self.use_conv_bias = config.mamba_conv_bias
  67. self.use_bias = config.mamba_proj_bias
  68. self.conv1d = ColumnParallelLinear(
  69. input_size=self.conv_kernel_size,
  70. output_size=self.intermediate_size,
  71. bias=self.use_conv_bias,
  72. )
  73. # unsqueeze to fit conv1d weights shape into the linear weights shape.
  74. # Can't do this in `weight_loader` since it already exists in
  75. # `ColumnParallelLinear` and `set_weight_attrs`
  76. # doesn't allow to override it
  77. self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
  78. self.in_proj = MergedColumnParallelLinear(self.hidden_size,
  79. [self.intermediate_size] * 2,
  80. bias=self.use_bias)
  81. # selective projection used to make dt, B and C input dependent
  82. self.x_proj = RowParallelLinear(
  83. self.intermediate_size,
  84. self.time_step_rank + self.ssm_state_size * 2,
  85. bias=False,
  86. )
  87. # time step projection (discretization) -
  88. # In the forward we need to apply dt_proj without the bias,
  89. # as the bias is added in the selective scan kernel.
  90. self.dt_proj = ColumnParallelLinear(self.time_step_rank,
  91. self.intermediate_size,
  92. bias=True,
  93. skip_bias_add=True)
  94. def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  95. tp_rank = get_tensor_model_parallel_rank()
  96. tp_size = get_tensor_model_parallel_world_size()
  97. param.data.copy_(
  98. loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
  99. dim=0)[tp_rank])
  100. def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  101. weight_loader(param, -torch.exp(loaded_weight.float()))
  102. tp_size = get_tensor_model_parallel_world_size()
  103. self.A = nn.Parameter(
  104. torch.empty(
  105. self.intermediate_size // tp_size,
  106. self.ssm_state_size,
  107. dtype=torch.float32,
  108. ))
  109. self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
  110. set_weight_attrs(self.D, {"weight_loader": weight_loader})
  111. set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
  112. self.out_proj = RowParallelLinear(
  113. self.intermediate_size,
  114. self.hidden_size,
  115. bias=self.use_bias,
  116. input_is_parallel=True,
  117. )
  118. self.activation = config.hidden_act
  119. self.dt_layernorm = RMSNorm(self.time_step_rank,
  120. eps=config.rms_norm_eps)
  121. self.b_layernorm = RMSNorm(self.ssm_state_size,
  122. eps=config.rms_norm_eps)
  123. self.c_layernorm = RMSNorm(self.ssm_state_size,
  124. eps=config.rms_norm_eps)
  125. def mamba_forward(self,
  126. hidden_states: torch.Tensor,
  127. cache_params: MambaCacheParams = None):
  128. # 1. Gated MLP's linear projection
  129. projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
  130. hidden_states, gate = projected_states.chunk(2, dim=1)
  131. # 2. Convolution sequence transformation
  132. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
  133. self.conv1d.weight.size(2))
  134. if cache_params is not None and not cache_params.is_prompt:
  135. hidden_states = causal_conv1d_update(
  136. hidden_states.squeeze(-1),
  137. cache_params.conv_state,
  138. conv_weights,
  139. self.conv1d.bias,
  140. self.activation,
  141. )
  142. hidden_states = hidden_states.unsqueeze(-1)
  143. else:
  144. if cache_params is not None:
  145. conv_states = nn.functional.pad(
  146. hidden_states,
  147. (self.conv_kernel_size - hidden_states.shape[-1], 0))
  148. cache_params.conv_state.copy_(conv_states)
  149. hidden_states, _ = causal_conv1d_fn(
  150. hidden_states,
  151. conv_weights,
  152. self.conv1d.bias,
  153. activation=self.activation,
  154. )
  155. # 3. State Space Model sequence transformation
  156. # 3.a. input varying initialization of time_step, B and C
  157. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
  158. time_step, B, C = torch.split(
  159. ssm_parameters,
  160. [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
  161. dim=-1,
  162. )
  163. time_step = self.dt_layernorm(time_step.contiguous())
  164. B = self.b_layernorm(B.contiguous())
  165. C = self.c_layernorm(C.contiguous())
  166. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
  167. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  168. time_proj_bias = (self.dt_proj.bias.float() if hasattr(
  169. self.dt_proj, "bias") else None)
  170. if cache_params is not None and not cache_params.is_prompt:
  171. scan_outputs = selective_state_update(
  172. cache_params.ssm_state,
  173. hidden_states[..., 0],
  174. discrete_time_step[..., 0],
  175. self.A,
  176. B[:, 0],
  177. C[:, 0],
  178. self.D,
  179. gate[..., 0],
  180. time_proj_bias,
  181. dt_softplus=True,
  182. ).unsqueeze(-1)
  183. else:
  184. scan_outputs, ssm_state = selective_scan_fn(
  185. hidden_states,
  186. discrete_time_step,
  187. self.A,
  188. B.transpose(1, 2),
  189. C.transpose(1, 2),
  190. self.D.float(),
  191. gate,
  192. time_proj_bias,
  193. delta_softplus=True,
  194. return_last_state=True,
  195. )
  196. if ssm_state is not None and cache_params is not None:
  197. cache_params.ssm_state.copy_(ssm_state)
  198. # 4. Final linear projection
  199. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
  200. return contextualized_states
  201. def forward(
  202. self,
  203. hidden_states: torch.Tensor,
  204. attn_metadata: AttentionMetadata,
  205. conv_state: torch.Tensor,
  206. ssm_state: torch.Tensor,
  207. ):
  208. if attn_metadata.prefill_metadata is not None:
  209. offset = 0
  210. for i, prompt_len in enumerate(
  211. attn_metadata.prefill_metadata.seq_lens):
  212. cache = MambaCacheParams(True,
  213. conv_state=conv_state[i].unsqueeze(0),
  214. ssm_state=ssm_state[i].unsqueeze(0))
  215. hidden_states[offset:offset + prompt_len].copy_(
  216. self.mamba_forward(hidden_states[offset:offset +
  217. prompt_len].unsqueeze(0),
  218. cache_params=cache)[0])
  219. offset += prompt_len
  220. else:
  221. cache = MambaCacheParams(False,
  222. conv_state=conv_state,
  223. ssm_state=ssm_state)
  224. hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
  225. cache_params=cache)
  226. hidden_states = hidden_states.squeeze(1)
  227. return hidden_states
  228. class JambaMLP(nn.Module):
  229. def __init__(
  230. self,
  231. config: JambaConfig,
  232. quant_config: Optional[QuantizationConfig] = None,
  233. ) -> None:
  234. super().__init__()
  235. hidden_size = config.hidden_size
  236. intermediate_size = config.intermediate_size
  237. hidden_act = config.hidden_act
  238. self.gate_up_proj = MergedColumnParallelLinear(
  239. hidden_size, [intermediate_size] * 2,
  240. bias=False,
  241. quant_config=quant_config)
  242. self.down_proj = RowParallelLinear(intermediate_size,
  243. hidden_size,
  244. bias=False,
  245. quant_config=quant_config)
  246. if hidden_act != "silu":
  247. raise ValueError(f"Unsupported activation: {hidden_act}. "
  248. "Only silu is supported for now.")
  249. self.act_fn = SiluAndMul()
  250. def forward(self, x):
  251. gate_up, _ = self.gate_up_proj(x)
  252. x = self.act_fn(gate_up)
  253. x, _ = self.down_proj(x)
  254. return x
  255. class JambaMoE(nn.Module):
  256. def __init__(self,
  257. config: JambaConfig,
  258. num_experts: Optional[int] = None,
  259. top_k: Optional[int] = None,
  260. params_dtype: Optional[torch.dtype] = None,
  261. tp_size: Optional[int] = None,
  262. quant_config: Optional[QuantizationConfig] = None):
  263. super().__init__()
  264. self.num_total_experts = num_experts or config.num_experts
  265. self.top_k = top_k or config.num_experts_per_tok
  266. self.hidden_size = config.hidden_size
  267. self.intermediate_size = config.intermediate_size
  268. if self.num_total_experts > 1:
  269. self.router = ReplicatedLinear(self.hidden_size,
  270. self.num_total_experts,
  271. bias=False,
  272. quant_config=None,
  273. params_dtype=params_dtype)
  274. self.experts = FusedMoE(self.num_total_experts,
  275. self.top_k,
  276. self.hidden_size,
  277. self.intermediate_size,
  278. tp_size=tp_size,
  279. params_dtype=params_dtype,
  280. reduce_results=True,
  281. renormalize=False,
  282. use_grouped_topk=False,
  283. quant_config=quant_config)
  284. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  285. orig_shape = hidden_states.shape
  286. hidden_states = hidden_states.view(-1, self.hidden_size)
  287. # router_logits: (batch * sequence_length, n_experts)
  288. if self.num_total_experts > 1:
  289. router_logits, _ = self.router(hidden_states)
  290. else:
  291. router_logits = torch.ones((hidden_states.shape[0], 1),
  292. device=hidden_states.device,
  293. dtype=hidden_states.dtype)
  294. hidden_states = self.experts(hidden_states, router_logits)
  295. return hidden_states.view(orig_shape)
  296. class JambaMambaDecoderLayer(nn.Module):
  297. def __init__(self,
  298. config: JambaConfig,
  299. layer_idx: int,
  300. cache_config: Optional[CacheConfig] = None,
  301. quant_config: Optional[QuantizationConfig] = None) -> None:
  302. super().__init__()
  303. self.layer_idx = layer_idx
  304. self.config = config
  305. self.mamba = JambaMambaMixer(config, layer_idx)
  306. num_experts = config.layers_num_experts[layer_idx]
  307. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  308. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  309. self.input_layernorm = RMSNorm(config.hidden_size,
  310. eps=config.rms_norm_eps)
  311. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  312. eps=config.rms_norm_eps)
  313. def forward(
  314. self,
  315. hidden_states: torch.Tensor,
  316. attn_metadata: AttentionMetadata,
  317. residual: Optional[torch.Tensor],
  318. conv_state: torch.Tensor,
  319. ssm_state: torch.Tensor,
  320. **kwargs,
  321. ):
  322. if residual is None:
  323. residual = hidden_states
  324. hidden_states = self.input_layernorm(hidden_states)
  325. else:
  326. hidden_states, residual = self.input_layernorm(
  327. hidden_states, residual)
  328. hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
  329. ssm_state)
  330. # Fully Connected
  331. hidden_states, residual = self.pre_ff_layernorm(
  332. hidden_states, residual)
  333. hidden_states = self.feed_forward(hidden_states)
  334. return hidden_states, residual
  335. class JambaAttentionDecoderLayer(nn.Module):
  336. def __init__(
  337. self,
  338. config: JambaConfig,
  339. layer_idx: int,
  340. cache_config: Optional[CacheConfig] = None,
  341. quant_config: Optional[QuantizationConfig] = None,
  342. ) -> None:
  343. super().__init__()
  344. self.hidden_size = config.hidden_size
  345. tp_size = get_tensor_model_parallel_world_size()
  346. self.total_num_heads = config.num_attention_heads
  347. assert self.total_num_heads % tp_size == 0
  348. self.num_heads = self.total_num_heads // tp_size
  349. self.total_num_kv_heads = config.num_key_value_heads
  350. if self.total_num_kv_heads >= tp_size:
  351. # Number of KV heads is greater than TP size, so we partition
  352. # the KV heads across multiple tensor parallel GPUs.
  353. assert self.total_num_kv_heads % tp_size == 0
  354. else:
  355. # Number of KV heads is less than TP size, so we replicate
  356. # the KV heads across multiple tensor parallel GPUs.
  357. assert tp_size % self.total_num_kv_heads == 0
  358. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  359. self.head_dim = config.hidden_size // self.total_num_heads
  360. self.q_size = self.num_heads * self.head_dim
  361. self.kv_size = self.num_kv_heads * self.head_dim
  362. self.scaling = self.head_dim**-0.5
  363. self.qkv_proj = QKVParallelLinear(
  364. config.hidden_size,
  365. self.head_dim,
  366. self.total_num_heads,
  367. self.total_num_kv_heads,
  368. bias=False,
  369. quant_config=quant_config,
  370. )
  371. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  372. config.hidden_size,
  373. bias=False,
  374. quant_config=quant_config)
  375. self.attn = Attention(
  376. self.num_heads,
  377. self.head_dim,
  378. self.scaling,
  379. num_kv_heads=self.num_kv_heads,
  380. cache_config=cache_config,
  381. )
  382. num_experts = config.layers_num_experts[layer_idx]
  383. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  384. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  385. self.input_layernorm = RMSNorm(config.hidden_size,
  386. eps=config.rms_norm_eps)
  387. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  388. eps=config.rms_norm_eps)
  389. def self_attention(
  390. self,
  391. positions: torch.Tensor,
  392. hidden_states: torch.Tensor,
  393. kv_cache: torch.Tensor,
  394. attn_metadata: AttentionMetadata,
  395. **kwargs,
  396. ) -> torch.Tensor:
  397. qkv, _ = self.qkv_proj(hidden_states)
  398. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  399. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  400. output, _ = self.o_proj(attn_output)
  401. return output
  402. def forward(
  403. self,
  404. positions: torch.Tensor,
  405. hidden_states: torch.Tensor,
  406. kv_cache: torch.Tensor,
  407. attn_metadata: AttentionMetadata,
  408. residual: Optional[torch.Tensor],
  409. **kwargs,
  410. ):
  411. if residual is None:
  412. residual = hidden_states
  413. hidden_states = self.input_layernorm(hidden_states)
  414. else:
  415. hidden_states, residual = self.input_layernorm(
  416. hidden_states, residual)
  417. hidden_states = self.self_attention(
  418. positions=positions,
  419. hidden_states=hidden_states,
  420. kv_cache=kv_cache,
  421. attn_metadata=attn_metadata,
  422. )
  423. # Fully Connected
  424. hidden_states, residual = self.pre_ff_layernorm(
  425. hidden_states, residual)
  426. hidden_states = self.feed_forward(hidden_states)
  427. return hidden_states, residual
  428. ALL_DECODER_LAYER_TYPES = {
  429. "attention": JambaAttentionDecoderLayer,
  430. "mamba": JambaMambaDecoderLayer
  431. }
  432. class JambaModel(nn.Module):
  433. def __init__(
  434. self,
  435. config: JambaConfig,
  436. quant_config: Optional[QuantizationConfig] = None,
  437. cache_config: Optional[CacheConfig] = None,
  438. lora_config: Optional[LoRAConfig] = None,
  439. ) -> None:
  440. super().__init__()
  441. self.config = config
  442. self.padding_idx = config.pad_token_id
  443. lora_vocab = ((lora_config.lora_extra_vocab_size *
  444. (lora_config.max_loras or 1)) if lora_config else 0)
  445. self.vocab_size = config.vocab_size + lora_vocab
  446. self.org_vocab_size = config.vocab_size
  447. self.embed_tokens = VocabParallelEmbedding(
  448. self.vocab_size,
  449. config.hidden_size,
  450. org_num_embeddings=config.vocab_size,
  451. )
  452. decoder_layers = []
  453. for i in range(config.num_hidden_layers):
  454. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  455. decoder_layers.append(
  456. layer_class(config,
  457. layer_idx=i,
  458. cache_config=cache_config,
  459. quant_config=quant_config))
  460. self.layers = nn.ModuleList(decoder_layers)
  461. self.final_layernorm = RMSNorm(config.hidden_size,
  462. eps=config.rms_norm_eps)
  463. def forward(
  464. self,
  465. input_ids: torch.Tensor,
  466. positions: torch.Tensor,
  467. kv_caches: List[torch.Tensor],
  468. attn_metadata: AttentionMetadata,
  469. conv_state: torch.Tensor,
  470. ssm_state: torch.Tensor,
  471. ) -> torch.Tensor:
  472. hidden_states = self.embed_tokens(input_ids)
  473. residual = None
  474. for i in range(len(self.layers)):
  475. layer = self.layers[i]
  476. kv_cache = None
  477. current_ssm_state = None
  478. current_conv_state = None
  479. if isinstance(layer, JambaAttentionDecoderLayer):
  480. kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
  481. self.config.attn_layer_period]
  482. if isinstance(layer, JambaMambaDecoderLayer):
  483. current_state_layer = i - (1 +
  484. (i - self.config.attn_layer_offset)
  485. // self.config.attn_layer_period)
  486. current_ssm_state = ssm_state[current_state_layer]
  487. current_conv_state = conv_state[current_state_layer]
  488. hidden_states, residual = layer(
  489. positions=positions,
  490. hidden_states=hidden_states,
  491. kv_cache=kv_cache,
  492. attn_metadata=attn_metadata,
  493. residual=residual,
  494. conv_state=current_conv_state,
  495. ssm_state=current_ssm_state,
  496. )
  497. hidden_states, _ = self.final_layernorm(hidden_states, residual)
  498. return hidden_states
  499. class JambaForCausalLM(nn.Module, HasInnerState):
  500. packed_modules_mapping = {
  501. "qkv_proj": [
  502. "q_proj",
  503. "k_proj",
  504. "v_proj",
  505. ],
  506. }
  507. # LoRA specific attributes
  508. supported_lora_modules = [
  509. "qkv_proj",
  510. "o_proj",
  511. "embed_tokens",
  512. "lm_head",
  513. ]
  514. embedding_modules = {
  515. "embed_tokens": "input_embeddings",
  516. "lm_head": "output_embeddings",
  517. }
  518. embedding_padding_modules = ["lm_head"]
  519. def __init__(
  520. self,
  521. config: JambaConfig,
  522. cache_config: Optional[CacheConfig] = None,
  523. quant_config: Optional[QuantizationConfig] = None,
  524. lora_config: Optional[LoRAConfig] = None,
  525. scheduler_config: Optional[SchedulerConfig] = None,
  526. ) -> None:
  527. assert not scheduler_config.chunked_prefill_enabled, \
  528. "Jamba currently does not support chunked prefill"
  529. assert not cache_config.enable_prefix_caching, \
  530. "Jamba currently does not support prefix caching"
  531. super().__init__()
  532. self.config = config
  533. self.scheduler_config = scheduler_config
  534. self.model = JambaModel(config,
  535. cache_config=cache_config,
  536. quant_config=quant_config,
  537. lora_config=lora_config)
  538. self.unpadded_vocab_size = config.vocab_size
  539. if lora_config:
  540. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  541. self.lm_head = ParallelLMHead(
  542. self.unpadded_vocab_size,
  543. config.hidden_size,
  544. org_num_embeddings=config.vocab_size,
  545. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  546. # We need bigger padding if using lora for kernel
  547. # compatibility
  548. if not lora_config else lora_config.lora_vocab_padding_size,
  549. )
  550. # Current step used indices
  551. self.current_indices: List[int] = []
  552. # Used to track and store by the Mamba cache between steps.
  553. self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
  554. # Used as an input_buffer for the CUDA graph runs.
  555. self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
  556. # Maps between the request id and a dict that maps between the seq_id
  557. # and its index inside the self.mamba_cache
  558. self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
  559. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  560. config.vocab_size)
  561. self.sampler = Sampler()
  562. def forward(self,
  563. input_ids: torch.Tensor,
  564. positions: torch.Tensor,
  565. kv_caches: List[KVCache],
  566. attn_metadata: AttentionMetadata,
  567. intermediate_tensors: Optional[IntermediateTensors] = None,
  568. **kwargs):
  569. if not self.mamba_cache:
  570. self._prepare_mamba_cache()
  571. if "seqlen_agnostic_capture_inputs" not in kwargs:
  572. # We get here only on Prefill/Eager mode runs
  573. assert all(
  574. key in kwargs
  575. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  576. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  577. finished_requests_ids = kwargs["finished_requests_ids"]
  578. self._release_mamba_cache(finished_requests_ids)
  579. batch_size = input_ids.shape[0]
  580. if attn_metadata.prefill_metadata:
  581. batch_size = len(request_ids_to_seq_ids)
  582. (
  583. current_seqlen_agnostic_cache,
  584. indices,
  585. ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
  586. batch_size,
  587. finished_requests_ids)
  588. else:
  589. # CUDA graph capturing runs
  590. current_seqlen_agnostic_cache, indices = (
  591. kwargs["seqlen_agnostic_capture_inputs"],
  592. [],
  593. )
  594. self.current_indices = indices
  595. hidden_states = self.model(input_ids, positions, kv_caches,
  596. attn_metadata,
  597. current_seqlen_agnostic_cache[0],
  598. current_seqlen_agnostic_cache[1])
  599. if "seqlen_agnostic_capture_inputs" not in kwargs:
  600. self._copy_mamba_cache_by_indices(self.current_indices,
  601. current_seqlen_agnostic_cache)
  602. return hidden_states
  603. def _copy_mamba_cache_by_indices(
  604. self, indices: List[int],
  605. current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
  606. for i, offset in enumerate(indices):
  607. self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
  608. def _copy_mamba_cache(self, index_to: int, index_from: int,
  609. from_buffer: Tuple[torch.Tensor, torch.Tensor]):
  610. assert len(self.mamba_cache) > 0
  611. for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
  612. cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
  613. non_blocking=True)
  614. def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
  615. seqs_id: List[int]) -> List[int]:
  616. indices_for_current_run = []
  617. for seq_id in seqs_id:
  618. if cur_rid not in self.mamba_cache_indices_mapping:
  619. self.mamba_cache_indices_mapping[cur_rid] = {}
  620. first_free_index = self._first_free_index_in_mamba_cache()
  621. self.mamba_cache_indices_mapping[cur_rid][
  622. seq_id] = first_free_index
  623. index_for_current_run = first_free_index
  624. ## case of decoding n>1, copy prefill cache to decoding indices
  625. elif seq_id not in (seq_ids2indices :=
  626. self.mamba_cache_indices_mapping[cur_rid]):
  627. first_free_index = self._first_free_index_in_mamba_cache()
  628. index_exist = list(seq_ids2indices.values())[0]
  629. self._copy_mamba_cache(index_from=index_exist,
  630. index_to=first_free_index,
  631. from_buffer=self.mamba_cache)
  632. self.mamba_cache_indices_mapping[cur_rid][
  633. seq_id] = first_free_index
  634. index_for_current_run = first_free_index
  635. else:
  636. index_for_current_run = self.mamba_cache_indices_mapping[
  637. cur_rid][seq_id]
  638. indices_for_current_run.append(index_for_current_run)
  639. return indices_for_current_run
  640. def _prepare_current_run_mamba_cache(
  641. self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
  642. finished_requests_ids: List[str]
  643. ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
  644. indices_for_current_run = []
  645. for request_id, seqs_id in request_ids_to_seq_ids.items():
  646. if request_id in finished_requests_ids:
  647. # Do not allocate cache for requests that run
  648. # and finish right after
  649. continue
  650. indices_for_current_run += self._assign_seq_id_to_mamba_cache(
  651. request_id, seqs_id)
  652. ## Pad the batch in case of running batch that was not captured via CG
  653. padded_indices = indices_for_current_run.copy()
  654. pad_index = self._first_free_index_in_mamba_cache()
  655. for _ in range(batch_size - len(indices_for_current_run)):
  656. padded_indices.append(pad_index)
  657. conv_state = self.mamba_cache[0][:, padded_indices]
  658. temporal_state = self.mamba_cache[1][:, padded_indices]
  659. return (conv_state, temporal_state), indices_for_current_run
  660. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  661. """
  662. Copy the relevant Mamba cache into the CUDA graph input buffer
  663. that was provided during the capture runs
  664. (JambaForCausalLM.mamba_gc_cache_buffer).
  665. """
  666. assert all(
  667. key in kwargs
  668. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  669. finished_requests_ids = kwargs["finished_requests_ids"]
  670. self._release_mamba_cache(finished_requests_ids)
  671. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  672. cg_batch_size = input_buffers['input_ids'].shape[0]
  673. (
  674. current_mamba_cache,
  675. indices,
  676. ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
  677. cg_batch_size,
  678. finished_requests_ids)
  679. self.current_indices = indices
  680. for input_buffer, current_cache_buffer in zip(
  681. input_buffers["seqlen_agnostic_capture_inputs"],
  682. current_mamba_cache):
  683. input_buffer.copy_(current_cache_buffer, non_blocking=True)
  684. def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
  685. """
  686. Copy the relevant Mamba cache from the CUDA graph input_buffers
  687. back to the JambaForCausalLM.mamba_cache after CUDA
  688. graph replay run is done.
  689. """
  690. self._copy_mamba_cache_by_indices(
  691. self.current_indices,
  692. input_buffers["seqlen_agnostic_capture_inputs"])
  693. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  694. """
  695. Provide the CUDA graph capture runs with a buffer in adjusted size.
  696. The buffer is used to maintain the Mamba Cache during the CUDA graph
  697. replay runs.
  698. """
  699. return tuple(buffer[:, :batch_size]
  700. for buffer in self.mamba_gc_cache_buffer)
  701. def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
  702. for req_id in finished_seq_groups_req_ids:
  703. if req_id in self.mamba_cache_indices_mapping:
  704. self.mamba_cache_indices_mapping.pop(req_id)
  705. def _first_free_index_in_mamba_cache(self) -> int:
  706. if self.mamba_cache:
  707. max_possible_batch_size = self.mamba_cache[0].shape[1]
  708. occupied = [
  709. id for seq_ids in self.mamba_cache_indices_mapping.values()
  710. for id in seq_ids.values()
  711. ]
  712. first_free_index = [
  713. i not in occupied for i in range(max_possible_batch_size)
  714. ].index(True)
  715. return first_free_index
  716. return 0
  717. def _get_mamba_cache_shape(
  718. self
  719. ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
  720. world_size = get_tensor_model_parallel_world_size()
  721. hidden_size = self.config.hidden_size
  722. conv_state_shape = (
  723. self.config.mamba_expand * hidden_size // world_size,
  724. self.config.mamba_d_conv,
  725. )
  726. temporal_state_shape = (
  727. self.config.mamba_expand * self.config.hidden_size // world_size,
  728. self.config.mamba_d_state,
  729. )
  730. return conv_state_shape, temporal_state_shape
  731. def _prepare_mamba_cache(self):
  732. dtype = self.lm_head.weight.dtype
  733. layers_type = self.config.layers_block_type
  734. mamba_layers = sum(
  735. [layer_type == "mamba" for layer_type in layers_type])
  736. max_batch_size = (_get_graph_batch_size(
  737. self.scheduler_config.max_num_seqs) if self.scheduler_config else
  738. max(_BATCH_SIZES_TO_CAPTURE)) + 10
  739. conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
  740. assert conv_state_shape is not None and temporal_state_shape is not None
  741. for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
  742. buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
  743. conv_state_shape,
  744. dtype=dtype,
  745. device="cuda"),
  746. torch.empty(size=(mamba_layers, max_batch_size) +
  747. temporal_state_shape,
  748. dtype=dtype,
  749. device="cuda"))
  750. setattr(self, buffername, buffer)
  751. def compute_logits(self, hidden_states: torch.Tensor,
  752. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  753. logits = self.logits_processor(self.lm_head, hidden_states,
  754. sampling_metadata)
  755. return logits
  756. def sample(
  757. self,
  758. logits: Optional[torch.Tensor],
  759. sampling_metadata: SamplingMetadata,
  760. ) -> Optional[SamplerOutput]:
  761. next_tokens = self.sampler(logits, sampling_metadata)
  762. return next_tokens
  763. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  764. stacked_params_mapping = [
  765. # (param_name, shard_name, shard_id)
  766. ("qkv_proj", "q_proj", "q"),
  767. ("qkv_proj", "k_proj", "k"),
  768. ("qkv_proj", "v_proj", "v"),
  769. ("gate_up_proj", "gate_proj", 0),
  770. ("gate_up_proj", "up_proj", 1),
  771. ]
  772. # Params for weights, fp8 weight scales, fp8 activation scales
  773. # (param_name, weight_name, expert_id, shard_id)
  774. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  775. ckpt_gate_proj_name="gate_proj",
  776. ckpt_down_proj_name="down_proj",
  777. ckpt_up_proj_name="up_proj",
  778. num_experts=self.config.num_experts)
  779. params_dict = dict(self.named_parameters())
  780. for name, loaded_weight in weights:
  781. if "rotary_emb.inv_freq" in name:
  782. continue
  783. if "A_log" in name:
  784. name = name.replace("A_log", "A")
  785. if ".self_attn." in name:
  786. name = name.replace(".self_attn", "")
  787. for param_name, weight_name, shard_id in stacked_params_mapping:
  788. if weight_name not in name:
  789. continue
  790. if 'experts' in name:
  791. continue
  792. name = name.replace(weight_name, param_name)
  793. # Skip loading extra bias for GPTQ models.
  794. if name.endswith(".bias") and name not in params_dict:
  795. continue
  796. param = params_dict[name]
  797. weight_loader = param.weight_loader
  798. weight_loader(param, loaded_weight, shard_id)
  799. break
  800. else:
  801. for mapping in expert_params_mapping:
  802. param_name, weight_name, expert_id, shard_id = mapping
  803. if weight_name not in name:
  804. continue
  805. name = name.replace(weight_name, param_name)
  806. param = params_dict[name]
  807. weight_loader = param.weight_loader
  808. weight_loader(param,
  809. loaded_weight,
  810. weight_name,
  811. shard_id=shard_id,
  812. expert_id=expert_id)
  813. break
  814. else:
  815. # Skip loading extra bias for GPTQ models.
  816. if name.endswith(".bias") and name not in params_dict:
  817. continue
  818. param = params_dict[name]
  819. weight_loader = getattr(param, "weight_loader",
  820. default_weight_loader)
  821. weight_loader(param, loaded_weight)