soft_prompt_inference.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from aphrodite import LLM, SamplingParams
  2. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  3. # Define the model and prompt adapter paths
  4. MODEL_PATH = "bigscience/bloomz-560m"
  5. PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
  6. def do_sample(llm, pa_name: str, pa_id: int):
  7. # Sample prompts
  8. prompts = [
  9. "Tweet text : @nationalgridus I have no water and the bill is \
  10. current and paid. Can you do something about this? Label : ",
  11. "Tweet text : @nationalgridus Looks good thanks! Label : "
  12. ]
  13. # Define sampling parameters
  14. sampling_params = SamplingParams(temperature=0.0,
  15. max_tokens=3,
  16. stop_token_ids=[3])
  17. # Generate outputs using the LLM
  18. outputs = llm.generate(prompts,
  19. sampling_params,
  20. prompt_adapter_request=PromptAdapterRequest(
  21. pa_name, pa_id, PA_PATH, 8) if pa_id else None)
  22. # Print the outputs
  23. for output in outputs:
  24. prompt = output.prompt
  25. generated_text = output.outputs[0].text.strip()
  26. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  27. def main():
  28. # Create an LLM with prompt adapter enabled
  29. llm = LLM(MODEL_PATH,
  30. enforce_eager=True,
  31. enable_prompt_adapter=True,
  32. max_prompt_adapter_token=8)
  33. # Run the sampling function
  34. do_sample(llm, "twitter_pa", pa_id=1)
  35. if __name__ == "__main__":
  36. main()