Selaa lähdekoodia

feat: support for Audio modality (#698)

* add audio plugin

* modify interfaces to accomodate multiple modalities

* import fixes

* chat utils stuff

* update docs
AlpinDale 6 kuukautta sitten
vanhempi
commit
3693028340

+ 69 - 33
aphrodite/endpoints/chat_utils.py

@@ -3,7 +3,8 @@ import tempfile
 from dataclasses import dataclass
 from functools import lru_cache
 from pathlib import Path
-from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
+from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
+                    Union, cast)
 
 import requests
 from loguru import logger
@@ -23,10 +24,25 @@ from typing_extensions import Required, TypedDict
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.multimodal import MultiModalDataDict
-from aphrodite.multimodal.utils import async_get_and_parse_image
+from aphrodite.multimodal.utils import (async_get_and_parse_audio,
+                                        async_get_and_parse_image)
 from aphrodite.transformers_utils.tokenizer import AnyTokenizer
 
 
+class AudioURL(TypedDict, total=False):
+    url: Required[str]
+    """
+    Either a URL of the audio or a data URL with base64 encoded audio data.
+    """
+
+
+class ChatCompletionContentPartAudioParam(TypedDict, total=False):
+    audio_url: Required[AudioURL]
+
+    type: Required[Literal["audio_url"]]
+    """The type of the content part."""
+
+
 class CustomChatCompletionContentPartParam(TypedDict, total=False):
     __pydantic_config__ = ConfigDict(extra="allow")  # type: ignore
 
@@ -35,6 +51,7 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
 
 
 ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
+                                       ChatCompletionContentPartAudioParam,
                                        CustomChatCompletionContentPartParam]
 
 
@@ -103,35 +120,41 @@ def load_chat_template(
 
 
 @lru_cache(maxsize=None)
-def _image_token_str(model_config: ModelConfig,
-                     tokenizer: PreTrainedTokenizer) -> Optional[str]:
+def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
+                  modality: Literal["image", "audio"]) -> Optional[str]:
     # TODO: Let user specify how to insert image tokens into prompt
     # (similar to chat template)
-    model_type = model_config.hf_config.model_type
-    if model_type == "phi3_v":
-        # Workaround since this token is not defined in the tokenizer
-        return "<|image_1|>"
-    if model_type == "minicpmv":
-        return "(<image>./</image>)"
-    if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
-        # These models do not use image tokens in the prompt
-        return None
-    if model_type.startswith("llava"):
-        return tokenizer.decode(model_config.hf_config.image_token_index)
-    if model_type in ("chameleon", "internvl_chat"):
-        return "<image>"
-
-    raise TypeError(f"Unknown model type: {model_type}")
-
-
-# TODO: Let user specify how to insert image tokens into prompt
+    if modality == "image":
+        model_type = model_config.hf_config.model_type
+        if model_type == "phi3_v":
+            # Workaround since this token is not defined in the tokenizer
+            return "<|image_1|>"
+        if model_type == "minicpmv":
+            return "(<image>./</image>)"
+        if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
+            # These models do not use image tokens in the prompt
+            return None
+        if model_type.startswith("llava"):
+            return tokenizer.decode(model_config.hf_config.image_token_index)
+        if model_type in ("chameleon", "internvl_chat"):
+            return "<image>"
+
+        raise TypeError(f"Unknown model type: {model_type}")
+    elif modality == "audio":
+        raise TypeError("No audio models are supported yet.")
+    else:
+        raise TypeError(f"Unknown modality: {modality}")
+
+
+# TODO: Let user specify how to insert multimodal tokens into prompt
 # (similar to chat template)
-def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
-    """Combine image and text prompts for vision language model"""
+def _get_full_multimodal_text_prompt(placeholder_token_str: str,
+                                     text_prompt: str) -> str:
+    """Combine multimodal prompts for a multimodal language model"""
 
     # NOTE: For now we assume all model architectures use the same
-    # image + text prompt format. This may change in the future.
-    return f"{image_token_str}\n{text_prompt}"
+    # placeholder + text prompt format. This may change in the future.
+    return f"{placeholder_token_str}\n{text_prompt}"
 
 
 def _parse_chat_message_content_parts(
@@ -142,6 +165,7 @@ def _parse_chat_message_content_parts(
 ) -> ChatMessageParseResult:
     texts: List[str] = []
     mm_futures: List[Awaitable[MultiModalDataDict]] = []
+    modality: Literal["image", "audio"] = "image"
 
     for part in parts:
         part_type = part["type"]
@@ -149,9 +173,10 @@ def _parse_chat_message_content_parts(
             text = cast(ChatCompletionContentPartTextParam, part)["text"]
             texts.append(text)
         elif part_type == "image_url":
+            modality = "image"
             if len(mm_futures) > 0:
                 raise NotImplementedError(
-                    "Multiple 'image_url' input is currently not supported.")
+                    "Multiple multimodal inputs is currently not supported.")
 
             image_url = cast(ChatCompletionContentPartImageParam,
                              part)["image_url"]
@@ -163,21 +188,32 @@ def _parse_chat_message_content_parts(
 
             image_future = async_get_and_parse_image(image_url["url"])
             mm_futures.append(image_future)
+        elif part_type == "audio_url":
+            modality = "audio"
+            if len(mm_futures) > 0:
+                raise NotImplementedError(
+                    "Multiple multimodal inputs is currently not supported.")
+
+            audio_url = cast(ChatCompletionContentPartAudioParam,
+                             part)["audio_url"]
+            audio_future = async_get_and_parse_audio(audio_url["url"])
+            mm_futures.append(audio_future)
         else:
             raise NotImplementedError(f"Unknown part type: {part_type}")
 
     text_prompt = "\n".join(texts)
 
     if mm_futures:
-        image_token_str = _image_token_str(model_config, tokenizer)
-        if image_token_str is not None:
-            if image_token_str in text_prompt:
+        placeholder_token_str = _mm_token_str(model_config, tokenizer,
+                                              modality)
+        if placeholder_token_str is not None:
+            if placeholder_token_str in text_prompt:
                 logger.warning(
-                    "Detected image token string in the text prompt. "
+                    "Detected multi-modal token string in the text prompt. "
                     "Skipping prompt formatting.")
             else:
-                text_prompt = _get_full_image_text_prompt(
-                    image_token_str=image_token_str,
+                text_prompt = _get_full_multimodal_text_prompt(
+                    placeholder_token_str=placeholder_token_str,
                     text_prompt=text_prompt,
                 )
 

+ 2 - 2
aphrodite/modeling/model_loader/loader.py

@@ -37,7 +37,7 @@ from aphrodite.modeling.model_loader.weight_utils import (
     safetensors_weights_iterator)
 from aphrodite.modeling.models.interfaces import (has_inner_state,
                                                   supports_lora,
-                                                  supports_vision)
+                                                  supports_multimodal)
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.platforms import current_platform
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -127,7 +127,7 @@ def _get_model_initialization_kwargs(
             "be added in the future. If this is important to you, "
             "please open an issue on github.")
 
-    if supports_vision(model_class):
+    if supports_multimodal(model_class):
         if multimodal_config is None:
             raise ValueError("Provide vision related configurations "
                              "through LLM entrypoint or engine arguments.")

+ 6 - 6
aphrodite/modeling/models/blip2.py

@@ -22,8 +22,8 @@ from aphrodite.quantization import QuantizationConfig
 
 from .blip import (BlipVisionModel, dummy_image_for_blip,
                    get_max_blip_image_tokens)
-from .interfaces import SupportsVision
-from .utils import merge_vision_embeddings
+from .interfaces import SupportsMultiModal
+from .utils import merge_multimodal_embeddings
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -458,7 +458,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
-class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
+class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: Blip2Config,
@@ -615,9 +615,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
             vision_embeddings = self._process_image_input(image_input)
             inputs_embeds = self.language_model.get_input_embeddings(input_ids)
 
-            inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
-                                                    vision_embeddings,
-                                                    BLIP2_IMAGE_TOKEN_ID)
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids, inputs_embeds, vision_embeddings,
+                BLIP2_IMAGE_TOKEN_ID)
 
             input_ids = None
         else:

+ 2 - 2
aphrodite/modeling/models/chameleon.py

@@ -34,7 +34,7 @@ from aphrodite.multimodal.image import (cached_get_tokenizer,
                                         repeat_and_pad_image_tokens)
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .interfaces import SupportsVision
+from .interfaces import SupportsMultiModal
 
 # These configs are not part of the model config but the preprocessor
 # and processor files, so we hardcode them in the model file for now.
@@ -883,7 +883,7 @@ class ChameleonModel(nn.Module):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
-class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
+class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def __init__(
         self,

+ 6 - 6
aphrodite/modeling/models/fuyu.py

@@ -40,8 +40,8 @@ from aphrodite.multimodal.image import (cached_get_image_processor,
                                         cached_get_tokenizer)
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .interfaces import SupportsVision
-from .utils import merge_vision_embeddings
+from .interfaces import SupportsMultiModal
+from .utils import merge_multimodal_embeddings
 
 # Cannot find the following 2 numbers from hf config.
 _IMAGE_TOKEN_ID = 71011
@@ -207,7 +207,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
-class FuyuForCausalLM(nn.Module, SupportsVision):
+class FuyuForCausalLM(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: FuyuConfig,
@@ -269,9 +269,9 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
         if image_input is not None:
             vision_embeddings = self._process_image_input(image_input)
             inputs_embeds = self.language_model.model.embed_tokens(input_ids)
-            inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
-                                                    vision_embeddings,
-                                                    self.image_token_id)
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids, inputs_embeds, vision_embeddings,
+                self.image_token_id)
 
         else:
             inputs_embeds = None

+ 16 - 12
aphrodite/modeling/models/interfaces.py

@@ -9,12 +9,15 @@ from aphrodite.common.config import (LoRAConfig, MultiModalConfig,
 
 
 @runtime_checkable
-class SupportsVision(Protocol):
-    """The interface required for all vision language models (VLMs)."""
+class SupportsMultiModal(Protocol):
+    """
+    The interface required for all multimodal (vision or audio) language
+    models.
+    """
 
-    supports_vision: ClassVar[Literal[True]] = True
+    supports_multimodal: ClassVar[Literal[True]] = True
     """
-    A flag that indicates this model supports vision inputs.
+    A flag that indicates this model supports multimodal inputs.
 
     Note:
         There is no need to redefine this flag if this class is in the
@@ -28,30 +31,31 @@ class SupportsVision(Protocol):
 # We can't use runtime_checkable with ClassVar for issubclass checks
 # so we need to treat the class as an instance and use isinstance instead
 @runtime_checkable
-class _SupportsVisionType(Protocol):
-    supports_vision: Literal[True]
+class _SupportsMultiModalType(Protocol):
+    supports_multimodal: Literal[True]
 
     def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
         ...
 
 
 @overload
-def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
+def supports_multimodal(
+        model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
     ...
 
 
 @overload
-def supports_vision(model: object) -> TypeIs[SupportsVision]:
+def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
     ...
 
 
-def supports_vision(
+def supports_multimodal(
     model: Union[Type[object], object],
-) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
+) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
     if isinstance(model, type):
-        return isinstance(model, _SupportsVisionType)
+        return isinstance(model, _SupportsMultiModalType)
 
-    return isinstance(model, SupportsVision)
+    return isinstance(model, SupportsMultiModal)
 
 
 @runtime_checkable

+ 6 - 6
aphrodite/modeling/models/internvl.py

@@ -27,9 +27,9 @@ from aphrodite.quantization import QuantizationConfig
 
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                    get_clip_num_patches)
-from .interfaces import SupportsVision
+from .interfaces import SupportsMultiModal
 from .utils import (filter_weights, init_aphrodite_registered_model,
-                    merge_vision_embeddings)
+                    merge_multimodal_embeddings)
 
 IMG_START = '<img>'
 IMG_END = '</img>'
@@ -287,7 +287,7 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
-class InternVLChatModel(nn.Module, SupportsVision):
+class InternVLChatModel(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: PretrainedConfig,
@@ -446,9 +446,9 @@ class InternVLChatModel(nn.Module, SupportsVision):
             inputs_embeds = self.language_model.model.get_input_embeddings(
                 input_ids)
             vision_embeddings = self._process_image_input(image_input)
-            inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
-                                                    vision_embeddings,
-                                                    self.img_context_token_id)
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids, inputs_embeds, vision_embeddings,
+                self.img_context_token_id)
             input_ids = None
         else:
             inputs_embeds = None

+ 4 - 4
aphrodite/modeling/models/llava.py

@@ -18,12 +18,12 @@ from aphrodite.quantization.base_config import QuantizationConfig
 from .clip import (CLIPVisionModel, dummy_image_for_clip,
                    dummy_seq_data_for_clip, get_max_clip_image_tokens,
                    input_processor_for_clip)
-from .interfaces import SupportsVision
+from .interfaces import SupportsMultiModal
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                      input_processor_for_siglip)
 from .utils import (filter_weights, init_aphrodite_registered_model,
-                    merge_vision_embeddings)
+                    merge_multimodal_embeddings)
 
 
 class LlavaImagePixelInputs(TypedDict):
@@ -179,7 +179,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
-class LlavaForConditionalGeneration(nn.Module, SupportsVision):
+class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: LlavaConfig,
@@ -336,7 +336,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
             inputs_embeds = self.language_model.model.get_input_embeddings(
                 input_ids)
 
-            inputs_embeds = merge_vision_embeddings(
+            inputs_embeds = merge_multimodal_embeddings(
                 input_ids, inputs_embeds, vision_embeddings,
                 self.config.image_token_index)
 

+ 4 - 4
aphrodite/modeling/models/llava_next.py

@@ -21,13 +21,13 @@ from aphrodite.quantization.base_config import QuantizationConfig
 from .clip import (CLIPVisionModel, dummy_image_for_clip,
                    dummy_seq_data_for_clip, get_clip_image_feature_size,
                    get_clip_patch_grid_length, input_processor_for_clip)
-from .interfaces import SupportsVision
+from .interfaces import SupportsMultiModal
 from .llava import LlavaMultiModalProjector
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                      get_siglip_patch_grid_length, input_processor_for_siglip)
 from .utils import (filter_weights, init_aphrodite_registered_model,
-                    merge_vision_embeddings)
+                    merge_multimodal_embeddings)
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -270,7 +270,7 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
-class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
+class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: LlavaNextConfig,
@@ -566,7 +566,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
             inputs_embeds = self.language_model.model.get_input_embeddings(
                 input_ids)
 
-            inputs_embeds = merge_vision_embeddings(
+            inputs_embeds = merge_multimodal_embeddings(
                 input_ids, inputs_embeds, vision_embeddings,
                 self.config.image_token_index)
 

+ 2 - 2
aphrodite/modeling/models/minicpmv.py

@@ -48,7 +48,7 @@ from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.utils import set_default_torch_dtype
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
-from aphrodite.modeling.models.interfaces import SupportsVision
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.models.minicpm import MiniCPMModel
 from aphrodite.modeling.models.qwen2 import Qwen2Model
@@ -480,7 +480,7 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
     return llm_inputs
 
 
-class MiniCPMVBaseModel(nn.Module, SupportsVision):
+class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
     """
     The abstract class of MiniCPMV can only be inherited, but cannot be
     instantiated.

+ 4 - 4
aphrodite/modeling/models/paligemma.py

@@ -19,10 +19,10 @@ from aphrodite.multimodal import MULTIMODAL_REGISTRY
 from aphrodite.multimodal.image import cached_get_tokenizer
 from aphrodite.quantization.base_config import QuantizationConfig
 
-from .interfaces import SupportsVision
+from .interfaces import SupportsMultiModal
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
-from .utils import merge_vision_embeddings
+from .utils import merge_multimodal_embeddings
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.model": "language_model",
@@ -126,7 +126,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
-class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
+class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: PaliGemmaConfig,
@@ -240,7 +240,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
 
             inputs_embeds = self.language_model.get_input_embeddings(input_ids)
 
-            inputs_embeds = merge_vision_embeddings(
+            inputs_embeds = merge_multimodal_embeddings(
                 input_ids, inputs_embeds, vision_embeddings,
                 self.config.image_token_index)
 

+ 6 - 6
aphrodite/modeling/models/phi3v.py

@@ -43,8 +43,8 @@ from aphrodite.quantization.base_config import QuantizationConfig
 
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                    input_processor_for_clip)
-from .interfaces import SupportsVision
-from .utils import merge_vision_embeddings
+from .interfaces import SupportsMultiModal
+from .utils import merge_multimodal_embeddings
 
 _KEYS_TO_MODIFY_MAPPING = {
     "model.vision_embed_tokens": "vision_embed_tokens",
@@ -448,7 +448,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
-class Phi3VForCausalLM(nn.Module, SupportsVision):
+class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
 
     def __init__(self,
                  config: PretrainedConfig,
@@ -563,9 +563,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
         if image_input is not None:
             vision_embeddings = self._process_image_input(image_input)
             inputs_embeds = self.model.get_input_embeddings(input_ids)
-            inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
-                                                    vision_embeddings,
-                                                    self.image_token_id)
+            inputs_embeds = merge_multimodal_embeddings(
+                input_ids, inputs_embeds, vision_embeddings,
+                self.image_token_id)
 
             input_ids = None
         else:

+ 15 - 15
aphrodite/modeling/models/utils.py

@@ -56,41 +56,41 @@ def init_aphrodite_registered_model(
     )
 
 
-def merge_vision_embeddings(input_ids: torch.Tensor,
-                            inputs_embeds: torch.Tensor,
-                            vision_embeddings: BatchedTensors,
-                            image_token_id: int) -> torch.Tensor:
+def merge_multimodal_embeddings(input_ids: torch.Tensor,
+                                inputs_embeds: torch.Tensor,
+                                multimodal_embeddings: BatchedTensors,
+                                placeholder_token_id: int) -> torch.Tensor:
     """
-    Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
-    positions in ``inputs_embeds`` corresponding to placeholder image tokens in
+    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
+    positions in ``inputs_embeds`` corresponding to placeholder tokens in
     ``input_ids``.
-
     Note:
         This updates ``inputs_embeds`` in place.
     """
-    mask = (input_ids == image_token_id)
+    mask = (input_ids == placeholder_token_id)
     num_expected_tokens = mask.sum()
 
-    if isinstance(vision_embeddings, torch.Tensor):
-        batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
+    if isinstance(multimodal_embeddings, torch.Tensor):
+        batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
         total_tokens = batch_size * batch_tokens
         if num_expected_tokens != total_tokens:
             expr = f"{batch_size} x {batch_tokens}"
             raise ValueError(
                 f"Attempted to assign {expr} = {total_tokens} "
-                f"image tokens to {num_expected_tokens} placeholders")
+                f"multimodal tokens to {num_expected_tokens} placeholders")
 
-        inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
+        inputs_embeds[mask] = multimodal_embeddings.view(
+            total_tokens, embed_dim)
     else:
-        size_per_batch = [t.shape[0] for t in vision_embeddings]
+        size_per_batch = [t.shape[0] for t in multimodal_embeddings]
         total_tokens = sum(size_per_batch)
         if num_expected_tokens != total_tokens:
             expr = ' + '.join(map(str, size_per_batch))
             raise ValueError(
                 f"Attempted to assign {expr} = {total_tokens} "
-                f"image tokens to {num_expected_tokens} placeholders")
+                f"multimodal tokens to {num_expected_tokens} placeholders")
 
-        inputs_embeds[mask] = torch.cat(vision_embeddings)
+        inputs_embeds[mask] = torch.cat(multimodal_embeddings)
 
     return inputs_embeds
 

+ 17 - 0
aphrodite/multimodal/audio.py

@@ -0,0 +1,17 @@
+from aphrodite.inputs.registry import InputContext
+from aphrodite.multimodal.base import MultiModalInputs, MultiModalPlugin
+
+
+class AudioPlugin(MultiModalPlugin):
+    """Plugin for audio data."""
+
+    def get_data_key(self) -> str:
+        return "audio"
+
+    def _default_input_mapper(self, ctx: InputContext,
+                              data: object) -> MultiModalInputs:
+        raise NotImplementedError("There is no default audio input mapper")
+
+    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
+        raise NotImplementedError(
+            "There is no default maximum multimodal tokens")

+ 7 - 1
aphrodite/multimodal/base.py

@@ -3,8 +3,9 @@ from abc import ABC, abstractmethod
 from collections import UserDict, defaultdict
 from typing import Any, Callable, Dict, List, Optional
 from typing import Sequence as GenericSequence
-from typing import Type, TypedDict, TypeVar, Union, cast
+from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
 
+import numpy as np
 import torch
 import torch.types
 from loguru import logger
@@ -115,7 +116,12 @@ class MultiModalInputs(_MultiModalInputsBase):
 
 
 class MultiModalDataBuiltins(TypedDict, total=False):
+    """Modality types that are pre-defined by Aphrodite."""
     image: Image.Image
+    """The input image."""
+
+    audio: Tuple[np.ndarray, Union[int, float]]
+    """THe input audio and its sampling rate."""
 
 
 MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]

+ 2 - 1
aphrodite/multimodal/registry.py

@@ -6,6 +6,7 @@ from loguru import logger
 
 from aphrodite.common.config import ModelConfig
 
+from .audio import AudioPlugin
 from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
                    MultiModalPlugin, MultiModalTokensCalc)
 from .image import ImagePlugin
@@ -19,7 +20,7 @@ class MultiModalRegistry:
     The registry handles both external and internal data input.
     """
 
-    DEFAULT_PLUGINS = (ImagePlugin(), )
+    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin())
 
     def __init__(
             self,

+ 58 - 1
aphrodite/multimodal/utils.py

@@ -1,8 +1,11 @@
 import base64
 import os
 from io import BytesIO
-from typing import Union
+from typing import Tuple, Union
 
+import librosa
+import numpy as np
+import soundfile
 from PIL import Image
 
 from aphrodite.common.connections import global_http_connection
@@ -11,6 +14,9 @@ from aphrodite.multimodal.base import MultiModalDataDict
 APHRODITE_IMAGE_FETCH_TIMEOUT = int(
     os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
 
+APHRODITE_AUDIO_FETCH_TIMEOUT = int(
+    os.getenv("APHRODITE_AUDIO_FETCH_TIMEOUT", 10))
+
 
 def _load_image_from_bytes(b: bytes):
     image = Image.open(BytesIO(b))
@@ -64,11 +70,62 @@ async def async_fetch_image(image_url: str,
     return image.convert(image_mode)
 
 
+def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
+    """
+    Load audio from a URL.
+    """
+    if audio_url.startswith("http"):
+        audio_bytes = global_http_connection.get_bytes(
+            audio_url, timeout=APHRODITE_AUDIO_FETCH_TIMEOUT)
+    elif audio_url.startswith("data:audio"):
+        _, audio_base64 = audio_url.split(",", 1)
+        audio_bytes = base64.b64decode(audio_base64)
+    else:
+        raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
+                         "with either 'data:audio' or 'http'.")
+
+    return librosa.load(BytesIO(audio_bytes), sr=None)
+
+
+async def async_fetch_audio(
+        audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
+    """
+    Asynchronously fetch audio from a URL.
+    """
+    if audio_url.startswith("http"):
+        audio_bytes = await global_http_connection.async_get_bytes(
+            audio_url, timeout=APHRODITE_AUDIO_FETCH_TIMEOUT)
+    elif audio_url.startswith("data:audio"):
+        _, audio_base64 = audio_url.split(",", 1)
+        audio_bytes = base64.b64decode(audio_base64)
+    else:
+        raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
+                         "with either 'data:audio' or 'http'.")
+
+    return librosa.load(BytesIO(audio_bytes), sr=None)
+
+
+async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
+    audio, sr = await async_fetch_audio(audio_url)
+    return {"audio": (audio, sr)}
+
+
 async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
     image = await async_fetch_image(image_url)
     return {"image": image}
 
 
+def encode_audio_base64(
+    audio: np.ndarray,
+    sampling_rate: int,
+) -> str:
+    """Encode audio as base64."""
+    buffered = BytesIO()
+    soundfile.write(buffered, audio, sampling_rate, format="WAV")
+
+    return base64.b64encode(buffered.getvalue()).decode('utf-8')
+
+
 def encode_image_base64(
     image: Image.Image,
     *,

+ 5 - 4
aphrodite/task_handler/model_runner.py

@@ -49,7 +49,8 @@ from aphrodite.lora.worker_manager import LRUCacheWorkerLoRAManager
 from aphrodite.modeling import SamplingMetadata, SamplingMetadataCache
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
-from aphrodite.modeling.models.interfaces import supports_lora, supports_vision
+from aphrodite.modeling.models.interfaces import (supports_lora,
+                                                  supports_multimodal)
 from aphrodite.modeling.models.utils import set_cpu_offload_max_bytes
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs)
@@ -908,9 +909,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
 
         if self.lora_config:
             assert supports_lora(self.model), "Model does not support LoRA"
-            assert not supports_vision(
+            assert not supports_multimodal(
                 self.model
-            ), "To be tested: vision language model with LoRA settings."
+            ), "To be tested: multimodal language model with LoRA settings."
 
             self.lora_manager = LRUCacheWorkerLoRAManager(
                 self.scheduler_config.max_num_seqs,
@@ -1072,7 +1073,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         # the number of seqs (batch_size) is chosen to maximize the number
         # of images processed.
         model_config = self.model_config
-        if supports_vision(self.model):
+        if supports_multimodal(self.model):
             max_mm_tokens = MULTIMODAL_REGISTRY \
                 .get_max_multimodal_tokens(model_config)
             max_num_seqs_orig = max_num_seqs

+ 2 - 2
aphrodite/task_handler/xpu_model_runner.py

@@ -17,7 +17,7 @@ from aphrodite.common.utils import CudaMemoryProfiler, make_tensor_with_pad
 from aphrodite.distributed import broadcast_tensor_dict
 from aphrodite.inputs import INPUT_REGISTRY
 from aphrodite.modeling.model_loader import get_model
-from aphrodite.modeling.models.interfaces import supports_vision
+from aphrodite.modeling.models.interfaces import supports_multimodal
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs)
 from aphrodite.task_handler.model_runner import (AttentionMetadata,
@@ -166,7 +166,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
         # the number of seqs (batch_size) is chosen to maximize the number
         # of images processed.
         model_config = self.model_config
-        if supports_vision(self.model):
+        if supports_multimodal(self.model):
             max_mm_tokens = MULTIMODAL_REGISTRY \
                 .get_max_multimodal_tokens(model_config)
             max_num_seqs_orig = max_num_seqs

+ 11 - 11
docs/pages/developer/multimodal.md

@@ -13,13 +13,13 @@ See also: [Adding a New Model](/pages/developer/adding-model).
 ## Step 1: Update the base Aphrodite model
 We assume that you have already created a new model by following the steps in the [Adding a New Model](/pages/developer/adding-model) guide. If not, please do so before proceeding.
 
-1. Implement the `aphrodite.modeling.models.interfaces.SupportsVision` interface:
+1. Implement the `aphrodite.modeling.models.interfaces.SupportsMultiModal` interface:
 
 ```py
-from aphrodite.modeling.models.interfaces import SupportsVision  # [!code ++]
+from aphrodite.modeling.models.interfaces import SupportsMultiModal  # [!code ++]
 
 class YourModelForImage2Seq(nn.Module):  # [!code --]
-class YourModelForImage2Seq(nn.Module, SupportsVision):  # [!code ++]
+class YourModelForImage2Seq(nn.Module, SupportsMultiModal):  # [!code ++]
 ```
 
 :::info
@@ -43,11 +43,11 @@ The model class does not have to be named `*ForCausalLM`. Check out the[ Hugging
 For each modality type that the model accepts as input, decorate the model class with `aphrodite.multimodal.MultiModalRegistry.register_input_mapper`. This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in `forward()`.
 
 ```py
-from aphrodite.modeling.models.interfaces import SupportsVision
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.multimodal import MULTIMODAL_REGISTRY  # [!code ++]
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper()  # [!code ++]
-class YourModelForImage2Seq(nn.Module, SupportsVision):
+class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
 ```
 
 A default mapper is available for each modality in the core Aphrodite library. This input mapper will be used if you do not provide your own function.
@@ -62,13 +62,13 @@ For each modality type that the model accepts as input, calculate the maximum po
 
 ```py
 from aphrodite.inputs import INPUT_REGISTRY  # [!code ++]
-from aphrodite.modeling.models.interfaces import SupportsVision
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
 @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
 @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)  # [!code ++]
-class YourModelForImage2Seq(nn.Module, SupportsVision):
+class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
 ```
 
 Here are some examples:
@@ -85,13 +85,13 @@ During startup, dummy data is passed to the Aphrodite model to allocate memory.
 
 ```py
 from aphrodite.inputs import INPUT_REGISTRY
-from aphrodite.modeling.models.interfaces import SupportsVision
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
 @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
 @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)  # [!code ++]
-class YourModelForImage2Seq(nn.Module, SupportsVision):
+class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
 ```
 
 :::info
@@ -111,14 +111,14 @@ Sometimes, there's a need to process inputs at the `aphrodite.AphroditeEngine` l
 
 ```py
 from aphrodite.inputs import INPUT_REGISTRY
-from aphrodite.modeling.models.interfaces import SupportsVision
+from aphrodite.modeling.models.interfaces import SupportsMultiModal
 from aphrodite.multimodal import MULTIMODAL_REGISTRY
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
 @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
 @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
 @INPUT_REGISTRY.register_input_processor(<your_input_processor>)  # [!code ++]
-class YourModelForImage2Seq(nn.Module, SupportsVision):
+class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
 ```
 
 A common use case of input processors is inserting placeholder tokens to leverage the Aphrodite framework for attention mask generation. Here are some examples:

+ 2 - 0
requirements-common.txt

@@ -24,4 +24,6 @@ scipy # for quip
 rich
 loguru
 hf_transfer # for faster downloads
+librosa  # Required for audio processing
+soundfile  # Required for audio processing
 gguf == 0.9.1