test_audio.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import math
  2. import sys
  3. import time
  4. from typing import Dict, List, Optional, Tuple, Union, cast
  5. from unittest.mock import patch
  6. import librosa
  7. import numpy as np
  8. import openai
  9. import pytest
  10. import requests
  11. import torch
  12. from aphrodite import ModelRegistry
  13. from aphrodite.common.config import MultiModalConfig
  14. from aphrodite.common.utils import get_open_port
  15. from aphrodite.inputs import INPUT_REGISTRY
  16. from aphrodite.inputs.data import LLMInputs
  17. from aphrodite.inputs.registry import InputContext
  18. from aphrodite.modeling.models.interfaces import SupportsMultiModal
  19. from aphrodite.modeling.models.opt import OPTForCausalLM
  20. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  21. from aphrodite.multimodal.base import MultiModalInputs
  22. from aphrodite.multimodal.image import (cached_get_tokenizer,
  23. repeat_and_pad_image_tokens)
  24. from aphrodite.multimodal.utils import encode_audio_base64, fetch_audio
  25. from ...utils import APHRODITE_PATH
  26. chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
  27. assert chatml_jinja_path.exists()
  28. MODEL_NAME = "facebook/opt-125m"
  29. TEST_AUDIO_URLS = [
  30. "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
  31. ]
  32. def server_function(port):
  33. def fake_input_mapper(ctx: InputContext, data: object):
  34. assert isinstance(data, tuple)
  35. (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
  36. # Resample it to 1 sample per second
  37. audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
  38. return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})
  39. def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
  40. multi_modal_data = llm_inputs.get("multi_modal_data")
  41. if multi_modal_data is None or "audio" not in multi_modal_data:
  42. return llm_inputs
  43. audio, sr = multi_modal_data.get("audio")
  44. audio_duration = math.ceil(len(audio) / sr)
  45. new_prompt, new_token_ids = repeat_and_pad_image_tokens(
  46. cached_get_tokenizer(ctx.model_config.tokenizer),
  47. llm_inputs.get("prompt"),
  48. llm_inputs["prompt_token_ids"],
  49. image_token_id=62, # "_"
  50. repeat_count=audio_duration)
  51. return LLMInputs(prompt_token_ids=new_token_ids,
  52. prompt=new_prompt,
  53. multi_modal_data=multi_modal_data)
  54. @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
  55. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  56. "audio", lambda *_, **__: 100)
  57. @INPUT_REGISTRY.register_input_processor(fake_input_processor)
  58. class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):
  59. def __init__(self, *args, multimodal_config: MultiModalConfig,
  60. **kwargs):
  61. assert multimodal_config is not None
  62. super().__init__(*args, **kwargs)
  63. def forward(
  64. self,
  65. *args,
  66. processed_audio: Optional[torch.Tensor] = None,
  67. **kwargs,
  68. ) -> torch.Tensor:
  69. return super().forward(*args, **kwargs)
  70. ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)
  71. with patch(
  72. "aphrodite.endpoints.chat_utils._mm_token_str",
  73. lambda *_, **__: "_"), patch(
  74. "aphrodite.modeling.models.ModelRegistry.is_multimodal_model"
  75. ) as mock:
  76. mock.return_value = True
  77. sys.argv = ["placeholder.py"] + \
  78. (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
  79. "--dtype bfloat16 --enforce-eager --api-key token-abc123 "
  80. f"--port {port} --chat-template {chatml_jinja_path} "
  81. "--disable-frontend-multiprocessing").split()
  82. import runpy
  83. runpy.run_module('aphrodite.endpoints.openai.api_server',
  84. run_name='__main__')
  85. @pytest.fixture(scope="module")
  86. def client():
  87. port = get_open_port()
  88. ctx = torch.multiprocessing.get_context("spawn")
  89. server = ctx.Process(target=server_function, args=(port, ))
  90. server.start()
  91. MAX_SERVER_START_WAIT_S = 60
  92. client = openai.AsyncOpenAI(
  93. base_url=f"http://localhost:{port}/v1",
  94. api_key="token-abc123",
  95. )
  96. # run health check
  97. health_url = f"http://localhost:{port}/health"
  98. start = time.time()
  99. while True:
  100. try:
  101. if requests.get(health_url).status_code == 200:
  102. break
  103. except Exception as err:
  104. result = server.exitcode
  105. if result is not None:
  106. raise RuntimeError("Server exited unexpectedly.") from err
  107. time.sleep(0.5)
  108. if time.time() - start > MAX_SERVER_START_WAIT_S:
  109. raise RuntimeError("Server failed to start in time.") from err
  110. try:
  111. yield client
  112. finally:
  113. server.kill()
  114. @pytest.fixture(scope="session")
  115. def base64_encoded_audio() -> Dict[str, str]:
  116. return {
  117. audio_url: encode_audio_base64(*fetch_audio(audio_url))
  118. for audio_url in TEST_AUDIO_URLS
  119. }
  120. @pytest.mark.asyncio
  121. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  122. @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
  123. async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
  124. model_name: str, audio_url: str):
  125. messages = [{
  126. "role":
  127. "user",
  128. "content": [
  129. {
  130. "type": "audio_url",
  131. "audio_url": {
  132. "url": audio_url
  133. }
  134. },
  135. {
  136. "type": "text",
  137. "text": "What's happening in this audio?"
  138. },
  139. ],
  140. }]
  141. # test single completion
  142. chat_completion = await client.chat.completions.create(model=model_name,
  143. messages=messages,
  144. max_tokens=10,
  145. logprobs=True,
  146. top_logprobs=5)
  147. assert len(chat_completion.choices) == 1
  148. choice = chat_completion.choices[0]
  149. assert choice.finish_reason == "length"
  150. assert chat_completion.usage == openai.types.CompletionUsage(
  151. completion_tokens=10, prompt_tokens=36, total_tokens=46)
  152. message = choice.message
  153. message = chat_completion.choices[0].message
  154. assert message.content is not None and len(message.content) >= 10
  155. assert message.role == "assistant"
  156. messages.append({"role": "assistant", "content": message.content})
  157. # test multi-turn dialogue
  158. messages.append({"role": "user", "content": "express your result in json"})
  159. chat_completion = await client.chat.completions.create(
  160. model=model_name,
  161. messages=messages,
  162. max_tokens=10,
  163. )
  164. message = chat_completion.choices[0].message
  165. assert message.content is not None and len(message.content) >= 0
  166. @pytest.mark.asyncio
  167. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  168. @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
  169. async def test_single_chat_session_audio_base64encoded(
  170. client: openai.AsyncOpenAI, model_name: str, audio_url: str,
  171. base64_encoded_audio: Dict[str, str]):
  172. messages = [{
  173. "role":
  174. "user",
  175. "content": [
  176. {
  177. "type": "audio_url",
  178. "audio_url": {
  179. "url":
  180. f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
  181. }
  182. },
  183. {
  184. "type": "text",
  185. "text": "What's happening in this audio?"
  186. },
  187. ],
  188. }]
  189. # test single completion
  190. chat_completion = await client.chat.completions.create(model=model_name,
  191. messages=messages,
  192. max_tokens=10,
  193. logprobs=True,
  194. top_logprobs=5)
  195. assert len(chat_completion.choices) == 1
  196. choice = chat_completion.choices[0]
  197. assert choice.finish_reason == "length"
  198. assert chat_completion.usage == openai.types.CompletionUsage(
  199. completion_tokens=10, prompt_tokens=36, total_tokens=46)
  200. message = choice.message
  201. message = chat_completion.choices[0].message
  202. assert message.content is not None and len(message.content) >= 10
  203. assert message.role == "assistant"
  204. messages.append({"role": "assistant", "content": message.content})
  205. # test multi-turn dialogue
  206. messages.append({"role": "user", "content": "express your result in json"})
  207. chat_completion = await client.chat.completions.create(
  208. model=model_name,
  209. messages=messages,
  210. max_tokens=10,
  211. )
  212. message = chat_completion.choices[0].message
  213. assert message.content is not None and len(message.content) >= 0
  214. @pytest.mark.asyncio
  215. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  216. @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
  217. async def test_chat_streaming_audio(client: openai.AsyncOpenAI,
  218. model_name: str, audio_url: str):
  219. messages = [{
  220. "role":
  221. "user",
  222. "content": [
  223. {
  224. "type": "audio_url",
  225. "audio_url": {
  226. "url": audio_url
  227. }
  228. },
  229. {
  230. "type": "text",
  231. "text": "What's happening in this audio?"
  232. },
  233. ],
  234. }]
  235. # test single completion
  236. chat_completion = await client.chat.completions.create(
  237. model=model_name,
  238. messages=messages,
  239. max_tokens=10,
  240. temperature=0.0,
  241. )
  242. output = chat_completion.choices[0].message.content
  243. stop_reason = chat_completion.choices[0].finish_reason
  244. # test streaming
  245. stream = await client.chat.completions.create(
  246. model=model_name,
  247. messages=messages,
  248. max_tokens=10,
  249. temperature=0.0,
  250. stream=True,
  251. )
  252. chunks: List[str] = []
  253. finish_reason_count = 0
  254. async for chunk in stream:
  255. delta = chunk.choices[0].delta
  256. if delta.role:
  257. assert delta.role == "assistant"
  258. if delta.content:
  259. chunks.append(delta.content)
  260. if chunk.choices[0].finish_reason is not None:
  261. finish_reason_count += 1
  262. # finish reason should only return in last block
  263. assert finish_reason_count == 1
  264. assert chunk.choices[0].finish_reason == stop_reason
  265. assert delta.content
  266. assert "".join(chunks) == output
  267. @pytest.mark.asyncio
  268. @pytest.mark.parametrize("model_name", [MODEL_NAME])
  269. @pytest.mark.parametrize("audio_url", TEST_AUDIO_URLS)
  270. async def test_multi_audio_input(client: openai.AsyncOpenAI, model_name: str,
  271. audio_url: str):
  272. messages = [{
  273. "role":
  274. "user",
  275. "content": [
  276. {
  277. "type": "audio_url",
  278. "audio_url": {
  279. "url": audio_url
  280. }
  281. },
  282. {
  283. "type": "audio_url",
  284. "audio_url": {
  285. "url": audio_url
  286. }
  287. },
  288. {
  289. "type": "text",
  290. "text": "What's happening in this audio?"
  291. },
  292. ],
  293. }]
  294. with pytest.raises(openai.BadRequestError): # test multi-audio input
  295. await client.chat.completions.create(
  296. model=model_name,
  297. messages=messages,
  298. max_tokens=10,
  299. temperature=0.0,
  300. )
  301. # the server should still work afterwards
  302. completion = await client.completions.create(
  303. model=model_name,
  304. prompt=[0, 0, 0, 0, 0],
  305. max_tokens=5,
  306. temperature=0.0,
  307. )
  308. completion = completion.choices[0].text
  309. assert completion is not None and len(completion) >= 0