chameleon.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062
  1. from functools import cached_property
  2. from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
  3. TypedDict)
  4. import torch
  5. import torch.nn.functional as F
  6. from PIL import Image
  7. from torch import nn
  8. from transformers import ChameleonConfig, ChameleonVQVAEConfig
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, MultiModalConfig
  11. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  12. SequenceData)
  13. from aphrodite.common.utils import print_warning_once, progress_bar
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  16. from aphrodite.modeling.layers.activation import SiluAndMul
  17. from aphrodite.modeling.layers.layernorm import RMSNorm
  18. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  19. QKVParallelLinear,
  20. RowParallelLinear)
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.rotary_embedding import get_rope
  23. from aphrodite.modeling.layers.sampler import Sampler
  24. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  25. ParallelLMHead, VocabParallelEmbedding)
  26. from aphrodite.modeling.model_loader.weight_utils import (
  27. default_weight_loader, row_parallel_weight_loader)
  28. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  29. from aphrodite.modeling.utils import set_weight_attrs
  30. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  31. from aphrodite.multimodal.image import (cached_get_tokenizer,
  32. repeat_and_pad_image_tokens)
  33. from aphrodite.quantization.base_config import QuantizationConfig
  34. from .interfaces import SupportsMultiModal
  35. # These configs are not part of the model config but the preprocessor
  36. # and processor files, so we hardcode them in the model file for now.
  37. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
  38. CHAMELEON_IMAGE_SEQ_LENGTH = 1024
  39. CHAMELEON_IMAGE_TOKEN_ID = 8711
  40. CHAMELEON_IMAGE_START_TOKEN_ID = 8197
  41. CHAMELEON_IMAGE_END_TOKEN_ID = 8196
  42. CHAMELEON_SEP_TOKEN_ID = 8710
  43. class ChameleonImagePixelInputs(TypedDict):
  44. type: Literal["pixel_values"]
  45. data: torch.Tensor
  46. """Shape: `(batch_size, num_channels, height, width)`"""
  47. def get_max_chameleon_image_tokens(ctx: InputContext):
  48. return CHAMELEON_IMAGE_SEQ_LENGTH
  49. def dummy_seq_data_for_chameleon(
  50. seq_len: int,
  51. *,
  52. image_token_id: int,
  53. image_feature_size_override: Optional[int] = None,
  54. ):
  55. if image_feature_size_override is None:
  56. image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
  57. else:
  58. image_feature_size = image_feature_size_override
  59. token_ids = [image_token_id] * image_feature_size
  60. token_ids += [0] * (seq_len - image_feature_size)
  61. return SequenceData(token_ids)
  62. def dummy_image_for_chameleon(
  63. image_width_override: Optional[int] = None,
  64. image_height_override: Optional[int] = None,
  65. ):
  66. width = CHAMELEON_CROP_SIZE_WIDTH
  67. height = CHAMELEON_CROP_SIZE_HEIGHT
  68. if image_width_override is not None:
  69. width = image_width_override
  70. if image_height_override is not None:
  71. height = image_height_override
  72. image = Image.new("RGB", (width, height), color=0)
  73. return {"image": image}
  74. def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
  75. seq_data = dummy_seq_data_for_chameleon(
  76. seq_len,
  77. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  78. )
  79. mm_data = dummy_image_for_chameleon()
  80. return seq_data, mm_data
  81. def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
  82. """
  83. Processing input prompt to insert required tokens for image placeholder.
  84. See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
  85. """ # noqa
  86. multi_modal_data = llm_inputs.get("multi_modal_data")
  87. if multi_modal_data is None or "image" not in multi_modal_data:
  88. return llm_inputs
  89. model_config = ctx.model_config
  90. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  91. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  92. tokenizer,
  93. llm_inputs.get("prompt"),
  94. llm_inputs["prompt_token_ids"],
  95. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  96. repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
  97. pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
  98. pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
  99. )
  100. # Appending sep token for chat mode to follow default processor
  101. # behavior
  102. if new_prompt is not None:
  103. new_prompt += tokenizer.sep_token
  104. new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
  105. # NOTE: Create a defensive copy of the original inputs
  106. return LLMInputs(prompt_token_ids=new_token_ids,
  107. prompt=new_prompt,
  108. multi_modal_data=multi_modal_data)
  109. class ChameleonLayerNorm(nn.LayerNorm):
  110. def __init__(self, hidden_size, *args, **kwargs):
  111. super().__init__(hidden_size, *args, **kwargs)
  112. self.normalized_shape = (hidden_size[-1], )
  113. set_weight_attrs(self.weight,
  114. {"weight_loader": row_parallel_weight_loader})
  115. set_weight_attrs(self.bias,
  116. {"weight_loader": row_parallel_weight_loader})
  117. def forward(self, hidden_states):
  118. hidden_states = F.layer_norm(hidden_states,
  119. self.normalized_shape,
  120. None,
  121. None,
  122. eps=1e-5)
  123. hidden_states = hidden_states * self.weight + self.bias
  124. return hidden_states
  125. # Copied from aphrodite.modeling.models.llama.LlamaMLP -> ChameleonMLP
  126. class ChameleonMLP(nn.Module):
  127. def __init__(
  128. self,
  129. hidden_size: int,
  130. intermediate_size: int,
  131. hidden_act: str,
  132. quant_config: Optional[QuantizationConfig] = None,
  133. bias: bool = False,
  134. ) -> None:
  135. super().__init__()
  136. self.gate_up_proj = MergedColumnParallelLinear(
  137. input_size=hidden_size,
  138. output_sizes=[intermediate_size] * 2,
  139. bias=bias,
  140. quant_config=quant_config)
  141. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  142. output_size=hidden_size,
  143. bias=bias,
  144. quant_config=quant_config)
  145. if hidden_act != "silu":
  146. raise ValueError(f"Unsupported activation: {hidden_act}. "
  147. "Only silu is supported for now.")
  148. self.act_fn = SiluAndMul()
  149. def forward(self, x):
  150. gate_up, _ = self.gate_up_proj(x)
  151. x = self.act_fn(gate_up)
  152. x, _ = self.down_proj(x)
  153. return x
  154. # Modified from aphrodite.modeling.models.llama.LlamaAttention -> ChameleonAttention #noqa
  155. class ChameleonAttention(nn.Module):
  156. def __init__(
  157. self,
  158. hidden_size: int,
  159. num_heads: int,
  160. num_kv_heads: int,
  161. rope_theta: float = 10000,
  162. rope_scaling: Optional[Dict[str, Any]] = None,
  163. max_position_embeddings: int = 4096,
  164. quant_config: Optional[QuantizationConfig] = None,
  165. bias: bool = False,
  166. cache_config: Optional[CacheConfig] = None,
  167. ) -> None:
  168. super().__init__()
  169. self.hidden_size = hidden_size
  170. tp_size = get_tensor_model_parallel_world_size()
  171. self.total_num_heads = num_heads
  172. assert self.total_num_heads % tp_size == 0
  173. self.num_heads = self.total_num_heads // tp_size
  174. self.total_num_kv_heads = num_kv_heads
  175. if self.total_num_kv_heads >= tp_size:
  176. # Number of KV heads is greater than TP size, so we partition
  177. # the KV heads across multiple tensor parallel GPUs.
  178. assert self.total_num_kv_heads % tp_size == 0
  179. else:
  180. # Number of KV heads is less than TP size, so we replicate
  181. # the KV heads across multiple tensor parallel GPUs.
  182. assert tp_size % self.total_num_kv_heads == 0
  183. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  184. self.head_dim = hidden_size // self.total_num_heads
  185. self.q_size = self.num_heads * self.head_dim
  186. self.kv_size = self.num_kv_heads * self.head_dim
  187. self.scaling = self.head_dim**-0.5
  188. self.rope_theta = rope_theta
  189. self.max_position_embeddings = max_position_embeddings
  190. self.qkv_proj = QKVParallelLinear(
  191. hidden_size=hidden_size,
  192. head_size=self.head_dim,
  193. total_num_heads=self.total_num_heads,
  194. total_num_kv_heads=self.total_num_kv_heads,
  195. bias=bias,
  196. quant_config=quant_config,
  197. )
  198. self.o_proj = RowParallelLinear(
  199. input_size=self.total_num_heads * self.head_dim,
  200. output_size=hidden_size,
  201. bias=bias,
  202. quant_config=quant_config,
  203. )
  204. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  205. self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
  206. self.rotary_emb = get_rope(
  207. self.head_dim,
  208. rotary_dim=self.head_dim,
  209. max_position=max_position_embeddings,
  210. base=rope_theta,
  211. rope_scaling=rope_scaling,
  212. )
  213. self.attn = Attention(self.num_heads,
  214. self.head_dim,
  215. self.scaling,
  216. num_kv_heads=self.num_kv_heads,
  217. cache_config=cache_config,
  218. quant_config=quant_config)
  219. def _apply_qk_norm(self, q: torch.Tensor,
  220. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  221. # reshape for layernorm
  222. q = q.reshape(-1, self.num_heads, self.head_dim)
  223. k = k.reshape(-1, self.num_kv_heads, self.head_dim)
  224. q = self.q_norm(q)
  225. k = self.k_norm(k)
  226. q = q.view(*q.shape[:-2], -1)
  227. k = k.view(*k.shape[:-2], -1)
  228. return q, k
  229. def forward(
  230. self,
  231. positions: torch.Tensor,
  232. hidden_states: torch.Tensor,
  233. kv_cache: torch.Tensor,
  234. attn_metadata: AttentionMetadata,
  235. ) -> torch.Tensor:
  236. qkv, _ = self.qkv_proj(hidden_states)
  237. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  238. q, k = self._apply_qk_norm(q, k)
  239. q, k = self.rotary_emb(positions, q, k)
  240. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  241. output, _ = self.o_proj(attn_output)
  242. return output
  243. class ChameleonDecoderLayer(nn.Module):
  244. def __init__(
  245. self,
  246. config: ChameleonConfig,
  247. cache_config: Optional[CacheConfig] = None,
  248. quant_config: Optional[QuantizationConfig] = None,
  249. ) -> None:
  250. super().__init__()
  251. self.hidden_size = config.hidden_size
  252. rope_theta = getattr(config, "rope_theta", 10000)
  253. rope_scaling = getattr(config, "rope_scaling", None)
  254. if rope_scaling is not None and getattr(
  255. config, "original_max_position_embeddings", None):
  256. rope_scaling["original_max_position_embeddings"] = (
  257. config.original_max_position_embeddings)
  258. max_position_embeddings = getattr(config, "max_position_embeddings",
  259. 4096)
  260. self.self_attn = ChameleonAttention(
  261. hidden_size=self.hidden_size,
  262. num_heads=config.num_attention_heads,
  263. num_kv_heads=getattr(config, "num_key_value_heads",
  264. config.num_attention_heads),
  265. rope_theta=rope_theta,
  266. rope_scaling=rope_scaling,
  267. max_position_embeddings=max_position_embeddings,
  268. quant_config=quant_config,
  269. bias=False,
  270. cache_config=cache_config,
  271. )
  272. self.mlp = ChameleonMLP(
  273. hidden_size=self.hidden_size,
  274. intermediate_size=config.intermediate_size,
  275. hidden_act=config.hidden_act,
  276. quant_config=quant_config,
  277. bias=getattr(config, "mlp_bias", False),
  278. )
  279. self.input_layernorm = RMSNorm(config.hidden_size,
  280. eps=config.rms_norm_eps)
  281. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  282. eps=config.rms_norm_eps)
  283. def forward(
  284. self,
  285. positions: torch.Tensor,
  286. hidden_states: torch.Tensor,
  287. kv_cache: torch.Tensor,
  288. attn_metadata: AttentionMetadata,
  289. residual: Optional[torch.Tensor],
  290. ) -> Tuple[torch.Tensor, torch.Tensor]:
  291. if residual is None:
  292. residual = hidden_states
  293. hidden_states = self.input_layernorm(hidden_states)
  294. else:
  295. hidden_states, residual = self.input_layernorm(
  296. hidden_states, residual)
  297. hidden_states = self.self_attn(
  298. positions=positions,
  299. hidden_states=hidden_states,
  300. kv_cache=kv_cache,
  301. attn_metadata=attn_metadata,
  302. )
  303. # Fully Connected
  304. hidden_states, residual = self.post_attention_layernorm(
  305. hidden_states, residual)
  306. hidden_states = self.mlp(hidden_states)
  307. return hidden_states, residual
  308. class ChameleonSwinDecoderLayer(nn.Module):
  309. def __init__(
  310. self,
  311. config: ChameleonConfig,
  312. cache_config: Optional[CacheConfig] = None,
  313. quant_config: Optional[QuantizationConfig] = None,
  314. ) -> None:
  315. super().__init__()
  316. self.hidden_size = config.hidden_size
  317. rope_theta = getattr(config, "rope_theta", 10000)
  318. rope_scaling = getattr(config, "rope_scaling", None)
  319. if rope_scaling is not None and getattr(
  320. config, "original_max_position_embeddings", None):
  321. rope_scaling["original_max_position_embeddings"] = (
  322. config.original_max_position_embeddings)
  323. max_position_embeddings = getattr(config, "max_position_embeddings",
  324. 4096)
  325. self.self_attn = ChameleonAttention(
  326. hidden_size=self.hidden_size,
  327. num_heads=config.num_attention_heads,
  328. num_kv_heads=getattr(config, "num_key_value_heads",
  329. config.num_attention_heads),
  330. rope_theta=rope_theta,
  331. rope_scaling=rope_scaling,
  332. max_position_embeddings=max_position_embeddings,
  333. quant_config=quant_config,
  334. bias=False,
  335. cache_config=cache_config,
  336. )
  337. self.mlp = ChameleonMLP(
  338. hidden_size=self.hidden_size,
  339. intermediate_size=config.intermediate_size,
  340. hidden_act=config.hidden_act,
  341. quant_config=quant_config,
  342. bias=getattr(config, "mlp_bias", False),
  343. )
  344. self.input_layernorm = RMSNorm(config.hidden_size,
  345. eps=config.rms_norm_eps)
  346. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  347. eps=config.rms_norm_eps)
  348. def forward(
  349. self,
  350. positions: torch.Tensor,
  351. hidden_states: torch.Tensor,
  352. kv_cache: torch.Tensor,
  353. attn_metadata: AttentionMetadata,
  354. residual: Optional[torch.Tensor],
  355. ) -> Tuple[torch.Tensor, torch.Tensor]:
  356. residual = hidden_states
  357. hidden_states = self.self_attn(
  358. positions=positions,
  359. hidden_states=hidden_states,
  360. kv_cache=kv_cache,
  361. attn_metadata=attn_metadata,
  362. )
  363. hidden_states = self.input_layernorm(hidden_states)
  364. hidden_states = hidden_states + residual
  365. # Fully Connected
  366. residual = hidden_states
  367. hidden_states = self.mlp(hidden_states)
  368. hidden_states = self.post_attention_layernorm(hidden_states)
  369. hidden_states = residual + hidden_states
  370. return hidden_states, residual
  371. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
  372. class ChameleonVQVAEVectorQuantizer(nn.Module):
  373. def __init__(self, config: ChameleonVQVAEConfig):
  374. super().__init__()
  375. self.num_embeddings = config.num_embeddings
  376. self.embedding_dim = config.embed_dim
  377. self.beta = getattr(config, "beta", 0.25)
  378. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  379. self.re_embed = self.num_embeddings
  380. def forward(self, hidden_state: torch.Tensor):
  381. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  382. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  383. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  384. distances = (
  385. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
  386. torch.sum(self.embedding.weight**2, dim=1) -
  387. 2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
  388. self.embedding.weight.transpose(0, 1)))
  389. min_encoding_indices = torch.argmin(distances, dim=1)
  390. hidden_state_quant = self.embedding(min_encoding_indices).view(
  391. hidden_state.shape)
  392. # compute loss for embedding
  393. loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
  394. 2) + self.beta * torch.mean(
  395. (hidden_state_quant - hidden_state.detach())**2)
  396. # preserve gradients
  397. hidden_state_quant = hidden_state + (hidden_state_quant -
  398. hidden_state).detach()
  399. # reshape back to match original input shape
  400. hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
  401. 2).contiguous()
  402. return hidden_state_quant, loss, min_encoding_indices
  403. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
  404. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  405. def __init__(self, in_channels: int):
  406. super().__init__()
  407. self.conv = nn.Conv2d(in_channels,
  408. in_channels,
  409. kernel_size=3,
  410. stride=2,
  411. padding=0)
  412. def forward(self, hidden_states: torch.Tensor):
  413. # no asymmetric padding in torch conv, must do it ourselves
  414. hidden_states = F.pad(hidden_states,
  415. pad=(0, 1, 0, 1),
  416. mode="constant",
  417. value=0)
  418. hidden_states = self.conv(hidden_states)
  419. return hidden_states
  420. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
  421. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  422. def __init__(
  423. self,
  424. config: ChameleonVQVAEConfig,
  425. in_channels: int,
  426. out_channels=None,
  427. conv_shortcut=False,
  428. ):
  429. super().__init__()
  430. self.in_channels = in_channels
  431. self.out_channels = in_channels if out_channels is None \
  432. else out_channels
  433. self.use_conv_shortcut = conv_shortcut
  434. self.norm1 = torch.nn.GroupNorm(num_groups=32,
  435. num_channels=in_channels,
  436. eps=1e-6,
  437. affine=True)
  438. self.conv1 = torch.nn.Conv2d(in_channels,
  439. out_channels,
  440. kernel_size=3,
  441. stride=1,
  442. padding=1)
  443. self.norm2 = torch.nn.GroupNorm(num_groups=32,
  444. num_channels=out_channels,
  445. eps=1e-6,
  446. affine=True)
  447. self.dropout = torch.nn.Dropout(config.dropout)
  448. self.conv2 = torch.nn.Conv2d(out_channels,
  449. out_channels,
  450. kernel_size=3,
  451. stride=1,
  452. padding=1)
  453. if self.in_channels != self.out_channels:
  454. if self.use_conv_shortcut:
  455. self.conv_shortcut = torch.nn.Conv2d(in_channels,
  456. out_channels,
  457. kernel_size=3,
  458. stride=1,
  459. padding=1)
  460. else:
  461. self.nin_shortcut = torch.nn.Conv2d(in_channels,
  462. out_channels,
  463. kernel_size=1,
  464. stride=1,
  465. padding=0)
  466. def forward(self, hidden_states: torch.Tensor):
  467. residual = hidden_states
  468. hidden_states = self.norm1(hidden_states)
  469. hidden_states *= torch.sigmoid(hidden_states)
  470. hidden_states = self.conv1(hidden_states)
  471. hidden_states = self.norm2(hidden_states)
  472. hidden_states *= torch.sigmoid(hidden_states)
  473. hidden_states = self.dropout(hidden_states)
  474. hidden_states = self.conv2(hidden_states)
  475. if self.in_channels != self.out_channels:
  476. if self.use_conv_shortcut:
  477. residual = self.conv_shortcut(residual)
  478. else:
  479. residual = self.nin_shortcut(residual)
  480. return residual + hidden_states
  481. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
  482. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  483. def __init__(self, in_channels: int):
  484. super().__init__()
  485. self.in_channels = in_channels
  486. self.norm = torch.nn.GroupNorm(num_groups=32,
  487. num_channels=in_channels,
  488. eps=1e-6,
  489. affine=True)
  490. self.q = torch.nn.Conv2d(in_channels,
  491. in_channels,
  492. kernel_size=1,
  493. stride=1,
  494. padding=0)
  495. self.k = torch.nn.Conv2d(in_channels,
  496. in_channels,
  497. kernel_size=1,
  498. stride=1,
  499. padding=0)
  500. self.v = torch.nn.Conv2d(in_channels,
  501. in_channels,
  502. kernel_size=1,
  503. stride=1,
  504. padding=0)
  505. self.proj_out = torch.nn.Conv2d(in_channels,
  506. in_channels,
  507. kernel_size=1,
  508. stride=1,
  509. padding=0)
  510. def forward(self, hidden_states: torch.Tensor):
  511. residual = hidden_states
  512. hidden_states = self.norm(hidden_states)
  513. query_states = self.q(hidden_states)
  514. key_states = self.k(hidden_states)
  515. value_states = self.v(hidden_states)
  516. # compute attention
  517. batch_size, channels, height, width = query_states.shape
  518. query_states = query_states.reshape(batch_size, channels,
  519. height * width).permute(0, 2, 1)
  520. key_states = key_states.reshape(batch_size, channels, height * width)
  521. attn_weights = torch.bmm(query_states, key_states)
  522. attn_weights = attn_weights * (int(channels)**(-0.5))
  523. attn_weights = F.softmax(attn_weights, dim=2)
  524. # attend to values
  525. value_states = value_states.reshape(batch_size, channels,
  526. height * width)
  527. attn_weights = attn_weights.permute(0, 2, 1)
  528. attn_output = torch.bmm(value_states,
  529. attn_weights).reshape(batch_size, channels,
  530. height, width)
  531. attn_output = self.proj_out(attn_output)
  532. return residual + attn_output
  533. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
  534. class ChameleonVQVAEEncoder(nn.Module):
  535. def __init__(self, config: ChameleonVQVAEConfig):
  536. super().__init__()
  537. self.num_resolutions = len(config.channel_multiplier)
  538. self.num_res_blocks = config.num_res_blocks
  539. base_channels = config.base_channels
  540. resolution = config.resolution
  541. in_channels = config.in_channels
  542. double_latent = config.double_latent
  543. latent_channels = config.latent_channels
  544. channel_multiplier = config.channel_multiplier
  545. self.conv_in = torch.nn.Conv2d(in_channels,
  546. base_channels,
  547. kernel_size=3,
  548. stride=1,
  549. padding=1)
  550. curr_res = resolution
  551. in_channel_multiplier = (1, ) + tuple(channel_multiplier)
  552. self.in_channel_multiplier = in_channel_multiplier
  553. self.down = nn.ModuleList()
  554. for i_level in range(self.num_resolutions):
  555. block = nn.ModuleList()
  556. attn = nn.ModuleList()
  557. block_in = base_channels * in_channel_multiplier[i_level]
  558. block_out = base_channels * channel_multiplier[i_level]
  559. for i_block in range(self.num_res_blocks):
  560. block.append(
  561. ChameleonVQVAEEncoderResnetBlock(
  562. config=config,
  563. in_channels=block_in,
  564. out_channels=block_out,
  565. ))
  566. block_in = block_out
  567. if (config.attn_resolutions is not None
  568. and curr_res in config.attn_resolutions
  569. and config.attn_type == "vanilla"):
  570. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  571. down = nn.Module()
  572. down.block = block
  573. down.attn = attn
  574. if i_level != self.num_resolutions - 1:
  575. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  576. curr_res = curr_res // 2
  577. self.down.append(down)
  578. self.mid = nn.Module()
  579. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  580. config=config,
  581. in_channels=block_in,
  582. out_channels=block_in,
  583. )
  584. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
  585. block_in) if config.attn_type == "vanilla" else nn.Identity()
  586. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  587. config=config,
  588. in_channels=block_in,
  589. out_channels=block_in,
  590. )
  591. self.norm_out = torch.nn.GroupNorm(num_groups=32,
  592. num_channels=block_in,
  593. eps=1e-6,
  594. affine=True)
  595. self.conv_out = torch.nn.Conv2d(
  596. block_in,
  597. 2 * latent_channels if double_latent else latent_channels,
  598. kernel_size=3,
  599. stride=1,
  600. padding=1,
  601. )
  602. def forward(self, pixel_values: torch.Tensor):
  603. pixel_values = pixel_values.to(self.conv_in.weight.dtype)
  604. # downsampling
  605. hidden_states = [self.conv_in(pixel_values)]
  606. for i_level in range(self.num_resolutions):
  607. for i_block in range(self.num_res_blocks):
  608. hidden_state = self.down[i_level].block[i_block](
  609. hidden_states[-1], )
  610. if len(self.down[i_level].attn) > 0:
  611. hidden_state = self.down[i_level].attn[i_block](
  612. hidden_state)
  613. hidden_states.append(hidden_state)
  614. if i_level != self.num_resolutions - 1:
  615. hidden_states.append(self.down[i_level].downsample(
  616. hidden_states[-1]))
  617. # middle
  618. last_hidden_state = hidden_states[-1]
  619. last_hidden_state = self.mid.block_1(last_hidden_state)
  620. last_hidden_state = self.mid.attn_1(last_hidden_state)
  621. last_hidden_state = self.mid.block_2(last_hidden_state)
  622. # end
  623. last_hidden_state = self.norm_out(last_hidden_state)
  624. last_hidden_state *= torch.sigmoid(last_hidden_state)
  625. last_hidden_state = self.conv_out(last_hidden_state)
  626. return last_hidden_state
  627. # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
  628. class ChameleonVQVAE(nn.Module):
  629. def __init__(self, config: ChameleonVQVAEConfig):
  630. super().__init__()
  631. self.encoder = ChameleonVQVAEEncoder(config)
  632. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  633. self.quant_conv = torch.nn.Conv2d(config.latent_channels,
  634. config.embed_dim, 1)
  635. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
  636. config.latent_channels, 1)
  637. self.eval() # Chameleon's VQ model is frozen
  638. def encode(
  639. self, pixel_values: torch.Tensor
  640. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  641. hidden_states = self.encoder(pixel_values)
  642. hidden_states = self.quant_conv(hidden_states)
  643. quant, emb_loss, indices = self.quantize(hidden_states)
  644. return quant, emb_loss, indices
  645. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
  646. class ChameleonImageVocabularyMapping:
  647. """
  648. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  649. """
  650. def __init__(self, vocab_map: Dict[str, int]):
  651. self.vocab_map = vocab_map
  652. self.image_token_id = vocab_map.get("<image>")
  653. @cached_property
  654. def val2name(self):
  655. return {v: k for k, v in self.vocab_map.items()}
  656. @cached_property
  657. def image_tokens(self):
  658. return sorted([
  659. val for name, val in self.vocab_map.items()
  660. if name.startswith("IMGIMG")
  661. ])
  662. @cached_property
  663. def bpe2img(self):
  664. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  665. def remap(old_name: str) -> str:
  666. return "".join(
  667. img_tkn_chr_mapping.get(c, c)
  668. for c in old_name[len("IMGIMG"):-1])
  669. return {
  670. tok: int(remap(self.val2name[tok]))
  671. for tok in self.image_tokens
  672. }
  673. @cached_property
  674. def img2bpe(self):
  675. return {v: k for k, v in self.bpe2img.items()}
  676. @cached_property
  677. def bpe2img_search_tensors(self):
  678. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
  679. sorted(self.bpe2img.values()))
  680. @cached_property
  681. def img2bpe_mapping_tensor(self):
  682. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  683. for k, v in self.img2bpe.items():
  684. mapping[k] = v
  685. return mapping
  686. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  687. device = img_batch.device
  688. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  689. return img_tokens.to(device)
  690. class ChameleonModel(nn.Module):
  691. def __init__(
  692. self,
  693. config: ChameleonConfig,
  694. cache_config: Optional[CacheConfig] = None,
  695. quant_config: Optional[QuantizationConfig] = None,
  696. ) -> None:
  697. super().__init__()
  698. self.config = config
  699. self.padding_idx = config.pad_token_id
  700. self.vocab_size = config.vocab_size
  701. self.embed_tokens = VocabParallelEmbedding(
  702. self.vocab_size,
  703. config.hidden_size,
  704. )
  705. self.vocabulary_mapping = ChameleonImageVocabularyMapping(
  706. config.vocabulary_map)
  707. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
  708. else ChameleonSwinDecoderLayer
  709. self.layers = nn.ModuleList([
  710. decoder_layer(config=config,
  711. cache_config=cache_config,
  712. quant_config=quant_config)
  713. for _ in range(config.num_hidden_layers)
  714. ])
  715. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  716. self.vqmodel = ChameleonVQVAE(config.vq_config)
  717. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  718. return self.embed_tokens(input_ids)
  719. def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
  720. """
  721. Tokenizes images into discrete tokens with VQGAN module. Converts
  722. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  723. special tokens.
  724. """
  725. batch_size = pixel_values.shape[0]
  726. _, _, image_toks = self.vqmodel.encode(pixel_values)
  727. bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
  728. bpe_toks = bpe_toks.view(batch_size, -1)
  729. return bpe_toks
  730. def forward(
  731. self,
  732. input_ids: Optional[torch.Tensor],
  733. positions: torch.Tensor,
  734. kv_caches: List[torch.Tensor],
  735. attn_metadata: AttentionMetadata,
  736. inputs_embeds: Optional[torch.Tensor] = None,
  737. ) -> torch.Tensor:
  738. if inputs_embeds is not None:
  739. hidden_states = inputs_embeds
  740. else:
  741. hidden_states = self.get_input_embeddings(input_ids)
  742. residual = None
  743. for i in range(len(self.layers)):
  744. layer = self.layers[i]
  745. hidden_states, residual = layer(
  746. positions,
  747. hidden_states,
  748. kv_caches[i],
  749. attn_metadata,
  750. residual,
  751. )
  752. hidden_states, _ = self.norm(hidden_states, residual)
  753. return hidden_states
  754. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  755. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
  756. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
  757. @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
  758. class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
  759. def __init__(
  760. self,
  761. config: ChameleonConfig,
  762. multimodal_config: MultiModalConfig,
  763. cache_config: Optional[CacheConfig] = None,
  764. quant_config: Optional[QuantizationConfig] = None,
  765. ) -> None:
  766. super().__init__()
  767. self.config = config
  768. self.multimodal_config = multimodal_config
  769. self.model = ChameleonModel(config, cache_config, quant_config)
  770. self.unpadded_vocab_size = config.vocab_size
  771. self.lm_head = ParallelLMHead(
  772. self.unpadded_vocab_size,
  773. config.hidden_size,
  774. )
  775. if config.tie_word_embeddings:
  776. self.lm_head.weight = self.model.embed_tokens.weight
  777. logit_scale = getattr(config, "logit_scale", 1.0)
  778. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  779. config.vocab_size, logit_scale)
  780. self.sampler = Sampler()
  781. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  782. expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
  783. CHAMELEON_CROP_SIZE_WIDTH)
  784. actual_dims = tuple(data.shape[1:])
  785. if actual_dims != expected_dims:
  786. expected_expr = ("batch_size", *map(str, expected_dims))
  787. raise ValueError(
  788. f"The expected shape of pixel values is {expected_expr}. "
  789. f"You supplied {tuple(data.shape)}.")
  790. return data
  791. def _parse_and_validate_image_input(
  792. self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
  793. pixel_values = kwargs.pop("pixel_values", None)
  794. if pixel_values is None:
  795. return None
  796. if not isinstance(pixel_values, torch.Tensor):
  797. raise ValueError("Incorrect type of pixel values. "
  798. f"Got type: {type(pixel_values)}")
  799. return ChameleonImagePixelInputs(
  800. type="pixel_values",
  801. data=self._validate_pixel_values(pixel_values),
  802. )
  803. def forward(
  804. self,
  805. input_ids: torch.Tensor,
  806. positions: torch.Tensor,
  807. kv_caches: List[torch.Tensor],
  808. attn_metadata: AttentionMetadata,
  809. intermediate_tensors: Optional[IntermediateTensors] = None,
  810. **kwargs,
  811. ) -> torch.Tensor:
  812. image_input = self._parse_and_validate_image_input(**kwargs)
  813. if image_input is not None:
  814. assert self.model.vqmodel is not None
  815. image_tokens = self.model.get_image_tokens(image_input["data"].to(
  816. self.config.torch_dtype))
  817. image_token_id = self.model.vocabulary_mapping.image_token_id
  818. special_image_mask = input_ids == image_token_id
  819. image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
  820. input_ids = input_ids.masked_scatter(special_image_mask,
  821. image_tokens)
  822. hidden_states = self.model(input_ids, positions, kv_caches,
  823. attn_metadata)
  824. return hidden_states
  825. def compute_logits(
  826. self,
  827. hidden_states: torch.Tensor,
  828. sampling_metadata: SamplingMetadata,
  829. ) -> Optional[torch.Tensor]:
  830. logits = self.logits_processor(self.lm_head, hidden_states,
  831. sampling_metadata)
  832. # Disallow image tokens which does not include special
  833. # begin-image and end-image tokens
  834. if logits is not None:
  835. image_tokens = self.model.vocabulary_mapping.image_tokens
  836. logits[:, image_tokens] = torch.finfo(logits.dtype).min
  837. return logits
  838. def sample(
  839. self,
  840. logits: torch.Tensor,
  841. sampling_metadata: SamplingMetadata,
  842. ) -> Optional[SamplerOutput]:
  843. next_tokens = self.sampler(logits, sampling_metadata)
  844. return next_tokens
  845. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  846. stacked_params_mapping = [
  847. # (param_name, shard_name, shard_id)
  848. (".qkv_proj", ".q_proj", "q"),
  849. (".qkv_proj", ".k_proj", "k"),
  850. (".qkv_proj", ".v_proj", "v"),
  851. (".gate_up_proj", ".gate_proj", 0),
  852. (".gate_up_proj", ".up_proj", 1),
  853. ]
  854. params_dict = dict(self.named_parameters())
  855. weights_list = list(weights)
  856. for name, loaded_weight in progress_bar(weights_list,
  857. desc="Loading modules..."):
  858. if "rotary_emb.inv_freq" in name:
  859. continue
  860. if ("rotary_emb.cos_cached" in name
  861. or "rotary_emb.sin_cached" in name):
  862. # Models trained using ColossalAI may include these tensors in
  863. # the checkpoint. Skip them.
  864. continue
  865. # With tie_word_embeddings, we can skip lm_head.weight
  866. # The weight might appear unnecessarily in the files if the model is
  867. # processed with quantization, LoRA, fine-tuning, etc.
  868. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  869. continue
  870. use_default_weight_loading = False
  871. if "vqmodel" in name:
  872. if self.model.vqmodel is not None:
  873. # We only do sharding for language model and
  874. # not vqvae for now.
  875. use_default_weight_loading = True
  876. else:
  877. for (param_name, weight_name,
  878. shard_id) in stacked_params_mapping:
  879. if weight_name not in name:
  880. continue
  881. name = name.replace(weight_name, param_name)
  882. # Skip loading extra bias for GPTQ models.
  883. if name.endswith(".bias") and name not in params_dict:
  884. continue
  885. param = params_dict[name]
  886. weight_loader = param.weight_loader
  887. weight_loader(param, loaded_weight, shard_id)
  888. break
  889. else:
  890. # Skip loading extra bias for GPTQ models.
  891. if name.endswith(".bias") and name not in params_dict:
  892. continue
  893. # Remapping the name of FP8 kv-scale.
  894. if name.endswith("kv_scale"):
  895. remapped_kv_scale_name = name.replace(
  896. ".kv_scale", ".attn.kv_scale")
  897. if remapped_kv_scale_name not in params_dict:
  898. print_warning_once(
  899. "Found kv scale in the checkpoint (e.g. "
  900. f"{name}), but not found the expected name in "
  901. f"the model (e.g. {remapped_kv_scale_name}). "
  902. "kv-scale is not loaded.")
  903. continue
  904. else:
  905. name = remapped_kv_scale_name
  906. param = params_dict[name]
  907. weight_loader = getattr(param, "weight_loader",
  908. default_weight_loader)
  909. weight_loader(param, loaded_weight)
  910. if use_default_weight_loading and name in params_dict:
  911. param = params_dict[name]
  912. weight_loader = getattr(param, "weight_loader",
  913. default_weight_loader)
  914. weight_loader(param, loaded_weight)