2
0

setup_or_recover.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import os
  2. def env_entry(name, value, comment, null=False):
  3. return f"# {comment}\n{'# ' if null else ''}{name}={value}\n"
  4. def generate_env(
  5. *,
  6. env_suno_use_small_models: bool = False,
  7. env_suno_enable_mps: bool = False,
  8. env_suno_offload_cpu: bool = False,
  9. model_location_hf_env_var: str = "",
  10. model_location_hf_env_var2: str = "",
  11. model_location_th_home: str = "",
  12. model_location_th_xdg: str = "",
  13. # data\models\rvc\checkpoints\Alina_Gray-20230627T032329Z-001\Alina_Gray.pth
  14. rvc_weight_root: str = "data/models/rvc/checkpoints",
  15. rvc_weight_uvr5_root: str = "data/models/rvc/uvr5_weights",
  16. rvc_index_root: str = "data/models/rvc/checkpoints",
  17. rvc_outside_index_root: str = "data/models/rvc/checkpoints",
  18. rvc_rmvpe_root: str = "data/models/rvc/rmvpe",
  19. ):
  20. def get_suno_env(name):
  21. return os.environ.get(name, "").lower() in ("true", "1")
  22. if not env_suno_use_small_models:
  23. env_suno_use_small_models = get_suno_env("SUNO_USE_SMALL_MODELS")
  24. if not env_suno_enable_mps:
  25. env_suno_enable_mps = get_suno_env("SUNO_ENABLE_MPS")
  26. if not env_suno_offload_cpu:
  27. env_suno_offload_cpu = get_suno_env("SUNO_OFFLOAD_CPU")
  28. if not model_location_hf_env_var:
  29. model_location_hf_env_var = os.environ.get("HUGGINGFACE_HUB_CACHE", "")
  30. if not model_location_hf_env_var2:
  31. model_location_hf_env_var2 = os.environ.get("HF_HOME", "")
  32. if not model_location_th_home:
  33. model_location_th_home = os.environ.get("TORCH_HOME", "")
  34. if not model_location_th_xdg:
  35. model_location_th_xdg = os.environ.get("XDG_CACHE_HOME", "")
  36. if not rvc_weight_root:
  37. rvc_weight_root = os.environ.get("weight_root", "")
  38. if not rvc_weight_uvr5_root:
  39. rvc_weight_uvr5_root = os.environ.get("weight_uvr5_root", "")
  40. if not rvc_index_root:
  41. rvc_index_root = os.environ.get("index_root", "")
  42. if not rvc_outside_index_root:
  43. rvc_outside_index_root = os.environ.get("outside_index_root", "")
  44. if not rvc_rmvpe_root:
  45. rvc_rmvpe_root = os.environ.get("rmvpe_root", "")
  46. env = "# This file gets updated automatically from the UI\n\n"
  47. env += "# If you wish to manually specify any ENV variables, please do so in the .env.user file\n"
  48. env += "# The variables in .env.user will take PRIORITY!\n\n"
  49. env += env_entry(
  50. "SUNO_USE_SMALL_MODELS",
  51. env_suno_use_small_models,
  52. "Duplicates small models checkboxes",
  53. )
  54. env += env_entry(
  55. "SUNO_ENABLE_MPS",
  56. env_suno_enable_mps,
  57. "Use MPS when CUDA is unavailable",
  58. )
  59. env += env_entry(
  60. "SUNO_OFFLOAD_CPU", env_suno_offload_cpu, "Offload GPU models to CPU"
  61. )
  62. env += "\n"
  63. env += env_entry(
  64. "HUGGINGFACE_HUB_CACHE",
  65. model_location_hf_env_var,
  66. "Environment variable for HuggingFace model location",
  67. null=not model_location_hf_env_var,
  68. )
  69. env += env_entry(
  70. "HF_HOME",
  71. model_location_hf_env_var2,
  72. "Environment variable for HuggingFace model location (alternative)",
  73. null=not model_location_hf_env_var2,
  74. )
  75. env += env_entry(
  76. "TORCH_HOME",
  77. model_location_th_home,
  78. "Default location for Torch Hub models",
  79. null=not model_location_th_home,
  80. )
  81. env += env_entry(
  82. "XDG_CACHE_HOME",
  83. model_location_th_xdg,
  84. "Default location for Torch Hub models (alternative)",
  85. null=not model_location_th_xdg,
  86. )
  87. env += "\n"
  88. env += env_entry(
  89. "weight_root",
  90. rvc_weight_root,
  91. "Root directory for RVC model weights",
  92. null=not rvc_weight_root,
  93. )
  94. env += env_entry(
  95. "weight_uvr5_root",
  96. rvc_weight_uvr5_root,
  97. "Root directory for RVC model weights (UVR5)",
  98. null=not rvc_weight_uvr5_root,
  99. )
  100. env += env_entry(
  101. "index_root",
  102. rvc_index_root,
  103. "Root directory for RVC model indices",
  104. null=not rvc_index_root,
  105. )
  106. env += env_entry(
  107. "outside_index_root",
  108. rvc_outside_index_root,
  109. "Root directory for RVC model indices (outside)",
  110. null=not rvc_outside_index_root,
  111. )
  112. env += env_entry(
  113. "rmvpe_root",
  114. rvc_rmvpe_root,
  115. "Root directory for RVC model RMVPE",
  116. null=not rvc_rmvpe_root,
  117. )
  118. env += "\n"
  119. return env
  120. def write_env(text: str):
  121. with open(".env", "w") as outfile:
  122. outfile.write(text)
  123. def setup_or_recover():
  124. if not os.path.exists("outputs"):
  125. os.makedirs("outputs")
  126. if not os.path.exists("favorites"):
  127. os.makedirs("favorites")
  128. if not os.path.exists(".env"):
  129. print("Env file not found. Creating default env.")
  130. write_env(generate_env())