test_accuracy.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. """
  2. This file test accuracy of the Aphrodite server via LMEval.
  3. It uses local-completions, which interacts with Aphrodite
  4. through the OAI API with N concurrent connections.
  5. This simulates real work usage of the API and makes
  6. sure that the zmq frontend mp RPC message passing and
  7. AsyncAPhroditeEngine are working correctly.
  8. """
  9. import lm_eval
  10. import pytest
  11. from ...utils import RemoteOpenAIServer
  12. MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
  13. NUM_CONCURRENT = 500
  14. TASK = "gsm8k"
  15. FILTER = "exact_match,strict-match"
  16. RTOL = 0.03
  17. EXPECTED_VALUE = 0.58
  18. DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
  19. MORE_ARGS_LIST = [
  20. ["--enable-chunked-prefill"], # Chunked
  21. ["--num-scheduler-steps", "8"], # MS
  22. ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
  23. ]
  24. @pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
  25. def test_lm_eval_accuracy(more_args):
  26. args = list(DEFAULT_ARGS)
  27. args.extend(more_args)
  28. print(f"Running with: {args}")
  29. with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
  30. url = f"{remote_server.url_for('v1')}/completions"
  31. model_args = (
  32. f"model={MODEL_NAME},"
  33. f"base_url={url},"
  34. f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
  35. results = lm_eval.simple_evaluate(
  36. model="local-completions",
  37. model_args=model_args,
  38. tasks=TASK,
  39. )
  40. measured_value = results["results"][TASK][FILTER]
  41. assert (measured_value - RTOL < EXPECTED_VALUE
  42. and measured_value + RTOL > EXPECTED_VALUE
  43. ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"