123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- """Compare the outputs of HF and Aphrodite for Mistral models using greedy
- sampling.
- Run `pytest tests/models/decoder_only/vision_language/test_pixtral.py`.
- """
- import json
- import uuid
- from dataclasses import asdict
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
- import pytest
- from mistral_common.protocol.instruct.messages import ImageURLChunk
- from mistral_common.protocol.instruct.request import ChatCompletionRequest
- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
- from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
- from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
- from aphrodite.common.sequence import Logprob, SampleLogprobs
- from aphrodite.inputs import TokensPrompt
- from aphrodite.multimodal import MultiModalDataBuiltins
- from ....utils import APHRODITE_PATH
- from ...utils import check_logprobs_close
- if TYPE_CHECKING:
- from _typeshed import StrPath
- MODELS = ["mistralai/Pixtral-12B-2409"]
- IMG_URLS = [
- "https://picsum.photos/id/237/400/300",
- "https://picsum.photos/id/231/200/300",
- "https://picsum.photos/id/27/500/500",
- "https://picsum.photos/id/17/150/600",
- ]
- PROMPT = "Describe each image in one short sentence."
- def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
- return [{
- "role":
- "user",
- "content": [{
- "type": "text",
- "text": PROMPT,
- }] + [{
- "type": "image_url",
- "image_url": {
- "url": url
- }
- } for url in urls],
- }]
- def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
- msg = _create_msg_format(urls)
- tokenizer = MistralTokenizer.from_model("pixtral")
- request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
- tokenized = tokenizer.encode_chat_completion(request)
- engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
- images = []
- for chunk in request.messages[0].content:
- if isinstance(chunk, ImageURLChunk):
- images.append(image_from_chunk(chunk))
- mm_data = MultiModalDataBuiltins(image=images)
- engine_inputs["multi_modal_data"] = mm_data
- return engine_inputs
- MSGS = [
- _create_msg_format(IMG_URLS[:1]),
- _create_msg_format(IMG_URLS[:2]),
- _create_msg_format(IMG_URLS),
- ]
- ENGINE_INPUTS = [
- _create_engine_inputs(IMG_URLS[:1]),
- _create_engine_inputs(IMG_URLS[:2]),
- _create_engine_inputs(IMG_URLS),
- ]
- SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
- LIMIT_MM_PER_PROMPT = dict(image=4)
- MAX_MODEL_LEN = [8192, 65536]
- FIXTURES_PATH = APHRODITE_PATH / "tests/models/fixtures"
- assert FIXTURES_PATH.exists()
- FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json"
- FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json"
- OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
- # For the test author to store golden output in JSON
- def _dump_outputs_w_logprobs(
- outputs: OutputsLogprobs,
- filename: "StrPath",
- ) -> None:
- json_data = [(tokens, text,
- [{k: asdict(v)
- for k, v in token_logprobs.items()}
- for token_logprobs in (logprobs or [])])
- for tokens, text, logprobs in outputs]
- with open(filename, "w") as f:
- json.dump(json_data, f)
- def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
- with open(filename, "rb") as f:
- json_data = json.load(f)
- return [(tokens, text,
- [{int(k): Logprob(**v)
- for k, v in token_logprobs.items()}
- for token_logprobs in logprobs])
- for tokens, text, logprobs in json_data]
- @pytest.mark.skip(
- reason=
- "Model is too big, test passed on A100 locally but will OOM on CI machine."
- )
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
- @pytest.mark.parametrize("dtype", ["bfloat16"])
- def test_chat(
- aphrodite_runner,
- max_model_len: int,
- model: str,
- dtype: str,
- ) -> None:
- EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
- with aphrodite_runner(
- model,
- dtype=dtype,
- tokenizer_mode="mistral",
- enable_chunked_prefill=False,
- max_model_len=max_model_len,
- limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
- ) as aphrodite_model:
- outputs = []
- for msg in MSGS:
- output = aphrodite_model.model.chat(msg,
- sampling_params=SAMPLING_PARAMS)
- outputs.extend(output)
- logprobs = aphrodite_runner._final_steps_generate_w_logprobs(outputs)
- check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
- outputs_1_lst=logprobs,
- name_0="h100_ref",
- name_1="output")
- @pytest.mark.skip(
- reason=
- "Model is too big, test passed on A100 locally but will OOM on CI machine."
- )
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("dtype", ["bfloat16"])
- def test_model_engine(aphrodite_runner, model: str, dtype: str) -> None:
- EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
- args = EngineArgs(
- model=model,
- tokenizer_mode="mistral",
- enable_chunked_prefill=False,
- limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
- dtype=dtype,
- )
- engine = AphroditeEngine.from_engine_args(args)
- engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
- engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
- outputs = []
- count = 0
- while True:
- out = engine.step()
- count += 1
- for request_output in out:
- if request_output.finished:
- outputs.append(request_output)
- if count == 2:
- engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
- SAMPLING_PARAMS)
- if not engine.has_unfinished_requests():
- break
- logprobs = aphrodite_runner._final_steps_generate_w_logprobs(outputs)
- check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
- outputs_1_lst=logprobs,
- name_0="h100_ref",
- name_1="output")
|