s1_train.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
  2. import os
  3. import pdb
  4. if "_CUDA_VISIBLE_DEVICES" in os.environ:
  5. os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
  6. import argparse
  7. import logging
  8. from pathlib import Path
  9. import torch, platform
  10. from pytorch_lightning import seed_everything
  11. from pytorch_lightning import Trainer
  12. from pytorch_lightning.callbacks import ModelCheckpoint
  13. from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
  14. from pytorch_lightning.strategies import DDPStrategy
  15. from AR.data.data_module import Text2SemanticDataModule
  16. from AR.models.t2s_lightning_module import Text2SemanticLightningModule
  17. from AR.utils.io import load_yaml_config
  18. logging.getLogger("numba").setLevel(logging.WARNING)
  19. logging.getLogger("matplotlib").setLevel(logging.WARNING)
  20. torch.set_float32_matmul_precision("high")
  21. from AR.utils import get_newest_ckpt
  22. from collections import OrderedDict
  23. from time import time as ttime
  24. import shutil
  25. def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
  26. dir=os.path.dirname(path)
  27. name=os.path.basename(path)
  28. tmp_path="%s.pth"%(ttime())
  29. torch.save(fea,tmp_path)
  30. shutil.move(tmp_path,"%s/%s"%(dir,name))
  31. class my_model_ckpt(ModelCheckpoint):
  32. def __init__(
  33. self,
  34. config,
  35. if_save_latest,
  36. if_save_every_weights,
  37. half_weights_save_dir,
  38. exp_name,
  39. **kwargs
  40. ):
  41. super().__init__(**kwargs)
  42. self.if_save_latest = if_save_latest
  43. self.if_save_every_weights = if_save_every_weights
  44. self.half_weights_save_dir = half_weights_save_dir
  45. self.exp_name = exp_name
  46. self.config = config
  47. def on_train_epoch_end(self, trainer, pl_module):
  48. # if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
  49. if self._should_save_on_train_epoch_end(trainer):
  50. monitor_candidates = self._monitor_candidates(trainer)
  51. if (
  52. self._every_n_epochs >= 1
  53. and (trainer.current_epoch + 1) % self._every_n_epochs == 0
  54. ):
  55. if (
  56. self.if_save_latest == True
  57. ): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
  58. to_clean = list(os.listdir(self.dirpath))
  59. self._save_topk_checkpoint(trainer, monitor_candidates)
  60. if self.if_save_latest == True:
  61. for name in to_clean:
  62. try:
  63. os.remove("%s/%s" % (self.dirpath, name))
  64. except:
  65. pass
  66. if self.if_save_every_weights == True:
  67. to_save_od = OrderedDict()
  68. to_save_od["weight"] = OrderedDict()
  69. dictt = trainer.strategy._lightning_module.state_dict()
  70. for key in dictt:
  71. to_save_od["weight"][key] = dictt[key].half()
  72. to_save_od["config"] = self.config
  73. to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
  74. # torch.save(
  75. my_save(
  76. to_save_od,
  77. "%s/%s-e%s.ckpt"
  78. % (
  79. self.half_weights_save_dir,
  80. self.exp_name,
  81. trainer.current_epoch + 1,
  82. ),
  83. )
  84. self._save_last_checkpoint(trainer, monitor_candidates)
  85. def main(args):
  86. config = load_yaml_config(args.config_file)
  87. output_dir = Path(config["output_dir"])
  88. output_dir.mkdir(parents=True, exist_ok=True)
  89. ckpt_dir = output_dir / "ckpt"
  90. ckpt_dir.mkdir(parents=True, exist_ok=True)
  91. seed_everything(config["train"]["seed"], workers=True)
  92. ckpt_callback: ModelCheckpoint = my_model_ckpt(
  93. config=config,
  94. if_save_latest=config["train"]["if_save_latest"],
  95. if_save_every_weights=config["train"]["if_save_every_weights"],
  96. half_weights_save_dir=config["train"]["half_weights_save_dir"],
  97. exp_name=config["train"]["exp_name"],
  98. save_top_k=-1,
  99. monitor="top_3_acc",
  100. mode="max",
  101. save_on_train_epoch_end=True,
  102. every_n_epochs=config["train"]["save_every_n_epoch"],
  103. dirpath=ckpt_dir,
  104. )
  105. logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
  106. os.environ["MASTER_ADDR"]="localhost"
  107. trainer: Trainer = Trainer(
  108. max_epochs=config["train"]["epochs"],
  109. accelerator="gpu",
  110. # val_check_interval=9999999999999999999999,###不要验证
  111. # check_val_every_n_epoch=None,
  112. limit_val_batches=0,
  113. devices=-1,
  114. benchmark=False,
  115. fast_dev_run=False,
  116. strategy = "auto" if torch.backends.mps.is_available() else DDPStrategy(
  117. process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
  118. ), # mps 不支持多节点训练
  119. precision=config["train"]["precision"],
  120. logger=logger,
  121. num_sanity_val_steps=0,
  122. callbacks=[ckpt_callback],
  123. )
  124. model: Text2SemanticLightningModule = Text2SemanticLightningModule(
  125. config, output_dir
  126. )
  127. data_module: Text2SemanticDataModule = Text2SemanticDataModule(
  128. config,
  129. train_semantic_path=config["train_semantic_path"],
  130. train_phoneme_path=config["train_phoneme_path"],
  131. # dev_semantic_path=args.dev_semantic_path,
  132. # dev_phoneme_path=args.dev_phoneme_path
  133. )
  134. try:
  135. # 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
  136. newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
  137. ckpt_path = ckpt_dir / newest_ckpt_name
  138. except Exception:
  139. ckpt_path = None
  140. print("ckpt_path:", ckpt_path)
  141. trainer.fit(model, data_module, ckpt_path=ckpt_path)
  142. # srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
  143. if __name__ == "__main__":
  144. parser = argparse.ArgumentParser()
  145. parser.add_argument(
  146. "-c",
  147. "--config_file",
  148. type=str,
  149. default="configs/s1longer.yaml",
  150. help="path of config file",
  151. )
  152. # args for dataset
  153. # parser.add_argument('--train_semantic_path',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/6-name2semantic.tsv')
  154. # parser.add_argument('--train_phoneme_path', type=str, default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/2-name2text.txt')
  155. # parser.add_argument('--dev_semantic_path', type=str, default='dump_mix/semantic_dev.tsv')
  156. # parser.add_argument('--dev_phoneme_path', type=str, default='dump_mix/phoneme_dev.npy')
  157. # parser.add_argument('--output_dir',type=str,default='/data/docker/liujing04/gpt-vits/fine_tune_dataset/xuangou/logs_s1',help='directory to save the results')
  158. # parser.add_argument('--output_dir',type=str,default='/liujing04/gpt_logs/s1/xuangou_ft',help='directory to save the results')
  159. args = parser.parse_args()
  160. logging.info(str(args))
  161. main(args)