test_mistral.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. """Compare the outputs of HF and Aphrodite for Mistral models using greedy
  2. sampling.
  3. Run `pytest tests/models/test_mistral.py`.
  4. """
  5. import pytest
  6. from aphrodite import LLM, SamplingParams
  7. from ...utils import check_logprobs_close
  8. MODELS = [
  9. "mistralai/Mistral-7B-Instruct-v0.1",
  10. "mistralai/Mistral-7B-Instruct-v0.3",
  11. # Mistral-Nemo is to big for CI, but passes locally
  12. # "mistralai/Mistral-Nemo-Instruct-2407"
  13. ]
  14. SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
  15. SYMBOLIC_LANG_PROMPTS = [
  16. "勇敢な船乗りについての詩を書く", # japanese
  17. "寫一首關於勇敢的水手的詩", # chinese
  18. ]
  19. # for function calling
  20. TOOLS = [{
  21. "type": "function",
  22. "function": {
  23. "name": "get_current_weather",
  24. "description": "Get the current weather in a given location",
  25. "parameters": {
  26. "type": "object",
  27. "properties": {
  28. "city": {
  29. "type":
  30. "string",
  31. "description":
  32. "The city to find the weather for, e.g. 'San Francisco'"
  33. },
  34. "state": {
  35. "type":
  36. "string",
  37. "description":
  38. "the two-letter abbreviation for the state that the city is"
  39. " in, e.g. 'CA' which would mean 'California'"
  40. },
  41. "unit": {
  42. "type": "string",
  43. "description": "The unit to fetch the temperature in",
  44. "enum": ["celsius", "fahrenheit"]
  45. }
  46. },
  47. "required": ["city", "state", "unit"]
  48. }
  49. }
  50. }]
  51. MSGS = [{
  52. "role":
  53. "user",
  54. "content": ("Can you tell me what the temperate"
  55. " will be in Dallas, in fahrenheit?")
  56. }]
  57. EXPECTED_FUNC_CALL = (
  58. '[{"name": "get_current_weather", "arguments": '
  59. '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
  60. @pytest.mark.parametrize("model", MODELS)
  61. @pytest.mark.parametrize("dtype", ["bfloat16"])
  62. @pytest.mark.parametrize("max_tokens", [64])
  63. @pytest.mark.parametrize("num_logprobs", [5])
  64. def test_models(
  65. hf_runner,
  66. aphrodite_runner,
  67. example_prompts,
  68. model: str,
  69. dtype: str,
  70. max_tokens: int,
  71. num_logprobs: int,
  72. ) -> None:
  73. # TODO(sang): Sliding window should be tested separately.
  74. with hf_runner(model, dtype=dtype) as hf_model:
  75. hf_outputs = hf_model.generate_greedy_logprobs_limit(
  76. example_prompts, max_tokens, num_logprobs)
  77. with aphrodite_runner(model, dtype=dtype,
  78. tokenizer_mode="mistral") as aphrodite_model:
  79. aphrodite_outputs = aphrodite_model.generate_greedy_logprobs(
  80. example_prompts, max_tokens, num_logprobs)
  81. check_logprobs_close(
  82. outputs_0_lst=hf_outputs,
  83. outputs_1_lst=aphrodite_outputs,
  84. name_0="hf",
  85. name_1="aphrodite",
  86. )
  87. @pytest.mark.parametrize("model", MODELS[1:])
  88. @pytest.mark.parametrize("dtype", ["bfloat16"])
  89. @pytest.mark.parametrize("max_tokens", [64])
  90. @pytest.mark.parametrize("num_logprobs", [5])
  91. def test_mistral_format(
  92. aphrodite_runner,
  93. example_prompts,
  94. model: str,
  95. dtype: str,
  96. max_tokens: int,
  97. num_logprobs: int,
  98. ) -> None:
  99. with aphrodite_runner(
  100. model,
  101. dtype=dtype,
  102. tokenizer_mode="auto",
  103. load_format="safetensors",
  104. config_format="hf",
  105. ) as hf_format_model:
  106. hf_format_outputs = hf_format_model.generate_greedy_logprobs(
  107. example_prompts, max_tokens, num_logprobs)
  108. with aphrodite_runner(
  109. model,
  110. dtype=dtype,
  111. tokenizer_mode="mistral",
  112. load_format="mistral",
  113. config_format="mistral",
  114. ) as mistral_format_model:
  115. mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
  116. example_prompts, max_tokens, num_logprobs)
  117. check_logprobs_close(
  118. outputs_0_lst=hf_format_outputs,
  119. outputs_1_lst=mistral_format_outputs,
  120. name_0="hf",
  121. name_1="mistral",
  122. )
  123. @pytest.mark.parametrize("model", MODELS[1:])
  124. @pytest.mark.parametrize("dtype", ["bfloat16"])
  125. @pytest.mark.parametrize("prompt", SYMBOLIC_LANG_PROMPTS)
  126. def test_mistral_symbolic_languages(
  127. model: str,
  128. dtype: str,
  129. prompt: str,
  130. ) -> None:
  131. prompt = "hi"
  132. msg = {"role": "user", "content": prompt}
  133. llm = LLM(model=model,
  134. dtype=dtype,
  135. max_model_len=8192,
  136. tokenizer_mode="mistral",
  137. config_format="mistral",
  138. load_format="mistral")
  139. outputs = llm.chat([msg], sampling_params=SAMPLING_PARAMS)
  140. assert "�" not in outputs[0].outputs[0].text.strip()
  141. @pytest.mark.parametrize("dtype", ["bfloat16"])
  142. @pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
  143. def test_mistral_function_calling(
  144. aphrodite_runner,
  145. model: str,
  146. dtype: str,
  147. ) -> None:
  148. with aphrodite_runner(model,
  149. dtype=dtype,
  150. tokenizer_mode="mistral",
  151. config_format="mistral",
  152. load_format="mistral") as aphrodite_model:
  153. outputs = aphrodite_model.model.chat(MSGS,
  154. tools=TOOLS,
  155. sampling_params=SAMPLING_PARAMS)
  156. assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL