test_serving_chat.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import asyncio
  2. from contextlib import suppress
  3. from dataclasses import dataclass
  4. from unittest.mock import MagicMock
  5. from aphrodite.common.config import MultiModalConfig
  6. from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
  7. from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
  8. from aphrodite.endpoints.openai.serving_engine import BaseModelPath
  9. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  10. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  11. MODEL_NAME = "openai-community/gpt2"
  12. CHAT_TEMPLATE = "Dummy chat template for testing {}"
  13. BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
  14. @dataclass
  15. class MockModelConfig:
  16. tokenizer = MODEL_NAME
  17. trust_remote_code = False
  18. tokenizer_mode = "auto"
  19. max_model_len = 100
  20. tokenizer_revision = None
  21. embedding_mode = False
  22. multimodal_config = MultiModalConfig()
  23. @dataclass
  24. class MockEngine:
  25. async def get_model_config(self):
  26. return MockModelConfig()
  27. async def _async_serving_chat_init():
  28. engine = MockEngine()
  29. model_config = await engine.get_model_config()
  30. serving_completion = OpenAIServingChat(engine,
  31. model_config,
  32. BASE_MODEL_PATHS,
  33. response_role="assistant",
  34. chat_template=CHAT_TEMPLATE,
  35. lora_modules=None,
  36. prompt_adapters=None,
  37. request_logger=None)
  38. return serving_completion
  39. def test_async_serving_chat_init():
  40. serving_completion = asyncio.run(_async_serving_chat_init())
  41. assert serving_completion.chat_template == CHAT_TEMPLATE
  42. def test_serving_chat_should_set_correct_max_tokens():
  43. mock_engine = MagicMock(spec=AsyncAphrodite)
  44. mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
  45. serving_chat = OpenAIServingChat(mock_engine,
  46. MockModelConfig(),
  47. BASE_MODEL_PATHS,
  48. response_role="assistant",
  49. chat_template=CHAT_TEMPLATE,
  50. lora_modules=None,
  51. prompt_adapters=None,
  52. request_logger=None)
  53. req = ChatCompletionRequest(
  54. model=MODEL_NAME,
  55. messages=[{
  56. "role": "user",
  57. "content": "what is 1+1?"
  58. }],
  59. guided_decoding_backend="outlines",
  60. )
  61. with suppress(Exception):
  62. asyncio.run(serving_chat.create_chat_completion(req))
  63. assert mock_engine.generate.call_args.args[1].max_tokens == 93
  64. req.max_tokens = 10
  65. with suppress(Exception):
  66. asyncio.run(serving_chat.create_chat_completion(req))
  67. assert mock_engine.generate.call_args.args[1].max_tokens == 10