add an optional to alter the loss and model structure of tacotron2, add an alternative config
This commit is contained in:
parent
4fc86abf5a
commit
263d3eb88b
|
@ -23,8 +23,8 @@ _C.data = CN(
|
||||||
n_fft=1024, # fft frame size
|
n_fft=1024, # fft frame size
|
||||||
win_length=1024, # window size
|
win_length=1024, # window size
|
||||||
hop_length=256, # hop size between ajacent frame
|
hop_length=256, # hop size between ajacent frame
|
||||||
f_max=8000, # Hz, max frequency when converting to mel
|
fmax=8000, # Hz, max frequency when converting to mel
|
||||||
f_min=0, # Hz, min frequency when converting to mel
|
fmin=0, # Hz, min frequency when converting to mel
|
||||||
d_mels=80, # mel bands
|
d_mels=80, # mel bands
|
||||||
padding_idx=0, # text embedding's padding index
|
padding_idx=0, # text embedding's padding index
|
||||||
))
|
))
|
||||||
|
@ -50,7 +50,9 @@ _C.model = CN(
|
||||||
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
||||||
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
||||||
p_postnet_dropout=0.5, # droput probability in decoder postnet
|
p_postnet_dropout=0.5, # droput probability in decoder postnet
|
||||||
guided_attn_loss_sigma=0.2 # sigma in guided attention loss
|
use_stop_token=True, # wherther to use binary classifier to predict when to stop
|
||||||
|
use_guided_attention_loss=False, # whether to use guided attention loss
|
||||||
|
guided_attention_loss_sigma=0.2 # sigma in guided attention loss
|
||||||
))
|
))
|
||||||
|
|
||||||
_C.training = CN(
|
_C.training = CN(
|
||||||
|
|
|
@ -58,7 +58,7 @@ class LJSpeechCollector(object):
|
||||||
mels = []
|
mels = []
|
||||||
text_lens = []
|
text_lens = []
|
||||||
mel_lens = []
|
mel_lens = []
|
||||||
stop_tokens = []
|
|
||||||
for data in examples:
|
for data in examples:
|
||||||
text, mel = data
|
text, mel = data
|
||||||
text = np.array(text, dtype=np.int64)
|
text = np.array(text, dtype=np.int64)
|
||||||
|
@ -66,8 +66,6 @@ class LJSpeechCollector(object):
|
||||||
mels.append(mel)
|
mels.append(mel)
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
mel_lens.append(mel.shape[1])
|
mel_lens.append(mel.shape[1])
|
||||||
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
|
|
||||||
stop_tokens.append(np.append(stop_token, 1.0))
|
|
||||||
|
|
||||||
# Sort by text_len in descending order
|
# Sort by text_len in descending order
|
||||||
texts = [
|
texts = [
|
||||||
|
@ -86,15 +84,13 @@ class LJSpeechCollector(object):
|
||||||
for i, _ in sorted(
|
for i, _ in sorted(
|
||||||
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
mel_lens = np.array(mel_lens, dtype=np.int64)
|
mel_lens = np.array(mel_lens, dtype=np.int64)
|
||||||
|
|
||||||
|
|
||||||
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
|
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
|
||||||
|
|
||||||
# Pad sequence with largest len of the batch
|
# Pad sequence with largest len of the batch
|
||||||
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
|
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
|
||||||
mels, _ = np.transpose(
|
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||||
batch_spec(
|
mels = np.transpose(mels, axes=(0, 2, 1))
|
||||||
mels, pad_value=self.padding_value), axes=(0, 2, 1))
|
|
||||||
|
|
||||||
return texts, mels, text_lens, mel_lens
|
return texts, mels, text_lens, mel_lens
|
||||||
|
|
|
@ -39,8 +39,8 @@ def create_dataset(config, source_path, target_path, verbose=False):
|
||||||
n_mels=config.data.d_mels,
|
n_mels=config.data.d_mels,
|
||||||
win_length=config.data.win_length,
|
win_length=config.data.win_length,
|
||||||
hop_length=config.data.hop_length,
|
hop_length=config.data.hop_length,
|
||||||
f_max=config.data.f_max,
|
f_max=config.data.fmax,
|
||||||
f_min=config.data.f_min)
|
f_min=config.data.fmin)
|
||||||
normalizer = LogMagnitude()
|
normalizer = LogMagnitude()
|
||||||
|
|
||||||
records = []
|
records = []
|
||||||
|
|
|
@ -20,6 +20,8 @@ import paddle
|
||||||
import parakeet
|
import parakeet
|
||||||
from parakeet.frontend import EnglishCharacter
|
from parakeet.frontend import EnglishCharacter
|
||||||
from parakeet.models.tacotron2 import Tacotron2
|
from parakeet.models.tacotron2 import Tacotron2
|
||||||
|
from parakeet.utils import display
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from config import get_cfg_defaults
|
from config import get_cfg_defaults
|
||||||
|
|
||||||
|
@ -46,10 +48,13 @@ def main(config, args):
|
||||||
for i, sentence in enumerate(sentences):
|
for i, sentence in enumerate(sentences):
|
||||||
sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0)
|
sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0)
|
||||||
|
|
||||||
mel_output, _ = model.predict(sentence)
|
outputs = model.infer(sentence)
|
||||||
mel_output = mel_output["mel_outputs_postnet"][0].numpy().T
|
mel_output = outputs["mel_outputs_postnet"][0].numpy().T
|
||||||
|
alignment = outputs["alignments"][0].numpy().T
|
||||||
|
|
||||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||||
|
display.plot_alignment(alignment)
|
||||||
|
plt.savefig(str(output_dir / f"sentence_{i}.png"))
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
print("spectrogram saved at {}".format(output_dir /
|
print("spectrogram saved at {}".format(output_dir /
|
||||||
f"sentence_{i}.npy"))
|
f"sentence_{i}.npy"))
|
||||||
|
|
|
@ -34,14 +34,18 @@ from ljspeech import LJSpeech, LJSpeechCollector
|
||||||
|
|
||||||
class Experiment(ExperimentBase):
|
class Experiment(ExperimentBase):
|
||||||
def compute_losses(self, inputs, outputs):
|
def compute_losses(self, inputs, outputs):
|
||||||
_, mel_targets, plens, slens, stop_tokens = inputs
|
texts, mel_targets, plens, slens = inputs
|
||||||
|
|
||||||
mel_outputs = outputs["mel_output"]
|
mel_outputs = outputs["mel_output"]
|
||||||
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
||||||
attention_weight = outputs["alignments"]
|
attention_weight = outputs["alignments"]
|
||||||
|
if self.config.model.use_stop_token:
|
||||||
|
stop_logits = outputs["stop_logits"]
|
||||||
|
else:
|
||||||
|
stop_logits = None
|
||||||
|
|
||||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
|
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
|
||||||
attention_weight, slens, plens)
|
attention_weight, slens, plens, stop_logits)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def train_batch(self):
|
def train_batch(self):
|
||||||
|
@ -52,7 +56,7 @@ class Experiment(ExperimentBase):
|
||||||
self.optimizer.clear_grad()
|
self.optimizer.clear_grad()
|
||||||
self.model.train()
|
self.model.train()
|
||||||
texts, mels, text_lens, output_lens = batch
|
texts, mels, text_lens, output_lens = batch
|
||||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
outputs = self.model(texts, text_lens, mels, output_lens)
|
||||||
losses = self.compute_losses(batch, outputs)
|
losses = self.compute_losses(batch, outputs)
|
||||||
loss = losses["loss"]
|
loss = losses["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -80,7 +84,7 @@ class Experiment(ExperimentBase):
|
||||||
valid_losses = defaultdict(list)
|
valid_losses = defaultdict(list)
|
||||||
for i, batch in enumerate(self.valid_loader):
|
for i, batch in enumerate(self.valid_loader):
|
||||||
texts, mels, text_lens, output_lens = batch
|
texts, mels, text_lens, output_lens = batch
|
||||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
outputs = self.model(texts, text_lens, mels, output_lens)
|
||||||
losses = self.compute_losses(batch, outputs)
|
losses = self.compute_losses(batch, outputs)
|
||||||
for k, v in losses.items():
|
for k, v in losses.items():
|
||||||
valid_losses[k].append(float(v))
|
valid_losses[k].append(float(v))
|
||||||
|
@ -88,7 +92,7 @@ class Experiment(ExperimentBase):
|
||||||
attention_weights = outputs["alignments"]
|
attention_weights = outputs["alignments"]
|
||||||
self.visualizer.add_figure(
|
self.visualizer.add_figure(
|
||||||
f"valid_sentence_{i}_alignments",
|
f"valid_sentence_{i}_alignments",
|
||||||
display.plot_alignment(attention_weights[0].numpy()),
|
display.plot_alignment(attention_weights[0].numpy().T),
|
||||||
self.iteration)
|
self.iteration)
|
||||||
self.visualizer.add_figure(
|
self.visualizer.add_figure(
|
||||||
f"valid_sentence_{i}_target_spectrogram",
|
f"valid_sentence_{i}_target_spectrogram",
|
||||||
|
@ -114,9 +118,8 @@ class Experiment(ExperimentBase):
|
||||||
|
|
||||||
def setup_model(self):
|
def setup_model(self):
|
||||||
config = self.config
|
config = self.config
|
||||||
frontend = EnglishCharacter()
|
|
||||||
model = Tacotron2(
|
model = Tacotron2(
|
||||||
frontend,
|
vocab_size=config.model.vocab_size,
|
||||||
d_mels=config.data.d_mels,
|
d_mels=config.data.d_mels,
|
||||||
d_encoder=config.model.d_encoder,
|
d_encoder=config.model.d_encoder,
|
||||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||||
|
@ -135,7 +138,8 @@ class Experiment(ExperimentBase):
|
||||||
p_prenet_dropout=config.model.p_prenet_dropout,
|
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||||
p_attention_dropout=config.model.p_attention_dropout,
|
p_attention_dropout=config.model.p_attention_dropout,
|
||||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||||
p_postnet_dropout=config.model.p_postnet_dropout)
|
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||||
|
use_stop_token=config.model.use_stop_token)
|
||||||
|
|
||||||
if self.parallel:
|
if self.parallel:
|
||||||
model = paddle.DataParallel(model)
|
model = paddle.DataParallel(model)
|
||||||
|
@ -148,7 +152,10 @@ class Experiment(ExperimentBase):
|
||||||
weight_decay=paddle.regularizer.L2Decay(
|
weight_decay=paddle.regularizer.L2Decay(
|
||||||
config.training.weight_decay),
|
config.training.weight_decay),
|
||||||
grad_clip=grad_clip)
|
grad_clip=grad_clip)
|
||||||
criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma)
|
criterion = Tacotron2Loss(
|
||||||
|
use_stop_token_loss=config.model.use_stop_token,
|
||||||
|
use_guided_attention_loss=config.model.use_guided_attention_loss,
|
||||||
|
sigma=config.model.guided_attention_loss_sigma)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
|
|
|
@ -23,6 +23,7 @@ from paddle.fluid.layers import sequence_mask
|
||||||
from parakeet.modules.conv import Conv1dBatchNorm
|
from parakeet.modules.conv import Conv1dBatchNorm
|
||||||
from parakeet.modules.attention import LocationSensitiveAttention
|
from parakeet.modules.attention import LocationSensitiveAttention
|
||||||
from parakeet.modules.losses import guided_attention_loss
|
from parakeet.modules.losses import guided_attention_loss
|
||||||
|
from parakeet.utils import checkpoint
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||||
|
@ -274,7 +275,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int,
|
d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int,
|
||||||
d_attention: int, attention_filters: int,
|
d_attention: int, attention_filters: int,
|
||||||
attention_kernel_size: int, p_prenet_dropout: float,
|
attention_kernel_size: int, p_prenet_dropout: float,
|
||||||
p_attention_dropout: float, p_decoder_dropout: float):
|
p_attention_dropout: float, p_decoder_dropout: float, use_stop_token: bool=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_mels = d_mels
|
self.d_mels = d_mels
|
||||||
self.reduction_factor = reduction_factor
|
self.reduction_factor = reduction_factor
|
||||||
|
@ -303,6 +304,10 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
d_decoder_rnn)
|
d_decoder_rnn)
|
||||||
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
||||||
d_mels * reduction_factor)
|
d_mels * reduction_factor)
|
||||||
|
|
||||||
|
self.use_stop_token = use_stop_token
|
||||||
|
if use_stop_token:
|
||||||
|
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||||
|
|
||||||
# states - temporary attributes
|
# states - temporary attributes
|
||||||
self.attention_hidden = None
|
self.attention_hidden = None
|
||||||
|
@ -379,6 +384,9 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
[self.decoder_hidden, self.attention_context], axis=-1)
|
[self.decoder_hidden, self.attention_context], axis=-1)
|
||||||
decoder_output = self.linear_projection(
|
decoder_output = self.linear_projection(
|
||||||
decoder_hidden_attention_context)
|
decoder_hidden_attention_context)
|
||||||
|
if self.use_stop_token:
|
||||||
|
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||||
|
return decoder_output, self.attention_weights, stop_logit
|
||||||
return decoder_output, self.attention_weights
|
return decoder_output, self.attention_weights
|
||||||
|
|
||||||
def forward(self, keys, querys, mask):
|
def forward(self, keys, querys, mask):
|
||||||
|
@ -416,16 +424,24 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
querys = self.prenet(querys)
|
querys = self.prenet(querys)
|
||||||
|
|
||||||
mel_outputs, alignments = [], []
|
mel_outputs, alignments = [], []
|
||||||
|
stop_logits = []
|
||||||
# Ignore the last time step
|
# Ignore the last time step
|
||||||
while len(mel_outputs) < querys.shape[1] - 1:
|
while len(mel_outputs) < querys.shape[1] - 1:
|
||||||
query = querys[:, len(mel_outputs), :]
|
query = querys[:, len(mel_outputs), :]
|
||||||
mel_output, attention_weights = self._decode(query)
|
if self.use_stop_token:
|
||||||
|
mel_output, attention_weights, stop_logit = self._decode(query)
|
||||||
|
else:
|
||||||
|
mel_output, attention_weights = self._decode(query)
|
||||||
mel_outputs.append(mel_output)
|
mel_outputs.append(mel_output)
|
||||||
alignments.append(attention_weights)
|
alignments.append(attention_weights)
|
||||||
|
if self.use_stop_token:
|
||||||
|
stop_logits.append(stop_logit)
|
||||||
|
|
||||||
alignments = paddle.stack(alignments, axis=1)
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
if self.use_stop_token:
|
||||||
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
return mel_outputs, alignments, stop_logits
|
||||||
return mel_outputs, alignments
|
return mel_outputs, alignments
|
||||||
|
|
||||||
def infer(self, key, max_decoder_steps=1000):
|
def infer(self, key, max_decoder_steps=1000):
|
||||||
|
@ -460,19 +476,30 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
first_hit_end = None
|
first_hit_end = None
|
||||||
|
|
||||||
mel_outputs, alignments = [], []
|
mel_outputs, alignments = [], []
|
||||||
|
stop_logits = []
|
||||||
for i in trange(max_decoder_steps):
|
for i in trange(max_decoder_steps):
|
||||||
query = self.prenet(query)
|
query = self.prenet(query)
|
||||||
mel_output, alignment = self._decode(query)
|
if self.use_stop_token:
|
||||||
|
mel_output, alignment, stop_logit = self._decode(query)
|
||||||
|
else:
|
||||||
|
mel_output, alignment = self._decode(query)
|
||||||
|
|
||||||
mel_outputs.append(mel_output)
|
mel_outputs.append(mel_output)
|
||||||
alignments.append(alignment) # (B=1, T)
|
alignments.append(alignment) # (B=1, T)
|
||||||
|
if self.use_stop_token:
|
||||||
|
stop_logits.append(stop_logit)
|
||||||
|
|
||||||
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
|
if self.use_stop_token:
|
||||||
if first_hit_end is None:
|
if F.sigmoid(stop_logit) > 0.5:
|
||||||
first_hit_end = i
|
print("hit stop condition!")
|
||||||
if first_hit_end is not None and i > (first_hit_end + 10):
|
break
|
||||||
print("content exhausted!")
|
else:
|
||||||
break
|
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
|
||||||
|
if first_hit_end is None:
|
||||||
|
first_hit_end = i
|
||||||
|
elif i > (first_hit_end + 10):
|
||||||
|
print("content exhausted!")
|
||||||
|
break
|
||||||
if len(mel_outputs) == max_decoder_steps:
|
if len(mel_outputs) == max_decoder_steps:
|
||||||
print("Warning! Reached max decoder steps!!!")
|
print("Warning! Reached max decoder steps!!!")
|
||||||
break
|
break
|
||||||
|
@ -481,7 +508,9 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
alignments = paddle.stack(alignments, axis=1)
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
if self.use_stop_token:
|
||||||
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
return mel_outputs, alignments, stop_logits
|
||||||
return mel_outputs, alignments
|
return mel_outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
|
@ -580,7 +609,8 @@ class Tacotron2(nn.Layer):
|
||||||
p_attention_dropout: float = 0.1,
|
p_attention_dropout: float = 0.1,
|
||||||
p_decoder_dropout: float = 0.1,
|
p_decoder_dropout: float = 0.1,
|
||||||
p_postnet_dropout: float = 0.5,
|
p_postnet_dropout: float = 0.5,
|
||||||
d_global_condition=None):
|
d_global_condition=None,
|
||||||
|
use_stop_token=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
std = math.sqrt(2.0 / (vocab_size + d_encoder))
|
std = math.sqrt(2.0 / (vocab_size + d_encoder))
|
||||||
|
@ -606,7 +636,7 @@ class Tacotron2(nn.Layer):
|
||||||
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
|
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
|
||||||
d_decoder_rnn, d_attention, attention_filters,
|
d_decoder_rnn, d_attention, attention_filters,
|
||||||
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||||
p_decoder_dropout)
|
p_decoder_dropout, use_stop_token=use_stop_token)
|
||||||
self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor,
|
self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor,
|
||||||
d_hidden=d_postnet,
|
d_hidden=d_postnet,
|
||||||
kernel_size=postnet_kernel_size,
|
kernel_size=postnet_kernel_size,
|
||||||
|
@ -664,10 +694,14 @@ class Tacotron2(nn.Layer):
|
||||||
# [B, T_enc, 1]
|
# [B, T_enc, 1]
|
||||||
mask = sequence_mask(text_lens,
|
mask = sequence_mask(text_lens,
|
||||||
dtype=encoder_outputs.dtype).unsqueeze(-1)
|
dtype=encoder_outputs.dtype).unsqueeze(-1)
|
||||||
mel_outputs, alignments = self.decoder(encoder_outputs,
|
if self.decoder.use_stop_token:
|
||||||
mels,
|
mel_outputs, alignments, stop_logits = self.decoder(encoder_outputs,
|
||||||
mask=mask)
|
mels,
|
||||||
|
mask=mask)
|
||||||
|
else:
|
||||||
|
mel_outputs, alignments = self.decoder(encoder_outputs,
|
||||||
|
mels,
|
||||||
|
mask=mask)
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
|
@ -681,6 +715,8 @@ class Tacotron2(nn.Layer):
|
||||||
"mel_outputs_postnet": mel_outputs_postnet,
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
"alignments": alignments
|
"alignments": alignments
|
||||||
}
|
}
|
||||||
|
if self.decoder.use_stop_token:
|
||||||
|
outputs["stop_logits"] = stop_logits
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -723,8 +759,12 @@ class Tacotron2(nn.Layer):
|
||||||
global_condition, [-1, encoder_outputs.shape[1], -1])
|
global_condition, [-1, encoder_outputs.shape[1], -1])
|
||||||
encoder_outputs = paddle.concat(
|
encoder_outputs = paddle.concat(
|
||||||
[encoder_outputs, global_condition], -1)
|
[encoder_outputs, global_condition], -1)
|
||||||
mel_outputs, alignments = self.decoder.infer(
|
if self.decoder.use_stop_token:
|
||||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
mel_outputs, alignments, stop_logits = self.decoder.infer(
|
||||||
|
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||||
|
else:
|
||||||
|
mel_outputs, alignments = self.decoder.infer(
|
||||||
|
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||||
|
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
@ -734,22 +774,72 @@ class Tacotron2(nn.Layer):
|
||||||
"mel_outputs_postnet": mel_outputs_postnet,
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
"alignments": alignments
|
"alignments": alignments
|
||||||
}
|
}
|
||||||
|
if self.decoder.use_stop_token:
|
||||||
|
outputs["stop_logits"] = stop_logits
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, config, checkpoint_path):
|
||||||
|
"""Build a Tacotron2 model from a pretrained model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config: yacs.config.CfgNode
|
||||||
|
model configs
|
||||||
|
|
||||||
|
checkpoint_path: Path or str
|
||||||
|
the path of pretrained model checkpoint, without extension name
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ConditionalWaveFlow
|
||||||
|
The model built from pretrained result.
|
||||||
|
"""
|
||||||
|
model = cls(
|
||||||
|
vocab_size=config.model.vocab_size,
|
||||||
|
d_mels=config.data.d_mels,
|
||||||
|
d_encoder=config.model.d_encoder,
|
||||||
|
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||||
|
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||||
|
d_prenet=config.model.d_prenet,
|
||||||
|
d_attention_rnn=config.model.d_attention_rnn,
|
||||||
|
d_decoder_rnn=config.model.d_decoder_rnn,
|
||||||
|
attention_filters=config.model.attention_filters,
|
||||||
|
attention_kernel_size=config.model.attention_kernel_size,
|
||||||
|
d_attention=config.model.d_attention,
|
||||||
|
d_postnet=config.model.d_postnet,
|
||||||
|
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||||
|
postnet_conv_layers=config.model.postnet_conv_layers,
|
||||||
|
reduction_factor=config.model.reduction_factor,
|
||||||
|
p_encoder_dropout=config.model.p_encoder_dropout,
|
||||||
|
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||||
|
p_attention_dropout=config.model.p_attention_dropout,
|
||||||
|
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||||
|
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||||
|
use_stop_token=config.model.use_stop_token)
|
||||||
|
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class Tacotron2Loss(nn.Layer):
|
class Tacotron2Loss(nn.Layer):
|
||||||
""" Tacotron2 Loss module
|
""" Tacotron2 Loss module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sigma=0.2):
|
def __init__(self,
|
||||||
|
use_stop_token_loss=True,
|
||||||
|
use_guided_attention_loss=False,
|
||||||
|
sigma=0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.spec_criterion = nn.MSELoss()
|
self.spec_criterion = nn.MSELoss()
|
||||||
|
self.use_stop_token_loss = use_stop_token_loss
|
||||||
|
self.use_guided_attention_loss = use_guided_attention_loss
|
||||||
self.attn_criterion = guided_attention_loss
|
self.attn_criterion = guided_attention_loss
|
||||||
|
self.stop_criterion = paddle.nn.BCEWithLogitsLoss()
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
|
|
||||||
def forward(self, mel_outputs, mel_outputs_postnet, mel_targets,
|
def forward(self, mel_outputs, mel_outputs_postnet, mel_targets,
|
||||||
attention_weights, slens, plens):
|
attention_weights=None, slens=None, plens=None, stop_logits=None):
|
||||||
"""Calculate tacotron2 loss.
|
"""Calculate tacotron2 loss.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -775,13 +865,24 @@ class Tacotron2Loss(nn.Layer):
|
||||||
"""
|
"""
|
||||||
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
|
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
|
||||||
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
|
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
|
||||||
gal_loss = self.attn_criterion(attention_weights, slens, plens,
|
total_loss = mel_loss + post_mel_loss
|
||||||
|
if self.use_guided_attention_loss:
|
||||||
|
gal_loss = self.attn_criterion(attention_weights, slens, plens,
|
||||||
self.sigma)
|
self.sigma)
|
||||||
total_loss = mel_loss + post_mel_loss + gal_loss
|
total_loss += gal_loss
|
||||||
|
if self.use_stop_token_loss:
|
||||||
|
T_dec = mel_targets.shape[1]
|
||||||
|
stop_labels = F.one_hot(slens - 1, num_classes=T_dec)
|
||||||
|
stop_token_loss = self.stop_criterion(stop_logits, stop_labels)
|
||||||
|
total_loss += stop_token_loss
|
||||||
|
|
||||||
losses = {
|
losses = {
|
||||||
"loss": total_loss,
|
"loss": total_loss,
|
||||||
"mel_loss": mel_loss,
|
"mel_loss": mel_loss,
|
||||||
"post_mel_loss": post_mel_loss,
|
"post_mel_loss": post_mel_loss
|
||||||
"guided_attn_loss": gal_loss
|
|
||||||
}
|
}
|
||||||
|
if self.use_guided_attention_loss:
|
||||||
|
losses["guided_attn_loss"] = gal_loss
|
||||||
|
if self.use_stop_token_loss:
|
||||||
|
losses["stop_loss"] = stop_token_loss
|
||||||
return losses
|
return losses
|
||||||
|
|
Loading…
Reference in New Issue