aphrodite_engine_example.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import argparse
  2. from aphrodite import EngineArgs, AphroditeEngine, SamplingParams
  3. def main(args: argparse.Namespace):
  4. # Parse the CLI argument and initialize the engine.
  5. engine_args = EngineArgs.from_cli_args(args)
  6. engine = AphroditeEngine.from_engine_args(engine_args)
  7. # Test the following prompts.
  8. test_prompts = [
  9. ("<|system|>Enter chat mode.<|user|>Hello!<|model|>",
  10. SamplingParams(temperature=0.0)),
  11. ("<|system|>Enter RP mode.<|model|>Hello!<|user|>What are you doing?<|model|>",
  12. SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
  13. ("<|system|>Enter chat mode.<|user|>What is the meaning of life?<|model|>",
  14. SamplingParams(n=2,
  15. best_of=5,
  16. temperature=0.8,
  17. top_p=0.95,
  18. frequency_penalty=0.1)),
  19. ("<|system|>Enter QA mode.<|user|>What is a man?<|model|>A miserable",
  20. SamplingParams(n=3, best_of=3, use_beam_search=True,
  21. temperature=0.0)),
  22. ]
  23. # Run the engine by calling `engine.step()` manually.
  24. request_id = 0
  25. while True:
  26. # To test continuous batching, we add one request at each step.
  27. if test_prompts:
  28. prompt, sampling_params = test_prompts.pop(0)
  29. engine.add_request(str(request_id), prompt, sampling_params)
  30. request_id += 1
  31. request_outputs = engine.step()
  32. for request_output in request_outputs:
  33. if request_output.finished:
  34. print(request_output)
  35. if not (engine.has_unfinished_requests() or test_prompts):
  36. break
  37. if __name__ == '__main__':
  38. parser = argparse.ArgumentParser(
  39. description='Demo on using the AphroditeEngine class directly')
  40. parser = EngineArgs.add_cli_args(parser)
  41. args = parser.parse_args()
  42. main(args)