chameleon.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. from functools import cached_property
  2. from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
  3. TypedDict)
  4. import torch
  5. import torch.nn.functional as F
  6. from PIL import Image
  7. from torch import nn
  8. from transformers import ChameleonConfig, ChameleonVQVAEConfig
  9. from aphrodite.attention import Attention, AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, MultiModalConfig
  11. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  12. SequenceData)
  13. from aphrodite.common.utils import print_warning_once
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  16. from aphrodite.modeling.layers.activation import SiluAndMul
  17. from aphrodite.modeling.layers.layernorm import RMSNorm
  18. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  19. QKVParallelLinear,
  20. RowParallelLinear)
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.rotary_embedding import get_rope
  23. from aphrodite.modeling.layers.sampler import Sampler
  24. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  25. ParallelLMHead, VocabParallelEmbedding)
  26. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  27. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  28. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  29. from aphrodite.multimodal.image import (cached_get_tokenizer,
  30. repeat_and_pad_image_tokens)
  31. from aphrodite.quantization.base_config import QuantizationConfig
  32. from .interfaces import SupportsVision
  33. # These configs are not part of the model config but the preprocessor
  34. # and processor files, so we hardcode them in the model file for now.
  35. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
  36. CHAMELEON_IMAGE_SEQ_LENGTH = 1024
  37. CHAMELEON_IMAGE_TOKEN_ID = 8711
  38. CHAMELEON_IMAGE_START_TOKEN_ID = 8197
  39. CHAMELEON_IMAGE_END_TOKEN_ID = 8196
  40. CHAMELEON_SEP_TOKEN_ID = 8710
  41. class ChameleonImagePixelInputs(TypedDict):
  42. type: Literal["pixel_values"]
  43. data: torch.Tensor
  44. """Shape: `(batch_size, num_channels, height, width)`"""
  45. def get_max_chameleon_image_tokens(ctx: InputContext):
  46. return CHAMELEON_IMAGE_SEQ_LENGTH
  47. def dummy_seq_data_for_chameleon(
  48. seq_len: int,
  49. *,
  50. image_token_id: int,
  51. image_feature_size_override: Optional[int] = None,
  52. ):
  53. if image_feature_size_override is None:
  54. image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
  55. else:
  56. image_feature_size = image_feature_size_override
  57. token_ids = [image_token_id] * image_feature_size
  58. token_ids += [0] * (seq_len - image_feature_size)
  59. return SequenceData(token_ids)
  60. def dummy_image_for_chameleon(
  61. image_width_override: Optional[int] = None,
  62. image_height_override: Optional[int] = None,
  63. ):
  64. width = CHAMELEON_CROP_SIZE_WIDTH
  65. height = CHAMELEON_CROP_SIZE_HEIGHT
  66. if image_width_override is not None:
  67. width = image_width_override
  68. if image_height_override is not None:
  69. height = image_height_override
  70. image = Image.new("RGB", (width, height), color=0)
  71. return {"image": image}
  72. def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
  73. seq_data = dummy_seq_data_for_chameleon(
  74. seq_len,
  75. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  76. )
  77. mm_data = dummy_image_for_chameleon()
  78. return seq_data, mm_data
  79. def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
  80. """
  81. Processing input prompt to insert required tokens for image placeholder.
  82. See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
  83. """ # noqa
  84. multi_modal_data = llm_inputs.get("multi_modal_data")
  85. if multi_modal_data is None or "image" not in multi_modal_data:
  86. return llm_inputs
  87. model_config = ctx.model_config
  88. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  89. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  90. tokenizer,
  91. llm_inputs.get("prompt"),
  92. llm_inputs["prompt_token_ids"],
  93. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  94. repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
  95. pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
  96. pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
  97. )
  98. # Appending sep token for chat mode to follow default processor
  99. # behavior
  100. if new_prompt is not None:
  101. new_prompt += tokenizer.sep_token
  102. new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
  103. # NOTE: Create a defensive copy of the original inputs
  104. return LLMInputs(prompt_token_ids=new_token_ids,
  105. prompt=new_prompt,
  106. multi_modal_data=multi_modal_data)
  107. class ChameleonLayerNorm(nn.LayerNorm):
  108. def __init__(self, hidden_size, *args, **kwargs):
  109. super().__init__(hidden_size, *args, **kwargs)
  110. self.normalized_shape = (hidden_size[-1], )
  111. def forward(self, hidden_states):
  112. hidden_states = F.layer_norm(hidden_states,
  113. self.normalized_shape,
  114. None,
  115. None,
  116. eps=1e-5)
  117. hidden_states = hidden_states * self.weight + self.bias
  118. return hidden_states
  119. # Copied from aphrodite.modeling.models.llama.LlamaMLP -> ChameleonMLP
  120. class ChameleonMLP(nn.Module):
  121. def __init__(
  122. self,
  123. hidden_size: int,
  124. intermediate_size: int,
  125. hidden_act: str,
  126. quant_config: Optional[QuantizationConfig] = None,
  127. bias: bool = False,
  128. ) -> None:
  129. super().__init__()
  130. self.gate_up_proj = MergedColumnParallelLinear(
  131. input_size=hidden_size,
  132. output_sizes=[intermediate_size] * 2,
  133. bias=bias,
  134. quant_config=quant_config)
  135. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  136. output_size=hidden_size,
  137. bias=bias,
  138. quant_config=quant_config)
  139. if hidden_act != "silu":
  140. raise ValueError(f"Unsupported activation: {hidden_act}. "
  141. "Only silu is supported for now.")
  142. self.act_fn = SiluAndMul()
  143. def forward(self, x):
  144. gate_up, _ = self.gate_up_proj(x)
  145. x = self.act_fn(gate_up)
  146. x, _ = self.down_proj(x)
  147. return x
  148. # Modified from aphrodite.modeling.models.llama.LlamaAttention -> ChameleonAttention #noqa
  149. class ChameleonAttention(nn.Module):
  150. def __init__(
  151. self,
  152. hidden_size: int,
  153. num_heads: int,
  154. num_kv_heads: int,
  155. rope_theta: float = 10000,
  156. rope_scaling: Optional[Dict[str, Any]] = None,
  157. max_position_embeddings: int = 4096,
  158. quant_config: Optional[QuantizationConfig] = None,
  159. bias: bool = False,
  160. cache_config: Optional[CacheConfig] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = hidden_size
  164. tp_size = get_tensor_model_parallel_world_size()
  165. self.total_num_heads = num_heads
  166. assert self.total_num_heads % tp_size == 0
  167. self.num_heads = self.total_num_heads // tp_size
  168. self.total_num_kv_heads = num_kv_heads
  169. if self.total_num_kv_heads >= tp_size:
  170. # Number of KV heads is greater than TP size, so we partition
  171. # the KV heads across multiple tensor parallel GPUs.
  172. assert self.total_num_kv_heads % tp_size == 0
  173. else:
  174. # Number of KV heads is less than TP size, so we replicate
  175. # the KV heads across multiple tensor parallel GPUs.
  176. assert tp_size % self.total_num_kv_heads == 0
  177. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  178. self.head_dim = hidden_size // self.total_num_heads
  179. self.q_size = self.num_heads * self.head_dim
  180. self.kv_size = self.num_kv_heads * self.head_dim
  181. self.scaling = self.head_dim**-0.5
  182. self.rope_theta = rope_theta
  183. self.max_position_embeddings = max_position_embeddings
  184. self.qkv_proj = QKVParallelLinear(
  185. hidden_size=hidden_size,
  186. head_size=self.head_dim,
  187. total_num_heads=self.total_num_heads,
  188. total_num_kv_heads=self.total_num_kv_heads,
  189. bias=bias,
  190. quant_config=quant_config,
  191. )
  192. self.o_proj = RowParallelLinear(
  193. input_size=self.total_num_heads * self.head_dim,
  194. output_size=hidden_size,
  195. bias=bias,
  196. quant_config=quant_config,
  197. )
  198. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  199. self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
  200. self.rotary_emb = get_rope(
  201. self.head_dim,
  202. rotary_dim=self.head_dim,
  203. max_position=max_position_embeddings,
  204. base=rope_theta,
  205. rope_scaling=rope_scaling,
  206. )
  207. self.attn = Attention(self.num_heads,
  208. self.head_dim,
  209. self.scaling,
  210. num_kv_heads=self.num_kv_heads,
  211. cache_config=cache_config,
  212. quant_config=quant_config)
  213. def _apply_qk_norm(self, q: torch.Tensor,
  214. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  215. # reshape for layernorm
  216. q = q.reshape(-1, self.num_heads, self.head_dim)
  217. k = k.reshape(-1, self.num_kv_heads, self.head_dim)
  218. q = self.q_norm(q)
  219. k = self.k_norm(k)
  220. q = q.view(*q.shape[:-2], -1)
  221. k = k.view(*k.shape[:-2], -1)
  222. return q, k
  223. def forward(
  224. self,
  225. positions: torch.Tensor,
  226. hidden_states: torch.Tensor,
  227. kv_cache: torch.Tensor,
  228. attn_metadata: AttentionMetadata,
  229. ) -> torch.Tensor:
  230. qkv, _ = self.qkv_proj(hidden_states)
  231. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  232. q, k = self._apply_qk_norm(q, k)
  233. q, k = self.rotary_emb(positions, q, k)
  234. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  235. output, _ = self.o_proj(attn_output)
  236. return output
  237. class ChameleonDecoderLayer(nn.Module):
  238. def __init__(
  239. self,
  240. config: ChameleonConfig,
  241. cache_config: Optional[CacheConfig] = None,
  242. quant_config: Optional[QuantizationConfig] = None,
  243. ) -> None:
  244. super().__init__()
  245. self.hidden_size = config.hidden_size
  246. rope_theta = getattr(config, "rope_theta", 10000)
  247. rope_scaling = getattr(config, "rope_scaling", None)
  248. if rope_scaling is not None and getattr(
  249. config, "original_max_position_embeddings", None):
  250. rope_scaling["original_max_position_embeddings"] = (
  251. config.original_max_position_embeddings)
  252. max_position_embeddings = getattr(config, "max_position_embeddings",
  253. 4096)
  254. self.self_attn = ChameleonAttention(
  255. hidden_size=self.hidden_size,
  256. num_heads=config.num_attention_heads,
  257. num_kv_heads=getattr(config, "num_key_value_heads",
  258. config.num_attention_heads),
  259. rope_theta=rope_theta,
  260. rope_scaling=rope_scaling,
  261. max_position_embeddings=max_position_embeddings,
  262. quant_config=quant_config,
  263. bias=False,
  264. cache_config=cache_config,
  265. )
  266. self.mlp = ChameleonMLP(
  267. hidden_size=self.hidden_size,
  268. intermediate_size=config.intermediate_size,
  269. hidden_act=config.hidden_act,
  270. quant_config=quant_config,
  271. bias=getattr(config, "mlp_bias", False),
  272. )
  273. self.input_layernorm = RMSNorm(config.hidden_size,
  274. eps=config.rms_norm_eps)
  275. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  276. eps=config.rms_norm_eps)
  277. def forward(
  278. self,
  279. positions: torch.Tensor,
  280. hidden_states: torch.Tensor,
  281. kv_cache: torch.Tensor,
  282. attn_metadata: AttentionMetadata,
  283. residual: Optional[torch.Tensor],
  284. ) -> Tuple[torch.Tensor, torch.Tensor]:
  285. if residual is None:
  286. residual = hidden_states
  287. hidden_states = self.input_layernorm(hidden_states)
  288. else:
  289. hidden_states, residual = self.input_layernorm(
  290. hidden_states, residual)
  291. hidden_states = self.self_attn(
  292. positions=positions,
  293. hidden_states=hidden_states,
  294. kv_cache=kv_cache,
  295. attn_metadata=attn_metadata,
  296. )
  297. # Fully Connected
  298. hidden_states, residual = self.post_attention_layernorm(
  299. hidden_states, residual)
  300. hidden_states = self.mlp(hidden_states)
  301. return hidden_states, residual
  302. class ChameleonSwinDecoderLayer(nn.Module):
  303. def __init__(
  304. self,
  305. config: ChameleonConfig,
  306. cache_config: Optional[CacheConfig] = None,
  307. quant_config: Optional[QuantizationConfig] = None,
  308. ) -> None:
  309. super().__init__()
  310. self.hidden_size = config.hidden_size
  311. rope_theta = getattr(config, "rope_theta", 10000)
  312. rope_scaling = getattr(config, "rope_scaling", None)
  313. if rope_scaling is not None and getattr(
  314. config, "original_max_position_embeddings", None):
  315. rope_scaling["original_max_position_embeddings"] = (
  316. config.original_max_position_embeddings)
  317. max_position_embeddings = getattr(config, "max_position_embeddings",
  318. 4096)
  319. self.self_attn = ChameleonAttention(
  320. hidden_size=self.hidden_size,
  321. num_heads=config.num_attention_heads,
  322. num_kv_heads=getattr(config, "num_key_value_heads",
  323. config.num_attention_heads),
  324. rope_theta=rope_theta,
  325. rope_scaling=rope_scaling,
  326. max_position_embeddings=max_position_embeddings,
  327. quant_config=quant_config,
  328. bias=False,
  329. cache_config=cache_config,
  330. )
  331. self.mlp = ChameleonMLP(
  332. hidden_size=self.hidden_size,
  333. intermediate_size=config.intermediate_size,
  334. hidden_act=config.hidden_act,
  335. quant_config=quant_config,
  336. bias=getattr(config, "mlp_bias", False),
  337. )
  338. self.input_layernorm = RMSNorm(config.hidden_size,
  339. eps=config.rms_norm_eps)
  340. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  341. eps=config.rms_norm_eps)
  342. def forward(
  343. self,
  344. positions: torch.Tensor,
  345. hidden_states: torch.Tensor,
  346. kv_cache: torch.Tensor,
  347. attn_metadata: AttentionMetadata,
  348. residual: Optional[torch.Tensor],
  349. ) -> Tuple[torch.Tensor, torch.Tensor]:
  350. residual = hidden_states
  351. hidden_states = self.self_attn(
  352. positions=positions,
  353. hidden_states=hidden_states,
  354. kv_cache=kv_cache,
  355. attn_metadata=attn_metadata,
  356. )
  357. hidden_states = self.input_layernorm(hidden_states)
  358. hidden_states = hidden_states + residual
  359. # Fully Connected
  360. residual = hidden_states
  361. hidden_states = self.mlp(hidden_states)
  362. hidden_states = self.post_attention_layernorm(hidden_states)
  363. hidden_states = residual + hidden_states
  364. return hidden_states, residual
  365. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
  366. class ChameleonVQVAEVectorQuantizer(nn.Module):
  367. def __init__(self, config: ChameleonVQVAEConfig):
  368. super().__init__()
  369. self.num_embeddings = config.num_embeddings
  370. self.embedding_dim = config.embed_dim
  371. self.beta = getattr(config, "beta", 0.25)
  372. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  373. self.re_embed = self.num_embeddings
  374. def forward(self, hidden_state: torch.Tensor):
  375. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  376. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  377. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  378. distances = (
  379. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
  380. torch.sum(self.embedding.weight**2, dim=1) -
  381. 2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
  382. self.embedding.weight.transpose(0, 1)))
  383. min_encoding_indices = torch.argmin(distances, dim=1)
  384. hidden_state_quant = self.embedding(min_encoding_indices).view(
  385. hidden_state.shape)
  386. # compute loss for embedding
  387. loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
  388. 2) + self.beta * torch.mean(
  389. (hidden_state_quant - hidden_state.detach())**2)
  390. # preserve gradients
  391. hidden_state_quant = hidden_state + (hidden_state_quant -
  392. hidden_state).detach()
  393. # reshape back to match original input shape
  394. hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
  395. 2).contiguous()
  396. return hidden_state_quant, loss, min_encoding_indices
  397. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
  398. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  399. def __init__(self, in_channels: int):
  400. super().__init__()
  401. self.conv = nn.Conv2d(in_channels,
  402. in_channels,
  403. kernel_size=3,
  404. stride=2,
  405. padding=0)
  406. def forward(self, hidden_states: torch.Tensor):
  407. # no asymmetric padding in torch conv, must do it ourselves
  408. hidden_states = F.pad(hidden_states,
  409. pad=(0, 1, 0, 1),
  410. mode="constant",
  411. value=0)
  412. hidden_states = self.conv(hidden_states)
  413. return hidden_states
  414. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
  415. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  416. def __init__(
  417. self,
  418. config: ChameleonVQVAEConfig,
  419. in_channels: int,
  420. out_channels=None,
  421. conv_shortcut=False,
  422. ):
  423. super().__init__()
  424. self.in_channels = in_channels
  425. self.out_channels = in_channels if out_channels is None \
  426. else out_channels
  427. self.use_conv_shortcut = conv_shortcut
  428. self.norm1 = torch.nn.GroupNorm(num_groups=32,
  429. num_channels=in_channels,
  430. eps=1e-6,
  431. affine=True)
  432. self.conv1 = torch.nn.Conv2d(in_channels,
  433. out_channels,
  434. kernel_size=3,
  435. stride=1,
  436. padding=1)
  437. self.norm2 = torch.nn.GroupNorm(num_groups=32,
  438. num_channels=out_channels,
  439. eps=1e-6,
  440. affine=True)
  441. self.dropout = torch.nn.Dropout(config.dropout)
  442. self.conv2 = torch.nn.Conv2d(out_channels,
  443. out_channels,
  444. kernel_size=3,
  445. stride=1,
  446. padding=1)
  447. if self.in_channels != self.out_channels:
  448. if self.use_conv_shortcut:
  449. self.conv_shortcut = torch.nn.Conv2d(in_channels,
  450. out_channels,
  451. kernel_size=3,
  452. stride=1,
  453. padding=1)
  454. else:
  455. self.nin_shortcut = torch.nn.Conv2d(in_channels,
  456. out_channels,
  457. kernel_size=1,
  458. stride=1,
  459. padding=0)
  460. def forward(self, hidden_states: torch.Tensor):
  461. residual = hidden_states
  462. hidden_states = self.norm1(hidden_states)
  463. hidden_states *= torch.sigmoid(hidden_states)
  464. hidden_states = self.conv1(hidden_states)
  465. hidden_states = self.norm2(hidden_states)
  466. hidden_states *= torch.sigmoid(hidden_states)
  467. hidden_states = self.dropout(hidden_states)
  468. hidden_states = self.conv2(hidden_states)
  469. if self.in_channels != self.out_channels:
  470. if self.use_conv_shortcut:
  471. residual = self.conv_shortcut(residual)
  472. else:
  473. residual = self.nin_shortcut(residual)
  474. return residual + hidden_states
  475. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
  476. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  477. def __init__(self, in_channels: int):
  478. super().__init__()
  479. self.in_channels = in_channels
  480. self.norm = torch.nn.GroupNorm(num_groups=32,
  481. num_channels=in_channels,
  482. eps=1e-6,
  483. affine=True)
  484. self.q = torch.nn.Conv2d(in_channels,
  485. in_channels,
  486. kernel_size=1,
  487. stride=1,
  488. padding=0)
  489. self.k = torch.nn.Conv2d(in_channels,
  490. in_channels,
  491. kernel_size=1,
  492. stride=1,
  493. padding=0)
  494. self.v = torch.nn.Conv2d(in_channels,
  495. in_channels,
  496. kernel_size=1,
  497. stride=1,
  498. padding=0)
  499. self.proj_out = torch.nn.Conv2d(in_channels,
  500. in_channels,
  501. kernel_size=1,
  502. stride=1,
  503. padding=0)
  504. def forward(self, hidden_states: torch.Tensor):
  505. residual = hidden_states
  506. hidden_states = self.norm(hidden_states)
  507. query_states = self.q(hidden_states)
  508. key_states = self.k(hidden_states)
  509. value_states = self.v(hidden_states)
  510. # compute attention
  511. batch_size, channels, height, width = query_states.shape
  512. query_states = query_states.reshape(batch_size, channels,
  513. height * width).permute(0, 2, 1)
  514. key_states = key_states.reshape(batch_size, channels, height * width)
  515. attn_weights = torch.bmm(query_states, key_states)
  516. attn_weights = attn_weights * (int(channels)**(-0.5))
  517. attn_weights = F.softmax(attn_weights, dim=2)
  518. # attend to values
  519. value_states = value_states.reshape(batch_size, channels,
  520. height * width)
  521. attn_weights = attn_weights.permute(0, 2, 1)
  522. attn_output = torch.bmm(value_states,
  523. attn_weights).reshape(batch_size, channels,
  524. height, width)
  525. attn_output = self.proj_out(attn_output)
  526. return residual + attn_output
  527. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
  528. class ChameleonVQVAEEncoder(nn.Module):
  529. def __init__(self, config: ChameleonVQVAEConfig):
  530. super().__init__()
  531. self.num_resolutions = len(config.channel_multiplier)
  532. self.num_res_blocks = config.num_res_blocks
  533. base_channels = config.base_channels
  534. resolution = config.resolution
  535. in_channels = config.in_channels
  536. double_latent = config.double_latent
  537. latent_channels = config.latent_channels
  538. channel_multiplier = config.channel_multiplier
  539. self.conv_in = torch.nn.Conv2d(in_channels,
  540. base_channels,
  541. kernel_size=3,
  542. stride=1,
  543. padding=1)
  544. curr_res = resolution
  545. in_channel_multiplier = (1, ) + tuple(channel_multiplier)
  546. self.in_channel_multiplier = in_channel_multiplier
  547. self.down = nn.ModuleList()
  548. for i_level in range(self.num_resolutions):
  549. block = nn.ModuleList()
  550. attn = nn.ModuleList()
  551. block_in = base_channels * in_channel_multiplier[i_level]
  552. block_out = base_channels * channel_multiplier[i_level]
  553. for i_block in range(self.num_res_blocks):
  554. block.append(
  555. ChameleonVQVAEEncoderResnetBlock(
  556. config=config,
  557. in_channels=block_in,
  558. out_channels=block_out,
  559. ))
  560. block_in = block_out
  561. if (config.attn_resolutions is not None
  562. and curr_res in config.attn_resolutions
  563. and config.attn_type == "vanilla"):
  564. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  565. down = nn.Module()
  566. down.block = block
  567. down.attn = attn
  568. if i_level != self.num_resolutions - 1:
  569. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  570. curr_res = curr_res // 2
  571. self.down.append(down)
  572. self.mid = nn.Module()
  573. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  574. config=config,
  575. in_channels=block_in,
  576. out_channels=block_in,
  577. )
  578. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
  579. block_in) if config.attn_type == "vanilla" else nn.Identity()
  580. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  581. config=config,
  582. in_channels=block_in,
  583. out_channels=block_in,
  584. )
  585. self.norm_out = torch.nn.GroupNorm(num_groups=32,
  586. num_channels=block_in,
  587. eps=1e-6,
  588. affine=True)
  589. self.conv_out = torch.nn.Conv2d(
  590. block_in,
  591. 2 * latent_channels if double_latent else latent_channels,
  592. kernel_size=3,
  593. stride=1,
  594. padding=1,
  595. )
  596. def forward(self, pixel_values: torch.Tensor):
  597. # downsampling
  598. hidden_states = [self.conv_in(pixel_values)]
  599. for i_level in range(self.num_resolutions):
  600. for i_block in range(self.num_res_blocks):
  601. hidden_state = self.down[i_level].block[i_block](
  602. hidden_states[-1], )
  603. if len(self.down[i_level].attn) > 0:
  604. hidden_state = self.down[i_level].attn[i_block](
  605. hidden_state)
  606. hidden_states.append(hidden_state)
  607. if i_level != self.num_resolutions - 1:
  608. hidden_states.append(self.down[i_level].downsample(
  609. hidden_states[-1]))
  610. # middle
  611. last_hidden_state = hidden_states[-1]
  612. last_hidden_state = self.mid.block_1(last_hidden_state)
  613. last_hidden_state = self.mid.attn_1(last_hidden_state)
  614. last_hidden_state = self.mid.block_2(last_hidden_state)
  615. # end
  616. last_hidden_state = self.norm_out(last_hidden_state)
  617. last_hidden_state *= torch.sigmoid(last_hidden_state)
  618. last_hidden_state = self.conv_out(last_hidden_state)
  619. return last_hidden_state
  620. # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
  621. class ChameleonVQVAE(nn.Module):
  622. def __init__(self, config: ChameleonVQVAEConfig):
  623. super().__init__()
  624. self.encoder = ChameleonVQVAEEncoder(config)
  625. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  626. self.quant_conv = torch.nn.Conv2d(config.latent_channels,
  627. config.embed_dim, 1)
  628. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
  629. config.latent_channels, 1)
  630. self.eval() # Chameleon's VQ model is frozen
  631. def encode(
  632. self, pixel_values: torch.Tensor
  633. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  634. hidden_states = self.encoder(pixel_values)
  635. hidden_states = self.quant_conv(hidden_states)
  636. quant, emb_loss, indices = self.quantize(hidden_states)
  637. return quant, emb_loss, indices
  638. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
  639. class ChameleonImageVocabularyMapping:
  640. """
  641. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  642. """
  643. def __init__(self, vocab_map: Dict[str, int]):
  644. self.vocab_map = vocab_map
  645. self.image_token_id = vocab_map.get("<image>")
  646. @cached_property
  647. def val2name(self):
  648. return {v: k for k, v in self.vocab_map.items()}
  649. @cached_property
  650. def image_tokens(self):
  651. return sorted([
  652. val for name, val in self.vocab_map.items()
  653. if name.startswith("IMGIMG")
  654. ])
  655. @cached_property
  656. def bpe2img(self):
  657. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  658. def remap(old_name: str) -> str:
  659. return "".join(
  660. img_tkn_chr_mapping.get(c, c)
  661. for c in old_name[len("IMGIMG"):-1])
  662. return {
  663. tok: int(remap(self.val2name[tok]))
  664. for tok in self.image_tokens
  665. }
  666. @cached_property
  667. def img2bpe(self):
  668. return {v: k for k, v in self.bpe2img.items()}
  669. @cached_property
  670. def bpe2img_search_tensors(self):
  671. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
  672. sorted(self.bpe2img.values()))
  673. @cached_property
  674. def img2bpe_mapping_tensor(self):
  675. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  676. for k, v in self.img2bpe.items():
  677. mapping[k] = v
  678. return mapping
  679. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  680. device = img_batch.device
  681. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  682. return img_tokens.to(device)
  683. class ChameleonModel(nn.Module):
  684. def __init__(
  685. self,
  686. config: ChameleonConfig,
  687. cache_config: Optional[CacheConfig] = None,
  688. quant_config: Optional[QuantizationConfig] = None,
  689. ) -> None:
  690. super().__init__()
  691. self.config = config
  692. self.padding_idx = config.pad_token_id
  693. self.vocab_size = config.vocab_size
  694. self.embed_tokens = VocabParallelEmbedding(
  695. self.vocab_size,
  696. config.hidden_size,
  697. )
  698. self.vocabulary_mapping = ChameleonImageVocabularyMapping(
  699. config.vocabulary_map)
  700. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
  701. else ChameleonSwinDecoderLayer
  702. self.layers = nn.ModuleList([
  703. decoder_layer(config=config,
  704. cache_config=cache_config,
  705. quant_config=quant_config)
  706. for _ in range(config.num_hidden_layers)
  707. ])
  708. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  709. self.vqmodel = ChameleonVQVAE(config.vq_config)
  710. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  711. return self.embed_tokens(input_ids)
  712. def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
  713. """
  714. Tokenizes images into discrete tokens with VQGAN module. Converts
  715. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  716. special tokens.
  717. """
  718. batch_size = pixel_values.shape[0]
  719. _, _, image_toks = self.vqmodel.encode(pixel_values)
  720. bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
  721. bpe_toks = bpe_toks.view(batch_size, -1)
  722. return bpe_toks
  723. def forward(
  724. self,
  725. input_ids: Optional[torch.Tensor],
  726. positions: torch.Tensor,
  727. kv_caches: List[torch.Tensor],
  728. attn_metadata: AttentionMetadata,
  729. inputs_embeds: Optional[torch.Tensor] = None,
  730. ) -> torch.Tensor:
  731. if inputs_embeds is not None:
  732. hidden_states = inputs_embeds
  733. else:
  734. hidden_states = self.get_input_embeddings(input_ids)
  735. residual = None
  736. for i in range(len(self.layers)):
  737. layer = self.layers[i]
  738. hidden_states, residual = layer(
  739. positions,
  740. hidden_states,
  741. kv_caches[i],
  742. attn_metadata,
  743. residual,
  744. )
  745. hidden_states, _ = self.norm(hidden_states, residual)
  746. return hidden_states
  747. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  748. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
  749. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
  750. @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
  751. class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
  752. def __init__(
  753. self,
  754. config: ChameleonConfig,
  755. multimodal_config: MultiModalConfig,
  756. cache_config: Optional[CacheConfig] = None,
  757. quant_config: Optional[QuantizationConfig] = None,
  758. ) -> None:
  759. super().__init__()
  760. self.config = config
  761. self.multimodal_config = multimodal_config
  762. self.model = ChameleonModel(config, cache_config, quant_config)
  763. self.unpadded_vocab_size = config.vocab_size
  764. self.lm_head = ParallelLMHead(
  765. self.unpadded_vocab_size,
  766. config.hidden_size,
  767. )
  768. if config.tie_word_embeddings:
  769. self.lm_head.weight = self.model.embed_tokens.weight
  770. logit_scale = getattr(config, "logit_scale", 1.0)
  771. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  772. config.vocab_size, logit_scale)
  773. self.sampler = Sampler()
  774. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  775. expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
  776. CHAMELEON_CROP_SIZE_WIDTH)
  777. actual_dims = tuple(data.shape[1:])
  778. if actual_dims != expected_dims:
  779. expected_expr = ("batch_size", *map(str, expected_dims))
  780. raise ValueError(
  781. f"The expected shape of pixel values is {expected_expr}. "
  782. f"You supplied {tuple(data.shape)}.")
  783. return data
  784. def _parse_and_validate_image_input(
  785. self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
  786. pixel_values = kwargs.pop("pixel_values", None)
  787. if pixel_values is None:
  788. return None
  789. if not isinstance(pixel_values, torch.Tensor):
  790. raise ValueError("Incorrect type of pixel values. "
  791. f"Got type: {type(pixel_values)}")
  792. return ChameleonImagePixelInputs(
  793. type="pixel_values",
  794. data=self._validate_pixel_values(pixel_values),
  795. )
  796. def forward(
  797. self,
  798. input_ids: torch.Tensor,
  799. positions: torch.Tensor,
  800. kv_caches: List[torch.Tensor],
  801. attn_metadata: AttentionMetadata,
  802. intermediate_tensors: Optional[IntermediateTensors] = None,
  803. **kwargs,
  804. ) -> torch.Tensor:
  805. image_input = self._parse_and_validate_image_input(**kwargs)
  806. if image_input is not None:
  807. assert self.model.vqmodel is not None
  808. image_tokens = self.model.get_image_tokens(image_input["data"].to(
  809. self.config.torch_dtype))
  810. image_token_id = self.model.vocabulary_mapping.image_token_id
  811. special_image_mask = input_ids == image_token_id
  812. image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
  813. input_ids = input_ids.masked_scatter(special_image_mask,
  814. image_tokens)
  815. hidden_states = self.model(input_ids, positions, kv_caches,
  816. attn_metadata)
  817. return hidden_states
  818. def compute_logits(self, hidden_states: torch.Tensor,
  819. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  820. logits = self.logits_processor(self.lm_head, hidden_states,
  821. sampling_metadata)
  822. # Disallow image tokens which does not include special
  823. # begin-image and end-image tokens
  824. image_tokens = self.model.vocabulary_mapping.image_tokens
  825. logits[:, image_tokens] = torch.finfo(logits.dtype).min
  826. return logits
  827. def sample(
  828. self,
  829. logits: torch.Tensor,
  830. sampling_metadata: SamplingMetadata,
  831. ) -> Optional[SamplerOutput]:
  832. next_tokens = self.sampler(logits, sampling_metadata)
  833. return next_tokens
  834. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  835. stacked_params_mapping = [
  836. # (param_name, shard_name, shard_id)
  837. (".qkv_proj", ".q_proj", "q"),
  838. (".qkv_proj", ".k_proj", "k"),
  839. (".qkv_proj", ".v_proj", "v"),
  840. (".gate_up_proj", ".gate_proj", 0),
  841. (".gate_up_proj", ".up_proj", 1),
  842. ]
  843. params_dict = dict(self.named_parameters())
  844. for name, loaded_weight in weights:
  845. if "rotary_emb.inv_freq" in name:
  846. continue
  847. if ("rotary_emb.cos_cached" in name
  848. or "rotary_emb.sin_cached" in name):
  849. # Models trained using ColossalAI may include these tensors in
  850. # the checkpoint. Skip them.
  851. continue
  852. # With tie_word_embeddings, we can skip lm_head.weight
  853. # The weight might appear unnecessarily in the files if the model is
  854. # processed with quantization, LoRA, fine-tuning, etc.
  855. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  856. continue
  857. use_default_weight_loading = False
  858. if "vqmodel" in name:
  859. if self.model.vqmodel is not None:
  860. # We only do sharding for language model and
  861. # not vqvae for now.
  862. use_default_weight_loading = True
  863. else:
  864. for (param_name, weight_name,
  865. shard_id) in stacked_params_mapping:
  866. if weight_name not in name:
  867. continue
  868. name = name.replace(weight_name, param_name)
  869. # Skip loading extra bias for GPTQ models.
  870. if name.endswith(".bias") and name not in params_dict:
  871. continue
  872. param = params_dict[name]
  873. weight_loader = param.weight_loader
  874. weight_loader(param, loaded_weight, shard_id)
  875. break
  876. else:
  877. # Skip loading extra bias for GPTQ models.
  878. if name.endswith(".bias") and name not in params_dict:
  879. continue
  880. # Remapping the name of FP8 kv-scale.
  881. if name.endswith("kv_scale"):
  882. remapped_kv_scale_name = name.replace(
  883. ".kv_scale", ".attn.kv_scale")
  884. if remapped_kv_scale_name not in params_dict:
  885. print_warning_once(
  886. "Found kv scale in the checkpoint (e.g. "
  887. f"{name}), but not found the expected name in "
  888. f"the model (e.g. {remapped_kv_scale_name}). "
  889. "kv-scale is not loaded.")
  890. continue
  891. else:
  892. name = remapped_kv_scale_name
  893. param = params_dict[name]
  894. weight_loader = getattr(param, "weight_loader",
  895. default_weight_loader)
  896. weight_loader(param, loaded_weight)
  897. if use_default_weight_loading and name in params_dict:
  898. param = params_dict[name]
  899. weight_loader = getattr(param, "weight_loader",
  900. default_weight_loader)
  901. weight_loader(param, loaded_weight)