vision_example.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. """
  2. This example shows how to use Aphrodite for running offline inference
  3. with the correct prompt format on vision language models.
  4. For most models, the prompt format should follow corresponding examples
  5. on HuggingFace model repository.
  6. """
  7. import os
  8. import cv2
  9. import numpy as np
  10. from PIL import Image
  11. from transformers import AutoTokenizer
  12. from aphrodite import LLM, SamplingParams
  13. from aphrodite.assets.video import VideoAsset
  14. from aphrodite.common.utils import FlexibleArgumentParser
  15. from aphrodite.multimodal.utils import sample_frames_from_video
  16. # Input image and question
  17. image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  18. "burg.jpg")
  19. image = Image.open(image_path).convert("RGB")
  20. img_question = "What is the content of this image?"
  21. # Input video and question
  22. video_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  23. "nadeko.mp4")
  24. vid_question = "What's in this video?"
  25. def load_video_frames(video_path: str, num_frames: int) -> np.ndarray:
  26. """
  27. Load video frames from a local file path.
  28. Args:
  29. video_path: Path to the video file
  30. num_frames: Number of frames to sample from the video
  31. Returns:
  32. np.ndarray: Array of sampled video frames
  33. """
  34. cap = cv2.VideoCapture(video_path)
  35. if not cap.isOpened():
  36. raise ValueError(f"Could not open video file {video_path}")
  37. frames = []
  38. while True:
  39. ret, frame = cap.read()
  40. if not ret:
  41. break
  42. frames.append(frame)
  43. cap.release()
  44. frames = np.stack(frames)
  45. return sample_frames_from_video(frames, num_frames)
  46. # LLaVA-1.5
  47. def run_llava(question, modality):
  48. assert modality == "image"
  49. prompt = f"USER: <image>\n{question}\nASSISTANT:"
  50. llm = LLM(model="llava-hf/llava-1.5-7b-hf")
  51. stop_token_ids = None
  52. return llm, prompt, stop_token_ids
  53. # LLaVA-1.6/LLaVA-NeXT
  54. def run_llava_next(question, modality):
  55. assert modality == "image"
  56. prompt = f"[INST] <image>\n{question} [/INST]"
  57. llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
  58. stop_token_ids = None
  59. return llm, prompt, stop_token_ids
  60. # LlaVA-NeXT-Video
  61. # Currently only support for video input
  62. def run_llava_next_video(question, modality):
  63. assert modality == "video"
  64. prompt = f"USER: <video>\n{question} ASSISTANT:"
  65. llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
  66. stop_token_ids = None
  67. return llm, prompt, stop_token_ids
  68. # LLaVA-OneVision
  69. def run_llava_onevision(question, modality):
  70. if modality == "video":
  71. prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
  72. <|im_start|>assistant\n"
  73. elif modality == "image":
  74. prompt = f"<|im_start|>user <image>\n{question}<|im_end|> \
  75. <|im_start|>assistant\n"
  76. llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
  77. max_model_len=32768)
  78. stop_token_ids = None
  79. return llm, prompt, stop_token_ids
  80. # Fuyu
  81. def run_fuyu(question, modality):
  82. assert modality == "image"
  83. prompt = f"{question}\n"
  84. llm = LLM(model="adept/fuyu-8b")
  85. stop_token_ids = None
  86. return llm, prompt, stop_token_ids
  87. # Phi-3-Vision
  88. def run_phi3v(question, modality):
  89. assert modality == "image"
  90. prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
  91. # Note: The default setting of max_num_seqs (256) and
  92. # max_model_len (128k) for this model may cause OOM.
  93. # You may lower either to run this example on lower-end GPUs.
  94. # In this example, we override max_num_seqs to 5 while
  95. # keeping the original context length of 128k.
  96. llm = LLM(
  97. model="microsoft/Phi-3-vision-128k-instruct",
  98. trust_remote_code=True,
  99. max_num_seqs=5,
  100. )
  101. stop_token_ids = None
  102. return llm, prompt, stop_token_ids
  103. # PaliGemma
  104. def run_paligemma(question, modality):
  105. assert modality == "image"
  106. # PaliGemma has special prompt format for VQA
  107. prompt = "caption en"
  108. llm = LLM(model="google/paligemma-3b-mix-224")
  109. stop_token_ids = None
  110. return llm, prompt, stop_token_ids
  111. # Chameleon
  112. def run_chameleon(question, modality):
  113. assert modality == "image"
  114. prompt = f"{question}<image>"
  115. llm = LLM(model="facebook/chameleon-7b")
  116. stop_token_ids = None
  117. return llm, prompt, stop_token_ids
  118. # MiniCPM-V
  119. def run_minicpmv(question, modality):
  120. assert modality == "image"
  121. # 2.0
  122. # The official repo doesn't work yet, so we need to use a fork for now
  123. # model_name = "HwwwH/MiniCPM-V-2"
  124. # 2.5
  125. # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
  126. #2.6
  127. model_name = "openbmb/MiniCPM-V-2_6"
  128. tokenizer = AutoTokenizer.from_pretrained(model_name,
  129. trust_remote_code=True)
  130. llm = LLM(
  131. model=model_name,
  132. trust_remote_code=True,
  133. max_model_len=8192,
  134. )
  135. # NOTE The stop_token_ids are different for various versions of MiniCPM-V
  136. # 2.0
  137. # stop_token_ids = [tokenizer.eos_id]
  138. # 2.5
  139. # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
  140. # 2.6
  141. stop_tokens = ['<|im_end|>', '<|endoftext|>']
  142. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  143. messages = [{
  144. 'role': 'user',
  145. 'content': f'(<image>./</image>)\n{question}'
  146. }]
  147. prompt = tokenizer.apply_chat_template(messages,
  148. tokenize=False,
  149. add_generation_prompt=True)
  150. return llm, prompt, stop_token_ids
  151. # InternVL
  152. def run_internvl(question, modality):
  153. assert modality == "image"
  154. model_name = "OpenGVLab/InternVL2-2B"
  155. llm = LLM(
  156. model=model_name,
  157. trust_remote_code=True,
  158. max_num_seqs=5,
  159. )
  160. tokenizer = AutoTokenizer.from_pretrained(model_name,
  161. trust_remote_code=True)
  162. messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
  163. prompt = tokenizer.apply_chat_template(messages,
  164. tokenize=False,
  165. add_generation_prompt=True)
  166. # Stop tokens for InternVL
  167. # models variants may have different stop tokens
  168. # please refer to the model card for the correct "stop words":
  169. # https://huggingface.co/OpenGVLab/InternVL2-2B#service
  170. stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
  171. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  172. return llm, prompt, stop_token_ids
  173. # BLIP-2
  174. def run_blip2(question, modality):
  175. assert modality == "image"
  176. # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
  177. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
  178. prompt = f"Question: {question} Answer:"
  179. llm = LLM(model="Salesforce/blip2-opt-2.7b")
  180. stop_token_ids = None
  181. return llm, prompt, stop_token_ids
  182. # Qwen
  183. def run_qwen_vl(question, modality):
  184. assert modality == "image"
  185. llm = LLM(
  186. model="Qwen/Qwen-VL",
  187. trust_remote_code=True,
  188. max_num_seqs=5,
  189. )
  190. prompt = f"{question}Picture 1: <img></img>\n"
  191. stop_token_ids = None
  192. return llm, prompt, stop_token_ids
  193. # Qwen2-VL
  194. def run_qwen2_vl(question, modality):
  195. assert modality == "image"
  196. model_name = "Qwen/Qwen2-VL-7B-Instruct"
  197. llm = LLM(
  198. model=model_name,
  199. max_num_seqs=5,
  200. )
  201. prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
  202. "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
  203. f"{question}<|im_end|>\n"
  204. "<|im_start|>assistant\n")
  205. stop_token_ids = None
  206. return llm, prompt, stop_token_ids
  207. # Molmo
  208. def run_molmo(question):
  209. model_name = "allenai/Molmo-7B-D-0924"
  210. llm = LLM(
  211. model=model_name,
  212. trust_remote_code=True,
  213. dtype="bfloat16",
  214. )
  215. prompt = question
  216. stop_token_ids = None
  217. return llm, prompt, stop_token_ids
  218. model_example_map = {
  219. "llava": run_llava,
  220. "llava-next": run_llava_next,
  221. "llava-next-video": run_llava_next_video,
  222. "llava-onevision": run_llava_onevision,
  223. "fuyu": run_fuyu,
  224. "phi3_v": run_phi3v,
  225. "paligemma": run_paligemma,
  226. "chameleon": run_chameleon,
  227. "minicpmv": run_minicpmv,
  228. "blip-2": run_blip2,
  229. "internvl_chat": run_internvl,
  230. "qwen_vl": run_qwen_vl,
  231. "qwen2_vl": run_qwen2_vl,
  232. "molmo": run_molmo,
  233. }
  234. def get_multi_modal_input(args):
  235. """
  236. return {
  237. "data": image or video,
  238. "question": question,
  239. }
  240. """
  241. if args.modality == "image":
  242. return {
  243. "data": image,
  244. "question": img_question,
  245. }
  246. if args.modality == "video":
  247. video = VideoAsset(name="nadeko.mp4",
  248. num_frames=args.num_frames,
  249. local_path=video_path).np_ndarrays
  250. return {
  251. "data": video,
  252. "question": vid_question,
  253. }
  254. msg = f"Modality {args.modality} is not supported."
  255. raise ValueError(msg)
  256. def main(args):
  257. model = args.model_type
  258. if model not in model_example_map:
  259. raise ValueError(f"Model type {model} is not supported.")
  260. modality = args.modality
  261. mm_input = get_multi_modal_input(args)
  262. data = mm_input["data"]
  263. question = mm_input["question"]
  264. llm, prompt, stop_token_ids = model_example_map[model](question, modality)
  265. # We set temperature to 0.2 so that outputs can be different
  266. # even when all prompts are identical when running batch inference.
  267. sampling_params = SamplingParams(temperature=0.2,
  268. max_tokens=512,
  269. stop_token_ids=stop_token_ids)
  270. assert args.num_prompts > 0
  271. if args.num_prompts == 1:
  272. # Single inference
  273. inputs = {
  274. "prompt": prompt,
  275. "multi_modal_data": {
  276. modality: data
  277. },
  278. }
  279. else:
  280. # Batch inference
  281. inputs = [{
  282. "prompt": prompt,
  283. "multi_modal_data": {
  284. modality: data
  285. },
  286. } for _ in range(args.num_prompts)]
  287. outputs = llm.generate(inputs, sampling_params=sampling_params)
  288. for o in outputs:
  289. generated_text = o.outputs[0].text
  290. print(generated_text)
  291. if __name__ == "__main__":
  292. parser = FlexibleArgumentParser(
  293. description='Demo on using Aphrodite for offline inference with '
  294. 'vision language models')
  295. parser.add_argument('--model-type',
  296. '-m',
  297. type=str,
  298. default="llava",
  299. choices=model_example_map.keys(),
  300. help='Huggingface "model_type".')
  301. parser.add_argument('--num-prompts',
  302. type=int,
  303. default=1,
  304. help='Number of prompts to run.')
  305. parser.add_argument('--modality',
  306. type=str,
  307. default="image",
  308. choices=['image', 'video'],
  309. help='Modality of the input.')
  310. parser.add_argument('--num-frames',
  311. type=int,
  312. default=16,
  313. help='Number of frames to extract from the video.')
  314. args = parser.parse_args()
  315. main(args)