chameleon.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. from functools import cached_property
  2. from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
  3. Tuple, TypedDict)
  4. import torch
  5. import torch.nn.functional as F
  6. from PIL import Image
  7. from torch import nn
  8. from transformers import ChameleonConfig, ChameleonVQVAEConfig
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, MultiModalConfig
  11. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  12. SequenceData)
  13. from aphrodite.common.utils import print_warning_once
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  16. from aphrodite.modeling.layers.activation import SiluAndMul
  17. from aphrodite.modeling.layers.layernorm import RMSNorm
  18. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  19. QKVParallelLinear,
  20. RowParallelLinear)
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.rotary_embedding import get_rope
  23. from aphrodite.modeling.layers.sampler import Sampler
  24. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  25. ParallelLMHead, VocabParallelEmbedding)
  26. from aphrodite.modeling.model_loader.weight_utils import (
  27. default_weight_loader, row_parallel_weight_loader)
  28. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  29. from aphrodite.modeling.utils import set_weight_attrs
  30. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  31. from aphrodite.multimodal.image import (cached_get_tokenizer,
  32. repeat_and_pad_image_tokens)
  33. from aphrodite.quantization.base_config import QuantizationConfig
  34. from .interfaces import SupportsMultiModal
  35. # These configs are not part of the model config but the preprocessor
  36. # and processor files, so we hardcode them in the model file for now.
  37. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
  38. CHAMELEON_IMAGE_SEQ_LENGTH = 1024
  39. CHAMELEON_IMAGE_TOKEN_ID = 8711
  40. CHAMELEON_IMAGE_START_TOKEN_ID = 8197
  41. CHAMELEON_IMAGE_END_TOKEN_ID = 8196
  42. CHAMELEON_SEP_TOKEN_ID = 8710
  43. class ChameleonImagePixelInputs(TypedDict):
  44. type: Literal["pixel_values"]
  45. data: torch.Tensor
  46. """Shape: `(batch_size, num_channels, height, width)`"""
  47. def get_max_chameleon_image_tokens(ctx: InputContext):
  48. return CHAMELEON_IMAGE_SEQ_LENGTH
  49. def dummy_seq_data_for_chameleon(
  50. seq_len: int,
  51. num_images: int,
  52. *,
  53. image_token_id: int,
  54. image_feature_size_override: Optional[int] = None,
  55. ):
  56. if image_feature_size_override is None:
  57. image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
  58. else:
  59. image_feature_size = image_feature_size_override
  60. token_ids = [image_token_id] * image_feature_size * num_images
  61. token_ids += [0] * (seq_len - image_feature_size * num_images)
  62. return SequenceData(token_ids)
  63. def dummy_image_for_chameleon(
  64. num_images: int,
  65. *,
  66. image_width_override: Optional[int] = None,
  67. image_height_override: Optional[int] = None,
  68. ):
  69. width = CHAMELEON_CROP_SIZE_WIDTH
  70. height = CHAMELEON_CROP_SIZE_HEIGHT
  71. if image_width_override is not None:
  72. width = image_width_override
  73. if image_height_override is not None:
  74. height = image_height_override
  75. image = Image.new("RGB", (width, height), color=0)
  76. return {"image": image if num_images == 1 else [image] * num_images}
  77. def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
  78. mm_counts: Mapping[str, int]):
  79. num_images = mm_counts["image"]
  80. seq_data = dummy_seq_data_for_chameleon(
  81. seq_len,
  82. num_images,
  83. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  84. )
  85. mm_data = dummy_image_for_chameleon(num_images)
  86. return seq_data, mm_data
  87. def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
  88. """
  89. Processing input prompt to insert required tokens for image placeholder.
  90. See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
  91. """ # noqa
  92. multi_modal_data = llm_inputs.get("multi_modal_data")
  93. if multi_modal_data is None or "image" not in multi_modal_data:
  94. return llm_inputs
  95. model_config = ctx.model_config
  96. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  97. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  98. tokenizer,
  99. llm_inputs.get("prompt"),
  100. llm_inputs["prompt_token_ids"],
  101. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  102. repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
  103. pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
  104. pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
  105. )
  106. # Appending sep token for chat mode to follow default processor
  107. # behavior
  108. if new_prompt is not None:
  109. new_prompt += tokenizer.sep_token
  110. new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
  111. # NOTE: Create a defensive copy of the original inputs
  112. return LLMInputs(prompt_token_ids=new_token_ids,
  113. prompt=new_prompt,
  114. multi_modal_data=multi_modal_data)
  115. class ChameleonLayerNorm(nn.LayerNorm):
  116. def __init__(self, hidden_size, *args, **kwargs):
  117. super().__init__(hidden_size, *args, **kwargs)
  118. self.normalized_shape = (hidden_size[-1], )
  119. set_weight_attrs(self.weight,
  120. {"weight_loader": row_parallel_weight_loader})
  121. set_weight_attrs(self.bias,
  122. {"weight_loader": row_parallel_weight_loader})
  123. def forward(self, hidden_states):
  124. hidden_states = F.layer_norm(hidden_states,
  125. self.normalized_shape,
  126. None,
  127. None,
  128. eps=1e-5)
  129. hidden_states = hidden_states * self.weight + self.bias
  130. return hidden_states
  131. # Copied from aphrodite.modeling.models.llama.LlamaMLP -> ChameleonMLP
  132. class ChameleonMLP(nn.Module):
  133. def __init__(
  134. self,
  135. hidden_size: int,
  136. intermediate_size: int,
  137. hidden_act: str,
  138. quant_config: Optional[QuantizationConfig] = None,
  139. bias: bool = False,
  140. ) -> None:
  141. super().__init__()
  142. self.gate_up_proj = MergedColumnParallelLinear(
  143. input_size=hidden_size,
  144. output_sizes=[intermediate_size] * 2,
  145. bias=bias,
  146. quant_config=quant_config)
  147. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  148. output_size=hidden_size,
  149. bias=bias,
  150. quant_config=quant_config)
  151. if hidden_act != "silu":
  152. raise ValueError(f"Unsupported activation: {hidden_act}. "
  153. "Only silu is supported for now.")
  154. self.act_fn = SiluAndMul()
  155. def forward(self, x):
  156. gate_up, _ = self.gate_up_proj(x)
  157. x = self.act_fn(gate_up)
  158. x, _ = self.down_proj(x)
  159. return x
  160. # Modified from aphrodite.modeling.models.llama.LlamaAttention -> ChameleonAttention #noqa
  161. class ChameleonAttention(nn.Module):
  162. def __init__(
  163. self,
  164. hidden_size: int,
  165. num_heads: int,
  166. num_kv_heads: int,
  167. rope_theta: float = 10000,
  168. rope_scaling: Optional[Dict[str, Any]] = None,
  169. max_position_embeddings: int = 4096,
  170. quant_config: Optional[QuantizationConfig] = None,
  171. bias: bool = False,
  172. cache_config: Optional[CacheConfig] = None,
  173. ) -> None:
  174. super().__init__()
  175. self.hidden_size = hidden_size
  176. tp_size = get_tensor_model_parallel_world_size()
  177. self.total_num_heads = num_heads
  178. assert self.total_num_heads % tp_size == 0
  179. self.num_heads = self.total_num_heads // tp_size
  180. self.total_num_kv_heads = num_kv_heads
  181. if self.total_num_kv_heads >= tp_size:
  182. # Number of KV heads is greater than TP size, so we partition
  183. # the KV heads across multiple tensor parallel GPUs.
  184. assert self.total_num_kv_heads % tp_size == 0
  185. else:
  186. # Number of KV heads is less than TP size, so we replicate
  187. # the KV heads across multiple tensor parallel GPUs.
  188. assert tp_size % self.total_num_kv_heads == 0
  189. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  190. self.head_dim = hidden_size // self.total_num_heads
  191. self.q_size = self.num_heads * self.head_dim
  192. self.kv_size = self.num_kv_heads * self.head_dim
  193. self.scaling = self.head_dim**-0.5
  194. self.rope_theta = rope_theta
  195. self.max_position_embeddings = max_position_embeddings
  196. self.qkv_proj = QKVParallelLinear(
  197. hidden_size=hidden_size,
  198. head_size=self.head_dim,
  199. total_num_heads=self.total_num_heads,
  200. total_num_kv_heads=self.total_num_kv_heads,
  201. bias=bias,
  202. quant_config=quant_config,
  203. )
  204. self.o_proj = RowParallelLinear(
  205. input_size=self.total_num_heads * self.head_dim,
  206. output_size=hidden_size,
  207. bias=bias,
  208. quant_config=quant_config,
  209. )
  210. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  211. self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
  212. self.rotary_emb = get_rope(
  213. self.head_dim,
  214. rotary_dim=self.head_dim,
  215. max_position=max_position_embeddings,
  216. base=rope_theta,
  217. rope_scaling=rope_scaling,
  218. )
  219. self.attn = Attention(self.num_heads,
  220. self.head_dim,
  221. self.scaling,
  222. num_kv_heads=self.num_kv_heads,
  223. cache_config=cache_config,
  224. quant_config=quant_config)
  225. def _apply_qk_norm(self, q: torch.Tensor,
  226. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  227. # reshape for layernorm
  228. q = q.reshape(-1, self.num_heads, self.head_dim)
  229. k = k.reshape(-1, self.num_kv_heads, self.head_dim)
  230. q = self.q_norm(q)
  231. k = self.k_norm(k)
  232. q = q.view(*q.shape[:-2], -1)
  233. k = k.view(*k.shape[:-2], -1)
  234. return q, k
  235. def forward(
  236. self,
  237. positions: torch.Tensor,
  238. hidden_states: torch.Tensor,
  239. kv_cache: torch.Tensor,
  240. attn_metadata: AttentionMetadata,
  241. ) -> torch.Tensor:
  242. qkv, _ = self.qkv_proj(hidden_states)
  243. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  244. q, k = self._apply_qk_norm(q, k)
  245. q, k = self.rotary_emb(positions, q, k)
  246. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  247. output, _ = self.o_proj(attn_output)
  248. return output
  249. class ChameleonDecoderLayer(nn.Module):
  250. def __init__(
  251. self,
  252. config: ChameleonConfig,
  253. cache_config: Optional[CacheConfig] = None,
  254. quant_config: Optional[QuantizationConfig] = None,
  255. ) -> None:
  256. super().__init__()
  257. self.hidden_size = config.hidden_size
  258. rope_theta = getattr(config, "rope_theta", 10000)
  259. rope_scaling = getattr(config, "rope_scaling", None)
  260. if rope_scaling is not None and getattr(
  261. config, "original_max_position_embeddings", None):
  262. rope_scaling["original_max_position_embeddings"] = (
  263. config.original_max_position_embeddings)
  264. max_position_embeddings = getattr(config, "max_position_embeddings",
  265. 4096)
  266. self.self_attn = ChameleonAttention(
  267. hidden_size=self.hidden_size,
  268. num_heads=config.num_attention_heads,
  269. num_kv_heads=getattr(config, "num_key_value_heads",
  270. config.num_attention_heads),
  271. rope_theta=rope_theta,
  272. rope_scaling=rope_scaling,
  273. max_position_embeddings=max_position_embeddings,
  274. quant_config=quant_config,
  275. bias=False,
  276. cache_config=cache_config,
  277. )
  278. self.mlp = ChameleonMLP(
  279. hidden_size=self.hidden_size,
  280. intermediate_size=config.intermediate_size,
  281. hidden_act=config.hidden_act,
  282. quant_config=quant_config,
  283. bias=getattr(config, "mlp_bias", False),
  284. )
  285. self.input_layernorm = RMSNorm(config.hidden_size,
  286. eps=config.rms_norm_eps)
  287. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  288. eps=config.rms_norm_eps)
  289. def forward(
  290. self,
  291. positions: torch.Tensor,
  292. hidden_states: torch.Tensor,
  293. kv_cache: torch.Tensor,
  294. attn_metadata: AttentionMetadata,
  295. residual: Optional[torch.Tensor],
  296. ) -> Tuple[torch.Tensor, torch.Tensor]:
  297. if residual is None:
  298. residual = hidden_states
  299. hidden_states = self.input_layernorm(hidden_states)
  300. else:
  301. hidden_states, residual = self.input_layernorm(
  302. hidden_states, residual)
  303. hidden_states = self.self_attn(
  304. positions=positions,
  305. hidden_states=hidden_states,
  306. kv_cache=kv_cache,
  307. attn_metadata=attn_metadata,
  308. )
  309. # Fully Connected
  310. hidden_states, residual = self.post_attention_layernorm(
  311. hidden_states, residual)
  312. hidden_states = self.mlp(hidden_states)
  313. return hidden_states, residual
  314. class ChameleonSwinDecoderLayer(nn.Module):
  315. def __init__(
  316. self,
  317. config: ChameleonConfig,
  318. cache_config: Optional[CacheConfig] = None,
  319. quant_config: Optional[QuantizationConfig] = None,
  320. ) -> None:
  321. super().__init__()
  322. self.hidden_size = config.hidden_size
  323. rope_theta = getattr(config, "rope_theta", 10000)
  324. rope_scaling = getattr(config, "rope_scaling", None)
  325. if rope_scaling is not None and getattr(
  326. config, "original_max_position_embeddings", None):
  327. rope_scaling["original_max_position_embeddings"] = (
  328. config.original_max_position_embeddings)
  329. max_position_embeddings = getattr(config, "max_position_embeddings",
  330. 4096)
  331. self.self_attn = ChameleonAttention(
  332. hidden_size=self.hidden_size,
  333. num_heads=config.num_attention_heads,
  334. num_kv_heads=getattr(config, "num_key_value_heads",
  335. config.num_attention_heads),
  336. rope_theta=rope_theta,
  337. rope_scaling=rope_scaling,
  338. max_position_embeddings=max_position_embeddings,
  339. quant_config=quant_config,
  340. bias=False,
  341. cache_config=cache_config,
  342. )
  343. self.mlp = ChameleonMLP(
  344. hidden_size=self.hidden_size,
  345. intermediate_size=config.intermediate_size,
  346. hidden_act=config.hidden_act,
  347. quant_config=quant_config,
  348. bias=getattr(config, "mlp_bias", False),
  349. )
  350. self.input_layernorm = RMSNorm(config.hidden_size,
  351. eps=config.rms_norm_eps)
  352. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  353. eps=config.rms_norm_eps)
  354. def forward(
  355. self,
  356. positions: torch.Tensor,
  357. hidden_states: torch.Tensor,
  358. kv_cache: torch.Tensor,
  359. attn_metadata: AttentionMetadata,
  360. residual: Optional[torch.Tensor],
  361. ) -> Tuple[torch.Tensor, torch.Tensor]:
  362. residual = hidden_states
  363. hidden_states = self.self_attn(
  364. positions=positions,
  365. hidden_states=hidden_states,
  366. kv_cache=kv_cache,
  367. attn_metadata=attn_metadata,
  368. )
  369. hidden_states = self.input_layernorm(hidden_states)
  370. hidden_states = hidden_states + residual
  371. # Fully Connected
  372. residual = hidden_states
  373. hidden_states = self.mlp(hidden_states)
  374. hidden_states = self.post_attention_layernorm(hidden_states)
  375. hidden_states = residual + hidden_states
  376. return hidden_states, residual
  377. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
  378. class ChameleonVQVAEVectorQuantizer(nn.Module):
  379. def __init__(self, config: ChameleonVQVAEConfig):
  380. super().__init__()
  381. self.num_embeddings = config.num_embeddings
  382. self.embedding_dim = config.embed_dim
  383. self.beta = getattr(config, "beta", 0.25)
  384. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  385. self.re_embed = self.num_embeddings
  386. def forward(self, hidden_state: torch.Tensor):
  387. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  388. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  389. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  390. distances = (
  391. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
  392. torch.sum(self.embedding.weight**2, dim=1) -
  393. 2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
  394. self.embedding.weight.transpose(0, 1)))
  395. min_encoding_indices = torch.argmin(distances, dim=1)
  396. hidden_state_quant = self.embedding(min_encoding_indices).view(
  397. hidden_state.shape)
  398. # compute loss for embedding
  399. loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
  400. 2) + self.beta * torch.mean(
  401. (hidden_state_quant - hidden_state.detach())**2)
  402. # preserve gradients
  403. hidden_state_quant = hidden_state + (hidden_state_quant -
  404. hidden_state).detach()
  405. # reshape back to match original input shape
  406. hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
  407. 2).contiguous()
  408. return hidden_state_quant, loss, min_encoding_indices
  409. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
  410. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  411. def __init__(self, in_channels: int):
  412. super().__init__()
  413. self.conv = nn.Conv2d(in_channels,
  414. in_channels,
  415. kernel_size=3,
  416. stride=2,
  417. padding=0)
  418. def forward(self, hidden_states: torch.Tensor):
  419. # no asymmetric padding in torch conv, must do it ourselves
  420. hidden_states = F.pad(hidden_states,
  421. pad=(0, 1, 0, 1),
  422. mode="constant",
  423. value=0)
  424. hidden_states = self.conv(hidden_states)
  425. return hidden_states
  426. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
  427. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  428. def __init__(
  429. self,
  430. config: ChameleonVQVAEConfig,
  431. in_channels: int,
  432. out_channels=None,
  433. conv_shortcut=False,
  434. ):
  435. super().__init__()
  436. self.in_channels = in_channels
  437. self.out_channels = in_channels if out_channels is None \
  438. else out_channels
  439. self.use_conv_shortcut = conv_shortcut
  440. self.norm1 = torch.nn.GroupNorm(num_groups=32,
  441. num_channels=in_channels,
  442. eps=1e-6,
  443. affine=True)
  444. self.conv1 = torch.nn.Conv2d(in_channels,
  445. out_channels,
  446. kernel_size=3,
  447. stride=1,
  448. padding=1)
  449. self.norm2 = torch.nn.GroupNorm(num_groups=32,
  450. num_channels=out_channels,
  451. eps=1e-6,
  452. affine=True)
  453. self.dropout = torch.nn.Dropout(config.dropout)
  454. self.conv2 = torch.nn.Conv2d(out_channels,
  455. out_channels,
  456. kernel_size=3,
  457. stride=1,
  458. padding=1)
  459. if self.in_channels != self.out_channels:
  460. if self.use_conv_shortcut:
  461. self.conv_shortcut = torch.nn.Conv2d(in_channels,
  462. out_channels,
  463. kernel_size=3,
  464. stride=1,
  465. padding=1)
  466. else:
  467. self.nin_shortcut = torch.nn.Conv2d(in_channels,
  468. out_channels,
  469. kernel_size=1,
  470. stride=1,
  471. padding=0)
  472. def forward(self, hidden_states: torch.Tensor):
  473. residual = hidden_states
  474. hidden_states = self.norm1(hidden_states)
  475. hidden_states *= torch.sigmoid(hidden_states)
  476. hidden_states = self.conv1(hidden_states)
  477. hidden_states = self.norm2(hidden_states)
  478. hidden_states *= torch.sigmoid(hidden_states)
  479. hidden_states = self.dropout(hidden_states)
  480. hidden_states = self.conv2(hidden_states)
  481. if self.in_channels != self.out_channels:
  482. if self.use_conv_shortcut:
  483. residual = self.conv_shortcut(residual)
  484. else:
  485. residual = self.nin_shortcut(residual)
  486. return residual + hidden_states
  487. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
  488. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  489. def __init__(self, in_channels: int):
  490. super().__init__()
  491. self.in_channels = in_channels
  492. self.norm = torch.nn.GroupNorm(num_groups=32,
  493. num_channels=in_channels,
  494. eps=1e-6,
  495. affine=True)
  496. self.q = torch.nn.Conv2d(in_channels,
  497. in_channels,
  498. kernel_size=1,
  499. stride=1,
  500. padding=0)
  501. self.k = torch.nn.Conv2d(in_channels,
  502. in_channels,
  503. kernel_size=1,
  504. stride=1,
  505. padding=0)
  506. self.v = torch.nn.Conv2d(in_channels,
  507. in_channels,
  508. kernel_size=1,
  509. stride=1,
  510. padding=0)
  511. self.proj_out = torch.nn.Conv2d(in_channels,
  512. in_channels,
  513. kernel_size=1,
  514. stride=1,
  515. padding=0)
  516. def forward(self, hidden_states: torch.Tensor):
  517. residual = hidden_states
  518. hidden_states = self.norm(hidden_states)
  519. query_states = self.q(hidden_states)
  520. key_states = self.k(hidden_states)
  521. value_states = self.v(hidden_states)
  522. # compute attention
  523. batch_size, channels, height, width = query_states.shape
  524. query_states = query_states.reshape(batch_size, channels,
  525. height * width).permute(0, 2, 1)
  526. key_states = key_states.reshape(batch_size, channels, height * width)
  527. attn_weights = torch.bmm(query_states, key_states)
  528. attn_weights = attn_weights * (int(channels)**(-0.5))
  529. attn_weights = F.softmax(attn_weights, dim=2)
  530. # attend to values
  531. value_states = value_states.reshape(batch_size, channels,
  532. height * width)
  533. attn_weights = attn_weights.permute(0, 2, 1)
  534. attn_output = torch.bmm(value_states,
  535. attn_weights).reshape(batch_size, channels,
  536. height, width)
  537. attn_output = self.proj_out(attn_output)
  538. return residual + attn_output
  539. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
  540. class ChameleonVQVAEEncoder(nn.Module):
  541. def __init__(self, config: ChameleonVQVAEConfig):
  542. super().__init__()
  543. self.num_resolutions = len(config.channel_multiplier)
  544. self.num_res_blocks = config.num_res_blocks
  545. base_channels = config.base_channels
  546. resolution = config.resolution
  547. in_channels = config.in_channels
  548. double_latent = config.double_latent
  549. latent_channels = config.latent_channels
  550. channel_multiplier = config.channel_multiplier
  551. self.conv_in = torch.nn.Conv2d(in_channels,
  552. base_channels,
  553. kernel_size=3,
  554. stride=1,
  555. padding=1)
  556. curr_res = resolution
  557. in_channel_multiplier = (1, ) + tuple(channel_multiplier)
  558. self.in_channel_multiplier = in_channel_multiplier
  559. self.down = nn.ModuleList()
  560. for i_level in range(self.num_resolutions):
  561. block = nn.ModuleList()
  562. attn = nn.ModuleList()
  563. block_in = base_channels * in_channel_multiplier[i_level]
  564. block_out = base_channels * channel_multiplier[i_level]
  565. for i_block in range(self.num_res_blocks):
  566. block.append(
  567. ChameleonVQVAEEncoderResnetBlock(
  568. config=config,
  569. in_channels=block_in,
  570. out_channels=block_out,
  571. ))
  572. block_in = block_out
  573. if (config.attn_resolutions is not None
  574. and curr_res in config.attn_resolutions
  575. and config.attn_type == "vanilla"):
  576. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  577. down = nn.Module()
  578. down.block = block
  579. down.attn = attn
  580. if i_level != self.num_resolutions - 1:
  581. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  582. curr_res = curr_res // 2
  583. self.down.append(down)
  584. self.mid = nn.Module()
  585. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  586. config=config,
  587. in_channels=block_in,
  588. out_channels=block_in,
  589. )
  590. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
  591. block_in) if config.attn_type == "vanilla" else nn.Identity()
  592. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  593. config=config,
  594. in_channels=block_in,
  595. out_channels=block_in,
  596. )
  597. self.norm_out = torch.nn.GroupNorm(num_groups=32,
  598. num_channels=block_in,
  599. eps=1e-6,
  600. affine=True)
  601. self.conv_out = torch.nn.Conv2d(
  602. block_in,
  603. 2 * latent_channels if double_latent else latent_channels,
  604. kernel_size=3,
  605. stride=1,
  606. padding=1,
  607. )
  608. def forward(self, pixel_values: torch.Tensor):
  609. pixel_values = pixel_values.to(self.conv_in.weight.dtype)
  610. # downsampling
  611. hidden_states = [self.conv_in(pixel_values)]
  612. for i_level in range(self.num_resolutions):
  613. for i_block in range(self.num_res_blocks):
  614. hidden_state = self.down[i_level].block[i_block](
  615. hidden_states[-1], )
  616. if len(self.down[i_level].attn) > 0:
  617. hidden_state = self.down[i_level].attn[i_block](
  618. hidden_state)
  619. hidden_states.append(hidden_state)
  620. if i_level != self.num_resolutions - 1:
  621. hidden_states.append(self.down[i_level].downsample(
  622. hidden_states[-1]))
  623. # middle
  624. last_hidden_state = hidden_states[-1]
  625. last_hidden_state = self.mid.block_1(last_hidden_state)
  626. last_hidden_state = self.mid.attn_1(last_hidden_state)
  627. last_hidden_state = self.mid.block_2(last_hidden_state)
  628. # end
  629. last_hidden_state = self.norm_out(last_hidden_state)
  630. last_hidden_state *= torch.sigmoid(last_hidden_state)
  631. last_hidden_state = self.conv_out(last_hidden_state)
  632. return last_hidden_state
  633. # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
  634. class ChameleonVQVAE(nn.Module):
  635. def __init__(self, config: ChameleonVQVAEConfig):
  636. super().__init__()
  637. self.encoder = ChameleonVQVAEEncoder(config)
  638. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  639. self.quant_conv = torch.nn.Conv2d(config.latent_channels,
  640. config.embed_dim, 1)
  641. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
  642. config.latent_channels, 1)
  643. self.eval() # Chameleon's VQ model is frozen
  644. def encode(
  645. self, pixel_values: torch.Tensor
  646. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  647. hidden_states = self.encoder(pixel_values)
  648. hidden_states = self.quant_conv(hidden_states)
  649. quant, emb_loss, indices = self.quantize(hidden_states)
  650. return quant, emb_loss, indices
  651. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
  652. class ChameleonImageVocabularyMapping:
  653. """
  654. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  655. """
  656. def __init__(self, vocab_map: Dict[str, int]):
  657. self.vocab_map = vocab_map
  658. self.image_token_id = vocab_map.get("<image>")
  659. @cached_property
  660. def val2name(self):
  661. return {v: k for k, v in self.vocab_map.items()}
  662. @cached_property
  663. def image_tokens(self):
  664. return sorted([
  665. val for name, val in self.vocab_map.items()
  666. if name.startswith("IMGIMG")
  667. ])
  668. @cached_property
  669. def bpe2img(self):
  670. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  671. def remap(old_name: str) -> str:
  672. return "".join(
  673. img_tkn_chr_mapping.get(c, c)
  674. for c in old_name[len("IMGIMG"):-1])
  675. return {
  676. tok: int(remap(self.val2name[tok]))
  677. for tok in self.image_tokens
  678. }
  679. @cached_property
  680. def img2bpe(self):
  681. return {v: k for k, v in self.bpe2img.items()}
  682. @cached_property
  683. def bpe2img_search_tensors(self):
  684. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
  685. sorted(self.bpe2img.values()))
  686. @cached_property
  687. def img2bpe_mapping_tensor(self):
  688. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  689. for k, v in self.img2bpe.items():
  690. mapping[k] = v
  691. return mapping
  692. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  693. device = img_batch.device
  694. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  695. return img_tokens.to(device)
  696. class ChameleonModel(nn.Module):
  697. def __init__(
  698. self,
  699. config: ChameleonConfig,
  700. cache_config: Optional[CacheConfig] = None,
  701. quant_config: Optional[QuantizationConfig] = None,
  702. ) -> None:
  703. super().__init__()
  704. self.config = config
  705. self.padding_idx = config.pad_token_id
  706. self.vocab_size = config.vocab_size
  707. self.embed_tokens = VocabParallelEmbedding(
  708. self.vocab_size,
  709. config.hidden_size,
  710. )
  711. self.vocabulary_mapping = ChameleonImageVocabularyMapping(
  712. config.vocabulary_map)
  713. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
  714. else ChameleonSwinDecoderLayer
  715. self.layers = nn.ModuleList([
  716. decoder_layer(config=config,
  717. cache_config=cache_config,
  718. quant_config=quant_config)
  719. for _ in range(config.num_hidden_layers)
  720. ])
  721. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  722. self.vqmodel = ChameleonVQVAE(config.vq_config)
  723. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  724. return self.embed_tokens(input_ids)
  725. def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
  726. """
  727. Tokenizes images into discrete tokens with VQGAN module. Converts
  728. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  729. special tokens.
  730. """
  731. batch_size = pixel_values.shape[0]
  732. _, _, image_toks = self.vqmodel.encode(pixel_values)
  733. bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
  734. bpe_toks = bpe_toks.view(batch_size, -1)
  735. return bpe_toks
  736. def forward(
  737. self,
  738. input_ids: Optional[torch.Tensor],
  739. positions: torch.Tensor,
  740. kv_caches: List[torch.Tensor],
  741. attn_metadata: AttentionMetadata,
  742. inputs_embeds: Optional[torch.Tensor] = None,
  743. ) -> torch.Tensor:
  744. if inputs_embeds is not None:
  745. hidden_states = inputs_embeds
  746. else:
  747. hidden_states = self.get_input_embeddings(input_ids)
  748. residual = None
  749. for i in range(len(self.layers)):
  750. layer = self.layers[i]
  751. hidden_states, residual = layer(
  752. positions,
  753. hidden_states,
  754. kv_caches[i],
  755. attn_metadata,
  756. residual,
  757. )
  758. hidden_states, _ = self.norm(hidden_states, residual)
  759. return hidden_states
  760. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  761. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
  762. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
  763. @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
  764. class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
  765. def __init__(
  766. self,
  767. config: ChameleonConfig,
  768. multimodal_config: MultiModalConfig,
  769. cache_config: Optional[CacheConfig] = None,
  770. quant_config: Optional[QuantizationConfig] = None,
  771. ) -> None:
  772. super().__init__()
  773. self.config = config
  774. self.multimodal_config = multimodal_config
  775. self.model = ChameleonModel(config, cache_config, quant_config)
  776. self.unpadded_vocab_size = config.vocab_size
  777. self.lm_head = ParallelLMHead(
  778. self.unpadded_vocab_size,
  779. config.hidden_size,
  780. )
  781. if config.tie_word_embeddings:
  782. self.lm_head.weight = self.model.embed_tokens.weight
  783. logit_scale = getattr(config, "logit_scale", 1.0)
  784. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  785. config.vocab_size, logit_scale)
  786. self.sampler = Sampler()
  787. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  788. expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
  789. CHAMELEON_CROP_SIZE_WIDTH)
  790. actual_dims = tuple(data.shape[1:])
  791. if actual_dims != expected_dims:
  792. expected_expr = ("batch_size", *map(str, expected_dims))
  793. raise ValueError(
  794. f"The expected shape of pixel values is {expected_expr}. "
  795. f"You supplied {tuple(data.shape)}.")
  796. return data
  797. def _parse_and_validate_image_input(
  798. self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
  799. pixel_values = kwargs.pop("pixel_values", None)
  800. if pixel_values is None:
  801. return None
  802. if not isinstance(pixel_values, torch.Tensor):
  803. raise ValueError("Incorrect type of pixel values. "
  804. f"Got type: {type(pixel_values)}")
  805. return ChameleonImagePixelInputs(
  806. type="pixel_values",
  807. data=self._validate_pixel_values(pixel_values),
  808. )
  809. def forward(
  810. self,
  811. input_ids: torch.Tensor,
  812. positions: torch.Tensor,
  813. kv_caches: List[torch.Tensor],
  814. attn_metadata: AttentionMetadata,
  815. intermediate_tensors: Optional[IntermediateTensors] = None,
  816. **kwargs,
  817. ) -> torch.Tensor:
  818. image_input = self._parse_and_validate_image_input(**kwargs)
  819. if image_input is not None:
  820. assert self.model.vqmodel is not None
  821. image_tokens = self.model.get_image_tokens(image_input["data"].to(
  822. self.config.torch_dtype))
  823. image_token_id = self.model.vocabulary_mapping.image_token_id
  824. special_image_mask = input_ids == image_token_id
  825. image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
  826. input_ids = input_ids.masked_scatter(special_image_mask,
  827. image_tokens)
  828. hidden_states = self.model(input_ids, positions, kv_caches,
  829. attn_metadata)
  830. return hidden_states
  831. def compute_logits(
  832. self,
  833. hidden_states: torch.Tensor,
  834. sampling_metadata: SamplingMetadata,
  835. ) -> Optional[torch.Tensor]:
  836. logits = self.logits_processor(self.lm_head, hidden_states,
  837. sampling_metadata)
  838. # Disallow image tokens which does not include special
  839. # begin-image and end-image tokens
  840. if logits is not None:
  841. image_tokens = self.model.vocabulary_mapping.image_tokens
  842. logits[:, image_tokens] = torch.finfo(logits.dtype).min
  843. return logits
  844. def sample(
  845. self,
  846. logits: torch.Tensor,
  847. sampling_metadata: SamplingMetadata,
  848. ) -> Optional[SamplerOutput]:
  849. next_tokens = self.sampler(logits, sampling_metadata)
  850. return next_tokens
  851. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  852. stacked_params_mapping = [
  853. # (param_name, shard_name, shard_id)
  854. (".qkv_proj", ".q_proj", "q"),
  855. (".qkv_proj", ".k_proj", "k"),
  856. (".qkv_proj", ".v_proj", "v"),
  857. (".gate_up_proj", ".gate_proj", 0),
  858. (".gate_up_proj", ".up_proj", 1),
  859. ]
  860. params_dict = dict(self.named_parameters())
  861. for name, loaded_weight in weights:
  862. if "rotary_emb.inv_freq" in name:
  863. continue
  864. if ("rotary_emb.cos_cached" in name
  865. or "rotary_emb.sin_cached" in name):
  866. # Models trained using ColossalAI may include these tensors in
  867. # the checkpoint. Skip them.
  868. continue
  869. # With tie_word_embeddings, we can skip lm_head.weight
  870. # The weight might appear unnecessarily in the files if the model is
  871. # processed with quantization, LoRA, fine-tuning, etc.
  872. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  873. continue
  874. use_default_weight_loading = False
  875. if "vqmodel" in name:
  876. if self.model.vqmodel is not None:
  877. # We only do sharding for language model and
  878. # not vqvae for now.
  879. use_default_weight_loading = True
  880. else:
  881. for (param_name, weight_name,
  882. shard_id) in stacked_params_mapping:
  883. if weight_name not in name:
  884. continue
  885. name = name.replace(weight_name, param_name)
  886. # Skip loading extra bias for GPTQ models.
  887. if name.endswith(".bias") and name not in params_dict:
  888. continue
  889. param = params_dict[name]
  890. weight_loader = param.weight_loader
  891. weight_loader(param, loaded_weight, shard_id)
  892. break
  893. else:
  894. # Skip loading extra bias for GPTQ models.
  895. if name.endswith(".bias") and name not in params_dict:
  896. continue
  897. # Remapping the name of FP8 kv-scale.
  898. if name.endswith("kv_scale"):
  899. remapped_kv_scale_name = name.replace(
  900. ".kv_scale", ".attn.kv_scale")
  901. if remapped_kv_scale_name not in params_dict:
  902. print_warning_once(
  903. "Found kv scale in the checkpoint (e.g. "
  904. f"{name}), but not found the expected name in "
  905. f"the model (e.g. {remapped_kv_scale_name}). "
  906. "kv-scale is not loaded.")
  907. continue
  908. else:
  909. name = remapped_kv_scale_name
  910. param = params_dict[name]
  911. weight_loader = getattr(param, "weight_loader",
  912. default_weight_loader)
  913. weight_loader(param, loaded_weight)
  914. if use_default_weight_loading and name in params_dict:
  915. param = params_dict[name]
  916. weight_loader = getattr(param, "weight_loader",
  917. default_weight_loader)
  918. weight_loader(param, loaded_weight)