소스 검색

feat: add fuyu vision model and persimmon language model support

AlpinDale 6 달 전
부모
커밋
e13a66925c
4개의 변경된 파일695개의 추가작업 그리고 0개의 파일을 삭제
  1. 2 0
      aphrodite/modeling/models/__init__.py
  2. 325 0
      aphrodite/modeling/models/fuyu.py
  3. 333 0
      aphrodite/modeling/models/persimmon.py
  4. 35 0
      examples/vision/fuyu_example.py

+ 2 - 0
aphrodite/modeling/models/__init__.py

@@ -65,6 +65,8 @@ _GENERATION_MODELS = {
     "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
     "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
     "MedusaModel": ("medusa", "Medusa"),
+    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
+    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
 }
 
 _EMBEDDING_MODELS = {

+ 325 - 0
aphrodite/modeling/models/fuyu.py

@@ -0,0 +1,325 @@
+# coding=utf-8
+# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
+# Copyright 2023 The vLLM team.
+# Copyright 2023 HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch Fuyu model."""
+import math
+from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from PIL import Image
+from transformers import FuyuConfig, FuyuImageProcessor
+
+from aphrodite.attention import AttentionMetadata
+from aphrodite.common.config import CacheConfig, MultiModalConfig
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
+                                       SequenceData)
+from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from aphrodite.modeling.layers.linear import ColumnParallelLinear
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal.base import MultiModalInputs
+from aphrodite.multimodal.image import (cached_get_image_processor,
+                                        cached_get_tokenizer)
+from aphrodite.quantization.base_config import QuantizationConfig
+
+from .interfaces import SupportsVision
+from .utils import merge_vision_embeddings
+
+# Cannot find the following 2 numbers from hf config.
+_IMAGE_TOKEN_ID = 71011
+_NEWLINE_TOKEN_ID = 71019
+
+MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
+MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
+
+
+class FuyuImagePixelInputs(TypedDict):
+    type: Literal["pixel_values"]
+    data: torch.Tensor
+    """
+    Shape: 
+    (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
+    """
+
+
+def _calculate_num_image_tokens(
+    height: int,
+    width: int,
+) -> Tuple[int, int]:
+    """
+    calculate number of image tokens needed for a given image size
+    The expected Fuyu image prompts is in format:
+        (image_token * ncols + newline_token) * nrows
+    args:
+        image_size: Tuple[int, int] - (width, height) of the image
+    returns:
+        ncols: int - number of image tokens in x direction
+        nrows: int - number of image tokens in y direction
+    """
+    ncol = math.ceil(width / 30)
+    nrow = math.ceil(height / 30)
+    return ncol, nrow
+
+
+def get_max_fuyu_image_feature_size():
+
+    return _calculate_num_image_tokens(
+        height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
+        width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
+    )
+
+
+def get_max_fuyu_image_tokens(ctx: InputContext):
+    ncol, nrow = get_max_fuyu_image_feature_size()
+    return (ncol + 1) * nrow
+
+
+def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
+    ncol, nrow = get_max_fuyu_image_feature_size()
+    image_feature_size = get_max_fuyu_image_tokens(ctx)
+
+    token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
+    token_ids += [0] * (seq_len - image_feature_size)
+    return SequenceData(token_ids)
+
+
+def dummy_image_for_fuyu(
+    image_width: int,
+    image_height: int,
+):
+    image = Image.new("RGB", (image_width, image_height), color=0)
+    return {"image": image}
+
+
+def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
+    seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
+    mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
+                                   MAX_IMAGE_FEATURE_SIZE_HEIGHT)
+    return seq_data, mm_data
+
+
+def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
+                           data: Image.Image):
+    image_encoding = image_processor.preprocess(data, return_tensors="pt")
+    batch_images = torch.stack([img[0] for img in image_encoding["images"]
+                                ]).unsqueeze(1)
+    image_unpadded_heights = torch.tensor(
+        image_encoding["image_unpadded_heights"])
+    image_unpadded_widths = torch.tensor(
+        image_encoding["image_unpadded_widths"])
+
+    batch_size = len(image_encoding["images"])
+    image_present = torch.ones(batch_size, 1, 1)
+    model_image_input = image_processor.preprocess_with_tokenizer_info(
+        image_input=batch_images,
+        image_present=image_present,
+        image_unpadded_h=image_unpadded_heights,
+        image_unpadded_w=image_unpadded_widths,
+        image_placeholder_id=_IMAGE_TOKEN_ID,
+        image_newline_id=_NEWLINE_TOKEN_ID,
+        variable_sized=True,
+    )
+    return model_image_input
+
+
+def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return llm_inputs
+
+    model_config = ctx.model_config
+    image_data = multi_modal_data["image"]
+    new_multi_modal_data = {}
+    # process image data
+    if isinstance(image_data, Image.Image):
+        # Fuyu's image_processor can also finish token padding
+        image_processor: FuyuImageProcessor = cached_get_image_processor(
+            model_config.model)
+
+        model_image_input = _fuyu_image_preprocess(image_processor, image_data)
+        image_patches = torch.stack([
+            image_patch[0]
+            for image_patch in model_image_input["image_patches"]
+        ])
+        new_multi_modal_data["image"] = image_patches
+
+    elif isinstance(image_data, torch.Tensor):
+        raise NotImplementedError("Embeddings input is not supported yet")
+    else:
+        raise TypeError(f"Invalid image type: {type(image_data)}")
+
+    # process prompts
+    prompt = llm_inputs["prompt"]
+    prompt_token_ids = llm_inputs["prompt_token_ids"]
+    tokenizer = cached_get_tokenizer(model_config.model)
+    # dim0 is batch_size, dim1 is subseq_size which will always be 1
+    image_input_ids: List[List[
+        torch.Tensor]] = model_image_input["image_input_ids"]
+    image_input_ids = image_input_ids[0][0].tolist()
+    bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
+    boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
+
+    new_prompt = prompt + "\x04"
+    new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
+        1:] + boa_token
+
+    return LLMInputs(prompt=new_prompt,
+                     prompt_token_ids=new_prompt_token_ids,
+                     multi_modal_data=new_multi_modal_data)
+
+
+def input_mapper_for_fuyu(ctx: InputContext, data: object):
+    model_config = ctx.model_config
+    if isinstance(data, Image.Image):
+        # Fuyu's image_processor can also finish token padding
+        image_processor: FuyuImageProcessor = cached_get_image_processor(
+            model_config.model)
+
+        model_image_input = _fuyu_image_preprocess(image_processor, data)
+        data = torch.stack([
+            image_patch[0]
+            for image_patch in model_image_input["image_patches"]
+        ])
+
+    # image has been processed with prompt in input processor
+    return MultiModalInputs({"image_patches": data})
+
+
+@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
+@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
+class FuyuForCausalLM(nn.Module, SupportsVision):
+
+    def __init__(self,
+                 config: FuyuConfig,
+                 multimodal_config: MultiModalConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
+        super().__init__()
+        self.config = config
+        self.multimodal_config = multimodal_config
+
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+        self.image_token_id = _IMAGE_TOKEN_ID
+        self.image_feature_size = config.patch_size**2 * config.num_channels
+
+        self.vision_embed_tokens = ColumnParallelLinear(
+            self.image_feature_size,
+            config.hidden_size,
+            quant_config=quant_config,
+        )
+        self.language_model = PersimmonForCausalLM(config,
+                                                   cache_config=cache_config,
+                                                   quant_config=quant_config)
+
+    def _parse_and_validate_image_input(self, **kwargs: object):
+        image_patches = kwargs.pop("image_patches", None)
+
+        if isinstance(image_patches, torch.Tensor):
+            expected_feature_size = self.image_feature_size
+            if image_patches.size(-1) != expected_feature_size:
+                raise ValueError(
+                    f"Expected image patches to have the last dimension of "
+                    f"{expected_feature_size}, got {image_patches.size(-1)}")
+            image_patches = image_patches.to(
+                self.vision_embed_tokens.weight.dtype)
+            return FuyuImagePixelInputs(type="pixel_values",
+                                        data=image_patches)
+        return None
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        **kwargs: object,
+    ):
+        image_input = self._parse_and_validate_image_input(**kwargs)
+
+        if image_input is not None:
+            vision_embeddings, _ = self.vision_embed_tokens(
+                image_input["data"])
+            inputs_embeds = self.language_model.model.embed_tokens(input_ids)
+            inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
+                                                    vision_embeddings,
+                                                    self.image_token_id)
+
+        else:
+            inputs_embeds = None
+
+        hidden_states = self.language_model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+            inputs_embeds=inputs_embeds,
+        )
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.language_model.logits_processor(
+            self.language_model.lm_head, hidden_states, sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.language_model.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            param = params_dict[name]
+
+            if "query_key_value" in name:
+                # copy from vllm/model_executor/models/bloom.py
+                # NOTE: Fuyu's fused QKV's output_dim has the shape of
+                # (num_heads * 3 * head_size), while the
+                # required shape is (3 * num_heads * head_size).
+                # Thus, we need weight conversion.
+                output_dim = getattr(param, "output_dim", None)
+                num_heads = self.config.num_attention_heads
+                if output_dim is not None:
+                    loaded_weight_shape = loaded_weight.shape
+                    loaded_weight = loaded_weight.view(
+                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
+                        loaded_weight_shape[output_dim + 1:])
+                    loaded_weight = loaded_weight.transpose(
+                        output_dim, output_dim + 1)
+                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 333 - 0
aphrodite/modeling/models/persimmon.py

@@ -0,0 +1,333 @@
+# coding=utf-8
+# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
+# Copyright 2023 The PygmalionAI team.
+# Copyright 2023 The vLLM team.
+# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only persimmon model compatible with HuggingFace weights."""
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PersimmonConfig
+from transformers.activations import ReLUSquaredActivation
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+
+
+class PersimmonMLP(nn.Module):
+
+    def __init__(self,
+                 config: PersimmonConfig,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
+                                                  config.intermediate_size,
+                                                  quant_config=quant_config)
+        self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
+                                               config.hidden_size,
+                                               quant_config=quant_config)
+        self.act = ReLUSquaredActivation()
+
+    def forward(self, hidden_states) -> torch.Tensor:
+        hidden_states, _ = self.dense_h_to_4h(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states, _ = self.dense_4h_to_h(hidden_states)
+        return hidden_states
+
+
+class PersimmonAttention(nn.Module):
+
+    def __init__(self,
+                 config: PersimmonConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.config = config
+        tensor_parallel_world_size = get_tensor_model_parallel_world_size()
+
+        self.hidden_size = config.hidden_size
+        self.total_num_heads = config.num_attention_heads
+        self.num_heads = self.total_num_heads // tensor_parallel_world_size
+        self.head_dim = self.hidden_size // self.total_num_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.partial_rotary_factor = config.partial_rotary_factor
+        self.is_causal = True
+
+        assert (self.head_dim * self.total_num_heads) == self.hidden_size
+        assert self.total_num_heads % tensor_parallel_world_size == 0
+
+        self.query_key_value = QKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            bias=True,
+            quant_config=quant_config,
+        )
+        self.dense = RowParallelLinear(
+            self.num_heads * self.head_dim,
+            self.hidden_size,
+            bias=True,
+            quant_config=quant_config,
+        )
+        self.is_qk_layernorm = config.qk_layernorm
+
+        if self.is_qk_layernorm:
+            self.q_layernorm = nn.LayerNorm(self.head_dim)
+            self.k_layernorm = nn.LayerNorm(self.head_dim)
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=int(self.partial_rotary_factor * self.head_dim),
+            max_position=self.max_position_embeddings,
+            base=self.rope_theta,
+        )
+        self.scaling = self.head_dim**-0.5
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scaling,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
+        # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
+        seq_length = x.shape[0]
+        return x.view(seq_length, self.num_heads, self.head_dim)
+
+    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
+        # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
+        seq_length = x.shape[0]
+        return x.view(seq_length, self.num_heads * self.head_dim)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        # [seq_length, 3 x hidden_size]
+        qkv, _ = self.query_key_value(hidden_states)
+        q, k, v = qkv.chunk(chunks=3, dim=-1)
+
+        if self.is_qk_layernorm:
+            # [seq_length, num_heads, head_dim]
+            q = self._split_heads(q)
+            k = self._split_heads(k)
+
+            q = self.q_layernorm(q)
+            k = self.k_layernorm(k)
+
+            q = self._merge_heads(q)
+            k = self._merge_heads(k)
+
+        q, k = self.rotary_emb(position_ids, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.dense(attn_output)
+        return output
+
+
+class PersimmonDecoderLayer(nn.Module):
+
+    def __init__(self,
+                 config: PersimmonConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = PersimmonAttention(config=config,
+                                            cache_config=cache_config,
+                                            quant_config=quant_config)
+        self.mlp = PersimmonMLP(config, quant_config=quant_config)
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_eps)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        position_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states = self.self_attn(
+            position_ids=position_ids,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+
+        hidden_states = hidden_states + residual
+
+        outputs = hidden_states
+        return outputs
+
+
+class PersimmonModel(nn.Module):
+
+    def __init__(self,
+                 config: PersimmonConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size)
+        self.layers = nn.ModuleList([
+            PersimmonDecoderLayer(config,
+                                  cache_config=cache_config,
+                                  quant_config=quant_config)
+            for _ in range(config.num_hidden_layers)
+        ])
+        self.final_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_eps)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if inputs_embeds is not None:
+            hidden_states = inputs_embeds
+        else:
+            hidden_states = self.embed_tokens(input_ids)
+        for i in range(len(self.layers)):
+            hidden_states = self.layers[i](
+                positions,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+            )
+        hidden_states = self.final_layernorm(hidden_states)
+        return hidden_states
+
+
+class PersimmonForCausalLM(nn.Module):
+
+    def __init__(self,
+                 config,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.config = config
+        self.vocab_size = config.vocab_size
+        self.model = PersimmonModel(config,
+                                    cache_config=cache_config,
+                                    quant_config=quant_config)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      bias=False)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ):
+        hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+            inputs_embeds=inputs_embeds,
+        )
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            param = params_dict[name]
+
+            if "query_key_value" in name:
+                # copy from vllm/model_executor/models/bloom.py
+                # NOTE: Persimmon's fused QKV's output_dim has the shape of
+                # (num_heads * 3 * head_size), while the
+                # required shape is (3 * num_heads * head_size).
+                # Thus, we need weight conversion.
+                output_dim = getattr(param, "output_dim", None)
+                num_heads = self.config.num_attention_heads
+                if output_dim is not None:
+                    loaded_weight_shape = loaded_weight.shape
+                    loaded_weight = loaded_weight.view(
+                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
+                        loaded_weight_shape[output_dim + 1:])
+                    loaded_weight = loaded_weight.transpose(
+                        output_dim, output_dim + 1)
+                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)
+
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)

+ 35 - 0
examples/vision/fuyu_example.py

@@ -0,0 +1,35 @@
+import os
+from PIL import Image
+
+from aphrodite import LLM, SamplingParams
+
+
+def run_fuyu():
+    llm = LLM(model="adept/fuyu-8b", max_model_len=4096)
+
+    # single-image prompt
+    prompt = "What is the content of this image?\n"
+    image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                              "burg.jpg")
+    image = Image.open(image_path)
+
+    sampling_params = SamplingParams(temperature=1.1,
+                                     min_p=0.06,
+                                     max_tokens=512)
+
+    outputs = llm.generate(
+        {
+            "prompt": prompt,
+            "multi_modal_data": {
+                "image": image
+            },
+        },
+        sampling_params=sampling_params)
+
+    for o in outputs:
+        generated_text = o.outputs[0].text
+        print(generated_text)
+
+
+if __name__ == "__main__":
+    run_fuyu()