llava.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from typing import List, Optional
  2. import torch
  3. from torch import nn
  4. # TODO: We should port CLIPVisionModel's code over to not depend on
  5. # transformers' impl.
  6. from transformers import CLIPVisionModel, LlavaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import VisionLanguageConfig
  9. from aphrodite.modeling.layers.activation import get_act_fn
  10. from aphrodite.modeling.layers.linear import LinearMethodBase
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler
  13. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  14. from aphrodite.modeling.models.llama import LlamaModel
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  17. hf_model_weights_iterator)
  18. from aphrodite.common.sequence import SamplerOutput
  19. _KEYS_TO_MODIFY_MAPPING = {
  20. "language_model.lm_head": "lm_head",
  21. "language_model.model": "language_model",
  22. }
  23. # TODO: Run benchmark and decide if TP.
  24. class LlavaMultiModalProjector(nn.Module):
  25. def __init__(self, vision_hidden_size: int, text_hidden_size: int,
  26. projector_hidden_act: str):
  27. super().__init__()
  28. self.linear_1 = nn.Linear(vision_hidden_size,
  29. text_hidden_size,
  30. bias=True)
  31. self.act = get_act_fn(projector_hidden_act)
  32. self.linear_2 = nn.Linear(text_hidden_size,
  33. text_hidden_size,
  34. bias=True)
  35. def forward(self, image_features):
  36. hidden_states = self.linear_1(image_features)
  37. hidden_states = self.act(hidden_states)
  38. hidden_states = self.linear_2(hidden_states)
  39. return hidden_states
  40. def _merge_vision_embeddings(input_ids: torch.Tensor,
  41. inputs_embeds: torch.Tensor,
  42. vision_embeddings: torch.Tensor,
  43. image_token_id: int):
  44. """In place merges in vision_embeddings with inputs_embeds."""
  45. mask = (input_ids == image_token_id)
  46. inputs_embeds[mask] = vision_embeddings.view(-1,
  47. vision_embeddings.shape[-1])
  48. class LlavaForConditionalGeneration(nn.Module):
  49. def __init__(self,
  50. config: "LlavaConfig",
  51. vision_language_config: VisionLanguageConfig,
  52. linear_method: Optional["LinearMethodBase"] = None) -> None:
  53. super().__init__()
  54. self.config = config
  55. self.vision_language_config = vision_language_config
  56. assert self.vision_language_config, (
  57. "Provide `image_input_type` and other vision "
  58. "related configurations through LLM entrypoint "
  59. "or engine arguments.")
  60. if self.vision_language_config.image_input_type == (
  61. VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
  62. self.vision_tower = CLIPVisionModel(config.vision_config)
  63. else:
  64. self.vision_tower = None
  65. self.multi_modal_projector = LlavaMultiModalProjector(
  66. vision_hidden_size=config.vision_config.hidden_size,
  67. text_hidden_size=config.text_config.hidden_size,
  68. projector_hidden_act=config.projector_hidden_act)
  69. self.linear_method = linear_method
  70. self.language_model = LlamaModel(config.text_config, linear_method)
  71. self.unpadded_vocab_size = config.text_config.vocab_size
  72. self.lm_head = ParallelLMHead(
  73. self.unpadded_vocab_size,
  74. config.text_config.hidden_size,
  75. org_num_embeddings=self.language_model.org_vocab_size)
  76. logit_scale = getattr(config, "logit_scale", 1.0)
  77. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  78. config.vocab_size, logit_scale)
  79. self.sampler = Sampler()
  80. def forward(
  81. self,
  82. input_ids: torch.Tensor,
  83. positions: torch.Tensor,
  84. kv_caches: List[torch.Tensor],
  85. attn_metadata: AttentionMetadata,
  86. image_input: Optional[torch.Tensor] = None
  87. ) -> SamplerOutput: # noqa: E501
  88. """Run forward pass for Llava 1.5.
  89. One key thing to understand is the `input_ids` already accounts for the
  90. positions of the to-be-inserted image embeddings.
  91. Concretely, consider a text prompt:
  92. "<image>\nUSER: What's the content of the image?\nASSISTANT:".
  93. Tokenizer outputs:
  94. [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
  95. 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
  96. The to-be-inserted image has a size of 576 (24 * 24) along the context
  97. length dimension.
  98. `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
  99. 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
  100. 9047, 13566, 29901].
  101. There will be 576 `32000` in the `input_ids`.
  102. (32000 is the token id for `<image>`.)
  103. This way, the `positions` and `attn_metadata` are consistent
  104. with the `input_ids`.
  105. The model takes two types of image inputs:
  106. PIXEL_VALUES and IMAGE_FEATURES.
  107. The following shows how each maps to huggingface implementation.
  108. PIXEL_VALUES:
  109. - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
  110. IMAGE_FEATURES:
  111. - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
  112. before going through the multi modal projector.
  113. Args:
  114. input_ids: Flattened (concatenated) input_ids corresponding to a
  115. batch.
  116. image_input: A batch of image inputs.
  117. For PIXEL_VALUES, expecting [1, 3, 336, 336].
  118. For IMAGE_FEATURES, expecting [1, 576, 1024].
  119. """
  120. if image_input is not None:
  121. if list(image_input.shape[1:]) != list(
  122. self.vision_language_config.image_input_shape[1:]):
  123. raise ValueError(
  124. f"The expected image tensor shape is batch dimension "
  125. f"plus "
  126. f"{self.vision_language_config.image_input_shape[1:]}."
  127. f" You supplied {image_input.shape}. "
  128. f"If you are using Aphrodite's endpoint, make sure your "
  129. f"supplied image input is consistent with "
  130. f"image_input_shape in engine args.")
  131. if self.vision_tower is not None:
  132. # TODO: Maybe port minimal CLIPVisionModel over.
  133. image_outputs = self.vision_tower(image_input,
  134. output_hidden_states=True)
  135. image_features = image_outputs.hidden_states[
  136. self.config.vision_feature_layer]
  137. # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
  138. if self.config.vision_feature_select_strategy == "default":
  139. image_features = image_features[:, 1:]
  140. elif self.config.vision_feature_select_strategy == "full":
  141. image_features = image_features
  142. else:
  143. raise ValueError(
  144. f"Unexpected select feature strategy: "
  145. f"{self.config.vision_feature_select_strategy}")
  146. else:
  147. image_features = image_input
  148. vision_embeddings = self.multi_modal_projector(image_features)
  149. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  150. _merge_vision_embeddings(
  151. input_ids, inputs_embeds, vision_embeddings,
  152. self.vision_language_config.image_token_id)
  153. input_ids = None
  154. else:
  155. inputs_embeds = None
  156. hidden_states = self.language_model(input_ids,
  157. positions,
  158. kv_caches,
  159. attn_metadata,
  160. inputs_embeds=inputs_embeds)
  161. return hidden_states
  162. def compute_logits(self, hidden_states: torch.Tensor,
  163. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  164. logits = self.logits_processor(self.lm_head, hidden_states,
  165. sampling_metadata)
  166. return logits
  167. def sample(
  168. self,
  169. logits: torch.Tensor,
  170. sampling_metadata: SamplingMetadata,
  171. ) -> Optional[SamplerOutput]:
  172. next_tokens = self.sampler(logits, sampling_metadata)
  173. return next_tokens
  174. def load_weights(self,
  175. model_name_or_path: str,
  176. cache_dir: Optional[str] = None,
  177. load_format: str = "auto",
  178. revision: Optional[str] = None):
  179. # only doing this for language model part for now.
  180. stacked_params_mapping = [
  181. # (param_name, shard_name, shard_id)
  182. ("qkv_proj", "q_proj", "q"),
  183. ("qkv_proj", "k_proj", "k"),
  184. ("qkv_proj", "v_proj", "v"),
  185. ("gate_up_proj", "gate_proj", 0),
  186. ("gate_up_proj", "up_proj", 1),
  187. ]
  188. params_dict = dict(self.named_parameters())
  189. for name, loaded_weight in hf_model_weights_iterator(
  190. model_name_or_path, cache_dir, load_format, revision):
  191. if "rotary_emb.inv_freq" in name:
  192. continue
  193. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  194. if key_to_modify in name:
  195. name = name.replace(key_to_modify, new_key)
  196. use_default_weight_loading = False
  197. if "vision" in name:
  198. if self.vision_tower is not None:
  199. # We only do sharding for language model and
  200. # not vision model for now.
  201. use_default_weight_loading = True
  202. else:
  203. for (param_name, weight_name,
  204. shard_id) in stacked_params_mapping:
  205. if weight_name not in name:
  206. continue
  207. param = params_dict[name.replace(weight_name, param_name)]
  208. weight_loader = param.weight_loader
  209. weight_loader(param, loaded_weight, shard_id)
  210. break
  211. else:
  212. use_default_weight_loading = True
  213. if use_default_weight_loading:
  214. param = params_dict[name]
  215. weight_loader = getattr(param, "weight_loader",
  216. default_weight_loader)
  217. weight_loader(param, loaded_weight)