|
@@ -5,6 +5,7 @@ from aphrodite.prompt_adapter.request import PromptAdapterRequest
|
|
|
MODEL_PATH = "bigscience/bloomz-560m"
|
|
|
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
|
|
|
|
|
|
+
|
|
|
def do_sample(llm, pa_name: str, pa_id: int):
|
|
|
# Sample prompts
|
|
|
prompts = [
|
|
@@ -13,16 +14,15 @@ def do_sample(llm, pa_name: str, pa_id: int):
|
|
|
"Tweet text : @nationalgridus Looks good thanks! Label : "
|
|
|
]
|
|
|
# Define sampling parameters
|
|
|
- sampling_params = SamplingParams(temperature=0.0, max_tokens=3,
|
|
|
+ sampling_params = SamplingParams(temperature=0.0,
|
|
|
+ max_tokens=3,
|
|
|
stop_token_ids=[3])
|
|
|
|
|
|
# Generate outputs using the LLM
|
|
|
- outputs = llm.generate(
|
|
|
- prompts,
|
|
|
- sampling_params,
|
|
|
- prompt_adapter_request=PromptAdapterRequest(pa_name, pa_id, PA_PATH,
|
|
|
- 8) if pa_id else None
|
|
|
- )
|
|
|
+ outputs = llm.generate(prompts,
|
|
|
+ sampling_params,
|
|
|
+ prompt_adapter_request=PromptAdapterRequest(
|
|
|
+ pa_name, pa_id, PA_PATH, 8) if pa_id else None)
|
|
|
|
|
|
# Print the outputs
|
|
|
for output in outputs:
|
|
@@ -30,17 +30,17 @@ def do_sample(llm, pa_name: str, pa_id: int):
|
|
|
generated_text = output.outputs[0].text.strip()
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
+
|
|
|
def main():
|
|
|
# Create an LLM with prompt adapter enabled
|
|
|
- llm = LLM(
|
|
|
- MODEL_PATH,
|
|
|
- enforce_eager=True,
|
|
|
- enable_prompt_adapter=True,
|
|
|
- max_prompt_adapter_token=8
|
|
|
- )
|
|
|
+ llm = LLM(MODEL_PATH,
|
|
|
+ enforce_eager=True,
|
|
|
+ enable_prompt_adapter=True,
|
|
|
+ max_prompt_adapter_token=8)
|
|
|
|
|
|
# Run the sampling function
|
|
|
do_sample(llm, "twitter_pa", pa_id=1)
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
main()
|