2
0

initializeApp.js 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. const fs = require("fs");
  2. const { resolve } = require("path");
  3. const { displayError, displayMessage } = require("./displayMessage.js");
  4. const { processExit } = require("./processExit.js");
  5. const { menu } = require("./menu.js");
  6. const { $, $$, $sh } = require("./shell.js");
  7. const { applyDatabaseConfig } = require("./applyDatabaseConfig.js");
  8. const DEBUG_DRY_RUN = false;
  9. const torchVersion = "2.3.1"; // 2.4.1+cu118
  10. const cudaVersion = "11.8";
  11. const pythonVersion = `3.10.11`;
  12. const pythonPackage = `python=${pythonVersion}`;
  13. const ffmpegPackage = `conda-forge::ffmpeg=4.4.2[build=lgpl*]`;
  14. const nodePackage = `conda-forge::nodejs=22.9.0`;
  15. const anacondaPostgresqlPackage = `conda-forge::postgresql=16.4`;
  16. // const terraformPackage = `conda-forge::terraform=1.8.2`;
  17. const terraformPackage = ``;
  18. const cudaChannels = [
  19. "",
  20. "pytorch",
  21. `nvidia/label/cuda-${cudaVersion}.0`,
  22. "nvidia",
  23. ].join(" -c ");
  24. const cpuChannels = ["", "pytorch"].join(" -c ");
  25. const windowsOnlyPackages =
  26. process.platform === "win32" ? ["conda-forge::vswhere"] : [];
  27. const commonPackages = [
  28. "conda-forge::uv=0.4.17",
  29. ...windowsOnlyPackages,
  30. terraformPackage,
  31. anacondaPostgresqlPackage,
  32. nodePackage,
  33. ffmpegPackage,
  34. ];
  35. const cudaPackages = `pytorch[version=${torchVersion},build=py3.10_cuda${cudaVersion}.*] pytorch-cuda=${cudaVersion} torchvision torchaudio cuda-toolkit ninja`;
  36. const cudaPytorchInstall$ = [
  37. "conda install -y -k",
  38. ...commonPackages,
  39. cudaPackages,
  40. cudaChannels,
  41. ].join(" ");
  42. const cpuPackages = `pytorch=${torchVersion} torchvision torchaudio cpuonly`;
  43. const pytorchCPUInstall$ = [
  44. "conda install -y -k",
  45. ...commonPackages,
  46. cpuPackages,
  47. cpuChannels,
  48. ].join(" ");
  49. const baseOnlyInstall$ = ["conda install -y -k", ...commonPackages].join(" ");
  50. const ensurePythonVersion = async () => {
  51. try {
  52. displayMessage("Checking python version...");
  53. const version = await getPythonVersion();
  54. if (version !== `Python ${pythonVersion}`) {
  55. displayMessage(`Current python version is """${version}"""`);
  56. displayMessage(`Python version is not ${pythonVersion}. Reinstalling...`);
  57. await $(`conda install -y -k -c conda-forge ${pythonPackage}`);
  58. await $(`conda install -y -k -c conda-forge pip==23.3.2`);
  59. }
  60. } catch (error) {
  61. displayError("Failed to check/install python version");
  62. }
  63. async function getPythonVersion() {
  64. await $sh(`python --version > installer_scripts/.python_version`);
  65. return fs.readFileSync("installer_scripts/.python_version", "utf8").trim();
  66. }
  67. };
  68. const installDependencies = async (gpuchoice) => {
  69. try {
  70. if (gpuchoice === "NVIDIA GPU") {
  71. await $(cudaPytorchInstall$);
  72. } else if (gpuchoice === "Apple M Series Chip" || gpuchoice === "CPU") {
  73. await $(pytorchCPUInstall$);
  74. } else if (gpuchoice === "AMD GPU (ROCM, Linux only, potentially broken)") {
  75. displayMessage(
  76. "ROCM is experimental and not well supported yet, installing..."
  77. );
  78. displayMessage("Linux only!");
  79. await $(baseOnlyInstall$);
  80. await $(
  81. `pip install torch==${torchVersion} torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.0`
  82. );
  83. } else if (gpuchoice === "pip torch cpu (experimental, faster install)") {
  84. await $(baseOnlyInstall$);
  85. await $(`pip install torch==${torchVersion} torchvision torchaudio`);
  86. // uv pip install torch==2.5.0+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
  87. } else {
  88. displayMessage("Unsupported or cancelled. Exiting...");
  89. removeGPUChoice();
  90. processExit(1);
  91. }
  92. saveMajorVersion(majorVersion);
  93. displayMessage(" Successfully installed torch");
  94. await pip_install_all(true); // approximate first install
  95. } catch (error) {
  96. displayError(`Error during installation: ${error.message}`);
  97. throw error;
  98. }
  99. };
  100. const askForGPUChoice = () =>
  101. menu(
  102. [
  103. "NVIDIA GPU",
  104. "Apple M Series Chip",
  105. "CPU",
  106. "Cancel",
  107. "AMD GPU (ROCM, Linux only, potentially broken)",
  108. "Intel GPU (unsupported)",
  109. "Integrated GPU (unsupported)",
  110. "pip torch cpu (experimental, faster install)",
  111. ],
  112. `
  113. These are not yet automatically supported: AMD GPU, Intel GPU, Integrated GPU.
  114. Select the device (GPU/CPU) you are using to run the application:
  115. (use arrow keys to move, enter to select)
  116. `
  117. );
  118. const getInstallerFilesPath = (...files) => resolve(__dirname, "..", ...files);
  119. const gpuFile = getInstallerFilesPath(".gpu");
  120. const majorVersionFile = getInstallerFilesPath(".major_version");
  121. const pipPackagesFile = getInstallerFilesPath(".pip_packages");
  122. const majorVersion = "4";
  123. const versions = JSON.parse(
  124. fs.readFileSync(getInstallerFilesPath("versions.json"))
  125. );
  126. const newPipPackagesVersion = String(versions.pip_packages);
  127. const readGeneric = (file) => {
  128. if (fs.existsSync(file)) {
  129. return fs.readFileSync(file, "utf8");
  130. }
  131. return -1;
  132. };
  133. const saveGeneric = (file, data) => fs.writeFileSync(file, data.toString());
  134. const readMajorVersion = () => readGeneric(majorVersionFile);
  135. const saveMajorVersion = (data) => saveGeneric(majorVersionFile, data);
  136. const readPipPackagesVersion = () => readGeneric(pipPackagesFile);
  137. const savePipPackagesVersion = (data) => saveGeneric(pipPackagesFile, data);
  138. const readGPUChoice = () => readGeneric(gpuFile);
  139. const saveGPUChoice = (data) => saveGeneric(gpuFile, data);
  140. const removeGPUChoice = () => {
  141. if (fs.existsSync(gpuFile)) fs.unlinkSync(gpuFile);
  142. };
  143. const dry_run_flag = DEBUG_DRY_RUN ? "--dry-run " : "";
  144. function pip_install(requirements, name = "", pipFallback = false) {
  145. try {
  146. displayMessage(`Installing ${name || requirements} dependencies...`);
  147. $sh(
  148. `${
  149. pipFallback ? "pip" : "uv pip"
  150. } install ${dry_run_flag}${requirements} torch==${torchVersion}`
  151. );
  152. displayMessage(
  153. `Successfully installed ${name || requirements} dependencies\n`
  154. );
  155. } catch (error) {
  156. displayMessage(`Failed to install ${name || requirements} dependencies\n`);
  157. }
  158. }
  159. // The first install is a temporary safeguard due to mysterious issues with uv
  160. async function pip_install_all(first_install = false) {
  161. if (readPipPackagesVersion() === newPipPackagesVersion)
  162. return displayMessage(
  163. "Dependencies are already up to date, skipping pip installs..."
  164. );
  165. const pip_install_all_choice = await menu(
  166. ["Yes", "No"],
  167. `Attempt single pip install of all dependencies (potentially faster)?
  168. (use arrow keys to move, enter to select)`
  169. );
  170. if (pip_install_all_choice === "Yes") {
  171. try {
  172. displayMessage("Attempting single pip install of all dependencies...");
  173. pip_install(
  174. "xformers==0.0.27+cu118 --index-url https://download.pytorch.org/whl/cu118",
  175. "xformers"
  176. );
  177. pip_install(
  178. "-r requirements.txt -r requirements_bark_hubert_quantizer.txt -r requirements_rvc.txt -r requirements_audiocraft.txt -r requirements_styletts2.txt -r requirements_vall_e.txt -r requirements_maha_tts.txt -r requirements_stable_audio.txt hydra-core==1.3.2 nvidia-ml-py",
  179. "All dependencies",
  180. first_install
  181. );
  182. savePipPackagesVersion(newPipPackagesVersion);
  183. displayMessage("");
  184. return;
  185. } catch (error) {
  186. displayMessage(
  187. "Failed to install all dependencies, falling back to individual installs..."
  188. );
  189. }
  190. }
  191. displayMessage("Updating dependencies...");
  192. // pip_install_all(false); // potential speed optimization
  193. pip_install(
  194. "-r requirements.txt",
  195. "Core Packages, Bark, Tortoise",
  196. first_install
  197. );
  198. pip_install(
  199. "xformers==0.0.27+cu118 --index-url https://download.pytorch.org/whl/cu118",
  200. "xformers"
  201. );
  202. pip_install(
  203. "-r requirements_bark_hubert_quantizer.txt",
  204. "Bark Voice Clone",
  205. first_install
  206. );
  207. pip_install("-r requirements_rvc.txt", "RVC", first_install);
  208. pip_install("-r requirements_audiocraft.txt", "Audiocraft", first_install);
  209. pip_install("-r requirements_styletts2.txt", "StyleTTS", first_install);
  210. pip_install("-r requirements_vall_e.txt", "Vall-E-X", first_install);
  211. pip_install("-r requirements_maha_tts.txt", "Maha TTS", first_install);
  212. pip_install("-r requirements_stable_audio.txt", "Stable Audio", true);
  213. // reinstall hydra-core==1.3.2 because of fairseq
  214. pip_install(
  215. "hydra-core==1.3.2",
  216. "hydra-core fix due to fairseq",
  217. first_install
  218. );
  219. pip_install("nvidia-ml-py", "nvidia-ml-py", first_install);
  220. savePipPackagesVersion(newPipPackagesVersion);
  221. displayMessage("");
  222. }
  223. const checkIfTorchInstalled = async () => {
  224. try {
  225. await $$([
  226. "python",
  227. "-c",
  228. 'import importlib.util; import sys; package_name = "torch"; spec = importlib.util.find_spec(package_name); sys.exit(0) if spec else sys.exit(1)',
  229. ]);
  230. return true;
  231. } catch (error) {
  232. return false;
  233. }
  234. };
  235. const FORCE_REINSTALL = process.env.FORCE_REINSTALL ? true : false;
  236. async function applyCondaConfig() {
  237. displayMessage("Applying conda config...");
  238. displayMessage(" Checking if Torch is installed...");
  239. if (readMajorVersion() === majorVersion && !FORCE_REINSTALL) {
  240. if (await checkIfTorchInstalled()) {
  241. displayMessage(" Torch is already installed. Skipping installation...");
  242. await pip_install_all();
  243. return;
  244. } else {
  245. displayMessage(" Torch is not installed. Starting installation...\n");
  246. }
  247. } else {
  248. displayMessage(
  249. " Major version update detected. Upgrading base environment"
  250. );
  251. }
  252. if (fs.existsSync(gpuFile)) {
  253. const gpuchoice = readGPUChoice();
  254. displayMessage(` Using saved GPU choice: ${gpuchoice}`);
  255. await installDependencies(gpuchoice);
  256. return;
  257. } else {
  258. const gpuchoice = await askForGPUChoice();
  259. displayMessage(` You selected: ${gpuchoice}`);
  260. saveGPUChoice(gpuchoice);
  261. await installDependencies(gpuchoice);
  262. }
  263. }
  264. exports.initializeApp = async () => {
  265. displayMessage("Ensuring that python has the correct version...");
  266. await ensurePythonVersion();
  267. displayMessage("");
  268. await applyCondaConfig();
  269. displayMessage("");
  270. try {
  271. await applyDatabaseConfig();
  272. displayMessage("");
  273. } catch (error) {
  274. displayError("Failed to apply database config");
  275. }
  276. };
  277. const checkIfTorchHasCuda = async () => {
  278. try {
  279. displayMessage("Checking if torch has CUDA...");
  280. await $$([
  281. "python",
  282. "-c",
  283. "import torch; exit(0 if torch.cuda.is_available() else 1)",
  284. ]);
  285. return true;
  286. } catch (error) {
  287. return false;
  288. }
  289. };
  290. exports.repairTorch = async () => {
  291. const gpuChoice = readGPUChoice();
  292. if (!checkIfTorchHasCuda() && gpuChoice === "NVIDIA GPU") {
  293. displayMessage("Backend is NVIDIA GPU, fixing PyTorch");
  294. try {
  295. await $(`conda install -y -k --force-reinstall ${cudaPackages}`);
  296. } catch (error) {
  297. displayError("Failed to fix torch");
  298. }
  299. } else if (gpuChoice === "CPU" || gpuChoice === "Apple M Series Chip") {
  300. displayMessage("Backend is CPU/Apple M Series Chip, fixing PyTorch");
  301. try {
  302. await $(`conda install -y -k --force-reinstall ${cpuPackages}`);
  303. } catch (error) {
  304. displayError("Failed to fix torch");
  305. }
  306. displayMessage("Torch has CUDA, skipping reinstall");
  307. }
  308. };
  309. function setupReactUIExtensions() {
  310. try {
  311. displayMessage("Initializing extensions...");
  312. const packageJSONpath = getInstallerFilesPath(
  313. "..",
  314. "react-ui",
  315. "src",
  316. "extensions",
  317. "package.json"
  318. );
  319. if (!fs.existsSync(packageJSONpath)) {
  320. fs.writeFileSync(packageJSONpath, "{}");
  321. }
  322. // $sh("cd react-ui/src/extensions && npm install");
  323. // displayMessage("Successfully installed extensions");
  324. } catch (error) {
  325. displayMessage("Failed to install extensions");
  326. throw error;
  327. }
  328. }
  329. exports.setupReactUI = () => {
  330. try {
  331. setupReactUIExtensions();
  332. displayMessage("Installing node_modules...");
  333. $sh("cd react-ui && npm install");
  334. displayMessage("Successfully installed node_modules");
  335. displayMessage("Building react-ui...");
  336. $sh("cd react-ui && npm run build");
  337. displayMessage("Successfully built react-ui");
  338. } catch (error) {
  339. displayMessage("Failed to install node_modules or build react-ui");
  340. throw error;
  341. }
  342. };