aphrodite_engine_example.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import argparse
  2. from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
  3. def main(args: argparse.Namespace):
  4. # Parse the CLI argument and initialize the engine.
  5. engine_args = EngineArgs.from_cli_args(args)
  6. engine = AphroditeEngine.from_engine_args(engine_args)
  7. # Test the following prompts.
  8. test_prompts = [
  9. ("<|system|>Enter chat mode.<|user|>Hello!<|model|>",
  10. SamplingParams(temperature=0.0)),
  11. (
  12. "<|system|>Enter RP mode.<|model|>Hello!<|user|>What are you doing?<|model|>", # noqa: E501
  13. SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
  14. (
  15. "<|system|>Enter chat mode.<|user|>What is the meaning of life?<|model|>", # noqa: E501
  16. SamplingParams(n=2,
  17. best_of=5,
  18. temperature=0.8,
  19. top_p=0.95,
  20. frequency_penalty=0.1)),
  21. ("<|system|>Enter QA mode.<|user|>What is a man?<|model|>A miserable",
  22. SamplingParams(n=3, best_of=3, use_beam_search=True,
  23. temperature=0.0)),
  24. ]
  25. # Run the engine by calling `engine.step()` manually.
  26. request_id = 0
  27. while True:
  28. # To test continuous batching, we add one request at each step.
  29. if test_prompts:
  30. prompt, sampling_params = test_prompts.pop(0)
  31. engine.add_request(str(request_id), prompt, sampling_params)
  32. request_id += 1
  33. request_outputs = engine.step()
  34. for request_output in request_outputs:
  35. if request_output.finished:
  36. print(request_output)
  37. if not (engine.has_unfinished_requests() or test_prompts):
  38. break
  39. if __name__ == '__main__':
  40. parser = argparse.ArgumentParser(
  41. description='Demo on using the AphroditeEngine class directly')
  42. parser = EngineArgs.add_cli_args(parser)
  43. args = parser.parse_args()
  44. main(args)