vision_example.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. """
  2. This example shows how to use Aphrodite for running offline inference
  3. with the correct prompt format on vision language models.
  4. For most models, the prompt format should follow corresponding examples
  5. on HuggingFace model repository.
  6. """
  7. import os
  8. import cv2
  9. import numpy as np
  10. from PIL import Image
  11. from transformers import AutoTokenizer
  12. from aphrodite import LLM, SamplingParams
  13. from aphrodite.assets.video import VideoAsset
  14. from aphrodite.common.utils import FlexibleArgumentParser
  15. from aphrodite.multimodal.utils import sample_frames_from_video
  16. # Input image and question
  17. image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  18. "burg.jpg")
  19. image = Image.open(image_path).convert("RGB")
  20. img_question = "What is the content of this image?"
  21. # Input video and question
  22. video_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  23. "nadeko.mp4")
  24. vid_question = "What's in this video?"
  25. def load_video_frames(video_path: str, num_frames: int) -> np.ndarray:
  26. """
  27. Load video frames from a local file path.
  28. Args:
  29. video_path: Path to the video file
  30. num_frames: Number of frames to sample from the video
  31. Returns:
  32. np.ndarray: Array of sampled video frames
  33. """
  34. cap = cv2.VideoCapture(video_path)
  35. if not cap.isOpened():
  36. raise ValueError(f"Could not open video file {video_path}")
  37. frames = []
  38. while True:
  39. ret, frame = cap.read()
  40. if not ret:
  41. break
  42. frames.append(frame)
  43. cap.release()
  44. frames = np.stack(frames)
  45. return sample_frames_from_video(frames, num_frames)
  46. # LLaVA-1.5
  47. def run_llava(question):
  48. prompt = f"USER: <image>\n{question}\nASSISTANT:"
  49. llm = LLM(model="llava-hf/llava-1.5-7b-hf")
  50. stop_token_ids = None
  51. return llm, prompt, stop_token_ids
  52. # LLaVA-1.6/LLaVA-NeXT
  53. def run_llava_next(question):
  54. prompt = f"[INST] <image>\n{question} [/INST]"
  55. llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
  56. stop_token_ids = None
  57. return llm, prompt, stop_token_ids
  58. # LlaVA-NeXT-Video
  59. # Currently only support for video input
  60. def run_llava_next_video(question):
  61. prompt = f"USER: <video>\n{question} ASSISTANT:"
  62. llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
  63. stop_token_ids = None
  64. return llm, prompt, stop_token_ids
  65. # Fuyu
  66. def run_fuyu(question):
  67. prompt = f"{question}\n"
  68. llm = LLM(model="adept/fuyu-8b")
  69. stop_token_ids = None
  70. return llm, prompt, stop_token_ids
  71. # Phi-3-Vision
  72. def run_phi3v(question):
  73. prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
  74. # Note: The default setting of max_num_seqs (256) and
  75. # max_model_len (128k) for this model may cause OOM.
  76. # You may lower either to run this example on lower-end GPUs.
  77. # In this example, we override max_num_seqs to 5 while
  78. # keeping the original context length of 128k.
  79. llm = LLM(
  80. model="microsoft/Phi-3-vision-128k-instruct",
  81. trust_remote_code=True,
  82. max_num_seqs=5,
  83. )
  84. stop_token_ids = None
  85. return llm, prompt, stop_token_ids
  86. # PaliGemma
  87. def run_paligemma(question):
  88. # PaliGemma has special prompt format for VQA
  89. prompt = "caption en"
  90. llm = LLM(model="google/paligemma-3b-mix-224")
  91. stop_token_ids = None
  92. return llm, prompt, stop_token_ids
  93. # Chameleon
  94. def run_chameleon(question):
  95. prompt = f"{question}<image>"
  96. llm = LLM(model="facebook/chameleon-7b")
  97. stop_token_ids = None
  98. return llm, prompt, stop_token_ids
  99. # MiniCPM-V
  100. def run_minicpmv(question):
  101. # 2.0
  102. # The official repo doesn't work yet, so we need to use a fork for now
  103. # model_name = "HwwwH/MiniCPM-V-2"
  104. # 2.5
  105. # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
  106. #2.6
  107. model_name = "openbmb/MiniCPM-V-2_6"
  108. tokenizer = AutoTokenizer.from_pretrained(model_name,
  109. trust_remote_code=True)
  110. llm = LLM(
  111. model=model_name,
  112. trust_remote_code=True,
  113. max_model_len=8192,
  114. )
  115. # NOTE The stop_token_ids are different for various versions of MiniCPM-V
  116. # 2.0
  117. # stop_token_ids = [tokenizer.eos_id]
  118. # 2.5
  119. # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
  120. # 2.6
  121. stop_tokens = ['<|im_end|>', '<|endoftext|>']
  122. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  123. messages = [{
  124. 'role': 'user',
  125. 'content': f'(<image>./</image>)\n{question}'
  126. }]
  127. prompt = tokenizer.apply_chat_template(messages,
  128. tokenize=False,
  129. add_generation_prompt=True)
  130. return llm, prompt, stop_token_ids
  131. # InternVL
  132. def run_internvl(question):
  133. model_name = "OpenGVLab/InternVL2-2B"
  134. llm = LLM(
  135. model=model_name,
  136. trust_remote_code=True,
  137. max_num_seqs=5,
  138. )
  139. tokenizer = AutoTokenizer.from_pretrained(model_name,
  140. trust_remote_code=True)
  141. messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
  142. prompt = tokenizer.apply_chat_template(messages,
  143. tokenize=False,
  144. add_generation_prompt=True)
  145. # Stop tokens for InternVL
  146. # models variants may have different stop tokens
  147. # please refer to the model card for the correct "stop words":
  148. # https://huggingface.co/OpenGVLab/InternVL2-2B#service
  149. stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
  150. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  151. return llm, prompt, stop_token_ids
  152. # BLIP-2
  153. def run_blip2(question):
  154. # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
  155. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
  156. prompt = f"Question: {question} Answer:"
  157. llm = LLM(model="Salesforce/blip2-opt-2.7b")
  158. stop_token_ids = None
  159. return llm, prompt, stop_token_ids
  160. # Qwen
  161. def run_qwen_vl(question):
  162. llm = LLM(
  163. model="Qwen/Qwen-VL",
  164. trust_remote_code=True,
  165. max_num_seqs=5,
  166. )
  167. prompt = f"{question}Picture 1: <img></img>\n"
  168. stop_token_ids = None
  169. return llm, prompt, stop_token_ids
  170. # Qwen2-VL
  171. def run_qwen2_vl(question):
  172. model_name = "Qwen/Qwen2-VL-7B-Instruct"
  173. llm = LLM(
  174. model=model_name,
  175. max_num_seqs=5,
  176. )
  177. prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
  178. "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
  179. f"{question}<|im_end|>\n"
  180. "<|im_start|>assistant\n")
  181. stop_token_ids = None
  182. return llm, prompt, stop_token_ids
  183. # Molmo
  184. def run_molmo(question):
  185. model_name = "allenai/Molmo-7B-D-0924"
  186. llm = LLM(
  187. model=model_name,
  188. trust_remote_code=True,
  189. dtype="bfloat16",
  190. )
  191. prompt = question
  192. stop_token_ids = None
  193. return llm, prompt, stop_token_ids
  194. model_example_map = {
  195. "llava": run_llava,
  196. "llava-next": run_llava_next,
  197. "llava-next-video": run_llava_next_video,
  198. "fuyu": run_fuyu,
  199. "phi3_v": run_phi3v,
  200. "paligemma": run_paligemma,
  201. "chameleon": run_chameleon,
  202. "minicpmv": run_minicpmv,
  203. "blip-2": run_blip2,
  204. "internvl_chat": run_internvl,
  205. "qwen_vl": run_qwen_vl,
  206. "qwen2_vl": run_qwen2_vl,
  207. "molmo": run_molmo,
  208. }
  209. def get_multi_modal_input(args):
  210. """
  211. return {
  212. "data": image or video,
  213. "question": question,
  214. }
  215. """
  216. if args.modality == "image":
  217. return {
  218. "data": image,
  219. "question": img_question,
  220. }
  221. if args.modality == "video":
  222. video = VideoAsset(name="nadeko.mp4",
  223. num_frames=args.num_frames,
  224. local_path=video_path).np_ndarrays
  225. return {
  226. "data": video,
  227. "question": vid_question,
  228. }
  229. msg = f"Modality {args.modality} is not supported."
  230. raise ValueError(msg)
  231. def main(args):
  232. model = args.model_type
  233. if model not in model_example_map:
  234. raise ValueError(f"Model type {model} is not supported.")
  235. modality = args.modality
  236. mm_input = get_multi_modal_input(args)
  237. data = mm_input["data"]
  238. question = mm_input["question"]
  239. llm, prompt, stop_token_ids = model_example_map[model](question)
  240. # We set temperature to 0.2 so that outputs can be different
  241. # even when all prompts are identical when running batch inference.
  242. sampling_params = SamplingParams(temperature=0.2,
  243. max_tokens=512,
  244. stop_token_ids=stop_token_ids)
  245. assert args.num_prompts > 0
  246. if args.num_prompts == 1:
  247. # Single inference
  248. inputs = {
  249. "prompt": prompt,
  250. "multi_modal_data": {
  251. modality: data
  252. },
  253. }
  254. else:
  255. # Batch inference
  256. inputs = [{
  257. "prompt": prompt,
  258. "multi_modal_data": {
  259. modality: data
  260. },
  261. } for _ in range(args.num_prompts)]
  262. outputs = llm.generate(inputs, sampling_params=sampling_params)
  263. for o in outputs:
  264. generated_text = o.outputs[0].text
  265. print(generated_text)
  266. if __name__ == "__main__":
  267. parser = FlexibleArgumentParser(
  268. description='Demo on using Aphrodite for offline inference with '
  269. 'vision language models')
  270. parser.add_argument('--model-type',
  271. '-m',
  272. type=str,
  273. default="llava",
  274. choices=model_example_map.keys(),
  275. help='Huggingface "model_type".')
  276. parser.add_argument('--num-prompts',
  277. type=int,
  278. default=1,
  279. help='Number of prompts to run.')
  280. parser.add_argument('--modality',
  281. type=str,
  282. default="image",
  283. help='Modality of the input.')
  284. parser.add_argument('--num-frames',
  285. type=int,
  286. default=16,
  287. help='Number of frames to extract from the video.')
  288. args = parser.parse_args()
  289. main(args)