jamba.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # coding=utf-8
  2. """Inference-only Jurassic model."""
  3. from dataclasses import dataclass
  4. from typing import Dict, Iterable, List, Optional, Tuple
  5. import torch
  6. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  7. from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
  8. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  9. from torch import nn
  10. from torch.nn.parameter import Parameter
  11. from transformers import JambaConfig
  12. from aphrodite.attention.backends.abstract import AttentionMetadata
  13. from aphrodite.attention.layer import Attention
  14. from aphrodite.common.config import CacheConfig, LoRAConfig
  15. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  16. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  17. get_tensor_model_parallel_world_size,
  18. tensor_model_parallel_all_reduce)
  19. from aphrodite.modeling.layers.activation import SiluAndMul
  20. from aphrodite.modeling.layers.fused_moe import fused_moe
  21. from aphrodite.modeling.layers.layernorm import RMSNorm
  22. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  23. MergedColumnParallelLinear,
  24. QKVParallelLinear,
  25. ReplicatedLinear,
  26. RowParallelLinear)
  27. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  28. from aphrodite.modeling.layers.sampler import Sampler
  29. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  30. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  31. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  32. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  33. from aphrodite.modeling.utils import set_weight_attrs
  34. from aphrodite.quantization.base_config import QuantizationConfig
  35. from aphrodite.task_handler.model_runner import _BATCH_SIZES_TO_CAPTURE
  36. KVCache = Tuple[torch.Tensor, torch.Tensor]
  37. @dataclass
  38. class MambaCacheParams:
  39. is_prompt: bool = False
  40. conv_state: torch.Tensor = torch.Tensor()
  41. ssm_state: torch.Tensor = torch.Tensor()
  42. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  43. class JambaMambaMixer(nn.Module):
  44. """
  45. Compute ∆, A, B, C, and D the state space parameters and compute
  46. the `contextualized_states`. A, D are input independent
  47. (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
  48. for why A isn't selective) ∆, B, C are input-dependent
  49. (this is a key difference between Mamba and the linear time
  50. invariant S4, and is why Mamba is called
  51. **selective** state spaces)
  52. """
  53. def __init__(self, config: JambaConfig, layer_idx):
  54. super().__init__()
  55. self.config = config
  56. self.layer_idx = layer_idx
  57. self.hidden_size = config.hidden_size
  58. self.ssm_state_size = config.mamba_d_state
  59. self.conv_kernel_size = config.mamba_d_conv
  60. self.intermediate_size = config.mamba_expand * config.hidden_size
  61. self.time_step_rank = config.mamba_dt_rank
  62. self.use_conv_bias = config.mamba_conv_bias
  63. self.use_bias = config.mamba_proj_bias
  64. self.conv1d = ColumnParallelLinear(
  65. input_size=self.conv_kernel_size,
  66. output_size=self.intermediate_size,
  67. bias=self.use_conv_bias,
  68. )
  69. # unsqueeze to fit conv1d weights shape into the linear weights shape.
  70. # Can't do this in `weight_loader` since it already exists in
  71. # `ColumnParallelLinear` and `set_weight_attrs`
  72. # doesn't allow to override it
  73. self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
  74. self.in_proj = MergedColumnParallelLinear(self.hidden_size,
  75. [self.intermediate_size] * 2,
  76. bias=self.use_bias)
  77. # selective projection used to make dt, B and C input dependent
  78. self.x_proj = RowParallelLinear(
  79. self.intermediate_size,
  80. self.time_step_rank + self.ssm_state_size * 2,
  81. bias=False,
  82. )
  83. # time step projection (discretization) -
  84. # In the forward we need to apply dt_proj without the bias,
  85. # as the bias is added in the selective scan kernel.
  86. self.dt_proj = ColumnParallelLinear(self.time_step_rank,
  87. self.intermediate_size,
  88. bias=True,
  89. skip_bias_add=True)
  90. def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  91. tp_rank = get_tensor_model_parallel_rank()
  92. tp_size = get_tensor_model_parallel_world_size()
  93. param.data.copy_(
  94. loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
  95. dim=0)[tp_rank])
  96. def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  97. weight_loader(param, -torch.exp(loaded_weight.float()))
  98. tp_size = get_tensor_model_parallel_world_size()
  99. self.A = nn.Parameter(
  100. torch.empty(
  101. self.intermediate_size // tp_size,
  102. self.ssm_state_size,
  103. dtype=torch.float32,
  104. ))
  105. self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
  106. set_weight_attrs(self.D, {"weight_loader": weight_loader})
  107. set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
  108. self.out_proj = RowParallelLinear(
  109. self.intermediate_size,
  110. self.hidden_size,
  111. bias=self.use_bias,
  112. input_is_parallel=True,
  113. )
  114. self.activation = config.hidden_act
  115. self.dt_layernorm = RMSNorm(self.time_step_rank,
  116. eps=config.rms_norm_eps)
  117. self.b_layernorm = RMSNorm(self.ssm_state_size,
  118. eps=config.rms_norm_eps)
  119. self.c_layernorm = RMSNorm(self.ssm_state_size,
  120. eps=config.rms_norm_eps)
  121. def mamba_forward(self,
  122. hidden_states: torch.Tensor,
  123. cache_params: MambaCacheParams = None):
  124. # 1. Gated MLP's linear projection
  125. projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
  126. hidden_states, gate = projected_states.chunk(2, dim=1)
  127. # 2. Convolution sequence transformation
  128. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
  129. self.conv1d.weight.size(2))
  130. if cache_params is not None and not cache_params.is_prompt:
  131. hidden_states = causal_conv1d_update(
  132. hidden_states.squeeze(-1),
  133. cache_params.conv_state,
  134. conv_weights,
  135. self.conv1d.bias,
  136. self.activation,
  137. )
  138. hidden_states = hidden_states.unsqueeze(-1)
  139. else:
  140. if cache_params is not None:
  141. conv_states = nn.functional.pad(
  142. hidden_states,
  143. (self.conv_kernel_size - hidden_states.shape[-1], 0))
  144. cache_params.conv_state.copy_(conv_states)
  145. hidden_states = causal_conv1d_fn(
  146. hidden_states,
  147. conv_weights,
  148. self.conv1d.bias,
  149. activation=self.activation,
  150. )
  151. # 3. State Space Model sequence transformation
  152. # 3.a. input varying initialization of time_step, B and C
  153. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
  154. time_step, B, C = torch.split(
  155. ssm_parameters,
  156. [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
  157. dim=-1,
  158. )
  159. time_step = self.dt_layernorm(time_step.contiguous())
  160. B = self.b_layernorm(B.contiguous())
  161. C = self.c_layernorm(C.contiguous())
  162. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
  163. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  164. time_proj_bias = (self.dt_proj.bias.float() if hasattr(
  165. self.dt_proj, "bias") else None)
  166. if cache_params is not None and not cache_params.is_prompt:
  167. scan_outputs = selective_state_update(
  168. cache_params.ssm_state,
  169. hidden_states[..., 0],
  170. discrete_time_step[..., 0],
  171. self.A,
  172. B[:, 0],
  173. C[:, 0],
  174. self.D,
  175. gate[..., 0],
  176. time_proj_bias,
  177. dt_softplus=True,
  178. ).unsqueeze(-1)
  179. else:
  180. scan_outputs, ssm_state = selective_scan_fn(
  181. hidden_states,
  182. discrete_time_step,
  183. self.A,
  184. B.transpose(1, 2),
  185. C.transpose(1, 2),
  186. self.D.float(),
  187. gate,
  188. time_proj_bias,
  189. delta_softplus=True,
  190. return_last_state=True,
  191. )
  192. if ssm_state is not None and cache_params is not None:
  193. cache_params.ssm_state.copy_(ssm_state)
  194. # 4. Final linear projection
  195. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
  196. return contextualized_states
  197. def forward(
  198. self,
  199. hidden_states: torch.Tensor,
  200. attn_metadata: AttentionMetadata,
  201. conv_state: torch.Tensor,
  202. ssm_state: torch.Tensor,
  203. ):
  204. if attn_metadata.prefill_metadata is not None:
  205. offset = 0
  206. for i, prompt_len in enumerate(
  207. attn_metadata.prefill_metadata.seq_lens):
  208. cache = MambaCacheParams(True,
  209. conv_state=conv_state[i].unsqueeze(0),
  210. ssm_state=ssm_state[i].unsqueeze(0))
  211. hidden_states[offset:offset + prompt_len].copy_(
  212. self.mamba_forward(hidden_states[offset:offset +
  213. prompt_len].unsqueeze(0),
  214. cache_params=cache)[0])
  215. offset += prompt_len
  216. else:
  217. cache = MambaCacheParams(False,
  218. conv_state=conv_state,
  219. ssm_state=ssm_state)
  220. hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
  221. cache_params=cache)
  222. hidden_states = hidden_states.squeeze(1)
  223. return hidden_states
  224. class JambaMLP(nn.Module):
  225. def __init__(
  226. self,
  227. config: JambaConfig,
  228. quant_config: Optional[QuantizationConfig] = None,
  229. ) -> None:
  230. super().__init__()
  231. hidden_size = config.hidden_size
  232. intermediate_size = config.intermediate_size
  233. hidden_act = config.hidden_act
  234. self.gate_up_proj = MergedColumnParallelLinear(
  235. hidden_size, [intermediate_size] * 2,
  236. bias=False,
  237. quant_config=quant_config)
  238. self.down_proj = RowParallelLinear(intermediate_size,
  239. hidden_size,
  240. bias=False,
  241. quant_config=quant_config)
  242. if hidden_act != "silu":
  243. raise ValueError(f"Unsupported activation: {hidden_act}. "
  244. "Only silu is supported for now.")
  245. self.act_fn = SiluAndMul()
  246. def forward(self, x):
  247. gate_up, _ = self.gate_up_proj(x)
  248. x = self.act_fn(gate_up)
  249. x, _ = self.down_proj(x)
  250. return x
  251. class JambaMoE(nn.Module):
  252. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  253. across all ranks.
  254. Each expert's weights are sharded across all ranks and a fused MoE
  255. kernel is used for the forward pass, and finally we reduce the outputs
  256. across ranks.
  257. """
  258. def __init__(
  259. self,
  260. config: JambaConfig,
  261. params_dtype: Optional[torch.dtype] = None,
  262. tp_size: Optional[int] = None,
  263. quant_config: Optional[QuantizationConfig] = None,
  264. ):
  265. super().__init__()
  266. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  267. self.num_total_experts = config.num_experts
  268. self.top_k = config.num_experts_per_tok
  269. self.hidden_size = config.hidden_size
  270. self.intermediate_size = config.intermediate_size // self.tp_size
  271. if params_dtype is None:
  272. params_dtype = torch.get_default_dtype()
  273. self.params_dtype = params_dtype
  274. self.router = ReplicatedLinear(self.hidden_size,
  275. self.num_total_experts,
  276. bias=False,
  277. params_dtype=self.params_dtype)
  278. self.ws = nn.Parameter(
  279. torch.empty(
  280. self.num_total_experts,
  281. 2 * self.intermediate_size,
  282. self.hidden_size,
  283. device="cuda",
  284. dtype=self.params_dtype,
  285. ))
  286. self.w2s = nn.Parameter(
  287. torch.empty(
  288. self.num_total_experts,
  289. self.hidden_size,
  290. self.intermediate_size,
  291. device="cuda",
  292. dtype=self.params_dtype,
  293. ))
  294. set_weight_attrs(
  295. self.ws,
  296. {
  297. "weight_loader": self.weight_loader,
  298. },
  299. )
  300. set_weight_attrs(
  301. self.w2s,
  302. {
  303. "weight_loader": self.weight_loader,
  304. },
  305. )
  306. def weight_loader(
  307. self,
  308. param: nn.Parameter,
  309. loaded_weight: torch.Tensor,
  310. weight_name: str,
  311. expert_id: int,
  312. ):
  313. tp_rank = get_tensor_model_parallel_rank()
  314. param_data = param.data
  315. shard_size = self.intermediate_size
  316. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  317. if weight_name.endswith("gate_proj.weight"):
  318. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  319. if weight_name.endswith("up_proj.weight"):
  320. param_data[expert_id,
  321. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  322. if weight_name.endswith("down_proj.weight"):
  323. param_data[expert_id, :, :] = loaded_weight[:, shard]
  324. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  325. num_tokens, hidden_size = hidden_states.shape
  326. hidden_states = hidden_states.view(-1, self.hidden_size)
  327. # router_logits: (batch * sequence_length, n_experts)
  328. router_logits, _ = self.router(hidden_states)
  329. final_hidden_states = fused_moe(
  330. hidden_states,
  331. self.ws,
  332. self.w2s,
  333. router_logits,
  334. self.top_k,
  335. renormalize=
  336. False, # Mixtral normalize the expert probs to 1. We don't!
  337. inplace=True,
  338. )
  339. if self.tp_size > 1:
  340. final_hidden_states = tensor_model_parallel_all_reduce(
  341. final_hidden_states)
  342. return final_hidden_states.view(num_tokens, hidden_size)
  343. class JambaMambaDecoderLayer(nn.Module):
  344. def __init__(self,
  345. config: JambaConfig,
  346. layer_idx: int,
  347. cache_config: Optional[CacheConfig] = None,
  348. quant_config: Optional[QuantizationConfig] = None) -> None:
  349. super().__init__()
  350. self.layer_idx = layer_idx
  351. self.config = config
  352. self.mamba = JambaMambaMixer(config, layer_idx)
  353. num_experts = config.layers_num_experts[layer_idx]
  354. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  355. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  356. self.input_layernorm = RMSNorm(config.hidden_size,
  357. eps=config.rms_norm_eps)
  358. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  359. eps=config.rms_norm_eps)
  360. def forward(
  361. self,
  362. hidden_states: torch.Tensor,
  363. attn_metadata: AttentionMetadata,
  364. residual: Optional[torch.Tensor],
  365. conv_state: torch.Tensor,
  366. ssm_state: torch.Tensor,
  367. **kwargs,
  368. ):
  369. if residual is None:
  370. residual = hidden_states
  371. hidden_states = self.input_layernorm(hidden_states)
  372. else:
  373. hidden_states, residual = self.input_layernorm(
  374. hidden_states, residual)
  375. hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
  376. ssm_state)
  377. # Fully Connected
  378. hidden_states, residual = self.pre_ff_layernorm(
  379. hidden_states, residual)
  380. hidden_states = self.feed_forward(hidden_states)
  381. return hidden_states, residual
  382. class JambaAttentionDecoderLayer(nn.Module):
  383. def __init__(
  384. self,
  385. config: JambaConfig,
  386. layer_idx: int,
  387. cache_config: Optional[CacheConfig] = None,
  388. quant_config: Optional[QuantizationConfig] = None,
  389. ) -> None:
  390. super().__init__()
  391. self.hidden_size = config.hidden_size
  392. tp_size = get_tensor_model_parallel_world_size()
  393. self.total_num_heads = config.num_attention_heads
  394. assert self.total_num_heads % tp_size == 0
  395. self.num_heads = self.total_num_heads // tp_size
  396. self.total_num_kv_heads = config.num_key_value_heads
  397. if self.total_num_kv_heads >= tp_size:
  398. # Number of KV heads is greater than TP size, so we partition
  399. # the KV heads across multiple tensor parallel GPUs.
  400. assert self.total_num_kv_heads % tp_size == 0
  401. else:
  402. # Number of KV heads is less than TP size, so we replicate
  403. # the KV heads across multiple tensor parallel GPUs.
  404. assert tp_size % self.total_num_kv_heads == 0
  405. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  406. self.head_dim = config.hidden_size // self.total_num_heads
  407. self.q_size = self.num_heads * self.head_dim
  408. self.kv_size = self.num_kv_heads * self.head_dim
  409. self.scaling = self.head_dim**-0.5
  410. self.qkv_proj = QKVParallelLinear(
  411. config.hidden_size,
  412. self.head_dim,
  413. self.total_num_heads,
  414. self.total_num_kv_heads,
  415. bias=False,
  416. quant_config=quant_config,
  417. )
  418. self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
  419. config.hidden_size,
  420. bias=False,
  421. quant_config=quant_config)
  422. self.attn = Attention(
  423. self.num_heads,
  424. self.head_dim,
  425. self.scaling,
  426. num_kv_heads=self.num_kv_heads,
  427. cache_config=cache_config,
  428. )
  429. num_experts = config.layers_num_experts[layer_idx]
  430. ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
  431. self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
  432. self.input_layernorm = RMSNorm(config.hidden_size,
  433. eps=config.rms_norm_eps)
  434. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  435. eps=config.rms_norm_eps)
  436. def self_attention(
  437. self,
  438. positions: torch.Tensor,
  439. hidden_states: torch.Tensor,
  440. kv_cache: torch.Tensor,
  441. attn_metadata: AttentionMetadata,
  442. **kwargs,
  443. ) -> torch.Tensor:
  444. qkv, _ = self.qkv_proj(hidden_states)
  445. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  446. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  447. output, _ = self.o_proj(attn_output)
  448. return output
  449. def forward(
  450. self,
  451. positions: torch.Tensor,
  452. hidden_states: torch.Tensor,
  453. kv_cache: torch.Tensor,
  454. attn_metadata: AttentionMetadata,
  455. residual: Optional[torch.Tensor],
  456. **kwargs,
  457. ):
  458. if residual is None:
  459. residual = hidden_states
  460. hidden_states = self.input_layernorm(hidden_states)
  461. else:
  462. hidden_states, residual = self.input_layernorm(
  463. hidden_states, residual)
  464. hidden_states = self.self_attention(
  465. positions=positions,
  466. hidden_states=hidden_states,
  467. kv_cache=kv_cache,
  468. attn_metadata=attn_metadata,
  469. )
  470. # Fully Connected
  471. hidden_states, residual = self.pre_ff_layernorm(
  472. hidden_states, residual)
  473. hidden_states = self.feed_forward(hidden_states)
  474. return hidden_states, residual
  475. ALL_DECODER_LAYER_TYPES = {
  476. "attention": JambaAttentionDecoderLayer,
  477. "mamba": JambaMambaDecoderLayer
  478. }
  479. class JambaModel(nn.Module):
  480. def __init__(
  481. self,
  482. config: JambaConfig,
  483. quant_config: Optional[QuantizationConfig] = None,
  484. cache_config: Optional[CacheConfig] = None,
  485. lora_config: Optional[LoRAConfig] = None,
  486. ) -> None:
  487. super().__init__()
  488. self.config = config
  489. self.padding_idx = config.pad_token_id
  490. lora_vocab = ((lora_config.lora_extra_vocab_size *
  491. (lora_config.max_loras or 1)) if lora_config else 0)
  492. self.vocab_size = config.vocab_size + lora_vocab
  493. self.org_vocab_size = config.vocab_size
  494. self.embed_tokens = VocabParallelEmbedding(
  495. self.vocab_size,
  496. config.hidden_size,
  497. org_num_embeddings=config.vocab_size,
  498. )
  499. decoder_layers = []
  500. for i in range(config.num_hidden_layers):
  501. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  502. decoder_layers.append(
  503. layer_class(config,
  504. layer_idx=i,
  505. cache_config=cache_config,
  506. quant_config=quant_config))
  507. self.layers = nn.ModuleList(decoder_layers)
  508. self.final_layernorm = RMSNorm(config.hidden_size,
  509. eps=config.rms_norm_eps)
  510. def forward(
  511. self,
  512. input_ids: torch.Tensor,
  513. positions: torch.Tensor,
  514. kv_caches: List[torch.Tensor],
  515. attn_metadata: AttentionMetadata,
  516. conv_state: torch.Tensor,
  517. ssm_state: torch.Tensor,
  518. ) -> torch.Tensor:
  519. hidden_states = self.embed_tokens(input_ids)
  520. residual = None
  521. for i in range(len(self.layers)):
  522. layer = self.layers[i]
  523. kv_cache = None
  524. current_ssm_state = None
  525. current_conv_state = None
  526. if isinstance(layer, JambaAttentionDecoderLayer):
  527. kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
  528. self.config.attn_layer_period]
  529. if isinstance(layer, JambaMambaDecoderLayer):
  530. current_state_layer = i - (1 +
  531. (i - self.config.attn_layer_offset)
  532. // self.config.attn_layer_period)
  533. current_ssm_state = ssm_state[current_state_layer]
  534. current_conv_state = conv_state[current_state_layer]
  535. hidden_states, residual = layer(
  536. positions=positions,
  537. hidden_states=hidden_states,
  538. kv_cache=kv_cache,
  539. attn_metadata=attn_metadata,
  540. residual=residual,
  541. conv_state=current_conv_state,
  542. ssm_state=current_ssm_state,
  543. )
  544. hidden_states, _ = self.final_layernorm(hidden_states, residual)
  545. return hidden_states
  546. class JambaForCausalLM(nn.Module):
  547. packed_modules_mapping = {
  548. "qkv_proj": [
  549. "q_proj",
  550. "k_proj",
  551. "v_proj",
  552. ],
  553. }
  554. # LoRA specific attributes
  555. supported_lora_modules = [
  556. "qkv_proj",
  557. "o_proj",
  558. "embed_tokens",
  559. "lm_head",
  560. ]
  561. embedding_modules = {
  562. "embed_tokens": "input_embeddings",
  563. "lm_head": "output_embeddings",
  564. }
  565. embedding_padding_modules = ["lm_head"]
  566. def __init__(
  567. self,
  568. config: JambaConfig,
  569. cache_config: Optional[CacheConfig] = None,
  570. quant_config: Optional[QuantizationConfig] = None,
  571. lora_config: Optional[LoRAConfig] = None,
  572. ) -> None:
  573. super().__init__()
  574. self.config = config
  575. self.model = JambaModel(config,
  576. cache_config=cache_config,
  577. quant_config=quant_config,
  578. lora_config=lora_config)
  579. self.unpadded_vocab_size = config.vocab_size
  580. if lora_config:
  581. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  582. self.lm_head = ParallelLMHead(
  583. self.unpadded_vocab_size,
  584. config.hidden_size,
  585. org_num_embeddings=config.vocab_size,
  586. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  587. # We need bigger padding if using lora for kernel
  588. # compatibility
  589. if not lora_config else lora_config.lora_vocab_padding_size,
  590. )
  591. # Current step used indices
  592. self.current_indices: List[int] = []
  593. # Used to track and store by the Mamba cache between steps.
  594. self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
  595. # Used as an input_buffer for the CUDA graph runs.
  596. self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
  597. # Maps between the request id and a dict that maps between the seq_id
  598. # and its index inside the self.mamba_cache
  599. self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
  600. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  601. config.vocab_size)
  602. self.sampler = Sampler()
  603. def forward(self,
  604. input_ids: torch.Tensor,
  605. positions: torch.Tensor,
  606. kv_caches: List[KVCache],
  607. attn_metadata: AttentionMetadata,
  608. intermediate_tensors: Optional[IntermediateTensors] = None,
  609. **kwargs):
  610. if not self.mamba_cache:
  611. self._prepare_mamba_cache()
  612. if "seqlen_agnostic_capture_inputs" not in kwargs:
  613. # We get here only on Prefill/Eager mode runs
  614. assert all(
  615. key in kwargs
  616. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  617. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  618. batch_size = input_ids.shape[0]
  619. if attn_metadata.prefill_metadata:
  620. batch_size = len(request_ids_to_seq_ids)
  621. (
  622. current_seqlen_agnostic_cache,
  623. indices,
  624. ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
  625. batch_size)
  626. finished_requests_ids = kwargs["finished_requests_ids"]
  627. self._release_mamba_cache(finished_requests_ids)
  628. else:
  629. # CUDA graph capturing runs
  630. current_seqlen_agnostic_cache, indices = (
  631. kwargs["seqlen_agnostic_capture_inputs"],
  632. [],
  633. )
  634. self.current_indices = indices
  635. hidden_states = self.model(input_ids, positions, kv_caches,
  636. attn_metadata,
  637. current_seqlen_agnostic_cache[0],
  638. current_seqlen_agnostic_cache[1])
  639. if "seqlen_agnostic_capture_inputs" not in kwargs:
  640. self._copy_mamba_cache_by_indices(self.current_indices,
  641. current_seqlen_agnostic_cache)
  642. return hidden_states
  643. def _copy_mamba_cache_by_indices(
  644. self, indices: List[int],
  645. current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
  646. for i, offset in enumerate(indices):
  647. self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
  648. def _copy_mamba_cache(self, index_to: int, index_from: int,
  649. from_buffer: Tuple[torch.Tensor, torch.Tensor]):
  650. assert len(self.mamba_cache) > 0
  651. for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
  652. cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
  653. non_blocking=True)
  654. def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
  655. seqs_id: List[int]) -> List[int]:
  656. indices_for_current_run = []
  657. for seq_id in seqs_id:
  658. if cur_rid not in self.mamba_cache_indices_mapping:
  659. self.mamba_cache_indices_mapping[cur_rid] = {}
  660. first_free_index = self._first_free_index_in_mamba_cache()
  661. self.mamba_cache_indices_mapping[cur_rid][
  662. seq_id] = first_free_index
  663. index_for_current_run = first_free_index
  664. ## case of decoding n>1, copy prefill cache to decoding indices
  665. elif seq_id not in (seq_ids2indices :=
  666. self.mamba_cache_indices_mapping[cur_rid]):
  667. first_free_index = self._first_free_index_in_mamba_cache()
  668. index_exist = list(seq_ids2indices.values())[0]
  669. self._copy_mamba_cache(index_from=index_exist,
  670. index_to=first_free_index,
  671. from_buffer=self.mamba_cache)
  672. self.mamba_cache_indices_mapping[cur_rid][
  673. seq_id] = first_free_index
  674. index_for_current_run = first_free_index
  675. else:
  676. index_for_current_run = self.mamba_cache_indices_mapping[
  677. cur_rid][seq_id]
  678. indices_for_current_run.append(index_for_current_run)
  679. return indices_for_current_run
  680. def _prepare_current_run_mamba_cache(
  681. self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int
  682. ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
  683. indices_for_current_run = []
  684. for request_id, seqs_id in request_ids_to_seq_ids.items():
  685. indices_for_current_run += self._assign_seq_id_to_mamba_cache(
  686. request_id, seqs_id)
  687. ## Pad the batch in case of running batch that was not captured via CG
  688. padded_indices = indices_for_current_run.copy()
  689. pad_index = self._first_free_index_in_mamba_cache()
  690. for _ in range(batch_size - len(indices_for_current_run)):
  691. padded_indices.append(pad_index)
  692. conv_state = self.mamba_cache[0][:, padded_indices]
  693. temporal_state = self.mamba_cache[1][:, padded_indices]
  694. return (conv_state, temporal_state), indices_for_current_run
  695. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  696. """
  697. Copy the relevant Mamba cache into the CUDA graph input buffer
  698. that was provided during the capture runs
  699. (JambaForCausalLM.mamba_gc_cache_buffer).
  700. """
  701. assert all(
  702. key in kwargs
  703. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  704. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  705. batch_size = len(request_ids_to_seq_ids)
  706. (
  707. current_mamba_cache,
  708. indices,
  709. ) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
  710. batch_size)
  711. self.current_indices = indices
  712. finished_requests_ids = kwargs["finished_requests_ids"]
  713. self._release_mamba_cache(finished_requests_ids)
  714. for input_buffer, current_cache_buffer in zip(
  715. input_buffers["seqlen_agnostic_capture_inputs"],
  716. current_mamba_cache):
  717. input_buffer.copy_(current_cache_buffer, non_blocking=True)
  718. def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
  719. """
  720. Copy the relevant Mamba cache from the CUDA graph input_buffers
  721. back to the JambaForCausalLM.mamba_cache after CUDA
  722. graph replay run is done.
  723. """
  724. self._copy_mamba_cache_by_indices(
  725. self.current_indices,
  726. input_buffers["seqlen_agnostic_capture_inputs"])
  727. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  728. """
  729. Provide the CUDA graph capture runs with a buffer in adjusted size.
  730. The buffer is used to maintain the Mamba Cache during the CUDA graph
  731. replay runs.
  732. """
  733. return tuple(buffer[:, :batch_size]
  734. for buffer in self.mamba_gc_cache_buffer)
  735. def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
  736. for req_id in finished_seq_groups_req_ids:
  737. if req_id in self.mamba_cache_indices_mapping:
  738. self.mamba_cache_indices_mapping.pop(req_id)
  739. def _first_free_index_in_mamba_cache(self) -> int:
  740. if self.mamba_cache:
  741. max_possible_batch_size = self.mamba_cache[0].shape[1]
  742. occupied = [
  743. id for seq_ids in self.mamba_cache_indices_mapping.values()
  744. for id in seq_ids.values()
  745. ]
  746. first_free_index = [
  747. i not in occupied for i in range(max_possible_batch_size)
  748. ].index(True)
  749. return first_free_index
  750. return 0
  751. def _get_mamba_cache_shape(
  752. self
  753. ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
  754. world_size = get_tensor_model_parallel_world_size()
  755. hidden_size = self.config.hidden_size
  756. conv_state_shape = (
  757. self.config.mamba_expand * hidden_size // world_size,
  758. self.config.mamba_d_conv,
  759. )
  760. temporal_state_shape = (
  761. self.config.mamba_expand * self.config.hidden_size // world_size,
  762. self.config.mamba_d_state,
  763. )
  764. return conv_state_shape, temporal_state_shape
  765. def _prepare_mamba_cache(self):
  766. dtype = self.lm_head.weight.dtype
  767. layers_type = self.config.layers_block_type
  768. mamba_layers = sum(
  769. [layer_type == "mamba" for layer_type in layers_type])
  770. max_batch_size = _BATCH_SIZES_TO_CAPTURE[-1] + 10
  771. conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
  772. assert conv_state_shape is not None and temporal_state_shape is not None
  773. for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
  774. buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
  775. conv_state_shape,
  776. dtype=dtype,
  777. device="cuda"),
  778. torch.empty(size=(mamba_layers, max_batch_size) +
  779. temporal_state_shape,
  780. dtype=dtype,
  781. device="cuda"))
  782. setattr(self, buffername, buffer)
  783. def compute_logits(self, hidden_states: torch.Tensor,
  784. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  785. logits = self.logits_processor(self.lm_head, hidden_states,
  786. sampling_metadata)
  787. return logits
  788. def sample(
  789. self,
  790. logits: Optional[torch.Tensor],
  791. sampling_metadata: SamplingMetadata,
  792. ) -> Optional[SamplerOutput]:
  793. next_tokens = self.sampler(logits, sampling_metadata)
  794. return next_tokens
  795. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  796. stacked_params_mapping = [
  797. # (param_name, shard_name, shard_id)
  798. ("qkv_proj", "q_proj", "q"),
  799. ("qkv_proj", "k_proj", "k"),
  800. ("qkv_proj", "v_proj", "v"),
  801. ("gate_up_proj", "gate_proj", 0),
  802. ("gate_up_proj", "up_proj", 1),
  803. ]
  804. expert_params_mapping = [
  805. # (param_name, weight_name, expert_id)
  806. (
  807. "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
  808. f"experts.{expert_id}.{weight_name}.weight",
  809. expert_id,
  810. ) for expert_id in range(self.config.num_experts)
  811. for weight_name in ["down_proj", "up_proj", "gate_proj"]
  812. ]
  813. params_dict = dict(self.named_parameters())
  814. for name, loaded_weight in weights:
  815. if "rotary_emb.inv_freq" in name:
  816. continue
  817. if "A_log" in name:
  818. name = name.replace("A_log", "A")
  819. if ".self_attn." in name:
  820. name = name.replace(".self_attn", "")
  821. for param_name, weight_name, shard_id in stacked_params_mapping:
  822. if weight_name not in name:
  823. continue
  824. if 'experts' in name:
  825. continue
  826. name = name.replace(weight_name, param_name)
  827. # Skip loading extra bias for GPTQ models.
  828. if name.endswith(".bias") and name not in params_dict:
  829. continue
  830. param = params_dict[name]
  831. weight_loader = param.weight_loader
  832. weight_loader(param, loaded_weight, shard_id)
  833. break
  834. else:
  835. for param_name, weight_name, expert_id in expert_params_mapping:
  836. if weight_name not in name:
  837. continue
  838. name = name.replace(weight_name, param_name)
  839. param = params_dict[name]
  840. weight_loader = param.weight_loader
  841. weight_loader(param,
  842. loaded_weight,
  843. weight_name,
  844. expert_id=expert_id)
  845. break
  846. else:
  847. # Skip loading extra bias for GPTQ models.
  848. if name.endswith(".bias") and name not in params_dict:
  849. continue
  850. param = params_dict[name]
  851. weight_loader = getattr(param, "weight_loader",
  852. default_weight_loader)
  853. weight_loader(param, loaded_weight)