chameleon.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072
  1. from array import array
  2. from functools import cached_property
  3. from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
  4. Tuple, TypedDict)
  5. import torch
  6. import torch.nn.functional as F
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import ChameleonConfig, ChameleonVQVAEConfig
  10. from aphrodite.attention import Attention, AttentionMetadata
  11. from aphrodite.common.config import CacheConfig, MultiModalConfig
  12. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  13. from aphrodite.common.utils import print_warning_once
  14. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  15. from aphrodite.distributed import get_tensor_model_parallel_world_size
  16. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  17. from aphrodite.modeling.layers.activation import SiluAndMul
  18. from aphrodite.modeling.layers.layernorm import RMSNorm
  19. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  20. QKVParallelLinear,
  21. RowParallelLinear)
  22. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  23. from aphrodite.modeling.layers.rotary_embedding import get_rope
  24. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  25. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  26. ParallelLMHead, VocabParallelEmbedding)
  27. from aphrodite.modeling.model_loader.weight_utils import (
  28. default_weight_loader, row_parallel_weight_loader)
  29. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  30. from aphrodite.modeling.utils import set_weight_attrs
  31. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  32. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  33. repeat_and_pad_placeholder_tokens)
  34. from aphrodite.quantization.base_config import QuantizationConfig
  35. from .interfaces import SupportsMultiModal
  36. # These configs are not part of the model config but the preprocessor
  37. # and processor files, so we hardcode them in the model file for now.
  38. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
  39. CHAMELEON_IMAGE_SEQ_LENGTH = 1024
  40. CHAMELEON_IMAGE_TOKEN_ID = 8711
  41. CHAMELEON_IMAGE_START_TOKEN_ID = 8197
  42. CHAMELEON_IMAGE_END_TOKEN_ID = 8196
  43. CHAMELEON_SEP_TOKEN_ID = 8710
  44. class ChameleonImagePixelInputs(TypedDict):
  45. type: Literal["pixel_values"]
  46. data: torch.Tensor
  47. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  48. def get_max_chameleon_image_tokens(ctx: InputContext):
  49. return CHAMELEON_IMAGE_SEQ_LENGTH
  50. def dummy_seq_data_for_chameleon(
  51. seq_len: int,
  52. num_images: int,
  53. *,
  54. image_token_id: int,
  55. image_feature_size_override: Optional[int] = None,
  56. ):
  57. if image_feature_size_override is None:
  58. image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
  59. else:
  60. image_feature_size = image_feature_size_override
  61. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  62. [image_token_id]) * image_feature_size * num_images
  63. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  64. [0]) * (seq_len - image_feature_size * num_images)
  65. return SequenceData(token_ids)
  66. def dummy_image_for_chameleon(
  67. num_images: int,
  68. *,
  69. image_width_override: Optional[int] = None,
  70. image_height_override: Optional[int] = None,
  71. ):
  72. width = CHAMELEON_CROP_SIZE_WIDTH
  73. height = CHAMELEON_CROP_SIZE_HEIGHT
  74. if image_width_override is not None:
  75. width = image_width_override
  76. if image_height_override is not None:
  77. height = image_height_override
  78. image = Image.new("RGB", (width, height), color=0)
  79. return {"image": image if num_images == 1 else [image] * num_images}
  80. def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
  81. mm_counts: Mapping[str, int]):
  82. num_images = mm_counts["image"]
  83. seq_data = dummy_seq_data_for_chameleon(
  84. seq_len,
  85. num_images,
  86. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  87. )
  88. mm_data = dummy_image_for_chameleon(num_images)
  89. return seq_data, mm_data
  90. def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
  91. """
  92. Processing input prompt to insert required tokens for image placeholder.
  93. See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
  94. """ # noqa
  95. multi_modal_data = llm_inputs.get("multi_modal_data")
  96. if multi_modal_data is None or "image" not in multi_modal_data:
  97. return llm_inputs
  98. model_config = ctx.model_config
  99. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  100. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  101. tokenizer,
  102. llm_inputs.get("prompt"),
  103. llm_inputs["prompt_token_ids"],
  104. placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  105. repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
  106. pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
  107. pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
  108. )
  109. # Appending sep token for chat mode to follow default processor
  110. # behavior
  111. if new_prompt is not None:
  112. new_prompt += tokenizer.sep_token
  113. new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
  114. # NOTE: Create a defensive copy of the original inputs
  115. return LLMInputs(prompt_token_ids=new_token_ids,
  116. prompt=new_prompt,
  117. multi_modal_data=multi_modal_data)
  118. class ChameleonLayerNorm(nn.LayerNorm):
  119. def __init__(self, hidden_size, *args, **kwargs):
  120. super().__init__(hidden_size, *args, **kwargs)
  121. self.normalized_shape = (hidden_size[-1], )
  122. set_weight_attrs(self.weight,
  123. {"weight_loader": row_parallel_weight_loader})
  124. set_weight_attrs(self.bias,
  125. {"weight_loader": row_parallel_weight_loader})
  126. def forward(self, hidden_states):
  127. hidden_states = F.layer_norm(hidden_states,
  128. self.normalized_shape,
  129. None,
  130. None,
  131. eps=1e-5)
  132. hidden_states = hidden_states * self.weight + self.bias
  133. return hidden_states
  134. # Copied from aphrodite.modeling.models.llama.LlamaMLP -> ChameleonMLP
  135. class ChameleonMLP(nn.Module):
  136. def __init__(
  137. self,
  138. hidden_size: int,
  139. intermediate_size: int,
  140. hidden_act: str,
  141. quant_config: Optional[QuantizationConfig] = None,
  142. bias: bool = False,
  143. ) -> None:
  144. super().__init__()
  145. self.gate_up_proj = MergedColumnParallelLinear(
  146. input_size=hidden_size,
  147. output_sizes=[intermediate_size] * 2,
  148. bias=bias,
  149. quant_config=quant_config)
  150. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  151. output_size=hidden_size,
  152. bias=bias,
  153. quant_config=quant_config)
  154. if hidden_act != "silu":
  155. raise ValueError(f"Unsupported activation: {hidden_act}. "
  156. "Only silu is supported for now.")
  157. self.act_fn = SiluAndMul()
  158. def forward(self, x):
  159. gate_up, _ = self.gate_up_proj(x)
  160. x = self.act_fn(gate_up)
  161. x, _ = self.down_proj(x)
  162. return x
  163. # Modified from aphrodite.modeling.models.llama.LlamaAttention -> ChameleonAttention #noqa
  164. class ChameleonAttention(nn.Module):
  165. def __init__(
  166. self,
  167. hidden_size: int,
  168. num_heads: int,
  169. num_kv_heads: int,
  170. rope_theta: float = 10000,
  171. rope_scaling: Optional[Dict[str, Any]] = None,
  172. max_position_embeddings: int = 4096,
  173. quant_config: Optional[QuantizationConfig] = None,
  174. bias: bool = False,
  175. cache_config: Optional[CacheConfig] = None,
  176. ) -> None:
  177. super().__init__()
  178. self.hidden_size = hidden_size
  179. tp_size = get_tensor_model_parallel_world_size()
  180. self.total_num_heads = num_heads
  181. assert self.total_num_heads % tp_size == 0
  182. self.num_heads = self.total_num_heads // tp_size
  183. self.total_num_kv_heads = num_kv_heads
  184. if self.total_num_kv_heads >= tp_size:
  185. # Number of KV heads is greater than TP size, so we partition
  186. # the KV heads across multiple tensor parallel GPUs.
  187. assert self.total_num_kv_heads % tp_size == 0
  188. else:
  189. # Number of KV heads is less than TP size, so we replicate
  190. # the KV heads across multiple tensor parallel GPUs.
  191. assert tp_size % self.total_num_kv_heads == 0
  192. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  193. self.head_dim = hidden_size // self.total_num_heads
  194. self.q_size = self.num_heads * self.head_dim
  195. self.kv_size = self.num_kv_heads * self.head_dim
  196. self.scaling = self.head_dim**-0.5
  197. self.rope_theta = rope_theta
  198. self.max_position_embeddings = max_position_embeddings
  199. self.qkv_proj = QKVParallelLinear(
  200. hidden_size=hidden_size,
  201. head_size=self.head_dim,
  202. total_num_heads=self.total_num_heads,
  203. total_num_kv_heads=self.total_num_kv_heads,
  204. bias=bias,
  205. quant_config=quant_config,
  206. )
  207. self.o_proj = RowParallelLinear(
  208. input_size=self.total_num_heads * self.head_dim,
  209. output_size=hidden_size,
  210. bias=bias,
  211. quant_config=quant_config,
  212. )
  213. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  214. self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
  215. self.rotary_emb = get_rope(
  216. self.head_dim,
  217. rotary_dim=self.head_dim,
  218. max_position=max_position_embeddings,
  219. base=rope_theta,
  220. rope_scaling=rope_scaling,
  221. )
  222. self.attn = Attention(self.num_heads,
  223. self.head_dim,
  224. self.scaling,
  225. num_kv_heads=self.num_kv_heads,
  226. cache_config=cache_config,
  227. quant_config=quant_config)
  228. def _apply_qk_norm(self, q: torch.Tensor,
  229. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  230. # reshape for layernorm
  231. q = q.reshape(-1, self.num_heads, self.head_dim)
  232. k = k.reshape(-1, self.num_kv_heads, self.head_dim)
  233. q = self.q_norm(q)
  234. k = self.k_norm(k)
  235. q = q.view(*q.shape[:-2], -1)
  236. k = k.view(*k.shape[:-2], -1)
  237. return q, k
  238. def forward(
  239. self,
  240. positions: torch.Tensor,
  241. hidden_states: torch.Tensor,
  242. kv_cache: torch.Tensor,
  243. attn_metadata: AttentionMetadata,
  244. ) -> torch.Tensor:
  245. qkv, _ = self.qkv_proj(hidden_states)
  246. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  247. q, k = self._apply_qk_norm(q, k)
  248. q, k = self.rotary_emb(positions, q, k)
  249. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  250. output, _ = self.o_proj(attn_output)
  251. return output
  252. class ChameleonDecoderLayer(nn.Module):
  253. def __init__(
  254. self,
  255. config: ChameleonConfig,
  256. cache_config: Optional[CacheConfig] = None,
  257. quant_config: Optional[QuantizationConfig] = None,
  258. ) -> None:
  259. super().__init__()
  260. self.hidden_size = config.hidden_size
  261. rope_theta = getattr(config, "rope_theta", 10000)
  262. rope_scaling = getattr(config, "rope_scaling", None)
  263. if rope_scaling is not None and getattr(
  264. config, "original_max_position_embeddings", None):
  265. rope_scaling["original_max_position_embeddings"] = (
  266. config.original_max_position_embeddings)
  267. max_position_embeddings = getattr(config, "max_position_embeddings",
  268. 4096)
  269. self.self_attn = ChameleonAttention(
  270. hidden_size=self.hidden_size,
  271. num_heads=config.num_attention_heads,
  272. num_kv_heads=getattr(config, "num_key_value_heads",
  273. config.num_attention_heads),
  274. rope_theta=rope_theta,
  275. rope_scaling=rope_scaling,
  276. max_position_embeddings=max_position_embeddings,
  277. quant_config=quant_config,
  278. bias=False,
  279. cache_config=cache_config,
  280. )
  281. self.mlp = ChameleonMLP(
  282. hidden_size=self.hidden_size,
  283. intermediate_size=config.intermediate_size,
  284. hidden_act=config.hidden_act,
  285. quant_config=quant_config,
  286. bias=getattr(config, "mlp_bias", False),
  287. )
  288. self.input_layernorm = RMSNorm(config.hidden_size,
  289. eps=config.rms_norm_eps)
  290. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  291. eps=config.rms_norm_eps)
  292. def forward(
  293. self,
  294. positions: torch.Tensor,
  295. hidden_states: torch.Tensor,
  296. kv_cache: torch.Tensor,
  297. attn_metadata: AttentionMetadata,
  298. residual: Optional[torch.Tensor],
  299. ) -> Tuple[torch.Tensor, torch.Tensor]:
  300. if residual is None:
  301. residual = hidden_states
  302. hidden_states = self.input_layernorm(hidden_states)
  303. else:
  304. hidden_states, residual = self.input_layernorm(
  305. hidden_states, residual)
  306. hidden_states = self.self_attn(
  307. positions=positions,
  308. hidden_states=hidden_states,
  309. kv_cache=kv_cache,
  310. attn_metadata=attn_metadata,
  311. )
  312. # Fully Connected
  313. hidden_states, residual = self.post_attention_layernorm(
  314. hidden_states, residual)
  315. hidden_states = self.mlp(hidden_states)
  316. return hidden_states, residual
  317. class ChameleonSwinDecoderLayer(nn.Module):
  318. def __init__(
  319. self,
  320. config: ChameleonConfig,
  321. cache_config: Optional[CacheConfig] = None,
  322. quant_config: Optional[QuantizationConfig] = None,
  323. ) -> None:
  324. super().__init__()
  325. self.hidden_size = config.hidden_size
  326. rope_theta = getattr(config, "rope_theta", 10000)
  327. rope_scaling = getattr(config, "rope_scaling", None)
  328. if rope_scaling is not None and getattr(
  329. config, "original_max_position_embeddings", None):
  330. rope_scaling["original_max_position_embeddings"] = (
  331. config.original_max_position_embeddings)
  332. max_position_embeddings = getattr(config, "max_position_embeddings",
  333. 4096)
  334. self.self_attn = ChameleonAttention(
  335. hidden_size=self.hidden_size,
  336. num_heads=config.num_attention_heads,
  337. num_kv_heads=getattr(config, "num_key_value_heads",
  338. config.num_attention_heads),
  339. rope_theta=rope_theta,
  340. rope_scaling=rope_scaling,
  341. max_position_embeddings=max_position_embeddings,
  342. quant_config=quant_config,
  343. bias=False,
  344. cache_config=cache_config,
  345. )
  346. self.mlp = ChameleonMLP(
  347. hidden_size=self.hidden_size,
  348. intermediate_size=config.intermediate_size,
  349. hidden_act=config.hidden_act,
  350. quant_config=quant_config,
  351. bias=getattr(config, "mlp_bias", False),
  352. )
  353. self.input_layernorm = RMSNorm(config.hidden_size,
  354. eps=config.rms_norm_eps)
  355. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  356. eps=config.rms_norm_eps)
  357. def forward(
  358. self,
  359. positions: torch.Tensor,
  360. hidden_states: torch.Tensor,
  361. kv_cache: torch.Tensor,
  362. attn_metadata: AttentionMetadata,
  363. residual: Optional[torch.Tensor],
  364. ) -> Tuple[torch.Tensor, torch.Tensor]:
  365. residual = hidden_states
  366. hidden_states = self.self_attn(
  367. positions=positions,
  368. hidden_states=hidden_states,
  369. kv_cache=kv_cache,
  370. attn_metadata=attn_metadata,
  371. )
  372. hidden_states = self.input_layernorm(hidden_states)
  373. hidden_states = hidden_states + residual
  374. # Fully Connected
  375. residual = hidden_states
  376. hidden_states = self.mlp(hidden_states)
  377. hidden_states = self.post_attention_layernorm(hidden_states)
  378. hidden_states = residual + hidden_states
  379. return hidden_states, residual
  380. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
  381. class ChameleonVQVAEVectorQuantizer(nn.Module):
  382. def __init__(self, config: ChameleonVQVAEConfig):
  383. super().__init__()
  384. self.num_embeddings = config.num_embeddings
  385. self.embedding_dim = config.embed_dim
  386. self.beta = getattr(config, "beta", 0.25)
  387. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  388. self.re_embed = self.num_embeddings
  389. def forward(self, hidden_state: torch.Tensor):
  390. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  391. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  392. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  393. distances = (
  394. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
  395. torch.sum(self.embedding.weight**2, dim=1) -
  396. 2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
  397. self.embedding.weight.transpose(0, 1)))
  398. min_encoding_indices = torch.argmin(distances, dim=1)
  399. hidden_state_quant = self.embedding(min_encoding_indices).view(
  400. hidden_state.shape)
  401. # compute loss for embedding
  402. loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
  403. 2) + self.beta * torch.mean(
  404. (hidden_state_quant - hidden_state.detach())**2)
  405. # preserve gradients
  406. hidden_state_quant = hidden_state + (hidden_state_quant -
  407. hidden_state).detach()
  408. # reshape back to match original input shape
  409. hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
  410. 2).contiguous()
  411. return hidden_state_quant, loss, min_encoding_indices
  412. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
  413. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  414. def __init__(self, in_channels: int):
  415. super().__init__()
  416. self.conv = nn.Conv2d(in_channels,
  417. in_channels,
  418. kernel_size=3,
  419. stride=2,
  420. padding=0)
  421. def forward(self, hidden_states: torch.Tensor):
  422. # no asymmetric padding in torch conv, must do it ourselves
  423. hidden_states = F.pad(hidden_states,
  424. pad=(0, 1, 0, 1),
  425. mode="constant",
  426. value=0)
  427. hidden_states = self.conv(hidden_states)
  428. return hidden_states
  429. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
  430. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  431. def __init__(
  432. self,
  433. config: ChameleonVQVAEConfig,
  434. in_channels: int,
  435. out_channels=None,
  436. conv_shortcut=False,
  437. ):
  438. super().__init__()
  439. self.in_channels = in_channels
  440. self.out_channels = in_channels if out_channels is None \
  441. else out_channels
  442. self.use_conv_shortcut = conv_shortcut
  443. self.norm1 = torch.nn.GroupNorm(num_groups=32,
  444. num_channels=in_channels,
  445. eps=1e-6,
  446. affine=True)
  447. self.conv1 = torch.nn.Conv2d(in_channels,
  448. out_channels,
  449. kernel_size=3,
  450. stride=1,
  451. padding=1)
  452. self.norm2 = torch.nn.GroupNorm(num_groups=32,
  453. num_channels=out_channels,
  454. eps=1e-6,
  455. affine=True)
  456. self.dropout = torch.nn.Dropout(config.dropout)
  457. self.conv2 = torch.nn.Conv2d(out_channels,
  458. out_channels,
  459. kernel_size=3,
  460. stride=1,
  461. padding=1)
  462. if self.in_channels != self.out_channels:
  463. if self.use_conv_shortcut:
  464. self.conv_shortcut = torch.nn.Conv2d(in_channels,
  465. out_channels,
  466. kernel_size=3,
  467. stride=1,
  468. padding=1)
  469. else:
  470. self.nin_shortcut = torch.nn.Conv2d(in_channels,
  471. out_channels,
  472. kernel_size=1,
  473. stride=1,
  474. padding=0)
  475. def forward(self, hidden_states: torch.Tensor):
  476. residual = hidden_states
  477. hidden_states = self.norm1(hidden_states)
  478. hidden_states *= torch.sigmoid(hidden_states)
  479. hidden_states = self.conv1(hidden_states)
  480. hidden_states = self.norm2(hidden_states)
  481. hidden_states *= torch.sigmoid(hidden_states)
  482. hidden_states = self.dropout(hidden_states)
  483. hidden_states = self.conv2(hidden_states)
  484. if self.in_channels != self.out_channels:
  485. if self.use_conv_shortcut:
  486. residual = self.conv_shortcut(residual)
  487. else:
  488. residual = self.nin_shortcut(residual)
  489. return residual + hidden_states
  490. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
  491. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  492. def __init__(self, in_channels: int):
  493. super().__init__()
  494. self.in_channels = in_channels
  495. self.norm = torch.nn.GroupNorm(num_groups=32,
  496. num_channels=in_channels,
  497. eps=1e-6,
  498. affine=True)
  499. self.q = torch.nn.Conv2d(in_channels,
  500. in_channels,
  501. kernel_size=1,
  502. stride=1,
  503. padding=0)
  504. self.k = torch.nn.Conv2d(in_channels,
  505. in_channels,
  506. kernel_size=1,
  507. stride=1,
  508. padding=0)
  509. self.v = torch.nn.Conv2d(in_channels,
  510. in_channels,
  511. kernel_size=1,
  512. stride=1,
  513. padding=0)
  514. self.proj_out = torch.nn.Conv2d(in_channels,
  515. in_channels,
  516. kernel_size=1,
  517. stride=1,
  518. padding=0)
  519. def forward(self, hidden_states: torch.Tensor):
  520. residual = hidden_states
  521. hidden_states = self.norm(hidden_states)
  522. query_states = self.q(hidden_states)
  523. key_states = self.k(hidden_states)
  524. value_states = self.v(hidden_states)
  525. # compute attention
  526. batch_size, channels, height, width = query_states.shape
  527. query_states = query_states.reshape(batch_size, channels,
  528. height * width).permute(0, 2, 1)
  529. key_states = key_states.reshape(batch_size, channels, height * width)
  530. attn_weights = torch.bmm(query_states, key_states)
  531. attn_weights = attn_weights * (int(channels)**(-0.5))
  532. attn_weights = F.softmax(attn_weights, dim=2)
  533. # attend to values
  534. value_states = value_states.reshape(batch_size, channels,
  535. height * width)
  536. attn_weights = attn_weights.permute(0, 2, 1)
  537. attn_output = torch.bmm(value_states,
  538. attn_weights).reshape(batch_size, channels,
  539. height, width)
  540. attn_output = self.proj_out(attn_output)
  541. return residual + attn_output
  542. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
  543. class ChameleonVQVAEEncoder(nn.Module):
  544. def __init__(self, config: ChameleonVQVAEConfig):
  545. super().__init__()
  546. self.num_resolutions = len(config.channel_multiplier)
  547. self.num_res_blocks = config.num_res_blocks
  548. base_channels = config.base_channels
  549. resolution = config.resolution
  550. in_channels = config.in_channels
  551. double_latent = config.double_latent
  552. latent_channels = config.latent_channels
  553. channel_multiplier = config.channel_multiplier
  554. self.conv_in = torch.nn.Conv2d(in_channels,
  555. base_channels,
  556. kernel_size=3,
  557. stride=1,
  558. padding=1)
  559. curr_res = resolution
  560. in_channel_multiplier = (1, ) + tuple(channel_multiplier)
  561. self.in_channel_multiplier = in_channel_multiplier
  562. self.down = nn.ModuleList()
  563. for i_level in range(self.num_resolutions):
  564. block = nn.ModuleList()
  565. attn = nn.ModuleList()
  566. block_in = base_channels * in_channel_multiplier[i_level]
  567. block_out = base_channels * channel_multiplier[i_level]
  568. for i_block in range(self.num_res_blocks):
  569. block.append(
  570. ChameleonVQVAEEncoderResnetBlock(
  571. config=config,
  572. in_channels=block_in,
  573. out_channels=block_out,
  574. ))
  575. block_in = block_out
  576. if (config.attn_resolutions is not None
  577. and curr_res in config.attn_resolutions
  578. and config.attn_type == "vanilla"):
  579. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  580. down = nn.Module()
  581. down.block = block
  582. down.attn = attn
  583. if i_level != self.num_resolutions - 1:
  584. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  585. curr_res = curr_res // 2
  586. self.down.append(down)
  587. self.mid = nn.Module()
  588. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  589. config=config,
  590. in_channels=block_in,
  591. out_channels=block_in,
  592. )
  593. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
  594. block_in) if config.attn_type == "vanilla" else nn.Identity()
  595. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  596. config=config,
  597. in_channels=block_in,
  598. out_channels=block_in,
  599. )
  600. self.norm_out = torch.nn.GroupNorm(num_groups=32,
  601. num_channels=block_in,
  602. eps=1e-6,
  603. affine=True)
  604. self.conv_out = torch.nn.Conv2d(
  605. block_in,
  606. 2 * latent_channels if double_latent else latent_channels,
  607. kernel_size=3,
  608. stride=1,
  609. padding=1,
  610. )
  611. def forward(self, pixel_values: torch.Tensor):
  612. pixel_values = pixel_values.to(self.conv_in.weight.dtype)
  613. # downsampling
  614. hidden_states = [self.conv_in(pixel_values)]
  615. for i_level in range(self.num_resolutions):
  616. for i_block in range(self.num_res_blocks):
  617. hidden_state = self.down[i_level].block[i_block](
  618. hidden_states[-1], )
  619. if len(self.down[i_level].attn) > 0:
  620. hidden_state = self.down[i_level].attn[i_block](
  621. hidden_state)
  622. hidden_states.append(hidden_state)
  623. if i_level != self.num_resolutions - 1:
  624. hidden_states.append(self.down[i_level].downsample(
  625. hidden_states[-1]))
  626. # middle
  627. last_hidden_state = hidden_states[-1]
  628. last_hidden_state = self.mid.block_1(last_hidden_state)
  629. last_hidden_state = self.mid.attn_1(last_hidden_state)
  630. last_hidden_state = self.mid.block_2(last_hidden_state)
  631. # end
  632. last_hidden_state = self.norm_out(last_hidden_state)
  633. last_hidden_state *= torch.sigmoid(last_hidden_state)
  634. last_hidden_state = self.conv_out(last_hidden_state)
  635. return last_hidden_state
  636. # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
  637. class ChameleonVQVAE(nn.Module):
  638. def __init__(self, config: ChameleonVQVAEConfig):
  639. super().__init__()
  640. self.encoder = ChameleonVQVAEEncoder(config)
  641. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  642. self.quant_conv = torch.nn.Conv2d(config.latent_channels,
  643. config.embed_dim, 1)
  644. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
  645. config.latent_channels, 1)
  646. self.eval() # Chameleon's VQ model is frozen
  647. def encode(
  648. self, pixel_values: torch.Tensor
  649. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  650. hidden_states = self.encoder(pixel_values)
  651. hidden_states = self.quant_conv(hidden_states)
  652. quant, emb_loss, indices = self.quantize(hidden_states)
  653. return quant, emb_loss, indices
  654. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
  655. class ChameleonImageVocabularyMapping:
  656. """
  657. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  658. """
  659. def __init__(self, vocab_map: Dict[str, int]):
  660. self.vocab_map = vocab_map
  661. self.image_token_id = vocab_map.get("<image>")
  662. @cached_property
  663. def val2name(self):
  664. return {v: k for k, v in self.vocab_map.items()}
  665. @cached_property
  666. def image_tokens(self):
  667. return sorted([
  668. val for name, val in self.vocab_map.items()
  669. if name.startswith("IMGIMG")
  670. ])
  671. @cached_property
  672. def bpe2img(self):
  673. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  674. def remap(old_name: str) -> str:
  675. return "".join(
  676. img_tkn_chr_mapping.get(c, c)
  677. for c in old_name[len("IMGIMG"):-1])
  678. return {
  679. tok: int(remap(self.val2name[tok]))
  680. for tok in self.image_tokens
  681. }
  682. @cached_property
  683. def img2bpe(self):
  684. return {v: k for k, v in self.bpe2img.items()}
  685. @cached_property
  686. def bpe2img_search_tensors(self):
  687. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
  688. sorted(self.bpe2img.values()))
  689. @cached_property
  690. def img2bpe_mapping_tensor(self):
  691. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  692. for k, v in self.img2bpe.items():
  693. mapping[k] = v
  694. return mapping
  695. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  696. device = img_batch.device
  697. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  698. return img_tokens.to(device)
  699. class ChameleonModel(nn.Module):
  700. def __init__(
  701. self,
  702. config: ChameleonConfig,
  703. cache_config: Optional[CacheConfig] = None,
  704. quant_config: Optional[QuantizationConfig] = None,
  705. ) -> None:
  706. super().__init__()
  707. self.config = config
  708. self.padding_idx = config.pad_token_id
  709. self.vocab_size = config.vocab_size
  710. self.embed_tokens = VocabParallelEmbedding(
  711. self.vocab_size,
  712. config.hidden_size,
  713. )
  714. self.vocabulary_mapping = ChameleonImageVocabularyMapping(
  715. config.vocabulary_map)
  716. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
  717. else ChameleonSwinDecoderLayer
  718. self.layers = nn.ModuleList([
  719. decoder_layer(config=config,
  720. cache_config=cache_config,
  721. quant_config=quant_config)
  722. for _ in range(config.num_hidden_layers)
  723. ])
  724. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  725. self.vqmodel = ChameleonVQVAE(config.vq_config)
  726. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  727. return self.embed_tokens(input_ids)
  728. def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
  729. """
  730. Tokenizes images into discrete tokens with VQGAN module. Converts
  731. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  732. special tokens.
  733. """
  734. batch_size = pixel_values.shape[0]
  735. _, _, image_toks = self.vqmodel.encode(pixel_values)
  736. bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
  737. bpe_toks = bpe_toks.view(batch_size, -1)
  738. return bpe_toks
  739. def forward(
  740. self,
  741. input_ids: Optional[torch.Tensor],
  742. positions: torch.Tensor,
  743. kv_caches: List[torch.Tensor],
  744. attn_metadata: AttentionMetadata,
  745. inputs_embeds: Optional[torch.Tensor] = None,
  746. ) -> torch.Tensor:
  747. if inputs_embeds is not None:
  748. hidden_states = inputs_embeds
  749. else:
  750. hidden_states = self.get_input_embeddings(input_ids)
  751. residual = None
  752. for i in range(len(self.layers)):
  753. layer = self.layers[i]
  754. hidden_states, residual = layer(
  755. positions,
  756. hidden_states,
  757. kv_caches[i],
  758. attn_metadata,
  759. residual,
  760. )
  761. hidden_states, _ = self.norm(hidden_states, residual)
  762. return hidden_states
  763. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  764. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
  765. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
  766. @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
  767. class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
  768. def __init__(
  769. self,
  770. config: ChameleonConfig,
  771. multimodal_config: MultiModalConfig,
  772. cache_config: Optional[CacheConfig] = None,
  773. quant_config: Optional[QuantizationConfig] = None,
  774. ) -> None:
  775. super().__init__()
  776. self.config = config
  777. self.multimodal_config = multimodal_config
  778. self.model = ChameleonModel(config, cache_config, quant_config)
  779. self.unpadded_vocab_size = config.vocab_size
  780. self.lm_head = ParallelLMHead(
  781. self.unpadded_vocab_size,
  782. config.hidden_size,
  783. )
  784. if config.tie_word_embeddings:
  785. self.lm_head.weight = self.model.embed_tokens.weight
  786. logit_scale = getattr(config, "logit_scale", 1.0)
  787. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  788. config.vocab_size, logit_scale)
  789. self.sampler = Sampler()
  790. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  791. expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
  792. CHAMELEON_CROP_SIZE_WIDTH)
  793. actual_dims = tuple(data.shape[1:])
  794. if actual_dims != expected_dims:
  795. expected_expr = ("batch_size", *map(str, expected_dims))
  796. raise ValueError(
  797. f"The expected shape of pixel values is {expected_expr}. "
  798. f"You supplied {tuple(data.shape)}.")
  799. return data
  800. def _parse_and_validate_image_input(
  801. self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
  802. pixel_values = kwargs.pop("pixel_values", None)
  803. if pixel_values is None:
  804. return None
  805. if not isinstance(pixel_values, torch.Tensor):
  806. raise ValueError("Incorrect type of pixel values. "
  807. f"Got type: {type(pixel_values)}")
  808. # Remove the N dimension until multiple images are supported.
  809. pixel_values = pixel_values.squeeze(1)
  810. return ChameleonImagePixelInputs(
  811. type="pixel_values",
  812. data=self._validate_pixel_values(pixel_values),
  813. )
  814. def forward(
  815. self,
  816. input_ids: torch.Tensor,
  817. positions: torch.Tensor,
  818. kv_caches: List[torch.Tensor],
  819. attn_metadata: AttentionMetadata,
  820. intermediate_tensors: Optional[IntermediateTensors] = None,
  821. **kwargs,
  822. ) -> torch.Tensor:
  823. image_input = self._parse_and_validate_image_input(**kwargs)
  824. if image_input is not None:
  825. assert self.model.vqmodel is not None
  826. image_tokens = self.model.get_image_tokens(image_input["data"].to(
  827. self.config.torch_dtype))
  828. image_token_id = self.model.vocabulary_mapping.image_token_id
  829. special_image_mask = input_ids == image_token_id
  830. image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
  831. input_ids = input_ids.masked_scatter(special_image_mask,
  832. image_tokens)
  833. hidden_states = self.model(input_ids, positions, kv_caches,
  834. attn_metadata)
  835. return hidden_states
  836. def compute_logits(
  837. self,
  838. hidden_states: torch.Tensor,
  839. sampling_metadata: SamplingMetadata,
  840. ) -> Optional[torch.Tensor]:
  841. logits = self.logits_processor(self.lm_head, hidden_states,
  842. sampling_metadata)
  843. # Disallow image tokens which does not include special
  844. # begin-image and end-image tokens
  845. if logits is not None:
  846. image_tokens = self.model.vocabulary_mapping.image_tokens
  847. logits[:, image_tokens] = torch.finfo(logits.dtype).min
  848. return logits
  849. def sample(
  850. self,
  851. logits: torch.Tensor,
  852. sampling_metadata: SamplingMetadata,
  853. ) -> Optional[SamplerOutput]:
  854. next_tokens = self.sampler(logits, sampling_metadata)
  855. return next_tokens
  856. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  857. stacked_params_mapping = [
  858. # (param_name, shard_name, shard_id)
  859. (".qkv_proj", ".q_proj", "q"),
  860. (".qkv_proj", ".k_proj", "k"),
  861. (".qkv_proj", ".v_proj", "v"),
  862. (".gate_up_proj", ".gate_proj", 0),
  863. (".gate_up_proj", ".up_proj", 1),
  864. ]
  865. params_dict = dict(self.named_parameters())
  866. for name, loaded_weight in weights:
  867. if "rotary_emb.inv_freq" in name:
  868. continue
  869. if ("rotary_emb.cos_cached" in name
  870. or "rotary_emb.sin_cached" in name):
  871. # Models trained using ColossalAI may include these tensors in
  872. # the checkpoint. Skip them.
  873. continue
  874. # With tie_word_embeddings, we can skip lm_head.weight
  875. # The weight might appear unnecessarily in the files if the model is
  876. # processed with quantization, LoRA, fine-tuning, etc.
  877. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  878. continue
  879. use_default_weight_loading = False
  880. if "vqmodel" in name:
  881. if self.model.vqmodel is not None:
  882. # We only do sharding for language model and
  883. # not vqvae for now.
  884. use_default_weight_loading = True
  885. else:
  886. for (param_name, weight_name,
  887. shard_id) in stacked_params_mapping:
  888. if weight_name not in name:
  889. continue
  890. name = name.replace(weight_name, param_name)
  891. # Skip loading extra bias for GPTQ models.
  892. if name.endswith(".bias") and name not in params_dict:
  893. continue
  894. param = params_dict[name]
  895. weight_loader = param.weight_loader
  896. weight_loader(param, loaded_weight, shard_id)
  897. break
  898. else:
  899. # Skip loading extra bias for GPTQ models.
  900. if name.endswith(".bias") and name not in params_dict:
  901. continue
  902. # Remapping the name of FP8 kv-scale.
  903. if name.endswith("kv_scale"):
  904. remapped_kv_scale_name = name.replace(
  905. ".kv_scale", ".attn.kv_scale")
  906. if remapped_kv_scale_name not in params_dict:
  907. print_warning_once(
  908. "Found kv scale in the checkpoint (e.g. "
  909. f"{name}), but not found the expected name in "
  910. f"the model (e.g. {remapped_kv_scale_name}). "
  911. "kv-scale is not loaded.")
  912. continue
  913. else:
  914. name = remapped_kv_scale_name
  915. param = params_dict[name]
  916. weight_loader = getattr(param, "weight_loader",
  917. default_weight_loader)
  918. weight_loader(param, loaded_weight)
  919. if use_default_weight_loading and name in params_dict:
  920. param = params_dict[name]
  921. weight_loader = getattr(param, "weight_loader",
  922. default_weight_loader)
  923. weight_loader(param, loaded_weight)