test_serving_chat.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import asyncio
  2. from contextlib import suppress
  3. from dataclasses import dataclass
  4. from unittest.mock import MagicMock
  5. from aphrodite.common.config import MultiModalConfig
  6. from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
  7. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  8. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  9. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  10. MODEL_NAME = "openai-community/gpt2"
  11. CHAT_TEMPLATE = "Dummy chat template for testing {}"
  12. @dataclass
  13. class MockModelConfig:
  14. tokenizer = MODEL_NAME
  15. trust_remote_code = False
  16. tokenizer_mode = "auto"
  17. max_model_len = 100
  18. tokenizer_revision = None
  19. embedding_mode = False
  20. multimodal_config = MultiModalConfig()
  21. @dataclass
  22. class MockEngine:
  23. async def get_model_config(self):
  24. return MockModelConfig()
  25. async def _async_serving_chat_init():
  26. engine = MockEngine()
  27. model_config = await engine.get_model_config()
  28. serving_completion = OpenAIServingChat(engine,
  29. model_config,
  30. served_model_names=[MODEL_NAME],
  31. response_role="assistant",
  32. chat_template=CHAT_TEMPLATE,
  33. lora_modules=None,
  34. prompt_adapters=None,
  35. request_logger=None)
  36. return serving_completion
  37. def test_async_serving_chat_init():
  38. serving_completion = asyncio.run(_async_serving_chat_init())
  39. assert serving_completion.chat_template == CHAT_TEMPLATE
  40. def test_serving_chat_should_set_correct_max_tokens():
  41. mock_engine = MagicMock(spec=AsyncAphrodite)
  42. mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
  43. serving_chat = OpenAIServingChat(mock_engine,
  44. MockModelConfig(),
  45. served_model_names=[MODEL_NAME],
  46. response_role="assistant",
  47. chat_template=CHAT_TEMPLATE,
  48. lora_modules=None,
  49. prompt_adapters=None,
  50. request_logger=None)
  51. req = ChatCompletionRequest(
  52. model=MODEL_NAME,
  53. messages=[{
  54. "role": "user",
  55. "content": "what is 1+1?"
  56. }],
  57. guided_decoding_backend="outlines",
  58. )
  59. with suppress(Exception):
  60. asyncio.run(serving_chat.create_chat_completion(req))
  61. assert mock_engine.generate.call_args.args[1].max_tokens == 93
  62. req.max_tokens = 10
  63. with suppress(Exception):
  64. asyncio.run(serving_chat.create_chat_completion(req))
  65. assert mock_engine.generate.call_args.args[1].max_tokens == 10