vision_example.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. """
  2. This example shows how to use Aphrodite for running offline inference
  3. with the correct prompt format on vision language models.
  4. For most models, the prompt format should follow corresponding examples
  5. on HuggingFace model repository.
  6. """
  7. import os
  8. from PIL import Image
  9. from transformers import AutoTokenizer
  10. from aphrodite import LLM, SamplingParams
  11. from aphrodite.common.utils import FlexibleArgumentParser
  12. # Input image and question
  13. image_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
  14. "burg.jpg")
  15. image = Image.open(image_path).convert("RGB")
  16. question = "What is the content of this image?"
  17. # LLaVA-1.5
  18. def run_llava(question):
  19. prompt = f"USER: <image>\n{question}\nASSISTANT:"
  20. llm = LLM(model="llava-hf/llava-1.5-7b-hf")
  21. return llm, prompt
  22. # LLaVA-1.6/LLaVA-NeXT
  23. def run_llava_next(question):
  24. prompt = f"[INST] <image>\n{question} [/INST]"
  25. llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf")
  26. return llm, prompt
  27. # Fuyu
  28. def run_fuyu(question):
  29. prompt = f"{question}\n"
  30. llm = LLM(model="adept/fuyu-8b")
  31. return llm, prompt
  32. # Phi-3-Vision
  33. def run_phi3v(question):
  34. prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
  35. # Note: The default setting of max_num_seqs (256) and
  36. # max_model_len (128k) for this model may cause OOM.
  37. # You may lower either to run this example on lower-end GPUs.
  38. # In this example, we override max_num_seqs to 5 while
  39. # keeping the original context length of 128k.
  40. llm = LLM(
  41. model="microsoft/Phi-3-vision-128k-instruct",
  42. trust_remote_code=True,
  43. max_num_seqs=5,
  44. )
  45. return llm, prompt
  46. # PaliGemma
  47. def run_paligemma(question):
  48. prompt = "caption en"
  49. llm = LLM(model="google/paligemma-3b-mix-224")
  50. return llm, prompt
  51. # Chameleon
  52. def run_chameleon(question):
  53. prompt = f"{question}<image>"
  54. llm = LLM(model="facebook/chameleon-7b")
  55. return llm, prompt
  56. # MiniCPM-V
  57. def run_minicpmv(question):
  58. # 2.0
  59. # The official repo doesn't work yet, so we need to use a fork for now
  60. # model_name = "HwwwH/MiniCPM-V-2"
  61. # 2.5
  62. model_name = "openbmb/MiniCPM-Llama3-V-2_5"
  63. tokenizer = AutoTokenizer.from_pretrained(model_name,
  64. trust_remote_code=True)
  65. llm = LLM(
  66. model=model_name,
  67. trust_remote_code=True,
  68. )
  69. messages = [{
  70. 'role': 'user',
  71. 'content': f'(<image>./</image>)\n{question}'
  72. }]
  73. prompt = tokenizer.apply_chat_template(messages,
  74. tokenize=False,
  75. add_generation_prompt=True)
  76. return llm, prompt
  77. # BLIP-2
  78. def run_blip2(question):
  79. # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
  80. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
  81. prompt = f"Question: {question} Answer:"
  82. llm = LLM(model="Salesforce/blip2-opt-2.7b")
  83. return llm, prompt
  84. # InternVL
  85. def run_internvl(question):
  86. # Generally, InternVL can use chatml template for conversation
  87. TEMPLATE = "<|im_start|>User\n{prompt}<|im_end|>\n<|im_start|>Assistant\n"
  88. prompt = f"<image>\n{question}\n"
  89. prompt = TEMPLATE.format(prompt=prompt)
  90. llm = LLM(
  91. model="OpenGVLab/InternVL2-4B",
  92. trust_remote_code=True,
  93. max_num_seqs=28,
  94. tensor_parallel_size=2,
  95. max_model_len=8192,
  96. )
  97. return llm, prompt
  98. # Qwen2-VL
  99. def run_qwen2_vl(question):
  100. model_name = "Qwen/Qwen2-VL-7B-Instruct"
  101. llm = LLM(
  102. model=model_name,
  103. max_num_seqs=5,
  104. )
  105. prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
  106. "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
  107. f"{question}<|im_end|>\n"
  108. "<|im_start|>assistant\n")
  109. stop_token_ids = None
  110. return llm, prompt, stop_token_ids
  111. model_example_map = {
  112. "llava": run_llava,
  113. "llava-next": run_llava_next,
  114. "fuyu": run_fuyu,
  115. "phi3_v": run_phi3v,
  116. "paligemma": run_paligemma,
  117. "chameleon": run_chameleon,
  118. "minicpmv": run_minicpmv,
  119. "blip-2": run_blip2,
  120. "internvl_chat": run_internvl,
  121. "qwen2_vl": run_qwen2_vl,
  122. }
  123. def main(args):
  124. model = args.model_type
  125. if model not in model_example_map:
  126. raise ValueError(f"Model type {model} is not supported.")
  127. llm, prompt = model_example_map[model](question)
  128. # We set temperature to 0.2 so that outputs can be different
  129. # even when all prompts are identical when running batch inference.
  130. sampling_params = SamplingParams(temperature=0.2, max_tokens=128)
  131. assert args.num_prompts > 0
  132. if args.num_prompts == 1:
  133. # Single inference
  134. inputs = {
  135. "prompt": prompt,
  136. "multi_modal_data": {
  137. "image": image
  138. },
  139. }
  140. else:
  141. # Batch inference
  142. inputs = [{
  143. "prompt": prompt,
  144. "multi_modal_data": {
  145. "image": image
  146. },
  147. } for _ in range(args.num_prompts)]
  148. outputs = llm.generate(inputs, sampling_params=sampling_params)
  149. for o in outputs:
  150. generated_text = o.outputs[0].text
  151. print(generated_text)
  152. if __name__ == "__main__":
  153. parser = FlexibleArgumentParser(
  154. description='Demo on using Aphrodite for offline inference with '
  155. 'vision language models')
  156. parser.add_argument('--model-type',
  157. '-m',
  158. type=str,
  159. default="llava",
  160. choices=model_example_map.keys(),
  161. help='Huggingface "model_type".')
  162. parser.add_argument('--num-prompts',
  163. type=int,
  164. default=1,
  165. help='Number of prompts to run.')
  166. args = parser.parse_args()
  167. main(args)