models.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992
  1. import copy
  2. import math
  3. import torch
  4. from torch import nn
  5. from torch.nn import functional as F
  6. from module import commons
  7. from module import modules
  8. from module import attentions
  9. from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
  10. from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
  11. from module.commons import init_weights, get_padding
  12. from module.mrte_model import MRTE
  13. from module.quantize import ResidualVectorQuantizer
  14. from text import symbols
  15. from torch.cuda.amp import autocast
  16. class StochasticDurationPredictor(nn.Module):
  17. def __init__(
  18. self,
  19. in_channels,
  20. filter_channels,
  21. kernel_size,
  22. p_dropout,
  23. n_flows=4,
  24. gin_channels=0,
  25. ):
  26. super().__init__()
  27. filter_channels = in_channels # it needs to be removed from future version.
  28. self.in_channels = in_channels
  29. self.filter_channels = filter_channels
  30. self.kernel_size = kernel_size
  31. self.p_dropout = p_dropout
  32. self.n_flows = n_flows
  33. self.gin_channels = gin_channels
  34. self.log_flow = modules.Log()
  35. self.flows = nn.ModuleList()
  36. self.flows.append(modules.ElementwiseAffine(2))
  37. for i in range(n_flows):
  38. self.flows.append(
  39. modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
  40. )
  41. self.flows.append(modules.Flip())
  42. self.post_pre = nn.Conv1d(1, filter_channels, 1)
  43. self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
  44. self.post_convs = modules.DDSConv(
  45. filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
  46. )
  47. self.post_flows = nn.ModuleList()
  48. self.post_flows.append(modules.ElementwiseAffine(2))
  49. for i in range(4):
  50. self.post_flows.append(
  51. modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
  52. )
  53. self.post_flows.append(modules.Flip())
  54. self.pre = nn.Conv1d(in_channels, filter_channels, 1)
  55. self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
  56. self.convs = modules.DDSConv(
  57. filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
  58. )
  59. if gin_channels != 0:
  60. self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
  61. def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
  62. x = torch.detach(x)
  63. x = self.pre(x)
  64. if g is not None:
  65. g = torch.detach(g)
  66. x = x + self.cond(g)
  67. x = self.convs(x, x_mask)
  68. x = self.proj(x) * x_mask
  69. if not reverse:
  70. flows = self.flows
  71. assert w is not None
  72. logdet_tot_q = 0
  73. h_w = self.post_pre(w)
  74. h_w = self.post_convs(h_w, x_mask)
  75. h_w = self.post_proj(h_w) * x_mask
  76. e_q = (
  77. torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
  78. * x_mask
  79. )
  80. z_q = e_q
  81. for flow in self.post_flows:
  82. z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
  83. logdet_tot_q += logdet_q
  84. z_u, z1 = torch.split(z_q, [1, 1], 1)
  85. u = torch.sigmoid(z_u) * x_mask
  86. z0 = (w - u) * x_mask
  87. logdet_tot_q += torch.sum(
  88. (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
  89. )
  90. logq = (
  91. torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
  92. - logdet_tot_q
  93. )
  94. logdet_tot = 0
  95. z0, logdet = self.log_flow(z0, x_mask)
  96. logdet_tot += logdet
  97. z = torch.cat([z0, z1], 1)
  98. for flow in flows:
  99. z, logdet = flow(z, x_mask, g=x, reverse=reverse)
  100. logdet_tot = logdet_tot + logdet
  101. nll = (
  102. torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
  103. - logdet_tot
  104. )
  105. return nll + logq # [b]
  106. else:
  107. flows = list(reversed(self.flows))
  108. flows = flows[:-2] + [flows[-1]] # remove a useless vflow
  109. z = (
  110. torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
  111. * noise_scale
  112. )
  113. for flow in flows:
  114. z = flow(z, x_mask, g=x, reverse=reverse)
  115. z0, z1 = torch.split(z, [1, 1], 1)
  116. logw = z0
  117. return logw
  118. class DurationPredictor(nn.Module):
  119. def __init__(
  120. self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
  121. ):
  122. super().__init__()
  123. self.in_channels = in_channels
  124. self.filter_channels = filter_channels
  125. self.kernel_size = kernel_size
  126. self.p_dropout = p_dropout
  127. self.gin_channels = gin_channels
  128. self.drop = nn.Dropout(p_dropout)
  129. self.conv_1 = nn.Conv1d(
  130. in_channels, filter_channels, kernel_size, padding=kernel_size // 2
  131. )
  132. self.norm_1 = modules.LayerNorm(filter_channels)
  133. self.conv_2 = nn.Conv1d(
  134. filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
  135. )
  136. self.norm_2 = modules.LayerNorm(filter_channels)
  137. self.proj = nn.Conv1d(filter_channels, 1, 1)
  138. if gin_channels != 0:
  139. self.cond = nn.Conv1d(gin_channels, in_channels, 1)
  140. def forward(self, x, x_mask, g=None):
  141. x = torch.detach(x)
  142. if g is not None:
  143. g = torch.detach(g)
  144. x = x + self.cond(g)
  145. x = self.conv_1(x * x_mask)
  146. x = torch.relu(x)
  147. x = self.norm_1(x)
  148. x = self.drop(x)
  149. x = self.conv_2(x * x_mask)
  150. x = torch.relu(x)
  151. x = self.norm_2(x)
  152. x = self.drop(x)
  153. x = self.proj(x * x_mask)
  154. return x * x_mask
  155. class TextEncoder(nn.Module):
  156. def __init__(
  157. self,
  158. out_channels,
  159. hidden_channels,
  160. filter_channels,
  161. n_heads,
  162. n_layers,
  163. kernel_size,
  164. p_dropout,
  165. latent_channels=192,
  166. ):
  167. super().__init__()
  168. self.out_channels = out_channels
  169. self.hidden_channels = hidden_channels
  170. self.filter_channels = filter_channels
  171. self.n_heads = n_heads
  172. self.n_layers = n_layers
  173. self.kernel_size = kernel_size
  174. self.p_dropout = p_dropout
  175. self.latent_channels = latent_channels
  176. self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
  177. self.encoder_ssl = attentions.Encoder(
  178. hidden_channels,
  179. filter_channels,
  180. n_heads,
  181. n_layers // 2,
  182. kernel_size,
  183. p_dropout,
  184. )
  185. self.encoder_text = attentions.Encoder(
  186. hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
  187. )
  188. self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
  189. self.mrte = MRTE()
  190. self.encoder2 = attentions.Encoder(
  191. hidden_channels,
  192. filter_channels,
  193. n_heads,
  194. n_layers // 2,
  195. kernel_size,
  196. p_dropout,
  197. )
  198. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  199. def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
  200. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  201. y.dtype
  202. )
  203. y = self.ssl_proj(y * y_mask) * y_mask
  204. y = self.encoder_ssl(y * y_mask, y_mask)
  205. text_mask = torch.unsqueeze(
  206. commons.sequence_mask(text_lengths, text.size(1)), 1
  207. ).to(y.dtype)
  208. if test == 1:
  209. text[:, :] = 0
  210. text = self.text_embedding(text).transpose(1, 2)
  211. text = self.encoder_text(text * text_mask, text_mask)
  212. y = self.mrte(y, y_mask, text, text_mask, ge)
  213. y = self.encoder2(y * y_mask, y_mask)
  214. stats = self.proj(y) * y_mask
  215. m, logs = torch.split(stats, self.out_channels, dim=1)
  216. return y, m, logs, y_mask
  217. def extract_latent(self, x):
  218. x = self.ssl_proj(x)
  219. quantized, codes, commit_loss, quantized_list = self.quantizer(x)
  220. return codes.transpose(0, 1)
  221. def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
  222. quantized = self.quantizer.decode(codes)
  223. y = self.vq_proj(quantized) * y_mask
  224. y = self.encoder_ssl(y * y_mask, y_mask)
  225. y = self.mrte(y, y_mask, refer, refer_mask, ge)
  226. y = self.encoder2(y * y_mask, y_mask)
  227. stats = self.proj(y) * y_mask
  228. m, logs = torch.split(stats, self.out_channels, dim=1)
  229. return y, m, logs, y_mask, quantized
  230. class ResidualCouplingBlock(nn.Module):
  231. def __init__(
  232. self,
  233. channels,
  234. hidden_channels,
  235. kernel_size,
  236. dilation_rate,
  237. n_layers,
  238. n_flows=4,
  239. gin_channels=0,
  240. ):
  241. super().__init__()
  242. self.channels = channels
  243. self.hidden_channels = hidden_channels
  244. self.kernel_size = kernel_size
  245. self.dilation_rate = dilation_rate
  246. self.n_layers = n_layers
  247. self.n_flows = n_flows
  248. self.gin_channels = gin_channels
  249. self.flows = nn.ModuleList()
  250. for i in range(n_flows):
  251. self.flows.append(
  252. modules.ResidualCouplingLayer(
  253. channels,
  254. hidden_channels,
  255. kernel_size,
  256. dilation_rate,
  257. n_layers,
  258. gin_channels=gin_channels,
  259. mean_only=True,
  260. )
  261. )
  262. self.flows.append(modules.Flip())
  263. def forward(self, x, x_mask, g=None, reverse=False):
  264. if not reverse:
  265. for flow in self.flows:
  266. x, _ = flow(x, x_mask, g=g, reverse=reverse)
  267. else:
  268. for flow in reversed(self.flows):
  269. x = flow(x, x_mask, g=g, reverse=reverse)
  270. return x
  271. class PosteriorEncoder(nn.Module):
  272. def __init__(
  273. self,
  274. in_channels,
  275. out_channels,
  276. hidden_channels,
  277. kernel_size,
  278. dilation_rate,
  279. n_layers,
  280. gin_channels=0,
  281. ):
  282. super().__init__()
  283. self.in_channels = in_channels
  284. self.out_channels = out_channels
  285. self.hidden_channels = hidden_channels
  286. self.kernel_size = kernel_size
  287. self.dilation_rate = dilation_rate
  288. self.n_layers = n_layers
  289. self.gin_channels = gin_channels
  290. self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
  291. self.enc = modules.WN(
  292. hidden_channels,
  293. kernel_size,
  294. dilation_rate,
  295. n_layers,
  296. gin_channels=gin_channels,
  297. )
  298. self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
  299. def forward(self, x, x_lengths, g=None):
  300. if g != None:
  301. g = g.detach()
  302. x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
  303. x.dtype
  304. )
  305. x = self.pre(x) * x_mask
  306. x = self.enc(x, x_mask, g=g)
  307. stats = self.proj(x) * x_mask
  308. m, logs = torch.split(stats, self.out_channels, dim=1)
  309. z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
  310. return z, m, logs, x_mask
  311. class WNEncoder(nn.Module):
  312. def __init__(
  313. self,
  314. in_channels,
  315. out_channels,
  316. hidden_channels,
  317. kernel_size,
  318. dilation_rate,
  319. n_layers,
  320. gin_channels=0,
  321. ):
  322. super().__init__()
  323. self.in_channels = in_channels
  324. self.out_channels = out_channels
  325. self.hidden_channels = hidden_channels
  326. self.kernel_size = kernel_size
  327. self.dilation_rate = dilation_rate
  328. self.n_layers = n_layers
  329. self.gin_channels = gin_channels
  330. self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
  331. self.enc = modules.WN(
  332. hidden_channels,
  333. kernel_size,
  334. dilation_rate,
  335. n_layers,
  336. gin_channels=gin_channels,
  337. )
  338. self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
  339. self.norm = modules.LayerNorm(out_channels)
  340. def forward(self, x, x_lengths, g=None):
  341. x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
  342. x.dtype
  343. )
  344. x = self.pre(x) * x_mask
  345. x = self.enc(x, x_mask, g=g)
  346. out = self.proj(x) * x_mask
  347. out = self.norm(out)
  348. return out
  349. class Generator(torch.nn.Module):
  350. def __init__(
  351. self,
  352. initial_channel,
  353. resblock,
  354. resblock_kernel_sizes,
  355. resblock_dilation_sizes,
  356. upsample_rates,
  357. upsample_initial_channel,
  358. upsample_kernel_sizes,
  359. gin_channels=0,
  360. ):
  361. super(Generator, self).__init__()
  362. self.num_kernels = len(resblock_kernel_sizes)
  363. self.num_upsamples = len(upsample_rates)
  364. self.conv_pre = Conv1d(
  365. initial_channel, upsample_initial_channel, 7, 1, padding=3
  366. )
  367. resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
  368. self.ups = nn.ModuleList()
  369. for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
  370. self.ups.append(
  371. weight_norm(
  372. ConvTranspose1d(
  373. upsample_initial_channel // (2**i),
  374. upsample_initial_channel // (2 ** (i + 1)),
  375. k,
  376. u,
  377. padding=(k - u) // 2,
  378. )
  379. )
  380. )
  381. self.resblocks = nn.ModuleList()
  382. for i in range(len(self.ups)):
  383. ch = upsample_initial_channel // (2 ** (i + 1))
  384. for j, (k, d) in enumerate(
  385. zip(resblock_kernel_sizes, resblock_dilation_sizes)
  386. ):
  387. self.resblocks.append(resblock(ch, k, d))
  388. self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
  389. self.ups.apply(init_weights)
  390. if gin_channels != 0:
  391. self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
  392. def forward(self, x, g=None):
  393. x = self.conv_pre(x)
  394. if g is not None:
  395. x = x + self.cond(g)
  396. for i in range(self.num_upsamples):
  397. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  398. x = self.ups[i](x)
  399. xs = None
  400. for j in range(self.num_kernels):
  401. if xs is None:
  402. xs = self.resblocks[i * self.num_kernels + j](x)
  403. else:
  404. xs += self.resblocks[i * self.num_kernels + j](x)
  405. x = xs / self.num_kernels
  406. x = F.leaky_relu(x)
  407. x = self.conv_post(x)
  408. x = torch.tanh(x)
  409. return x
  410. def remove_weight_norm(self):
  411. print("Removing weight norm...")
  412. for l in self.ups:
  413. remove_weight_norm(l)
  414. for l in self.resblocks:
  415. l.remove_weight_norm()
  416. class DiscriminatorP(torch.nn.Module):
  417. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  418. super(DiscriminatorP, self).__init__()
  419. self.period = period
  420. self.use_spectral_norm = use_spectral_norm
  421. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  422. self.convs = nn.ModuleList(
  423. [
  424. norm_f(
  425. Conv2d(
  426. 1,
  427. 32,
  428. (kernel_size, 1),
  429. (stride, 1),
  430. padding=(get_padding(kernel_size, 1), 0),
  431. )
  432. ),
  433. norm_f(
  434. Conv2d(
  435. 32,
  436. 128,
  437. (kernel_size, 1),
  438. (stride, 1),
  439. padding=(get_padding(kernel_size, 1), 0),
  440. )
  441. ),
  442. norm_f(
  443. Conv2d(
  444. 128,
  445. 512,
  446. (kernel_size, 1),
  447. (stride, 1),
  448. padding=(get_padding(kernel_size, 1), 0),
  449. )
  450. ),
  451. norm_f(
  452. Conv2d(
  453. 512,
  454. 1024,
  455. (kernel_size, 1),
  456. (stride, 1),
  457. padding=(get_padding(kernel_size, 1), 0),
  458. )
  459. ),
  460. norm_f(
  461. Conv2d(
  462. 1024,
  463. 1024,
  464. (kernel_size, 1),
  465. 1,
  466. padding=(get_padding(kernel_size, 1), 0),
  467. )
  468. ),
  469. ]
  470. )
  471. self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  472. def forward(self, x):
  473. fmap = []
  474. # 1d to 2d
  475. b, c, t = x.shape
  476. if t % self.period != 0: # pad first
  477. n_pad = self.period - (t % self.period)
  478. x = F.pad(x, (0, n_pad), "reflect")
  479. t = t + n_pad
  480. x = x.view(b, c, t // self.period, self.period)
  481. for l in self.convs:
  482. x = l(x)
  483. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  484. fmap.append(x)
  485. x = self.conv_post(x)
  486. fmap.append(x)
  487. x = torch.flatten(x, 1, -1)
  488. return x, fmap
  489. class DiscriminatorS(torch.nn.Module):
  490. def __init__(self, use_spectral_norm=False):
  491. super(DiscriminatorS, self).__init__()
  492. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  493. self.convs = nn.ModuleList(
  494. [
  495. norm_f(Conv1d(1, 16, 15, 1, padding=7)),
  496. norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
  497. norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
  498. norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
  499. norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
  500. norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
  501. ]
  502. )
  503. self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
  504. def forward(self, x):
  505. fmap = []
  506. for l in self.convs:
  507. x = l(x)
  508. x = F.leaky_relu(x, modules.LRELU_SLOPE)
  509. fmap.append(x)
  510. x = self.conv_post(x)
  511. fmap.append(x)
  512. x = torch.flatten(x, 1, -1)
  513. return x, fmap
  514. class MultiPeriodDiscriminator(torch.nn.Module):
  515. def __init__(self, use_spectral_norm=False):
  516. super(MultiPeriodDiscriminator, self).__init__()
  517. periods = [2, 3, 5, 7, 11]
  518. discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
  519. discs = discs + [
  520. DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
  521. ]
  522. self.discriminators = nn.ModuleList(discs)
  523. def forward(self, y, y_hat):
  524. y_d_rs = []
  525. y_d_gs = []
  526. fmap_rs = []
  527. fmap_gs = []
  528. for i, d in enumerate(self.discriminators):
  529. y_d_r, fmap_r = d(y)
  530. y_d_g, fmap_g = d(y_hat)
  531. y_d_rs.append(y_d_r)
  532. y_d_gs.append(y_d_g)
  533. fmap_rs.append(fmap_r)
  534. fmap_gs.append(fmap_g)
  535. return y_d_rs, y_d_gs, fmap_rs, fmap_gs
  536. class ReferenceEncoder(nn.Module):
  537. """
  538. inputs --- [N, Ty/r, n_mels*r] mels
  539. outputs --- [N, ref_enc_gru_size]
  540. """
  541. def __init__(self, spec_channels, gin_channels=0):
  542. super().__init__()
  543. self.spec_channels = spec_channels
  544. ref_enc_filters = [32, 32, 64, 64, 128, 128]
  545. K = len(ref_enc_filters)
  546. filters = [1] + ref_enc_filters
  547. convs = [
  548. weight_norm(
  549. nn.Conv2d(
  550. in_channels=filters[i],
  551. out_channels=filters[i + 1],
  552. kernel_size=(3, 3),
  553. stride=(2, 2),
  554. padding=(1, 1),
  555. )
  556. )
  557. for i in range(K)
  558. ]
  559. self.convs = nn.ModuleList(convs)
  560. # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
  561. out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
  562. self.gru = nn.GRU(
  563. input_size=ref_enc_filters[-1] * out_channels,
  564. hidden_size=256 // 2,
  565. batch_first=True,
  566. )
  567. self.proj = nn.Linear(128, gin_channels)
  568. def forward(self, inputs):
  569. N = inputs.size(0)
  570. out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
  571. for conv in self.convs:
  572. out = conv(out)
  573. # out = wn(out)
  574. out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
  575. out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
  576. T = out.size(1)
  577. N = out.size(0)
  578. out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
  579. self.gru.flatten_parameters()
  580. memory, out = self.gru(out) # out --- [1, N, 128]
  581. return self.proj(out.squeeze(0)).unsqueeze(-1)
  582. def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
  583. for i in range(n_convs):
  584. L = (L - kernel_size + 2 * pad) // stride + 1
  585. return L
  586. class Quantizer_module(torch.nn.Module):
  587. def __init__(self, n_e, e_dim):
  588. super(Quantizer_module, self).__init__()
  589. self.embedding = nn.Embedding(n_e, e_dim)
  590. self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
  591. def forward(self, x):
  592. d = (
  593. torch.sum(x**2, 1, keepdim=True)
  594. + torch.sum(self.embedding.weight**2, 1)
  595. - 2 * torch.matmul(x, self.embedding.weight.T)
  596. )
  597. min_indicies = torch.argmin(d, 1)
  598. z_q = self.embedding(min_indicies)
  599. return z_q, min_indicies
  600. class Quantizer(torch.nn.Module):
  601. def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
  602. super(Quantizer, self).__init__()
  603. assert embed_dim % n_code_groups == 0
  604. self.quantizer_modules = nn.ModuleList(
  605. [
  606. Quantizer_module(n_codes, embed_dim // n_code_groups)
  607. for _ in range(n_code_groups)
  608. ]
  609. )
  610. self.n_code_groups = n_code_groups
  611. self.embed_dim = embed_dim
  612. def forward(self, xin):
  613. # B, C, T
  614. B, C, T = xin.shape
  615. xin = xin.transpose(1, 2)
  616. x = xin.reshape(-1, self.embed_dim)
  617. x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
  618. min_indicies = []
  619. z_q = []
  620. for _x, m in zip(x, self.quantizer_modules):
  621. _z_q, _min_indicies = m(_x)
  622. z_q.append(_z_q)
  623. min_indicies.append(_min_indicies) # B * T,
  624. z_q = torch.cat(z_q, -1).reshape(xin.shape)
  625. loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
  626. (z_q - xin.detach()) ** 2
  627. )
  628. z_q = xin + (z_q - xin).detach()
  629. z_q = z_q.transpose(1, 2)
  630. codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
  631. return z_q, loss, codes.transpose(1, 2)
  632. def embed(self, x):
  633. # idx: N, 4, T
  634. x = x.transpose(1, 2)
  635. x = torch.split(x, 1, 2)
  636. ret = []
  637. for q, embed in zip(x, self.quantizer_modules):
  638. q = embed.embedding(q.squeeze(-1))
  639. ret.append(q)
  640. ret = torch.cat(ret, -1)
  641. return ret.transpose(1, 2) # N, C, T
  642. class CodePredictor(nn.Module):
  643. def __init__(
  644. self,
  645. hidden_channels,
  646. filter_channels,
  647. n_heads,
  648. n_layers,
  649. kernel_size,
  650. p_dropout,
  651. n_q=8,
  652. dims=1024,
  653. ssl_dim=768,
  654. ):
  655. super().__init__()
  656. self.hidden_channels = hidden_channels
  657. self.filter_channels = filter_channels
  658. self.n_heads = n_heads
  659. self.n_layers = n_layers
  660. self.kernel_size = kernel_size
  661. self.p_dropout = p_dropout
  662. self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
  663. self.ref_enc = modules.MelStyleEncoder(
  664. ssl_dim, style_vector_dim=hidden_channels
  665. )
  666. self.encoder = attentions.Encoder(
  667. hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
  668. )
  669. self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
  670. self.n_q = n_q
  671. self.dims = dims
  672. def forward(self, x, x_mask, refer, codes, infer=False):
  673. x = x.detach()
  674. x = self.vq_proj(x * x_mask) * x_mask
  675. g = self.ref_enc(refer, x_mask)
  676. x = x + g
  677. x = self.encoder(x * x_mask, x_mask)
  678. x = self.out_proj(x * x_mask) * x_mask
  679. logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
  680. 2, 3
  681. )
  682. target = codes[1:].transpose(0, 1)
  683. if not infer:
  684. logits = logits.reshape(-1, self.dims)
  685. target = target.reshape(-1)
  686. loss = torch.nn.functional.cross_entropy(logits, target)
  687. return loss
  688. else:
  689. _, top10_preds = torch.topk(logits, 10, dim=-1)
  690. correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
  691. top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
  692. print("Top-10 Accuracy:", top3_acc, "%")
  693. pred_codes = torch.argmax(logits, dim=-1)
  694. acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
  695. print("Top-1 Accuracy:", acc, "%")
  696. return pred_codes.transpose(0, 1)
  697. class SynthesizerTrn(nn.Module):
  698. """
  699. Synthesizer for Training
  700. """
  701. def __init__(
  702. self,
  703. spec_channels,
  704. segment_size,
  705. inter_channels,
  706. hidden_channels,
  707. filter_channels,
  708. n_heads,
  709. n_layers,
  710. kernel_size,
  711. p_dropout,
  712. resblock,
  713. resblock_kernel_sizes,
  714. resblock_dilation_sizes,
  715. upsample_rates,
  716. upsample_initial_channel,
  717. upsample_kernel_sizes,
  718. n_speakers=0,
  719. gin_channels=0,
  720. use_sdp=True,
  721. semantic_frame_rate=None,
  722. freeze_quantizer=None,
  723. **kwargs
  724. ):
  725. super().__init__()
  726. self.spec_channels = spec_channels
  727. self.inter_channels = inter_channels
  728. self.hidden_channels = hidden_channels
  729. self.filter_channels = filter_channels
  730. self.n_heads = n_heads
  731. self.n_layers = n_layers
  732. self.kernel_size = kernel_size
  733. self.p_dropout = p_dropout
  734. self.resblock = resblock
  735. self.resblock_kernel_sizes = resblock_kernel_sizes
  736. self.resblock_dilation_sizes = resblock_dilation_sizes
  737. self.upsample_rates = upsample_rates
  738. self.upsample_initial_channel = upsample_initial_channel
  739. self.upsample_kernel_sizes = upsample_kernel_sizes
  740. self.segment_size = segment_size
  741. self.n_speakers = n_speakers
  742. self.gin_channels = gin_channels
  743. self.use_sdp = use_sdp
  744. self.enc_p = TextEncoder(
  745. inter_channels,
  746. hidden_channels,
  747. filter_channels,
  748. n_heads,
  749. n_layers,
  750. kernel_size,
  751. p_dropout,
  752. )
  753. self.dec = Generator(
  754. inter_channels,
  755. resblock,
  756. resblock_kernel_sizes,
  757. resblock_dilation_sizes,
  758. upsample_rates,
  759. upsample_initial_channel,
  760. upsample_kernel_sizes,
  761. gin_channels=gin_channels,
  762. )
  763. self.enc_q = PosteriorEncoder(
  764. spec_channels,
  765. inter_channels,
  766. hidden_channels,
  767. 5,
  768. 1,
  769. 16,
  770. gin_channels=gin_channels,
  771. )
  772. self.flow = ResidualCouplingBlock(
  773. inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
  774. )
  775. self.ref_enc = modules.MelStyleEncoder(
  776. spec_channels, style_vector_dim=gin_channels
  777. )
  778. ssl_dim = 768
  779. assert semantic_frame_rate in ["25hz", "50hz"]
  780. self.semantic_frame_rate = semantic_frame_rate
  781. if semantic_frame_rate == "25hz":
  782. self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
  783. else:
  784. self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
  785. self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
  786. if freeze_quantizer:
  787. self.ssl_proj.requires_grad_(False)
  788. self.quantizer.requires_grad_(False)
  789. # self.enc_p.text_embedding.requires_grad_(False)
  790. # self.enc_p.encoder_text.requires_grad_(False)
  791. # self.enc_p.mrte.requires_grad_(False)
  792. def forward(self, ssl, y, y_lengths, text, text_lengths):
  793. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  794. y.dtype
  795. )
  796. ge = self.ref_enc(y * y_mask, y_mask)
  797. with autocast(enabled=False):
  798. ssl = self.ssl_proj(ssl)
  799. quantized, codes, commit_loss, quantized_list = self.quantizer(
  800. ssl, layers=[0]
  801. )
  802. if self.semantic_frame_rate == "25hz":
  803. quantized = F.interpolate(
  804. quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
  805. )
  806. x, m_p, logs_p, y_mask = self.enc_p(
  807. quantized, y_lengths, text, text_lengths, ge
  808. )
  809. z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
  810. z_p = self.flow(z, y_mask, g=ge)
  811. z_slice, ids_slice = commons.rand_slice_segments(
  812. z, y_lengths, self.segment_size
  813. )
  814. o = self.dec(z_slice, g=ge)
  815. return (
  816. o,
  817. commit_loss,
  818. ids_slice,
  819. y_mask,
  820. y_mask,
  821. (z, z_p, m_p, logs_p, m_q, logs_q),
  822. quantized,
  823. )
  824. def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
  825. y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
  826. y.dtype
  827. )
  828. ge = self.ref_enc(y * y_mask, y_mask)
  829. ssl = self.ssl_proj(ssl)
  830. quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
  831. if self.semantic_frame_rate == "25hz":
  832. quantized = F.interpolate(
  833. quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
  834. )
  835. x, m_p, logs_p, y_mask = self.enc_p(
  836. quantized, y_lengths, text, text_lengths, ge, test=test
  837. )
  838. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  839. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  840. o = self.dec((z * y_mask)[:, :, :], g=ge)
  841. return o, y_mask, (z, z_p, m_p, logs_p)
  842. @torch.no_grad()
  843. def decode(self, codes, text, refer, noise_scale=0.5):
  844. ge = None
  845. if refer is not None:
  846. refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
  847. refer_mask = torch.unsqueeze(
  848. commons.sequence_mask(refer_lengths, refer.size(2)), 1
  849. ).to(refer.dtype)
  850. ge = self.ref_enc(refer * refer_mask, refer_mask)
  851. y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
  852. text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
  853. quantized = self.quantizer.decode(codes)
  854. if self.semantic_frame_rate == "25hz":
  855. quantized = F.interpolate(
  856. quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
  857. )
  858. x, m_p, logs_p, y_mask = self.enc_p(
  859. quantized, y_lengths, text, text_lengths, ge
  860. )
  861. z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
  862. z = self.flow(z_p, y_mask, g=ge, reverse=True)
  863. o = self.dec((z * y_mask)[:, :, :], g=ge)
  864. return o
  865. def extract_latent(self, x):
  866. ssl = self.ssl_proj(x)
  867. quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
  868. return codes.transpose(0, 1)