Browse Source

feat: classifier free guidance - take 2

AlpinDale 4 months ago
parent
commit
2242c38a30

+ 0 - 0
aphrodite/cfg/__init__.py


+ 163 - 0
aphrodite/cfg/cfg_model_runner.py

@@ -0,0 +1,163 @@
+from typing import List, Optional, Union
+
+import torch
+
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.distributed import get_pp_group
+from aphrodite.multimodal import MultiModalInputs
+from aphrodite.task_handler.model_runner import (
+    FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper,
+    BatchPrefillWithPagedKVCacheWrapper, ModelInputForGPUWithSamplingMetadata,
+    ModelRunner)
+
+
+class CFGModelRunner(ModelRunner):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    @torch.inference_mode()
+    def model_execute(
+        self,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+        kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        num_steps: int = 1,
+    ) -> torch.Tensor:
+        if num_steps > 1:
+            raise ValueError("num_steps > 1 is not supported in ModelRunner")
+
+        if self.lora_config:
+            assert model_input.lora_requests is not None
+            assert model_input.lora_mapping is not None
+            self.set_active_loras(model_input.lora_requests,
+                                  model_input.lora_mapping)
+
+        if self.prompt_adapter_config:
+            assert model_input.prompt_adapter_requests is not None
+            assert model_input.prompt_adapter_mapping is not None
+            self.set_active_prompt_adapters(
+                model_input.prompt_adapter_requests,
+                model_input.prompt_adapter_mapping)
+
+        if self.attn_backend.get_name() == "flashinfer":
+            assert model_input.attn_metadata is not None
+            assert model_input.input_tokens is not None
+            if self.flashinfer_decode_workspace_buffer is None:
+                self.flashinfer_decode_workspace_buffer = torch.empty(
+                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
+                    dtype=torch.uint8,
+                    device=self.device)
+                self.flashinfer_decode_wrapper = \
+                    BatchDecodeWithPagedKVCacheWrapper(
+                    self.flashinfer_decode_workspace_buffer, "NHD")
+                self.flashinfer_prefill_workspace_buffer = torch.empty(
+                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
+                    dtype=torch.uint8,
+                    device=self.device)
+                self.flashinfer_prefill_wrapper = \
+                    BatchPrefillWithPagedKVCacheWrapper(
+                    self.flashinfer_prefill_workspace_buffer, "NHD")
+
+            model_input.attn_metadata.prefill_wrapper = \
+                self.flashinfer_prefill_wrapper
+            if model_input.attn_metadata.use_cuda_graph:
+                batch_size = model_input.input_tokens.shape[0]
+                model_input.attn_metadata.decode_wrapper = self.graph_runners[
+                    model_input.
+                    virtual_engine][batch_size].flashinfer_decode_wrapper
+            else:
+                model_input.attn_metadata.decode_wrapper = \
+                    self.flashinfer_decode_wrapper
+            model_input.attn_metadata.begin_forward()
+
+        # Currently cuda graph is only supported by the decode phase.
+        assert model_input.attn_metadata is not None
+        prefill_meta = model_input.attn_metadata.prefill_metadata
+        decode_meta = model_input.attn_metadata.decode_metadata
+        # TODO(andoorve): We can remove this once all
+        # virtual engines share the same kv cache.
+        virtual_engine = model_input.virtual_engine
+        if prefill_meta is None and decode_meta.use_cuda_graph:
+            assert model_input.input_tokens is not None
+            graph_batch_size = model_input.input_tokens.shape[0]
+            model_executable = self.graph_runners[virtual_engine][
+                graph_batch_size]
+        else:
+            model_executable = self.model
+
+        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
+        seqlen_agnostic_kwargs = {
+            "finished_requests_ids": model_input.finished_requests_ids,
+            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
+        } if self.has_seqlen_agnostic else {}
+
+        hidden_or_intermediate_states = model_executable(
+            input_ids=model_input.input_tokens,
+            positions=model_input.input_positions,
+            kv_caches=kv_caches,
+            attn_metadata=model_input.attn_metadata,
+            intermediate_tensors=intermediate_tensors,
+            **MultiModalInputs.as_kwargs(multi_modal_kwargs,
+                                         device=self.device),
+            **seqlen_agnostic_kwargs)
+
+        return hidden_or_intermediate_states
+
+    @torch.inference_mode()
+    def get_logits(
+        self,
+        hidden_or_intermediate_states: torch.Tensor,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+    ) -> torch.Tensor:
+        return self.model._get_logits(hidden_or_intermediate_states, 
+                                      model_input.sampling_metadata)
+
+    @torch.inference_mode()
+    def compute_logits(
+        self,
+        logits: torch.Tensor,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+    ) -> torch.Tensor:
+        return self.model.compute_logits(logits,
+                                         model_input.sampling_metadata)
+
+    @torch.inference_mode()
+    def do_sample(
+        self,
+        logits: torch.Tensor,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+    ):
+        if not self.is_driver_worker:
+            return []
+
+        # Sample the next token.
+        output: SamplerOutput = self.model.sample(
+            logits=logits,
+            sampling_metadata=model_input.sampling_metadata,
+        )
+
+        if self.return_hidden_states:
+            raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner")
+
+        return [output]
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+        kv_caches: List[torch.Tensor],
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        num_steps: int = 1,
+    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
+
+        hidden_or_intermediate_states = self.model_execute(
+            model_input, kv_caches, intermediate_tensors, num_steps)
+
+        if not get_pp_group().is_last_rank:
+            return hidden_or_intermediate_states
+
+        hidden_or_intermediate_states = self.get_logits(
+            hidden_or_intermediate_states, model_input)
+        logits = self.compute_logits(hidden_or_intermediate_states, model_input)
+
+        return self.do_sample(logits, model_input)

+ 194 - 0
aphrodite/cfg/cfg_worker.py

@@ -0,0 +1,194 @@
+import copy
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from aphrodite.cfg.cfg_model_runner import CFGModelRunner
+from aphrodite.cfg.separated_worker import SeparatedWorker
+from aphrodite.common.config import CFGConfig, ParallelConfig
+from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
+                                       SequenceData, SequenceGroupMetadata)
+from aphrodite.distributed import get_pp_group, get_tp_group
+from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
+                                                WorkerBase)
+
+
+def create_cfg_worker(*args, **kwargs) -> "CFGWorker":
+    assert "cfg_config" in kwargs
+    cfg_config: CFGConfig = kwargs.get("cfg_config")
+    assert cfg_config is not None
+    kwargs.pop("cfg_config")
+
+    kwargs["model_runner_cls"] = CFGModelRunner
+    root_worker = SeparatedWorker(*args, **kwargs)
+
+    guidance_model_config = cfg_config.guidance_model_config
+    guidance_parallel_config = cfg_config.guidance_parallel_config
+    kwargs.update(
+        model_config=guidance_model_config,
+        parallel_config=guidance_parallel_config,
+    )
+    guidance_worker = SeparatedWorker(*args, **kwargs)
+
+    return CFGWorker(
+        root_worker=root_worker,
+        guidance_worker=guidance_worker,
+        is_driver_worker=kwargs["is_driver_worker"],
+        parallel_config=kwargs["parallel_config"],
+    )
+
+
+class CFGWorker(LoraNotSupportedWorkerBase):
+    def __init__(
+        self,
+        root_worker: WorkerBase,
+        guidance_worker: WorkerBase,
+        is_driver_worker: bool,
+        parallel_config: ParallelConfig,
+    ):
+        self.root_worker = root_worker
+        self.guidance_worker = guidance_worker
+        self.is_driver_worker = is_driver_worker
+        self.parallel_config = parallel_config
+        assert self.parallel_config.pipeline_parallel_size == 1
+
+    def init_device(self):
+        self.root_task_handler.init_device()
+        self.guidance_worker.init_device()
+
+    def load_model(self):
+        self.root_worker.load_model()
+        self.guidance_worker.share_model(self.root_worker)
+
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
+        (
+            num_gpu_blocks,
+            num_cpu_blocks,
+        ) = self.root_worker.determine_num_available_blocks()
+
+        root_cache_block_size_bytes = (
+            self.root_worker.get_cache_block_size_bytes()
+        )
+        guidance_cache_block_size_bytes = (
+            self.guidance_worker.get_cache_block_size_bytes()
+        )
+
+        new_num_gpu_blocks = int(
+            num_gpu_blocks
+            * root_cache_block_size_bytes
+            / (guidance_cache_block_size_bytes + root_cache_block_size_bytes)
+        )
+        return new_num_gpu_blocks, num_cpu_blocks
+
+    def initialize_cache(
+        self, num_gpu_blocks: int, num_cpu_blocks: int
+    ) -> None:
+        self.root_worker.initialize_cache(
+            num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks
+        )
+        self.guidance_worker.initialize_cache(
+            num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks
+        )
+
+    @property
+    def do_metadata_broadcast(self) -> bool:
+        return self.parallel_config.tensor_parallel_size > 1
+
+    @torch.inference_mode()
+    def execute_model(
+        self, execute_model_req: Optional[ExecuteModelRequest] = None
+    ) -> List[SamplerOutput]:
+        # prepare negative request with shallow copy
+        if execute_model_req is not None:
+            negative_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+            negative_excute_model_req = execute_model_req.clone(
+                negative_seq_group_metadata_list
+            )
+            for seq_group_metadata in execute_model_req.seq_group_metadata_list:
+                negative_seq_group_metadata = copy.copy(seq_group_metadata)
+                negative_seq_data: Dict[int, SequenceData] = {}
+                negative_block_tables: Dict[int, List[int]] = {}
+                assert len(seq_group_metadata.seq_data) == 1
+                for seq_id in seq_group_metadata.seq_data.keys():
+                    negative_seq_data[
+                        seq_id
+                    ] = seq_group_metadata.negative_seq_data
+                    negative_block_tables[
+                        seq_id
+                    ] = seq_group_metadata.negative_block_table
+
+                if negative_seq_group_metadata.is_prompt:
+                    negative_seq_group_metadata.token_chunk_size = list(
+                        negative_seq_data.values()
+                    )[0].get_len()
+
+                negative_seq_group_metadata.seq_data = negative_seq_data
+                negative_seq_group_metadata.block_tables = negative_block_tables
+                negative_seq_group_metadata.negative_seq_data = None
+                negative_seq_group_metadata.negative_block_table = None
+                negative_seq_group_metadata_list.append(
+                    negative_seq_group_metadata
+                )
+            negative_excute_model_req.seq_group_metadata_list = (
+                negative_seq_group_metadata_list
+            )
+        else:
+            negative_excute_model_req = None
+
+        inputs = self.root_worker.prepare_input(execute_model_req)
+        negative_inputs = self.guidance_worker.prepare_input(
+            negative_excute_model_req
+        )
+        if inputs is None:
+            assert negative_inputs is None
+            return None
+
+        # get root models's logits
+        condition_logits = self.root_worker.execute_model_part(inputs)
+        # get unconditional logits
+        unconditional_logits = self.guidance_worker.execute_model_part(
+            negative_inputs
+        )
+
+        # do classifier free guidance logist process
+        model_input, _ = inputs
+        if condition_logits is not None:
+            for seq_group in model_input.sampling_metadata.seq_groups:
+                seq_ids = seq_group.seq_ids
+                guidance_scale = seq_group.sampling_params.guidance_scale
+                if guidance_scale == 1.0:
+                    break
+                for seq_id, logits_row_idx in zip(
+                    seq_ids, seq_group.sample_indices
+                ):
+                    logits_row = torch.nn.functional.log_softmax(
+                        condition_logits[logits_row_idx], dim=-1
+                    )
+                    unconditional_logits_row = torch.nn.functional.log_softmax(
+                        unconditional_logits[logits_row_idx], dim=-1
+                    )
+                    condition_logits[logits_row_idx] = (
+                        guidance_scale * (logits_row - unconditional_logits_row)
+                        + unconditional_logits_row
+                    )
+
+        # do logist_processor
+        scores = self.root_worker.compute_logits(condition_logits, model_input)
+        if not self.is_driver_worker:
+            return []
+
+        # do sample
+        output = self.root_worker.do_sample(scores, model_input)
+
+        if not get_pp_group().is_last_rank:
+            # output is IntermediateTensors
+            get_pp_group().send_tensor_dict(
+                output.tensors, all_gather_group=get_tp_group()
+            )
+            return [None]
+
+        # output is List[SamplerOutput]
+        return output
+
+    def get_cache_block_size_bytes(self):
+        raise NotImplementedError

+ 77 - 0
aphrodite/cfg/separated_worker.py

@@ -0,0 +1,77 @@
+from typing import List, Optional, Tuple
+
+import torch
+
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.distributed import get_pp_group, get_tp_group
+from aphrodite.task_handler.model_runner import (
+    ModelInputForGPUWithSamplingMetadata)
+from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
+from aphrodite.task_handler.worker import Worker
+from aphrodite.task_handler.worker_base import WorkerInput
+
+
+class SeparatedWorker(Worker):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    @torch.inference_mode()
+    def get_logits(
+        self,
+        hidden_or_intermediate_states: torch.Tensor,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+    ) -> torch.Tensor:
+        return self.model_runner.get_logits(
+            hidden_or_intermediate_states, model_input)
+
+    @torch.inference_mode()
+    def compute_logits(
+        self,
+        logits: torch.Tensor,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+    ) -> torch.Tensor:
+        return self.model_runner.compute_logits(logits, model_input)
+
+    @torch.inference_mode()
+    def do_sample(
+        self,
+        logits: torch.Tensor,
+        model_input: ModelInputForGPUWithSamplingMetadata,
+    ) -> List[SamplerOutput]:
+        return self.model_runner.do_sample(logits, model_input)
+
+    @torch.inference_mode()
+    def execute_model_part(
+        self,
+        inputs: Tuple[BroadcastableModelInput, WorkerInput],
+    ) -> Optional[List[SamplerOutput]]:
+
+        model_input, worker_input = inputs
+        num_steps = worker_input.num_steps
+
+        self.execute_worker(worker_input)
+
+        # If there is no input, we don't need to execute the model.
+        if worker_input.num_seq_groups == 0:
+            return []
+
+        intermediate_tensors = None
+        if not get_pp_group().is_first_rank:
+            intermediate_tensors = IntermediateTensors(
+                get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group()))
+
+        hidden_or_intermediate_states = self.model_runner.model_execute(
+            model_input, 
+            self.kv_cache[worker_input.virtual_engine]
+            if self.kv_cache is not None else None, 
+            intermediate_tensors,
+            num_steps
+        )
+
+        # Compute the logits in the last pipeline stage.
+        if not get_pp_group().is_last_rank:
+            return hidden_or_intermediate_states
+
+        logits = self.get_logits(hidden_or_intermediate_states, model_input)
+
+        return logits

+ 38 - 0
aphrodite/common/config.py

@@ -1555,6 +1555,43 @@ class SpeculativeConfig:
         return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"
 
 
+class CFGConfig:
+    @staticmethod
+    def maybe_create_spec_config(
+        target_model_config: ModelConfig,
+        target_parallel_config: ParallelConfig,
+        guidance_model: Optional[str],
+    ):
+        if guidance_model is None:
+            return None
+
+        guidance_parallel_config = target_parallel_config
+        assert target_model_config.model == guidance_model
+        guidance_model_config = target_model_config
+
+        return CFGConfig(
+            guidance_model_config,
+            guidance_parallel_config
+        )
+
+    def __init__(
+        self,
+        guidance_model_config: ModelConfig,
+        guidance_parallel_config: ParallelConfig,
+    ):
+        self.guidance_model_config = guidance_model_config
+        self.guidance_parallel_config = guidance_parallel_config
+
+    def _verify_args(self) -> None:
+        if self.guidance_model_config:
+            self.guidance_model_config.verify_with_parallel_config(
+                self.guidance_parallel_config)
+
+    def __repr__(self) -> str:
+        guidance_model = self.guidance_model_config.model
+        return f"CFGConfig({guidance_model=})"
+
+
 @dataclass
 class LoRAConfig:
     max_lora_rank: int
@@ -1877,6 +1914,7 @@ class EngineConfig:
     speculative_config: Optional[SpeculativeConfig]
     decoding_config: Optional[DecodingConfig]
     prompt_adapter_config: Optional[PromptAdapterConfig]
+    cfg_config: Optional[CFGConfig]
 
     def __post_init__(self):
         """Verify configs are valid & consistent with each other.

+ 3 - 0
aphrodite/common/sampling_params.py

@@ -175,6 +175,7 @@ class SamplingParams(
             Defaults to None.
         skew: Bias the token selection towards higher or lower probability
             tokens. Defaults to 0 (disabled).
+        guidance_scale: The scale of CFG guidance to apply.
     """
 
     n: int = 1
@@ -227,6 +228,7 @@ class SamplingParams(
     dry_allowed_length: int = 2
     dry_sequence_breaker_ids: List[int] = []
     skew: float = 0.0
+    guidance_scale: Optional[float] = None
     # The below fields are not supposed to be used as an input.
     # They are set in post_init.
     output_text_buffer_length: int = 0
@@ -279,6 +281,7 @@ class SamplingParams(
         "dry_allowed_length": 2,
         "dry_sequence_breaker_ids": [],
         "skew": 0.0,
+        "guidance_scale": None,
     }
 
     def __post_init__(self) -> None:

+ 47 - 6
aphrodite/common/sequence.py

@@ -330,6 +330,7 @@ class Sequence:
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         from_decoder_prompt: bool = True,
+        from_negative_prompt: bool = False,
     ) -> None:
         self.seq_id = seq_id
         self.inputs = inputs
@@ -338,6 +339,7 @@ class Sequence:
         self.lora_request = lora_request
         self.prompt_adapter_request = prompt_adapter_request
         self.from_decoder_prompt = from_decoder_prompt
+        self.from_negative_prompt = from_negative_prompt
         self._prompt: Optional[str] = None
         self._prompt_token_ids: Optional[List[int]] = None
 
@@ -395,8 +397,12 @@ class Sequence:
 
         # Select decoder or encoder input prompt str,
         # as appropriate
-        prompt_key: str = ("prompt"
-                           if self.from_decoder_prompt else "encoder_prompt")
+        prompt_key: str = "prompt"
+        if not self.from_decoder_prompt:
+            prompt_key = "encoder_prompt"
+        if self.from_negative_prompt:
+            assert self.from_decoder_prompt is True
+            prompt_key = "negative_prompt"
 
         # Cache prompt
         self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
@@ -410,9 +416,12 @@ class Sequence:
 
         # Select decoder or encoder input prompt
         # token ids, as appropriate
-        prompt_token_ids_key: str = ("prompt_token_ids"
-                                     if self.from_decoder_prompt else
-                                     "encoder_prompt_token_ids")
+        prompt_token_ids_key: str = "prompt_token_ids"
+        if not self.from_decoder_prompt:
+            "encoder_prompt_token_ids"
+        if self.from_negative_prompt:
+            assert self.from_decoder_prompt is True
+            prompt_token_ids_key = "negative_prompt_token_ids"
 
         # Cache computed prompt token ids
         self._prompt_token_ids = cast(List[int],
@@ -476,6 +485,9 @@ class Sequence:
     def get_token_ids(self) -> List[int]:
         return self.data.get_token_ids()
 
+    def get_negative_token_ids(self) -> List[int]:
+        return self.data.get_negative_token_ids()
+
     def get_prompt_token_ids(self) -> Tuple[int, ...]:
         return self.data.get_prompt_token_ids()
 
@@ -532,7 +544,8 @@ class Sequence:
     def __repr__(self) -> str:
         return (f"Sequence(seq_id={self.seq_id}, "
                 f"status={self.status.name}, "
-                f"num_blocks={self.n_blocks}, ")
+                f"num_blocks={self.n_blocks}, "
+                f"data={self.data})")
 
 
 class SequenceGroupState(
@@ -576,6 +589,7 @@ class SequenceGroup:
         embeddings: Optional[List[float]] = None,
         pooling_params: Optional[PoolingParams] = None,
         encoder_seq: Optional[Sequence] = None,
+        negative_seq: Optional[Sequence] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
     ) -> None:
         self.request_id = request_id
@@ -596,6 +610,9 @@ class SequenceGroup:
         self.prompt_adapter_request = prompt_adapter_request
         self.encoder_seq = encoder_seq
 
+        assert self.is_single_seq is True
+        self.negative_seq = negative_seq
+
     @property
     def prompt(self) -> Optional[str]:
         # All sequences in the group should have the same prompt.
@@ -624,6 +641,22 @@ class SequenceGroup:
         return (self.encoder_seq.prompt_token_ids
                 if self.encoder_seq is not None else None)
 
+    @property
+    def negative_prompt(self) -> Optional[str]:
+        # There are either 0 or 1 negative sequences
+        # We use the prompt of an arbitrary sequence.
+        assert self.is_single_seq is True
+        return (self.negative_seq.prompt
+                if self.negative_seq is not None else None)
+
+    @property
+    def negative_prompt_token_ids(self) -> List[int]:
+        # All sequences in the group should have the same prompt.
+        # We use the prompt of an arbitrary sequence.
+        assert self.is_single_seq is True
+        return (self.negative_seq.prompt_token_ids
+                if self.negative_seq is not None else None)
+
     @property
     def multi_modal_data(self) -> "MultiModalDataDict":
         # All sequences in the group should have the same multi-modal data.
@@ -723,6 +756,12 @@ class SequenceGroup:
     def get_encoder_seq(self) -> Optional[Sequence]:
         return self.encoder_seq
 
+    def has_negative_prompt(self) -> bool:
+        return self.negative_seq is not None
+
+    def get_negative_seq(self) -> Optional[Sequence]:
+        return self.negative_seq
+
     def get_unfinished_seqs(self) -> List[Sequence]:
         if self.is_single_seq:
             return self.seqs if not self.seqs[0].is_finished() else []
@@ -921,6 +960,8 @@ class SequenceGroupMetadata(
     multi_modal_data: Optional[Any] = None
     encoder_seq_data: Optional[SequenceData] = None
     cross_block_table: Optional[List[int]] = None
+    negative_seq_data: Optional[SequenceData] = None
+    negative_block_table: Optional[List[int]] = None
     prompt_adapter_request: Optional[PromptAdapterRequest] = None
     token_chunk_size: Optional[int] = None
 

+ 64 - 14
aphrodite/engine/aphrodite_engine.py

@@ -9,9 +9,9 @@ from loguru import logger
 from transformers import PreTrainedTokenizer
 from typing_extensions import assert_never
 
-from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig,
-                                     EngineConfig, LoadConfig, LoRAConfig,
-                                     ModelConfig, ParallelConfig,
+from aphrodite.common.config import (CacheConfig, CFGConfig, DecodingConfig,
+                                     DeviceConfig, EngineConfig, LoadConfig,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig,
                                      SpeculativeConfig)
 from aphrodite.common.logger import setup_logger
@@ -70,9 +70,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
 _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
 
 PromptComponents = Tuple[Optional[str], List[int],
-                         Optional[MultiModalDataDict]]
+                         Optional[MultiModalDataDict],
+                         Optional[None], Optional[None]]
 DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
-                                Optional[MultiModalDataDict]]
+                                Optional[MultiModalDataDict],
+                                Optional[None], Optional[None]]
 
 
 class AphroditeEngine:
@@ -171,6 +173,7 @@ class AphroditeEngine:
         speculative_config: Optional[SpeculativeConfig],
         decoding_config: Optional[DecodingConfig],
         prompt_adapter_config: Optional[PromptAdapterConfig],
+        cfg_config: Optional[CFGConfig],
         executor_class: Type[ExecutorBase],
         log_stats: bool,
         stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
@@ -185,6 +188,7 @@ class AphroditeEngine:
         config_dict = {
             "Model": model_config.model,
             "Speculative Config": speculative_config,
+            "CFG Config": cfg_config,
             "DataType": model_config.dtype,
             "Model Load Format": load_config.load_format,
             "Tensor Parallel Size": parallel_config.tensor_parallel_size,
@@ -233,6 +237,7 @@ class AphroditeEngine:
         self.load_config = load_config
         self.decoding_config = decoding_config or DecodingConfig()
         self.prompt_adapter_config = prompt_adapter_config
+        self.cfg_config = cfg_config
         self.log_stats = log_stats
 
         if not self.model_config.skip_tokenizer_init:
@@ -269,6 +274,7 @@ class AphroditeEngine:
             speculative_config=speculative_config,
             load_config=load_config,
             prompt_adapter_config=prompt_adapter_config,
+            cfg_config=cfg_config,
         )
 
         if not self.model_config.embedding_mode:
@@ -533,6 +539,16 @@ class AphroditeEngine:
         seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
                        lora_request, prompt_adapter_request)
 
+        negative_seq = None
+        if 'negative_prompt_token_ids' in processed_inputs:
+           negative_seq = Sequence(seq_id, 
+                                   processed_inputs, 
+                                   block_size, 
+                                   eos_token_id, 
+                                   lora_request, 
+                                   prompt_adapter_request,
+                                   from_negative_prompt=True)
+
         encoder_seq = None
         if 'encoder_prompt_token_ids' in processed_inputs:
             encoder_seq = Sequence(seq_id,
@@ -553,6 +569,7 @@ class AphroditeEngine:
                 lora_request=lora_request,
                 prompt_adapter_request=prompt_adapter_request,
                 encoder_seq=encoder_seq,
+                negative_seq=negative_seq,
             )
         elif isinstance(params, PoolingParams):
             seq_group = self._create_sequence_group_with_pooling(
@@ -661,6 +678,8 @@ class AphroditeEngine:
                 lora_request=lora_request,
             )
             multi_modal_data = None
+            negative_prompt = None
+            negative_prompt_token_ids = None
         elif isinstance(inputs, dict):
             if "prompt_token_ids" in inputs:
                 prompt = None
@@ -674,11 +693,27 @@ class AphroditeEngine:
                     lora_request=lora_request,
                 )
 
+            if "negative_prompt_token_ids" in inputs:
+                negative_prompt = None
+                negative_prompt_token_ids = inputs["negative_prompt_token_ids"]
+            elif "negative_prompt" in inputs:
+                negative_prompt = parsed_negative_prompt = inputs[
+                    "negative_prompt"]
+                negative_prompt_token_ids = self._tokenize_prompt(
+                    parsed_negative_prompt,
+                    request_id=request_id,
+                    lora_request=lora_request,
+                )
+            else:
+                negative_prompt = None
+                negative_prompt_token_ids = None
+
             multi_modal_data = inputs.get("multi_modal_data")
         else:
             assert_never(inputs)
 
-        return prompt, prompt_token_ids, multi_modal_data
+        return (prompt, prompt_token_ids, multi_modal_data,
+                negative_prompt, negative_prompt_token_ids)
 
     def _apply_prompt_adapter(
         self,
@@ -728,8 +763,10 @@ class AphroditeEngine:
         encoder_comps: PromptComponents,
         decoder_comps: DecoderPromptComponents,
     ) -> EncoderDecoderLLMInputs:
-        encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
-        decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
+        encoder_prompt, encoder_prompt_ids, encoder_mm_data, \
+            encoder_negative_prompt, encoder_negative_prompt_ids = encoder_comps
+        decoder_prompt, decoder_prompt_ids, decoder_mm_data, \
+            decoder_negative_prompt, decoder_negative_prompt_ids= decoder_comps
 
         if encoder_mm_data is not None or decoder_mm_data is not None:
             raise ValueError("Multi-modal encoder-decoder models are "
@@ -737,12 +774,18 @@ class AphroditeEngine:
 
         decoder_prompt_ids = (
             self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
+        decoder_negative_prompt_ids = (
+            self._prepare_decoder_input_ids_for_generation(decoder_negative_prompt_ids))
 
         return EncoderDecoderLLMInputs(
             prompt_token_ids=decoder_prompt_ids,
             prompt=decoder_prompt,
+            negative_prompt_token_ids=decoder_negative_prompt_ids,
+            negative_prompt=decoder_negative_prompt,
             encoder_prompt_token_ids=encoder_prompt_ids,
             encoder_prompt=encoder_prompt,
+            encoder_negative_prompt_token_ids=encoder_negative_prompt_ids,
+            encoder_negative_prompt=encoder_negative_prompt,
         )
 
     def _process_encoder_decoder_prompt(
@@ -787,7 +830,7 @@ class AphroditeEngine:
             )
 
             if (decoder_input := inputs["decoder_prompt"]) is None:
-                decoder_comps = None, None, None
+                decoder_comps = None, None, None, None, None
             else:
                 decoder_comps = self._extract_prompt_components(
                     decoder_input,
@@ -799,7 +842,7 @@ class AphroditeEngine:
                 request_id=request_id,
             )
 
-            decoder_comps = None, None, None
+            decoder_comps = None, None, None, None, None
 
         return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
 
@@ -808,14 +851,17 @@ class AphroditeEngine:
         prompt_comps: PromptComponents,
         prompt_adapter_request: Optional[PromptAdapterRequest],
     ) -> LLMInputs:
-        prompt, prompt_token_ids, multi_modal_data = prompt_comps
+        prompt, prompt_token_ids, multi_modal_data, \
+            negative_prompt, negative_prompt_token_ids = prompt_comps
 
         prompt_token_ids = self._apply_prompt_adapter(
             prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
 
         return LLMInputs(prompt_token_ids=prompt_token_ids,
                          prompt=prompt,
-                         multi_modal_data=multi_modal_data)
+                         multi_modal_data=multi_modal_data,
+                         negative_prompt_token_ids=negative_prompt_token_ids,
+                         negative_prompt=negative_prompt)
 
     def _process_decoder_only_prompt(
         self,
@@ -960,6 +1006,7 @@ class AphroditeEngine:
         lora_request: Optional[LoRARequest],
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         encoder_seq: Optional[Sequence] = None,
+        negative_seq: Optional[Sequence] = None,
     ) -> SequenceGroup:
         """Creates a SequenceGroup with SamplingParams."""
         max_logprobs = self.get_model_config().max_logprobs
@@ -984,7 +1031,8 @@ class AphroditeEngine:
             sampling_params=sampling_params,
             lora_request=lora_request,
             prompt_adapter_request=prompt_adapter_request,
-            encoder_seq=encoder_seq)
+            encoder_seq=encoder_seq,
+            negative_seq=negative_seq)
 
         return seq_group
 
@@ -997,6 +1045,7 @@ class AphroditeEngine:
         lora_request: Optional[LoRARequest],
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
         encoder_seq: Optional[Sequence] = None,
+        negative_seq: Optional[Sequence] = None,
     ) -> SequenceGroup:
         """Creates a SequenceGroup with PoolingParams."""
         # Defensive copy of PoolingParams, which are used by the pooler
@@ -1009,7 +1058,8 @@ class AphroditeEngine:
             lora_request=lora_request,
             pooling_params=pooling_params,
             prompt_adapter_request=prompt_adapter_request,
-            encoder_seq=encoder_seq)
+            encoder_seq=encoder_seq,
+            negative_seq=negative_seq)
 
         return seq_group
 

+ 23 - 7
aphrodite/engine/args_tools.py

@@ -8,12 +8,12 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
 
 from loguru import logger
 
-from aphrodite.common.config import (CacheConfig, ConfigFormat, DecodingConfig,
-                                     DeviceConfig, EngineConfig, LoadConfig,
-                                     LoadFormat, LoRAConfig, ModelConfig,
-                                     ParallelConfig, PromptAdapterConfig,
-                                     SchedulerConfig, SpeculativeConfig,
-                                     TokenizerPoolConfig)
+from aphrodite.common.config import (CacheConfig, CFGConfig, ConfigFormat,
+                                     DecodingConfig, DeviceConfig,
+                                     EngineConfig, LoadConfig, LoadFormat,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
+                                     PromptAdapterConfig, SchedulerConfig,
+                                     SpeculativeConfig, TokenizerPoolConfig)
 from aphrodite.common.utils import FlexibleArgumentParser, is_cpu
 from aphrodite.executor.executor_base import ExecutorBase
 from aphrodite.quantization import QUANTIZATION_METHODS
@@ -149,6 +149,8 @@ class EngineArgs:
     max_prompt_adapter_token: int = 0
     # Log Options
     disable_log_stats: bool = False
+    # Classifier-Free-Guidance (CFG) options
+    cfg_model: Optional[str] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -855,6 +857,14 @@ class EngineArgs:
             "disable logging statistics",
         )
 
+        # CFG Options
+        parser.add_argument(
+            "--cfg-model",
+            type=str,
+            default=EngineArgs.cfg_model,
+            help="The name of the model to be used in CFG."
+        )
+
         return parser
 
     @classmethod
@@ -1033,6 +1043,11 @@ class EngineArgs:
             if speculative_config is None \
             else speculative_config.num_lookahead_slots
 
+        cfg_config = CFGConfig.maybe_create_spec_config(
+            target_model_config=model_config,
+            target_parallel_config=parallel_config,
+            guidance_model=self.cfg_model)
+
         scheduler_config = SchedulerConfig(
             max_num_batched_tokens=self.max_num_batched_tokens,
             max_num_seqs=self.max_num_seqs,
@@ -1099,7 +1114,8 @@ class EngineArgs:
                             speculative_config=speculative_config,
                             load_config=load_config,
                             decoding_config=decoding_config,
-                            prompt_adapter_config=prompt_adapter_config)
+                            prompt_adapter_config=prompt_adapter_config,
+                            cfg_config=cfg_config)
 
 
 @dataclass

+ 17 - 1
aphrodite/engine/async_aphrodite.py

@@ -437,6 +437,7 @@ class _AsyncAphrodite(AphroditeEngine):
                 lora_request=lora_request,
             )
             multi_modal_data = None
+            negative_prompt = negative_prompt_token_ids = None
         elif isinstance(inputs, dict):
             if "prompt_token_ids" in inputs:
                 prompt = None
@@ -450,11 +451,26 @@ class _AsyncAphrodite(AphroditeEngine):
                     lora_request=lora_request,
                 )
 
+            if "negative_prompt_token_ids" in inputs:
+                negative_prompt = None
+                negative_prompt_token_ids = inputs["negative_prompt_token_ids"]
+            elif "negative_prompt" in inputs:
+                negative_prompt = parsed_negative_prompt = inputs[
+                    "negative_prompt"]
+                negative_prompt_token_ids = await self._tokenize_prompt_async(
+                    parsed_negative_prompt,
+                    request_id=request_id,
+                    lora_request=lora_request,
+                )
+            else:
+                negative_prompt = negative_prompt_token_ids = None
+
             multi_modal_data = inputs.get("multi_modal_data")
         else:
             assert_never(inputs)
 
-        return prompt, prompt_token_ids, multi_modal_data
+        return (prompt, prompt_token_ids, multi_modal_data,
+                negative_prompt, negative_prompt_token_ids)
 
     async def _process_encoder_decoder_prompt_async(
         self,

+ 2 - 0
aphrodite/engine/output_processor/single_step.py

@@ -84,6 +84,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
             # only have one sequence
             seq = seq_group.seqs[0]
             seq.append_token_id(sample.output_token, sample.logprobs)
+            negative_seq = seq_group.negative_seq
+            negative_seq.append_token_id(sample.output_token, sample.logprobs)
             if sampling_params.detokenize and self.detokenizer:
                 new_char_count = self.detokenizer.decode_sequence_inplace(
                     seq, sampling_params)

+ 6 - 4
aphrodite/executor/executor_base.py

@@ -1,10 +1,10 @@
 from abc import ABC, abstractmethod
 from typing import List, Optional, Set, Tuple
 
-from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
-                                     LoRAConfig, ModelConfig, ParallelConfig,
-                                     PromptAdapterConfig, SchedulerConfig,
-                                     SpeculativeConfig)
+from aphrodite.common.config import (CacheConfig, CFGConfig, DeviceConfig,
+                                     LoadConfig, LoRAConfig, ModelConfig,
+                                     ParallelConfig, PromptAdapterConfig,
+                                     SchedulerConfig, SpeculativeConfig)
 from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
 from aphrodite.lora.request import LoRARequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
@@ -31,6 +31,7 @@ class ExecutorBase(ABC):
         lora_config: Optional[LoRAConfig],
         speculative_config: Optional[SpeculativeConfig],
         prompt_adapter_config: Optional[PromptAdapterConfig],
+        cfg_config: Optional[CFGConfig],
     ) -> None:
         self.model_config = model_config
         self.cache_config = cache_config
@@ -41,6 +42,7 @@ class ExecutorBase(ABC):
         self.device_config = device_config
         self.speculative_config = speculative_config
         self.prompt_adapter_config = prompt_adapter_config
+        self.cfg_config = cfg_config
 
         self._init_executor()
 

+ 5 - 0
aphrodite/executor/gpu_executor.py

@@ -57,6 +57,7 @@ class GPUExecutor(ExecutorBase):
             lora_config=self.lora_config,
             speculative_config=self.speculative_config,
             prompt_adapter_config=self.prompt_adapter_config,
+            cfg_config=self.cfg_config,
             is_driver_worker=(not self.parallel_config)
             or (rank % self.parallel_config.tensor_parallel_size == 0),
         )
@@ -76,6 +77,10 @@ class GPUExecutor(ExecutorBase):
             worker_kwargs.update(
                 worker_module_name="aphrodite.spec_decode.spec_decode_worker",
                 worker_class_name="create_spec_worker")
+        elif self.cfg_config:
+            worker_kwargs.update(
+                worker_module_name="aphrodite.cfg.cfg_worker",
+                worker_class_name="create_cfg_worker")
         else:
             worker_kwargs.update(
                 worker_module_name="aphrodite.task_handler.worker",

+ 22 - 1
aphrodite/inputs/data.py

@@ -33,7 +33,22 @@ class TokensPrompt(TypedDict):
     """
 
 
-SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
+class NegativeTextPrompt(TypedDict):
+    """Schema for a text prompt."""
+
+    negative_prompt: str
+    """The input text to be tokenized before passing to the model."""
+
+
+class NegativeTokensPrompt(TypedDict):
+    """Schema for a tokenized prompt."""
+
+    negative_prompt_token_ids: List[int]
+    """A list of token IDs to pass to the model."""
+
+
+SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt,
+                              NegativeTextPrompt, NegativeTokensPrompt]
 """
 Set of possible schemas for a single LLM input:
 - A text prompt (:class:`str` or :class:`TextPrompt`)
@@ -116,6 +131,9 @@ class LLMInputs(TypedDict):
     if the model supports it.
     """
 
+    negative_prompt_token_ids: NotRequired[Optional[List[int]]]
+    negative_prompt: NotRequired[Optional[str]]
+
 
 class EncoderDecoderLLMInputs(LLMInputs):
     """
@@ -132,6 +150,9 @@ class EncoderDecoderLLMInputs(LLMInputs):
     available.
     """
 
+    encoder_negative_prompt_token_ids: NotRequired[Optional[List[int]]]
+    encoder_negative_prompt: NotRequired[Optional[str]]
+
 
 _T1 = TypeVar("_T1",
               bound=SingletonPromptInputs,

+ 24 - 3
aphrodite/modeling/models/llama.py

@@ -34,13 +34,15 @@ from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
-                                   get_tensor_model_parallel_world_size)
+                                   get_tensor_model_parallel_world_size,
+                                   tensor_model_parallel_gather)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
-from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.logits_processor import (LogitsProcessor,
+                                                        _prune_hidden_states)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -429,7 +431,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
             logit_scale = getattr(config, "logit_scale", 1.0)
             self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                     config.vocab_size,
-                                                    logit_scale)
+                                                    logit_scale,
+                                                    logits_as_input=True)
+            self.org_vocab_size = config.vocab_size
             self.sampler = Sampler()
         else:
             self.lm_head = PPMissingLayer()
@@ -446,6 +450,23 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                                   attn_metadata, intermediate_tensors)
         return model_output
 
+    def _get_logits(self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> torch.Tensor:
+        hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
+        # Get the logits for the next tokens.
+        logits = self.lm_head.linear_method.apply(
+            self.lm_head,
+            hidden_states,
+            bias=None,
+        )
+        logits = tensor_model_parallel_gather(logits)
+        # Remove paddings in vocab (if any).
+        if logits is not None:
+            logits = logits[:, :self.org_vocab_size]
+        return logits
+
     def compute_logits(
         self,
         hidden_states: torch.Tensor,

+ 49 - 0
aphrodite/processing/block_manager_v2.py

@@ -17,6 +17,7 @@ from aphrodite.processing.block.utils import (
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 
 SeqId = int
+NegativeSeqId = str
 EncoderSeqId = str
 
 
@@ -98,6 +99,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         )
 
         self.block_tables: Dict[SeqId, BlockTable] = {}
+        self.negative_block_tables: Dict[NegativeSeqId, BlockTable] = {}
         self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {}
 
         self._computed_blocks_tracker = ComputedBlocksTracker(
@@ -123,6 +125,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
                 block_size=self.block_size,
             )
 
+        if seq_group.has_negative_prompt():
+            num_required_blocks += BlockTable.get_num_required_blocks(
+                seq_group.get_negative_seq().get_token_ids(),
+                block_size=self.block_size)
+
         if self.max_block_sliding_window is not None:
             num_required_blocks = min(num_required_blocks,
                                       self.max_block_sliding_window)
@@ -183,6 +190,15 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         assert (request_id
                 not in self.cross_block_tables), \
                 "block table already exists"
+        assert (request_id
+                not in self.negative_block_tables), \
+                "block table already exists"
+        
+        if seq_group.has_negative_prompt():
+                    block_table = self._allocate_sequence(
+                        seq_group.get_negative_seq())
+                    self.negative_block_tables[request_id] = block_table
+            
 
         check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)
 
@@ -215,6 +231,15 @@ class BlockSpaceManagerV2(BlockSpaceManager):
                     num_lookahead_slots=num_lookahead_slots,
                 ))
 
+            negative_block_table = self.negative_block_tables[
+                seq_group.request_id]
+            num_touched_blocks += (
+                negative_block_table.get_num_blocks_touched_by_append_slots(
+                    token_ids=negative_block_table.get_unseen_token_ids(
+                        seq_group.get_negative_seq().get_token_ids()),
+                    num_lookahead_slots=num_lookahead_slots,
+                ))
+
         num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
             Device.GPU)
         return num_touched_blocks <= num_free_gpu_blocks
@@ -223,6 +248,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         self,
         seq: Sequence,
         num_lookahead_slots: int,
+        seq_group: SequenceGroup,
     ) -> List[Tuple[int, int]]:
 
         block_table = self.block_tables[seq.seq_id]
@@ -232,6 +258,15 @@ class BlockSpaceManagerV2(BlockSpaceManager):
             num_lookahead_slots=num_lookahead_slots,
             num_computed_slots=seq.data.get_num_computed_tokens(),
         )
+
+        negative_block_table = self.negative_block_tables[seq_group.request_id]
+        negative_seq = seq_group.negative_seq
+        negative_block_table.append_token_ids(
+            token_ids=negative_block_table.get_unseen_token_ids(
+                negative_seq.get_token_ids()),
+            num_lookahead_slots=num_lookahead_slots,
+            num_computed_slots=negative_seq.data.get_num_computed_tokens(),
+        )
         # Return any new copy-on-writes.
         new_cows = self.block_allocator.clear_copy_on_writes()
         return new_cows
@@ -263,6 +298,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         self.cross_block_tables[request_id].free()
         del self.cross_block_tables[request_id]
 
+    def free_negative(self, seq_group: SequenceGroup) -> None:
+        request_id = seq_group.request_id
+        if request_id not in self.negative_block_tables:
+            return
+        self.negative_block_tables[request_id].free()
+        del self.negative_block_tables[request_id]
+
     def get_block_table(self, seq: Sequence) -> List[int]:
         block_ids = self.block_tables[seq.seq_id].physical_block_ids
         return block_ids  # type: ignore
@@ -274,6 +316,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
         assert all(b is not None for b in block_ids)
         return block_ids  # type: ignore
 
+    def get_negative_block_table(self, seq_group: SequenceGroup) -> List[int]:
+        request_id = seq_group.request_id
+        assert request_id in self.negative_block_tables
+        block_ids = self.negative_block_tables[request_id].physical_block_ids
+        assert all(b is not None for b in block_ids)
+        return block_ids
+
     def access_all_blocks_in_seq(self, seq: Sequence, now: float):
         if self.enable_caching:
             # Record the latest access time for the sequence. The actual update

+ 29 - 4
aphrodite/processing/scheduler.py

@@ -441,6 +441,7 @@ class Scheduler:
                     self.free_seq(seq)
 
                 self._free_seq_group_cross_attn_blocks(aborted_group)
+                self._free_seq_group_negative_blocks(aborted_group)
 
     def _free_seq_group_cross_attn_blocks(
         self,
@@ -453,6 +454,13 @@ class Scheduler:
         if seq_group.is_encoder_decoder():
             self.block_manager.free_cross(seq_group)
 
+    def _free_seq_group_negative_blocks(
+        self,
+        seq_group: SequenceGroup,
+    ) -> None:
+        if seq_group.has_negative_prompt():
+            self.block_manager.free_negative(seq_group)
+
     def has_unfinished_seqs(self) -> bool:
         return len(self.waiting) != 0 or len(self.running) != 0 or len(
             self.swapped) != 0
@@ -1036,7 +1044,8 @@ class Scheduler:
 
         return self.block_manager.can_append_slots(
             seq_group=seq_group,
-            num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
+            num_lookahead_slots=self._get_num_lookahead_slots(
+                is_prefill, seq_group),
         )
 
     def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
@@ -1073,6 +1082,14 @@ class Scheduler:
                 encoder_seq_data = None
                 cross_block_table = None
 
+            if seq_group.has_negative_prompt():
+                negative_seq_data = seq_group.get_negative_seq().data
+                negative_block_table = (
+                    self.block_manager.get_negative_block_table(seq_group))
+            else:
+                negative_seq_data = None
+                negative_block_table = None
+
             for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                 seq_id = seq.seq_id
                 seq_data[seq_id] = seq.data
@@ -1120,6 +1137,8 @@ class Scheduler:
                     computed_block_nums=common_computed_block_nums,
                     encoder_seq_data=encoder_seq_data,
                     cross_block_table=cross_block_table,
+                    negative_seq_data=negative_seq_data,
+                    negative_block_table=negative_block_table,
                     state=seq_group.state,
                     # `multi_modal_data` will only be present for the 1st comm
                     # between engine and worker.
@@ -1169,6 +1188,7 @@ class Scheduler:
             if seq_group.is_finished():
                 # Free cross-attention block table, if it exists
                 self._free_seq_group_cross_attn_blocks(seq_group)
+                self._free_seq_group_negative_blocks(seq_group)
                 # Add the finished requests to the finished requests list.
                 # This list will be used to update the Mamba cache in the
                 # next step.
@@ -1198,11 +1218,14 @@ class Scheduler:
                 the new source and destination block indices for the appended
                 slots.
         """
-        num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
+        num_lookahead_slots = self._get_num_lookahead_slots(
+            is_prefill=False, seq_group=seq_group)
         seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
 
         for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
-            cows = self.block_manager.append_slots(seq, num_lookahead_slots)
+            cows = self.block_manager.append_slots(seq, num_lookahead_slots,
+                                                   seq_group)
+            assert len(cows) == 0
             if len(cows) > 0:
                 blocks_to_copy.extend(cows)
 
@@ -1313,7 +1336,9 @@ class Scheduler:
             passed_delay = True
         return passed_delay
 
-    def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
+    def _get_num_lookahead_slots(self, is_prefill: bool,
+                                 seq_group: Optional[SequenceGroup] = None
+                                 ) -> int:
         """The number of slots to allocate per sequence per step, beyond known
         token ids. Speculative decoding uses these slots to store KV activations
         of tokens which may or may not be accepted.

+ 3 - 0
aphrodite/task_handler/model_runner.py

@@ -979,6 +979,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
     def get_model_memory_usage(self):
         return self.model_memory_usage
 
+    def share_model(self, model: nn.Module) -> None:
+        self.model = model
+
     def save_sharded_state(
         self,
         path: str,

+ 9 - 4
aphrodite/task_handler/worker.py

@@ -8,10 +8,10 @@ import torch
 import torch.distributed
 from loguru import logger
 
-from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
-                                     LoRAConfig, ModelConfig, ParallelConfig,
-                                     PromptAdapterConfig, SchedulerConfig,
-                                     SpeculativeConfig)
+from aphrodite.common.config import (CacheConfig, CFGConfig, DeviceConfig,
+                                     LoadConfig, LoRAConfig, ModelConfig,
+                                     ParallelConfig, PromptAdapterConfig,
+                                     SchedulerConfig, SpeculativeConfig)
 from aphrodite.common.sequence import (ExecuteModelRequest,
                                        IntermediateTensors, SamplerOutput,
                                        SequenceGroupMetadata,
@@ -56,6 +56,7 @@ class Worker(LocalOrDistributedWorkerBase):
         lora_config: Optional[LoRAConfig] = None,
         speculative_config: Optional[SpeculativeConfig] = None,
         prompt_adapter_config: Optional[PromptAdapterConfig] = None,
+        cfg_config: Optional[CFGConfig] = None,
         is_driver_worker: bool = False,
         model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
     ) -> None:
@@ -70,6 +71,7 @@ class Worker(LocalOrDistributedWorkerBase):
         self.distributed_init_method = distributed_init_method
         self.lora_config = lora_config
         self.prompt_adapter_config = prompt_adapter_config
+        self.cfg_config = cfg_config
         self.load_config = load_config
         self.is_driver_worker = is_driver_worker
         if parallel_config and is_driver_worker:
@@ -155,6 +157,9 @@ class Worker(LocalOrDistributedWorkerBase):
     def load_model(self):
         self.model_runner.load_model()
 
+    def share_model(self, shared_worker) -> None:
+        self.model_runner.share_model(shared_worker.model_runner.model)
+
     def save_sharded_state(
         self,
         path: str,