1
0

mamba.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # coding=utf-8
  2. """PyTorch MAMBA model."""
  3. from dataclasses import dataclass
  4. from typing import Iterable, List, Optional, Tuple
  5. import torch
  6. from torch import nn
  7. from torch.nn.parameter import Parameter
  8. from transformers import MambaConfig
  9. from aphrodite.attention.backends.abstract import AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
  11. from aphrodite.common.sequence import IntermediateTensors
  12. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  13. get_tensor_model_parallel_world_size)
  14. from aphrodite.modeling.layers.activation import SiluAndMul
  15. from aphrodite.modeling.layers.layernorm import RMSNorm
  16. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  17. MergedColumnParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  20. from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
  21. causal_conv1d_fn, causal_conv1d_update)
  22. from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
  23. selective_scan_fn, selective_state_update)
  24. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  25. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  26. VocabParallelEmbedding)
  27. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  28. from aphrodite.modeling.models.interfaces import HasInnerState
  29. from aphrodite.modeling.models.mamba_cache import MambaCacheManager
  30. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  31. from aphrodite.modeling.utils import set_weight_attrs
  32. from aphrodite.quantization.base_config import QuantizationConfig
  33. from aphrodite.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
  34. _get_graph_batch_size)
  35. KVCache = Tuple[torch.Tensor, torch.Tensor]
  36. @dataclass
  37. class MambaCacheParams:
  38. is_prompt: bool = False
  39. conv_state: torch.Tensor = torch.Tensor()
  40. ssm_state: torch.Tensor = torch.Tensor()
  41. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  42. class MambaMixer(nn.Module):
  43. """
  44. Compute ∆, A, B, C, and D the state space parameters and compute
  45. the `contextualized_states`. A, D are input independent
  46. (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
  47. for why A isn't selective) ∆, B, C are input-dependent
  48. (this is a key difference between Mamba and the linear time
  49. invariant S4, and is why Mamba is called
  50. **selective** state spaces)
  51. """
  52. def __init__(self, config: MambaConfig, layer_idx):
  53. super().__init__()
  54. self.config = config
  55. self.layer_idx = layer_idx
  56. self.hidden_size = config.hidden_size
  57. self.ssm_state_size = config.state_size
  58. self.conv_kernel_size = config.conv_kernel
  59. self.intermediate_size = config.intermediate_size
  60. self.time_step_rank = int(config.time_step_rank)
  61. self.use_conv_bias = config.use_conv_bias
  62. # TODO: ??
  63. #self.use_bias = config.mamba_proj_bias
  64. self.use_bias = False
  65. self.conv1d = ColumnParallelLinear(
  66. input_size=self.conv_kernel_size,
  67. output_size=self.intermediate_size,
  68. bias=self.use_conv_bias,
  69. )
  70. # unsqueeze to fit conv1d weights shape into the linear weights shape.
  71. # Can't do this in `weight_loader` since it already exists in
  72. # `ColumnParallelLinear` and `set_weight_attrs`
  73. # doesn't allow to override it
  74. self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
  75. self.in_proj = MergedColumnParallelLinear(self.hidden_size,
  76. [self.intermediate_size] * 2,
  77. bias=self.use_bias)
  78. # selective projection used to make dt, B and C input dependent
  79. self.x_proj = RowParallelLinear(
  80. self.intermediate_size,
  81. self.time_step_rank + self.ssm_state_size * 2,
  82. bias=False,
  83. )
  84. # time step projection (discretization) -
  85. # In the forward we need to apply dt_proj without the bias,
  86. # as the bias is added in the selective scan kernel.
  87. self.dt_proj = ColumnParallelLinear(self.time_step_rank,
  88. self.intermediate_size,
  89. bias=True,
  90. skip_bias_add=True)
  91. def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  92. tp_rank = get_tensor_model_parallel_rank()
  93. tp_size = get_tensor_model_parallel_world_size()
  94. param.data.copy_(
  95. loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
  96. dim=0)[tp_rank])
  97. def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
  98. weight_loader(param, -torch.exp(loaded_weight.float()))
  99. tp_size = get_tensor_model_parallel_world_size()
  100. self.A = nn.Parameter(
  101. torch.empty(
  102. self.intermediate_size // tp_size,
  103. self.ssm_state_size,
  104. dtype=torch.float32,
  105. ))
  106. self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
  107. set_weight_attrs(self.D, {"weight_loader": weight_loader})
  108. set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
  109. self.out_proj = RowParallelLinear(
  110. self.intermediate_size,
  111. self.hidden_size,
  112. bias=self.use_bias,
  113. input_is_parallel=True,
  114. )
  115. self.activation = config.hidden_act
  116. def mamba_forward(self,
  117. hidden_states: torch.Tensor,
  118. cache_params: MambaCacheParams = None):
  119. # 1. Gated MLP's linear projection
  120. projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
  121. hidden_states, gate = projected_states.chunk(2, dim=1)
  122. # 2. Convolution sequence transformation
  123. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
  124. self.conv1d.weight.size(2))
  125. if cache_params is not None and not cache_params.is_prompt:
  126. hidden_states = causal_conv1d_update(
  127. hidden_states.squeeze(-1),
  128. cache_params.conv_state,
  129. conv_weights,
  130. self.conv1d.bias,
  131. self.activation,
  132. )
  133. hidden_states = hidden_states.unsqueeze(-1)
  134. else:
  135. if cache_params is not None:
  136. conv_states = nn.functional.pad(
  137. hidden_states,
  138. (self.conv_kernel_size - hidden_states.shape[-1], 0))
  139. cache_params.conv_state.copy_(conv_states)
  140. hidden_states, _ = causal_conv1d_fn(
  141. hidden_states,
  142. conv_weights,
  143. self.conv1d.bias,
  144. activation=self.activation,
  145. )
  146. # 3. State Space Model sequence transformation
  147. # 3.a. input varying initialization of time_step, B and C
  148. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
  149. time_step, B, C = torch.split(
  150. ssm_parameters,
  151. [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
  152. dim=-1,
  153. )
  154. # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
  155. discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
  156. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  157. time_proj_bias = (self.dt_proj.bias.float() if hasattr(
  158. self.dt_proj, "bias") else None)
  159. if cache_params is not None and not cache_params.is_prompt:
  160. scan_outputs = selective_state_update(
  161. cache_params.ssm_state,
  162. hidden_states[..., 0],
  163. discrete_time_step[..., 0],
  164. self.A,
  165. B[:, 0],
  166. C[:, 0],
  167. self.D,
  168. gate[..., 0],
  169. time_proj_bias,
  170. dt_softplus=True,
  171. ).unsqueeze(-1)
  172. else:
  173. scan_outputs, ssm_state = selective_scan_fn(
  174. hidden_states,
  175. discrete_time_step,
  176. self.A,
  177. B.transpose(1, 2),
  178. C.transpose(1, 2),
  179. self.D.float(),
  180. gate,
  181. time_proj_bias,
  182. delta_softplus=True,
  183. return_last_state=True,
  184. )
  185. if ssm_state is not None and cache_params is not None:
  186. cache_params.ssm_state.copy_(ssm_state)
  187. # 4. Final linear projection
  188. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
  189. return contextualized_states
  190. def forward(
  191. self,
  192. hidden_states: torch.Tensor,
  193. attn_metadata: AttentionMetadata,
  194. conv_state: torch.Tensor,
  195. ssm_state: torch.Tensor,
  196. ):
  197. if attn_metadata.prefill_metadata is not None:
  198. offset = 0
  199. for i, prompt_len in enumerate(
  200. attn_metadata.prefill_metadata.seq_lens):
  201. cache = MambaCacheParams(True,
  202. conv_state=conv_state[i].unsqueeze(0),
  203. ssm_state=ssm_state[i].unsqueeze(0))
  204. hidden_states[offset:offset + prompt_len].copy_(
  205. self.mamba_forward(hidden_states[offset:offset +
  206. prompt_len].unsqueeze(0),
  207. cache_params=cache)[0])
  208. offset += prompt_len
  209. else:
  210. cache = MambaCacheParams(False,
  211. conv_state=conv_state,
  212. ssm_state=ssm_state)
  213. hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
  214. cache_params=cache)
  215. hidden_states = hidden_states.squeeze(1)
  216. return hidden_states
  217. class MambaMLP(nn.Module):
  218. def __init__(
  219. self,
  220. config: MambaConfig,
  221. quant_config: Optional[QuantizationConfig] = None,
  222. ) -> None:
  223. super().__init__()
  224. hidden_size = config.hidden_size
  225. intermediate_size = config.intermediate_size
  226. hidden_act = config.hidden_act
  227. self.gate_up_proj = MergedColumnParallelLinear(
  228. hidden_size, [intermediate_size] * 2,
  229. bias=False,
  230. quant_config=quant_config)
  231. self.down_proj = RowParallelLinear(intermediate_size,
  232. hidden_size,
  233. bias=False,
  234. quant_config=quant_config)
  235. if hidden_act != "silu":
  236. raise ValueError(f"Unsupported activation: {hidden_act}. "
  237. "Only silu is supported for now.")
  238. self.act_fn = SiluAndMul()
  239. def forward(self, x):
  240. gate_up, _ = self.gate_up_proj(x)
  241. x = self.act_fn(gate_up)
  242. x, _ = self.down_proj(x)
  243. return x
  244. class MambaDecoderLayer(nn.Module):
  245. def __init__(self,
  246. config: MambaConfig,
  247. layer_idx: int,
  248. cache_config: Optional[CacheConfig] = None,
  249. quant_config: Optional[QuantizationConfig] = None) -> None:
  250. super().__init__()
  251. self.layer_idx = layer_idx
  252. self.config = config
  253. self.mixer = MambaMixer(config, layer_idx)
  254. self.feed_forward = MambaMLP(config, quant_config=quant_config)
  255. self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  256. self.pre_ff_layernorm = RMSNorm(config.hidden_size,
  257. eps=config.layer_norm_epsilon)
  258. def forward(
  259. self,
  260. hidden_states: torch.Tensor,
  261. attn_metadata: AttentionMetadata,
  262. residual: Optional[torch.Tensor],
  263. conv_state: torch.Tensor,
  264. ssm_state: torch.Tensor,
  265. **kwargs,
  266. ):
  267. if residual is None:
  268. residual = hidden_states
  269. hidden_states = self.norm(hidden_states)
  270. else:
  271. hidden_states, residual = self.norm(hidden_states, residual)
  272. hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
  273. ssm_state)
  274. # Fully Connected
  275. hidden_states, residual = self.pre_ff_layernorm(
  276. hidden_states, residual)
  277. hidden_states = self.feed_forward(hidden_states)
  278. return hidden_states, residual
  279. class MambaModel(nn.Module):
  280. def __init__(
  281. self,
  282. config: MambaConfig,
  283. quant_config: Optional[QuantizationConfig] = None,
  284. cache_config: Optional[CacheConfig] = None,
  285. lora_config: Optional[LoRAConfig] = None,
  286. ) -> None:
  287. super().__init__()
  288. self.config = config
  289. self.padding_idx = config.pad_token_id
  290. lora_vocab = ((lora_config.lora_extra_vocab_size *
  291. (lora_config.max_loras or 1)) if lora_config else 0)
  292. self.vocab_size = config.vocab_size + lora_vocab
  293. self.org_vocab_size = config.vocab_size
  294. self.embeddings = VocabParallelEmbedding(
  295. self.vocab_size,
  296. config.hidden_size,
  297. org_num_embeddings=config.vocab_size,
  298. )
  299. decoder_layers = []
  300. for i in range(config.num_hidden_layers):
  301. decoder_layers.append(
  302. MambaDecoderLayer(config,
  303. layer_idx=i,
  304. cache_config=cache_config,
  305. quant_config=quant_config))
  306. self.layers = nn.ModuleList(decoder_layers)
  307. self.norm_f = RMSNorm(config.hidden_size,
  308. eps=config.layer_norm_epsilon)
  309. def forward(
  310. self,
  311. input_ids: torch.Tensor,
  312. positions: torch.Tensor,
  313. kv_caches: List[torch.Tensor],
  314. attn_metadata: AttentionMetadata,
  315. conv_state: torch.Tensor,
  316. ssm_state: torch.Tensor,
  317. ) -> torch.Tensor:
  318. hidden_states = self.embeddings(input_ids)
  319. residual = None
  320. for i in range(len(self.layers)):
  321. layer = self.layers[i]
  322. current_ssm_state = ssm_state[i]
  323. current_conv_state = conv_state[i]
  324. hidden_states, residual = layer(
  325. positions=positions,
  326. hidden_states=hidden_states,
  327. attn_metadata=attn_metadata,
  328. residual=residual,
  329. conv_state=current_conv_state,
  330. ssm_state=current_ssm_state,
  331. )
  332. hidden_states, _ = self.norm_f(hidden_states, residual)
  333. return hidden_states
  334. class MambaForCausalLM(nn.Module, HasInnerState):
  335. packed_modules_mapping = {
  336. "qkv_proj": [
  337. "q_proj",
  338. "k_proj",
  339. "v_proj",
  340. ],
  341. }
  342. # LoRA specific attributes
  343. supported_lora_modules = [
  344. "qkv_proj",
  345. "o_proj",
  346. "embed_tokens",
  347. "lm_head",
  348. ]
  349. embedding_modules = {
  350. "embeddings": "input_embeddings",
  351. "lm_head": "output_embeddings",
  352. }
  353. embedding_padding_modules = ["lm_head"]
  354. def __init__(
  355. self,
  356. config: MambaConfig,
  357. cache_config: Optional[CacheConfig] = None,
  358. quant_config: Optional[QuantizationConfig] = None,
  359. lora_config: Optional[LoRAConfig] = None,
  360. scheduler_config: Optional[SchedulerConfig] = None,
  361. ) -> None:
  362. super().__init__()
  363. self.config = config
  364. self.scheduler_config = scheduler_config
  365. self.backbone = MambaModel(config,
  366. cache_config=cache_config,
  367. quant_config=quant_config,
  368. lora_config=lora_config)
  369. self.unpadded_vocab_size = config.vocab_size
  370. if lora_config:
  371. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  372. self.lm_head = self.backbone.embeddings
  373. # Used to track and store by the Mamba cache between steps.
  374. self.mamba_cache: Optional[MambaCacheManager] = None
  375. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  376. config.vocab_size)
  377. self.sampler = Sampler()
  378. def forward(self,
  379. input_ids: torch.Tensor,
  380. positions: torch.Tensor,
  381. kv_caches: List[KVCache],
  382. attn_metadata: AttentionMetadata,
  383. intermediate_tensors: Optional[IntermediateTensors] = None,
  384. **kwargs):
  385. if self.mamba_cache is None:
  386. max_batch_size = (_get_graph_batch_size(
  387. self.scheduler_config.max_num_seqs) if self.scheduler_config
  388. else max(_BATCH_SIZES_TO_CAPTURE) + 2)
  389. self.mamba_cache = MambaCacheManager(
  390. self.lm_head.weight.dtype, self.config.num_hidden_layers,
  391. max_batch_size, *self._get_mamba_cache_shape())
  392. if "seqlen_agnostic_capture_inputs" not in kwargs:
  393. # We get here only on Prefill/Eager mode runs
  394. assert all(
  395. key in kwargs
  396. for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
  397. request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
  398. finished_requests_ids = kwargs["finished_requests_ids"]
  399. self.mamba_cache.release_finished_requests(finished_requests_ids)
  400. batch_size = input_ids.shape[0]
  401. if attn_metadata.prefill_metadata:
  402. batch_size = len(request_ids_to_seq_ids)
  403. mamba_cache_tensors = self.mamba_cache.prepare_current_run_state(
  404. request_ids_to_seq_ids, batch_size, finished_requests_ids)
  405. else:
  406. # CUDA graph capturing runs
  407. mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
  408. hidden_states = self.backbone(input_ids, positions, kv_caches,
  409. attn_metadata, mamba_cache_tensors[0],
  410. mamba_cache_tensors[1])
  411. return hidden_states
  412. def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
  413. return self.mamba_cache.copy_inputs_before_cuda_graphs(
  414. input_buffers, **kwargs)
  415. def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
  416. return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
  417. def _get_mamba_cache_shape(
  418. self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
  419. world_size = get_tensor_model_parallel_world_size()
  420. conv_state_shape = (
  421. self.config.intermediate_size // world_size,
  422. self.config.conv_kernel,
  423. )
  424. temporal_state_shape = (
  425. self.config.intermediate_size // world_size,
  426. self.config.state_size,
  427. )
  428. return conv_state_shape, temporal_state_shape
  429. def compute_logits(
  430. self,
  431. hidden_states: torch.Tensor,
  432. sampling_metadata: SamplingMetadata,
  433. ) -> Optional[torch.Tensor]:
  434. logits = self.logits_processor(self.lm_head, hidden_states,
  435. sampling_metadata)
  436. return logits
  437. def sample(
  438. self,
  439. logits: Optional[torch.Tensor],
  440. sampling_metadata: SamplingMetadata,
  441. ) -> Optional[SamplerOutput]:
  442. next_tokens = self.sampler(logits, sampling_metadata)
  443. return next_tokens
  444. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  445. stacked_params_mapping = [
  446. # (param_name, shard_name, shard_id)
  447. ("qkv_proj", "q_proj", "q"),
  448. ("qkv_proj", "k_proj", "k"),
  449. ("qkv_proj", "v_proj", "v"),
  450. ("gate_up_proj", "gate_proj", 0),
  451. ("gate_up_proj", "up_proj", 1),
  452. ]
  453. params_dict = dict(self.named_parameters())
  454. for name, loaded_weight in weights:
  455. if "rotary_emb.inv_freq" in name:
  456. continue
  457. if "A_log" in name:
  458. name = name.replace("A_log", "A")
  459. if ".self_attn." in name:
  460. name = name.replace(".self_attn", "")
  461. for param_name, weight_name, shard_id in stacked_params_mapping:
  462. if weight_name not in name:
  463. continue
  464. name = name.replace(weight_name, param_name)
  465. # Skip loading extra bias for GPTQ models.
  466. if name.endswith(".bias") and name not in params_dict:
  467. continue
  468. param = params_dict[name]
  469. weight_loader = param.weight_loader
  470. weight_loader(param, loaded_weight, shard_id)
  471. break
  472. else:
  473. # Skip loading extra bias for GPTQ models.
  474. if name.endswith(".bias") and name not in params_dict:
  475. continue
  476. param = params_dict[name]
  477. weight_loader = getattr(param, "weight_loader",
  478. default_weight_loader)
  479. weight_loader(param, loaded_weight)