test_vision.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. from typing import Dict, List
  2. import openai
  3. import pytest
  4. from aphrodite.multimodal.utils import encode_image_base64, fetch_image
  5. from ...utils import APHRODITE_PATH, RemoteOpenAIServer
  6. MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
  7. LLAVA_CHAT_TEMPLATE = APHRODITE_PATH / "examples/chat_templates/llava.jinja"
  8. assert LLAVA_CHAT_TEMPLATE.exists()
  9. # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
  10. TEST_IMAGE_URLS = [
  11. "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
  12. "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
  13. "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
  14. "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
  15. ]
  16. @pytest.fixture(scope="module")
  17. def server():
  18. args = [
  19. "--dtype",
  20. "bfloat16",
  21. "--max-model-len",
  22. "4096",
  23. "--enforce-eager",
  24. "--chat-template",
  25. str(LLAVA_CHAT_TEMPLATE),
  26. ]
  27. with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
  28. yield remote_server
  29. @pytest.fixture(scope="module")
  30. def client(server):
  31. return server.get_async_client()
  32. @pytest.fixture(scope="session")
  33. def base64_encoded_image() -> Dict[str, str]:
  34. return {
  35. image_url: encode_image_base64(fetch_image(image_url))
  36. for image_url in TEST_IMAGE_URLS
  37. }
  38. @pytest.mark.asyncio
  39. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  40. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  41. async def test_single_chat_session_image(client: openai.AsyncOpenAI,
  42. model_name: str, image_url: str):
  43. messages = [{
  44. "role":
  45. "user",
  46. "content": [
  47. {
  48. "type": "image_url",
  49. "image_url": {
  50. "url": image_url
  51. }
  52. },
  53. {
  54. "type": "text",
  55. "text": "What's in this image?"
  56. },
  57. ],
  58. }]
  59. # test single completion
  60. chat_completion = await client.chat.completions.create(model=model_name,
  61. messages=messages,
  62. max_tokens=10,
  63. logprobs=True,
  64. top_logprobs=5)
  65. assert len(chat_completion.choices) == 1
  66. choice = chat_completion.choices[0]
  67. assert choice.finish_reason == "length"
  68. assert chat_completion.usage == openai.types.CompletionUsage(
  69. completion_tokens=10, prompt_tokens=596, total_tokens=606)
  70. message = choice.message
  71. message = chat_completion.choices[0].message
  72. assert message.content is not None and len(message.content) >= 10
  73. assert message.role == "assistant"
  74. messages.append({"role": "assistant", "content": message.content})
  75. # test multi-turn dialogue
  76. messages.append({"role": "user", "content": "express your result in json"})
  77. chat_completion = await client.chat.completions.create(
  78. model=model_name,
  79. messages=messages,
  80. max_tokens=10,
  81. )
  82. message = chat_completion.choices[0].message
  83. assert message.content is not None and len(message.content) >= 0
  84. @pytest.mark.asyncio
  85. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  86. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  87. async def test_single_chat_session_image_base64encoded(
  88. client: openai.AsyncOpenAI, model_name: str, image_url: str,
  89. base64_encoded_image: Dict[str, str]):
  90. messages = [{
  91. "role":
  92. "user",
  93. "content": [
  94. {
  95. "type": "image_url",
  96. "image_url": {
  97. "url":
  98. f"data:image/jpeg;base64,{base64_encoded_image[image_url]}"
  99. }
  100. },
  101. {
  102. "type": "text",
  103. "text": "What's in this image?"
  104. },
  105. ],
  106. }]
  107. # test single completion
  108. chat_completion = await client.chat.completions.create(model=model_name,
  109. messages=messages,
  110. max_tokens=10,
  111. logprobs=True,
  112. top_logprobs=5)
  113. assert len(chat_completion.choices) == 1
  114. choice = chat_completion.choices[0]
  115. assert choice.finish_reason == "length"
  116. assert chat_completion.usage == openai.types.CompletionUsage(
  117. completion_tokens=10, prompt_tokens=596, total_tokens=606)
  118. message = choice.message
  119. message = chat_completion.choices[0].message
  120. assert message.content is not None and len(message.content) >= 10
  121. assert message.role == "assistant"
  122. messages.append({"role": "assistant", "content": message.content})
  123. # test multi-turn dialogue
  124. messages.append({"role": "user", "content": "express your result in json"})
  125. chat_completion = await client.chat.completions.create(
  126. model=model_name,
  127. messages=messages,
  128. max_tokens=10,
  129. )
  130. message = chat_completion.choices[0].message
  131. assert message.content is not None and len(message.content) >= 0
  132. @pytest.mark.asyncio
  133. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  134. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  135. async def test_chat_streaming_image(client: openai.AsyncOpenAI,
  136. model_name: str, image_url: str):
  137. messages = [{
  138. "role":
  139. "user",
  140. "content": [
  141. {
  142. "type": "image_url",
  143. "image_url": {
  144. "url": image_url
  145. }
  146. },
  147. {
  148. "type": "text",
  149. "text": "What's in this image?"
  150. },
  151. ],
  152. }]
  153. # test single completion
  154. chat_completion = await client.chat.completions.create(
  155. model=model_name,
  156. messages=messages,
  157. max_tokens=10,
  158. temperature=0.0,
  159. )
  160. output = chat_completion.choices[0].message.content
  161. stop_reason = chat_completion.choices[0].finish_reason
  162. # test streaming
  163. stream = await client.chat.completions.create(
  164. model=model_name,
  165. messages=messages,
  166. max_tokens=10,
  167. temperature=0.0,
  168. stream=True,
  169. )
  170. chunks: List[str] = []
  171. finish_reason_count = 0
  172. async for chunk in stream:
  173. delta = chunk.choices[0].delta
  174. if delta.role:
  175. assert delta.role == "assistant"
  176. if delta.content:
  177. chunks.append(delta.content)
  178. if chunk.choices[0].finish_reason is not None:
  179. finish_reason_count += 1
  180. # finish reason should only return in last block
  181. assert finish_reason_count == 1
  182. assert chunk.choices[0].finish_reason == stop_reason
  183. assert delta.content
  184. assert "".join(chunks) == output
  185. @pytest.mark.asyncio
  186. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  187. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  188. async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
  189. image_url: str):
  190. messages = [{
  191. "role":
  192. "user",
  193. "content": [
  194. {
  195. "type": "image_url",
  196. "image_url": {
  197. "url": image_url
  198. }
  199. },
  200. {
  201. "type": "image_url",
  202. "image_url": {
  203. "url": image_url
  204. }
  205. },
  206. {
  207. "type": "text",
  208. "text": "What's in this image?"
  209. },
  210. ],
  211. }]
  212. with pytest.raises(openai.BadRequestError): # test multi-image input
  213. await client.chat.completions.create(
  214. model=model_name,
  215. messages=messages,
  216. max_tokens=10,
  217. temperature=0.0,
  218. )
  219. # the server should still work afterwards
  220. completion = await client.completions.create(
  221. model=model_name,
  222. prompt=[0, 0, 0, 0, 0],
  223. max_tokens=5,
  224. temperature=0.0,
  225. )
  226. completion = completion.choices[0].text
  227. assert completion is not None and len(completion) >= 0