tacotron.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. import os
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from pathlib import Path
  7. from typing import Union
  8. class HighwayNetwork(nn.Module):
  9. def __init__(self, size):
  10. super().__init__()
  11. self.W1 = nn.Linear(size, size)
  12. self.W2 = nn.Linear(size, size)
  13. self.W1.bias.data.fill_(0.)
  14. def forward(self, x):
  15. x1 = self.W1(x)
  16. x2 = self.W2(x)
  17. g = torch.sigmoid(x2)
  18. y = g * F.relu(x1) + (1. - g) * x
  19. return y
  20. class Encoder(nn.Module):
  21. def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
  22. super().__init__()
  23. prenet_dims = (encoder_dims, encoder_dims)
  24. cbhg_channels = encoder_dims
  25. self.embedding = nn.Embedding(num_chars, embed_dims)
  26. self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
  27. dropout=dropout)
  28. self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
  29. proj_channels=[cbhg_channels, cbhg_channels],
  30. num_highways=num_highways)
  31. def forward(self, x, speaker_embedding=None):
  32. x = self.embedding(x)
  33. x = self.pre_net(x)
  34. x.transpose_(1, 2)
  35. x = self.cbhg(x)
  36. if speaker_embedding is not None:
  37. x = self.add_speaker_embedding(x, speaker_embedding)
  38. return x
  39. def add_speaker_embedding(self, x, speaker_embedding):
  40. # SV2TTS
  41. # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
  42. # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
  43. # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
  44. # This concats the speaker embedding for each char in the encoder output
  45. # Save the dimensions as human-readable names
  46. batch_size = x.size()[0]
  47. num_chars = x.size()[1]
  48. if speaker_embedding.dim() == 1:
  49. idx = 0
  50. else:
  51. idx = 1
  52. # Start by making a copy of each speaker embedding to match the input text length
  53. # The output of this has size (batch_size, num_chars * tts_embed_dims)
  54. speaker_embedding_size = speaker_embedding.size()[idx]
  55. e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
  56. # Reshape it and transpose
  57. e = e.reshape(batch_size, speaker_embedding_size, num_chars)
  58. e = e.transpose(1, 2)
  59. # Concatenate the tiled speaker embedding with the encoder output
  60. x = torch.cat((x, e), 2)
  61. return x
  62. class BatchNormConv(nn.Module):
  63. def __init__(self, in_channels, out_channels, kernel, relu=True):
  64. super().__init__()
  65. self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
  66. self.bnorm = nn.BatchNorm1d(out_channels)
  67. self.relu = relu
  68. def forward(self, x):
  69. x = self.conv(x)
  70. x = F.relu(x) if self.relu is True else x
  71. return self.bnorm(x)
  72. class CBHG(nn.Module):
  73. def __init__(self, K, in_channels, channels, proj_channels, num_highways):
  74. super().__init__()
  75. # List of all rnns to call `flatten_parameters()` on
  76. self._to_flatten = []
  77. self.bank_kernels = [i for i in range(1, K + 1)]
  78. self.conv1d_bank = nn.ModuleList()
  79. for k in self.bank_kernels:
  80. conv = BatchNormConv(in_channels, channels, k)
  81. self.conv1d_bank.append(conv)
  82. self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
  83. self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
  84. self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
  85. # Fix the highway input if necessary
  86. if proj_channels[-1] != channels:
  87. self.highway_mismatch = True
  88. self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
  89. else:
  90. self.highway_mismatch = False
  91. self.highways = nn.ModuleList()
  92. for i in range(num_highways):
  93. hn = HighwayNetwork(channels)
  94. self.highways.append(hn)
  95. self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
  96. self._to_flatten.append(self.rnn)
  97. # Avoid fragmentation of RNN parameters and associated warning
  98. self._flatten_parameters()
  99. def forward(self, x):
  100. # Although we `_flatten_parameters()` on init, when using DataParallel
  101. # the model gets replicated, making it no longer guaranteed that the
  102. # weights are contiguous in GPU memory. Hence, we must call it again
  103. self._flatten_parameters()
  104. # Save these for later
  105. residual = x
  106. seq_len = x.size(-1)
  107. conv_bank = []
  108. # Convolution Bank
  109. for conv in self.conv1d_bank:
  110. c = conv(x) # Convolution
  111. conv_bank.append(c[:, :, :seq_len])
  112. # Stack along the channel axis
  113. conv_bank = torch.cat(conv_bank, dim=1)
  114. # dump the last padding to fit residual
  115. x = self.maxpool(conv_bank)[:, :, :seq_len]
  116. # Conv1d projections
  117. x = self.conv_project1(x)
  118. x = self.conv_project2(x)
  119. # Residual Connect
  120. x = x + residual
  121. # Through the highways
  122. x = x.transpose(1, 2)
  123. if self.highway_mismatch is True:
  124. x = self.pre_highway(x)
  125. for h in self.highways: x = h(x)
  126. # And then the RNN
  127. x, _ = self.rnn(x)
  128. return x
  129. def _flatten_parameters(self):
  130. """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
  131. to improve efficiency and avoid PyTorch yelling at us."""
  132. [m.flatten_parameters() for m in self._to_flatten]
  133. class PreNet(nn.Module):
  134. def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
  135. super().__init__()
  136. self.fc1 = nn.Linear(in_dims, fc1_dims)
  137. self.fc2 = nn.Linear(fc1_dims, fc2_dims)
  138. self.p = dropout
  139. def forward(self, x):
  140. x = self.fc1(x)
  141. x = F.relu(x)
  142. x = F.dropout(x, self.p, training=True)
  143. x = self.fc2(x)
  144. x = F.relu(x)
  145. x = F.dropout(x, self.p, training=True)
  146. return x
  147. class Attention(nn.Module):
  148. def __init__(self, attn_dims):
  149. super().__init__()
  150. self.W = nn.Linear(attn_dims, attn_dims, bias=False)
  151. self.v = nn.Linear(attn_dims, 1, bias=False)
  152. def forward(self, encoder_seq_proj, query, t):
  153. # print(encoder_seq_proj.shape)
  154. # Transform the query vector
  155. query_proj = self.W(query).unsqueeze(1)
  156. # Compute the scores
  157. u = self.v(torch.tanh(encoder_seq_proj + query_proj))
  158. scores = F.softmax(u, dim=1)
  159. return scores.transpose(1, 2)
  160. class LSA(nn.Module):
  161. def __init__(self, attn_dim, kernel_size=31, filters=32):
  162. super().__init__()
  163. self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
  164. self.L = nn.Linear(filters, attn_dim, bias=False)
  165. self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
  166. self.v = nn.Linear(attn_dim, 1, bias=False)
  167. self.cumulative = None
  168. self.attention = None
  169. def init_attention(self, encoder_seq_proj):
  170. device = next(self.parameters()).device # use same device as parameters
  171. b, t, c = encoder_seq_proj.size()
  172. self.cumulative = torch.zeros(b, t, device=device)
  173. self.attention = torch.zeros(b, t, device=device)
  174. def forward(self, encoder_seq_proj, query, t, chars):
  175. if t == 0: self.init_attention(encoder_seq_proj)
  176. processed_query = self.W(query).unsqueeze(1)
  177. location = self.cumulative.unsqueeze(1)
  178. processed_loc = self.L(self.conv(location).transpose(1, 2))
  179. u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
  180. u = u.squeeze(-1)
  181. # Mask zero padding chars
  182. u = u * (chars != 0).float()
  183. # Smooth Attention
  184. # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
  185. scores = F.softmax(u, dim=1)
  186. self.attention = scores
  187. self.cumulative = self.cumulative + self.attention
  188. return scores.unsqueeze(-1).transpose(1, 2)
  189. class Decoder(nn.Module):
  190. # Class variable because its value doesn't change between classes
  191. # yet ought to be scoped by class because its a property of a Decoder
  192. max_r = 20
  193. def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
  194. dropout, speaker_embedding_size):
  195. super().__init__()
  196. self.register_buffer("r", torch.tensor(1, dtype=torch.int))
  197. self.n_mels = n_mels
  198. prenet_dims = (decoder_dims * 2, decoder_dims * 2)
  199. self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
  200. dropout=dropout)
  201. self.attn_net = LSA(decoder_dims)
  202. self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
  203. self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
  204. self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
  205. self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
  206. self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
  207. self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
  208. def zoneout(self, prev, current, p=0.1):
  209. device = next(self.parameters()).device # Use same device as parameters
  210. mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
  211. return prev * mask + current * (1 - mask)
  212. def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
  213. hidden_states, cell_states, context_vec, t, chars):
  214. # Need this for reshaping mels
  215. batch_size = encoder_seq.size(0)
  216. # Unpack the hidden and cell states
  217. attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
  218. rnn1_cell, rnn2_cell = cell_states
  219. # PreNet for the Attention RNN
  220. prenet_out = self.prenet(prenet_in)
  221. # Compute the Attention RNN hidden state
  222. attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
  223. attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
  224. # Compute the attention scores
  225. scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
  226. # Dot product to create the context vector
  227. context_vec = scores @ encoder_seq
  228. context_vec = context_vec.squeeze(1)
  229. # Concat Attention RNN output w. Context Vector & project
  230. x = torch.cat([context_vec, attn_hidden], dim=1)
  231. x = self.rnn_input(x)
  232. # Compute first Residual RNN
  233. rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
  234. if self.training:
  235. rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
  236. else:
  237. rnn1_hidden = rnn1_hidden_next
  238. x = x + rnn1_hidden
  239. # Compute second Residual RNN
  240. rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
  241. if self.training:
  242. rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
  243. else:
  244. rnn2_hidden = rnn2_hidden_next
  245. x = x + rnn2_hidden
  246. # Project Mels
  247. mels = self.mel_proj(x)
  248. mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
  249. hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
  250. cell_states = (rnn1_cell, rnn2_cell)
  251. # Stop token prediction
  252. s = torch.cat((x, context_vec), dim=1)
  253. s = self.stop_proj(s)
  254. stop_tokens = torch.sigmoid(s)
  255. return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
  256. class Tacotron(nn.Module):
  257. def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
  258. fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
  259. dropout, stop_threshold, speaker_embedding_size):
  260. super().__init__()
  261. self.n_mels = n_mels
  262. self.lstm_dims = lstm_dims
  263. self.encoder_dims = encoder_dims
  264. self.decoder_dims = decoder_dims
  265. self.speaker_embedding_size = speaker_embedding_size
  266. self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
  267. encoder_K, num_highways, dropout)
  268. self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
  269. self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
  270. dropout, speaker_embedding_size)
  271. self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
  272. [postnet_dims, fft_bins], num_highways)
  273. self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
  274. self.init_model()
  275. self.num_params()
  276. self.register_buffer("step", torch.zeros(1, dtype=torch.long))
  277. self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
  278. @property
  279. def r(self):
  280. return self.decoder.r.item()
  281. @r.setter
  282. def r(self, value):
  283. self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
  284. def forward(self, x, m, speaker_embedding):
  285. device = next(self.parameters()).device # use same device as parameters
  286. self.step += 1
  287. batch_size, _, steps = m.size()
  288. # Initialise all hidden states and pack into tuple
  289. attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
  290. rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
  291. rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
  292. hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
  293. # Initialise all lstm cell states and pack into tuple
  294. rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
  295. rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
  296. cell_states = (rnn1_cell, rnn2_cell)
  297. # <GO> Frame for start of decoder loop
  298. go_frame = torch.zeros(batch_size, self.n_mels, device=device)
  299. # Need an initial context vector
  300. context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
  301. # SV2TTS: Run the encoder with the speaker embedding
  302. # The projection avoids unnecessary matmuls in the decoder loop
  303. encoder_seq = self.encoder(x, speaker_embedding)
  304. encoder_seq_proj = self.encoder_proj(encoder_seq)
  305. # Need a couple of lists for outputs
  306. mel_outputs, attn_scores, stop_outputs = [], [], []
  307. # Run the decoder loop
  308. for t in range(0, steps, self.r):
  309. prenet_in = m[:, :, t - 1] if t > 0 else go_frame
  310. mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
  311. self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
  312. hidden_states, cell_states, context_vec, t, x)
  313. mel_outputs.append(mel_frames)
  314. attn_scores.append(scores)
  315. stop_outputs.extend([stop_tokens] * self.r)
  316. # Concat the mel outputs into sequence
  317. mel_outputs = torch.cat(mel_outputs, dim=2)
  318. # Post-Process for Linear Spectrograms
  319. postnet_out = self.postnet(mel_outputs)
  320. linear = self.post_proj(postnet_out)
  321. linear = linear.transpose(1, 2)
  322. # For easy visualisation
  323. attn_scores = torch.cat(attn_scores, 1)
  324. # attn_scores = attn_scores.cpu().data.numpy()
  325. stop_outputs = torch.cat(stop_outputs, 1)
  326. return mel_outputs, linear, attn_scores, stop_outputs
  327. def generate(self, x, speaker_embedding=None, steps=2000):
  328. self.eval()
  329. device = next(self.parameters()).device # use same device as parameters
  330. batch_size, _ = x.size()
  331. # Need to initialise all hidden states and pack into tuple for tidyness
  332. attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
  333. rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
  334. rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
  335. hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
  336. # Need to initialise all lstm cell states and pack into tuple for tidyness
  337. rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
  338. rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
  339. cell_states = (rnn1_cell, rnn2_cell)
  340. # Need a <GO> Frame for start of decoder loop
  341. go_frame = torch.zeros(batch_size, self.n_mels, device=device)
  342. # Need an initial context vector
  343. context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
  344. # SV2TTS: Run the encoder with the speaker embedding
  345. # The projection avoids unnecessary matmuls in the decoder loop
  346. encoder_seq = self.encoder(x, speaker_embedding)
  347. encoder_seq_proj = self.encoder_proj(encoder_seq)
  348. # Need a couple of lists for outputs
  349. mel_outputs, attn_scores, stop_outputs = [], [], []
  350. # Run the decoder loop
  351. for t in range(0, steps, self.r):
  352. prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
  353. mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
  354. self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
  355. hidden_states, cell_states, context_vec, t, x)
  356. mel_outputs.append(mel_frames)
  357. attn_scores.append(scores)
  358. stop_outputs.extend([stop_tokens] * self.r)
  359. # Stop the loop when all stop tokens in batch exceed threshold
  360. if (stop_tokens > 0.5).all() and t > 10: break
  361. # Concat the mel outputs into sequence
  362. mel_outputs = torch.cat(mel_outputs, dim=2)
  363. # Post-Process for Linear Spectrograms
  364. postnet_out = self.postnet(mel_outputs)
  365. linear = self.post_proj(postnet_out)
  366. linear = linear.transpose(1, 2)
  367. # For easy visualisation
  368. attn_scores = torch.cat(attn_scores, 1)
  369. stop_outputs = torch.cat(stop_outputs, 1)
  370. self.train()
  371. return mel_outputs, linear, attn_scores
  372. def init_model(self):
  373. for p in self.parameters():
  374. if p.dim() > 1: nn.init.xavier_uniform_(p)
  375. def get_step(self):
  376. return self.step.data.item()
  377. def reset_step(self):
  378. # assignment to parameters or buffers is overloaded, updates internal dict entry
  379. self.step = self.step.data.new_tensor(1)
  380. def log(self, path, msg):
  381. with open(path, "a") as f:
  382. print(msg, file=f)
  383. def load(self, path, optimizer=None):
  384. # Use device of model params as location for loaded state
  385. device = next(self.parameters()).device
  386. checkpoint = torch.load(str(path), map_location=device)
  387. self.load_state_dict(checkpoint["model_state"])
  388. if "optimizer_state" in checkpoint and optimizer is not None:
  389. optimizer.load_state_dict(checkpoint["optimizer_state"])
  390. def save(self, path, optimizer=None):
  391. if optimizer is not None:
  392. torch.save({
  393. "model_state": self.state_dict(),
  394. "optimizer_state": optimizer.state_dict(),
  395. }, str(path))
  396. else:
  397. torch.save({
  398. "model_state": self.state_dict(),
  399. }, str(path))
  400. def num_params(self, print_out=True):
  401. parameters = filter(lambda p: p.requires_grad, self.parameters())
  402. parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
  403. if print_out:
  404. print("Trainable Parameters: %.3fM" % parameters)
  405. return parameters