1
0

chameleon.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070
  1. from array import array
  2. from functools import cached_property
  3. from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
  4. Tuple, TypedDict)
  5. import torch
  6. import torch.nn.functional as F
  7. from PIL import Image
  8. from torch import nn
  9. from transformers import ChameleonConfig, ChameleonVQVAEConfig
  10. from aphrodite.attention import Attention, AttentionMetadata
  11. from aphrodite.common.config import CacheConfig, MultiModalConfig
  12. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  13. SequenceData)
  14. from aphrodite.common.utils import print_warning_once
  15. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  16. from aphrodite.distributed import get_tensor_model_parallel_world_size
  17. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  18. from aphrodite.modeling.layers.activation import SiluAndMul
  19. from aphrodite.modeling.layers.layernorm import RMSNorm
  20. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  21. QKVParallelLinear,
  22. RowParallelLinear)
  23. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  24. from aphrodite.modeling.layers.rotary_embedding import get_rope
  25. from aphrodite.modeling.layers.sampler import Sampler
  26. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  27. ParallelLMHead, VocabParallelEmbedding)
  28. from aphrodite.modeling.model_loader.weight_utils import (
  29. default_weight_loader, row_parallel_weight_loader)
  30. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  31. from aphrodite.modeling.utils import set_weight_attrs
  32. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  33. from aphrodite.multimodal.image import (cached_get_tokenizer,
  34. repeat_and_pad_image_tokens)
  35. from aphrodite.quantization.base_config import QuantizationConfig
  36. from .interfaces import SupportsMultiModal
  37. # These configs are not part of the model config but the preprocessor
  38. # and processor files, so we hardcode them in the model file for now.
  39. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
  40. CHAMELEON_IMAGE_SEQ_LENGTH = 1024
  41. CHAMELEON_IMAGE_TOKEN_ID = 8711
  42. CHAMELEON_IMAGE_START_TOKEN_ID = 8197
  43. CHAMELEON_IMAGE_END_TOKEN_ID = 8196
  44. CHAMELEON_SEP_TOKEN_ID = 8710
  45. class ChameleonImagePixelInputs(TypedDict):
  46. type: Literal["pixel_values"]
  47. data: torch.Tensor
  48. """Shape: `(batch_size, num_channels, height, width)`"""
  49. def get_max_chameleon_image_tokens(ctx: InputContext):
  50. return CHAMELEON_IMAGE_SEQ_LENGTH
  51. def dummy_seq_data_for_chameleon(
  52. seq_len: int,
  53. num_images: int,
  54. *,
  55. image_token_id: int,
  56. image_feature_size_override: Optional[int] = None,
  57. ):
  58. if image_feature_size_override is None:
  59. image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH
  60. else:
  61. image_feature_size = image_feature_size_override
  62. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  63. [image_token_id]) * image_feature_size * num_images
  64. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  65. [0]) * (seq_len - image_feature_size * num_images)
  66. return SequenceData(token_ids)
  67. def dummy_image_for_chameleon(
  68. num_images: int,
  69. *,
  70. image_width_override: Optional[int] = None,
  71. image_height_override: Optional[int] = None,
  72. ):
  73. width = CHAMELEON_CROP_SIZE_WIDTH
  74. height = CHAMELEON_CROP_SIZE_HEIGHT
  75. if image_width_override is not None:
  76. width = image_width_override
  77. if image_height_override is not None:
  78. height = image_height_override
  79. image = Image.new("RGB", (width, height), color=0)
  80. return {"image": image if num_images == 1 else [image] * num_images}
  81. def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
  82. mm_counts: Mapping[str, int]):
  83. num_images = mm_counts["image"]
  84. seq_data = dummy_seq_data_for_chameleon(
  85. seq_len,
  86. num_images,
  87. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  88. )
  89. mm_data = dummy_image_for_chameleon(num_images)
  90. return seq_data, mm_data
  91. def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
  92. """
  93. Processing input prompt to insert required tokens for image placeholder.
  94. See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
  95. """ # noqa
  96. multi_modal_data = llm_inputs.get("multi_modal_data")
  97. if multi_modal_data is None or "image" not in multi_modal_data:
  98. return llm_inputs
  99. model_config = ctx.model_config
  100. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  101. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  102. tokenizer,
  103. llm_inputs.get("prompt"),
  104. llm_inputs["prompt_token_ids"],
  105. image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
  106. repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
  107. pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
  108. pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
  109. )
  110. # Appending sep token for chat mode to follow default processor
  111. # behavior
  112. if new_prompt is not None:
  113. new_prompt += tokenizer.sep_token
  114. new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
  115. # NOTE: Create a defensive copy of the original inputs
  116. return LLMInputs(prompt_token_ids=new_token_ids,
  117. prompt=new_prompt,
  118. multi_modal_data=multi_modal_data)
  119. class ChameleonLayerNorm(nn.LayerNorm):
  120. def __init__(self, hidden_size, *args, **kwargs):
  121. super().__init__(hidden_size, *args, **kwargs)
  122. self.normalized_shape = (hidden_size[-1], )
  123. set_weight_attrs(self.weight,
  124. {"weight_loader": row_parallel_weight_loader})
  125. set_weight_attrs(self.bias,
  126. {"weight_loader": row_parallel_weight_loader})
  127. def forward(self, hidden_states):
  128. hidden_states = F.layer_norm(hidden_states,
  129. self.normalized_shape,
  130. None,
  131. None,
  132. eps=1e-5)
  133. hidden_states = hidden_states * self.weight + self.bias
  134. return hidden_states
  135. # Copied from aphrodite.modeling.models.llama.LlamaMLP -> ChameleonMLP
  136. class ChameleonMLP(nn.Module):
  137. def __init__(
  138. self,
  139. hidden_size: int,
  140. intermediate_size: int,
  141. hidden_act: str,
  142. quant_config: Optional[QuantizationConfig] = None,
  143. bias: bool = False,
  144. ) -> None:
  145. super().__init__()
  146. self.gate_up_proj = MergedColumnParallelLinear(
  147. input_size=hidden_size,
  148. output_sizes=[intermediate_size] * 2,
  149. bias=bias,
  150. quant_config=quant_config)
  151. self.down_proj = RowParallelLinear(input_size=intermediate_size,
  152. output_size=hidden_size,
  153. bias=bias,
  154. quant_config=quant_config)
  155. if hidden_act != "silu":
  156. raise ValueError(f"Unsupported activation: {hidden_act}. "
  157. "Only silu is supported for now.")
  158. self.act_fn = SiluAndMul()
  159. def forward(self, x):
  160. gate_up, _ = self.gate_up_proj(x)
  161. x = self.act_fn(gate_up)
  162. x, _ = self.down_proj(x)
  163. return x
  164. # Modified from aphrodite.modeling.models.llama.LlamaAttention -> ChameleonAttention #noqa
  165. class ChameleonAttention(nn.Module):
  166. def __init__(
  167. self,
  168. hidden_size: int,
  169. num_heads: int,
  170. num_kv_heads: int,
  171. rope_theta: float = 10000,
  172. rope_scaling: Optional[Dict[str, Any]] = None,
  173. max_position_embeddings: int = 4096,
  174. quant_config: Optional[QuantizationConfig] = None,
  175. bias: bool = False,
  176. cache_config: Optional[CacheConfig] = None,
  177. ) -> None:
  178. super().__init__()
  179. self.hidden_size = hidden_size
  180. tp_size = get_tensor_model_parallel_world_size()
  181. self.total_num_heads = num_heads
  182. assert self.total_num_heads % tp_size == 0
  183. self.num_heads = self.total_num_heads // tp_size
  184. self.total_num_kv_heads = num_kv_heads
  185. if self.total_num_kv_heads >= tp_size:
  186. # Number of KV heads is greater than TP size, so we partition
  187. # the KV heads across multiple tensor parallel GPUs.
  188. assert self.total_num_kv_heads % tp_size == 0
  189. else:
  190. # Number of KV heads is less than TP size, so we replicate
  191. # the KV heads across multiple tensor parallel GPUs.
  192. assert tp_size % self.total_num_kv_heads == 0
  193. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  194. self.head_dim = hidden_size // self.total_num_heads
  195. self.q_size = self.num_heads * self.head_dim
  196. self.kv_size = self.num_kv_heads * self.head_dim
  197. self.scaling = self.head_dim**-0.5
  198. self.rope_theta = rope_theta
  199. self.max_position_embeddings = max_position_embeddings
  200. self.qkv_proj = QKVParallelLinear(
  201. hidden_size=hidden_size,
  202. head_size=self.head_dim,
  203. total_num_heads=self.total_num_heads,
  204. total_num_kv_heads=self.total_num_kv_heads,
  205. bias=bias,
  206. quant_config=quant_config,
  207. )
  208. self.o_proj = RowParallelLinear(
  209. input_size=self.total_num_heads * self.head_dim,
  210. output_size=hidden_size,
  211. bias=bias,
  212. quant_config=quant_config,
  213. )
  214. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  215. self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim))
  216. self.rotary_emb = get_rope(
  217. self.head_dim,
  218. rotary_dim=self.head_dim,
  219. max_position=max_position_embeddings,
  220. base=rope_theta,
  221. rope_scaling=rope_scaling,
  222. )
  223. self.attn = Attention(self.num_heads,
  224. self.head_dim,
  225. self.scaling,
  226. num_kv_heads=self.num_kv_heads,
  227. cache_config=cache_config,
  228. quant_config=quant_config)
  229. def _apply_qk_norm(self, q: torch.Tensor,
  230. k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  231. # reshape for layernorm
  232. q = q.reshape(-1, self.num_heads, self.head_dim)
  233. k = k.reshape(-1, self.num_kv_heads, self.head_dim)
  234. q = self.q_norm(q)
  235. k = self.k_norm(k)
  236. q = q.view(*q.shape[:-2], -1)
  237. k = k.view(*k.shape[:-2], -1)
  238. return q, k
  239. def forward(
  240. self,
  241. positions: torch.Tensor,
  242. hidden_states: torch.Tensor,
  243. kv_cache: torch.Tensor,
  244. attn_metadata: AttentionMetadata,
  245. ) -> torch.Tensor:
  246. qkv, _ = self.qkv_proj(hidden_states)
  247. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  248. q, k = self._apply_qk_norm(q, k)
  249. q, k = self.rotary_emb(positions, q, k)
  250. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  251. output, _ = self.o_proj(attn_output)
  252. return output
  253. class ChameleonDecoderLayer(nn.Module):
  254. def __init__(
  255. self,
  256. config: ChameleonConfig,
  257. cache_config: Optional[CacheConfig] = None,
  258. quant_config: Optional[QuantizationConfig] = None,
  259. ) -> None:
  260. super().__init__()
  261. self.hidden_size = config.hidden_size
  262. rope_theta = getattr(config, "rope_theta", 10000)
  263. rope_scaling = getattr(config, "rope_scaling", None)
  264. if rope_scaling is not None and getattr(
  265. config, "original_max_position_embeddings", None):
  266. rope_scaling["original_max_position_embeddings"] = (
  267. config.original_max_position_embeddings)
  268. max_position_embeddings = getattr(config, "max_position_embeddings",
  269. 4096)
  270. self.self_attn = ChameleonAttention(
  271. hidden_size=self.hidden_size,
  272. num_heads=config.num_attention_heads,
  273. num_kv_heads=getattr(config, "num_key_value_heads",
  274. config.num_attention_heads),
  275. rope_theta=rope_theta,
  276. rope_scaling=rope_scaling,
  277. max_position_embeddings=max_position_embeddings,
  278. quant_config=quant_config,
  279. bias=False,
  280. cache_config=cache_config,
  281. )
  282. self.mlp = ChameleonMLP(
  283. hidden_size=self.hidden_size,
  284. intermediate_size=config.intermediate_size,
  285. hidden_act=config.hidden_act,
  286. quant_config=quant_config,
  287. bias=getattr(config, "mlp_bias", False),
  288. )
  289. self.input_layernorm = RMSNorm(config.hidden_size,
  290. eps=config.rms_norm_eps)
  291. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  292. eps=config.rms_norm_eps)
  293. def forward(
  294. self,
  295. positions: torch.Tensor,
  296. hidden_states: torch.Tensor,
  297. kv_cache: torch.Tensor,
  298. attn_metadata: AttentionMetadata,
  299. residual: Optional[torch.Tensor],
  300. ) -> Tuple[torch.Tensor, torch.Tensor]:
  301. if residual is None:
  302. residual = hidden_states
  303. hidden_states = self.input_layernorm(hidden_states)
  304. else:
  305. hidden_states, residual = self.input_layernorm(
  306. hidden_states, residual)
  307. hidden_states = self.self_attn(
  308. positions=positions,
  309. hidden_states=hidden_states,
  310. kv_cache=kv_cache,
  311. attn_metadata=attn_metadata,
  312. )
  313. # Fully Connected
  314. hidden_states, residual = self.post_attention_layernorm(
  315. hidden_states, residual)
  316. hidden_states = self.mlp(hidden_states)
  317. return hidden_states, residual
  318. class ChameleonSwinDecoderLayer(nn.Module):
  319. def __init__(
  320. self,
  321. config: ChameleonConfig,
  322. cache_config: Optional[CacheConfig] = None,
  323. quant_config: Optional[QuantizationConfig] = None,
  324. ) -> None:
  325. super().__init__()
  326. self.hidden_size = config.hidden_size
  327. rope_theta = getattr(config, "rope_theta", 10000)
  328. rope_scaling = getattr(config, "rope_scaling", None)
  329. if rope_scaling is not None and getattr(
  330. config, "original_max_position_embeddings", None):
  331. rope_scaling["original_max_position_embeddings"] = (
  332. config.original_max_position_embeddings)
  333. max_position_embeddings = getattr(config, "max_position_embeddings",
  334. 4096)
  335. self.self_attn = ChameleonAttention(
  336. hidden_size=self.hidden_size,
  337. num_heads=config.num_attention_heads,
  338. num_kv_heads=getattr(config, "num_key_value_heads",
  339. config.num_attention_heads),
  340. rope_theta=rope_theta,
  341. rope_scaling=rope_scaling,
  342. max_position_embeddings=max_position_embeddings,
  343. quant_config=quant_config,
  344. bias=False,
  345. cache_config=cache_config,
  346. )
  347. self.mlp = ChameleonMLP(
  348. hidden_size=self.hidden_size,
  349. intermediate_size=config.intermediate_size,
  350. hidden_act=config.hidden_act,
  351. quant_config=quant_config,
  352. bias=getattr(config, "mlp_bias", False),
  353. )
  354. self.input_layernorm = RMSNorm(config.hidden_size,
  355. eps=config.rms_norm_eps)
  356. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  357. eps=config.rms_norm_eps)
  358. def forward(
  359. self,
  360. positions: torch.Tensor,
  361. hidden_states: torch.Tensor,
  362. kv_cache: torch.Tensor,
  363. attn_metadata: AttentionMetadata,
  364. residual: Optional[torch.Tensor],
  365. ) -> Tuple[torch.Tensor, torch.Tensor]:
  366. residual = hidden_states
  367. hidden_states = self.self_attn(
  368. positions=positions,
  369. hidden_states=hidden_states,
  370. kv_cache=kv_cache,
  371. attn_metadata=attn_metadata,
  372. )
  373. hidden_states = self.input_layernorm(hidden_states)
  374. hidden_states = hidden_states + residual
  375. # Fully Connected
  376. residual = hidden_states
  377. hidden_states = self.mlp(hidden_states)
  378. hidden_states = self.post_attention_layernorm(hidden_states)
  379. hidden_states = residual + hidden_states
  380. return hidden_states, residual
  381. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa
  382. class ChameleonVQVAEVectorQuantizer(nn.Module):
  383. def __init__(self, config: ChameleonVQVAEConfig):
  384. super().__init__()
  385. self.num_embeddings = config.num_embeddings
  386. self.embedding_dim = config.embed_dim
  387. self.beta = getattr(config, "beta", 0.25)
  388. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  389. self.re_embed = self.num_embeddings
  390. def forward(self, hidden_state: torch.Tensor):
  391. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  392. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  393. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  394. distances = (
  395. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) +
  396. torch.sum(self.embedding.weight**2, dim=1) -
  397. 2 * torch.einsum("bd,dn->bn", hidden_state_flattened,
  398. self.embedding.weight.transpose(0, 1)))
  399. min_encoding_indices = torch.argmin(distances, dim=1)
  400. hidden_state_quant = self.embedding(min_encoding_indices).view(
  401. hidden_state.shape)
  402. # compute loss for embedding
  403. loss = torch.mean((hidden_state_quant.detach() - hidden_state)**
  404. 2) + self.beta * torch.mean(
  405. (hidden_state_quant - hidden_state.detach())**2)
  406. # preserve gradients
  407. hidden_state_quant = hidden_state + (hidden_state_quant -
  408. hidden_state).detach()
  409. # reshape back to match original input shape
  410. hidden_state_quant = hidden_state_quant.permute(0, 3, 1,
  411. 2).contiguous()
  412. return hidden_state_quant, loss, min_encoding_indices
  413. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa
  414. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  415. def __init__(self, in_channels: int):
  416. super().__init__()
  417. self.conv = nn.Conv2d(in_channels,
  418. in_channels,
  419. kernel_size=3,
  420. stride=2,
  421. padding=0)
  422. def forward(self, hidden_states: torch.Tensor):
  423. # no asymmetric padding in torch conv, must do it ourselves
  424. hidden_states = F.pad(hidden_states,
  425. pad=(0, 1, 0, 1),
  426. mode="constant",
  427. value=0)
  428. hidden_states = self.conv(hidden_states)
  429. return hidden_states
  430. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa
  431. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  432. def __init__(
  433. self,
  434. config: ChameleonVQVAEConfig,
  435. in_channels: int,
  436. out_channels=None,
  437. conv_shortcut=False,
  438. ):
  439. super().__init__()
  440. self.in_channels = in_channels
  441. self.out_channels = in_channels if out_channels is None \
  442. else out_channels
  443. self.use_conv_shortcut = conv_shortcut
  444. self.norm1 = torch.nn.GroupNorm(num_groups=32,
  445. num_channels=in_channels,
  446. eps=1e-6,
  447. affine=True)
  448. self.conv1 = torch.nn.Conv2d(in_channels,
  449. out_channels,
  450. kernel_size=3,
  451. stride=1,
  452. padding=1)
  453. self.norm2 = torch.nn.GroupNorm(num_groups=32,
  454. num_channels=out_channels,
  455. eps=1e-6,
  456. affine=True)
  457. self.dropout = torch.nn.Dropout(config.dropout)
  458. self.conv2 = torch.nn.Conv2d(out_channels,
  459. out_channels,
  460. kernel_size=3,
  461. stride=1,
  462. padding=1)
  463. if self.in_channels != self.out_channels:
  464. if self.use_conv_shortcut:
  465. self.conv_shortcut = torch.nn.Conv2d(in_channels,
  466. out_channels,
  467. kernel_size=3,
  468. stride=1,
  469. padding=1)
  470. else:
  471. self.nin_shortcut = torch.nn.Conv2d(in_channels,
  472. out_channels,
  473. kernel_size=1,
  474. stride=1,
  475. padding=0)
  476. def forward(self, hidden_states: torch.Tensor):
  477. residual = hidden_states
  478. hidden_states = self.norm1(hidden_states)
  479. hidden_states *= torch.sigmoid(hidden_states)
  480. hidden_states = self.conv1(hidden_states)
  481. hidden_states = self.norm2(hidden_states)
  482. hidden_states *= torch.sigmoid(hidden_states)
  483. hidden_states = self.dropout(hidden_states)
  484. hidden_states = self.conv2(hidden_states)
  485. if self.in_channels != self.out_channels:
  486. if self.use_conv_shortcut:
  487. residual = self.conv_shortcut(residual)
  488. else:
  489. residual = self.nin_shortcut(residual)
  490. return residual + hidden_states
  491. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa
  492. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  493. def __init__(self, in_channels: int):
  494. super().__init__()
  495. self.in_channels = in_channels
  496. self.norm = torch.nn.GroupNorm(num_groups=32,
  497. num_channels=in_channels,
  498. eps=1e-6,
  499. affine=True)
  500. self.q = torch.nn.Conv2d(in_channels,
  501. in_channels,
  502. kernel_size=1,
  503. stride=1,
  504. padding=0)
  505. self.k = torch.nn.Conv2d(in_channels,
  506. in_channels,
  507. kernel_size=1,
  508. stride=1,
  509. padding=0)
  510. self.v = torch.nn.Conv2d(in_channels,
  511. in_channels,
  512. kernel_size=1,
  513. stride=1,
  514. padding=0)
  515. self.proj_out = torch.nn.Conv2d(in_channels,
  516. in_channels,
  517. kernel_size=1,
  518. stride=1,
  519. padding=0)
  520. def forward(self, hidden_states: torch.Tensor):
  521. residual = hidden_states
  522. hidden_states = self.norm(hidden_states)
  523. query_states = self.q(hidden_states)
  524. key_states = self.k(hidden_states)
  525. value_states = self.v(hidden_states)
  526. # compute attention
  527. batch_size, channels, height, width = query_states.shape
  528. query_states = query_states.reshape(batch_size, channels,
  529. height * width).permute(0, 2, 1)
  530. key_states = key_states.reshape(batch_size, channels, height * width)
  531. attn_weights = torch.bmm(query_states, key_states)
  532. attn_weights = attn_weights * (int(channels)**(-0.5))
  533. attn_weights = F.softmax(attn_weights, dim=2)
  534. # attend to values
  535. value_states = value_states.reshape(batch_size, channels,
  536. height * width)
  537. attn_weights = attn_weights.permute(0, 2, 1)
  538. attn_output = torch.bmm(value_states,
  539. attn_weights).reshape(batch_size, channels,
  540. height, width)
  541. attn_output = self.proj_out(attn_output)
  542. return residual + attn_output
  543. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa
  544. class ChameleonVQVAEEncoder(nn.Module):
  545. def __init__(self, config: ChameleonVQVAEConfig):
  546. super().__init__()
  547. self.num_resolutions = len(config.channel_multiplier)
  548. self.num_res_blocks = config.num_res_blocks
  549. base_channels = config.base_channels
  550. resolution = config.resolution
  551. in_channels = config.in_channels
  552. double_latent = config.double_latent
  553. latent_channels = config.latent_channels
  554. channel_multiplier = config.channel_multiplier
  555. self.conv_in = torch.nn.Conv2d(in_channels,
  556. base_channels,
  557. kernel_size=3,
  558. stride=1,
  559. padding=1)
  560. curr_res = resolution
  561. in_channel_multiplier = (1, ) + tuple(channel_multiplier)
  562. self.in_channel_multiplier = in_channel_multiplier
  563. self.down = nn.ModuleList()
  564. for i_level in range(self.num_resolutions):
  565. block = nn.ModuleList()
  566. attn = nn.ModuleList()
  567. block_in = base_channels * in_channel_multiplier[i_level]
  568. block_out = base_channels * channel_multiplier[i_level]
  569. for i_block in range(self.num_res_blocks):
  570. block.append(
  571. ChameleonVQVAEEncoderResnetBlock(
  572. config=config,
  573. in_channels=block_in,
  574. out_channels=block_out,
  575. ))
  576. block_in = block_out
  577. if (config.attn_resolutions is not None
  578. and curr_res in config.attn_resolutions
  579. and config.attn_type == "vanilla"):
  580. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  581. down = nn.Module()
  582. down.block = block
  583. down.attn = attn
  584. if i_level != self.num_resolutions - 1:
  585. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  586. curr_res = curr_res // 2
  587. self.down.append(down)
  588. self.mid = nn.Module()
  589. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  590. config=config,
  591. in_channels=block_in,
  592. out_channels=block_in,
  593. )
  594. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(
  595. block_in) if config.attn_type == "vanilla" else nn.Identity()
  596. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  597. config=config,
  598. in_channels=block_in,
  599. out_channels=block_in,
  600. )
  601. self.norm_out = torch.nn.GroupNorm(num_groups=32,
  602. num_channels=block_in,
  603. eps=1e-6,
  604. affine=True)
  605. self.conv_out = torch.nn.Conv2d(
  606. block_in,
  607. 2 * latent_channels if double_latent else latent_channels,
  608. kernel_size=3,
  609. stride=1,
  610. padding=1,
  611. )
  612. def forward(self, pixel_values: torch.Tensor):
  613. pixel_values = pixel_values.to(self.conv_in.weight.dtype)
  614. # downsampling
  615. hidden_states = [self.conv_in(pixel_values)]
  616. for i_level in range(self.num_resolutions):
  617. for i_block in range(self.num_res_blocks):
  618. hidden_state = self.down[i_level].block[i_block](
  619. hidden_states[-1], )
  620. if len(self.down[i_level].attn) > 0:
  621. hidden_state = self.down[i_level].attn[i_block](
  622. hidden_state)
  623. hidden_states.append(hidden_state)
  624. if i_level != self.num_resolutions - 1:
  625. hidden_states.append(self.down[i_level].downsample(
  626. hidden_states[-1]))
  627. # middle
  628. last_hidden_state = hidden_states[-1]
  629. last_hidden_state = self.mid.block_1(last_hidden_state)
  630. last_hidden_state = self.mid.attn_1(last_hidden_state)
  631. last_hidden_state = self.mid.block_2(last_hidden_state)
  632. # end
  633. last_hidden_state = self.norm_out(last_hidden_state)
  634. last_hidden_state *= torch.sigmoid(last_hidden_state)
  635. last_hidden_state = self.conv_out(last_hidden_state)
  636. return last_hidden_state
  637. # Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa
  638. class ChameleonVQVAE(nn.Module):
  639. def __init__(self, config: ChameleonVQVAEConfig):
  640. super().__init__()
  641. self.encoder = ChameleonVQVAEEncoder(config)
  642. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  643. self.quant_conv = torch.nn.Conv2d(config.latent_channels,
  644. config.embed_dim, 1)
  645. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim,
  646. config.latent_channels, 1)
  647. self.eval() # Chameleon's VQ model is frozen
  648. def encode(
  649. self, pixel_values: torch.Tensor
  650. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  651. hidden_states = self.encoder(pixel_values)
  652. hidden_states = self.quant_conv(hidden_states)
  653. quant, emb_loss, indices = self.quantize(hidden_states)
  654. return quant, emb_loss, indices
  655. # Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa
  656. class ChameleonImageVocabularyMapping:
  657. """
  658. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  659. """
  660. def __init__(self, vocab_map: Dict[str, int]):
  661. self.vocab_map = vocab_map
  662. self.image_token_id = vocab_map.get("<image>")
  663. @cached_property
  664. def val2name(self):
  665. return {v: k for k, v in self.vocab_map.items()}
  666. @cached_property
  667. def image_tokens(self):
  668. return sorted([
  669. val for name, val in self.vocab_map.items()
  670. if name.startswith("IMGIMG")
  671. ])
  672. @cached_property
  673. def bpe2img(self):
  674. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  675. def remap(old_name: str) -> str:
  676. return "".join(
  677. img_tkn_chr_mapping.get(c, c)
  678. for c in old_name[len("IMGIMG"):-1])
  679. return {
  680. tok: int(remap(self.val2name[tok]))
  681. for tok in self.image_tokens
  682. }
  683. @cached_property
  684. def img2bpe(self):
  685. return {v: k for k, v in self.bpe2img.items()}
  686. @cached_property
  687. def bpe2img_search_tensors(self):
  688. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(
  689. sorted(self.bpe2img.values()))
  690. @cached_property
  691. def img2bpe_mapping_tensor(self):
  692. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  693. for k, v in self.img2bpe.items():
  694. mapping[k] = v
  695. return mapping
  696. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  697. device = img_batch.device
  698. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  699. return img_tokens.to(device)
  700. class ChameleonModel(nn.Module):
  701. def __init__(
  702. self,
  703. config: ChameleonConfig,
  704. cache_config: Optional[CacheConfig] = None,
  705. quant_config: Optional[QuantizationConfig] = None,
  706. ) -> None:
  707. super().__init__()
  708. self.config = config
  709. self.padding_idx = config.pad_token_id
  710. self.vocab_size = config.vocab_size
  711. self.embed_tokens = VocabParallelEmbedding(
  712. self.vocab_size,
  713. config.hidden_size,
  714. )
  715. self.vocabulary_mapping = ChameleonImageVocabularyMapping(
  716. config.vocabulary_map)
  717. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \
  718. else ChameleonSwinDecoderLayer
  719. self.layers = nn.ModuleList([
  720. decoder_layer(config=config,
  721. cache_config=cache_config,
  722. quant_config=quant_config)
  723. for _ in range(config.num_hidden_layers)
  724. ])
  725. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  726. self.vqmodel = ChameleonVQVAE(config.vq_config)
  727. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  728. return self.embed_tokens(input_ids)
  729. def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
  730. """
  731. Tokenizes images into discrete tokens with VQGAN module. Converts
  732. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  733. special tokens.
  734. """
  735. batch_size = pixel_values.shape[0]
  736. _, _, image_toks = self.vqmodel.encode(pixel_values)
  737. bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
  738. bpe_toks = bpe_toks.view(batch_size, -1)
  739. return bpe_toks
  740. def forward(
  741. self,
  742. input_ids: Optional[torch.Tensor],
  743. positions: torch.Tensor,
  744. kv_caches: List[torch.Tensor],
  745. attn_metadata: AttentionMetadata,
  746. inputs_embeds: Optional[torch.Tensor] = None,
  747. ) -> torch.Tensor:
  748. if inputs_embeds is not None:
  749. hidden_states = inputs_embeds
  750. else:
  751. hidden_states = self.get_input_embeddings(input_ids)
  752. residual = None
  753. for i in range(len(self.layers)):
  754. layer = self.layers[i]
  755. hidden_states, residual = layer(
  756. positions,
  757. hidden_states,
  758. kv_caches[i],
  759. attn_metadata,
  760. residual,
  761. )
  762. hidden_states, _ = self.norm(hidden_states, residual)
  763. return hidden_states
  764. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  765. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
  766. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
  767. @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
  768. class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
  769. def __init__(
  770. self,
  771. config: ChameleonConfig,
  772. multimodal_config: MultiModalConfig,
  773. cache_config: Optional[CacheConfig] = None,
  774. quant_config: Optional[QuantizationConfig] = None,
  775. ) -> None:
  776. super().__init__()
  777. self.config = config
  778. self.multimodal_config = multimodal_config
  779. self.model = ChameleonModel(config, cache_config, quant_config)
  780. self.unpadded_vocab_size = config.vocab_size
  781. self.lm_head = ParallelLMHead(
  782. self.unpadded_vocab_size,
  783. config.hidden_size,
  784. )
  785. if config.tie_word_embeddings:
  786. self.lm_head.weight = self.model.embed_tokens.weight
  787. logit_scale = getattr(config, "logit_scale", 1.0)
  788. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  789. config.vocab_size, logit_scale)
  790. self.sampler = Sampler()
  791. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  792. expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
  793. CHAMELEON_CROP_SIZE_WIDTH)
  794. actual_dims = tuple(data.shape[1:])
  795. if actual_dims != expected_dims:
  796. expected_expr = ("batch_size", *map(str, expected_dims))
  797. raise ValueError(
  798. f"The expected shape of pixel values is {expected_expr}. "
  799. f"You supplied {tuple(data.shape)}.")
  800. return data
  801. def _parse_and_validate_image_input(
  802. self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]:
  803. pixel_values = kwargs.pop("pixel_values", None)
  804. if pixel_values is None:
  805. return None
  806. if not isinstance(pixel_values, torch.Tensor):
  807. raise ValueError("Incorrect type of pixel values. "
  808. f"Got type: {type(pixel_values)}")
  809. return ChameleonImagePixelInputs(
  810. type="pixel_values",
  811. data=self._validate_pixel_values(pixel_values),
  812. )
  813. def forward(
  814. self,
  815. input_ids: torch.Tensor,
  816. positions: torch.Tensor,
  817. kv_caches: List[torch.Tensor],
  818. attn_metadata: AttentionMetadata,
  819. intermediate_tensors: Optional[IntermediateTensors] = None,
  820. **kwargs,
  821. ) -> torch.Tensor:
  822. image_input = self._parse_and_validate_image_input(**kwargs)
  823. if image_input is not None:
  824. assert self.model.vqmodel is not None
  825. image_tokens = self.model.get_image_tokens(image_input["data"].to(
  826. self.config.torch_dtype))
  827. image_token_id = self.model.vocabulary_mapping.image_token_id
  828. special_image_mask = input_ids == image_token_id
  829. image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
  830. input_ids = input_ids.masked_scatter(special_image_mask,
  831. image_tokens)
  832. hidden_states = self.model(input_ids, positions, kv_caches,
  833. attn_metadata)
  834. return hidden_states
  835. def compute_logits(
  836. self,
  837. hidden_states: torch.Tensor,
  838. sampling_metadata: SamplingMetadata,
  839. ) -> Optional[torch.Tensor]:
  840. logits = self.logits_processor(self.lm_head, hidden_states,
  841. sampling_metadata)
  842. # Disallow image tokens which does not include special
  843. # begin-image and end-image tokens
  844. if logits is not None:
  845. image_tokens = self.model.vocabulary_mapping.image_tokens
  846. logits[:, image_tokens] = torch.finfo(logits.dtype).min
  847. return logits
  848. def sample(
  849. self,
  850. logits: torch.Tensor,
  851. sampling_metadata: SamplingMetadata,
  852. ) -> Optional[SamplerOutput]:
  853. next_tokens = self.sampler(logits, sampling_metadata)
  854. return next_tokens
  855. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  856. stacked_params_mapping = [
  857. # (param_name, shard_name, shard_id)
  858. (".qkv_proj", ".q_proj", "q"),
  859. (".qkv_proj", ".k_proj", "k"),
  860. (".qkv_proj", ".v_proj", "v"),
  861. (".gate_up_proj", ".gate_proj", 0),
  862. (".gate_up_proj", ".up_proj", 1),
  863. ]
  864. params_dict = dict(self.named_parameters())
  865. for name, loaded_weight in weights:
  866. if "rotary_emb.inv_freq" in name:
  867. continue
  868. if ("rotary_emb.cos_cached" in name
  869. or "rotary_emb.sin_cached" in name):
  870. # Models trained using ColossalAI may include these tensors in
  871. # the checkpoint. Skip them.
  872. continue
  873. # With tie_word_embeddings, we can skip lm_head.weight
  874. # The weight might appear unnecessarily in the files if the model is
  875. # processed with quantization, LoRA, fine-tuning, etc.
  876. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  877. continue
  878. use_default_weight_loading = False
  879. if "vqmodel" in name:
  880. if self.model.vqmodel is not None:
  881. # We only do sharding for language model and
  882. # not vqvae for now.
  883. use_default_weight_loading = True
  884. else:
  885. for (param_name, weight_name,
  886. shard_id) in stacked_params_mapping:
  887. if weight_name not in name:
  888. continue
  889. name = name.replace(weight_name, param_name)
  890. # Skip loading extra bias for GPTQ models.
  891. if name.endswith(".bias") and name not in params_dict:
  892. continue
  893. param = params_dict[name]
  894. weight_loader = param.weight_loader
  895. weight_loader(param, loaded_weight, shard_id)
  896. break
  897. else:
  898. # Skip loading extra bias for GPTQ models.
  899. if name.endswith(".bias") and name not in params_dict:
  900. continue
  901. # Remapping the name of FP8 kv-scale.
  902. if name.endswith("kv_scale"):
  903. remapped_kv_scale_name = name.replace(
  904. ".kv_scale", ".attn.kv_scale")
  905. if remapped_kv_scale_name not in params_dict:
  906. print_warning_once(
  907. "Found kv scale in the checkpoint (e.g. "
  908. f"{name}), but not found the expected name in "
  909. f"the model (e.g. {remapped_kv_scale_name}). "
  910. "kv-scale is not loaded.")
  911. continue
  912. else:
  913. name = remapped_kv_scale_name
  914. param = params_dict[name]
  915. weight_loader = getattr(param, "weight_loader",
  916. default_weight_loader)
  917. weight_loader(param, loaded_weight)
  918. if use_default_weight_loading and name in params_dict:
  919. param = params_dict[name]
  920. weight_loader = getattr(param, "weight_loader",
  921. default_weight_loader)
  922. weight_loader(param, loaded_weight)