test_oot_registration.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from ...utils import APHRODITE_PATH, RemoteOpenAIServer
  2. chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
  3. assert chatml_jinja_path.exists()
  4. def run_and_test_dummy_opt_api_server(model, tp=1):
  5. # the model is registered through the plugin
  6. server_args = [
  7. "--gpu-memory-utilization",
  8. "0.10",
  9. "--dtype",
  10. "float32",
  11. "--chat-template",
  12. str(chatml_jinja_path),
  13. "--load-format",
  14. "dummy",
  15. "-tp",
  16. f"{tp}",
  17. ]
  18. with RemoteOpenAIServer(model, server_args) as server:
  19. client = server.get_client()
  20. completion = client.chat.completions.create(
  21. model=model,
  22. messages=[{
  23. "role": "system",
  24. "content": "You are a helpful assistant."
  25. }, {
  26. "role": "user",
  27. "content": "Hello!"
  28. }],
  29. temperature=0,
  30. )
  31. generated_text = completion.choices[0].message.content
  32. assert generated_text is not None
  33. # make sure only the first token is generated
  34. rest = generated_text.replace("<s>", "")
  35. assert rest == ""
  36. def test_oot_registration_for_api_server(dummy_opt_path: str):
  37. run_and_test_dummy_opt_api_server(dummy_opt_path)