2
0

initializeApp.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. const fs = require("fs");
  2. const { resolve } = require("path");
  3. const { displayError, displayMessage } = require("./displayMessage.js");
  4. const { processExit } = require("./processExit.js");
  5. const { menu } = require("./menu.js");
  6. const { $, $$, $sh } = require("./shell.js");
  7. const { applyDatabaseConfig } = require("./applyDatabaseConfig.js");
  8. const DEBUG_DRY_RUN = false;
  9. const torchVersion = "2.3.1"; // 2.4.1+cu118
  10. const cudaVersion = "11.8";
  11. const pythonVersion = `3.10.11`;
  12. const pythonPackage = `python=${pythonVersion}`;
  13. const ffmpegPackage = `conda-forge::ffmpeg=4.4.2[build=lgpl*]`;
  14. const nodePackage = `conda-forge::nodejs=22.9.0`;
  15. const anacondaPostgresqlPackage = `conda-forge::postgresql=16.4`;
  16. // const terraformPackage = `conda-forge::terraform=1.8.2`;
  17. const terraformPackage = ``;
  18. const cudaChannels = [
  19. "",
  20. "pytorch",
  21. `nvidia/label/cuda-${cudaVersion}.0`,
  22. "nvidia",
  23. ].join(" -c ");
  24. const cpuChannels = ["", "pytorch"].join(" -c ");
  25. const windowsOnlyPackages =
  26. process.platform === "win32" ? ["conda-forge::vswhere"] : [];
  27. const commonPackages = [
  28. "conda-forge::uv=0.4.17",
  29. ...windowsOnlyPackages,
  30. terraformPackage,
  31. anacondaPostgresqlPackage,
  32. nodePackage,
  33. ffmpegPackage,
  34. ];
  35. const cudaPackages = `pytorch[version=${torchVersion},build=py3.10_cuda${cudaVersion}*] pytorch-cuda=${cudaVersion} torchvision torchaudio cuda-toolkit ninja`;
  36. const cudaPytorchInstall$ = [
  37. "conda install -y -k",
  38. ...commonPackages,
  39. cudaPackages,
  40. cudaChannels,
  41. ].join(" ");
  42. const cpuPackages = `pytorch=${torchVersion} torchvision torchaudio cpuonly`;
  43. const pytorchCPUInstall$ = [
  44. "conda install -y -k",
  45. ...commonPackages,
  46. cpuPackages,
  47. cpuChannels,
  48. ].join(" ");
  49. const baseOnlyInstall$ = ["conda install -y -k", ...commonPackages].join(" ");
  50. const ensurePythonVersion = async () => {
  51. try {
  52. displayMessage("Checking python version...");
  53. const version = await getPythonVersion();
  54. if (version !== `Python ${pythonVersion}`) {
  55. displayMessage(`Current python version is """${version}"""`);
  56. displayMessage(`Python version is not ${pythonVersion}. Reinstalling...`);
  57. await $(`conda install -y -k -c conda-forge ${pythonPackage}`);
  58. await $(`conda install -y -k -c conda-forge pip==23.3.2`);
  59. }
  60. } catch (error) {
  61. displayError("Failed to check/install python version");
  62. }
  63. async function getPythonVersion() {
  64. await $sh(`python --version > installer_scripts/.python_version`);
  65. return fs.readFileSync("installer_scripts/.python_version", "utf8").trim();
  66. }
  67. };
  68. const installDependencies = async (gpuchoice) => {
  69. try {
  70. if (gpuchoice === "NVIDIA GPU") {
  71. await $(cudaPytorchInstall$);
  72. } else if (gpuchoice === "Apple M Series Chip" || gpuchoice === "CPU") {
  73. await $(pytorchCPUInstall$);
  74. } else if (gpuchoice === "AMD GPU (ROCM, Linux only, potentially broken)") {
  75. displayMessage(
  76. "ROCM is experimental and not well supported yet, installing..."
  77. );
  78. displayMessage("Linux only!");
  79. await $(baseOnlyInstall$);
  80. await $(
  81. `pip install torch==${torchVersion} torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0`
  82. );
  83. } else if (gpuchoice === "pip torch cpu (experimental, faster install)") {
  84. await $(baseOnlyInstall$);
  85. await $(`pip install torch==${torchVersion} torchvision torchaudio`);
  86. // uv pip install torch==2.5.0+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
  87. } else {
  88. displayMessage("Unsupported or cancelled. Exiting...");
  89. removeGPUChoice();
  90. processExit(1);
  91. }
  92. saveMajorVersion(majorVersion);
  93. displayMessage(" Successfully installed torch");
  94. await pip_install_all(true); // approximate first install
  95. } catch (error) {
  96. displayError(`Error during installation: ${error.message}`);
  97. throw error;
  98. }
  99. };
  100. const askForGPUChoice = () =>
  101. menu(
  102. [
  103. "NVIDIA GPU",
  104. "Apple M Series Chip",
  105. "CPU",
  106. "Cancel",
  107. "AMD GPU (ROCM, Linux only, potentially broken)",
  108. "Intel GPU (unsupported)",
  109. "Integrated GPU (unsupported)",
  110. "pip torch cpu (experimental, faster install)",
  111. ],
  112. `
  113. These are not yet automatically supported: AMD GPU, Intel GPU, Integrated GPU.
  114. Select the device (GPU/CPU) you are using to run the application:
  115. (use arrow keys to move, enter to select)
  116. `
  117. );
  118. const getInstallerFilesPath = (...files) => resolve(__dirname, "..", ...files);
  119. const gpuFile = getInstallerFilesPath(".gpu");
  120. const majorVersionFile = getInstallerFilesPath(".major_version");
  121. const pipPackagesFile = getInstallerFilesPath(".pip_packages");
  122. const majorVersion = "4";
  123. const versions = JSON.parse(
  124. fs.readFileSync(getInstallerFilesPath("versions.json"))
  125. );
  126. const newPipPackagesVersion = String(versions.pip_packages);
  127. const readGeneric = (file) => {
  128. if (fs.existsSync(file)) {
  129. return fs.readFileSync(file, "utf8");
  130. }
  131. return -1;
  132. };
  133. const saveGeneric = (file, data) => fs.writeFileSync(file, data.toString());
  134. const readMajorVersion = () => readGeneric(majorVersionFile);
  135. const saveMajorVersion = (data) => saveGeneric(majorVersionFile, data);
  136. const readPipPackagesVersion = () => readGeneric(pipPackagesFile);
  137. const savePipPackagesVersion = (data) => saveGeneric(pipPackagesFile, data);
  138. const readGPUChoice = () => readGeneric(gpuFile);
  139. const saveGPUChoice = (data) => saveGeneric(gpuFile, data);
  140. const removeGPUChoice = () => {
  141. if (fs.existsSync(gpuFile)) fs.unlinkSync(gpuFile);
  142. };
  143. const dry_run_flag = DEBUG_DRY_RUN ? "--dry-run " : "";
  144. function pip_install(requirements, name = "", pipFallback = false) {
  145. try {
  146. displayMessage(`Installing ${name || requirements} dependencies...`);
  147. $sh(
  148. `${
  149. pipFallback ? "pip" : "uv pip"
  150. } install ${dry_run_flag}${requirements} torch==${torchVersion}`
  151. );
  152. displayMessage(
  153. `Successfully installed ${name || requirements} dependencies\n`
  154. );
  155. } catch (error) {
  156. displayMessage(`Failed to install ${name || requirements} dependencies\n`);
  157. }
  158. }
  159. // The first install is a temporary safeguard due to mysterious issues with uv
  160. async function pip_install_all(first_install = false) {
  161. if (readPipPackagesVersion() === newPipPackagesVersion)
  162. return displayMessage(
  163. "Dependencies are already up to date, skipping pip installs..."
  164. );
  165. const pip_install_all_choice = await menu(
  166. ["Yes", "No"],
  167. `Attempt single pip install of all dependencies (potentially faster)?
  168. (use arrow keys to move, enter to select)`
  169. );
  170. if (pip_install_all_choice === "Yes") {
  171. try {
  172. displayMessage("Attempting single pip install of all dependencies...");
  173. pip_install(
  174. "-r requirements.txt -r requirements_bark_hubert_quantizer.txt -r requirements_rvc.txt -r requirements_audiocraft.txt -r requirements_styletts2.txt -r requirements_vall_e.txt -r requirements_maha_tts.txt -r requirements_stable_audio.txt hydra-core==1.3.2 nvidia-ml-py",
  175. "All dependencies",
  176. first_install
  177. );
  178. savePipPackagesVersion(newPipPackagesVersion);
  179. displayMessage("");
  180. return;
  181. } catch (error) {
  182. displayMessage("Failed to install all dependencies, falling back to individual installs...");
  183. }
  184. }
  185. displayMessage("Updating dependencies...");
  186. // pip_install_all(false); // potential speed optimization
  187. pip_install(
  188. "-r requirements.txt",
  189. "Core Packages, Bark, Tortoise",
  190. first_install
  191. );
  192. pip_install(
  193. "xformers==0.0.27+cu118 --index-url https://download.pytorch.org/whl/cu118",
  194. "xformers"
  195. );
  196. pip_install(
  197. "-r requirements_bark_hubert_quantizer.txt",
  198. "Bark Voice Clone",
  199. first_install
  200. );
  201. pip_install("-r requirements_rvc.txt", "RVC", first_install);
  202. pip_install("-r requirements_audiocraft.txt", "Audiocraft", first_install);
  203. pip_install("-r requirements_styletts2.txt", "StyleTTS", first_install);
  204. pip_install("-r requirements_vall_e.txt", "Vall-E-X", first_install);
  205. pip_install("-r requirements_maha_tts.txt", "Maha TTS", first_install);
  206. pip_install("-r requirements_stable_audio.txt", "Stable Audio", true);
  207. // reinstall hydra-core==1.3.2 because of fairseq
  208. pip_install(
  209. "hydra-core==1.3.2",
  210. "hydra-core fix due to fairseq",
  211. first_install
  212. );
  213. pip_install("nvidia-ml-py", "nvidia-ml-py", first_install);
  214. savePipPackagesVersion(newPipPackagesVersion);
  215. displayMessage("");
  216. }
  217. const checkIfTorchInstalled = async () => {
  218. try {
  219. await $$([
  220. "python",
  221. "-c",
  222. 'import importlib.util; import sys; package_name = "torch"; spec = importlib.util.find_spec(package_name); sys.exit(0) if spec else sys.exit(1)',
  223. ]);
  224. return true;
  225. } catch (error) {
  226. return false;
  227. }
  228. };
  229. const FORCE_REINSTALL = process.env.FORCE_REINSTALL ? true : false;
  230. async function applyCondaConfig() {
  231. displayMessage("Applying conda config...");
  232. displayMessage(" Checking if Torch is installed...");
  233. if (readMajorVersion() === majorVersion && !FORCE_REINSTALL) {
  234. if (await checkIfTorchInstalled()) {
  235. displayMessage(" Torch is already installed. Skipping installation...");
  236. await pip_install_all();
  237. return;
  238. } else {
  239. displayMessage(" Torch is not installed. Starting installation...\n");
  240. }
  241. } else {
  242. displayMessage(
  243. " Major version update detected. Upgrading base environment"
  244. );
  245. }
  246. if (fs.existsSync(gpuFile)) {
  247. const gpuchoice = readGPUChoice();
  248. displayMessage(` Using saved GPU choice: ${gpuchoice}`);
  249. await installDependencies(gpuchoice);
  250. return;
  251. } else {
  252. const gpuchoice = await askForGPUChoice();
  253. displayMessage(` You selected: ${gpuchoice}`);
  254. saveGPUChoice(gpuchoice);
  255. await installDependencies(gpuchoice);
  256. }
  257. }
  258. exports.initializeApp = async () => {
  259. displayMessage("Ensuring that python has the correct version...");
  260. await ensurePythonVersion();
  261. displayMessage("");
  262. await applyCondaConfig();
  263. displayMessage("");
  264. try {
  265. await applyDatabaseConfig();
  266. displayMessage("");
  267. } catch (error) {
  268. displayError("Failed to apply database config");
  269. }
  270. };
  271. const checkIfTorchHasCuda = async () => {
  272. try {
  273. displayMessage("Checking if torch has CUDA...");
  274. await $$([
  275. "python",
  276. "-c",
  277. "import torch; exit(0 if torch.cuda.is_available() else 1)",
  278. ]);
  279. return true;
  280. } catch (error) {
  281. return false;
  282. }
  283. };
  284. exports.repairTorch = async () => {
  285. const gpuChoice = readGPUChoice();
  286. if (!checkIfTorchHasCuda() && gpuChoice === "NVIDIA GPU") {
  287. displayMessage("Backend is NVIDIA GPU, fixing PyTorch");
  288. try {
  289. await $(`conda install -y -k --force-reinstall ${cudaPackages}`);
  290. } catch (error) {
  291. displayError("Failed to fix torch");
  292. }
  293. } else if (gpuChoice === "CPU" || gpuChoice === "Apple M Series Chip") {
  294. displayMessage("Backend is CPU/Apple M Series Chip, fixing PyTorch");
  295. try {
  296. await $(`conda install -y -k --force-reinstall ${cpuPackages}`);
  297. } catch (error) {
  298. displayError("Failed to fix torch");
  299. }
  300. displayMessage("Torch has CUDA, skipping reinstall");
  301. }
  302. };
  303. function setupReactUIExtensions() {
  304. try {
  305. displayMessage("Initializing extensions...");
  306. const packageJSONpath = getInstallerFilesPath(
  307. "..",
  308. "react-ui",
  309. "src",
  310. "extensions",
  311. "package.json"
  312. );
  313. if (!fs.existsSync(packageJSONpath)) {
  314. fs.writeFileSync(packageJSONpath, "{}");
  315. }
  316. // $sh("cd react-ui/src/extensions && npm install");
  317. // displayMessage("Successfully installed extensions");
  318. } catch (error) {
  319. displayMessage("Failed to install extensions");
  320. throw error;
  321. }
  322. }
  323. exports.setupReactUI = () => {
  324. try {
  325. setupReactUIExtensions();
  326. displayMessage("Installing node_modules...");
  327. $sh("cd react-ui && npm install");
  328. displayMessage("Successfully installed node_modules");
  329. displayMessage("Building react-ui...");
  330. $sh("cd react-ui && npm run build");
  331. displayMessage("Successfully built react-ui");
  332. } catch (error) {
  333. displayMessage("Failed to install node_modules or build react-ui");
  334. throw error;
  335. }
  336. };