chameleon.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. from functools import cached_property
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. from transformers import ChameleonConfig
  7. from aphrodite.attention import Attention, AttentionMetadata
  8. from aphrodite.common.config import CacheConfig
  9. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  10. from aphrodite.common.utils import print_warning_once
  11. from aphrodite.distributed import get_tensor_model_parallel_world_size
  12. from aphrodite.modeling.layers.activation import SiluAndMul
  13. from aphrodite.modeling.layers.layernorm import RMSNorm
  14. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  15. QKVParallelLinear,
  16. RowParallelLinear)
  17. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  18. from aphrodite.modeling.layers.rotary_embedding import get_rope
  19. from aphrodite.modeling.layers.sampler import Sampler
  20. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  21. ParallelLMHead, VocabParallelEmbedding)
  22. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  23. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  24. from aphrodite.quantization.base_config import QuantizationConfig
  25. class ChameleonLayerNorm(nn.LayerNorm):
  26. def __init__(self, hidden_size, *args, **kwargs):
  27. super().__init__(hidden_size, *args, **kwargs)
  28. self.normalized_shape = (hidden_size[-1], )
  29. def forward(self, hidden_states):
  30. hidden_states = F.layer_norm(hidden_states,
  31. self.normalized_shape,
  32. None,
  33. None,
  34. eps=1e-5)
  35. hidden_states = hidden_states * self.weight + self.bias
  36. return hidden_states
  37. # Copied from aphrodite.modeling.models.llama.LlamaMLP -> ChameleonMLP
  38. class ChameleonMLP(nn.Module):
  39. def __init__(
  40. self,
  41. hidden_size: int,
  42. intermediate_size: int,
  43. hidden_act: str,
  44. quant_config: Optional[QuantizationConfig] = None,
  45. bias: bool = False,
  46. ) -> None:
  47. super().__init__()
  48. self.gate_up_proj = MergedColumnParallelLinear(
  49. input_size=hidden_size,
  50. output_sizes=[intermediate_size] * 2,
  51. bias=bias,
  52. quant_config=quant_config)
  53. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  54. output_size=hidden_size,
  55. bias=bias,
  56. quant_config=quant_config)
  57. if hidden_act != "silu":
  58. raise ValueError(f"Unsupported activation: {hidden_act}. "
  59. "Only silu is supported for now.")
  60. self.act_fn = SiluAndMul()
  61. def forward(self, x):
  62. gate_up, _ = self.gate_up_proj(x)
  63. x = self.act_fn(gate_up)
  64. x, _ = self.down_proj(x)
  65. return x
  66. # Modified from aphrodite.modeling.models.llama.LlamaAttention -> ChameleonAttention #noqa
  67. class ChameleonAttention(nn.Module):
  68. def __init__(
  69. self,
  70. hidden_size: int,
  71. num_heads: int,
  72. num_kv_heads: int,
  73. rope_theta: float = 10000,
  74. rope_scaling: Optional[Dict[str, Any]] = None,
  75. max_position_embeddings: int = 4096,
  76. quant_config: Optional[QuantizationConfig] = None,
  77. bias: bool = False,
  78. cache_config: Optional[CacheConfig] = None,
  79. ) -> None:
  80. super().__init__()
  81. self.hidden_size = hidden_size
  82. tp_size = get_tensor_model_parallel_world_size()
  83. self.total_num_heads = num_heads
  84. assert self.total_num_heads % tp_size == 0
  85. self.num_heads = self.total_num_heads // tp_size
  86. self.total_num_kv_heads = num_kv_heads
  87. if self.total_num_kv_heads >= tp_size:
  88. # Number of KV heads is greater than TP size, so we partition
  89. # the KV heads across multiple tensor parallel GPUs.
  90. assert self.total_num_kv_heads % tp_size == 0
  91. else:
  92. # Number of KV heads is less than TP size, so we replicate
  93. # the KV heads across multiple tensor parallel GPUs.
  94. assert tp_size % self.total_num_kv_heads == 0
  95. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  96. self.head_dim = hidden_size // self.total_num_heads
  97. self.q_size = self.num_heads * self.head_dim
  98. self.kv_size = self.num_kv_heads * self.head_dim
  99. self.scaling = self.head_dim**-0.5
  100. self.rope_theta = rope_theta
  101. self.max_position_embeddings = max_position_embeddings
  102. self.qkv_proj = QKVParallelLinear(
  103. hidden_size=hidden_size,
  104. head_size=self.head_dim,
  105. total_num_heads=self.total_num_heads,
  106. total_num_kv_heads=self.total_num_kv_heads,
  107. bias=bias,
  108. quant_config=quant_config,
  109. )
  110. self.o_proj = RowParallelLinear(
  111. input_size=self.total_num_heads * self.head_dim,
  112. output_size=hidden_size,
  113. bias=bias,
  114. quant_config=quant_config,
  115. )
  116. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  117. self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
  118. self.rotary_emb = get_rope(
  119. self.head_dim,
  120. rotary_dim=self.head_dim,
  121. max_position=max_position_embeddings,
  122. base=rope_theta,
  123. rope_scaling=rope_scaling,
  124. )
  125. self.attn = Attention(self.num_heads,
  126. self.head_dim,
  127. self.scaling,
  128. num_kv_heads=self.num_kv_heads,
  129. cache_config=cache_config,
  130. quant_config=quant_config)
  131. def _apply_qk_norm(self, q: torch.Tensor,
  132. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  133. # reshape for layernorm
  134. q = q.reshape(-1, self.num_heads, self.head_dim)
  135. k = k.reshape(-1, self.num_kv_heads, self.head_dim)
  136. q = self.q_norm(q)
  137. k = self.k_norm(k)
  138. q = q.view(*q.shape[:-2], -1)
  139. k = k.view(*k.shape[:-2], -1)
  140. return q, k
  141. def forward(
  142. self,
  143. positions: torch.Tensor,
  144. hidden_states: torch.Tensor,
  145. kv_cache: torch.Tensor,
  146. attn_metadata: AttentionMetadata,
  147. ) -> torch.Tensor:
  148. qkv, _ = self.qkv_proj(hidden_states)
  149. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  150. q, k = self._apply_qk_norm(q, k)
  151. q, k = self.rotary_emb(positions, q, k)
  152. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  153. output, _ = self.o_proj(attn_output)
  154. return output
  155. class ChameleonDecoderLayer(nn.Module):
  156. def __init__(
  157. self,
  158. config: ChameleonConfig,
  159. cache_config: Optional[CacheConfig] = None,
  160. quant_config: Optional[QuantizationConfig] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = config.hidden_size
  164. rope_theta = getattr(config, "rope_theta", 10000)
  165. rope_scaling = getattr(config, "rope_scaling", None)
  166. if rope_scaling is not None and getattr(
  167. config, "original_max_position_embeddings", None):
  168. rope_scaling["original_max_position_embeddings"] = (
  169. config.original_max_position_embeddings)
  170. max_position_embeddings = getattr(config, "max_position_embeddings",
  171. 4096)
  172. self.self_attn = ChameleonAttention(
  173. hidden_size=self.hidden_size,
  174. num_heads=config.num_attention_heads,
  175. num_kv_heads=getattr(config, "num_key_value_heads",
  176. config.num_attention_heads),
  177. rope_theta=rope_theta,
  178. rope_scaling=rope_scaling,
  179. max_position_embeddings=max_position_embeddings,
  180. quant_config=quant_config,
  181. bias=False,
  182. cache_config=cache_config,
  183. )
  184. self.mlp = ChameleonMLP(
  185. hidden_size=self.hidden_size,
  186. intermediate_size=config.intermediate_size,
  187. hidden_act=config.hidden_act,
  188. quant_config=quant_config,
  189. bias=getattr(config, "mlp_bias", False),
  190. )
  191. self.input_layernorm = RMSNorm(config.hidden_size,
  192. eps=config.rms_norm_eps)
  193. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  194. eps=config.rms_norm_eps)
  195. def forward(
  196. self,
  197. positions: torch.Tensor,
  198. hidden_states: torch.Tensor,
  199. kv_cache: torch.Tensor,
  200. attn_metadata: AttentionMetadata,
  201. residual: Optional[torch.Tensor],
  202. ) -> Tuple[torch.Tensor, torch.Tensor]:
  203. if residual is None:
  204. residual = hidden_states
  205. hidden_states = self.input_layernorm(hidden_states)
  206. else:
  207. hidden_states, residual = self.input_layernorm(
  208. hidden_states, residual)
  209. hidden_states = self.self_attn(
  210. positions=positions,
  211. hidden_states=hidden_states,
  212. kv_cache=kv_cache,
  213. attn_metadata=attn_metadata,
  214. )
  215. # Fully Connected
  216. hidden_states, residual = self.post_attention_layernorm(
  217. hidden_states, residual)
  218. hidden_states = self.mlp(hidden_states)
  219. return hidden_states, residual
  220. class ChameleonSwinDecoderLayer(nn.Module):
  221. def __init__(
  222. self,
  223. config: ChameleonConfig,
  224. cache_config: Optional[CacheConfig] = None,
  225. quant_config: Optional[QuantizationConfig] = None,
  226. ) -> None:
  227. super().__init__()
  228. self.hidden_size = config.hidden_size
  229. rope_theta = getattr(config, "rope_theta", 10000)
  230. rope_scaling = getattr(config, "rope_scaling", None)
  231. if rope_scaling is not None and getattr(
  232. config, "original_max_position_embeddings", None):
  233. rope_scaling["original_max_position_embeddings"] = (
  234. config.original_max_position_embeddings)
  235. max_position_embeddings = getattr(config, "max_position_embeddings",
  236. 4096)
  237. self.self_attn = ChameleonAttention(
  238. hidden_size=self.hidden_size,
  239. num_heads=config.num_attention_heads,
  240. num_kv_heads=getattr(config, "num_key_value_heads",
  241. config.num_attention_heads),
  242. rope_theta=rope_theta,
  243. rope_scaling=rope_scaling,
  244. max_position_embeddings=max_position_embeddings,
  245. quant_config=quant_config,
  246. bias=False,
  247. cache_config=cache_config,
  248. )
  249. self.mlp = ChameleonMLP(
  250. hidden_size=self.hidden_size,
  251. intermediate_size=config.intermediate_size,
  252. hidden_act=config.hidden_act,
  253. quant_config=quant_config,
  254. bias=getattr(config, "mlp_bias", False),
  255. )
  256. self.input_layernorm = RMSNorm(config.hidden_size,
  257. eps=config.rms_norm_eps)
  258. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  259. eps=config.rms_norm_eps)
  260. def forward(
  261. self,
  262. positions: torch.Tensor,
  263. hidden_states: torch.Tensor,
  264. kv_cache: torch.Tensor,
  265. attn_metadata: AttentionMetadata,
  266. residual: Optional[torch.Tensor],
  267. ) -> Tuple[torch.Tensor, torch.Tensor]:
  268. residual = hidden_states
  269. hidden_states = self.self_attn(
  270. positions=positions,
  271. hidden_states=hidden_states,
  272. kv_cache=kv_cache,
  273. attn_metadata=attn_metadata,
  274. )
  275. hidden_states = self.input_layernorm(hidden_states)
  276. hidden_states = hidden_states + residual
  277. # Fully Connected
  278. residual = hidden_states
  279. hidden_states = self.mlp(hidden_states)
  280. hidden_states = self.post_attention_layernorm(hidden_states)
  281. hidden_states = residual + hidden_states
  282. return hidden_states, residual
  283. class ChameleonImageVocabularyMapping:
  284. """
  285. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  286. """
  287. def __init__(self, vocab_map):
  288. self.vocab_map = vocab_map
  289. self.image_token_id = vocab_map.get("<image>")
  290. @cached_property
  291. def val2name(self):
  292. return {v: k for k, v in self.vocab_map.items()}
  293. @cached_property
  294. def image_tokens(self):
  295. return sorted([
  296. val for name, val in self.vocab_map.items()
  297. if name.startswith("IMGIMG")
  298. ])
  299. @cached_property
  300. def bpe2img(self):
  301. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  302. def remap(old_name: str) -> str:
  303. return "".join(
  304. img_tkn_chr_mapping.get(c, c)
  305. for c in old_name[len("IMGIMG"):-1])
  306. return {
  307. tok: int(remap(self.val2name[tok]))
  308. for tok in self.image_tokens
  309. }
  310. @cached_property
  311. def img2bpe(self):
  312. return {v: k for k, v in self.bpe2img.items()}
  313. @cached_property
  314. def bpe2img_search_tensors(self):
  315. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
  316. sorted(self.bpe2img.values()))
  317. @cached_property
  318. def img2bpe_mapping_tensor(self):
  319. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  320. for k, v in self.img2bpe.items():
  321. mapping[k] = v
  322. return mapping
  323. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  324. device = img_batch.device
  325. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  326. return img_tokens.to(device)
  327. class ChameleonModel(nn.Module):
  328. def __init__(
  329. self,
  330. config: ChameleonConfig,
  331. cache_config: Optional[CacheConfig] = None,
  332. quant_config: Optional[QuantizationConfig] = None,
  333. ) -> None:
  334. super().__init__()
  335. self.config = config
  336. self.padding_idx = config.pad_token_id
  337. self.vocab_size = config.vocab_size
  338. self.embed_tokens = VocabParallelEmbedding(
  339. self.vocab_size,
  340. config.hidden_size,
  341. )
  342. self.vocabulary_mapping = ChameleonImageVocabularyMapping(
  343. config.vocabulary_map)
  344. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
  345. else ChameleonSwinDecoderLayer
  346. self.layers = nn.ModuleList([
  347. decoder_layer(config=config,
  348. cache_config=cache_config,
  349. quant_config=quant_config)
  350. for _ in range(config.num_hidden_layers)
  351. ])
  352. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  353. # TODO: Support image input
  354. # self.vqmodel = ChameleonVQModel(config.vq_config)
  355. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  356. return self.embed_tokens(input_ids)
  357. def forward(
  358. self,
  359. input_ids: Optional[torch.Tensor],
  360. positions: torch.Tensor,
  361. kv_caches: List[torch.Tensor],
  362. attn_metadata: AttentionMetadata,
  363. inputs_embeds: Optional[torch.Tensor] = None,
  364. ) -> torch.Tensor:
  365. if inputs_embeds is not None:
  366. hidden_states = inputs_embeds
  367. else:
  368. hidden_states = self.get_input_embeddings(input_ids)
  369. residual = None
  370. for i in range(len(self.layers)):
  371. layer = self.layers[i]
  372. hidden_states, residual = layer(
  373. positions,
  374. hidden_states,
  375. kv_caches[i],
  376. attn_metadata,
  377. residual,
  378. )
  379. hidden_states, _ = self.norm(hidden_states, residual)
  380. return hidden_states
  381. class ChameleonForConditionalGeneration(nn.Module):
  382. def __init__(
  383. self,
  384. config: ChameleonConfig,
  385. cache_config: Optional[CacheConfig] = None,
  386. quant_config: Optional[QuantizationConfig] = None,
  387. ) -> None:
  388. super().__init__()
  389. self.config = config
  390. self.model = ChameleonModel(config, cache_config, quant_config)
  391. self.unpadded_vocab_size = config.vocab_size
  392. self.lm_head = ParallelLMHead(
  393. self.unpadded_vocab_size,
  394. config.hidden_size,
  395. )
  396. if config.tie_word_embeddings:
  397. self.lm_head.weight = self.model.embed_tokens.weight
  398. logit_scale = getattr(config, "logit_scale", 1.0)
  399. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  400. config.vocab_size, logit_scale)
  401. self.sampler = Sampler()
  402. def forward(
  403. self,
  404. input_ids: torch.Tensor,
  405. positions: torch.Tensor,
  406. kv_caches: List[torch.Tensor],
  407. attn_metadata: AttentionMetadata,
  408. intermediate_tensors: Optional[IntermediateTensors] = None,
  409. **kwargs,
  410. ) -> torch.Tensor:
  411. # TODO: Support image input
  412. # image_tokens = self.process_image_input(**kwargs)
  413. # image_mask = input_ids == self.vocabulary_mapping.image_token_id
  414. # input_ids[special_image_mask] = image_tokens.flatten().to(input_ids.dtype) #noqa
  415. hidden_states = self.model(input_ids, positions, kv_caches,
  416. attn_metadata)
  417. return hidden_states
  418. def compute_logits(self, hidden_states: torch.Tensor,
  419. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  420. logits = self.logits_processor(self.lm_head, hidden_states,
  421. sampling_metadata)
  422. # Disallow image tokens which does not include special
  423. # begin-image and end-image tokens
  424. image_tokens = self.model.vocabulary_mapping.image_tokens
  425. logits[:, image_tokens] = torch.finfo(logits.dtype).min
  426. return logits
  427. def sample(
  428. self,
  429. logits: torch.Tensor,
  430. sampling_metadata: SamplingMetadata,
  431. ) -> Optional[SamplerOutput]:
  432. next_tokens = self.sampler(logits, sampling_metadata)
  433. return next_tokens
  434. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  435. stacked_params_mapping = [
  436. # (param_name, shard_name, shard_id)
  437. (".qkv_proj", ".q_proj", "q"),
  438. (".qkv_proj", ".k_proj", "k"),
  439. (".qkv_proj", ".v_proj", "v"),
  440. (".gate_up_proj", ".gate_proj", 0),
  441. (".gate_up_proj", ".up_proj", 1),
  442. ]
  443. params_dict = dict(self.named_parameters())
  444. for name, loaded_weight in weights:
  445. if "rotary_emb.inv_freq" in name:
  446. continue
  447. # Skip loading vqgan
  448. # TODO: add support for the vision model
  449. if "vqmodel" in name:
  450. continue
  451. if ("rotary_emb.cos_cached" in name
  452. or "rotary_emb.sin_cached" in name):
  453. # Models trained using ColossalAI may include these tensors in
  454. # the checkpoint. Skip them.
  455. continue
  456. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  457. if weight_name not in name:
  458. continue
  459. name = name.replace(weight_name, param_name)
  460. # Skip loading extra bias for GPTQ models.
  461. if name.endswith(".bias") and name not in params_dict:
  462. continue
  463. param = params_dict[name]
  464. weight_loader = param.weight_loader
  465. weight_loader(param, loaded_weight, shard_id)
  466. break
  467. else:
  468. # Skip loading extra bias for GPTQ models.
  469. if name.endswith(".bias") and name not in params_dict:
  470. continue
  471. # Remapping the name of FP8 kv-scale.
  472. if name.endswith("kv_scale"):
  473. remapped_kv_scale_name = name.replace(
  474. ".kv_scale", ".attn.kv_scale")
  475. if remapped_kv_scale_name not in params_dict:
  476. print_warning_once(
  477. f"Found kv scale in the checkpoint (e.g. {name}), "
  478. "but not found the expected name in the model "
  479. f"(e.g. {remapped_kv_scale_name}). kv-scale is "
  480. "not loaded.")
  481. continue
  482. else:
  483. name = remapped_kv_scale_name
  484. param = params_dict[name]
  485. weight_loader = getattr(param, "weight_loader",
  486. default_weight_loader)
  487. weight_loader(param, loaded_weight)