jamba.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969
  1. # coding=utf-8
  2. """Inference-only Jamba model."""
  3. from dataclasses import dataclass
  4. from typing import Dict, Iterable, List, Optional, Tuple
  5. import torch
  6. from torch import nn
  7. from torch.nn.parameter import Parameter
  8. from transformers import JambaConfig
  9. from aphrodite.attention.backends.abstract import AttentionMetadata
  10. from aphrodite.attention.layer import Attention
  11. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  12. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  13. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  14. get_tensor_model_parallel_world_size,
  15. tensor_model_parallel_all_reduce)
  16. from aphrodite.modeling.layers.activation import SiluAndMul
  17. from aphrodite.modeling.layers.fused_moe import fused_moe
  18. from aphrodite.modeling.layers.layernorm import RMSNorm
  19. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  20. MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. ReplicatedLinear,
  23. RowParallelLinear)
  24. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  25. from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
  26. causal_conv1d_update,
  27. selective_scan_fn,
  28. selective_state_update)
  29. from aphrodite.modeling.layers.sampler import Sampler
  30. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  31. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  32. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  33. from aphrodite.modeling.models.interfaces import HasInnerState
  34. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  35. from aphrodite.modeling.utils import set_weight_attrs
  36. from aphrodite.quantization.base_config import QuantizationConfig
  37. from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
  38. _get_graph_batch_size)
  39. KVCache = Tuple[torch.Tensor, torch.Tensor]
  40. @dataclass
  41. class MambaCacheParams:
  42. is_prompt: bool = False
  43. conv_state: torch.Tensor = torch.Tensor()
  44. ssm_state: torch.Tensor = torch.Tensor()
  45. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  46. class JambaMambaMixer(nn.Module):
  47. """
  48. Compute ∆, A, B, C, and D the state space parameters and compute
  49. the `contextualized_states`. A, D are input independent
  50. (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
  51. for why A isn't selective) ∆, B, C are input-dependent
  52. (this is a key difference between Mamba and the linear time
  53. invariant S4, and is why Mamba is called
  54. **selective** state spaces)
  55. """
  56. def __init__(self, config: JambaConfig, layer_idx):
  57. super().__init__()
  58. self.config = config
  59. self.layer_idx = layer_idx
  60. self.hidden_size = config.hidden_size
  61. self.ssm_state_size = config.mamba_d_state
  62. self.conv_kernel_size = config.mamba_d_conv
  63. self.intermediate_size = config.mamba_expand * config.hidden_size
  64. self.time_step_rank = config.mamba_dt_rank
  65. self.use_conv_bias = config.mamba_conv_bias
  66. self.use_bias = config.mamba_proj_bias
  67. self.conv1d = ColumnParallelLinear(
  68. input_size=self.conv_kernel_size,
  69. output_size=self.intermediate_size,
  70. bias=self.use_conv_bias,
  71. )
  72. # unsqueeze to fit conv1d weights shape into the linear weights shape.
  73. # Can't do this in `weight_loader` since it already exists in
  74. # `ColumnParallelLinear` and `set_weight_attrs`
  75. # doesn't allow to override it
  76. self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
  77. self.in_proj = MergedColumnParallelLinear(self.hidden_size,
  78. [self.intermediate_size] * 2,
  79. bias=self.use_bias)
  80. # selective projection used to make dt, B and C input dependent
  81. self.x_proj = RowParallelLinear(
  82. self.intermediate_size,
  83. self.time_step_rank + self.ssm_state_size * 2,
  84. bias=False,
  85. )
  86. # time step projection (discretization) -
  87. # In the forward we need to apply dt_proj without the bias,
  88. # as the bias is added in the selective scan kernel.
  89. self.dt_proj = ColumnParallelLinear(self.time_step_rank,
  90. self.intermediate_size,
  91. bias=True,
  92. skip_bias_add=True)
  93. def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  94. tp_rank = get_tensor_model_parallel_rank()
  95. tp_size = get_tensor_model_parallel_world_size()
  96. param.data.copy_(
  97. loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
  98. dim=0)[tp_rank])
  99. def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  100. weight_loader(param, -torch.exp(loaded_weight.float()))
  101. tp_size = get_tensor_model_parallel_world_size()
  102. self.A = nn.Parameter(
  103. torch.empty(
  104. self.intermediate_size // tp_size,
  105. self.ssm_state_size,
  106. dtype=torch.float32,
  107. ))
  108. self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
  109. set_weight_attrs(self.D, {"weight_loader": weight_loader})
  110. set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
  111. self.out_proj = RowParallelLinear(
  112. self.intermediate_size,
  113. self.hidden_size,
  114. bias=self.use_bias,
  115. input_is_parallel=True,
  116. )
  117. self.activation = config.hidden_act
  118. self.dt_layernorm = RMSNorm(self.time_step_rank,
  119. eps=config.rms_norm_eps)
  120. self.b_layernorm = RMSNorm(self.ssm_state_size,
  121. eps=config.rms_norm_eps)
  122. self.c_layernorm = RMSNorm(self.ssm_state_size,
  123. eps=config.rms_norm_eps)
  124. def mamba_forward(self,
  125. hidden_states: torch.Tensor,
  126. cache_params: MambaCacheParams = None):
  127. # 1. Gated MLP's linear projection
  128. projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
  129. hidden_states, gate = projected_states.chunk(2, dim=1)
  130. # 2. Convolution sequence transformation
  131. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
  132. self.conv1d.weight.size(2))
  133. if cache_params is not None and not cache_params.is_prompt:
  134. hidden_states = causal_conv1d_update(
  135. hidden_states.squeeze(-1),
  136. cache_params.conv_state,
  137. conv_weights,
  138. self.conv1d.bias,
  139. self.activation,
  140. )
  141. hidden_states = hidden_states.unsqueeze(-1)
  142. else:
  143. if cache_params is not None:
  144. conv_states = nn.functional.pad(
  145. hidden_states,
  146. (self.conv_kernel_size - hidden_states.shape[-1], 0))
  147. cache_params.conv_state.copy_(conv_states)
  148. hidden_states, _ = causal_conv1d_fn(
  149. hidden_states,
  150. conv_weights,
  151. self.conv1d.bias,
  152. activation=self.activation,
  153. )
  154. # 3. State Space Model sequence transformation
  155. # 3.a. input varying initialization of time_step, B and C
  156. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
  157. time_step, B, C = torch.split(
  158. ssm_parameters,
  159. [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
  160. dim=-1,
  161. )
  162. time_step = self.dt_layernorm(time_step.contiguous())
  163. B = self.b_layernorm(B.contiguous())
  164. C = self.c_layernorm(C.contiguous())
  165. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
  166. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  167. time_proj_bias = (self.dt_proj.bias.float() if hasattr(
  168. self.dt_proj, "bias") else None)
  169. if cache_params is not None and not cache_params.is_prompt:
  170. scan_outputs = selective_state_update(
  171. cache_params.ssm_state,
  172. hidden_states[..., 0],
  173. discrete_time_step[..., 0],
  174. self.A,
  175. B[:, 0],
  176. C[:, 0],
  177. self.D,
  178. gate[..., 0],
  179. time_proj_bias,
  180. dt_softplus=True,
  181. ).unsqueeze(-1)
  182. else:
  183. scan_outputs, ssm_state = selective_scan_fn(
  184. hidden_states,
  185. discrete_time_step,
  186. self.A,
  187. B.transpose(1, 2),
  188. C.transpose(1, 2),
  189. self.D.float(),
  190. gate,
  191. time_proj_bias,
  192. delta_softplus=True,
  193. return_last_state=True,
  194. )
  195. if ssm_state is not None and cache_params is not None:
  196. cache_params.ssm_state.copy_(ssm_state)
  197. # 4. Final linear projection
  198. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
  199. return contextualized_states
  200. def forward(
  201. self,
  202. hidden_states: torch.Tensor,
  203. attn_metadata: AttentionMetadata,
  204. conv_state: torch.Tensor,
  205. ssm_state: torch.Tensor,
  206. ):
  207. if attn_metadata.prefill_metadata is not None:
  208. offset = 0
  209. for i, prompt_len in enumerate(
  210. attn_metadata.prefill_metadata.seq_lens):
  211. cache = MambaCacheParams(True,
  212. conv_state=conv_state[i].unsqueeze(0),
  213. ssm_state=ssm_state[i].unsqueeze(0))
  214. hidden_states[offset:offset + prompt_len].copy_(
  215. self.mamba_forward(hidden_states[offset:offset +
  216. prompt_len].unsqueeze(0),
  217. cache_params=cache)[0])
  218. offset += prompt_len
  219. else:
  220. cache = MambaCacheParams(False,
  221. conv_state=conv_state,
  222. ssm_state=ssm_state)
  223. hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
  224. cache_params=cache)
  225. hidden_states = hidden_states.squeeze(1)
  226. return hidden_states
  227. class JambaMLP(nn.Module):
  228. def __init__(
  229. self,
  230. config: JambaConfig,
  231. quant_config: Optional[QuantizationConfig] = None,
  232. ) -> None:
  233. super().__init__()
  234. hidden_size = config.hidden_size
  235. intermediate_size = config.intermediate_size
  236. hidden_act = config.hidden_act
  237. self.gate_up_proj = MergedColumnParallelLinear(
  238. hidden_size, [intermediate_size] * 2,
  239. bias=False,
  240. quant_config=quant_config)
  241. self.down_proj = RowParallelLinear(intermediate_size,
  242. hidden_size,
  243. bias=False,
  244. quant_config=quant_config)
  245. if hidden_act != "silu":
  246. raise ValueError(f"Unsupported activation: {hidden_act}. "
  247. "Only silu is supported for now.")
  248. self.act_fn = SiluAndMul()
  249. def forward(self, x):
  250. gate_up, _ = self.gate_up_proj(x)
  251. x = self.act_fn(gate_up)
  252. x, _ = self.down_proj(x)
  253. return x
  254. class JambaMoE(nn.Module):
  255. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  256. across all ranks.
  257. Each expert's weights are sharded across all ranks and a fused MoE
  258. kernel is used for the forward pass, and finally we reduce the outputs
  259. across ranks.
  260. """
  261. def __init__(
  262. self,
  263. config: JambaConfig,
  264. params_dtype: Optional[torch.dtype] = None,
  265. tp_size: Optional[int] = None,
  266. quant_config: Optional[QuantizationConfig] = None,
  267. ):
  268. super().__init__()
  269. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  270. self.num_total_experts = config.num_experts
  271. self.top_k = config.num_experts_per_tok
  272. self.hidden_size = config.hidden_size
  273. self.intermediate_size = config.intermediate_size // self.tp_size
  274. if params_dtype is None:
  275. params_dtype = torch.get_default_dtype()
  276. self.params_dtype = params_dtype
  277. self.router = ReplicatedLinear(self.hidden_size,
  278. self.num_total_experts,
  279. bias=False,
  280. params_dtype=self.params_dtype)
  281. self.ws = nn.Parameter(
  282. torch.empty(
  283. self.num_total_experts,
  284. 2 * self.intermediate_size,
  285. self.hidden_size,
  286. device="cuda",
  287. dtype=self.params_dtype,
  288. ))
  289. self.w2s = nn.Parameter(
  290. torch.empty(
  291. self.num_total_experts,
  292. self.hidden_size,
  293. self.intermediate_size,
  294. device="cuda",
  295. dtype=self.params_dtype,
  296. ))
  297. set_weight_attrs(
  298. self.ws,
  299. {
  300. "weight_loader": self.weight_loader,
  301. },
  302. )
  303. set_weight_attrs(
  304. self.w2s,
  305. {
  306. "weight_loader": self.weight_loader,
  307. },
  308. )
  309. def weight_loader(
  310. self,
  311. param: nn.Parameter,
  312. loaded_weight: torch.Tensor,
  313. weight_name: str,
  314. expert_id: int,
  315. ):
  316. tp_rank = get_tensor_model_parallel_rank()
  317. param_data = param.data
  318. shard_size = self.intermediate_size
  319. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  320. if weight_name.endswith("gate_proj.weight"):
  321. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  322. if weight_name.endswith("up_proj.weight"):
  323. param_data[expert_id,
  324. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  325. if weight_name.endswith("down_proj.weight"):
  326. param_data[expert_id, :, :] = loaded_weight[:, shard]
  327. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  328. num_tokens, hidden_size = hidden_states.shape
  329. hidden_states = hidden_states.view(-1, self.hidden_size)
  330. # router_logits: (batch * sequence_length, n_experts)
  331. router_logits, _ = self.router(hidden_states)
  332. final_hidden_states = fused_moe(
  333. hidden_states,
  334. self.ws,
  335. self.w2s,
  336. router_logits,
  337. self.top_k,
  338. renormalize=
  339. False, # Mixtral normalize the expert probs to 1. We don't!
  340. inplace=True,
  341. )
  342. if self.tp_size > 1:
  343. final_hidden_states = tensor_model_parallel_all_reduce(
  344. final_hidden_states)
  345. return final_hidden_states.view(num_tokens, hidden_size)
  346. class JambaMambaDecoderLayer(nn.Module):
  347. def __init__(self,
  348. config: JambaConfig,
  349. layer_idx: int,
  350. cache_config: Optional[CacheConfig] = None,
  351. quant_config: Optional[QuantizationConfig] = None) -> None:
  352. super().__init__()
  353. self.layer_idx = layer_idx
  354. self.config = config
  355. self.mamba = JambaMambaMixer(config, layer_idx)
  356. num_experts = config.layers_num_experts[layer_idx]
  357. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  358. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  359. self.input_layernorm = RMSNorm(config.hidden_size,
  360. eps=config.rms_norm_eps)
  361. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  362. eps=config.rms_norm_eps)
  363. def forward(
  364. self,
  365. hidden_states: torch.Tensor,
  366. attn_metadata: AttentionMetadata,
  367. residual: Optional[torch.Tensor],
  368. conv_state: torch.Tensor,
  369. ssm_state: torch.Tensor,
  370. **kwargs,
  371. ):
  372. if residual is None:
  373. residual = hidden_states
  374. hidden_states = self.input_layernorm(hidden_states)
  375. else:
  376. hidden_states, residual = self.input_layernorm(
  377. hidden_states, residual)
  378. hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
  379. ssm_state)
  380. # Fully Connected
  381. hidden_states, residual = self.pre_ff_layernorm(
  382. hidden_states, residual)
  383. hidden_states = self.feed_forward(hidden_states)
  384. return hidden_states, residual
  385. class JambaAttentionDecoderLayer(nn.Module):
  386. def __init__(
  387. self,
  388. config: JambaConfig,
  389. layer_idx: int,
  390. cache_config: Optional[CacheConfig] = None,
  391. quant_config: Optional[QuantizationConfig] = None,
  392. ) -> None:
  393. super().__init__()
  394. self.hidden_size = config.hidden_size
  395. tp_size = get_tensor_model_parallel_world_size()
  396. self.total_num_heads = config.num_attention_heads
  397. assert self.total_num_heads % tp_size == 0
  398. self.num_heads = self.total_num_heads // tp_size
  399. self.total_num_kv_heads = config.num_key_value_heads
  400. if self.total_num_kv_heads >= tp_size:
  401. # Number of KV heads is greater than TP size, so we partition
  402. # the KV heads across multiple tensor parallel GPUs.
  403. assert self.total_num_kv_heads % tp_size == 0
  404. else:
  405. # Number of KV heads is less than TP size, so we replicate
  406. # the KV heads across multiple tensor parallel GPUs.
  407. assert tp_size % self.total_num_kv_heads == 0
  408. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  409. self.head_dim = config.hidden_size // self.total_num_heads
  410. self.q_size = self.num_heads * self.head_dim
  411. self.kv_size = self.num_kv_heads * self.head_dim
  412. self.scaling = self.head_dim**-0.5
  413. self.qkv_proj = QKVParallelLinear(
  414. config.hidden_size,
  415. self.head_dim,
  416. self.total_num_heads,
  417. self.total_num_kv_heads,
  418. bias=False,
  419. quant_config=quant_config,
  420. )
  421. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  422. config.hidden_size,
  423. bias=False,
  424. quant_config=quant_config)
  425. self.attn = Attention(
  426. self.num_heads,
  427. self.head_dim,
  428. self.scaling,
  429. num_kv_heads=self.num_kv_heads,
  430. cache_config=cache_config,
  431. )
  432. num_experts = config.layers_num_experts[layer_idx]
  433. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  434. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  435. self.input_layernorm = RMSNorm(config.hidden_size,
  436. eps=config.rms_norm_eps)
  437. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  438. eps=config.rms_norm_eps)
  439. def self_attention(
  440. self,
  441. positions: torch.Tensor,
  442. hidden_states: torch.Tensor,
  443. kv_cache: torch.Tensor,
  444. attn_metadata: AttentionMetadata,
  445. **kwargs,
  446. ) -> torch.Tensor:
  447. qkv, _ = self.qkv_proj(hidden_states)
  448. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  449. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  450. output, _ = self.o_proj(attn_output)
  451. return output
  452. def forward(
  453. self,
  454. positions: torch.Tensor,
  455. hidden_states: torch.Tensor,
  456. kv_cache: torch.Tensor,
  457. attn_metadata: AttentionMetadata,
  458. residual: Optional[torch.Tensor],
  459. **kwargs,
  460. ):
  461. if residual is None:
  462. residual = hidden_states
  463. hidden_states = self.input_layernorm(hidden_states)
  464. else:
  465. hidden_states, residual = self.input_layernorm(
  466. hidden_states, residual)
  467. hidden_states = self.self_attention(
  468. positions=positions,
  469. hidden_states=hidden_states,
  470. kv_cache=kv_cache,
  471. attn_metadata=attn_metadata,
  472. )
  473. # Fully Connected
  474. hidden_states, residual = self.pre_ff_layernorm(
  475. hidden_states, residual)
  476. hidden_states = self.feed_forward(hidden_states)
  477. return hidden_states, residual
  478. ALL_DECODER_LAYER_TYPES = {
  479. "attention": JambaAttentionDecoderLayer,
  480. "mamba": JambaMambaDecoderLayer
  481. }
  482. class JambaModel(nn.Module):
  483. def __init__(
  484. self,
  485. config: JambaConfig,
  486. quant_config: Optional[QuantizationConfig] = None,
  487. cache_config: Optional[CacheConfig] = None,
  488. lora_config: Optional[LoRAConfig] = None,
  489. ) -> None:
  490. super().__init__()
  491. self.config = config
  492. self.padding_idx = config.pad_token_id
  493. lora_vocab = ((lora_config.lora_extra_vocab_size *
  494. (lora_config.max_loras or 1)) if lora_config else 0)
  495. self.vocab_size = config.vocab_size + lora_vocab
  496. self.org_vocab_size = config.vocab_size
  497. self.embed_tokens = VocabParallelEmbedding(
  498. self.vocab_size,
  499. config.hidden_size,
  500. org_num_embeddings=config.vocab_size,
  501. )
  502. decoder_layers = []
  503. for i in range(config.num_hidden_layers):
  504. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  505. decoder_layers.append(
  506. layer_class(config,
  507. layer_idx=i,
  508. cache_config=cache_config,
  509. quant_config=quant_config))
  510. self.layers = nn.ModuleList(decoder_layers)
  511. self.final_layernorm = RMSNorm(config.hidden_size,
  512. eps=config.rms_norm_eps)
  513. def forward(
  514. self,
  515. input_ids: torch.Tensor,
  516. positions: torch.Tensor,
  517. kv_caches: List[torch.Tensor],
  518. attn_metadata: AttentionMetadata,
  519. conv_state: torch.Tensor,
  520. ssm_state: torch.Tensor,
  521. ) -> torch.Tensor:
  522. hidden_states = self.embed_tokens(input_ids)
  523. residual = None
  524. for i in range(len(self.layers)):
  525. layer = self.layers[i]
  526. kv_cache = None
  527. current_ssm_state = None
  528. current_conv_state = None
  529. if isinstance(layer, JambaAttentionDecoderLayer):
  530. kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
  531. self.config.attn_layer_period]
  532. if isinstance(layer, JambaMambaDecoderLayer):
  533. current_state_layer = i - (1 +
  534. (i - self.config.attn_layer_offset)
  535. // self.config.attn_layer_period)
  536. current_ssm_state = ssm_state[current_state_layer]
  537. current_conv_state = conv_state[current_state_layer]
  538. hidden_states, residual = layer(
  539. positions=positions,
  540. hidden_states=hidden_states,
  541. kv_cache=kv_cache,
  542. attn_metadata=attn_metadata,
  543. residual=residual,
  544. conv_state=current_conv_state,
  545. ssm_state=current_ssm_state,
  546. )
  547. hidden_states, _ = self.final_layernorm(hidden_states, residual)
  548. return hidden_states
  549. class JambaForCausalLM(nn.Module, HasInnerState):
  550. packed_modules_mapping = {
  551. "qkv_proj": [
  552. "q_proj",
  553. "k_proj",
  554. "v_proj",
  555. ],
  556. }
  557. # LoRA specific attributes
  558. supported_lora_modules = [
  559. "qkv_proj",
  560. "o_proj",
  561. "embed_tokens",
  562. "lm_head",
  563. ]
  564. embedding_modules = {
  565. "embed_tokens": "input_embeddings",
  566. "lm_head": "output_embeddings",
  567. }
  568. embedding_padding_modules = ["lm_head"]
  569. def __init__(
  570. self,
  571. config: JambaConfig,
  572. cache_config: Optional[CacheConfig] = None,
  573. quant_config: Optional[QuantizationConfig] = None,
  574. lora_config: Optional[LoRAConfig] = None,
  575. scheduler_config: Optional[SchedulerConfig] = None,
  576. ) -> None:
  577. super().__init__()
  578. self.config = config
  579. self.scheduler_config = scheduler_config
  580. self.model = JambaModel(config,
  581. cache_config=cache_config,
  582. quant_config=quant_config,
  583. lora_config=lora_config)
  584. self.unpadded_vocab_size = config.vocab_size
  585. if lora_config:
  586. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  587. self.lm_head = ParallelLMHead(
  588. self.unpadded_vocab_size,
  589. config.hidden_size,
  590. org_num_embeddings=config.vocab_size,
  591. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  592. # We need bigger padding if using lora for kernel
  593. # compatibility
  594. if not lora_config else lora_config.lora_vocab_padding_size,
  595. )
  596. # Current step used indices
  597. self.current_indices: List[int] = []
  598. # Used to track and store by the Mamba cache between steps.
  599. self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
  600. # Used as an input_buffer for the CUDA graph runs.
  601. self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
  602. # Maps between the request id and a dict that maps between the seq_id
  603. # and its index inside the self.mamba_cache
  604. self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
  605. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  606. config.vocab_size)
  607. self.sampler = Sampler()
  608. def forward(self,
  609. input_ids: torch.Tensor,
  610. positions: torch.Tensor,
  611. kv_caches: List[KVCache],
  612. attn_metadata: AttentionMetadata,
  613. intermediate_tensors: Optional[IntermediateTensors] = None,
  614. **kwargs):
  615. if not self.mamba_cache:
  616. self._prepare_mamba_cache()
  617. if "seqlen_agnostic_capture_inputs" not in kwargs:
  618. # We get here only on Prefill/Eager mode runs
  619. assert all(
  620. key in kwargs
  621. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  622. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  623. finished_requests_ids = kwargs["finished_requests_ids"]
  624. self._release_mamba_cache(finished_requests_ids)
  625. batch_size = input_ids.shape[0]
  626. if attn_metadata.prefill_metadata:
  627. batch_size = len(request_ids_to_seq_ids)
  628. (
  629. current_seqlen_agnostic_cache,
  630. indices,
  631. ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
  632. batch_size,
  633. finished_requests_ids)
  634. else:
  635. # CUDA graph capturing runs
  636. current_seqlen_agnostic_cache, indices = (
  637. kwargs["seqlen_agnostic_capture_inputs"],
  638. [],
  639. )
  640. self.current_indices = indices
  641. hidden_states = self.model(input_ids, positions, kv_caches,
  642. attn_metadata,
  643. current_seqlen_agnostic_cache[0],
  644. current_seqlen_agnostic_cache[1])
  645. if "seqlen_agnostic_capture_inputs" not in kwargs:
  646. self._copy_mamba_cache_by_indices(self.current_indices,
  647. current_seqlen_agnostic_cache)
  648. return hidden_states
  649. def _copy_mamba_cache_by_indices(
  650. self, indices: List[int],
  651. current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
  652. for i, offset in enumerate(indices):
  653. self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
  654. def _copy_mamba_cache(self, index_to: int, index_from: int,
  655. from_buffer: Tuple[torch.Tensor, torch.Tensor]):
  656. assert len(self.mamba_cache) > 0
  657. for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
  658. cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
  659. non_blocking=True)
  660. def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
  661. seqs_id: List[int]) -> List[int]:
  662. indices_for_current_run = []
  663. for seq_id in seqs_id:
  664. if cur_rid not in self.mamba_cache_indices_mapping:
  665. self.mamba_cache_indices_mapping[cur_rid] = {}
  666. first_free_index = self._first_free_index_in_mamba_cache()
  667. self.mamba_cache_indices_mapping[cur_rid][
  668. seq_id] = first_free_index
  669. index_for_current_run = first_free_index
  670. ## case of decoding n>1, copy prefill cache to decoding indices
  671. elif seq_id not in (seq_ids2indices :=
  672. self.mamba_cache_indices_mapping[cur_rid]):
  673. first_free_index = self._first_free_index_in_mamba_cache()
  674. index_exist = list(seq_ids2indices.values())[0]
  675. self._copy_mamba_cache(index_from=index_exist,
  676. index_to=first_free_index,
  677. from_buffer=self.mamba_cache)
  678. self.mamba_cache_indices_mapping[cur_rid][
  679. seq_id] = first_free_index
  680. index_for_current_run = first_free_index
  681. else:
  682. index_for_current_run = self.mamba_cache_indices_mapping[
  683. cur_rid][seq_id]
  684. indices_for_current_run.append(index_for_current_run)
  685. return indices_for_current_run
  686. def _prepare_current_run_mamba_cache(
  687. self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
  688. finished_requests_ids: List[str]
  689. ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
  690. indices_for_current_run = []
  691. for request_id, seqs_id in request_ids_to_seq_ids.items():
  692. if request_id in finished_requests_ids:
  693. # Do not allocate cache for requests that run
  694. # and finish right after
  695. continue
  696. indices_for_current_run += self._assign_seq_id_to_mamba_cache(
  697. request_id, seqs_id)
  698. ## Pad the batch in case of running batch that was not captured via CG
  699. padded_indices = indices_for_current_run.copy()
  700. pad_index = self._first_free_index_in_mamba_cache()
  701. for _ in range(batch_size - len(indices_for_current_run)):
  702. padded_indices.append(pad_index)
  703. conv_state = self.mamba_cache[0][:, padded_indices]
  704. temporal_state = self.mamba_cache[1][:, padded_indices]
  705. return (conv_state, temporal_state), indices_for_current_run
  706. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  707. """
  708. Copy the relevant Mamba cache into the CUDA graph input buffer
  709. that was provided during the capture runs
  710. (JambaForCausalLM.mamba_gc_cache_buffer).
  711. """
  712. assert all(
  713. key in kwargs
  714. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  715. finished_requests_ids = kwargs["finished_requests_ids"]
  716. self._release_mamba_cache(finished_requests_ids)
  717. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  718. cg_batch_size = input_buffers['input_ids'].shape[0]
  719. (
  720. current_mamba_cache,
  721. indices,
  722. ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
  723. cg_batch_size,
  724. finished_requests_ids)
  725. self.current_indices = indices
  726. for input_buffer, current_cache_buffer in zip(
  727. input_buffers["seqlen_agnostic_capture_inputs"],
  728. current_mamba_cache):
  729. input_buffer.copy_(current_cache_buffer, non_blocking=True)
  730. def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
  731. """
  732. Copy the relevant Mamba cache from the CUDA graph input_buffers
  733. back to the JambaForCausalLM.mamba_cache after CUDA
  734. graph replay run is done.
  735. """
  736. self._copy_mamba_cache_by_indices(
  737. self.current_indices,
  738. input_buffers["seqlen_agnostic_capture_inputs"])
  739. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  740. """
  741. Provide the CUDA graph capture runs with a buffer in adjusted size.
  742. The buffer is used to maintain the Mamba Cache during the CUDA graph
  743. replay runs.
  744. """
  745. return tuple(buffer[:, :batch_size]
  746. for buffer in self.mamba_gc_cache_buffer)
  747. def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
  748. for req_id in finished_seq_groups_req_ids:
  749. if req_id in self.mamba_cache_indices_mapping:
  750. self.mamba_cache_indices_mapping.pop(req_id)
  751. def _first_free_index_in_mamba_cache(self) -> int:
  752. if self.mamba_cache:
  753. max_possible_batch_size = self.mamba_cache[0].shape[1]
  754. occupied = [
  755. id for seq_ids in self.mamba_cache_indices_mapping.values()
  756. for id in seq_ids.values()
  757. ]
  758. first_free_index = [
  759. i not in occupied for i in range(max_possible_batch_size)
  760. ].index(True)
  761. return first_free_index
  762. return 0
  763. def _get_mamba_cache_shape(
  764. self
  765. ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
  766. world_size = get_tensor_model_parallel_world_size()
  767. hidden_size = self.config.hidden_size
  768. conv_state_shape = (
  769. self.config.mamba_expand * hidden_size // world_size,
  770. self.config.mamba_d_conv,
  771. )
  772. temporal_state_shape = (
  773. self.config.mamba_expand * self.config.hidden_size // world_size,
  774. self.config.mamba_d_state,
  775. )
  776. return conv_state_shape, temporal_state_shape
  777. def _prepare_mamba_cache(self):
  778. dtype = self.lm_head.weight.dtype
  779. layers_type = self.config.layers_block_type
  780. mamba_layers = sum(
  781. [layer_type == "mamba" for layer_type in layers_type])
  782. max_batch_size = (_get_graph_batch_size(
  783. self.scheduler_config.max_num_seqs) if self.scheduler_config else
  784. max(_BATCH_SIZES_TO_CAPTURE)) + 10
  785. conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
  786. assert conv_state_shape is not None and temporal_state_shape is not None
  787. for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
  788. buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
  789. conv_state_shape,
  790. dtype=dtype,
  791. device="cuda"),
  792. torch.empty(size=(mamba_layers, max_batch_size) +
  793. temporal_state_shape,
  794. dtype=dtype,
  795. device="cuda"))
  796. setattr(self, buffername, buffer)
  797. def compute_logits(self, hidden_states: torch.Tensor,
  798. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  799. logits = self.logits_processor(self.lm_head, hidden_states,
  800. sampling_metadata)
  801. return logits
  802. def sample(
  803. self,
  804. logits: Optional[torch.Tensor],
  805. sampling_metadata: SamplingMetadata,
  806. ) -> Optional[SamplerOutput]:
  807. next_tokens = self.sampler(logits, sampling_metadata)
  808. return next_tokens
  809. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  810. stacked_params_mapping = [
  811. # (param_name, shard_name, shard_id)
  812. ("qkv_proj", "q_proj", "q"),
  813. ("qkv_proj", "k_proj", "k"),
  814. ("qkv_proj", "v_proj", "v"),
  815. ("gate_up_proj", "gate_proj", 0),
  816. ("gate_up_proj", "up_proj", 1),
  817. ]
  818. expert_params_mapping = [
  819. # (param_name, weight_name, expert_id)
  820. (
  821. "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
  822. f"experts.{expert_id}.{weight_name}.weight",
  823. expert_id,
  824. ) for expert_id in range(self.config.num_experts)
  825. for weight_name in ["down_proj", "up_proj", "gate_proj"]
  826. ]
  827. params_dict = dict(self.named_parameters())
  828. for name, loaded_weight in weights:
  829. if "rotary_emb.inv_freq" in name:
  830. continue
  831. if "A_log" in name:
  832. name = name.replace("A_log", "A")
  833. if ".self_attn." in name:
  834. name = name.replace(".self_attn", "")
  835. for param_name, weight_name, shard_id in stacked_params_mapping:
  836. if weight_name not in name:
  837. continue
  838. if 'experts' in name:
  839. continue
  840. name = name.replace(weight_name, param_name)
  841. # Skip loading extra bias for GPTQ models.
  842. if name.endswith(".bias") and name not in params_dict:
  843. continue
  844. param = params_dict[name]
  845. weight_loader = param.weight_loader
  846. weight_loader(param, loaded_weight, shard_id)
  847. break
  848. else:
  849. for param_name, weight_name, expert_id in expert_params_mapping:
  850. if weight_name not in name:
  851. continue
  852. name = name.replace(weight_name, param_name)
  853. param = params_dict[name]
  854. weight_loader = param.weight_loader
  855. weight_loader(param,
  856. loaded_weight,
  857. weight_name,
  858. expert_id=expert_id)
  859. break
  860. else:
  861. # Skip loading extra bias for GPTQ models.
  862. if name.endswith(".bias") and name not in params_dict:
  863. continue
  864. param = params_dict[name]
  865. weight_loader = getattr(param, "weight_loader",
  866. default_weight_loader)
  867. weight_loader(param, loaded_weight)