test_pixtral.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. """Compare the outputs of HF and Aphrodite for Pixtral models using greedy
  2. sampling.
  3. Run `pytest tests/models/test_pixtral.py`.
  4. """
  5. import pickle
  6. import uuid
  7. from typing import Any, Dict, List
  8. import pytest
  9. from mistral_common.protocol.instruct.messages import ImageURLChunk
  10. from mistral_common.protocol.instruct.request import ChatCompletionRequest
  11. from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
  12. from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
  13. from aphrodite import AphroditeEngine, EngineArgs, SamplingParams
  14. from aphrodite.inputs import TokensPrompt
  15. from aphrodite.multimodal import MultiModalDataBuiltins
  16. from .utils import check_logprobs_close
  17. pytestmark = pytest.mark.vlm
  18. MODELS = ["mistralai/Pixtral-12B-2409"]
  19. IMG_URLS = [
  20. "https://picsum.photos/id/237/400/300",
  21. "https://picsum.photos/id/231/200/300",
  22. "https://picsum.photos/id/27/500/500",
  23. "https://picsum.photos/id/17/150/600",
  24. ]
  25. PROMPT = "Describe each image in one short sentence."
  26. def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
  27. return [{
  28. "role":
  29. "user",
  30. "content": [{
  31. "type": "text",
  32. "text": PROMPT,
  33. }] + [{
  34. "type": "image_url",
  35. "image_url": {
  36. "url": url
  37. }
  38. } for url in urls],
  39. }]
  40. def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
  41. msg = _create_msg_format(urls)
  42. tokenizer = MistralTokenizer.from_model("pixtral")
  43. request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
  44. tokenized = tokenizer.encode_chat_completion(request)
  45. engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
  46. images = []
  47. for chunk in request.messages[0].content:
  48. if isinstance(chunk, ImageURLChunk):
  49. images.append(image_from_chunk(chunk))
  50. mm_data = MultiModalDataBuiltins(image=images)
  51. engine_inputs["multi_modal_data"] = mm_data
  52. return engine_inputs
  53. MSGS = [
  54. _create_msg_format(IMG_URLS[:1]),
  55. _create_msg_format(IMG_URLS[:2]),
  56. _create_msg_format(IMG_URLS),
  57. ]
  58. ENGINE_INPUTS = [
  59. _create_engine_inputs(IMG_URLS[:1]),
  60. _create_engine_inputs(IMG_URLS[:2]),
  61. _create_engine_inputs(IMG_URLS),
  62. ]
  63. SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
  64. LIMIT_MM_PER_PROMPT = dict(image=4)
  65. MAX_MODEL_LEN = [8192, 65536]
  66. FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
  67. FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
  68. def load_logprobs(filename: str) -> Any:
  69. with open(filename, 'rb') as f:
  70. return pickle.load(f)
  71. @pytest.mark.skip(
  72. reason=
  73. "Model is too big, test passed on A100 locally but will OOM on CI machine."
  74. )
  75. @pytest.mark.parametrize("model", MODELS)
  76. @pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
  77. @pytest.mark.parametrize("dtype", ["bfloat16"])
  78. def test_chat(
  79. aphrodite_runner,
  80. max_model_len: int,
  81. model: str,
  82. dtype: str,
  83. ) -> None:
  84. EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
  85. with aphrodite_runner(
  86. model,
  87. dtype=dtype,
  88. tokenizer_mode="mistral",
  89. enable_chunked_prefill=False,
  90. max_model_len=max_model_len,
  91. limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
  92. ) as aphrodite_model:
  93. outputs = []
  94. for msg in MSGS:
  95. output = aphrodite_model.model.chat(msg,
  96. sampling_params=SAMPLING_PARAMS)
  97. outputs.extend(output)
  98. logprobs = aphrodite_runner._final_steps_generate_w_logprobs(outputs)
  99. check_logprobs_close(outputs_0_lst=logprobs,
  100. outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
  101. name_0="output",
  102. name_1="h100_ref")
  103. @pytest.mark.skip(
  104. reason=
  105. "Model is too big, test passed on A100 locally but will OOM on CI machine."
  106. )
  107. @pytest.mark.parametrize("model", MODELS)
  108. @pytest.mark.parametrize("dtype", ["bfloat16"])
  109. def test_model_engine(aphrodite_runner, model: str, dtype: str) -> None:
  110. EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
  111. args = EngineArgs(
  112. model=model,
  113. tokenizer_mode="mistral",
  114. enable_chunked_prefill=False,
  115. limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
  116. dtype=dtype,
  117. )
  118. engine = AphroditeEngine.from_engine_args(args)
  119. engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
  120. engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
  121. outputs = []
  122. count = 0
  123. while True:
  124. out = engine.step()
  125. count += 1
  126. for request_output in out:
  127. if request_output.finished:
  128. outputs.append(request_output)
  129. if count == 2:
  130. engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
  131. SAMPLING_PARAMS)
  132. if not engine.has_unfinished_requests():
  133. break
  134. logprobs = aphrodite_runner._final_steps_generate_w_logprobs(outputs)
  135. check_logprobs_close(outputs_0_lst=logprobs,
  136. outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
  137. name_0="output",
  138. name_1="h100_ref")