test_pixtral.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. """Compare the outputs of HF and Aphrodite for Mistral models using greedy
  2. sampling.
  3. Run `pytest tests/models/decoder_only/vision_language/test_pixtral.py`.
  4. """
  5. import json
  6. import uuid
  7. from dataclasses import asdict
  8. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
  9. import pytest
  10. from mistral_common.protocol.instruct.messages import ImageURLChunk
  11. from mistral_common.protocol.instruct.request import ChatCompletionRequest
  12. from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
  13. from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
  14. from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
  15. from aphrodite.common.sequence import Logprob, SampleLogprobs
  16. from aphrodite.inputs import TokensPrompt
  17. from aphrodite.multimodal import MultiModalDataBuiltins
  18. from ....utils import APHRODITE_PATH
  19. from ...utils import check_logprobs_close
  20. if TYPE_CHECKING:
  21. from _typeshed import StrPath
  22. MODELS = ["mistralai/Pixtral-12B-2409"]
  23. IMG_URLS = [
  24. "https://picsum.photos/id/237/400/300",
  25. "https://picsum.photos/id/231/200/300",
  26. "https://picsum.photos/id/27/500/500",
  27. "https://picsum.photos/id/17/150/600",
  28. ]
  29. PROMPT = "Describe each image in one short sentence."
  30. def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
  31. return [{
  32. "role":
  33. "user",
  34. "content": [{
  35. "type": "text",
  36. "text": PROMPT,
  37. }] + [{
  38. "type": "image_url",
  39. "image_url": {
  40. "url": url
  41. }
  42. } for url in urls],
  43. }]
  44. def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
  45. msg = _create_msg_format(urls)
  46. tokenizer = MistralTokenizer.from_model("pixtral")
  47. request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
  48. tokenized = tokenizer.encode_chat_completion(request)
  49. engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
  50. images = []
  51. for chunk in request.messages[0].content:
  52. if isinstance(chunk, ImageURLChunk):
  53. images.append(image_from_chunk(chunk))
  54. mm_data = MultiModalDataBuiltins(image=images)
  55. engine_inputs["multi_modal_data"] = mm_data
  56. return engine_inputs
  57. MSGS = [
  58. _create_msg_format(IMG_URLS[:1]),
  59. _create_msg_format(IMG_URLS[:2]),
  60. _create_msg_format(IMG_URLS),
  61. ]
  62. ENGINE_INPUTS = [
  63. _create_engine_inputs(IMG_URLS[:1]),
  64. _create_engine_inputs(IMG_URLS[:2]),
  65. _create_engine_inputs(IMG_URLS),
  66. ]
  67. SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
  68. LIMIT_MM_PER_PROMPT = dict(image=4)
  69. MAX_MODEL_LEN = [8192, 65536]
  70. FIXTURES_PATH = APHRODITE_PATH / "tests/models/fixtures"
  71. assert FIXTURES_PATH.exists()
  72. FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json"
  73. FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json"
  74. OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
  75. # For the test author to store golden output in JSON
  76. def _dump_outputs_w_logprobs(
  77. outputs: OutputsLogprobs,
  78. filename: "StrPath",
  79. ) -> None:
  80. json_data = [(tokens, text,
  81. [{k: asdict(v)
  82. for k, v in token_logprobs.items()}
  83. for token_logprobs in (logprobs or [])])
  84. for tokens, text, logprobs in outputs]
  85. with open(filename, "w") as f:
  86. json.dump(json_data, f)
  87. def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
  88. with open(filename, "rb") as f:
  89. json_data = json.load(f)
  90. return [(tokens, text,
  91. [{int(k): Logprob(**v)
  92. for k, v in token_logprobs.items()}
  93. for token_logprobs in logprobs])
  94. for tokens, text, logprobs in json_data]
  95. @pytest.mark.skip(
  96. reason=
  97. "Model is too big, test passed on A100 locally but will OOM on CI machine."
  98. )
  99. @pytest.mark.parametrize("model", MODELS)
  100. @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
  101. @pytest.mark.parametrize("dtype", ["bfloat16"])
  102. def test_chat(
  103. aphrodite_runner,
  104. max_model_len: int,
  105. model: str,
  106. dtype: str,
  107. ) -> None:
  108. EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
  109. with aphrodite_runner(
  110. model,
  111. dtype=dtype,
  112. tokenizer_mode="mistral",
  113. enable_chunked_prefill=False,
  114. max_model_len=max_model_len,
  115. limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
  116. ) as aphrodite_model:
  117. outputs = []
  118. for msg in MSGS:
  119. output = aphrodite_model.model.chat(msg,
  120. sampling_params=SAMPLING_PARAMS)
  121. outputs.extend(output)
  122. logprobs = aphrodite_runner._final_steps_generate_w_logprobs(outputs)
  123. check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
  124. outputs_1_lst=logprobs,
  125. name_0="h100_ref",
  126. name_1="output")
  127. @pytest.mark.skip(
  128. reason=
  129. "Model is too big, test passed on A100 locally but will OOM on CI machine."
  130. )
  131. @pytest.mark.parametrize("model", MODELS)
  132. @pytest.mark.parametrize("dtype", ["bfloat16"])
  133. def test_model_engine(aphrodite_runner, model: str, dtype: str) -> None:
  134. EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
  135. args = EngineArgs(
  136. model=model,
  137. tokenizer_mode="mistral",
  138. enable_chunked_prefill=False,
  139. limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
  140. dtype=dtype,
  141. )
  142. engine = AphroditeEngine.from_engine_args(args)
  143. engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
  144. engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
  145. outputs = []
  146. count = 0
  147. while True:
  148. out = engine.step()
  149. count += 1
  150. for request_output in out:
  151. if request_output.finished:
  152. outputs.append(request_output)
  153. if count == 2:
  154. engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
  155. SAMPLING_PARAMS)
  156. if not engine.has_unfinished_requests():
  157. break
  158. logprobs = aphrodite_runner._final_steps_generate_w_logprobs(outputs)
  159. check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
  160. outputs_1_lst=logprobs,
  161. name_0="h100_ref",
  162. name_1="output")