浏览代码

feat: initial encoder-decoder support with BART model (#633)

* wip

* wip

* wip

* wip

* wip

* wip

* fix enc dec model runner and upload example

* fix example

* add documentation
AlpinDale 6 月之前
父节点
当前提交
a0e446a17d

+ 3 - 1
aphrodite/attention/__init__.py

@@ -1,12 +1,14 @@
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionMetadata,
-                                                   AttentionMetadataBuilder)
+                                                   AttentionMetadataBuilder,
+                                                   AttentionType)
 from aphrodite.attention.layer import Attention
 from aphrodite.attention.selector import get_attn_backend
 
 __all__ = [
     "AttentionBackend",
     "AttentionMetadata",
+    "AttentionType",
     "AttentionMetadataBuilder",
     "Attention",
     "get_attn_backend",

+ 1 - 2
aphrodite/attention/layer.py

@@ -4,8 +4,7 @@ from typing import Any, Dict, List, Optional
 import torch
 import torch.nn as nn
 
-from aphrodite.attention.backends.abstract import (AttentionMetadata,
-                                                   AttentionType)
+from aphrodite.attention import AttentionMetadata, AttentionType
 from aphrodite.attention.selector import get_attn_backend
 from aphrodite.common.config import CacheConfig
 from aphrodite.quantization.base_config import QuantizationConfig

+ 101 - 11
aphrodite/attention/selector.py

@@ -1,13 +1,15 @@
 import enum
 import os
+from contextlib import contextmanager
 from functools import lru_cache
-from typing import Optional, Type
+from typing import Generator, Optional, Type
 
 import torch
 from loguru import logger
 
 from aphrodite.attention.backends.abstract import AttentionBackend
-from aphrodite.common.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
+from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip,
+                                    is_openvino, is_tpu, is_xpu)
 from aphrodite.platforms import current_platform
 
 APHRODITE_ATTENTION_BACKEND = "APHRODITE_ATTENTION_BACKEND"
@@ -24,6 +26,61 @@ class _Backend(enum.Enum):
     IPEX = enum.auto()
 
 
+def backend_name_to_enum(backend_name: str) -> _Backend:
+    assert backend_name is not None
+
+    backend_members = _Backend.__members__
+    if backend_name not in backend_members:
+        raise ValueError(f"Invalid attention backend '{backend_name}'. "
+                         f"Available backends: {', '.join(backend_members)} "
+                         "(case-sensitive).")
+
+    return _Backend[backend_name]
+
+
+def get_env_variable_attn_backend() -> Optional[_Backend]:
+    '''
+    Get the backend override specified by the Aphrodite attention
+    backend environment variable, if one is specified.
+    Returns:
+    * _Backend enum value if an override is specified
+    * None otherwise
+    '''
+    backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
+    return (None
+            if backend_name is None else backend_name_to_enum(backend_name))
+
+
+# Global state allows a particular choice of backend
+# to be forced, overriding the logic which auto-selects
+# a backend based on system & workload configuration
+# (default behavior if this variable is None)
+#
+# THIS SELECTION TAKES PRECEDENCE OVER THE
+# APHRODITE ATTENTION BACKEND ENVIRONMENT VARIABLE
+forced_attn_backend: Optional[_Backend] = None
+
+
+def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
+    '''
+    Force all attention operations to use a specified backend.
+    Passing `None` for the argument re-enables automatic
+    backend selection.,
+    Arguments:
+    * attn_backend: backend selection (None to revert to auto)
+    '''
+    global forced_attn_backend
+    forced_attn_backend = attn_backend
+
+
+def get_global_forced_attn_backend() -> Optional[_Backend]:
+    '''
+    Get the currently-forced choice of attention backend,
+    or None if auto-selection is currently enabled.
+    '''
+    return forced_attn_backend
+
+
 @lru_cache(maxsize=None)
 def get_attn_backend(
     num_heads: int,
@@ -104,15 +161,20 @@ def which_attn_to_use(
     # Default case.
     selected_backend = _Backend.FLASH_ATTN
 
-    # Check the environment variable and override if specified
-    backend_by_env_var: Optional[str] = os.getenv(APHRODITE_ATTENTION_BACKEND)
-    if backend_by_env_var is not None:
-        backend_members = _Backend.__members__
-        if backend_by_env_var.upper() not in backend_members:
-            raise ValueError(
-                f"Invalid attention backend '{backend_by_env_var}'. "
-                f"Available backends: {', '.join(backend_members)} ")
-        selected_backend = _Backend[backend_by_env_var.upper()]
+    # Check whether a particular choice of backend was
+    # previously forced.
+    #
+    # THIS SELECTION OVERRIDES THE APHRODITE_ATTENTION_BACKEND
+    # ENVIRONMENT VARIABLE.
+    backend_by_global_setting: Optional[_Backend] = (
+        get_global_forced_attn_backend())
+    if backend_by_global_setting is not None:
+        selected_backend = backend_by_global_setting
+    else:
+        # Check the environment variable and override if specified
+        backend_by_env_var: Optional[str] = APHRODITE_ATTENTION_BACKEND
+        if backend_by_env_var is not None:
+            selected_backend = backend_name_to_enum(backend_by_env_var.upper())
     if is_cpu():
         if selected_backend != _Backend.TORCH_SDPA:
             logger.info(f"Cannot use {selected_backend} backend on CPU.")
@@ -194,3 +256,31 @@ def which_attn_to_use(
             selected_backend = _Backend.XFORMERS
 
     return selected_backend
+
+
+@contextmanager
+def global_force_attn_backend_context_manager(
+        attn_backend: _Backend) -> Generator[None, None, None]:
+    '''
+    Globally force a Aphrodite attention backend override within a
+    context manager, reverting the global attention backend
+    override to its prior state upon exiting the context
+    manager.
+    Arguments:
+    * attn_backend: attention backend to force
+    Returns:
+    * Generator
+    '''
+
+    # Save the current state of the global backend override (if any)
+    original_value = get_global_forced_attn_backend()
+
+    # Globally force the new backend override
+    global_force_attn_backend(attn_backend)
+
+    # Yield control back to the enclosed code block
+    try:
+        yield
+    finally:
+        # Revert the original global backend override, if any
+        global_force_attn_backend(original_value)

+ 34 - 2
aphrodite/common/config.py

@@ -9,7 +9,8 @@ import torch
 from loguru import logger
 from transformers import PretrainedConfig
 
-from aphrodite.common.utils import (cuda_device_count_stateless,
+from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
+                                    cuda_device_count_stateless,
                                     get_cpu_memory, is_cpu, is_hip, is_neuron,
                                     is_openvino, is_tpu, is_xpu,
                                     print_warning_once)
@@ -101,6 +102,9 @@ class ModelConfig:
         enforce_eager: Whether to enforce eager execution. If True, we will
             disable CUDA graph and always execute the model in eager mode.
             If False, we will use CUDA graph and eager execution in hybrid.
+            If None, the user did not specify, so default to False -
+            except for encoder/decoder models, which currently require
+            eager mode.
         max_context_len_to_capture: Maximum context len covered by CUDA graphs.
             When a sequence has context length larger than this, we fall back
             to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
@@ -136,7 +140,7 @@ class ModelConfig:
         quantization: Optional[str] = None,
         deepspeed_fp_bits: Optional[int] = None,
         quantization_param_path: Optional[str] = None,
-        enforce_eager: bool = True,
+        enforce_eager: Optional[bool] = None,
         max_context_len_to_capture: Optional[int] = None,
         max_seq_len_to_capture: Optional[int] = None,
         max_logprobs: int = 5,
@@ -178,6 +182,34 @@ class ModelConfig:
         self.hf_text_config = get_hf_text_config(self.hf_config)
         self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
 
+        # Choose a default enforce_eager value if the user did not specify
+        # a value (enforce_eager is None)
+        if getattr(self.hf_config, 'is_encoder_decoder', False):
+            if self.enforce_eager is None:
+                # *Only for encoder/decoder models* and
+                # *only if enforce_eager is unset*, override
+                # to enforce_eager=True
+                #
+                # Add a logger message since it is *somewhat* non-intuitive that
+                # enforce_eager is True when the user has not specified its
+                # value.
+                logger.info("Forcing enforce_eager == True because "
+                            "enforce_eager setting was unspecified and "
+                            "CUDAGraph is not supported with encoder/ "
+                            "decoder models.")
+                self.enforce_eager = True
+
+            if not self.enforce_eager:
+                # Eager mode explicitly disabled by user for an encoder/
+                # decoder model; however CUDAGRAPH + encoder/decoder is
+                # not currently supported
+                raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
+        elif self.enforce_eager is None:
+            # *Only for decoder-only models*, enforce_eager
+            # defaults to False if unset. This is intuitive
+            # so no logging message needed.
+            self.enforce_eager = False
+
         if (not self.disable_sliding_window
                 and self.hf_text_config.model_type == "gemma2"
                 and self.hf_text_config.sliding_window is not None):

+ 18 - 0
aphrodite/common/outputs.py

@@ -70,12 +70,20 @@ class RequestOutput:
     Args:
         request_id: The unique ID of the request.
         prompt: The prompt string of the request.
+                For encoder/decoder models, this is the
+                decoder input prompt.
         prompt_token_ids: The token IDs of the prompt.
+                          For encoder/decoder models, this is the
+                          decoder input prompt token ids.
         prompt_logprobs: The log probabilities to return per prompt token.
         outputs: The output sequences of the request.
         finished: Whether the whole request is finished.
         metrics: Metrics associated with the request.
         lora_request: The LoRA request that was used to generate the output.
+        encoder_prompt: The encoder prompt string of the request; 
+                        None if decoder-only
+        encoder_prompt_token_ids: The token IDs of the encoder prompt;
+                                  None if decoder-only
     """
 
     def __init__(
@@ -88,6 +96,8 @@ class RequestOutput:
         finished: bool,
         metrics: Optional[RequestMetrics] = None,
         lora_request: Optional[LoRARequest] = None,
+        encoder_prompt: Optional[str] = None,
+        encoder_prompt_token_ids: Optional[List[int]] = None,
     ) -> None:
         self.request_id = request_id
         self.prompt = prompt
@@ -97,6 +107,8 @@ class RequestOutput:
         self.finished = finished
         self.metrics = metrics
         self.lora_request = lora_request
+        self.encoder_prompt = encoder_prompt
+        self.encoder_prompt_token_ids = encoder_prompt_token_ids
 
     @classmethod
     def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
@@ -137,6 +149,8 @@ class RequestOutput:
         # Every sequence in the sequence group should have the same prompt.
         prompt = seq_group.prompt
         prompt_token_ids = seq_group.prompt_token_ids
+        encoder_prompt = seq_group.encoder_prompt
+        encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
         prompt_logprobs = seq_group.prompt_logprobs
         finished = seq_group.is_finished()
         finished_time = time.time() if finished else None
@@ -150,12 +164,16 @@ class RequestOutput:
             finished,
             seq_group.metrics,
             lora_request=seq_group.lora_request,
+            encoder_prompt=encoder_prompt,
+            encoder_prompt_token_ids=encoder_prompt_token_ids,
         )
 
     def __repr__(self) -> str:
         return (f"RequestOutput(request_id={self.request_id}, "
                 f"prompt={self.prompt!r}, "
                 f"prompt_token_ids={self.prompt_token_ids}, "
+                f"encoder_prompt={self.encoder_prompt!r}, "
+                f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, "
                 f"prompt_logprobs={self.prompt_logprobs}, "
                 f"outputs={self.outputs}, "
                 f"finished={self.finished}, "

+ 94 - 12
aphrodite/common/sequence.py

@@ -6,12 +6,13 @@ from abc import ABC, abstractmethod
 from array import array
 from collections import defaultdict
 from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast
 
 import torch
 
 from aphrodite.common.pooling_params import PoolingParams
 from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.inputs import is_valid_encoder_decoder_llm_inputs
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 
@@ -242,24 +243,35 @@ class SequenceData:
 
 class Sequence:
     """Stores the data, status, and block information of a sequence.
-
+    The sequence is constructed from the LLMInputs instance passed
+    in through the `inputs` constructor argument.
+    For encoder/decoder models, LLMInputs encapsulates both a
+    decoder and encoder prompt, creating an ambiguity about which
+    prompt to construct the sequence from. The `from_decoder_prompt`
+    constructor argument signals whether to construct the Sequence
+    from the LLMInputs decoder prompt, or encoder prompt.
     Args:
         seq_id: The ID of the sequence.
         inputs: The inputs of the sequence.
         block_size: The block size of the sequence. Should be the same as the
             block size used by the block manager and cache engine.
+        eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
         lora_request: LoRA request.
-        prompt_adapter_request: Prompt adapter request.
+        prompt_adapter_request: Prompt Adapter request.
+        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
+                             (True) or encoder prompt (False.) Must be True
+                             for decoder-only model.
     """
 
     def __init__(
-            self,
-            seq_id: int,
-            inputs: "LLMInputs",
-            block_size: int,
-            eos_token_id: Optional[int] = None,
-            lora_request: Optional[LoRARequest] = None,
-            prompt_adapter_request: Optional[PromptAdapterRequest] = None
+        self,
+        seq_id: int,
+        inputs: "LLMInputs",
+        block_size: int,
+        eos_token_id: Optional[int] = None,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+        from_decoder_prompt: bool = True,
     ) -> None:
         self.seq_id = seq_id
         self.inputs = inputs
@@ -267,6 +279,36 @@ class Sequence:
         self.eos_token_id = eos_token_id
         self.lora_request = lora_request
         self.prompt_adapter_request = prompt_adapter_request
+        self.from_decoder_prompt = from_decoder_prompt
+        self._prompt: Optional[str] = None
+        self._prompt_token_ids: Optional[List[int]] = None
+
+        # For decoder-only models, a Sequence is constructed
+        # from an LLMInputs instance (the `inputs` arg.)
+        #
+        # For encoder/decoder models the same `inputs`
+        # instance could be utilized to construct either an
+        # encoder sequence or a decoder sequence, because
+        # `LLMInputs` has both decoder- and encoder-oriented
+        # member variables (i.e. it encapsulates both an encoder
+        # and a decoder prompt.) The decision of which type of sequence
+        # to generate is determined by the `from_decoder_prompt` argument.
+        #
+        # When constructing a encoder sequence
+        # (`from_decoder_prompt` False) it matters that
+        # the `LLMInputs` instance stored in `inputs` is valid
+        # in the sense that its encoder-related member variables are
+        # populated; below, an exception is raised if this is
+        # not the case.
+        #
+        # When constructing a decoder sequence (`from_decoder_prompt` True)
+        # it does not matter whether `inputs` has its encoder-related
+        # member variables populated.
+        if not (from_decoder_prompt
+                or is_valid_encoder_decoder_llm_inputs(inputs)):
+            raise ValueError("Cannot extract encoder input prompt from "
+                             f"invalid input {inputs}; did you forget the "
+                             "encoder input prompt fields?")
 
         self.data = SequenceData(self.prompt_token_ids)
         self.output_logprobs: SampleLogprobs = []
@@ -287,11 +329,35 @@ class Sequence:
 
     @property
     def prompt(self) -> Optional[str]:
-        return self.inputs.get("prompt")
+        if self._prompt is not None:
+            # Reuse precomputed prompt string
+            return self._prompt
+
+        # Select decoder or encoder input prompt str,
+        # as appropriate
+        prompt_key: str = ("prompt"
+                           if self.from_decoder_prompt else "encoder_prompt")
+
+        # Cache prompt
+        self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
+        return self._prompt
 
     @property
     def prompt_token_ids(self) -> List[int]:
-        return self.inputs["prompt_token_ids"]
+        if self._prompt_token_ids is not None:
+            # Reuse precomputed prompt token ids
+            return self._prompt_token_ids
+
+        # Select decoder or encoder input prompt
+        # token ids, as appropriate
+        prompt_token_ids_key: str = ("prompt_token_ids"
+                                     if self.from_decoder_prompt else
+                                     "encoder_prompt_token_ids")
+
+        # Cache computed prompt token ids
+        self._prompt_token_ids = cast(List[int],
+                                      self.inputs.get(prompt_token_ids_key))
+        return self._prompt_token_ids
 
     @property
     def multi_modal_data(self) -> Optional["MultiModalDataDict"]:
@@ -467,6 +533,22 @@ class SequenceGroup:
         # We use the prompt of an arbitrary sequence.
         return self.seqs[0].prompt_token_ids
 
+    @property
+    def encoder_prompt(self) -> Optional[str]:
+        # There are either 0 or 1 encoder sequences
+        # If one is present, its prompt is distinct
+        # from the decoder's.
+        return (self.encoder_seq.prompt
+                if self.encoder_seq is not None else None)
+
+    @property
+    def encoder_prompt_token_ids(self) -> Optional[List[int]]:
+        # There are either 0 or 1 encoder sequences
+        # If one is present, its prompt token ids are
+        # distinct from the decoder's.
+        return (self.encoder_seq.prompt_token_ids
+                if self.encoder_seq is not None else None)
+
     @property
     def multi_modal_data(self) -> "MultiModalDataDict":
         # All sequences in the group should have the same multi-modal data.

+ 130 - 0
aphrodite/common/utils.py

@@ -28,6 +28,89 @@ from typing_extensions import ParamSpec
 
 from aphrodite import _custom_ops as ops
 from aphrodite.common.logger import enable_trace_function_call
+from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
+                              SingletonPromptInputs)
+
+# Exception strings for non-implemented encoder/decoder scenarios
+
+STR_NOT_IMPL_ENC_DEC_SWA = \
+    "Sliding window attention for encoder/decoder models " + \
+                    "is not currently supported."
+
+STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
+    "Prefix caching for encoder/decoder models " + \
+                    "is not currently supported."
+
+STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
+    "Chunked prefill for encoder/decoder models " + \
+                    "is not currently supported."
+
+STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
+    "Models with logits_soft_cap "
+    "require FlashInfer backend, which is "
+    "currently not supported for encoder/decoder "
+    "models.")
+
+STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
+                             "supported with encoder/decoder "
+                             "models.")
+
+STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
+                           "currently supported with "
+                           "encoder/decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
+                           "supported with encoder/decoder "
+                           "models.")
+
+STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
+                                 "currently supported with encoder/"
+                                 "decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
+                                  "currently supported with encoder/"
+                                  "decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
+                                "currently supported with encoder/"
+                                "decoder models.")
+
+STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
+                                       "currently supported with encoder/"
+                                       "decoder models.")
+
+# Efficiently import all enc/dec error strings
+# rather than having to import all of the above
+STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
+    "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
+    "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
+    "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
+    STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
+    "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
+    "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
+    "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
+    "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
+    "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
+    "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
+    "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
+    "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
+}
+
+# Constants related to forcing the attention backend selection
+
+# String name of register which may be set in order to
+# force auto-selection of attention backend by Attention
+# wrapper
+STR_BACKEND_ENV_VAR: str = "APHRODITE_ATTENTION_BACKEND"
+
+# Possible string values of STR_BACKEND_ENV_VAR
+# register, corresponding to possible backends
+STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
+STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
+STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
+STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
+STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
+STR_INVALID_VAL: str = "INVALID"
 
 STR_DTYPE_TO_TORCH_DTYPE = {
     "half": torch.half,
@@ -1025,3 +1108,50 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
     """Utility function to run async task in a lock"""
     async with lock:
         return await task(*args, **kwargs)
+
+
+def is_encoder_decoder_model_config(model_config) -> bool:
+    '''
+    Extract the HF encoder/decoder model flag from the ModelConfig instance.
+    Return False if model_config is None.
+    '''
+    return model_config is not None and \
+                getattr(model_config.hf_config,
+                        "is_encoder_decoder",
+                        False)
+
+
+def is_embedding_model_config(model_config) -> bool:
+    '''
+    Extract the embedding model flag from the ModelConfig instance.
+    Return False if model_config is None.
+    '''
+    return model_config is not None and \
+                model_config.embedding_mode
+
+
+def build_explicit_enc_dec_prompt(
+    encoder_prompt: SingletonPromptInputs,
+    decoder_prompt: SingletonPromptInputs,
+) -> ExplicitEncoderDecoderPrompt:
+    return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
+                                        decoder_prompt=decoder_prompt)
+
+
+def zip_enc_dec_prompt_lists(
+    enc_prompt_list: List[SingletonPromptInputs],
+    dec_prompt_list: List[SingletonPromptInputs],
+) -> List[ExplicitEncoderDecoderPrompt]:
+    return [
+        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
+        for (encoder_prompt,
+             decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
+    ]
+
+
+def to_enc_dec_tuple_list(
+    enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
+) -> List[Tuple[PromptInputs, PromptInputs]]:
+    return [(enc_dec_prompt['encoder_prompt'],
+             enc_dec_prompt['decoder_prompt'])
+            for enc_dec_prompt in enc_dec_prompts]

+ 16 - 3
aphrodite/endpoints/llm.py

@@ -117,12 +117,19 @@ class LLM:
         gpu_memory_utilization: float = 0.9,
         swap_space: int = 4,
         cpu_offload_gb: float = 0,
-        enforce_eager: bool = False,
+        enforce_eager: Optional[bool] = None,
         max_context_len_to_capture: Optional[int] = None,
         max_seq_len_to_capture: int = 8192,
         disable_custom_all_reduce: bool = False,
         **kwargs,
     ) -> None:
+        '''
+        LLM constructor.
+        Note: if enforce_eager is unset (enforce_eager is None)
+        it defaults to False for decoder-only models and True
+        for encoder/decoder models, since encoder/decoder models
+        do not currently support CUDAGraph.
+        '''
         if "disable_log_stats" not in kwargs:
             kwargs["disable_log_stats"] = True
         removed_vision_keys = ("image_token_id", "image_feature_size",
@@ -292,8 +299,8 @@ class LLM:
         """
         if self.llm_engine.model_config.embedding_mode:
             raise ValueError(
-                "LLM.generate() is only supported for generation models "
-                "(XForCausalLM).")
+                "LLM.generate() is only supported for (conditional) generation "
+                "models (XForCausalLM, XForConditionalGeneration).")
 
         if prompt_token_ids is not None:
             inputs = self._convert_v1_inputs(
@@ -629,3 +636,9 @@ class LLM:
         # This is necessary because some requests may be finished earlier than
         # its previous requests.
         return sorted(outputs, key=lambda x: int(x.request_id))
+
+    def _is_encoder_decoder_model(self):
+        return self.llm_engine.is_encoder_decoder_model()
+
+    def _is_embedding_model(self):
+        return self.llm_engine.is_embedding_model()

+ 376 - 22
aphrodite/engine/aphrodite_engine.py

@@ -3,7 +3,7 @@ import time
 from contextlib import contextmanager
 from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
 from typing import Sequence as GenericSequence
-from typing import Type, TypeVar, Union
+from typing import Tuple, Type, TypeVar, Union
 
 from loguru import logger
 from transformers import PreTrainedTokenizer
@@ -22,7 +22,8 @@ from aphrodite.common.sequence import (EmbeddingSequenceGroupOutput,
                                        ExecuteModelRequest, PoolerOutput,
                                        SamplerOutput, Sequence, SequenceGroup,
                                        SequenceGroupMetadata, SequenceStatus)
-from aphrodite.common.utils import Counter
+from aphrodite.common.utils import (Counter, is_embedding_model_config,
+                                    is_encoder_decoder_model_config)
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
                                       StatLoggerBase, Stats)
@@ -33,7 +34,8 @@ from aphrodite.engine.output_processor.util import (
     create_output_by_sequence_group)
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.executor.ray_utils import initialize_ray_cluster
-from aphrodite.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
+from aphrodite.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
+                              get_prompt_type)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.processing.scheduler import (ScheduledSequenceGroup, Scheduler,
                                             SchedulerOutputs)
@@ -450,8 +452,19 @@ class AphroditeEngine:
             self.prompt_adapter_config.verify_with_model_config(
                 self.model_config)
 
-    def _get_eos_token_id(
-            self, lora_request: Optional[LoRARequest]) -> Optional[int]:
+    def _get_bos_token_id(self,
+                          lora_request: Optional[LoRARequest] = None
+                          ) -> Optional[int]:
+        if self.tokenizer is None:
+            logger.warning("Using None for BOS token id because tokenizer "
+                           "is not initialized")
+            return None
+
+        return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
+
+    def _get_eos_token_id(self,
+                          lora_request: Optional[LoRARequest] = None
+                          ) -> Optional[int]:
         if self.tokenizer is None:
             logger.warning("Using None for EOS token id because tokenizer "
                            "is not initialized")
@@ -459,6 +472,32 @@ class AphroditeEngine:
 
         return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
 
+    def _get_decoder_start_token_id(self, ) -> Optional[int]:
+        '''
+        Obtain the decoder start token id employed by an encoder/decoder
+        model. Returns None for non-encoder/decoder models or if the
+        model config is unavailable.
+        '''
+
+        if not self.is_encoder_decoder_model():
+            logger.warning("Using None for decoder start token id because "
+                           "this is not an encoder/decoder model.")
+            return None
+
+        if (self.model_config is None or self.model_config.hf_config is None):
+            logger.warning("Using None for decoder start token id because "
+                           "model config is not available.")
+            return None
+
+        dec_start_token_id = getattr(self.model_config.hf_config,
+                                     'decoder_start_token_id', None)
+        if dec_start_token_id is None:
+            logger.warning("Falling back on <BOS> for decoder start token id "
+                           "because decoder start token id is not available.")
+            dec_start_token_id = self._get_bos_token_id()
+
+        return dec_start_token_id
+
     def _add_processed_request(
         self,
         request_id: str,
@@ -476,6 +515,16 @@ class AphroditeEngine:
         seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
                        lora_request, prompt_adapter_request)
 
+        encoder_seq = None
+        if 'encoder_prompt_token_ids' in processed_inputs:
+            encoder_seq = Sequence(seq_id,
+                                   processed_inputs,
+                                   block_size,
+                                   eos_token_id,
+                                   lora_request,
+                                   prompt_adapter_request,
+                                   from_decoder_prompt=False)
+
         # Create a SequenceGroup based on SamplingParams or PoolingParams
         if isinstance(params, SamplingParams):
             seq_group = self._create_sequence_group_with_sampling(
@@ -485,6 +534,7 @@ class AphroditeEngine:
                 arrival_time=arrival_time,
                 lora_request=lora_request,
                 prompt_adapter_request=prompt_adapter_request,
+                encoder_seq=encoder_seq,
             )
         elif isinstance(params, PoolingParams):
             seq_group = self._create_sequence_group_with_pooling(
@@ -494,6 +544,7 @@ class AphroditeEngine:
                 arrival_time=arrival_time,
                 lora_request=lora_request,
                 prompt_adapter_request=prompt_adapter_request,
+                encoder_seq=encoder_seq,
             )
         else:
             raise ValueError(
@@ -510,36 +561,329 @@ class AphroditeEngine:
     def stop_remote_worker_execution_loop(self) -> None:
         self.model_executor.stop_remote_worker_execution_loop()
 
-    def process_model_inputs(
+    _LLMInputComponentsType = Tuple[str, List[int], ]
+
+    def _prepare_decoder_input_ids_for_generation(
+        self,
+        decoder_input_ids: Optional[List[int]] = None,
+    ) -> List[int]:
+        """
+        Prepares `decoder_input_ids` for generation with encoder-decoder models.
+        Based on
+        https://github.com/huggingface/transformers/blob/
+        4037a2b5b1278736e566aec12e169100275545ea/
+        src/transformers/generation/utils.py
+        specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
+        Arguments:
+        * decoder_input_ids: input token ids to preprocess
+        Returns:
+        * Processed token list
+        """
+
+        decoder_start_token_id: Optional[int] = (
+            self._get_decoder_start_token_id())
+        assert decoder_start_token_id is not None
+
+        if decoder_input_ids is None:
+            # no decoder prompt input ->
+            # use decoder_start_token_id as decoder_input_ids
+            (decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
+
+        if (len(decoder_input_ids) == 0
+                or decoder_input_ids[0] != decoder_start_token_id):
+            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
+
+        return decoder_input_ids
+
+    def _tokenize_prompt(
+        self,
+        prompt: str,
+        request_id: Optional[str] = None,
+        lora_request: Optional[str] = None,
+    ) -> List[int]:
+        '''
+        Wrapper around application of the model's
+        tokenizer.
+        Arguments:
+        * prompt
+        * request_id
+        * lora_request
+        Returns:
+        * prompt token ids
+        '''
+
+        tokenizer = self.get_tokenizer_group("prompts must be None if "
+                                             "skip_tokenizer_init is True")
+
+        prompt_token_ids = tokenizer.encode(request_id=request_id,
+                                            prompt=prompt,
+                                            lora_request=lora_request)
+
+        return prompt_token_ids
+
+    def _extract_single_prompt_for_enc_dec_input(
+        self,
+        inputs: Optional[PromptInputs],
+        request_id: Optional[str] = None,
+        ptype: Optional[str] = None,
+        is_encoder_prompt: bool = False,
+    ) -> Tuple[Optional[str], List[int]]:
+        '''
+        Only for encoder/decoder models:
+        Extract prompt & prompt_token_ids from any single
+        encoder or decoder input prompt. For encoder input prompts
+        in particular, also extract multi-modal data.
+        This function handles the following scenarios:
+        1. The user supplied a singleton encoder prompt
+          & the prompt/prompt-token-ids must be extracted.
+        2. The user supplied an explicit encoder/decoder
+          prompt & the prompt/prompt-token-ids must be
+          extracted from either the encoder and decoder prompts.
+        For decoder prompts in particular (scenario 2), special
+        processing is applied to the returned decoder token ids.
+        Arguments:
+        * request_id
+        * ptype: str representation of the input prompt type.
+                 If `ptype` is `None`, assume that the prompt
+                 type is unknown and must be inferred. This is the
+                 case for ExplicitEncoderDecoder sub-prompts.
+        * inputs: single encoder or decoder input prompt
+        * is_encoder_prompt: True if encoder input prompt.
+                             If False, decoder prompt tokens
+                             are preprocessed.
+        Returns:
+        * prompt
+        * prompt_token_ids
+        '''
+        prompt_token_ids = None
+        ptype = (get_prompt_type(inputs) if ptype is None else ptype)
+
+        if inputs is None:
+            prompt = None
+        elif ptype == 'str':
+            prompt = inputs
+            prompt_token_ids = self._tokenize_prompt(
+                prompt,
+                request_id=request_id,
+            )
+        elif ptype == 'TokensPrompt':
+            prompt = None
+            prompt_token_ids = inputs['prompt_token_ids']
+        else:
+            prompt = inputs['prompt']
+            prompt_token_ids = self._tokenize_prompt(
+                prompt,
+                request_id=request_id,
+            )
+
+        if not is_encoder_prompt:
+            # Apply special pre-processing to
+            # decoder prompts
+            prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
+                prompt_token_ids, ))
+
+        assert prompt_token_ids is not None
+
+        return (
+            prompt,
+            prompt_token_ids,
+        )
+
+    def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
+        '''
+        Specifically for encoder/decoder models:
+        generate a default decoder prompt for when
+        the user specifies only the encoder prompt.
+        Encoder/decoder models utilize the decoder
+        prompt in different ways; as new models are
+        added, it is intended that this function
+        will be extended to produce differing
+        default decoder prompts, depending on the
+        model variety.
+        Absent a special case, the default behavior
+        of this method is to mirror the behavior of
+        the HuggingFace (HF) GenerationMixin for a None
+        decoder prompt, which is to employ a logit processor
+        setting to force the first decoded token to be <BOS>.
+        Here, this behavior is approximated by having the
+        "default" decoder prompt be <BOS>.
+        However, it is possible that in the future
+        other models may have different or more 
+        complex logic for the default decoder prompt.
+        This motivates having a special helper method
+        for default decoder prompts.
+        Returns:
+        * prompt_token_ids
+        '''
+
+        bos_token_id = self._get_bos_token_id()
+        assert bos_token_id is not None
+        prompt_token_ids: List[int] = [bos_token_id]
+        return prompt_token_ids
+
+    def _process_encoder_decoder_prompt(
+        self,
+        inputs: PromptInputs,
+        request_id: Optional[str] = None,
+    ) -> LLMInputs:
+        '''
+        For encoder/decoder models only:
+        Process an input prompt
+        into an `LLMInputs` instance.
+        There are two types of input prompts:
+        singleton prompts which carry only the
+        encoder prompt, and explicit encoder/decoder
+        prompts which carry both the encoder and the
+        decoder prompts as member variables.
+        This function handles the following scenarios:
+        * Singleton encoder prompt: extract encoder prompt
+          token ids & infer default decoder prompt token ids
+        * Explicit encoder/decoder prompt: extract encoder
+          and decoder prompt token ids
+        Note that for Explicit encoder/decoder prompts,
+        each sub-prompt (encoder or decoder prompt) can
+        have any possible singleton type; thus this
+        method relies on helper functions to obtain
+        token ids for the sub-prompts.
+        
+        Arguments:
+        * inputs: an input prompt
+        * request_id
+        Returns:
+        * `LLMInputs` instance
+        '''
+
+        ptype = get_prompt_type(inputs)
+
+        # Obtain encoder and decoder prompt tokens. Note
+        # that, no matter what, the decoder
+        # prompt type is unknown.
+        if ptype == "ExplicitEncoderDecoder":
+            # If input is explicit encoder/decoder prompt,
+            # then it remains to be determined what type
+            # of encoder prompt we have
+            extracted_encoder_prompt = inputs.get('encoder_prompt')
+            encoder_ptype = None
+            # Extract decoder prompt from explicit
+            # encoder/decoder prompt
+            extracted_decoder_prompt = inputs.get('decoder_prompt')
+        else:
+            # If input is singleton encoder prompt, then
+            # we know the encoder prompt type
+            extracted_encoder_prompt = inputs
+            encoder_ptype = ptype
+            # Decoder prompt is always unknown if
+            # encoder/decoder prompt is not explicit
+            extracted_decoder_prompt = None
+
+        # Invoke helper function to obtain encoder
+        # prompt and prompt token ids, either from
+        # singleton encoder prompt or from the
+        # encoder sub-prompt of an explicit
+        # encoder/decode scenario 2), special
+        # processing is applied to the returned decoder token ids
+        (
+            encoder_prompt,
+            encoder_prompt_token_ids,
+        ) = self._extract_single_prompt_for_enc_dec_input(
+            extracted_encoder_prompt,
+            request_id=request_id,
+            ptype=encoder_ptype,
+            is_encoder_prompt=True,
+        )
+
+        # Invoke helper method to obtain
+        # decoder prompt and prompt token ids.
+        #
+        # The helper method will detect the decoder
+        # prompt type.
+        #
+        # Helper method will also apply special
+        # preprocessing unique to decoder prompts.
+        (
+            decoder_prompt,
+            decoder_prompt_token_ids,
+        ) = self._extract_single_prompt_for_enc_dec_input(
+            extracted_decoder_prompt,
+            request_id=request_id,
+            ptype=None,
+            is_encoder_prompt=False,
+        )
+
+        return LLMInputs(
+            prompt_token_ids=decoder_prompt_token_ids,
+            prompt=decoder_prompt,
+            encoder_prompt_token_ids=encoder_prompt_token_ids,
+            encoder_prompt=encoder_prompt,
+        )
+
+    def _process_decoder_only_prompt(
         self,
-        request_id: str,
         inputs: PromptInputs,
         lora_request: Optional[LoRARequest] = None,
+        request_id: Optional[str] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> LLMInputs:
+        '''
+        For decoder-only models:
+        Process an input prompt
+        into an `LLMInputs` instance.
+        Arguments:
+        * inputs: input prompt
+        * lora_request
+        * request_id
+        * prompt_adapter_request
+        Returns:
+        * `LLMInputs` instance
+        '''
         if isinstance(inputs, str):
             inputs = {"prompt": inputs}
+        prompt = inputs.get("prompt")
 
         if "prompt_token_ids" not in inputs:
-            tokenizer = self.get_tokenizer_group("prompts must be None if "
-                                                 "skip_tokenizer_init is True")
-
-            prompt_token_ids = tokenizer.encode(request_id=request_id,
-                                                prompt=inputs["prompt"],
-                                                lora_request=lora_request)
+            prompt_token_ids = self._tokenize_prompt(
+                prompt,
+                request_id=request_id,
+                lora_request=lora_request,
+            )
         else:
             prompt_token_ids = inputs["prompt_token_ids"]
 
         if prompt_adapter_request:
-            prompt_token_ids = \
-                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
-                         + prompt_token_ids
+            prompt_token_ids = (
+                [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+                + prompt_token_ids)
 
-        llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
-                               prompt=inputs.get("prompt"),
-                               multi_modal_data=inputs.get("multi_modal_data"))
+        return LLMInputs(prompt_token_ids=prompt_token_ids,
+                         prompt=prompt,
+                         multi_modal_data=inputs.get("multi_modal_data"))
 
-        return self.input_processor(llm_inputs)
+    def process_model_inputs(
+        self,
+        request_id: str,
+        inputs: PromptInputs,
+        lora_request: Optional[LoRARequest] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+    ) -> LLMInputs:
+
+        if self.is_encoder_decoder_model():
+            # Encoder-decoder model requires special mapping of
+            # input prompts to encoder & decoder
+
+            model_inputs = self._process_encoder_decoder_prompt(
+                inputs,
+                request_id=request_id,
+            )
+        else:
+            # Decoder-only operation
+            model_inputs = self._process_decoder_only_prompt(
+                inputs,
+                request_id=request_id,
+                lora_request=lora_request,
+                prompt_adapter_request=prompt_adapter_request,
+            )
+
+        return self.input_processor(model_inputs)
 
     def add_request(
         self,
@@ -623,6 +967,7 @@ class AphroditeEngine:
         arrival_time: float,
         lora_request: Optional[LoRARequest],
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+        encoder_seq: Optional[Sequence] = None,
     ) -> SequenceGroup:
         """Creates a SequenceGroup with SamplingParams."""
         max_logprobs = self.get_model_config().max_logprobs
@@ -646,7 +991,8 @@ class AphroditeEngine:
             arrival_time=arrival_time,
             sampling_params=sampling_params,
             lora_request=lora_request,
-            prompt_adapter_request=prompt_adapter_request)
+            prompt_adapter_request=prompt_adapter_request,
+            encoder_seq=encoder_seq)
 
         return seq_group
 
@@ -658,6 +1004,7 @@ class AphroditeEngine:
         arrival_time: float,
         lora_request: Optional[LoRARequest],
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+        encoder_seq: Optional[Sequence] = None,
     ) -> SequenceGroup:
         """Creates a SequenceGroup with PoolingParams."""
         # Defensive copy of PoolingParams, which are used by the pooler
@@ -669,7 +1016,8 @@ class AphroditeEngine:
             arrival_time=arrival_time,
             lora_request=lora_request,
             pooling_params=pooling_params,
-            prompt_adapter_request=prompt_adapter_request)
+            prompt_adapter_request=prompt_adapter_request,
+            encoder_seq=encoder_seq)
 
         return seq_group
 
@@ -1101,5 +1449,11 @@ class AphroditeEngine:
             self.tokenizer.check_health()
         self.model_executor.check_health()
 
+    def is_encoder_decoder_model(self):
+        return is_encoder_decoder_model_config(self.model_config)
+
+    def is_embedding_model(self):
+        return is_embedding_model_config(self.model_config)
+
 
 setup_logger()

+ 1 - 1
aphrodite/engine/args_tools.py

@@ -42,7 +42,7 @@ class EngineArgs:
     rope_scaling: Optional[dict] = None
     rope_theta: Optional[float] = None
     model_loader_extra_config: Optional[dict] = None
-    enforce_eager: Optional[bool] = True
+    enforce_eager: Optional[bool] = None
     skip_tokenizer_init: bool = False
     tokenizer_pool_size: int = 0
     # Note: Specifying a tokenizer pool by passing a class

+ 18 - 5
aphrodite/inputs/__init__.py

@@ -1,5 +1,7 @@
-from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
-                   TextPrompt, TokensPrompt, parse_and_batch_prompt)
+from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
+                   ParsedTokens, PromptInputs, SingletonPromptInputs,
+                   TextPrompt, TokensPrompt, get_prompt_type,
+                   is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
 from .registry import InputContext, InputRegistry
 
 INPUT_REGISTRY = InputRegistry()
@@ -12,7 +14,18 @@ See also:
 """
 
 __all__ = [
-    "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
-    "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
-    "InputContext", "InputRegistry"
+    "ParsedText",
+    "ParsedTokens",
+    "parse_and_batch_prompt",
+    "TextPrompt",
+    "TokensPrompt",
+    "PromptInputs",
+    "LLMInputs",
+    "INPUT_REGISTRY",
+    "InputContext",
+    "InputRegistry",
+    "get_prompt_type",
+    "is_valid_encoder_decoder_llm_inputs",
+    "ExplicitEncoderDecoderPrompt",
+    "SingletonPromptInputs",
 ]

+ 111 - 2
aphrodite/inputs/data.py

@@ -91,13 +91,101 @@ class TokensPrompt(TypedDict):
     """
 
 
-PromptInputs = Union[str, TextPrompt, TokensPrompt]
+SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
 """
-The inputs to the LLM, which can take one of the following forms:
+Set of possible schemas for a single LLM input:
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
+Note that "singleton" is as opposed to a data structure
+which encapsulates multiple prompts, i.e. of the sort
+which may be utilized for encoder/decoder models when
+the user desires to express both the encoder & decoder
+prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
+A prompt of type SingletonPromptInputs may be employed
+as (1) input to a decoder-only model, (2) input to
+the encoder of an encoder/decoder model, in the scenario
+where the decoder-prompt is not specified explicitly, or
+(3) as a member of a larger data structure encapsulating
+more than one prompt, i.e. ExplicitEncoderDecoderPrompt
 """
 
+class ExplicitEncoderDecoderPrompt(TypedDict):
+    """Represents an encoder/decoder model input prompt,
+    comprising an explicit encoder prompt and a 
+    decoder prompt.
+    The encoder and decoder prompts, respectively,
+    may formatted according to any of the
+    SingletonPromptInputs schemas, and are not
+    required to have the same schema.
+    Only the encoder prompt may have multi-modal data.
+    Note that an ExplicitEncoderDecoderPrompt may not
+    be used as an input to a decoder-only model,
+    and that the `encoder_prompt` and `decoder_prompt`
+    fields of this data structure may not themselves
+    must be SingletonPromptInputs instances.
+    """
+
+    encoder_prompt: SingletonPromptInputs
+
+    decoder_prompt: SingletonPromptInputs
+
+
+PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
+"""
+Set of possible schemas for an LLM input, including
+both decoder-only and encoder/decoder input types:
+- A text prompt (:class:`str` or :class:`TextPrompt`)
+- A tokenized prompt (:class:`TokensPrompt`)
+- A single data structure containing both an encoder and a decoder prompt
+  (:class:`ExplicitEncoderDecoderPrompt`)
+"""
+
+
+def _has_required_keys(
+    d: dict,
+    required_keys: set,
+) -> bool:
+    return required_keys.issubset(d.keys())
+
+
+def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
+    """
+    Get the type-name of the prompt argument instance, given that
+    isinstance() cannot apply to TypedDict subclasses directly.
+    If the prompt is None, return 'None' as the type name.
+    Arguments:
+    * prompt: LLM input prompt or None
+    Returns:
+    * String representation of prompt type
+    """
+
+    if prompt is None:
+        return 'None'
+
+    required_keys_dict = {
+        'TextPrompt': {'prompt'},
+        'TokensPrompt': {'prompt_token_ids'},
+        'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
+    }
+
+    if isinstance(prompt, dict):
+        for (ptype, required_keys) in required_keys_dict.items():
+            # Ignore type checking in the conditional below because type
+            # checker does not understand that is_dict(prompt) narrows
+            # down the possible types
+            if _has_required_keys(
+                    prompt,  # type: ignore
+                    required_keys):
+                return ptype
+
+        raise ValueError(f"Invalid prompt {prompt}, valid types are "
+                         "required_keys_dict={required_keys_dict}")
+
+    if isinstance(prompt, str):
+        return "str"
+
+    raise ValueError(f"Invalid prompt {prompt}")
+
 
 class LLMInputs(TypedDict):
     """
@@ -113,8 +201,29 @@ class LLMInputs(TypedDict):
     The original prompt text corresponding to the token IDs, if available.
     """
 
+    encoder_prompt_token_ids: NotRequired[List[int]]
+    """The token IDs of the encoder prompt."""
+
+    encoder_prompt: NotRequired[Optional[str]]
+    """
+    The original encoder prompt text corresponding to the token IDs, if
+    available.
+    """
+
     multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
     """
     Optional multi-modal data to pass to the model,
     if the model supports it.
     """
+
+
+def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
+    """
+    Return True if the LLMInputs instance has the correct configuration
+    for encoder/decoder.
+    """
+
+    # True if encoder prompt token ids field exists &
+    # is not None
+    return ('encoder_prompt_token_ids' in inputs
+            and inputs['encoder_prompt_token_ids'] is not None)

+ 10 - 1
aphrodite/modeling/models/__init__.py

@@ -81,7 +81,16 @@ _EMBEDDING_MODELS = {
     "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
 }
 
-_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
+_CONDITIONAL_GENERATION_MODELS = {
+    "BartModel": ("bart", "BartForConditionalGeneration"),
+    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
+}
+
+_MODELS = {
+    **_GENERATION_MODELS,
+    **_EMBEDDING_MODELS,
+    **_CONDITIONAL_GENERATION_MODELS
+}
 
 # Architecture -> type.
 # out of tree models

+ 992 - 0
aphrodite/modeling/models/bart.py

@@ -0,0 +1,992 @@
+# Derived from BART implementation posted on HuggingFace; license below:
+#
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BART model."""
+import math
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import BartConfig
+
+from aphrodite.attention import Attention, AttentionMetadata, AttentionType
+from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.activation import get_act_fn
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+
+
+def get_bsz_seq_len(input_ids):
+    shp = input_ids.shape
+    ndim = len(shp)
+    if ndim == 1:
+        return 1, input_ids.numel()
+    else:
+        return shp[:2]
+
+
+class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
+    """
+    This module learns positional embeddings up to a fixed maximum size.
+    """
+
+    def __init__(self, num_embeddings: int, embedding_dim: int):
+        # Bart is set up so that if padding_idx is
+        # specified then offset the embedding ids by 2
+        # and adjust num_embeddings appropriately.
+        # Other models don't have this hack
+        self.offset = 2
+        super().__init__(num_embeddings + self.offset, embedding_dim)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        attn_type: AttentionType,
+    ) -> torch.Tensor:
+        """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+        assert attn_type != AttentionType.ENCODER_DECODER
+
+        return super().forward(positions + self.offset)
+
+
+class BartScaledWordEmbedding(VocabParallelEmbedding):
+    """
+    This module overrides VocabParallelEmbedding's 
+    forward by multiplying with embeddings scale.
+    """
+
+    def __init__(self,
+                 num_embeddings: int,
+                 embedding_dim: int,
+                 embed_scale: float = 1.0):
+        super().__init__(num_embeddings, embedding_dim)
+        self.embed_scale = embed_scale
+
+    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return super().forward(input_ids) * self.embed_scale
+
+
+class BartParallelLMHead(ParallelLMHead):
+    """
+    This module overrides ParallelLMHead's
+    forward by dividing by embeddings scale,
+    yielding effectively the inverse of
+    BartScaledWordEmbedding
+    """
+
+    def __init__(self,
+                 num_embeddings: int,
+                 embedding_dim: int,
+                 embed_scale: float = 1.0):
+        super().__init__(num_embeddings, embedding_dim)
+        self.embed_scale = embed_scale
+
+    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return super().forward(input_ids) / self.embed_scale
+
+
+class BartEncoderAttention(nn.Module):
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        bias: bool = True,
+        config: Optional[BartConfig] = None,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.embed_dim = embed_dim
+        self.total_num_heads = num_heads
+        self.total_num_kv_heads = self.total_num_heads
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(f"embed_dim must be divisible by num_heads "
+                             f"(got `embed_dim`: {self.embed_dim}"
+                             f" and `num_heads`: {num_heads}).")
+        self.scaling = self.head_dim**-0.5
+
+        self.qkv_proj = QKVParallelLinear(
+            self.d_model,
+            self.d_model // self.total_num_heads,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=bias,
+            quant_config=quant_config,
+        )
+
+        self.out_proj = RowParallelLinear(
+            embed_dim,
+            embed_dim,
+            bias=bias,
+            quant_config=quant_config,
+        )
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        if self.total_num_kv_heads >= tp_world_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_world_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_world_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+    def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        """Input shape: Batch x Time x Channel"""
+
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+        attn_output = self.attn(q,
+                                k,
+                                v,
+                                kv_cache,
+                                attn_metadata,
+                                attn_type=AttentionType.ENCODER)
+
+        output, _ = self.out_proj(attn_output)
+        return output
+
+
+class BartDecoderSelfAttention(nn.Module):
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        bias: bool = True,
+        config: Optional[BartConfig] = None,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.embed_dim = embed_dim
+        self.total_num_heads = num_heads
+        self.total_num_kv_heads = self.total_num_heads
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(f"embed_dim must be divisible by num_heads "
+                             f"(got `embed_dim`: {self.embed_dim}"
+                             f" and `num_heads`: {num_heads}).")
+        self.scaling = self.head_dim**-0.5
+
+        self.qkv_proj = QKVParallelLinear(
+            self.d_model,
+            self.d_model // self.total_num_heads,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=bias,
+            quant_config=quant_config,
+        )
+
+        self.out_proj = RowParallelLinear(
+            embed_dim,
+            embed_dim,
+            bias=bias,
+            quant_config=quant_config,
+        )
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        if self.total_num_kv_heads >= tp_world_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_world_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_world_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+    def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        """Input shape: Batch x Time x Channel"""
+
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+
+        attn_output = self.attn(q,
+                                k,
+                                v,
+                                kv_cache,
+                                attn_metadata,
+                                attn_type=AttentionType.DECODER)
+
+        output, _ = self.out_proj(attn_output)
+        return output
+
+
+class BartCrossAttention(nn.Module):
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        bias: bool = True,
+        config: Optional[BartConfig] = None,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.d_model = config.d_model
+        self.embed_dim = embed_dim
+        self.total_num_heads = num_heads
+        self.total_num_kv_heads = self.total_num_heads
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(f"embed_dim must be divisible by num_heads "
+                             f"(got `embed_dim`: {self.embed_dim}"
+                             f" and `num_heads`: {num_heads}).")
+        self.scaling = self.head_dim**-0.5
+
+        self.qkv_proj = QKVParallelLinear(
+            self.d_model,
+            self.d_model // self.total_num_heads,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=bias,
+            quant_config=quant_config,
+        )
+
+        self.out_proj = RowParallelLinear(
+            embed_dim,
+            embed_dim,
+            bias=bias,
+            quant_config=quant_config,
+        )
+
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert self.total_num_heads % tp_world_size == 0
+        self.num_heads = self.total_num_heads // tp_world_size
+
+        if self.total_num_kv_heads >= tp_world_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_world_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_world_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+    def forward(
+        self,
+        decoder_hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Input shape: Batch x Time x Channel"""
+
+        # (afeldman-nm 2024/07/22) TODO:
+        # Need a more efficient solution for q/k/v
+        qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
+        q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
+                                dim=-1)
+        if encoder_hidden_states is None:
+            k = None
+            v = None
+        else:
+            qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
+            _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
+                                    dim=-1)
+
+        attn_output = self.attn(q,
+                                k,
+                                v,
+                                kv_cache,
+                                attn_metadata,
+                                attn_type=AttentionType.ENCODER_DECODER)
+
+        output, _ = self.out_proj(attn_output)
+        return output
+
+
+class BartEncoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: BartConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = BartEncoderAttention(
+            embed_dim=self.embed_dim,
+            num_heads=config.encoder_attention_heads,
+            config=config,
+            cache_config=cache_config,
+            quant_config=quant_config)
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.activation_fn = get_act_fn(config.activation_function,
+                                        quant_config)
+
+        ffn_hidden_size = self.embed_dim
+        ffn_intermediate_size = config.encoder_ffn_dim
+        ffn_has_bias = True
+        self.fc1 = ColumnParallelLinear(
+            ffn_hidden_size,
+            ffn_intermediate_size,
+            bias=ffn_has_bias,
+            quant_config=quant_config,
+        )
+        self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
+        self.fc2 = RowParallelLinear(
+            ffn_intermediate_size,
+            ffn_hidden_size,
+            bias=ffn_has_bias,
+            quant_config=quant_config,
+        )
+
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        r"""
+        Args:
+            hidden_states
+                torch.Tensor of *encoder* input embeddings.
+            kv_cache:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                Aphrodite Attention metadata structure
+        Returns:
+            Encoder layer output torch.Tensor
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn(hidden_states=hidden_states,
+                                       kv_cache=kv_cache,
+                                       attn_metadata=attn_metadata)
+
+        hidden_states = residual + hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        residual = hidden_states
+        fc1_out, _ = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(fc1_out)
+
+        hidden_states, _ = self.fc2(hidden_states)
+
+        hidden_states = residual + hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+
+        if hidden_states.dtype == torch.float16 and (
+                torch.isinf(hidden_states).any()
+                or torch.isnan(hidden_states).any()):
+            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+            hidden_states = torch.clamp(hidden_states,
+                                        min=-clamp_value,
+                                        max=clamp_value)
+
+        return hidden_states
+
+
+class BartDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: BartConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = BartDecoderSelfAttention(
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            config=config,
+            cache_config=cache_config,
+            quant_config=quant_config)
+        self.activation_fn = get_act_fn(config.activation_function,
+                                        quant_config)
+
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        '''
+        afeldman-nm: personally I would call this "cross-attention",
+        however I left the name as "encoder_attn" to maintain consistency
+        with the name of the pretrained weights.
+        '''
+        self.encoder_attn = BartCrossAttention(
+            self.embed_dim,
+            config.decoder_attention_heads,
+            config=config,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+
+        ffn_hidden_size = self.embed_dim
+        ffn_intermediate_size = config.encoder_ffn_dim
+        ffn_has_bias = True
+        self.fc1 = ColumnParallelLinear(
+            ffn_hidden_size,
+            ffn_intermediate_size,
+            bias=ffn_has_bias,
+            quant_config=quant_config,
+        )
+        self.fc2 = RowParallelLinear(
+            ffn_intermediate_size,
+            ffn_hidden_size,
+            bias=ffn_has_bias,
+            quant_config=quant_config,
+        )
+
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        decoder_hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            decoder_hidden_states
+                torch.Tensor of *decoder* input embeddings.
+            kv_cache:
+                KV cache tensor
+            attn_metadata:
+                Aphrodite Attention metadata structure
+            encoder_hidden_states
+                torch.Tensor of *encoder* input embeddings.
+        Returns:
+            Decoder layer output torch.Tensor
+        """
+        residual = decoder_hidden_states
+
+        # Self Attention
+        hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
+                                       kv_cache=kv_cache,
+                                       attn_metadata=attn_metadata)
+
+        hidden_states = residual + hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Cross-Attention Block
+
+        residual = hidden_states
+
+        hidden_states = self.encoder_attn(
+            decoder_hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+            encoder_hidden_states=encoder_hidden_states,
+        )
+
+        hidden_states = residual + hidden_states
+        hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+        # Fully Connected
+        residual = hidden_states
+        fc1_out, _ = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(fc1_out)
+
+        hidden_states, _ = self.fc2(hidden_states)
+
+        hidden_states = residual + hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+
+        return hidden_states
+
+
+class BartEncoder(nn.Module):
+    """
+    Transformer encoder consisting of *config.encoder_layers*
+    self attention layers. Each layer is a [`BartEncoderLayer`].
+    Args:
+        config: BartConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self,
+                 config: BartConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 lora_config: Optional[LoRAConfig] = None,
+                 embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__()
+
+        self.cache_config = cache_config
+        self.quant_config = quant_config
+        self.lora_config = lora_config
+        embed_dim = config.d_model
+        self.max_source_positions = config.max_position_embeddings
+        embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+        self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
+                                                    embed_dim,
+                                                    embed_scale=embed_scale)
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = BartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            embed_dim,
+        )
+        self.layers = nn.ModuleList(
+            [BartEncoderLayer(config,cache_config,quant_config) \
+             for _ in range(config.encoder_layers)])
+
+        self.layernorm_embedding = nn.LayerNorm(embed_dim)
+
+    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+                kv_caches: List[torch.Tensor],
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        r"""
+        Args:
+            input_ids
+                Indices of *encoder* input sequence tokens in the vocabulary.
+                Padding will be ignored by default should you
+                provide it.
+            positions
+                Positions of *encoder* input sequence tokens.
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                Aphrodite Attention metadata structure
+        Returns:
+            Decoder output torch.Tensor
+        """
+        # retrieve input_ids and inputs_embeds
+
+        input_ids = input_ids.view(-1, input_ids.shape[-1])
+        inputs_embeds = self.embed_tokens(input_ids)
+
+        embed_pos = self.embed_positions(
+            positions,
+            AttentionType.ENCODER,
+        )
+        embed_pos = embed_pos.to(inputs_embeds.device)
+
+        hidden_states = inputs_embeds + embed_pos
+        hidden_states = self.layernorm_embedding(hidden_states)
+
+        for idx, encoder_layer in enumerate(self.layers):
+            hidden_states = encoder_layer(
+                hidden_states=hidden_states,
+                kv_cache=kv_caches[idx],
+                attn_metadata=attn_metadata,
+            )
+
+        return hidden_states
+
+
+class BartDecoder(nn.Module):
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers.
+    Each layer is a [`BartDecoderLayer`]
+    Args:
+        config: BartConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(
+        self,
+        config: BartConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+        embed_tokens: Optional[nn.Embedding] = None,
+    ):
+        super().__init__()
+        self.cache_config = cache_config
+        self.quant_config = quant_config
+        self.lora_config = lora_config
+        self.max_target_positions = config.max_position_embeddings
+        embed_scale = math.sqrt(
+            config.d_model) if config.scale_embedding else 1.0
+
+        self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
+                                                    config.d_model,
+                                                    embed_scale=embed_scale)
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = BartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+        )
+
+        self.layers = nn.ModuleList(
+            [BartDecoderLayer(config,cache_config,quant_config) \
+             for _ in range(config.decoder_layers)])
+
+        self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+    def forward(self, decoder_input_ids: torch.Tensor,
+                decoder_positions: torch.Tensor,
+                encoder_hidden_states: Optional[torch.Tensor],
+                kv_caches: List[torch.Tensor],
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        r"""
+        Args:
+            decoder_input_ids
+                Indices of *decoder* input sequence tokens in the vocabulary.
+                Padding will be ignored by default should you
+                provide it.
+            decoder_positions
+                Positions of *decoder* input sequence tokens.
+            encoder_hidden_states:
+                Tensor of encoder output embeddings
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                Aphrodite Attention metadata structure
+        Returns:
+            Decoder output torch.Tensor
+        """
+
+        inputs_embeds = self.embed_tokens(decoder_input_ids)
+
+        # embed positions
+        embed_pos = self.embed_positions(
+            decoder_positions,
+            AttentionType.DECODER,
+        )
+        embed_pos = embed_pos.to(inputs_embeds.device)
+
+        hidden_states = inputs_embeds + embed_pos
+        hidden_states = self.layernorm_embedding(hidden_states)
+
+        # decoder layers
+
+        for idx, decoder_layer in enumerate(self.layers):
+            hidden_states = decoder_layer(
+                decoder_hidden_states=hidden_states,
+                kv_cache=kv_caches[idx],
+                attn_metadata=attn_metadata,
+                encoder_hidden_states=encoder_hidden_states,
+            )
+
+        return hidden_states
+
+
+class BartModel(nn.Module):
+    _tied_weights_keys = [
+        "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
+    ]
+
+    def __init__(self,
+                 config: BartConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 lora_config: Optional[LoRAConfig] = None):
+        super().__init__()
+
+        self.config = config
+
+        self.padding_idx = config.pad_token_id
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+
+        self.encoder = BartEncoder(config,
+                                   cache_config,
+                                   quant_config=quant_config)
+        self.decoder = BartDecoder(config,
+                                   cache_config,
+                                   quant_config=quant_config)
+
+    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+                encoder_input_ids: torch.Tensor,
+                encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        r"""
+        Args:
+            input_ids
+                Indices of *decoder* input sequence tokens in the vocabulary.
+                Padding will be ignored by default should you
+                provide it.
+            positions
+                Positions of *decoder* input sequence tokens.
+            encoder_input_ids
+                Indices of *encoder* input sequence tokens in the vocabulary.
+            encoder_positions:
+                Positions of *encoder* input sequence tokens.
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                Aphrodite Attention metadata structure
+        Returns:
+            Model output torch.Tensor
+        """
+
+        encoder_hidden_states = None
+
+        if encoder_input_ids.numel() > 0:
+            # Run encoder attention if a non-zero number of encoder tokens
+            # are provided as input
+            encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
+                                                 positions=encoder_positions,
+                                                 kv_caches=kv_caches,
+                                                 attn_metadata=attn_metadata)
+
+        # decoder outputs consists of
+        # (dec_features, past_key_value, dec_hidden, dec_attn)
+        decoder_outputs = self.decoder(
+            decoder_input_ids=input_ids,
+            decoder_positions=positions,
+            encoder_hidden_states=encoder_hidden_states,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata)
+
+        return decoder_outputs
+
+
+class BartForConditionalGeneration(nn.Module):
+    base_model_prefix = "model"
+
+    def __init__(self,
+                 config: BartConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 lora_config: Optional[LoRAConfig] = None):
+
+        super().__init__()
+        self.config = config
+        self.model = BartModel(config,
+                               cache_config,
+                               quant_config,
+                               lora_config=lora_config)
+
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
+        embed_scale = math.sqrt(
+            config.d_model) if config.scale_embedding else 1.0
+
+        self.lm_head = BartParallelLMHead(config.vocab_size,
+                                          config.d_model,
+                                          embed_scale=embed_scale)
+
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        encoder_input_ids: torch.Tensor,
+        encoder_positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            input_ids
+                torch.Tensor of *decoder* input token ids.
+            positions
+                torch.Tensor of *decoder* position indices.
+            encoder_input_ids
+                torch.Tensor of *encoder* input token ids.
+            encoder_positions
+                torch.Tensor of *encoder* position indices
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                Aphrodite Attention metadata structure
+        Returns:
+            Output torch.Tensor
+        """
+        return self.model(input_ids, positions, encoder_input_ids,
+                          encoder_positions, kv_caches, attn_metadata)
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    stacked_params_mapping = {
+        "q_proj": {
+            "param_name": "qkv_proj",
+            "shard_id": "q",
+        },
+        "k_proj": {
+            "param_name": "qkv_proj",
+            "shard_id": "k",
+        },
+        "v_proj": {
+            "param_name": "qkv_proj",
+            "shard_id": "v",
+        },
+    }
+
+    params_mapping = {
+        "beta": "bias",
+        "gamma": "weight",
+        "LayerNorm": "layernorm",
+    }
+
+    def _rename_key(self, key: str):
+        prefix = f"{self.base_model_prefix}."
+        key = key[len(prefix):] if key.startswith(prefix) else key
+
+        for src, dst in self.params_mapping.items():
+            key = key.replace(src, dst)
+
+        return key
+
+    def _rename_stacked_param(
+        self,
+        name: str,
+    ) -> Tuple[str, Optional[str]]:
+        for key, mapping in self.stacked_params_mapping.items():
+            if key in name:
+                name = name.replace(key, mapping["param_name"])
+                return name, mapping["shard_id"]
+        return name, None
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+        model_params_dict = dict(self.model.named_parameters())
+        top_params_dict = dict(self.named_parameters())
+
+        weights_tuple_list = list(weights)
+
+        shared_embedding_weight = None
+        shared_embedding_shard_id = None
+
+        for name, loaded_weight in weights_tuple_list:
+
+            name = self._rename_key(name)
+            name, shard_id = self._rename_stacked_param(name)
+
+            if ('shared.weight' in name
+                    or 'encoder.embed_tokens.weight' in name
+                    or 'decoder.embed_tokens.weight' in name
+                    or 'lm_head.weight' in name):
+                assert shared_embedding_weight is None, (
+                    "Conflicting embedding weights.")
+                shared_embedding_weight = loaded_weight
+                shared_embedding_shard_id = shard_id
+            else:
+                # Skip the specific downstream task weight.
+                if name.startswith('cls.'):
+                    continue
+                # use Pooler instead.
+                if name.startswith('pooler.'):
+                    continue
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in model_params_dict:
+                    continue
+
+                param = model_params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                if shard_id:
+                    weight_loader(param, loaded_weight, shard_id)
+                else:
+                    weight_loader(param, loaded_weight)
+
+        # Assign shared weight values
+        encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
+        encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
+                                           default_weight_loader)
+
+        decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
+        decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
+                                           default_weight_loader)
+
+        lm_head_in_param = top_params_dict['lm_head.weight']
+        lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
+                                           default_weight_loader)
+
+        assert shared_embedding_weight is not None
+
+        if shared_embedding_shard_id:
+            encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
+                                     shared_embedding_shard_id)
+            decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
+                                     shared_embedding_shard_id)
+            lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
+                                     shared_embedding_shard_id)
+        else:
+            encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
+            decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
+            lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)

+ 2 - 10
aphrodite/processing/block/utils.py

@@ -1,15 +1,7 @@
 """Block manager utils."""
 from aphrodite.common.sequence import SequenceGroup
-
-# Exception strings for non-implemented block manager enc/dec scenarios
-
-STR_NOT_IMPL_ENC_DEC_SWA = \
-    "Sliding window attention for encoder/decoder models " + \
-                    "is not currently supported."
-
-STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
-    "Prefix caching for encoder/decoder models " + \
-                    "is not currently supported."
+from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
+                                    STR_NOT_IMPL_ENC_DEC_SWA)
 
 
 def _get_block_mgr_sliding_window_attr(block_mgr):

+ 28 - 0
aphrodite/processing/scheduler.py

@@ -391,6 +391,19 @@ class Scheduler:
                     seq.status = SequenceStatus.FINISHED_ABORTED
                     self.free_seq(seq)
 
+                self._free_seq_group_cross_attn_blocks(aborted_group)
+
+    def _free_seq_group_cross_attn_blocks(
+        self,
+        seq_group: SequenceGroup,
+    ) -> None:
+        """
+        Free a sequence group from a cross-attention block table.
+        Has no effect on decoder-only models.
+        """
+        if seq_group.is_encoder_decoder():
+            self.block_manager.free_cross(seq_group)
+
     def has_unfinished_seqs(self) -> bool:
         return len(self.waiting) != 0 or len(self.running) != 0 or len(
             self.swapped) != 0
@@ -961,6 +974,17 @@ class Scheduler:
             # seq_id -> physical block numbers
             block_tables: Dict[int, List[int]] = {}
 
+            if seq_group.is_encoder_decoder():
+                # Encoder associated with SequenceGroup
+                encoder_seq_data = seq_group.get_encoder_seq().data
+                # Block table for cross-attention
+                # Also managed at SequenceGroup level
+                cross_block_table = self.block_manager.get_cross_block_table(
+                    seq_group)
+            else:
+                encoder_seq_data = None
+                cross_block_table = None
+
             for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                 seq_id = seq.seq_id
                 seq_data[seq_id] = seq.data
@@ -999,6 +1023,8 @@ class Scheduler:
                 token_chunk_size=token_chunk_size,
                 lora_request=seq_group.lora_request,
                 computed_block_nums=common_computed_block_nums,
+                encoder_seq_data=encoder_seq_data,
+                cross_block_table=cross_block_table,
                 # `multi_modal_data` will only be present for the 1st comm
                 # between engine and worker.
                 # the subsequent comms can still use delta, but
@@ -1030,6 +1056,8 @@ class Scheduler:
         remaining: Deque[SequenceGroup] = deque()
         for seq_group in self.running:
             if seq_group.is_finished():
+                # Free cross-attention block table, if it exists
+                self._free_seq_group_cross_attn_blocks(seq_group)
                 # Add the finished requests to the finished requests list.
                 # This list will be used to update the Mamba cache in the
                 # next step.

+ 466 - 0
aphrodite/task_handler/enc_dec_model_runner.py

@@ -0,0 +1,466 @@
+import dataclasses
+from typing import Any, Dict, List, Optional, Tuple, Type, cast
+
+import torch
+import torch.distributed
+from loguru import logger
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionMetadata)
+from aphrodite.attention.selector import (_Backend,
+                                          get_env_variable_attn_backend,
+                                          get_global_forced_attn_backend,
+                                          global_force_attn_backend)
+from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
+                                     LoRAConfig, ModelConfig, MultiModalConfig,
+                                     ParallelConfig, PromptAdapterConfig,
+                                     SchedulerConfig)
+from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
+                                       SamplerOutput, SequenceGroupMetadata)
+from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND,
+                                    make_tensor_with_pad)
+from aphrodite.inputs import INPUT_REGISTRY
+from aphrodite.modeling import SamplingMetadata
+from aphrodite.task_handler.model_runner import (
+    _PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder,
+    ModelInputForGPUWithSamplingMetadata)
+from aphrodite.task_handler.model_runner_base import (
+    _add_attn_metadata_broadcastable_dict,
+    _add_sampling_metadata_broadcastable_dict)
+from aphrodite.task_handler.utils import assert_enc_dec_mr_supported_scenario
+
+
+@dataclasses.dataclass(frozen=True)
+class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
+    """
+    Used by the EncoderDecoderModelRunner.
+    """
+    encoder_input_tokens: Optional[torch.Tensor] = None
+    encoder_input_positions: Optional[torch.Tensor] = None
+
+    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
+        tensor_dict = {
+            "input_tokens": self.input_tokens,
+            "input_positions": self.input_positions,
+            "encoder_input_tokens": self.encoder_input_tokens,
+            "encoder_input_positions": self.encoder_input_positions,
+            "virtual_engine": self.virtual_engine,
+            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
+            "finished_requests_ids": self.finished_requests_ids,
+        }
+        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
+        _add_sampling_metadata_broadcastable_dict(tensor_dict,
+                                                  self.sampling_metadata)
+        return tensor_dict
+
+    @classmethod
+    def from_broadcasted_tensor_dict(
+        cls,
+        tensor_dict: Dict[str, Any],
+        attn_backend: Optional["AttentionBackend"] = None,
+    ) -> "EncoderDecoderModelInput":
+        return cast(
+            EncoderDecoderModelInput,
+            super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
+
+
+class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
+    _model_input_cls: Type[EncoderDecoderModelInput] = (
+        EncoderDecoderModelInput)
+    _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder)
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        cache_config: CacheConfig,
+        load_config: LoadConfig,
+        lora_config: Optional[LoRAConfig],
+        kv_cache_dtype: Optional[str] = "auto",
+        is_driver_worker: bool = False,
+        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
+        multimodal_config: Optional[MultiModalConfig] = None,
+        **kwargs,
+    ):
+        '''
+        EncoderDecoderModelRunner constructor.
+        `lora_config`, `multimodal_config`, and prompt_adapter_config are
+        unused (since these features are not yet supported for encoder/decoder
+        models) but these arguments are present here for compatibility with 
+        the base-class constructor.
+        '''
+
+        self._maybe_force_supported_attention_backend()
+
+        super().__init__(
+            model_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            cache_config,
+            load_config,
+            lora_config=None,
+            kv_cache_dtype=kv_cache_dtype,
+            is_driver_worker=is_driver_worker,
+            **kwargs,
+        )
+
+        # Crash for unsupported encoder/scenarios
+        assert_enc_dec_mr_supported_scenario(self)
+
+    def _maybe_force_supported_attention_backend(self):
+        '''
+        Force Aphrodite to use the XFormers attention backend,
+        which is currently the only supported option.
+        '''
+
+        def raise_backend_err():
+            # The user has specified an attention backend override
+            # which is invalid for encoder/decoder models
+            raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND)
+
+        maybe_env_var_forced_backend = get_env_variable_attn_backend()
+        maybe_global_forced_backend = get_global_forced_attn_backend()
+        is_forced_by_global = maybe_global_forced_backend is not None
+        is_forced_by_env_var = maybe_env_var_forced_backend is not None
+
+        if not (is_forced_by_global or is_forced_by_env_var):
+            # The user has not already specified an attention backend
+            # override
+            logger.info("EncoderDecoderModelRunner requires "
+                        "XFormers backend; overriding backend "
+                        "auto-selection and forcing XFormers.")
+            global_force_attn_backend(_Backend.XFORMERS)
+        elif is_forced_by_global:
+            # Backend override enforced by global variable takes
+            # precedence over Aphrodite backend environment variable.
+            if maybe_global_forced_backend != _Backend.XFORMERS:
+                raise_backend_err()
+        elif is_forced_by_env_var:
+            # Backend override enforced by Aphrodite backend
+            # environment variable
+            if maybe_env_var_forced_backend != _Backend.XFORMERS:
+                raise_backend_err()
+
+    def _list_to_int32_tensor(
+        self,
+        _list: List[int],
+    ) -> torch.Tensor:
+        return torch.tensor(_list, dtype=torch.int32, device=self.device)
+
+    def _list_to_long_tensor(
+        self,
+        _list: List[int],
+    ) -> torch.Tensor:
+        return torch.tensor(_list, dtype=torch.long, device=self.device)
+
+    def _empty_int32_tensor(self) -> torch.Tensor:
+        return self._list_to_int32_tensor([])
+
+    def _empty_long_tensor(self) -> torch.Tensor:
+        return self._list_to_long_tensor([])
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        model_input: EncoderDecoderModelInput,
+        kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        num_steps: int = 1,
+    ) -> Optional[List[PoolerOutput]]:
+        if num_steps > 1:
+            raise ValueError("num_steps > 1 is not supported in "
+                             "EncoderDecoderModelRunner")
+
+        model_executable = self.model
+
+        seqlen_agnostic_kwargs = {
+            "finished_requests_ids": model_input.finished_requests_ids,
+            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
+        } if self.has_seqlen_agnostic else {}
+        hidden_or_intermediate_states = model_executable(
+            input_ids=model_input.input_tokens,
+            positions=model_input.input_positions,
+            encoder_input_ids=model_input.encoder_input_tokens,
+            encoder_positions=model_input.encoder_input_positions,
+            kv_caches=kv_caches,
+            attn_metadata=model_input.attn_metadata,
+            intermediate_tensors=intermediate_tensors,
+            **seqlen_agnostic_kwargs)
+
+        logits = self.model.compute_logits(hidden_or_intermediate_states,
+                                           model_input.sampling_metadata)
+
+        if not self.is_driver_worker:
+            return []
+
+        # Sample the next token.
+        output: SamplerOutput = self.model.sample(
+            logits=logits,
+            sampling_metadata=model_input.sampling_metadata,
+        )
+
+        return [output]
+
+    def make_model_input_from_broadcasted_tensor_dict(
+            self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput:
+        return EncoderDecoderModelInput.from_broadcasted_tensor_dict(
+            tensor_dict,
+            attn_backend=self.attn_backend,
+        )
+
+    def prepare_model_input(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        virtual_engine: int = 0,
+        finished_requests_ids: Optional[List[str]] = None
+    ) -> EncoderDecoderModelInput:
+        """Prepare the model input based on a given sequence group, including
+        metadata for the sampling step.
+        Since chunked prefill is not supported for encoder/decoder models,
+        `input_tokens` is assumed to be either entirely prefill tokens or
+        entirely decode tokens.
+        """
+        model_input = self._prepare_model_input_tensors(
+            seq_group_metadata_list, finished_requests_ids)
+
+        (
+            attn_metadata,
+            encoder_input_tokens_tensor,
+            encoder_input_positions_tensor,
+        ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
+                                                       model_input))
+
+        # Inject attn_metadata encoder/cross-attention fields &
+        # encoder input tokens/positions into model_input.
+        # Frozen dataclass fields cannot be modified, so use
+        # dataclasses.replace to construct a new model input
+        # instance.
+        model_input = dataclasses.replace(
+            model_input,
+            attn_metadata=attn_metadata,
+            encoder_input_tokens=encoder_input_tokens_tensor,
+            encoder_input_positions=encoder_input_positions_tensor,
+        )
+
+        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
+                                                     model_input.seq_lens,
+                                                     model_input.query_lens,
+                                                     self.device,
+                                                     self.pin_memory)
+        is_prompt = (seq_group_metadata_list[0].is_prompt
+                     if seq_group_metadata_list else None)
+        return dataclasses.replace(model_input,
+                                   sampling_metadata=sampling_metadata,
+                                   is_prompt=is_prompt,
+                                   virtual_engine=virtual_engine)
+
+    @torch.inference_mode()
+    def profile_run(self) -> None:
+        # Enable top-k sampling to reflect the accurate memory usage.
+        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
+        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
+        max_num_seqs = self.scheduler_config.max_num_seqs
+
+        # Profile memory usage with max_num_sequences sequences and the total
+        # number of tokens equal to max_num_batched_tokens.
+        seqs: List[SequenceGroupMetadata] = []
+
+        model_config = self.model_config
+
+        batch_size = 0
+        for group_id in range(max_num_seqs):
+            seq_len = (max_num_batched_tokens // max_num_seqs +
+                       (group_id < max_num_batched_tokens % max_num_seqs))
+            batch_size += seq_len
+
+            seq_data, _ = INPUT_REGISTRY \
+                .dummy_data_for_profiling(model_config, seq_len)
+
+            # Having more tokens is over-conservative but otherwise fine
+            assert len(seq_data.prompt_token_ids) >= seq_len, (
+                f"Expected at least {seq_len} dummy tokens for profiling, "
+                f"but got: {len(seq_data.prompt_token_ids)}")
+
+            seq = SequenceGroupMetadata(
+                request_id=str(group_id),
+                is_prompt=True,
+                seq_data={group_id: seq_data},
+                sampling_params=sampling_params,
+                block_tables=None,
+                encoder_seq_data=seq_data,
+                cross_block_table=None,
+            )
+            seqs.append(seq)
+
+        # Run the model with the dummy inputs.
+        num_layers = self.model_config.get_num_layers(self.parallel_config)
+        kv_caches = [None] * num_layers
+        finished_requests_ids = [seq.request_id for seq in seqs]
+        model_input = self.prepare_model_input(
+            seqs, finished_requests_ids=finished_requests_ids)
+        intermediate_tensors = None
+        self.execute_model(model_input, kv_caches, intermediate_tensors)
+        torch.cuda.synchronize()
+        return
+
+    def _prepare_encoder_model_input_tensors(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        model_input: EncoderDecoderModelInput,
+    ) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
+               Optional[torch.Tensor]]:
+        """Helper method to prepare the encoder- and cross-attn-related
+        model inputs based on a given sequence group. These additional inputs
+        are used to augment an already-computed `EncoderDecoderModelInput`
+        data structure which already has decoder-related model inputs
+        populated.
+        Sets the following attn_metadata fields:
+        * `num_encoder_tokens`
+        * `encoder_seq_lens`
+        * `encoder_seq_lens_tensor`
+        * `max_encoder_seq_len`
+        * `cross_slot_mapping`
+        * `cross_block_tables`
+        Constructs a new model inputs data structure, based on
+        (1) the existing fields in the `model_inputs` argument,
+        and (2) the following additional fields which are
+        computed (or in the case of `attn_metadata`, updated) 
+        by this function:
+        * attn_metadata
+        * encoder_input_tokens
+        * encoder_input_positions
+        Arguments:
+        * seq_group_metadata_list: list of sequence groups for which to
+                                   compute inputs
+        * model_inputs: model inputs data structure with decoder-oriented
+                        fields already computed.
+        Return:
+        * Updated model inputs data structure
+        """
+
+        if len(seq_group_metadata_list) == 0:
+            return (model_input.attn_metadata, None, None)
+
+        # Since we are not supporting chunked prefill either the entire
+        # batch is prefill or it is decode
+        is_prompt = seq_group_metadata_list[0].is_prompt
+
+        # Build encoder inputs
+        encoder_seq_lens: List[int] = []
+        if is_prompt:
+            # Prefill phase.
+            cross_block_tables = self._empty_int32_tensor().view(
+                len(seq_group_metadata_list), -1)
+
+            # Extract input tokens/positions, cross-attention slot-mapping,
+            # & seq len from each sequence group metadata
+            (
+                encoder_input_tokens,
+                encoder_input_positions,
+                cross_slot_mapping,
+            ) = (
+                [],
+                [],
+                [],
+            )
+            for seq_group_metadata in seq_group_metadata_list:
+                # Build seq lens
+                seq_len = seq_group_metadata.encoder_seq_data.get_len()
+                token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
+                encoder_seq_lens.append(seq_len)
+
+                # Build slot mapping
+                is_profile_run = (seq_group_metadata.block_tables is None)
+                if is_profile_run:
+                    # During memory profiling, the block tables are not
+                    # initialized yet. In this case, we just use a dummy
+                    # slot mapping.
+                    # In embeddings, the block tables are {seq_id: None}.
+                    cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
+                else:
+                    for i in range(0, seq_len):
+                        block_number = seq_group_metadata.cross_block_table[
+                            i // self.block_size]
+                        block_offset = i % self.block_size
+                        slot = block_number * self.block_size + block_offset
+                        cross_slot_mapping.append(slot)
+
+                # Build encoder input tokens
+                encoder_input_tokens.extend(token_ids)
+                encoder_input_positions.extend(list(range(0, seq_len)))
+
+            # Convert tokens/positions & cross-attention
+            # slot-mapping to encoder input tensors
+            encoder_input_tokens_tensor = self._list_to_long_tensor(
+                encoder_input_tokens)
+            encoder_input_positions_tensor = self._list_to_long_tensor(
+                encoder_input_positions)
+            cross_slot_mapping_tensor = self._list_to_long_tensor(
+                cross_slot_mapping)
+
+        else:
+            # Decode phase.
+            encoder_input_tokens_tensor = self._empty_long_tensor()
+            encoder_input_positions_tensor = self._empty_long_tensor()
+            cross_slot_mapping_tensor = self._empty_long_tensor()
+
+            # Extract cross-attention block tables &
+            # seq len from each sequence group metadata.
+            # Cross-attention block tables are empty
+            # during Aphrodite memory profiling.
+            cross_block_tables = []
+            for seq_group_metadata in seq_group_metadata_list:
+                encoder_seq_lens.append(
+                    seq_group_metadata.encoder_seq_data.get_len())
+                cross_block_table = seq_group_metadata.cross_block_table
+                cross_block_tables.append([] if (
+                    cross_block_table is None) else cross_block_table)
+
+            # Convert cross-attention block tables to encoder input tensor
+            cross_block_tables = make_tensor_with_pad(
+                cross_block_tables,
+                max_len=max(
+                    len(block_table) for block_table in cross_block_tables),
+                pad=0,
+                dtype=torch.int32,
+                device=self.device,
+            )
+
+        # Compute encoder sequence lengths & encoder
+        # sequence starting offset tensors
+        max_encoder_seq_len = max(encoder_seq_lens, default=0)
+        encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
+        encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
+                                            1,
+                                            dtype=torch.int32,
+                                            device=self.device)
+        torch.cumsum(encoder_seq_lens_tensor,
+                     dim=0,
+                     dtype=encoder_seq_start_loc.dtype,
+                     out=encoder_seq_start_loc[1:])
+
+        # Update attention metadata with encoder-oriented attributes
+        attn_metadata = model_input.attn_metadata
+        assert attn_metadata is not None
+        (
+            attn_metadata.num_encoder_tokens,
+            attn_metadata.encoder_seq_lens,
+            attn_metadata.encoder_seq_lens_tensor,
+            attn_metadata.max_encoder_seq_len,
+            attn_metadata.cross_slot_mapping,
+            attn_metadata.cross_block_tables,
+        ) = (
+            sum(encoder_seq_lens),
+            encoder_seq_lens,
+            encoder_seq_lens_tensor,
+            max_encoder_seq_len,
+            cross_slot_mapping_tensor,
+            cross_block_tables,
+        )
+
+        return (attn_metadata, encoder_input_tokens_tensor,
+                encoder_input_positions_tensor)

+ 56 - 0
aphrodite/task_handler/utils.py

@@ -0,0 +1,56 @@
+'''
+Worker-related helper functions.
+'''
+
+from aphrodite.common.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
+from aphrodite.task_handler.model_runner import GPUModelRunnerBase
+
+
+def assert_enc_dec_mr_supported_scenario(
+        enc_dec_mr: GPUModelRunnerBase) -> None:
+    '''
+    Asserted that the provided encoder/decoder model runner instance reflects
+    a supported scenario.
+    '''
+
+    if enc_dec_mr.cache_config.enable_prefix_caching:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE'])
+
+    if enc_dec_mr.sliding_window is not None:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA'])
+
+    if enc_dec_mr.scheduler_config.chunked_prefill_enabled:
+        raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
+            'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL'])
+
+    if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping',
+               None) is not None:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP']
+        )
+
+    if enc_dec_mr.lora_config is not None:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA'])
+
+    if enc_dec_mr.parallel_config.pipeline_parallel_size > 1:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])
+
+    if enc_dec_mr.multimodal_config is not None:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM'])
+
+    if enc_dec_mr.scheduler_config.num_lookahead_slots > 0:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])
+
+    if not enc_dec_mr.model_config.enforce_eager:
+        raise NotImplementedError(
+            STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH'])
+
+    if enc_dec_mr.prompt_adapter_config is not None:
+        raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[
+            'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER'])

+ 13 - 1
aphrodite/task_handler/worker.py

@@ -13,6 +13,8 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      ParallelConfig, PromptAdapterConfig,
                                      SchedulerConfig, SpeculativeConfig)
 from aphrodite.common.sequence import ExecuteModelRequest
+from aphrodite.common.utils import (is_embedding_model_config,
+                                    is_encoder_decoder_model_config)
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -25,6 +27,8 @@ from aphrodite.platforms import current_platform
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
+from aphrodite.task_handler.enc_dec_model_runner import (
+    EncoderDecoderModelRunner)
 from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
 from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
                                                 WorkerInput)
@@ -91,8 +95,10 @@ class Worker(LocalOrDistributedWorkerBase):
         ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
         if model_runner_cls is not None:
             ModelRunnerClass = model_runner_cls
-        elif self.model_config.embedding_mode:
+        elif self._is_embedding_model():
             ModelRunnerClass = EmbeddingModelRunner
+        elif self._is_encoder_decoder_model():
+            ModelRunnerClass = EncoderDecoderModelRunner
         self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
             model_config,
             parallel_config,
@@ -114,6 +120,12 @@ class Worker(LocalOrDistributedWorkerBase):
         # Initialize gpu_cache as embedding models don't initialize kv_caches
         self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
 
+    def _is_encoder_decoder_model(self):
+        return is_encoder_decoder_model_config(self.model_config)
+
+    def _is_embedding_model(self):
+        return is_embedding_model_config(self.model_config)
+
     def init_device(self) -> None:
         if self.device_config.device.type == "cuda":
             # torch.distributed.all_reduce does not free the input tensor until

+ 4 - 0
docs/.vitepress/config.mts

@@ -58,6 +58,10 @@ export default defineConfig({
 						text: "Vision Language Models",
 						link: "/pages/usage/vlm",
 					},
+					{
+						text: "Encoder-Decoder Models",
+						link: "/pages/usage/encoder-decoder",
+					},
 					{
 						text: "Distributed Inference",
 						link: "/pages/usage/distributed",

+ 184 - 0
docs/pages/usage/encoder-decoder.md

@@ -0,0 +1,184 @@
+---
+outline: deep
+---
+
+# Encoder-Decoder Model Support in Aphrodite
+
+Aphrodite now supports encoder-decoder language models (only available if built from source for now), such as [BART](https://huggingface.co/facebook/bart-large-cnn), in addition to decoder-only models. This document will guide you through using encoder-decoder models with Aphrodite.
+
+## Introduction
+Encoder-decoder models, like BART, consist of two main components: an encoder that processes the input sequence, and a decoder that generates the output sequence. Aphrodite's support for these models allows you to leverage their capabilities for tasks such as summarization, translation, and more.
+
+## Setting up an Encoder-Decoder Model
+
+To use an encoder-decoder model with Aphrodite, you need to initialize an `LLM` instance with the appropriate model name. Here's an example using BART model:
+
+```py
+from aphrodite import LLM
+
+llm = LLM(model="facebook/bart-large-cnn", dtype="float")
+```
+
+Keep in mind that it's recommended to use float (FP32) data type for BART models.
+
+The `LLM` class automatically detects whether the model is encoder-decoder or a decoder-only model and sets up the appropriate internal configurations.
+
+## Input Types
+We support various input types for encoder-decoder models. The main types are defined in the `aphrodite.inputs.data` module:
+
+```py
+    if prompt is None:
+        return 'None'
+
+    required_keys_dict = {
+        'TextPrompt': {'prompt'}, # [!code highlight]
+        'TokensPrompt': {'prompt_token_ids'}, # [code! highlight]
+        'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'}, # [code! highlight]
+    }
+
+    if isinstance(prompt, dict):
+        for (ptype, required_keys) in required_keys_dict.items():
+            # Ignore type checking in the conditional below because type
+            # checker does not understand that is_dict(prompt) narrows
+            # down the possible types
+            if _has_required_keys(
+                    prompt,  # type: ignore
+                    required_keys):
+                return ptype
+
+        raise ValueError(f"Invalid prompt {prompt}, valid types are "
+                         "required_keys_dict={required_keys_dict}")
+
+    if isinstance(prompt, str):
+        return "str"
+```
+
+For encoder-decoder models, you can use these input types in different combinations:
+
+### Single Input (Implicit Encoder Input)
+You can provide a single input, which will be treated as the encoder input. The decoder input will be assumed to be empty (None).
+
+```py
+single_text_prompt_raw = "Hello, my name is"
+single_text_prompt = TextPrompt(prompt="The president of the United States is")
+single_tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode("The capital of France is"))
+```
+
+### Explicit Encoder and Decoder Inputs
+For more control, you can explicitly specify both encoder and decoder inputs using the `ExplicitEncoderDecoderPrompt` class:
+
+```py
+class ExplicitEncoderDecoderPrompt(TypedDict):
+    """Represents an encoder/decoder model input prompt,
+    comprising an explicit encoder prompt and a 
+    decoder prompt.
+    The encoder and decoder prompts, respectively,
+    may formatted according to any of the
+    SingletonPromptInputs schemas, and are not
+    required to have the same schema.
+    Only the encoder prompt may have multi-modal data.
+    Note that an ExplicitEncoderDecoderPrompt may not
+    be used as an input to a decoder-only model,
+    and that the `encoder_prompt` and `decoder_prompt`
+    fields of this data structure may not themselves
+    must be SingletonPromptInputs instances.
+    """
+
+    encoder_prompt: SingletonPromptInputs
+
+    decoder_prompt: SingletonPromptInputs
+```
+
+Example usage:
+
+```py
+enc_dec_prompt = ExplicitEncoderDecoderPrompt(
+    encoder_prompt="Summarize this text:",
+    decoder_prompt="Summary:"
+)
+```
+
+## Generating Text
+To generate text with an encoder-decoder model, use the `generate` method of the `LLM` instance. You can pass a single prompt, or a list of prompts, along with sampling parameters:
+
+```py
+from aphrodite import SamplingParams
+
+sampling_params = SamplingParams(
+    temperature=0,
+    top_p=1.0,
+    min_tokens=0,
+    max_tokens=20,
+)
+
+outputs = llm.generate(prompts, sampling_params)
+```
+
+The `generate` method returns a list of `RequestOutput` objects containing the generated text and other information.
+
+## Advanced Usage
+### Mixing Input Types
+You can mix different input types in a single generation request:
+
+```py
+prompts = [
+    single_text_prompt_raw,
+    single_text_prompt,
+    single_tokens_prompt,
+    enc_dec_prompt1,
+    enc_dec_prompt2,
+    enc_dec_prompt3
+]
+
+outputs = llm.generate(prompts, sampling_params)
+```
+
+### Batching Encoder and Decoder Prompts
+For efficient processing of multiple encoder-decoder pairs, use the `zip_enc_dec_prompt_lists` helper function:
+
+```py
+from aphrodite.common.utils import zip_enc_dec_prompt_lists
+
+zipped_prompt_list = zip_enc_dec_prompt_lists(
+    ['An encoder prompt', 'Another encoder prompt'],
+    ['A decoder prompt', 'Another decoder prompt']
+)
+```
+
+### Accessing Generated Text
+After generation, you can access the generated text and other information from the `RequestOutput` objects:
+
+```py
+for output in outputs:
+    prompt = output.prompt
+    encoder_prompt = output.encoder_prompt
+    generated_text = output.outputs[0].text
+    print(f"Encoder prompt: {encoder_prompt!r}, "
+          f"Decoder prompt: {prompt!r}, "
+          f"Generated text: {generated_text!r}")
+```
+
+## API Reference
+### LLM Class
+The `LLM` class in the `aphrodite.endpoints.llm` module is the main interface for working with both decoder-only and encoder-decoder models.
+
+Key methods:
+
+- `__init__(self, model: str, ...)`: Initialize an LLM instance with the specified model.
+- `generate(self, prompts: Union[PromptInputs, Sequence[PromptInputs]], ...)`: Generate text based on the given prompts and sampling parameters.
+
+### Input Types
+- `TextPrompt`: Represents a text prompt.
+- `TokensPrompt`: Represents a tokenized prompt.
+- `ExplicitEncoderDecoderPrompt`: Represents an explicit encoder-decoder prompt pair.
+
+### RequestOutput
+The `RequestOutput` class in the `aphrodite.common.outputs` module contains the results of a generation request.
+
+Key attributes:
+
+- `prompt`: The input (decoder prompt for encoder-decoder models).
+- `encoder_prompt`: The encoder prompt for encoder-decoder models.
+- `outputs`: A list of `CompletionOutput` objects containing the generate text and other information.
+
+For detailed info on these classes and their methods, please refer to the source code.

+ 5 - 0
docs/pages/usage/models.md

@@ -53,6 +53,11 @@ Aphrodite supports a large variety of generative Transformer models in [Hugging
 On ROCm platforms, Mistral and Mixtral are capped to 4096 max context length due to sliding window issues.
 :::
 
+## Encoder-Decoder Language Models
+| Architecture                   |             Example Model |
+| ------------------------------ | ------------------------: |
+| `BartForConditionalGeneration` | `facebook/bart-large-cnn` |
+
 ## Multimodal Language Models
 
 | Architecture                        | Supported Modalities |                       Example Model |

+ 97 - 0
examples/offline_inference/encoder_decoder_inference.py

@@ -0,0 +1,97 @@
+"""Prompting encoder-decoder models, specifically the BART model."""
+
+from aphrodite import LLM, SamplingParams
+from aphrodite.common.utils import zip_enc_dec_prompt_lists
+from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
+                              TokensPrompt)
+
+dtype = "float"
+
+# Create a BART encoder/decoder model instance
+llm = LLM(
+    model="facebook/bart-large-cnn",
+    dtype=dtype,
+)
+
+# Get BART tokenizer
+tokenizer = llm.llm_engine.get_tokenizer_group()
+
+# Test prompts
+#
+# This section shows all of the valid ways to prompt an
+# encoder/decoder model.
+#
+# - Helpers for building prompts
+text_prompt_raw = "Hello, my name is"
+text_prompt = TextPrompt(prompt="The president of the United States is")
+tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
+    prompt="The capital of France is"))
+# - Pass a single prompt to encoder/decoder model
+#   (implicitly encoder input prompt);
+#   decoder input prompt is assumed to be None
+
+single_text_prompt_raw = text_prompt_raw  # Pass a string directly
+single_text_prompt = text_prompt  # Pass a TextPrompt
+single_tokens_prompt = tokens_prompt  # Pass a TokensPrompt
+
+# - Pass explicit encoder and decoder input prompts within one data structure.
+#   Encoder and decoder prompts can both independently be text or tokens, with
+#   no requirement that they be the same prompt type. Some example prompt-type
+#   combinations are shown below, note that these are not exhaustive.
+
+enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
+    # Pass encoder prompt string directly, &
+    # pass decoder prompt tokens
+    encoder_prompt=single_text_prompt_raw,
+    decoder_prompt=single_tokens_prompt,
+)
+enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
+    # Pass TextPrompt to encoder, and
+    # pass decoder prompt string directly
+    encoder_prompt=single_text_prompt,
+    decoder_prompt=single_text_prompt_raw,
+)
+enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
+    # Pass encoder prompt tokens directly, and
+    # pass TextPrompt to decoder
+    encoder_prompt=single_tokens_prompt,
+    decoder_prompt=single_text_prompt,
+)
+
+# - Finally, here's a useful helper function for zipping encoder and
+#   decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
+#   instances
+zipped_prompt_list = zip_enc_dec_prompt_lists(
+    ['An encoder prompt', 'Another encoder prompt'],
+    ['A decoder prompt', 'Another decoder prompt'])
+
+# - Let's put all of the above example prompts together into one list
+#   which we will pass to the encoder/decoder LLM.
+prompts = [
+    single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
+    enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
+] + zipped_prompt_list
+
+print(prompts)
+
+# Create a sampling params object.
+sampling_params = SamplingParams(
+    temperature=0,
+    top_p=1.0,
+    min_tokens=0,
+    max_tokens=20,
+)
+
+# Generate output tokens from the prompts. The output is a list of
+# RequestOutput objects that contain the prompt, generated
+# text, and other information.
+outputs = llm.generate(prompts, sampling_params)
+
+# Print the outputs.
+for output in outputs:
+    prompt = output.prompt
+    encoder_prompt = output.encoder_prompt
+    generated_text = output.outputs[0].text
+    print(f"Encoder prompt: {encoder_prompt!r}, "
+          f"Decoder prompt: {prompt!r}, "
+          f"Generated text: {generated_text!r}")