test_serving_chat.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import asyncio
  2. from contextlib import suppress
  3. from dataclasses import dataclass
  4. from unittest.mock import MagicMock
  5. from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
  6. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  7. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  8. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  9. MODEL_NAME = "openai-community/gpt2"
  10. CHAT_TEMPLATE = "Dummy chat template for testing {}"
  11. @dataclass
  12. class MockModelConfig:
  13. tokenizer = MODEL_NAME
  14. trust_remote_code = False
  15. tokenizer_mode = "auto"
  16. max_model_len = 100
  17. tokenizer_revision = None
  18. embedding_mode = False
  19. @dataclass
  20. class MockEngine:
  21. async def get_model_config(self):
  22. return MockModelConfig()
  23. async def _async_serving_chat_init():
  24. engine = MockEngine()
  25. model_config = await engine.get_model_config()
  26. serving_completion = OpenAIServingChat(engine,
  27. model_config,
  28. served_model_names=[MODEL_NAME],
  29. response_role="assistant",
  30. chat_template=CHAT_TEMPLATE,
  31. lora_modules=None,
  32. prompt_adapters=None,
  33. request_logger=None)
  34. return serving_completion
  35. def test_async_serving_chat_init():
  36. serving_completion = asyncio.run(_async_serving_chat_init())
  37. assert serving_completion.chat_template == CHAT_TEMPLATE
  38. def test_serving_chat_should_set_correct_max_tokens():
  39. mock_engine = MagicMock(spec=AsyncAphrodite)
  40. mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
  41. serving_chat = OpenAIServingChat(mock_engine,
  42. MockModelConfig(),
  43. served_model_names=[MODEL_NAME],
  44. response_role="assistant",
  45. chat_template=CHAT_TEMPLATE,
  46. lora_modules=None,
  47. prompt_adapters=None,
  48. request_logger=None)
  49. req = ChatCompletionRequest(
  50. model=MODEL_NAME,
  51. messages=[{
  52. "role": "user",
  53. "content": "what is 1+1?"
  54. }],
  55. guided_decoding_backend="outlines",
  56. )
  57. with suppress(Exception):
  58. asyncio.run(serving_chat.create_chat_completion(req))
  59. assert mock_engine.generate.call_args.args[1].max_tokens == 93
  60. req.max_tokens = 10
  61. with suppress(Exception):
  62. asyncio.run(serving_chat.create_chat_completion(req))
  63. assert mock_engine.generate.call_args.args[1].max_tokens == 10