vision_example.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. """
  2. This example shows how to use Aphrodite for running offline inference
  3. with the correct prompt format on vision language models.
  4. For most models, the prompt format should follow corresponding examples
  5. on HuggingFace model repository.
  6. """
  7. import os
  8. import cv2
  9. import numpy as np
  10. from PIL import Image
  11. from transformers import AutoTokenizer
  12. from aphrodite import LLM, SamplingParams
  13. from aphrodite.assets.video import VideoAsset
  14. from aphrodite.common.utils import FlexibleArgumentParser
  15. from aphrodite.multimodal.utils import sample_frames_from_video
  16. # Input image and question
  17. image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  18. "burg.jpg")
  19. image = Image.open(image_path).convert("RGB")
  20. img_question = "What is the content of this image?"
  21. # Input video and question
  22. video_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  23. "nadeko.mp4")
  24. vid_question = "What's in this video?"
  25. def load_video_frames(video_path: str, num_frames: int) -> np.ndarray:
  26. """
  27. Load video frames from a local file path.
  28. Args:
  29. video_path: Path to the video file
  30. num_frames: Number of frames to sample from the video
  31. Returns:
  32. np.ndarray: Array of sampled video frames
  33. """
  34. cap = cv2.VideoCapture(video_path)
  35. if not cap.isOpened():
  36. raise ValueError(f"Could not open video file {video_path}")
  37. frames = []
  38. while True:
  39. ret, frame = cap.read()
  40. if not ret:
  41. break
  42. frames.append(frame)
  43. cap.release()
  44. frames = np.stack(frames)
  45. return sample_frames_from_video(frames, num_frames)
  46. # LLaVA-1.5
  47. def run_llava(question):
  48. prompt = f"USER: <image>\n{question}\nASSISTANT:"
  49. llm = LLM(model="llava-hf/llava-1.5-7b-hf")
  50. stop_token_ids = None
  51. return llm, prompt, stop_token_ids
  52. # LLaVA-1.6/LLaVA-NeXT
  53. def run_llava_next(question):
  54. prompt = f"[INST] <image>\n{question} [/INST]"
  55. llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192)
  56. stop_token_ids = None
  57. return llm, prompt, stop_token_ids
  58. # LlaVA-NeXT-Video
  59. # Currently only support for video input
  60. def run_llava_next_video(question):
  61. prompt = f"USER: <video>\n{question} ASSISTANT:"
  62. llm = LLM(model="llava-hf/LLaVA-NeXT-Video-7B-hf")
  63. stop_token_ids = None
  64. return llm, prompt, stop_token_ids
  65. # Fuyu
  66. def run_fuyu(question):
  67. prompt = f"{question}\n"
  68. llm = LLM(model="adept/fuyu-8b")
  69. stop_token_ids = None
  70. return llm, prompt, stop_token_ids
  71. # Phi-3-Vision
  72. def run_phi3v(question):
  73. prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
  74. # Note: The default setting of max_num_seqs (256) and
  75. # max_model_len (128k) for this model may cause OOM.
  76. # You may lower either to run this example on lower-end GPUs.
  77. # In this example, we override max_num_seqs to 5 while
  78. # keeping the original context length of 128k.
  79. llm = LLM(
  80. model="microsoft/Phi-3-vision-128k-instruct",
  81. trust_remote_code=True,
  82. max_num_seqs=5,
  83. )
  84. stop_token_ids = None
  85. return llm, prompt, stop_token_ids
  86. # PaliGemma
  87. def run_paligemma(question):
  88. # PaliGemma has special prompt format for VQA
  89. prompt = "caption en"
  90. llm = LLM(model="google/paligemma-3b-mix-224")
  91. stop_token_ids = None
  92. return llm, prompt, stop_token_ids
  93. # Chameleon
  94. def run_chameleon(question):
  95. prompt = f"{question}<image>"
  96. llm = LLM(model="facebook/chameleon-7b")
  97. stop_token_ids = None
  98. return llm, prompt, stop_token_ids
  99. # MiniCPM-V
  100. def run_minicpmv(question):
  101. # 2.0
  102. # The official repo doesn't work yet, so we need to use a fork for now
  103. # model_name = "HwwwH/MiniCPM-V-2"
  104. # 2.5
  105. # model_name = "openbmb/MiniCPM-Llama3-V-2_5"
  106. #2.6
  107. model_name = "openbmb/MiniCPM-V-2_6"
  108. tokenizer = AutoTokenizer.from_pretrained(model_name,
  109. trust_remote_code=True)
  110. llm = LLM(
  111. model=model_name,
  112. trust_remote_code=True,
  113. )
  114. # NOTE The stop_token_ids are different for various versions of MiniCPM-V
  115. # 2.0
  116. # stop_token_ids = [tokenizer.eos_id]
  117. # 2.5
  118. # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
  119. # 2.6
  120. stop_tokens = ['<|im_end|>', '<|endoftext|>']
  121. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  122. messages = [{
  123. 'role': 'user',
  124. 'content': f'(<image>./</image>)\n{question}'
  125. }]
  126. prompt = tokenizer.apply_chat_template(messages,
  127. tokenize=False,
  128. add_generation_prompt=True)
  129. return llm, prompt, stop_token_ids
  130. # InternVL
  131. def run_internvl(question):
  132. model_name = "OpenGVLab/InternVL2-2B"
  133. llm = LLM(
  134. model=model_name,
  135. trust_remote_code=True,
  136. max_num_seqs=5,
  137. )
  138. tokenizer = AutoTokenizer.from_pretrained(model_name,
  139. trust_remote_code=True)
  140. messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
  141. prompt = tokenizer.apply_chat_template(messages,
  142. tokenize=False,
  143. add_generation_prompt=True)
  144. # Stop tokens for InternVL
  145. # models variants may have different stop tokens
  146. # please refer to the model card for the correct "stop words":
  147. # https://huggingface.co/OpenGVLab/InternVL2-2B#service
  148. stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
  149. stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
  150. return llm, prompt, stop_token_ids
  151. # BLIP-2
  152. def run_blip2(question):
  153. # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
  154. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
  155. prompt = f"Question: {question} Answer:"
  156. llm = LLM(model="Salesforce/blip2-opt-2.7b")
  157. stop_token_ids = None
  158. return llm, prompt, stop_token_ids
  159. # Qwen
  160. def run_qwen_vl(question):
  161. llm = LLM(
  162. model="Qwen/Qwen-VL",
  163. trust_remote_code=True,
  164. max_num_seqs=5,
  165. )
  166. prompt = f"{question}Picture 1: <img></img>\n"
  167. stop_token_ids = None
  168. return llm, prompt, stop_token_ids
  169. # Qwen2-VL
  170. def run_qwen2_vl(question):
  171. model_name = "Qwen/Qwen2-VL-7B-Instruct"
  172. llm = LLM(
  173. model=model_name,
  174. max_num_seqs=5,
  175. )
  176. prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
  177. "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
  178. f"{question}<|im_end|>\n"
  179. "<|im_start|>assistant\n")
  180. stop_token_ids = None
  181. return llm, prompt, stop_token_ids
  182. model_example_map = {
  183. "llava": run_llava,
  184. "llava-next": run_llava_next,
  185. "llava-next-video": run_llava_next_video,
  186. "fuyu": run_fuyu,
  187. "phi3_v": run_phi3v,
  188. "paligemma": run_paligemma,
  189. "chameleon": run_chameleon,
  190. "minicpmv": run_minicpmv,
  191. "blip-2": run_blip2,
  192. "internvl_chat": run_internvl,
  193. "qwen_vl": run_qwen_vl,
  194. "qwen2_vl": run_qwen2_vl,
  195. }
  196. def get_multi_modal_input(args):
  197. """
  198. return {
  199. "data": image or video,
  200. "question": question,
  201. }
  202. """
  203. if args.modality == "image":
  204. return {
  205. "data": image,
  206. "question": img_question,
  207. }
  208. if args.modality == "video":
  209. video = VideoAsset(name="nadeko.mp4",
  210. num_frames=args.num_frames,
  211. local_path=video_path).np_ndarrays
  212. return {
  213. "data": video,
  214. "question": vid_question,
  215. }
  216. msg = f"Modality {args.modality} is not supported."
  217. raise ValueError(msg)
  218. def main(args):
  219. model = args.model_type
  220. if model not in model_example_map:
  221. raise ValueError(f"Model type {model} is not supported.")
  222. modality = args.modality
  223. mm_input = get_multi_modal_input(args)
  224. data = mm_input["data"]
  225. question = mm_input["question"]
  226. llm, prompt, stop_token_ids = model_example_map[model](question)
  227. # We set temperature to 0.2 so that outputs can be different
  228. # even when all prompts are identical when running batch inference.
  229. sampling_params = SamplingParams(temperature=0.2,
  230. max_tokens=512,
  231. stop_token_ids=stop_token_ids)
  232. assert args.num_prompts > 0
  233. if args.num_prompts == 1:
  234. # Single inference
  235. inputs = {
  236. "prompt": prompt,
  237. "multi_modal_data": {
  238. modality: data
  239. },
  240. }
  241. else:
  242. # Batch inference
  243. inputs = [{
  244. "prompt": prompt,
  245. "multi_modal_data": {
  246. modality: data
  247. },
  248. } for _ in range(args.num_prompts)]
  249. outputs = llm.generate(inputs, sampling_params=sampling_params)
  250. for o in outputs:
  251. generated_text = o.outputs[0].text
  252. print(generated_text)
  253. if __name__ == "__main__":
  254. parser = FlexibleArgumentParser(
  255. description='Demo on using vLLM for offline inference with '
  256. 'vision language models')
  257. parser.add_argument('--model-type',
  258. '-m',
  259. type=str,
  260. default="llava",
  261. choices=model_example_map.keys(),
  262. help='Huggingface "model_type".')
  263. parser.add_argument('--num-prompts',
  264. type=int,
  265. default=1,
  266. help='Number of prompts to run.')
  267. parser.add_argument('--modality',
  268. type=str,
  269. default="image",
  270. help='Modality of the input.')
  271. parser.add_argument('--num-frames',
  272. type=int,
  273. default=16,
  274. help='Number of frames to extract from the video.')
  275. args = parser.parse_args()
  276. main(args)