api.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. from fastapi import FastAPI, Body
  3. from fastapi.exceptions import HTTPException
  4. from PIL import Image
  5. import os
  6. import gradio as gr
  7. from modules.api.models import *
  8. from modules.api import api
  9. import glob
  10. import base64
  11. from pydantic import BaseModel
  12. def get_root_path():
  13. path = os.path.dirname(os.path.realpath(__file__))
  14. path = os.path.dirname(path)
  15. path = os.path.dirname(path)
  16. path = os.path.dirname(path)
  17. return path
  18. class GetGeneratedImagesRequest(BaseModel):
  19. limit: int = 10
  20. def StableStudio_api(_: gr.Blocks, app: FastAPI):
  21. @app.get("/StableStudio/check-extension-installed")
  22. async def check_extension_installed(extension_name: str):
  23. extension_path = os.path.join(get_root_path(), "extensions", extension_name)
  24. installed = 0
  25. if os.path.exists(extension_path):
  26. installed = 1
  27. return {
  28. "extension_path": extension_path,
  29. "installed": installed
  30. }
  31. @app.post("/StableStudio/get-generated-images")
  32. async def get_generated_images(request: GetGeneratedImagesRequest):
  33. outputs_path = os.path.join(get_root_path(), "outputs")
  34. txt2img_folder = os.path.join(outputs_path, 'txt2img-images', '**')
  35. img2img_folder = os.path.join(outputs_path, 'img2img-images', '**')
  36. files = glob.glob(txt2img_folder, recursive=True) + glob.glob(img2img_folder, recursive=True)
  37. files = [f for f in files if os.path.isfile(f)]
  38. files.sort(key=os.path.getctime, reverse=True)
  39. files = files[:request.limit]
  40. return_values = []
  41. for file in files:
  42. with open(file, "rb") as f:
  43. img = Image.open(f)
  44. width, height = img.size
  45. encoded_content = api.encode_pil_to_base64(img)
  46. image_name = os.path.basename(file)
  47. seed = int(image_name.split(".")[0].split("-")[1])
  48. return_value = {
  49. "image_name": image_name,
  50. "create_date": os.path.getctime(file),
  51. "content": encoded_content,
  52. "width": width,
  53. "height": height,
  54. "seed": seed
  55. }
  56. return_values.append(return_value)
  57. return return_values
  58. try:
  59. import modules.script_callbacks as script_callbacks
  60. script_callbacks.on_app_started(StableStudio_api)
  61. except:
  62. pass