1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import asyncio
- from contextlib import suppress
- from dataclasses import dataclass
- from unittest.mock import MagicMock
- from aphrodite.common.config import MultiModalConfig
- from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
- from aphrodite.endpoints.openai.serving_chat import OpenAIServingChat
- from aphrodite.endpoints.openai.serving_engine import BaseModelPath
- from aphrodite.engine.async_aphrodite import AsyncAphrodite
- from aphrodite.transformers_utils.tokenizer import get_tokenizer
- MODEL_NAME = "openai-community/gpt2"
- CHAT_TEMPLATE = "Dummy chat template for testing {}"
- BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
- @dataclass
- class MockModelConfig:
- tokenizer = MODEL_NAME
- trust_remote_code = False
- tokenizer_mode = "auto"
- max_model_len = 100
- tokenizer_revision = None
- embedding_mode = False
- multimodal_config = MultiModalConfig()
- @dataclass
- class MockEngine:
- async def get_model_config(self):
- return MockModelConfig()
- async def _async_serving_chat_init():
- engine = MockEngine()
- model_config = await engine.get_model_config()
- serving_completion = OpenAIServingChat(engine,
- model_config,
- BASE_MODEL_PATHS,
- response_role="assistant",
- chat_template=CHAT_TEMPLATE,
- lora_modules=None,
- prompt_adapters=None,
- request_logger=None)
- return serving_completion
- def test_async_serving_chat_init():
- serving_completion = asyncio.run(_async_serving_chat_init())
- assert serving_completion.chat_template == CHAT_TEMPLATE
- def test_serving_chat_should_set_correct_max_tokens():
- mock_engine = MagicMock(spec=AsyncAphrodite)
- mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
- serving_chat = OpenAIServingChat(mock_engine,
- MockModelConfig(),
- BASE_MODEL_PATHS,
- response_role="assistant",
- chat_template=CHAT_TEMPLATE,
- lora_modules=None,
- prompt_adapters=None,
- request_logger=None)
- req = ChatCompletionRequest(
- model=MODEL_NAME,
- messages=[{
- "role": "user",
- "content": "what is 1+1?"
- }],
- guided_decoding_backend="outlines",
- )
- with suppress(Exception):
- asyncio.run(serving_chat.create_chat_completion(req))
- assert mock_engine.generate.call_args.args[1].max_tokens == 93
- req.max_tokens = 10
- with suppress(Exception):
- asyncio.run(serving_chat.create_chat_completion(req))
- assert mock_engine.generate.call_args.args[1].max_tokens == 10
|