s2_train.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. import utils, os
  2. hps = utils.get_hparams(stage=2)
  3. os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
  4. import torch
  5. from torch.nn import functional as F
  6. from torch.utils.data import DataLoader
  7. from torch.utils.tensorboard import SummaryWriter
  8. import torch.multiprocessing as mp
  9. import torch.distributed as dist, traceback
  10. from torch.nn.parallel import DistributedDataParallel as DDP
  11. from torch.cuda.amp import autocast, GradScaler
  12. from tqdm import tqdm
  13. import logging, traceback
  14. logging.getLogger("matplotlib").setLevel(logging.INFO)
  15. logging.getLogger("h5py").setLevel(logging.INFO)
  16. logging.getLogger("numba").setLevel(logging.INFO)
  17. from random import randint
  18. from module import commons
  19. from module.data_utils import (
  20. TextAudioSpeakerLoader,
  21. TextAudioSpeakerCollate,
  22. DistributedBucketSampler,
  23. )
  24. from module.models import (
  25. SynthesizerTrn,
  26. MultiPeriodDiscriminator,
  27. )
  28. from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
  29. from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
  30. from process_ckpt import savee
  31. torch.backends.cudnn.benchmark = False
  32. torch.backends.cudnn.deterministic = False
  33. ###反正A100fp32更快,那试试tf32吧
  34. torch.backends.cuda.matmul.allow_tf32 = True
  35. torch.backends.cudnn.allow_tf32 = True
  36. torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
  37. # from config import pretrained_s2G,pretrained_s2D
  38. global_step = 0
  39. def main():
  40. """Assume Single Node Multi GPUs Training Only"""
  41. assert torch.cuda.is_available() or torch.backends.mps.is_available(), "Only GPU training is allowed."
  42. if torch.backends.mps.is_available():
  43. n_gpus = 1
  44. else:
  45. n_gpus = torch.cuda.device_count()
  46. os.environ["MASTER_ADDR"] = "localhost"
  47. os.environ["MASTER_PORT"] = str(randint(20000, 55555))
  48. mp.spawn(
  49. run,
  50. nprocs=n_gpus,
  51. args=(
  52. n_gpus,
  53. hps,
  54. ),
  55. )
  56. def run(rank, n_gpus, hps):
  57. global global_step
  58. if rank == 0:
  59. logger = utils.get_logger(hps.data.exp_dir)
  60. logger.info(hps)
  61. # utils.check_git_hash(hps.s2_ckpt_dir)
  62. writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
  63. writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
  64. dist.init_process_group(
  65. backend = "gloo" if os.name == "nt" or torch.backends.mps.is_available() else "nccl",
  66. init_method="env://",
  67. world_size=n_gpus,
  68. rank=rank,
  69. )
  70. torch.manual_seed(hps.train.seed)
  71. if torch.cuda.is_available():
  72. torch.cuda.set_device(rank)
  73. train_dataset = TextAudioSpeakerLoader(hps.data) ########
  74. train_sampler = DistributedBucketSampler(
  75. train_dataset,
  76. hps.train.batch_size,
  77. [
  78. 32,
  79. 300,
  80. 400,
  81. 500,
  82. 600,
  83. 700,
  84. 800,
  85. 900,
  86. 1000,
  87. 1100,
  88. 1200,
  89. 1300,
  90. 1400,
  91. 1500,
  92. 1600,
  93. 1700,
  94. 1800,
  95. 1900,
  96. ],
  97. num_replicas=n_gpus,
  98. rank=rank,
  99. shuffle=True,
  100. )
  101. collate_fn = TextAudioSpeakerCollate()
  102. train_loader = DataLoader(
  103. train_dataset,
  104. num_workers=6,
  105. shuffle=False,
  106. pin_memory=True,
  107. collate_fn=collate_fn,
  108. batch_sampler=train_sampler,
  109. persistent_workers=True,
  110. prefetch_factor=16,
  111. )
  112. # if rank == 0:
  113. # eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
  114. # eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
  115. # batch_size=1, pin_memory=True,
  116. # drop_last=False, collate_fn=collate_fn)
  117. net_g = SynthesizerTrn(
  118. hps.data.filter_length // 2 + 1,
  119. hps.train.segment_size // hps.data.hop_length,
  120. n_speakers=hps.data.n_speakers,
  121. **hps.model,
  122. ).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
  123. hps.data.filter_length // 2 + 1,
  124. hps.train.segment_size // hps.data.hop_length,
  125. n_speakers=hps.data.n_speakers,
  126. **hps.model,
  127. ).to("mps")
  128. net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to("mps")
  129. for name, param in net_g.named_parameters():
  130. if not param.requires_grad:
  131. print(name, "not requires_grad")
  132. te_p = list(map(id, net_g.enc_p.text_embedding.parameters()))
  133. et_p = list(map(id, net_g.enc_p.encoder_text.parameters()))
  134. mrte_p = list(map(id, net_g.enc_p.mrte.parameters()))
  135. base_params = filter(
  136. lambda p: id(p) not in te_p + et_p + mrte_p and p.requires_grad,
  137. net_g.parameters(),
  138. )
  139. # te_p=net_g.enc_p.text_embedding.parameters()
  140. # et_p=net_g.enc_p.encoder_text.parameters()
  141. # mrte_p=net_g.enc_p.mrte.parameters()
  142. optim_g = torch.optim.AdamW(
  143. # filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
  144. [
  145. {"params": base_params, "lr": hps.train.learning_rate},
  146. {
  147. "params": net_g.enc_p.text_embedding.parameters(),
  148. "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
  149. },
  150. {
  151. "params": net_g.enc_p.encoder_text.parameters(),
  152. "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
  153. },
  154. {
  155. "params": net_g.enc_p.mrte.parameters(),
  156. "lr": hps.train.learning_rate * hps.train.text_low_lr_rate,
  157. },
  158. ],
  159. hps.train.learning_rate,
  160. betas=hps.train.betas,
  161. eps=hps.train.eps,
  162. )
  163. optim_d = torch.optim.AdamW(
  164. net_d.parameters(),
  165. hps.train.learning_rate,
  166. betas=hps.train.betas,
  167. eps=hps.train.eps,
  168. )
  169. if torch.cuda.is_available():
  170. net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
  171. net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
  172. else:
  173. net_g = net_g.to("mps")
  174. net_d = net_d.to("mps")
  175. try: # 如果能加载自动resume
  176. _, _, _, epoch_str = utils.load_checkpoint(
  177. utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "D_*.pth"),
  178. net_d,
  179. optim_d,
  180. ) # D多半加载没事
  181. if rank == 0:
  182. logger.info("loaded D")
  183. # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
  184. _, _, _, epoch_str = utils.load_checkpoint(
  185. utils.latest_checkpoint_path("%s/logs_s2" % hps.data.exp_dir, "G_*.pth"),
  186. net_g,
  187. optim_g,
  188. )
  189. global_step = (epoch_str - 1) * len(train_loader)
  190. # epoch_str = 1
  191. # global_step = 0
  192. except: # 如果首次不能加载,加载pretrain
  193. # traceback.print_exc()
  194. epoch_str = 1
  195. global_step = 0
  196. if hps.train.pretrained_s2G != "":
  197. if rank == 0:
  198. logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
  199. print(
  200. net_g.module.load_state_dict(
  201. torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
  202. strict=False,
  203. ) if torch.cuda.is_available() else net_g.load_state_dict(
  204. torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
  205. strict=False,
  206. )
  207. ) ##测试不加载优化器
  208. if hps.train.pretrained_s2D != "":
  209. if rank == 0:
  210. logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
  211. print(
  212. net_d.module.load_state_dict(
  213. torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
  214. ) if torch.cuda.is_available() else net_d.load_state_dict(
  215. torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
  216. )
  217. )
  218. # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
  219. # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
  220. scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
  221. optim_g, gamma=hps.train.lr_decay, last_epoch=-1
  222. )
  223. scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
  224. optim_d, gamma=hps.train.lr_decay, last_epoch=-1
  225. )
  226. for _ in range(epoch_str):
  227. scheduler_g.step()
  228. scheduler_d.step()
  229. scaler = GradScaler(enabled=hps.train.fp16_run)
  230. for epoch in range(epoch_str, hps.train.epochs + 1):
  231. if rank == 0:
  232. train_and_evaluate(
  233. rank,
  234. epoch,
  235. hps,
  236. [net_g, net_d],
  237. [optim_g, optim_d],
  238. [scheduler_g, scheduler_d],
  239. scaler,
  240. # [train_loader, eval_loader], logger, [writer, writer_eval])
  241. [train_loader, None],
  242. logger,
  243. [writer, writer_eval],
  244. )
  245. else:
  246. train_and_evaluate(
  247. rank,
  248. epoch,
  249. hps,
  250. [net_g, net_d],
  251. [optim_g, optim_d],
  252. [scheduler_g, scheduler_d],
  253. scaler,
  254. [train_loader, None],
  255. None,
  256. None,
  257. )
  258. scheduler_g.step()
  259. scheduler_d.step()
  260. def train_and_evaluate(
  261. rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
  262. ):
  263. net_g, net_d = nets
  264. optim_g, optim_d = optims
  265. # scheduler_g, scheduler_d = schedulers
  266. train_loader, eval_loader = loaders
  267. if writers is not None:
  268. writer, writer_eval = writers
  269. train_loader.batch_sampler.set_epoch(epoch)
  270. global global_step
  271. net_g.train()
  272. net_d.train()
  273. for batch_idx, (
  274. ssl,
  275. ssl_lengths,
  276. spec,
  277. spec_lengths,
  278. y,
  279. y_lengths,
  280. text,
  281. text_lengths,
  282. ) in tqdm(enumerate(train_loader)):
  283. if torch.cuda.is_available():
  284. spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
  285. rank, non_blocking=True
  286. )
  287. y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
  288. rank, non_blocking=True
  289. )
  290. ssl = ssl.cuda(rank, non_blocking=True)
  291. ssl.requires_grad = False
  292. # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
  293. text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
  294. rank, non_blocking=True
  295. )
  296. else:
  297. spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
  298. y, y_lengths = y.to("mps"), y_lengths.to("mps")
  299. ssl = ssl.to("mps")
  300. ssl.requires_grad = False
  301. # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
  302. text, text_lengths = text.to("mps"), text_lengths.to("mps")
  303. with autocast(enabled=hps.train.fp16_run):
  304. (
  305. y_hat,
  306. kl_ssl,
  307. ids_slice,
  308. x_mask,
  309. z_mask,
  310. (z, z_p, m_p, logs_p, m_q, logs_q),
  311. stats_ssl,
  312. ) = net_g(ssl, spec, spec_lengths, text, text_lengths)
  313. mel = spec_to_mel_torch(
  314. spec,
  315. hps.data.filter_length,
  316. hps.data.n_mel_channels,
  317. hps.data.sampling_rate,
  318. hps.data.mel_fmin,
  319. hps.data.mel_fmax,
  320. )
  321. y_mel = commons.slice_segments(
  322. mel, ids_slice, hps.train.segment_size // hps.data.hop_length
  323. )
  324. y_hat_mel = mel_spectrogram_torch(
  325. y_hat.squeeze(1),
  326. hps.data.filter_length,
  327. hps.data.n_mel_channels,
  328. hps.data.sampling_rate,
  329. hps.data.hop_length,
  330. hps.data.win_length,
  331. hps.data.mel_fmin,
  332. hps.data.mel_fmax,
  333. )
  334. y = commons.slice_segments(
  335. y, ids_slice * hps.data.hop_length, hps.train.segment_size
  336. ) # slice
  337. # Discriminator
  338. y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
  339. with autocast(enabled=False):
  340. loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
  341. y_d_hat_r, y_d_hat_g
  342. )
  343. loss_disc_all = loss_disc
  344. optim_d.zero_grad()
  345. scaler.scale(loss_disc_all).backward()
  346. scaler.unscale_(optim_d)
  347. grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
  348. scaler.step(optim_d)
  349. with autocast(enabled=hps.train.fp16_run):
  350. # Generator
  351. y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
  352. with autocast(enabled=False):
  353. loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
  354. loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
  355. loss_fm = feature_loss(fmap_r, fmap_g)
  356. loss_gen, losses_gen = generator_loss(y_d_hat_g)
  357. loss_gen_all = loss_gen + loss_fm + loss_mel + kl_ssl * 1 + loss_kl
  358. optim_g.zero_grad()
  359. scaler.scale(loss_gen_all).backward()
  360. scaler.unscale_(optim_g)
  361. grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
  362. scaler.step(optim_g)
  363. scaler.update()
  364. if rank == 0:
  365. if global_step % hps.train.log_interval == 0:
  366. lr = optim_g.param_groups[0]["lr"]
  367. losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
  368. logger.info(
  369. "Train Epoch: {} [{:.0f}%]".format(
  370. epoch, 100.0 * batch_idx / len(train_loader)
  371. )
  372. )
  373. logger.info([x.item() for x in losses] + [global_step, lr])
  374. scalar_dict = {
  375. "loss/g/total": loss_gen_all,
  376. "loss/d/total": loss_disc_all,
  377. "learning_rate": lr,
  378. "grad_norm_d": grad_norm_d,
  379. "grad_norm_g": grad_norm_g,
  380. }
  381. scalar_dict.update(
  382. {
  383. "loss/g/fm": loss_fm,
  384. "loss/g/mel": loss_mel,
  385. "loss/g/kl_ssl": kl_ssl,
  386. "loss/g/kl": loss_kl,
  387. }
  388. )
  389. # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)})
  390. # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
  391. # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
  392. image_dict = {
  393. "slice/mel_org": utils.plot_spectrogram_to_numpy(
  394. y_mel[0].data.cpu().numpy()
  395. ),
  396. "slice/mel_gen": utils.plot_spectrogram_to_numpy(
  397. y_hat_mel[0].data.cpu().numpy()
  398. ),
  399. "all/mel": utils.plot_spectrogram_to_numpy(
  400. mel[0].data.cpu().numpy()
  401. ),
  402. "all/stats_ssl": utils.plot_spectrogram_to_numpy(
  403. stats_ssl[0].data.cpu().numpy()
  404. ),
  405. }
  406. utils.summarize(
  407. writer=writer,
  408. global_step=global_step,
  409. images=image_dict,
  410. scalars=scalar_dict,
  411. )
  412. global_step += 1
  413. if epoch % hps.train.save_every_epoch == 0 and rank == 0:
  414. if hps.train.if_save_latest == 0:
  415. utils.save_checkpoint(
  416. net_g,
  417. optim_g,
  418. hps.train.learning_rate,
  419. epoch,
  420. os.path.join(
  421. "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(global_step)
  422. ),
  423. )
  424. utils.save_checkpoint(
  425. net_d,
  426. optim_d,
  427. hps.train.learning_rate,
  428. epoch,
  429. os.path.join(
  430. "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(global_step)
  431. ),
  432. )
  433. else:
  434. utils.save_checkpoint(
  435. net_g,
  436. optim_g,
  437. hps.train.learning_rate,
  438. epoch,
  439. os.path.join(
  440. "%s/logs_s2" % hps.data.exp_dir, "G_{}.pth".format(233333333333)
  441. ),
  442. )
  443. utils.save_checkpoint(
  444. net_d,
  445. optim_d,
  446. hps.train.learning_rate,
  447. epoch,
  448. os.path.join(
  449. "%s/logs_s2" % hps.data.exp_dir, "D_{}.pth".format(233333333333)
  450. ),
  451. )
  452. if rank == 0 and hps.train.if_save_every_weights == True:
  453. if hasattr(net_g, "module"):
  454. ckpt = net_g.module.state_dict()
  455. else:
  456. ckpt = net_g.state_dict()
  457. logger.info(
  458. "saving ckpt %s_e%s:%s"
  459. % (
  460. hps.name,
  461. epoch,
  462. savee(
  463. ckpt,
  464. hps.name + "_e%s_s%s" % (epoch, global_step),
  465. epoch,
  466. global_step,
  467. hps,
  468. ),
  469. )
  470. )
  471. if rank == 0:
  472. logger.info("====> Epoch: {}".format(epoch))
  473. def evaluate(hps, generator, eval_loader, writer_eval):
  474. generator.eval()
  475. image_dict = {}
  476. audio_dict = {}
  477. print("Evaluating ...")
  478. with torch.no_grad():
  479. for batch_idx, (
  480. ssl,
  481. ssl_lengths,
  482. spec,
  483. spec_lengths,
  484. y,
  485. y_lengths,
  486. text,
  487. text_lengths,
  488. ) in enumerate(eval_loader):
  489. print(111)
  490. if torch.cuda.is_available():
  491. spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
  492. y, y_lengths = y.cuda(), y_lengths.cuda()
  493. ssl = ssl.cuda()
  494. text, text_lengths = text.cuda(), text_lengths.cuda()
  495. else:
  496. spec, spec_lengths = spec.to("mps"), spec_lengths.to("mps")
  497. y, y_lengths = y.to("mps"), y_lengths.to("mps")
  498. ssl = ssl.to("mps")
  499. text, text_lengths = text.to("mps"), text_lengths.to("mps")
  500. for test in [0, 1]:
  501. y_hat, mask, *_ = generator.module.infer(
  502. ssl, spec, spec_lengths, text, text_lengths, test=test
  503. ) if torch.cuda.is_available() else generator.infer(
  504. ssl, spec, spec_lengths, text, text_lengths, test=test
  505. )
  506. y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
  507. mel = spec_to_mel_torch(
  508. spec,
  509. hps.data.filter_length,
  510. hps.data.n_mel_channels,
  511. hps.data.sampling_rate,
  512. hps.data.mel_fmin,
  513. hps.data.mel_fmax,
  514. )
  515. y_hat_mel = mel_spectrogram_torch(
  516. y_hat.squeeze(1).float(),
  517. hps.data.filter_length,
  518. hps.data.n_mel_channels,
  519. hps.data.sampling_rate,
  520. hps.data.hop_length,
  521. hps.data.win_length,
  522. hps.data.mel_fmin,
  523. hps.data.mel_fmax,
  524. )
  525. image_dict.update(
  526. {
  527. f"gen/mel_{batch_idx}_{test}": utils.plot_spectrogram_to_numpy(
  528. y_hat_mel[0].cpu().numpy()
  529. )
  530. }
  531. )
  532. audio_dict.update(
  533. {f"gen/audio_{batch_idx}_{test}": y_hat[0, :, : y_hat_lengths[0]]}
  534. )
  535. image_dict.update(
  536. {
  537. f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
  538. mel[0].cpu().numpy()
  539. )
  540. }
  541. )
  542. audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
  543. # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
  544. # audio_dict.update({
  545. # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
  546. # })
  547. utils.summarize(
  548. writer=writer_eval,
  549. global_step=global_step,
  550. images=image_dict,
  551. audios=audio_dict,
  552. audio_sampling_rate=hps.data.sampling_rate,
  553. )
  554. generator.train()
  555. if __name__ == "__main__":
  556. main()