Procházet zdrojové kódy

feat: add support for audio models (#891)

* feat: add support for ultravox model

* fix import errors

* add offline audio inference example

* update the openai api example

* update all model tests
AlpinDale před 2 měsíci
rodič
revize
653d1a08d4
34 změnil soubory, kde provedl 1032 přidání a 269 odebrání
  1. 21 0
      aphrodite/assets/audio.py
  2. 4 2
      aphrodite/endpoints/chat_utils.py
  3. 1 0
      aphrodite/modeling/models/__init__.py
  4. 4 4
      aphrodite/modeling/models/blip.py
  5. 4 4
      aphrodite/modeling/models/chameleon.py
  6. 4 4
      aphrodite/modeling/models/clip.py
  7. 2 2
      aphrodite/modeling/models/fuyu.py
  8. 1 1
      aphrodite/modeling/models/internvl.py
  9. 2 2
      aphrodite/modeling/models/minicpmv.py
  10. 1 1
      aphrodite/modeling/models/paligemma.py
  11. 1 1
      aphrodite/modeling/models/phi3v.py
  12. 4 4
      aphrodite/modeling/models/siglip.py
  13. 432 0
      aphrodite/modeling/models/ultravox.py
  14. 6 85
      aphrodite/multimodal/image.py
  15. 92 6
      aphrodite/multimodal/utils.py
  16. 3 1
      aphrodite/transformers_utils/config.py
  17. 2 0
      aphrodite/transformers_utils/configs/__init__.py
  18. 85 0
      aphrodite/transformers_utils/configs/ultravox.py
  19. 1 0
      docs/pages/usage/models.md
  20. 92 0
      examples/audio/audio_example.py
  21. binární
      examples/audio/mary_had_lamb.ogg
  22. 55 0
      examples/openai_api/audio.py
  23. 19 12
      tests/conftest.py
  24. 2 1
      tests/distributed/test_basic_distributed_correctness_enc_dec.py
  25. 23 125
      tests/endpoints/openai/test_audio.py
  26. 2 1
      tests/models/test_bart.py
  27. 3 2
      tests/models/test_blip2.py
  28. 2 2
      tests/models/test_chameleon.py
  29. 3 2
      tests/models/test_llava.py
  30. 3 2
      tests/models/test_llava_image_embeds.py
  31. 3 2
      tests/models/test_llava_next.py
  32. 3 2
      tests/models/test_paligemma.py
  33. 1 1
      tests/models/test_qwen.py
  34. 151 0
      tests/models/test_ultravox.py

+ 21 - 0
aphrodite/assets/audio.py

@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import Literal, Tuple
+from urllib.parse import urljoin
+
+import librosa
+import numpy as np
+
+from aphrodite.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL
+
+ASSET_DIR = "multimodal_asset"
+@dataclass(frozen=True)
+class AudioAsset:
+    name: Literal["winning_call", "mary_had_lamb"]
+    @property
+    def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]:
+        audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
+                                            s3_prefix=ASSET_DIR)
+        return librosa.load(audio_path, sr=None)
+    @property
+    def url(self) -> str:
+        return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")

+ 4 - 2
aphrodite/endpoints/chat_utils.py

@@ -124,8 +124,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
                   modality: Literal["image", "audio"]) -> Optional[str]:
     # TODO: Let user specify how to insert image tokens into prompt
     # (similar to chat template)
+    model_type = model_config.hf_config.model_type
     if modality == "image":
-        model_type = model_config.hf_config.model_type
         if model_type == "phi3_v":
             # Workaround since this token is not defined in the tokenizer
             return "<|image_1|>"
@@ -141,7 +141,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
 
         raise TypeError(f"Unknown model type: {model_type}")
     elif modality == "audio":
-        raise TypeError("No audio models are supported yet.")
+        if model_type == "ultravox":
+            return "<|reserved_special_token_0|>"
+        raise TypeError(f"Unknown model type: {model_type}")
     else:
         raise TypeError(f"Unknown modality: {modality}")
 

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -85,6 +85,7 @@ _MULTIMODAL_MODELS = {
     "PaliGemmaForConditionalGeneration": ("paligemma",
                                           "PaliGemmaForConditionalGeneration"),
     "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
+    "UltravoxModel": ("ultravox", "UltravoxModel"),
 }
 
 _CONDITIONAL_GENERATION_MODELS = {

+ 4 - 4
aphrodite/modeling/models/blip.py

@@ -16,8 +16,8 @@ from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
-from aphrodite.multimodal.image import (cached_get_tokenizer,
-                                        repeat_and_pad_image_tokens)
+from aphrodite.multimodal.utils import (cached_get_tokenizer,
+                                        repeat_and_pad_placeholder_tokens)
 from aphrodite.quantization import QuantizationConfig
 
 
@@ -98,11 +98,11 @@ def input_processor_for_blip(
     else:
         image_feature_size = image_feature_size_override
 
-    new_prompt, new_token_ids = repeat_and_pad_image_tokens(
+    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
         llm_inputs.get("prompt"),
         llm_inputs["prompt_token_ids"],
-        image_token_id=image_token_id,
+        placeholder_token_id=image_token_id,
         repeat_count=image_feature_size,
     )
 

+ 4 - 4
aphrodite/modeling/models/chameleon.py

@@ -32,8 +32,8 @@ from aphrodite.modeling.model_loader.weight_utils import (
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.multimodal.image import (cached_get_tokenizer,
-                                        repeat_and_pad_image_tokens)
+from aphrodite.multimodal.utils import (cached_get_tokenizer,
+                                        repeat_and_pad_placeholder_tokens)
 from aphrodite.quantization.base_config import QuantizationConfig
 
 from .interfaces import SupportsMultiModal
@@ -122,11 +122,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
 
     model_config = ctx.model_config
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
-    new_prompt, new_token_ids = repeat_and_pad_image_tokens(
+    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
         llm_inputs.get("prompt"),
         llm_inputs["prompt_token_ids"],
-        image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
+        placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
         repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
         pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
         pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,

+ 4 - 4
aphrodite/modeling/models/clip.py

@@ -16,8 +16,8 @@ from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
-from aphrodite.multimodal.image import (cached_get_tokenizer,
-                                        repeat_and_pad_image_tokens)
+from aphrodite.multimodal.utils import (cached_get_tokenizer,
+                                        repeat_and_pad_placeholder_tokens)
 from aphrodite.quantization import QuantizationConfig
 
 
@@ -103,11 +103,11 @@ def input_processor_for_clip(
     else:
         image_feature_size = image_feature_size_override
 
-    new_prompt, new_token_ids = repeat_and_pad_image_tokens(
+    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
         llm_inputs.get("prompt"),
         llm_inputs["prompt_token_ids"],
-        image_token_id=image_token_id,
+        placeholder_token_id=image_token_id,
         repeat_count=image_feature_size,
     )
 

+ 2 - 2
aphrodite/modeling/models/fuyu.py

@@ -37,8 +37,8 @@ from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal.base import MultiModalInputs
-from aphrodite.multimodal.image import (cached_get_image_processor,
-                                        cached_get_tokenizer)
+from aphrodite.multimodal.image import cached_get_image_processor
+from aphrodite.multimodal.utils import cached_get_tokenizer
 from aphrodite.quantization.base_config import QuantizationConfig
 
 from .interfaces import SupportsMultiModal

+ 1 - 1
aphrodite/modeling/models/internvl.py

@@ -23,7 +23,7 @@ from aphrodite.modeling.models.intern_vit import InternVisionModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal.base import MultiModalInputs
-from aphrodite.multimodal.image import cached_get_tokenizer
+from aphrodite.multimodal.utils import cached_get_tokenizer
 from aphrodite.quantization import QuantizationConfig
 
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,

+ 2 - 2
aphrodite/modeling/models/minicpmv.py

@@ -55,8 +55,8 @@ from aphrodite.modeling.models.minicpm import MiniCPMModel
 from aphrodite.modeling.models.qwen2 import Qwen2Model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.multimodal.image import (cached_get_image_processor,
-                                        cached_get_tokenizer)
+from aphrodite.multimodal.image import cached_get_image_processor
+from aphrodite.multimodal.utils import cached_get_tokenizer
 from aphrodite.quantization.base_config import QuantizationConfig
 
 from .idefics2_vision_model import Idefics2VisionTransformer

+ 1 - 1
aphrodite/modeling/models/paligemma.py

@@ -16,7 +16,7 @@ from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.gemma import GemmaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.multimodal.image import cached_get_tokenizer
+from aphrodite.multimodal.utils import cached_get_tokenizer
 from aphrodite.quantization.base_config import QuantizationConfig
 
 from .interfaces import SupportsMultiModal

+ 1 - 1
aphrodite/modeling/models/phi3v.py

@@ -38,7 +38,7 @@ from aphrodite.modeling.models.clip import CLIPVisionModel
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.multimodal.image import cached_get_tokenizer
+from aphrodite.multimodal.utils import cached_get_tokenizer
 from aphrodite.quantization.base_config import QuantizationConfig
 
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,

+ 4 - 4
aphrodite/modeling/models/siglip.py

@@ -25,8 +25,8 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.multimodal.image import (cached_get_tokenizer,
-                                        repeat_and_pad_image_tokens)
+from aphrodite.multimodal.utils import (cached_get_tokenizer,
+                                        repeat_and_pad_placeholder_tokens)
 from aphrodite.quantization import QuantizationConfig
 
 
@@ -113,11 +113,11 @@ def input_processor_for_siglip(
     else:
         image_feature_size = image_feature_size_override
 
-    new_prompt, new_token_ids = repeat_and_pad_image_tokens(
+    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
         llm_inputs.get("prompt"),
         llm_inputs["prompt_token_ids"],
-        image_token_id=image_token_id,
+        placeholder_token_id=image_token_id,
         repeat_count=image_feature_size,
     )
 

+ 432 - 0
aphrodite/modeling/models/ultravox.py

@@ -0,0 +1,432 @@
+# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
+"""PyTorch Ultravox model."""
+
+import itertools
+import math
+from array import array
+from functools import lru_cache
+from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+                    TypedDict, Union, cast)
+
+import librosa
+import numpy as np
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import functional as F
+from transformers.models.whisper import WhisperFeatureExtractor
+from transformers.models.whisper.modeling_whisper import WhisperEncoder
+
+from aphrodite.attention import AttentionMetadata
+from aphrodite.common.config import CacheConfig, MultiModalConfig
+from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
+                                       SamplerOutput, SequenceData)
+from aphrodite.inputs import INPUT_REGISTRY
+from aphrodite.inputs.data import LLMInputs
+from aphrodite.inputs.registry import InputContext
+from aphrodite.modeling.layers.activation import SiluAndMul, get_act_fn
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
+from aphrodite.modeling.models.utils import (filter_weights,
+                                             init_aphrodite_registered_model,
+                                             merge_multimodal_embeddings)
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.multimodal import MULTIMODAL_REGISTRY
+from aphrodite.multimodal.base import MultiModalInputs
+from aphrodite.multimodal.utils import (cached_get_tokenizer,
+                                        repeat_and_pad_placeholder_tokens)
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.transformers_utils.configs.ultravox import UltravoxConfig
+
+_AUDIO_PLACEHOLDER_TOKEN = 128002
+_AUDIO_TOKENS_PER_SECOND = 6.25
+
+
+class UltravoxAudioFeatureInputs(TypedDict):
+    type: Literal["audio_features"]
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """Shape: `(batch_size, 80, M)"""
+
+
+class UltravoxAudioEmbeddingInputs(TypedDict):
+    type: Literal["audio_embeds"]
+    data: torch.Tensor
+
+
+UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
+                            UltravoxAudioEmbeddingInputs]
+
+
+@lru_cache
+def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
+    return WhisperFeatureExtractor.from_pretrained(model_id)
+
+
+def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
+    return cached_feature_extractor(
+        ctx.get_hf_config(UltravoxConfig).audio_model_id)
+
+
+def get_ultravox_max_audio_tokens(ctx: InputContext):
+    feature_extractor = whisper_feature_extractor(ctx)
+    return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
+
+
+def dummy_data_for_ultravox(
+    ctx: InputContext,
+    seq_len: int,
+    mm_counts: Mapping[str, int],
+):
+    feature_extractor = whisper_feature_extractor(ctx)
+
+    audio_count = mm_counts["audio"]
+
+    audio_token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [
+        _AUDIO_PLACEHOLDER_TOKEN
+    ]) * get_ultravox_max_audio_tokens(ctx) * audio_count
+    other_token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
+                            [0]) * (seq_len - len(audio_token_ids))
+
+    audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
+    mm_dict = {
+        "audio":
+        audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
+    }
+
+    return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
+
+
+def input_mapper_for_ultravox(ctx: InputContext, data: object):
+    if isinstance(data, tuple):
+        (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
+        feature_extractor = whisper_feature_extractor(ctx)
+
+        if sr != feature_extractor.sampling_rate:
+            audio = librosa.resample(audio,
+                                     orig_sr=sr,
+                                     target_sr=feature_extractor.sampling_rate)
+            sr = feature_extractor.sampling_rate
+
+        minimum_audio_length = feature_extractor.n_fft // 2 + 1
+        if len(audio) < minimum_audio_length:
+            # Not enough audio; pad it.
+            audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
+
+        return MultiModalInputs({
+            "audio_features":
+            feature_extractor(audio,
+                              sampling_rate=sr,
+                              padding="longest",
+                              return_tensors="pt")["input_features"]
+        })
+
+    raise NotImplementedError(f"Unsupported data type: {type(data)}")
+
+
+def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
+    multi_modal_data = llm_inputs.get("multi_modal_data")
+    if multi_modal_data is None or "audio" not in multi_modal_data:
+        return llm_inputs
+
+    feature_extractor = whisper_feature_extractor(ctx)
+    audio_data, sample_rate = multi_modal_data["audio"]
+
+    audio_length = audio_data.shape[0]
+    if sample_rate != feature_extractor.sampling_rate:
+        # Account for resampling.
+        adjustment = feature_extractor.sampling_rate / sample_rate
+        audio_length = math.ceil(adjustment * audio_length)
+
+    feature_extractor_output_length = math.ceil(
+        (audio_length -
+         (feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
+
+    uv_config = ctx.get_hf_config(UltravoxConfig)
+    audio_num_tokens = min(
+        max(
+            1,
+            math.ceil(feature_extractor_output_length /
+                      (uv_config.stack_factor * 2))),
+        get_ultravox_max_audio_tokens(ctx))
+    tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
+
+    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
+        tokenizer,
+        llm_inputs.get("prompt"),
+        llm_inputs["prompt_token_ids"],
+        placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
+        repeat_count=audio_num_tokens,
+    )
+
+    # NOTE: Create a defensive copy of the original inputs
+    return LLMInputs(prompt_token_ids=new_token_ids,
+                     prompt=new_prompt,
+                     multi_modal_data=multi_modal_data)
+
+
+class StackAudioFrames(nn.Module):
+    """
+    Stack the audio embedding frames to reduce the sequence length by a factor
+    of `stack_factor`.
+    """
+
+    def __init__(self, stack_factor: int = 8):
+        super().__init__()
+        self.stack_factor = stack_factor
+
+    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
+        B, T, C = audio_embeds.shape
+        T_pad = (T + self.stack_factor -
+                 1) // self.stack_factor * self.stack_factor
+        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
+        B, T, C = audio_embeds.shape
+        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
+                                         C * self.stack_factor)
+        return audio_embeds
+
+
+class FlippedSiluAndMul(SiluAndMul):
+    """Ultravox is trained with SwiGLU with flipped halves."""
+
+    def forward(self, x: torch.Tensor):
+        a, b = x.chunk(2, dim=-1)
+        flipped = torch.cat((b, a), dim=-1)
+        return super().forward(flipped)
+
+
+class UltravoxProjector(nn.Module):
+
+    def __init__(self, config: UltravoxConfig):
+        super().__init__()
+        self.hidden_dim = config.hidden_size
+        self._pad_and_stack = StackAudioFrames(config.stack_factor)
+        dim = config.audio_config.hidden_size * config.stack_factor
+        self.ln_pre = RMSNorm(dim)
+        self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
+        dim = self.hidden_dim
+
+        if config.projector_act == "swiglu":
+            self.act = FlippedSiluAndMul()
+            dim = dim // 2
+        else:
+            self.act = get_act_fn(config.projector_act)
+
+        self.linear_2 = nn.Linear(dim,
+                                  config.text_config.hidden_size,
+                                  bias=False)
+        self.ln_post = RMSNorm(config.text_config.hidden_size)
+
+    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
+        audio_features = self._pad_and_stack(audio_features)
+        audio_features = self.ln_pre(audio_features)
+        hidden_states = self.linear_1(audio_features)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.linear_2(hidden_states)
+        hidden_states = self.ln_post(hidden_states)
+        return hidden_states
+
+
+class ModifiedWhisperEncoder(WhisperEncoder):
+    """
+    Encoder portion of OpenAI's Whisper model.
+
+    This implementation is a slightly modified version of HF Transformers'
+    Whisper Encoder, with only a few fixes:
+    1. base_model_prefix updated to allow for doing `.from_pretrained`
+       directly on the encoder
+    2. allow less than 30 second of audio padding to be passed in:
+        - relaxed ValueError check for `input_features` length to be less
+           than or equal to `expected_seq_length` instead of strictly equal
+        - embed_pos is now sliced to match the length of `inputs_embeds`
+
+    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
+    See commentary: https://github.com/huggingface/transformers/issues/25744
+    """
+
+    base_model_prefix = "model.encoder"
+
+    def forward(
+        self,
+        input_features,
+    ):
+        expected_seq_length = (self.config.max_source_positions *
+                               self.conv1.stride[0] * self.conv2.stride[0])
+        if input_features.shape[-1] > expected_seq_length:
+            raise ValueError(
+                f"Whisper expects the mel input features to be of length "
+                f"{expected_seq_length} or less, but found "
+                f"{input_features.shape[-1]}. Make sure to pad the input mel "
+                f"features to {expected_seq_length}.")
+
+        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
+        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
+
+        inputs_embeds = inputs_embeds.permute(0, 2, 1)
+        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]
+
+        hidden_states = inputs_embeds + embed_pos
+        hidden_states = nn.functional.dropout(hidden_states,
+                                              p=self.dropout,
+                                              training=self.training)
+
+        for encoder_layer in self.layers:
+            layer_outputs = encoder_layer(
+                hidden_states,
+                None,
+                layer_head_mask=None,
+            )
+
+            hidden_states = layer_outputs[0]
+
+        hidden_states = self.layer_norm(hidden_states)
+        return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+    "audio", get_ultravox_max_audio_tokens)
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
+class UltravoxModel(nn.Module, SupportsMultiModal):
+
+    def __init__(self,
+                 config: UltravoxConfig,
+                 multimodal_config: MultiModalConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional["QuantizationConfig"] = None):
+        super().__init__()
+        self.config = config
+        self.multi_modal_config = multimodal_config
+        assert self.multi_modal_config
+
+        if config.audio_model_id is not None:
+            self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
+                config.audio_model_id)
+        else:
+            self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
+        self.multi_modal_projector = UltravoxProjector(config)
+        self.language_model = init_aphrodite_registered_model(
+            config.text_config, cache_config, quant_config)
+
+    def _audio_features_to_embeddings(
+            self, input_features: torch.Tensor) -> torch.Tensor:
+        audio_input = input_features.to(self.audio_tower.dtype)
+        audio_features = self.audio_tower(audio_input)
+        audio_features = audio_features.to(self.audio_tower.dtype)
+        audio_embeddings = self.multi_modal_projector(audio_features)
+        return audio_embeddings
+
+    def _parse_and_validate_audio_input(
+            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
+        audio_features = kwargs.pop("audio_features", None)
+        audio_embeds = kwargs.pop("audio_embeds", None)
+
+        if audio_features is None and audio_embeds is None:
+            return None
+
+        if audio_features is not None:
+            if not isinstance(audio_features, (torch.Tensor, list)):
+                raise ValueError("Incorrect type of audio features. "
+                                 f"Got type: {type(audio_features)}")
+
+            return UltravoxAudioFeatureInputs(type="audio_features",
+                                              data=audio_features)
+
+        if audio_embeds is not None:
+            if not isinstance(audio_embeds, torch.Tensor):
+                raise ValueError("Incorrect type of audio embeds. "
+                                 f"Got type: {type(audio_embeds)}")
+
+            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
+                                                data=audio_embeds)
+
+        raise AssertionError("This line should be unreachable.")
+
+    def _process_audio_input(
+        self, audio_input: UltravoxAudioInputs
+    ) -> Union[torch.Tensor, List[torch.Tensor]]:
+        if audio_input["type"] == "audio_embeds":
+            return audio_input["data"]
+
+        audio_features = audio_input["data"]
+        if isinstance(audio_features, list):
+            # TODO: Batch these through the encoder/projector instead of
+            # serializing them.
+            return [
+                self._audio_features_to_embeddings(
+                    features.unsqueeze(0)).squeeze(0)
+                for features in audio_features
+            ]
+        else:
+            return self._audio_features_to_embeddings(audio_features)
+
+    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+                kv_caches: List[torch.Tensor],
+                attn_metadata: AttentionMetadata,
+                intermediate_tensors: Optional[torch.Tensor],
+                **kwargs) -> SamplerOutput:
+        """Run forward pass for Ultravox
+
+        One key thing to understand is the `input_ids` already accounts for the
+        positions of the to-be-inserted audio embeddings. The to-be-inserted
+        audio has a size that is essentially 6.25 tokens per second of audio.
+
+        This way, the `positions` and `attn_metadata` are consistent
+        with the `input_ids`.
+
+        Args:
+            input_features: A batch of audio inputs, [1, 80, M].
+        """
+        audio_input = self._parse_and_validate_audio_input(**kwargs)
+        if audio_input is not None:
+            audio_embeddings = self._process_audio_input(audio_input)
+            inputs_embeds = self.language_model.model.get_input_embeddings(
+                input_ids)
+
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids, inputs_embeds, audio_embeddings,
+                _AUDIO_PLACEHOLDER_TOKEN)
+            input_ids = None
+        else:
+            inputs_embeds = None
+
+        hidden_states = self.language_model.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+            intermediate_tensors=intermediate_tensors,
+            inputs_embeds=inputs_embeds)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        return self.language_model.compute_logits(hidden_states,
+                                                  sampling_metadata)
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        return self.language_model.sample(logits, sampling_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        # prepare weight iterators for components
+        projector_weights, llm_weights = itertools.tee(weights, 2)
+
+        # load projector weights
+        projector_weights = filter_weights(projector_weights,
+                                           "multi_modal_projector")
+        projector_params_dict = dict(
+            self.multi_modal_projector.named_parameters())
+        for name, loaded_weight in projector_weights:
+            param = projector_params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+
+        # load llm backbone
+        llm_weights = filter_weights(llm_weights, "language_model")
+        self.language_model.load_weights(llm_weights)

+ 6 - 85
aphrodite/multimodal/image.py

@@ -1,103 +1,21 @@
 from functools import lru_cache
-from typing import List, Optional, Tuple, TypeVar
 
 import torch
 from loguru import logger
 from PIL import Image
-from transformers import PreTrainedTokenizerBase
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.utils import is_list_of
 from aphrodite.inputs.registry import InputContext
 from aphrodite.transformers_utils.image_processor import get_image_processor
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
 from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
 
 cached_get_image_processor = lru_cache(get_image_processor)
-cached_get_tokenizer = lru_cache(get_tokenizer)
-
-# Utilities for image input processors
-_T = TypeVar("_T", str, int)
-
-
-def repeat_and_pad_token(
-    token: _T,
-    *,
-    repeat_count: int = 1,
-    pad_token_left: Optional[_T] = None,
-    pad_token_right: Optional[_T] = None,
-) -> List[_T]:
-    replacement = [token] * repeat_count
-    if pad_token_left is not None:
-        replacement = [pad_token_left] + replacement
-    if pad_token_right is not None:
-        replacement = replacement + [pad_token_right]
-
-    return replacement
-
-
-def repeat_and_pad_image_tokens(
-    tokenizer: PreTrainedTokenizerBase,
-    prompt: Optional[str],
-    prompt_token_ids: List[int],
-    *,
-    image_token_id: int,
-    repeat_count: int = 1,
-    pad_token_left: Optional[int] = None,
-    pad_token_right: Optional[int] = None,
-) -> Tuple[Optional[str], List[int]]:
-    if prompt is None:
-        new_prompt = None
-    else:
-        image_token_str = tokenizer.decode(image_token_id)
-        pad_token_str_left = (None if pad_token_left is None else
-                              tokenizer.decode(pad_token_left))
-        pad_token_str_right = (None if pad_token_right is None else
-                               tokenizer.decode(pad_token_right))
-        replacement_str = "".join(
-            repeat_and_pad_token(
-                image_token_str,
-                repeat_count=repeat_count,
-                pad_token_left=pad_token_str_left,
-                pad_token_right=pad_token_str_right,
-            ))
-
-        image_token_count = prompt.count(image_token_str)
-        # This is an arbitrary number to distinguish between the two cases
-        if image_token_count > 16:
-            logger.warning("Please follow the prompt format that is "
-                           "documented on HuggingFace which does not involve "
-                           f"repeating {image_token_str} tokens.")
-        elif image_token_count > 1:
-            logger.warning("Multiple image input is not supported yet, "
-                           "so any extra image tokens will be treated "
-                           "as plain text.")
-
-        # The image tokens are removed to be consistent with HuggingFace
-        new_prompt = prompt.replace(image_token_str, replacement_str, 1)
-
-    new_token_ids: List[int] = []
-    for i, token in enumerate(prompt_token_ids):
-        if token == image_token_id:
-            replacement_ids = repeat_and_pad_token(
-                image_token_id,
-                repeat_count=repeat_count,
-                pad_token_left=pad_token_left,
-                pad_token_right=pad_token_right,
-            )
-            new_token_ids.extend(replacement_ids)
-
-            # No need to further scan the list since we only replace once
-            new_token_ids.extend(prompt_token_ids[i + 1:])
-            break
-        else:
-            new_token_ids.append(token)
-
-    return new_prompt, new_token_ids
 
 
 class ImagePlugin(MultiModalPlugin):
+    """Plugin for image data."""
 
     def get_data_key(self) -> str:
         return "image"
@@ -114,20 +32,23 @@ class ImagePlugin(MultiModalPlugin):
     ) -> MultiModalInputs:
         model_config = ctx.model_config
 
+        # PIL image
         if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
             image_processor = self._get_hf_image_processor(model_config)
             if image_processor is None:
-                raise RuntimeError("No HuggingFace processor is available"
+                raise RuntimeError("No HuggingFace processor is available "
                                    "to process the image object")
             try:
                 batch_data = image_processor \
                     .preprocess(data, return_tensors="pt") \
                     .data
             except Exception:
-                logger.error(f"Failed to process image ({data})")
+                logger.error("Failed to process image (%s)", data)
                 raise
 
             return MultiModalInputs(batch_data)
+
+        # Image embedding
         elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
             return MultiModalInputs({"image_embeds": data})
 

+ 92 - 6
aphrodite/multimodal/utils.py

@@ -1,19 +1,21 @@
 import base64
+from functools import lru_cache
 from io import BytesIO
-from typing import Tuple, Union
+from typing import List, Optional, Tuple, TypeVar, Union
 
 import librosa
 import numpy as np
 import soundfile
+from loguru import logger
 from PIL import Image
 
-import aphrodite.common.envs as envs
-from aphrodite.common.connections import global_http_connection
+from aphrodite.common.envs import (APHRODITE_AUDIO_FETCH_TIMEOUT,
+                                   APHRODITE_IMAGE_FETCH_TIMEOUT)
+from aphrodite.connections import global_http_connection
 from aphrodite.multimodal.base import MultiModalDataDict
+from aphrodite.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
 
-APHRODITE_IMAGE_FETCH_TIMEOUT = envs.APHRODITE_IMAGE_FETCH_TIMEOUT
-
-APHRODITE_AUDIO_FETCH_TIMEOUT = envs.APHRODITE_AUDIO_FETCH_TIMEOUT
+cached_get_tokenizer = lru_cache(get_tokenizer)
 
 
 def _load_image_from_bytes(b: bytes):
@@ -31,6 +33,7 @@ def _load_image_from_data_url(image_url: str):
 def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
     """
     Load a PIL image from a HTTP or base64 data URL.
+
     By default, the image is converted into RGB format.
     """
     if image_url.startswith('http'):
@@ -52,6 +55,7 @@ async def async_fetch_image(image_url: str,
                             image_mode: str = "RGB") -> Image.Image:
     """
     Asynchronously load a PIL image from a HTTP or base64 data URL.
+
     By default, the image is converted into RGB format.
     """
     if image_url.startswith('http'):
@@ -132,6 +136,7 @@ def encode_image_base64(
 ) -> str:
     """
     Encode a pillow image to base64 format.
+
     By default, the image is converted into RGB format before being encoded.
     """
     buffered = BytesIO()
@@ -155,3 +160,84 @@ def rescale_image_size(image: Image.Image,
     if transpose >= 0:
         image = image.transpose(Image.Transpose(transpose))
     return image
+
+
+# Utilities for input processors
+_T = TypeVar("_T", str, int)
+
+
+def repeat_and_pad_token(
+    token: _T,
+    *,
+    repeat_count: int = 1,
+    pad_token_left: Optional[_T] = None,
+    pad_token_right: Optional[_T] = None,
+) -> List[_T]:
+    replacement = [token] * repeat_count
+    if pad_token_left is not None:
+        replacement = [pad_token_left] + replacement
+    if pad_token_right is not None:
+        replacement = replacement + [pad_token_right]
+
+    return replacement
+
+
+def repeat_and_pad_placeholder_tokens(
+    tokenizer: AnyTokenizer,
+    prompt: Optional[str],
+    prompt_token_ids: List[int],
+    *,
+    placeholder_token_id: int,
+    repeat_count: int = 1,
+    pad_token_left: Optional[int] = None,
+    pad_token_right: Optional[int] = None,
+) -> Tuple[Optional[str], List[int]]:
+    if prompt is None:
+        new_prompt = None
+    else:
+        placeholder_token_str = tokenizer.decode(placeholder_token_id)
+        pad_token_str_left = (None if pad_token_left is None else
+                              tokenizer.decode(pad_token_left))
+        pad_token_str_right = (None if pad_token_right is None else
+                               tokenizer.decode(pad_token_right))
+        replacement_str = "".join(
+            repeat_and_pad_token(
+                placeholder_token_str,
+                repeat_count=repeat_count,
+                pad_token_left=pad_token_str_left,
+                pad_token_right=pad_token_str_right,
+            ))
+
+        placeholder_token_count = prompt.count(placeholder_token_str)
+        # This is an arbitrary number to distinguish between the two cases
+        if placeholder_token_count > 16:
+            logger.warning(
+                "Please follow the prompt format that is "
+                "documented on HuggingFace which does not involve "
+                "repeating %s tokens.", placeholder_token_str)
+        elif placeholder_token_count > 1:
+            logger.warning("Multiple multi-modal input is not supported yet, "
+                           "so any extra placeholder tokens will be treated "
+                           "as plain text.")
+
+        # The image tokens are removed to be consistent with HuggingFace
+        new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
+
+    new_token_ids: List[int] = []
+    for i, token in enumerate(prompt_token_ids):
+        if token == placeholder_token_id:
+            replacement_ids = repeat_and_pad_token(
+                placeholder_token_id,
+                repeat_count=repeat_count,
+                pad_token_left=pad_token_left,
+                pad_token_right=pad_token_right,
+            )
+            new_token_ids.extend(replacement_ids)
+
+            # No need to further scan the list since we only replace once
+            new_token_ids.extend(prompt_token_ids[i + 1:])
+            break
+        else:
+            new_token_ids.append(token)
+
+    return new_prompt, new_token_ids

+ 3 - 1
aphrodite/transformers_utils/config.py

@@ -18,7 +18,8 @@ from aphrodite.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
                                                   InternVLChatConfig,
                                                   JAISConfig, MedusaConfig,
                                                   MLPSpeculatorConfig,
-                                                  MPTConfig, RWConfig)
+                                                  MPTConfig, RWConfig,
+                                                  UltravoxConfig)
 from aphrodite.transformers_utils.utils import check_gguf_file
 
 APHRODITE_USE_MODELSCOPE = envs.APHRODITE_USE_MODELSCOPE
@@ -40,6 +41,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
     "mlp_speculator": MLPSpeculatorConfig,
     "medusa": MedusaConfig,
     "internvl_chat": InternVLChatConfig,
+    "ultravox": UltravoxConfig,
 }
 
 for name, cls in _CONFIG_REGISTRY.items():

+ 2 - 0
aphrodite/transformers_utils/configs/__init__.py

@@ -10,6 +10,7 @@ from aphrodite.transformers_utils.configs.medusa import MedusaConfig
 from aphrodite.transformers_utils.configs.mlp_speculator import (
     MLPSpeculatorConfig)
 from aphrodite.transformers_utils.configs.mpt import MPTConfig
+from aphrodite.transformers_utils.configs.ultravox import UltravoxConfig
 
 __all__ = [
     "ChatGLMConfig",
@@ -20,4 +21,5 @@ __all__ = [
     "JAISConfig",
     "MLPSpeculatorConfig",
     "MedusaConfig",
+    "UltravoxConfig",
 ]

+ 85 - 0
aphrodite/transformers_utils/configs/ultravox.py

@@ -0,0 +1,85 @@
+# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py
+from typing import Any, Dict, Optional
+
+import transformers
+
+
+class UltravoxConfig(transformers.PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a
+    [`UltravoxForConditionalGeneration`]. It is used to instantiate an
+    Ultravox model according to the specified arguments, defining the model
+    architecture.
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to
+    control the model outputs. Read the documentation from [`PretrainedConfig`]
+    for more information.
+    Args:
+        audio_config (`Union[AutoConfig, dict]`,  *optional*):
+            Custom audio config or dict
+        text_config (`Union[AutoConfig, dict]`, *optional*):
+            The config object of the text backbone. Can be any of `LlamaConfig`
+            or `MistralConfig`.
+        ignore_index (`int`, *optional*, defaults to -100):
+            The ignore index for the loss function.
+        audio_token_index (`int`, *optional*, defaults to 32000):
+            The audio token index to encode the audio prompt.
+        stack_factor (`int`, *optional*, defaults to 8):
+            Audio downsampling factor for the multimodal projector.
+        norm_init (`float`, *optional*, defaults to 0.4):
+            The initialization value for the layer normalization.
+        projector_act (`str`, *optional*, defaults to `"swiglu"`):
+            The activation function used by the multimodal projector.
+        text_model_lora_config (`LoraConfigSimplified`, *optional*):
+            The LoRA configuration for finetuning the text model.
+        audio_model_lora_config (`LoraConfigSimplified`, *optional*):
+            The LoRA configuration for finetuning the audio model.
+    """
+    model_type = "ultravox"
+    is_composition = False
+    def __init__(
+        self,
+        audio_config: Optional[Dict[str, Any]] = None,
+        text_config: Optional[Dict[str, Any]] = None,
+        audio_model_id: Optional[str] = None,
+        text_model_id: Optional[str] = None,
+        ignore_index: int = -100,
+        audio_token_index: int = 32000,
+        hidden_size: int = 4096,
+        stack_factor: int = 8,
+        norm_init: float = 0.4,
+        projector_act: str = "swiglu",
+        text_model_lora_config: Optional[Dict[str, Any]] = None,
+        audio_model_lora_config: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ):
+        self.ignore_index = ignore_index
+        self.audio_model_id = audio_model_id
+        self.text_model_id = text_model_id
+        self.audio_token_index = audio_token_index
+        self.hidden_size = hidden_size
+        self.stack_factor = stack_factor
+        self.norm_init = norm_init
+        self.projector_act = projector_act
+        if text_model_id is not None:
+            # Avoid circular import
+            from aphrodite.transformers_utils.config import get_config
+            self.text_config = get_config(text_model_id,
+                                          trust_remote_code=False)
+        else:
+            text_config = text_config or {}
+            self.text_config = transformers.CONFIG_MAPPING[text_config.get(
+                "model_type", "llama")](**text_config)
+        if audio_model_id is not None:
+            # Avoid circular import
+            from aphrodite.transformers_utils.config import get_config
+            self.audio_config = get_config(audio_model_id,
+                                           trust_remote_code=False)
+        else:
+            audio_config = audio_config or {}
+            self.audio_config = transformers.CONFIG_MAPPING[audio_config.get(
+                "model_type", "whisper")](**audio_config)
+        self.text_model_lora_config = text_model_lora_config or {}
+        self.audio_model_lora_config = audio_model_lora_config or {}
+        self.vocab_size = self.text_config.vocab_size
+        self.initializer_range = self.text_config.initializer_range
+        super().__init__(**kwargs)

+ 1 - 0
docs/pages/usage/models.md

@@ -71,6 +71,7 @@ On ROCm platforms, Mistral and Mixtral are capped to 4096 max context length due
 | `PaliGemmaForConditionalGeneration` |        Image         |        `google/paligemma-3b-pt-224` |
 | `Phi3VForCausalLM`                  |        Image         | `microsoft/Phi-3.5-vision-instruct` |
 | `MiniCPMV`                          |        Image         |             `openbmb/MiniCPM-V-2_6` |
+| `UltravoxModel`                     |        Audio         |            `fixie-ai/ultravox-v0_3` |
 
 
 If your model uses any of the architectures above, you can seamlessly run your model with Aphrodite.

+ 92 - 0
examples/audio/audio_example.py

@@ -0,0 +1,92 @@
+"""
+This example shows how to use vLLM for running offline inference 
+with the correct prompt format on vision language models.
+
+For most models, the prompt format should follow corresponding examples
+on HuggingFace model repository.
+"""
+import os
+
+import librosa
+from transformers import AutoTokenizer
+
+from aphrodite import LLM, SamplingParams
+from aphrodite.common.utils import FlexibleArgumentParser
+
+# Input audio and question
+audio_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+                         "mary_had_lamb.ogg")
+audio_and_sample_rate = librosa.load(audio_path, sr=None)
+question = "What is recited in the audio?"
+
+
+# Ultravox 0.3
+def run_ultravox(question):
+    model_name = "fixie-ai/ultravox-v0_3"
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    messages = [{
+        'role': 'user',
+        'content': f"<|reserved_special_token_0|>\n{question}"
+    }]
+    prompt = tokenizer.apply_chat_template(messages,
+                                           tokenize=False,
+                                           add_generation_prompt=True)
+    llm = LLM(model=model_name)
+    stop_token_ids = None
+    return llm, prompt, stop_token_ids
+
+
+model_example_map = {
+    "ultravox": run_ultravox,
+}
+
+
+def main(args):
+    model = args.model_type
+    if model not in model_example_map:
+        raise ValueError(f"Model type {model} is not supported.")
+    llm, prompt, stop_token_ids = model_example_map[model](question)
+    # We set temperature to 0.2 so that outputs can be different
+    # even when all prompts are identical when running batch inference.
+    sampling_params = SamplingParams(temperature=0.2,
+                                     max_tokens=64,
+                                     stop_token_ids=stop_token_ids)
+    assert args.num_prompts > 0
+    if args.num_prompts == 1:
+        # Single inference
+        inputs = {
+            "prompt": prompt,
+            "multi_modal_data": {
+                "audio": audio_and_sample_rate
+            },
+        }
+    else:
+        # Batch inference
+        inputs = [{
+            "prompt": prompt,
+            "multi_modal_data": {
+                "audio": audio_and_sample_rate
+            },
+        } for _ in range(args.num_prompts)]
+    outputs = llm.generate(inputs, sampling_params=sampling_params)
+    for o in outputs:
+        generated_text = o.outputs[0].text
+        print(generated_text)
+
+
+if __name__ == "__main__":
+    parser = FlexibleArgumentParser(
+        description='Demo on using Aphrodite for offline inference with '
+        'audio language models')
+    parser.add_argument('--model-type',
+                        '-m',
+                        type=str,
+                        default="ultravox",
+                        choices=model_example_map.keys(),
+                        help='Huggingface "model_type".')
+    parser.add_argument('--num-prompts',
+                        type=int,
+                        default=1,
+                        help='Number of prompts to run.')
+    args = parser.parse_args()
+    main(args)

binární
examples/audio/mary_had_lamb.ogg


+ 55 - 0
examples/openai_api/audio.py

@@ -0,0 +1,55 @@
+"""An example showing how to use aphrodite to serve VLMs.
+Launch the aphrodite server with the following command:
+aphrodite serve fixie-ai/ultravox-v0_3
+"""
+import base64
+import os
+
+from openai import OpenAI
+
+# Get path to the audio file in ../audio directory
+audio_path = os.path.join(
+    os.path.dirname(os.path.realpath(__file__)),
+    "..",
+    "audio",
+    "mary_had_lamb.ogg",
+)
+
+# Modify OpenAI's API key and API base to use aphrodite's API server.
+openai_api_key = "EMPTY"
+openai_api_base = "http://localhost:2242/v1"
+client = OpenAI(
+    # defaults to os.environ.get("OPENAI_API_KEY")
+    api_key=openai_api_key,
+    base_url=openai_api_base,
+)
+models = client.models.list()
+model = models.data[0].id
+
+def encode_audio_base64_from_file(file_path: str) -> str:
+    """Encode an audio file to base64 format."""
+    with open(file_path, "rb") as f:
+        return base64.b64encode(f.read()).decode("utf-8")
+
+# Use base64 encoded audio in the payload
+audio_base64 = encode_audio_base64_from_file(audio_path)
+chat_completion = client.chat.completions.create(
+    messages=[
+        {
+            "role": "user",
+            "content": [
+                {"type": "text", "text": "What's in this audio?"},
+                {
+                    "type": "audio_url",
+                    "audio_url": {
+                        "url": f"data:audio/ogg;base64,{audio_base64}"
+                    },
+                },
+            ],
+        }  # type: ignore
+    ],
+    model=model,
+    max_tokens=128,
+)
+result = chat_completion.choices[0].message.content
+print(f"Chat completion output: {result}")

+ 19 - 12
tests/conftest.py

@@ -9,6 +9,7 @@ from enum import Enum
 from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
                     TypeVar, Union)
 
+import numpy as np
 import pytest
 import torch
 import torch.nn as nn
@@ -16,8 +17,7 @@ import torch.nn.functional as F
 from huggingface_hub import snapshot_download
 from loguru import logger
 from PIL import Image
-from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
-                          AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
+from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
                           BatchFeature)
 
 from aphrodite import LLM, SamplingParams
@@ -198,8 +198,7 @@ class HfRunner:
         *,
         model_kwargs: Optional[Dict[str, Any]] = None,
         is_embedding_model: bool = False,
-        is_vision_model: bool = False,
-        is_encoder_decoder_model: bool = False,
+        auto_cls=AutoModelForCausalLM,
         postprocess_inputs: Callable[[BatchEncoding],
                                      BatchEncoding] = identity,
     ) -> None:
@@ -216,13 +215,6 @@ class HfRunner:
                     device="cpu",
                 ).to(dtype=torch_dtype))
         else:
-            if is_vision_model:
-                auto_cls = AutoModelForVision2Seq
-            elif is_encoder_decoder_model:
-                auto_cls = AutoModelForSeq2SeqLM
-            else:
-                auto_cls = AutoModelForCausalLM
-
             model_kwargs = model_kwargs if model_kwargs is not None else {}
             self.model = self.wrap_device(
                 auto_cls.from_pretrained(
@@ -414,6 +406,7 @@ class HfRunner:
         max_tokens: int,
         num_logprobs: int,
         images: Optional[List[Image.Image]] = None,
+        audios: Optional[List[Tuple[np.ndarray, int]]] = None,
         **kwargs: Any,
     ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
         all_logprobs: List[List[Dict[int, float]]] = []
@@ -428,6 +421,11 @@ class HfRunner:
             if images is not None and images[i] is not None:
                 processor_kwargs["images"] = images[i]
 
+            if audios is not None:
+                audio, sr = audios[i]
+                processor_kwargs["audio"] = audio
+                processor_kwargs["sampling_rate"] = sr
+
             inputs = self.processor(**processor_kwargs)
             inputs = self.postprocess_inputs(inputs)
 
@@ -609,6 +607,8 @@ class AphroditeRunner:
         sampling_params: SamplingParams,
         images: Optional[Union[List[Image.Image],
                                List[List[Image.Image]]]] = None,
+        audios: Optional[Union[List[Tuple[np.ndarray, int]],
+                               List[List[Tuple[np.ndarray, int]]]]] = None
     ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
         assert sampling_params.logprobs is not None
 
@@ -620,6 +620,10 @@ class AphroditeRunner:
             for i, image in enumerate(images):
                 inputs[i]["multi_modal_data"] = {"image": image}
 
+        if audios is not None:
+            for i, audio in enumerate(audios):
+                inputs[i]["multi_modal_data"] = {"audio": audio}
+
         req_outputs = self.model.generate(inputs,
                                           sampling_params=sampling_params)
         return self._final_steps_generate_w_logprobs(req_outputs)
@@ -656,6 +660,8 @@ class AphroditeRunner:
         num_logprobs: int,
         images: Optional[Union[List[Image.Image],
                                List[List[Image.Image]]]] = None,
+        audios: Optional[Union[List[Tuple[np.ndarray, int]],
+                               List[List[Tuple[np.ndarray, int]]]]] = None,
         stop_token_ids: Optional[List[int]] = None,
     ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
         greedy_logprobs_params = SamplingParams(temperature=0.0,
@@ -664,7 +670,8 @@ class AphroditeRunner:
                                                 stop_token_ids=stop_token_ids)
         outputs = self.generate_w_logprobs(prompts,
                                            greedy_logprobs_params,
-                                           images=images)
+                                           images=images,
+                                           audios=audios)
 
         return [(output_ids, output_str, output_logprobs)
                 for output_ids, output_str, output_logprobs in outputs]

+ 2 - 1
tests/distributed/test_basic_distributed_correctness_enc_dec.py

@@ -10,6 +10,7 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
 """
 
 import pytest
+from transformers import AutoModelForSeq2SeqLM
 
 from aphrodite.common.utils import cuda_device_count_stateless
 
@@ -86,7 +87,7 @@ def test_models(
     }
 
     with hf_runner(model, dtype=dtype,
-                   is_encoder_decoder_model=True) as hf_model:
+                   auto_cls=AutoModelForSeq2SeqLM) as hf_model:
         hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
             test_prompts,
             max_tokens,

+ 23 - 125
tests/endpoints/openai/test_audio.py

@@ -1,138 +1,36 @@
-import math
-import sys
-import time
-from typing import Dict, List, Optional, Tuple, Union, cast
-from unittest.mock import patch
-
-import librosa
-import numpy as np
+from typing import Dict, List
+
 import openai
 import pytest
-import requests
-import torch
-
-from aphrodite import ModelRegistry
-from aphrodite.common.config import MultiModalConfig
-from aphrodite.common.utils import get_open_port
-from aphrodite.inputs import INPUT_REGISTRY
-from aphrodite.inputs.data import LLMInputs
-from aphrodite.inputs.registry import InputContext
-from aphrodite.modeling.models.interfaces import SupportsMultiModal
-from aphrodite.modeling.models.opt import OPTForCausalLM
-from aphrodite.multimodal import MULTIMODAL_REGISTRY
-from aphrodite.multimodal.base import MultiModalInputs
-from aphrodite.multimodal.image import (cached_get_tokenizer,
-                                        repeat_and_pad_image_tokens)
-from aphrodite.multimodal.utils import encode_audio_base64, fetch_audio
 
-from ...utils import APHRODITE_PATH
+from aphrodite.assets.audio import AudioAsset
+from aphrodite.multimodal.utils import encode_audio_base64, fetch_audio
 
-chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
-assert chatml_jinja_path.exists()
+from ...utils import RemoteOpenAIServer
 
-MODEL_NAME = "facebook/opt-125m"
+MODEL_NAME = "fixie-ai/ultravox-v0_3"
 TEST_AUDIO_URLS = [
-    "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
+    AudioAsset("winning_call").url,
 ]
 
 
-def server_function(port):
-
-    def fake_input_mapper(ctx: InputContext, data: object):
-        assert isinstance(data, tuple)
-        (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
-
-        # Resample it to 1 sample per second
-        audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
-        return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
-
-    def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
-        multi_modal_data = llm_inputs.get("multi_modal_data")
-        if multi_modal_data is None or "audio" not in multi_modal_data:
-            return llm_inputs
-
-        audio, sr = multi_modal_data.get("audio")
-        audio_duration = math.ceil(len(audio) / sr)
-
-        new_prompt, new_token_ids = repeat_and_pad_image_tokens(
-            cached_get_tokenizer(ctx.model_config.tokenizer),
-            llm_inputs.get("prompt"),
-            llm_inputs["prompt_token_ids"],
-            image_token_id=62,  # "_"
-            repeat_count=audio_duration)
-
-        return LLMInputs(prompt_token_ids=new_token_ids,
-                         prompt=new_prompt,
-                         multi_modal_data=multi_modal_data)
-
-    @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
-    @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
-        "audio", lambda *_, **__: 100)
-    @INPUT_REGISTRY.register_input_processor(fake_input_processor)
-    class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
-
-        def __init__(self, *args, multimodal_config: MultiModalConfig,
-                     **kwargs):
-            assert multimodal_config is not None
-            super().__init__(*args, **kwargs)
-
-        def forward(
-            self,
-            *args,
-            processed_audio: Optional[torch.Tensor] = None,
-            **kwargs,
-        ) -> torch.Tensor:
-            return super().forward(*args, **kwargs)
-
-    ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
-
-    with patch(
-            "aphrodite.endpoints.chat_utils._mm_token_str",
-            lambda *_, **__: "_"), patch(
-                "aphrodite.modeling.models.ModelRegistry.is_multimodal_model"
-            ) as mock:
-        mock.return_value = True
-        sys.argv = ["placeholder.py"] + \
-            (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
-            "--dtype bfloat16 --enforce-eager --api-key token-abc123 "
-            f"--port {port} --chat-template {chatml_jinja_path} "
-            "--disable-frontend-multiprocessing").split()
-        import runpy
-        runpy.run_module('aphrodite.endpoints.openai.api_server',
-                         run_name='__main__')
+@pytest.fixture(scope="module")
+def server():
+    args = [
+        "--dtype",
+        "bfloat16",
+        "--max-model-len",
+        "4096",
+        "--enforce-eager",
+    ]
+
+    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
+        yield remote_server
 
 
 @pytest.fixture(scope="module")
-def client():
-    port = get_open_port()
-    ctx = torch.multiprocessing.get_context("spawn")
-    server = ctx.Process(target=server_function, args=(port, ))
-    server.start()
-    MAX_SERVER_START_WAIT_S = 60
-    client = openai.AsyncOpenAI(
-        base_url=f"http://localhost:{port}/v1",
-        api_key="token-abc123",
-    )
-    # run health check
-    health_url = f"http://localhost:{port}/health"
-    start = time.time()
-    while True:
-        try:
-            if requests.get(health_url).status_code == 200:
-                break
-        except Exception as err:
-            result = server.exitcode
-            if result is not None:
-                raise RuntimeError("Server exited unexpectedly.") from err
-
-            time.sleep(0.5)
-            if time.time() - start > MAX_SERVER_START_WAIT_S:
-                raise RuntimeError("Server failed to start in time.") from err
-
-    try:
-        yield client
-    finally:
-        server.kill()
+def client(server):
+    return server.get_async_client()
 
 
 @pytest.fixture(scope="session")
@@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
     choice = chat_completion.choices[0]
     assert choice.finish_reason == "length"
     assert chat_completion.usage == openai.types.CompletionUsage(
-        completion_tokens=10, prompt_tokens=36, total_tokens=46)
+        completion_tokens=10, prompt_tokens=202, total_tokens=212)
 
     message = choice.message
     message = chat_completion.choices[0].message
@@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded(
     choice = chat_completion.choices[0]
     assert choice.finish_reason == "length"
     assert chat_completion.usage == openai.types.CompletionUsage(
-        completion_tokens=10, prompt_tokens=36, total_tokens=46)
+        completion_tokens=10, prompt_tokens=202, total_tokens=212)
 
     message = choice.message
     message = chat_completion.choices[0].message

+ 2 - 1
tests/models/test_bart.py

@@ -13,6 +13,7 @@ if not is_cpu():
     # (xFormers, etc.)
 
     import pytest
+    from transformers import AutoModelForSeq2SeqLM
 
     from aphrodite.common.sequence import SampleLogprobs
 
@@ -133,7 +134,7 @@ if not is_cpu():
         }
 
         with hf_runner(model, dtype=dtype,
-                       is_encoder_decoder_model=True) as hf_model:
+                       auto_cls=AutoModelForSeq2SeqLM) as hf_model:
             hf_outputs = (
                 hf_model.generate_encoder_decoder_greedy_logprobs_limit(
                     test_case_prompts,

+ 3 - 2
tests/models/test_blip2.py

@@ -1,7 +1,7 @@
 from typing import List, Optional, Tuple
 
 import pytest
-from transformers import AutoTokenizer
+from transformers import AutoModelForVision2Seq, AutoTokenizer
 
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.multimodal.utils import rescale_image_size
@@ -81,7 +81,8 @@ def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
             for prompts, images in inputs_per_image
         ]
 
-    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
+    with hf_runner(model, dtype=dtype,
+                   auto_cls=AutoModelForVision2Seq) as hf_model:
         hf_outputs_per_image = [
             hf_model.generate_greedy_logprobs_limit(prompts,
                                                     max_tokens,

+ 2 - 2
tests/models/test_chameleon.py

@@ -1,7 +1,7 @@
 from typing import List, Optional, Type
 
 import pytest
-from transformers import BatchEncoding
+from transformers import AutoModelForVision2Seq, BatchEncoding
 
 from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
 from aphrodite.multimodal.utils import rescale_image_size
@@ -74,7 +74,7 @@ def run_test(
     with hf_runner(model,
                    dtype=dtype,
                    postprocess_inputs=process,
-                   is_vision_model=True) as hf_model:
+                   auto_cls=AutoModelForVision2Seq) as hf_model:
         hf_outputs_per_image = [
             hf_model.generate_greedy_logprobs_limit(prompts,
                                                     max_tokens,

+ 3 - 2
tests/models/test_llava.py

@@ -1,7 +1,8 @@
 from typing import List, Optional, Tuple, Type
 
 import pytest
-from transformers import AutoConfig, AutoTokenizer, BatchEncoding
+from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
+                          BatchEncoding)
 
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
@@ -124,7 +125,7 @@ def run_test(
     with hf_runner(model,
                    dtype=dtype,
                    postprocess_inputs=process,
-                   is_vision_model=True) as hf_model:
+                   auto_cls=AutoModelForVision2Seq) as hf_model:
         hf_outputs_per_image = [
             hf_model.generate_greedy_logprobs_limit(prompts,
                                                     max_tokens,

+ 3 - 2
tests/models/test_llava_image_embeds.py

@@ -1,7 +1,7 @@
 from typing import List, Optional, Tuple, Type
 
 import pytest
-from transformers import AutoConfig, AutoTokenizer
+from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 
 from aphrodite.common.sequence import SampleLogprobs
 
@@ -105,7 +105,8 @@ def run_test(
             for prompts, images in aphrodite_inputs_per_image
         ]
 
-    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
+    with hf_runner(model, dtype=dtype,
+                   auto_cls=AutoModelForVision2Seq) as hf_model:
         hf_outputs_per_image = [
             hf_model.generate_greedy_logprobs_limit(prompts,
                                                     max_tokens,

+ 3 - 2
tests/models/test_llava_next.py

@@ -1,7 +1,7 @@
 from typing import List, Optional, Tuple, Type, overload
 
 import pytest
-from transformers import AutoConfig, AutoTokenizer
+from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.multimodal.utils import rescale_image_size
@@ -129,7 +129,8 @@ def run_test(
             for prompts, images in inputs_per_image
         ]
 
-    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
+    with hf_runner(model, dtype=dtype,
+                   auto_cls=AutoModelForVision2Seq) as hf_model:
         hf_outputs_per_image = [
             hf_model.generate_greedy_logprobs_limit(prompts,
                                                     max_tokens,

+ 3 - 2
tests/models/test_paligemma.py

@@ -2,7 +2,7 @@ import os
 from typing import List, Optional, Tuple, Type
 
 import pytest
-from transformers import AutoConfig, AutoTokenizer
+from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
 
 from aphrodite.common.sequence import SampleLogprobs
 from aphrodite.common.utils import is_hip
@@ -102,7 +102,8 @@ def run_test(
             for prompts, images in inputs_per_image
         ]
 
-    with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
+    with hf_runner(model, dtype=dtype,
+                   auto_cls=AutoModelForVision2Seq) as hf_model:
         hf_outputs_per_image = [
             hf_model.generate_greedy_logprobs_limit(prompts,
                                                     max_tokens,

+ 1 - 1
tests/models/test_qwen.py

@@ -26,7 +26,7 @@ def test_text_only_qwen_model(
     # for qwen-vl is still unsupported in Aphrodite. In the near-future, the
     # implementation and this test will be extended to consider
     # visual inputs as well.
-    with hf_runner(model, dtype=dtype, is_vision_model=False) as hf_model:
+    with hf_runner(model, dtype=dtype) as hf_model:
         hf_outputs = hf_model.generate_greedy_logprobs_limit(
             example_prompts,
             max_tokens,

+ 151 - 0
tests/models/test_ultravox.py

@@ -0,0 +1,151 @@
+from typing import List, Optional, Tuple, Type
+
+import librosa
+import numpy as np
+import pytest
+from transformers import AutoModel, AutoTokenizer, BatchEncoding
+
+from aphrodite.assets.audio import AudioAsset
+from aphrodite.common.sequence import SampleLogprobs
+from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
+
+from ..conftest import HfRunner, AphroditeRunner
+from .utils import check_logprobs_close
+
+pytestmark = pytest.mark.vlm
+
+MODEL_NAME = "fixie-ai/ultravox-v0_3"
+
+AudioTuple = Tuple[np.ndarray, int]
+
+
+@pytest.fixture(scope="session")
+def audio_and_sample_rate():
+    return AudioAsset("mary_had_lamb").audio_and_sample_rate
+
+
+@pytest.fixture
+def prompts_and_audios(audio_and_sample_rate):
+    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
+
+    aphrodite_placeholder = "<|reserved_special_token_0|>"
+    hf_placeholder = "<|audio|>"
+
+    question = "What's in the audio?"
+    aphrodite_prompt = tokenizer.apply_chat_template(
+        [{
+            'role': 'user',
+            'content': f"{aphrodite_placeholder}\n{question}"
+        }],
+        tokenize=False,
+        add_generation_prompt=True)
+    hf_prompt = tokenizer.apply_chat_template(
+        [{
+            'role': 'user',
+            'content': f"{hf_placeholder}\n{question}"
+        }],
+        tokenize=False,
+        add_generation_prompt=True)
+
+    return [(aphrodite_prompt, hf_prompt, audio_and_sample_rate)]
+
+
+def aphrodite_to_hf_output(aphrodite_output: Tuple[List[int], str,
+                                         Optional[SampleLogprobs]],
+                      model: str):
+    """Sanitize aphrodite output to be comparable with hf output."""
+    output_ids, output_str, out_logprobs = aphrodite_output
+
+    tokenizer = AutoTokenizer.from_pretrained(model)
+    eos_token_id = tokenizer.eos_token_id
+
+    hf_output_ids = output_ids[:]
+    hf_output_str = output_str
+    if hf_output_ids[-1] == eos_token_id:
+        hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
+
+    return hf_output_ids, hf_output_str, out_logprobs
+
+
+def run_test(
+    hf_runner: Type[HfRunner],
+    aphrodite_runner: Type[AphroditeRunner],
+    prompts_and_audios: List[Tuple[str, str, AudioTuple]],
+    model: str,
+    *,
+    dtype: str,
+    max_tokens: int,
+    num_logprobs: int,
+    tensor_parallel_size: int,
+    distributed_executor_backend: Optional[str] = None,
+):
+    """Inference result should be the same between hf and aphrodite."""
+    torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
+
+    # NOTE: take care of the order. run Aphrodite first, and then run HF.
+    # Aphrodite needs a fresh new process without cuda initialization.
+    # if we run HF first, the cuda initialization will be done and it
+    # will hurt multiprocessing backend with fork method (the default method).
+
+    with aphrodite_runner(model,
+                     dtype=dtype,
+                     tensor_parallel_size=tensor_parallel_size,
+                     distributed_executor_backend=distributed_executor_backend,
+                     enforce_eager=True) as aphrodite_model:
+        aphrodite_outputs_per_audio = [
+            aphrodite_model.generate_greedy_logprobs([aphrodite_prompt],
+                                                max_tokens,
+                                                num_logprobs=num_logprobs,
+                                                audios=[audio])
+            for aphrodite_prompt, _, audio in prompts_and_audios
+        ]
+
+    def process(hf_inputs: BatchEncoding):
+        hf_inputs["audio_values"] = hf_inputs["audio_values"] \
+            .to(torch_dtype)  # type: ignore
+        return hf_inputs
+
+    with hf_runner(model,
+                   dtype=dtype,
+                   postprocess_inputs=process,
+                   auto_cls=AutoModel) as hf_model:
+
+        hf_outputs_per_audio = [
+            hf_model.generate_greedy_logprobs_limit(
+                [hf_prompt],
+                max_tokens,
+                num_logprobs=num_logprobs,
+                audios=[(librosa.resample(audio[0],
+                                          orig_sr=audio[1],
+                                          target_sr=16000), 16000)])
+            for _, hf_prompt, audio in prompts_and_audios
+        ]
+
+    for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_audio,
+                                        aphrodite_outputs_per_audio):
+        check_logprobs_close(
+            outputs_0_lst=hf_outputs,
+            outputs_1_lst=[
+                aphrodite_to_hf_output(aphrodite_output, model)
+                for aphrodite_output in aphrodite_outputs
+            ],
+            name_0="hf",
+            name_1="aphrodite",
+        )
+
+
+@pytest.mark.parametrize("dtype", ["half"])
+@pytest.mark.parametrize("max_tokens", [128])
+@pytest.mark.parametrize("num_logprobs", [5])
+def test_models(hf_runner, aphrodite_runner, prompts_and_audios, dtype: str,
+                max_tokens: int, num_logprobs: int) -> None:
+    run_test(
+        hf_runner,
+        aphrodite_runner,
+        prompts_and_audios,
+        MODEL_NAME,
+        dtype=dtype,
+        max_tokens=max_tokens,
+        num_logprobs=num_logprobs,
+        tensor_parallel_size=1,
+    )