test_generate.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import weakref
  2. from typing import List
  3. import pytest
  4. from aphrodite import LLM, RequestOutput, SamplingParams
  5. from ...conftest import cleanup
  6. from ..openai.test_vision import TEST_IMAGE_URLS
  7. MODEL_NAME = "facebook/opt-125m"
  8. PROMPTS = [
  9. "Hello, my name is",
  10. "The president of the United States is",
  11. "The capital of France is",
  12. "The future of AI is",
  13. ]
  14. TOKEN_IDS = [
  15. [0],
  16. [0, 1],
  17. [0, 2, 1],
  18. [0, 3, 1, 2],
  19. ]
  20. @pytest.fixture(scope="module")
  21. def llm():
  22. # pytest caches the fixture so we use weakref.proxy to
  23. # enable garbage collection
  24. llm = LLM(model=MODEL_NAME,
  25. max_num_batched_tokens=4096,
  26. tensor_parallel_size=1,
  27. gpu_memory_utilization=0.10,
  28. enforce_eager=True)
  29. with llm.deprecate_legacy_api():
  30. yield weakref.proxy(llm)
  31. del llm
  32. cleanup()
  33. def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
  34. assert [o.outputs for o in o1] == [o.outputs for o in o2]
  35. @pytest.mark.skip_global_cleanup
  36. @pytest.mark.parametrize('prompt', PROMPTS)
  37. def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
  38. sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
  39. with pytest.warns(DeprecationWarning, match="'prompts'"):
  40. v1_output = llm.generate(prompts=prompt,
  41. sampling_params=sampling_params)
  42. v2_output = llm.generate(prompt, sampling_params=sampling_params)
  43. assert_outputs_equal(v1_output, v2_output)
  44. v2_output = llm.generate({"prompt": prompt},
  45. sampling_params=sampling_params)
  46. assert_outputs_equal(v1_output, v2_output)
  47. @pytest.mark.skip_global_cleanup
  48. @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
  49. def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
  50. prompt_token_ids):
  51. sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
  52. with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
  53. v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
  54. sampling_params=sampling_params)
  55. v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
  56. sampling_params=sampling_params)
  57. assert_outputs_equal(v1_output, v2_output)
  58. @pytest.mark.skip_global_cleanup
  59. def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
  60. sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
  61. with pytest.warns(DeprecationWarning, match="'prompts'"):
  62. v1_output = llm.generate(prompts=PROMPTS,
  63. sampling_params=sampling_params)
  64. v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
  65. assert_outputs_equal(v1_output, v2_output)
  66. v2_output = llm.generate(
  67. [{
  68. "prompt": p
  69. } for p in PROMPTS],
  70. sampling_params=sampling_params,
  71. )
  72. assert_outputs_equal(v1_output, v2_output)
  73. @pytest.mark.skip_global_cleanup
  74. def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
  75. sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
  76. with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
  77. v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
  78. sampling_params=sampling_params)
  79. v2_output = llm.generate(
  80. [{
  81. "prompt_token_ids": p
  82. } for p in TOKEN_IDS],
  83. sampling_params=sampling_params,
  84. )
  85. assert_outputs_equal(v1_output, v2_output)
  86. @pytest.mark.skip_global_cleanup
  87. def test_multiple_sampling_params(llm: LLM):
  88. sampling_params = [
  89. SamplingParams(temperature=0.01, top_p=0.95),
  90. SamplingParams(temperature=0.3, top_p=0.95),
  91. SamplingParams(temperature=0.7, top_p=0.95),
  92. SamplingParams(temperature=0.99, top_p=0.95),
  93. ]
  94. # Multiple SamplingParams should be matched with each prompt
  95. outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
  96. assert len(PROMPTS) == len(outputs)
  97. # Exception raised, if the size of params does not match the size of prompts
  98. with pytest.raises(ValueError):
  99. outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
  100. # Single SamplingParams should be applied to every prompt
  101. single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
  102. outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
  103. assert len(PROMPTS) == len(outputs)
  104. # sampling_params is None, default params should be applied
  105. outputs = llm.generate(PROMPTS, sampling_params=None)
  106. assert len(PROMPTS) == len(outputs)
  107. def test_chat():
  108. llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
  109. prompt1 = "Explain the concept of entropy."
  110. messages = [
  111. {
  112. "role": "system",
  113. "content": "You are a helpful assistant"
  114. },
  115. {
  116. "role": "user",
  117. "content": prompt1
  118. },
  119. ]
  120. outputs = llm.chat(messages)
  121. assert len(outputs) == 1
  122. @pytest.mark.parametrize("image_urls",
  123. [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
  124. def test_chat_multi_image(image_urls: List[str]):
  125. llm = LLM(
  126. model="microsoft/Phi-3.5-vision-instruct",
  127. dtype="bfloat16",
  128. max_model_len=4096,
  129. max_num_seqs=5,
  130. enforce_eager=True,
  131. trust_remote_code=True,
  132. limit_mm_per_prompt={"image": 2},
  133. )
  134. messages = [{
  135. "role":
  136. "user",
  137. "content": [
  138. *({
  139. "type": "image_url",
  140. "image_url": {
  141. "url": image_url
  142. }
  143. } for image_url in image_urls),
  144. {
  145. "type": "text",
  146. "text": "What's in this image?"
  147. },
  148. ],
  149. }]
  150. outputs = llm.chat(messages)
  151. assert len(outputs) >= 0