1
0

chat_inference.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from aphrodite import LLM, SamplingParams
  2. llm = LLM(model="NousResearch/Meta-Llama-3.1-8B-Instruct")
  3. sampling_params = SamplingParams(temperature=0.5)
  4. def print_outputs(outputs):
  5. for output in outputs:
  6. prompt = output.prompt
  7. generated_text = output.outputs[0].text
  8. print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
  9. print("-" * 80)
  10. print("=" * 80)
  11. # In this script, we demonstrate how to pass input to the chat method:
  12. conversation = [
  13. {
  14. "role": "system",
  15. "content": "You are a helpful assistant"
  16. },
  17. {
  18. "role": "user",
  19. "content": "Hello"
  20. },
  21. {
  22. "role": "assistant",
  23. "content": "Hello! How can I assist you today?"
  24. },
  25. {
  26. "role": "user",
  27. "content": "Write an essay about the importance of higher education.",
  28. },
  29. ]
  30. outputs = llm.chat(conversation,
  31. sampling_params=sampling_params,
  32. use_tqdm=False)
  33. print_outputs(outputs)
  34. # You can run batch inference with llm.chat API
  35. conversation = [
  36. {
  37. "role": "system",
  38. "content": "You are a helpful assistant"
  39. },
  40. {
  41. "role": "user",
  42. "content": "Hello"
  43. },
  44. {
  45. "role": "assistant",
  46. "content": "Hello! How can I assist you today?"
  47. },
  48. {
  49. "role": "user",
  50. "content": "Write an essay about the importance of higher education.",
  51. },
  52. ]
  53. conversations = [conversation for _ in range(10)]
  54. # We turn on tqdm progress bar to verify it's indeed running batch inference
  55. outputs = llm.chat(messages=conversations,
  56. sampling_params=sampling_params,
  57. use_tqdm=True)
  58. print_outputs(outputs)
  59. # A chat template can be optionally supplied.
  60. # If not, the model will use its default chat template.
  61. # with open('template_falcon_180b.jinja', "r") as f:
  62. # chat_template = f.read()
  63. # outputs = llm.chat(
  64. # conversations,
  65. # sampling_params=sampling_params,
  66. # use_tqdm=False,
  67. # chat_template=chat_template,
  68. # )