llava_example.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import argparse
  2. import os
  3. import subprocess
  4. import torch
  5. from PIL import Image
  6. from aphrodite import LLM
  7. from aphrodite.multimodal.image import ImageFeatureData, ImagePixelData
  8. # The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
  9. # You can use `.buildkite/download-images.sh` to download them
  10. def run_llava_pixel_values(*, disable_image_processor: bool = False):
  11. llm = LLM(
  12. model="llava-hf/llava-1.5-7b-hf",
  13. image_input_type="pixel_values",
  14. image_token_id=32000,
  15. image_input_shape="1,3,336,336",
  16. image_feature_size=576,
  17. disable_image_processor=disable_image_processor,
  18. )
  19. prompt = "<image>" * 576 + (
  20. "\nUSER: What is the content of this image?\nASSISTANT:")
  21. if disable_image_processor:
  22. image = torch.load("images/stop_sign_pixel_values.pt")
  23. else:
  24. image = Image.open("images/stop_sign.jpg")
  25. outputs = llm.generate({
  26. "prompt": prompt,
  27. "multi_modal_data": ImagePixelData(image),
  28. })
  29. for o in outputs:
  30. generated_text = o.outputs[0].text
  31. print(generated_text)
  32. def run_llava_image_features():
  33. llm = LLM(
  34. model="llava-hf/llava-1.5-7b-hf",
  35. image_input_type="image_features",
  36. image_token_id=32000,
  37. image_input_shape="1,576,1024",
  38. image_feature_size=576,
  39. )
  40. prompt = "<image>" * 576 + (
  41. "\nUSER: What is the content of this image?\nASSISTANT:")
  42. image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")
  43. outputs = llm.generate({
  44. "prompt": prompt,
  45. "multi_modal_data": ImageFeatureData(image),
  46. })
  47. for o in outputs:
  48. generated_text = o.outputs[0].text
  49. print(generated_text)
  50. def main(args):
  51. if args.type == "pixel_values":
  52. run_llava_pixel_values()
  53. else:
  54. run_llava_image_features()
  55. if __name__ == "__main__":
  56. parser = argparse.ArgumentParser(description="Demo on Llava")
  57. parser.add_argument("--type",
  58. type=str,
  59. choices=["pixel_values", "image_features"],
  60. default="pixel_values",
  61. help="image input type")
  62. args = parser.parse_args()
  63. # Download from s3
  64. s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
  65. local_directory = "images"
  66. # Make sure the local directory exists or create it
  67. os.makedirs(local_directory, exist_ok=True)
  68. # Use AWS CLI to sync the directory, assume anonymous access
  69. subprocess.check_call([
  70. "aws",
  71. "s3",
  72. "sync",
  73. s3_bucket_path,
  74. local_directory,
  75. "--no-sign-request",
  76. ])
  77. main(args)