llava.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. from typing import Iterable, List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. # TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
  5. # transformers' impl.
  6. from transformers import CLIPVisionModel, LlavaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import VisionLanguageConfig
  9. from aphrodite.common.sequence import SamplerOutput
  10. from aphrodite.modeling.layers.activation import get_act_fn
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler
  13. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.models.llama import LlamaModel
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.quantization.base_config import QuantizationConfig
  18. _KEYS_TO_MODIFY_MAPPING = {
  19. "language_model.lm_head": "lm_head",
  20. "language_model.model": "language_model",
  21. }
  22. # TODO(xwjiang): Run benchmark and decide if TP.
  23. class LlavaMultiModalProjector(nn.Module):
  24. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  25. projector_hidden_act: str):
  26. super().__init__()
  27. self.linear_1 = nn.Linear(vision_hidden_size,
  28. text_hidden_size,
  29. bias=True)
  30. self.act = get_act_fn(projector_hidden_act)
  31. self.linear_2 = nn.Linear(text_hidden_size,
  32. text_hidden_size,
  33. bias=True)
  34. def forward(self, image_features):
  35. hidden_states = self.linear_1(image_features)
  36. hidden_states = self.act(hidden_states)
  37. hidden_states = self.linear_2(hidden_states)
  38. return hidden_states
  39. def _merge_vision_embeddings(input_ids: torch.Tensor,
  40. inputs_embeds: torch.Tensor,
  41. vision_embeddings: torch.Tensor,
  42. image_token_id: int):
  43. """In place merges in vision_embeddings with inputs_embeds."""
  44. mask = (input_ids == image_token_id)
  45. inputs_embeds[mask] = vision_embeddings.view(-1,
  46. vision_embeddings.shape[-1])
  47. class LlavaForConditionalGeneration(nn.Module):
  48. def __init__(self,
  49. config: "LlavaConfig",
  50. vision_language_config: VisionLanguageConfig,
  51. quant_config: Optional["QuantizationConfig"] = None) -> None:
  52. super().__init__()
  53. self.config = config
  54. self.vision_language_config = vision_language_config
  55. assert self.vision_language_config, (
  56. "Provide `image_input_type` and other vision "
  57. "related configurations through LLM entrypoint "
  58. "or engine arguments.")
  59. if self.vision_language_config.image_input_type == (
  60. VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
  61. self.vision_tower = CLIPVisionModel(config.vision_config)
  62. else:
  63. self.vision_tower = None
  64. self.multi_modal_projector = LlavaMultiModalProjector(
  65. vision_hidden_size=config.vision_config.hidden_size,
  66. text_hidden_size=config.text_config.hidden_size,
  67. projector_hidden_act=config.projector_hidden_act)
  68. self.quant_config = quant_config
  69. self.language_model = LlamaModel(config.text_config, quant_config)
  70. self.unpadded_vocab_size = config.text_config.vocab_size
  71. self.lm_head = ParallelLMHead(
  72. self.unpadded_vocab_size,
  73. config.text_config.hidden_size,
  74. org_num_embeddings=self.language_model.org_vocab_size)
  75. logit_scale = getattr(config, "logit_scale", 1.0)
  76. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  77. config.vocab_size, logit_scale)
  78. self.sampler = Sampler()
  79. def forward(
  80. self,
  81. input_ids: torch.Tensor,
  82. positions: torch.Tensor,
  83. kv_caches: List[torch.Tensor],
  84. attn_metadata: AttentionMetadata,
  85. image_input: Optional[torch.Tensor] = None
  86. ) -> SamplerOutput: # noqa: E501
  87. """Run forward pass for Llava 1.5.
  88. One key thing to understand is the `input_ids` already accounts for the
  89. positions of the to-be-inserted image embeddings.
  90. Concretely, consider a text prompt:
  91. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  92. Tokenizer outputs:
  93. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  94. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  95. The to-be-inserted image has a size of 576 (24 * 24) along the context
  96. length dimension.
  97. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  98. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  99. 9047, 13566, 29901].
  100. There will be 576 `32000` in the `input_ids`.
  101. (32000 is the token id for `<image>`.)
  102. This way, the `positions` and `attn_metadata` are consistent
  103. with the `input_ids`.
  104. The model takes two types of image inputs:
  105. PIXEL_VALUES and IMAGE_FEATURES.
  106. The following shows how each maps to huggingface implementation.
  107. PIXEL_VALUES:
  108. - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
  109. IMAGE_FEATURES:
  110. - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
  111. before going through the multi modal projector.
  112. Args:
  113. input_ids: Flattened (concatenated) input_ids corresponding to a
  114. batch.
  115. image_input: A batch of image inputs.
  116. For PIXEL_VALUES, expecting [1, 3, 336, 336].
  117. For IMAGE_FEATURES, expecting [1, 576, 1024].
  118. """
  119. if image_input is not None:
  120. if list(image_input.shape[1:]) != list(
  121. self.vision_language_config.image_input_shape[1:]):
  122. raise ValueError(
  123. f"The expected image tensor shape is batch dimension "
  124. f"plus "
  125. f"{self.vision_language_config.image_input_shape[1:]}."
  126. f" You supplied {image_input.shape}. "
  127. f"If you are using vLLM's entrypoint, make sure your "
  128. f"supplied image input is consistent with "
  129. f"image_input_shape in engine args.")
  130. if self.vision_tower is not None:
  131. # TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
  132. image_outputs = self.vision_tower(image_input,
  133. output_hidden_states=True)
  134. image_features = image_outputs.hidden_states[
  135. self.config.vision_feature_layer]
  136. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  137. if self.config.vision_feature_select_strategy == "default":
  138. image_features = image_features[:, 1:]
  139. elif self.config.vision_feature_select_strategy == "full":
  140. image_features = image_features
  141. else:
  142. raise ValueError(
  143. f"Unexpected select feature strategy: "
  144. f"{self.config.vision_feature_select_strategy}")
  145. else:
  146. image_features = image_input
  147. vision_embeddings = self.multi_modal_projector(image_features)
  148. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  149. _merge_vision_embeddings(
  150. input_ids, inputs_embeds, vision_embeddings,
  151. self.vision_language_config.image_token_id)
  152. input_ids = None
  153. else:
  154. inputs_embeds = None
  155. hidden_states = self.language_model(input_ids,
  156. positions,
  157. kv_caches,
  158. attn_metadata,
  159. inputs_embeds=inputs_embeds)
  160. return hidden_states
  161. def compute_logits(self, hidden_states: torch.Tensor,
  162. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  163. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  164. sampling_metadata)
  165. return logits
  166. def sample(
  167. self,
  168. logits: torch.Tensor,
  169. sampling_metadata: SamplingMetadata,
  170. ) -> Optional[SamplerOutput]:
  171. next_tokens = self.sampler(logits, sampling_metadata)
  172. return next_tokens
  173. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  174. # only doing this for language model part for now.
  175. stacked_params_mapping = [
  176. # (param_name, shard_name, shard_id)
  177. ("qkv_proj", "q_proj", "q"),
  178. ("qkv_proj", "k_proj", "k"),
  179. ("qkv_proj", "v_proj", "v"),
  180. ("gate_up_proj", "gate_proj", 0),
  181. ("gate_up_proj", "up_proj", 1),
  182. ]
  183. params_dict = dict(self.named_parameters())
  184. for name, loaded_weight in weights:
  185. if "rotary_emb.inv_freq" in name:
  186. continue
  187. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  188. if key_to_modify in name:
  189. name = name.replace(key_to_modify, new_key)
  190. use_default_weight_loading = False
  191. if "vision" in name:
  192. if self.vision_tower is not None:
  193. # We only do sharding for language model and
  194. # not vision model for now.
  195. use_default_weight_loading = True
  196. else:
  197. for (param_name, weight_name,
  198. shard_id) in stacked_params_mapping:
  199. if weight_name not in name:
  200. continue
  201. param = params_dict[name.replace(weight_name, param_name)]
  202. weight_loader = param.weight_loader
  203. weight_loader(param, loaded_weight, shard_id)
  204. break
  205. else:
  206. use_default_weight_loading = True
  207. if use_default_weight_loading:
  208. param = params_dict[name]
  209. weight_loader = getattr(param, "weight_loader",
  210. default_weight_loader)
  211. weight_loader(param, loaded_weight)