1
0
Эх сурвалжийг харах

clean-up and example script

AlpinDale 4 сар өмнө
parent
commit
2242cf6fc5

+ 2 - 1
aphrodite/cfg/cfg_model_runner.py

@@ -137,7 +137,8 @@ class CFGModelRunner(ModelRunner):
         )
 
         if self.return_hidden_states:
-            raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner")
+            raise NotImplementedError("return_hidden_states is not supported "
+                                      "in CFGModelRunner")
 
         return [output]
 

+ 1 - 1
aphrodite/cfg/cfg_worker.py

@@ -109,7 +109,7 @@ class CFGWorker(LoraNotSupportedWorkerBase):
                 negative_seq_data: Dict[int, SequenceData] = {}
                 negative_block_tables: Dict[int, List[int]] = {}
                 assert len(seq_group_metadata.seq_data) == 1
-                for seq_id in seq_group_metadata.seq_data.keys():
+                for seq_id in seq_group_metadata.seq_data:
                     negative_seq_data[
                         seq_id
                     ] = seq_group_metadata.negative_seq_data

+ 42 - 0
examples/offline_inference/cfg_inference.py

@@ -0,0 +1,42 @@
+from typing import List
+from aphrodite import LLM, SamplingParams
+from aphrodite.inputs import PromptInputs
+
+llm = LLM(
+    model="NousResearch/Meta-Llama-3.1-8B-Instruct",
+    use_v2_block_manager=True,
+    cfg_model="NousResearch/Meta-Llama-3.1-8B-Instruct",
+    max_model_len=8192,
+)
+
+prompt_pairs = [
+    {
+        "prompt": "Hello, my name is",
+        "negative_prompt": "I am uncertain and confused about who I am"
+    },
+    {
+        "prompt": "The president of the United States is",
+        "negative_prompt": "I don't know anything about US politics or leadership"
+    },
+]
+
+tokenizer = llm.get_tokenizer()
+
+inputs: List[PromptInputs] = [
+    {
+        "prompt_token_ids": tokenizer.encode(text=pair["prompt"]),
+        "negative_prompt_token_ids": tokenizer.encode(text=pair["negative_prompt"])
+    }
+    for pair in prompt_pairs
+]
+
+sampling_params = SamplingParams(guidance_scale=5.0, max_tokens=128)
+outputs = llm.generate(inputs, sampling_params)
+
+for i, output in enumerate(outputs):
+    prompt_pair = prompt_pairs[i]
+    generated_text = output.outputs[0].text
+    print(f"Prompt: {prompt_pair['prompt']!r}")
+    print(f"Negative Prompt: {prompt_pair['negative_prompt']!r}")
+    print(f"Generated text: {generated_text!r}")
+    print("-" * 50)