add an optional to alter the loss and model structure of tacotron2, add an alternative config

This commit is contained in:
chenfeiyu 2021-04-26 21:18:29 +08:00
parent 4fc86abf5a
commit 263d3eb88b
6 changed files with 160 additions and 49 deletions

View File

@ -23,8 +23,8 @@ _C.data = CN(
n_fft=1024, # fft frame size
win_length=1024, # window size
hop_length=256, # hop size between ajacent frame
f_max=8000, # Hz, max frequency when converting to mel
f_min=0, # Hz, min frequency when converting to mel
fmax=8000, # Hz, max frequency when converting to mel
fmin=0, # Hz, min frequency when converting to mel
d_mels=80, # mel bands
padding_idx=0, # text embedding's padding index
))
@ -50,7 +50,9 @@ _C.model = CN(
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
p_postnet_dropout=0.5, # droput probability in decoder postnet
guided_attn_loss_sigma=0.2 # sigma in guided attention loss
use_stop_token=True, # wherther to use binary classifier to predict when to stop
use_guided_attention_loss=False, # whether to use guided attention loss
guided_attention_loss_sigma=0.2 # sigma in guided attention loss
))
_C.training = CN(

View File

@ -58,7 +58,7 @@ class LJSpeechCollector(object):
mels = []
text_lens = []
mel_lens = []
stop_tokens = []
for data in examples:
text, mel = data
text = np.array(text, dtype=np.int64)
@ -66,8 +66,6 @@ class LJSpeechCollector(object):
mels.append(mel)
texts.append(text)
mel_lens.append(mel.shape[1])
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
stop_tokens.append(np.append(stop_token, 1.0))
# Sort by text_len in descending order
texts = [
@ -86,15 +84,13 @@ class LJSpeechCollector(object):
for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
]
mel_lens = np.array(mel_lens, dtype=np.int64)
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
# Pad sequence with largest len of the batch
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
mels, _ = np.transpose(
batch_spec(
mels, pad_value=self.padding_value), axes=(0, 2, 1))
mels, _ = batch_spec(mels, pad_value=self.padding_value)
mels = np.transpose(mels, axes=(0, 2, 1))
return texts, mels, text_lens, mel_lens

View File

@ -39,8 +39,8 @@ def create_dataset(config, source_path, target_path, verbose=False):
n_mels=config.data.d_mels,
win_length=config.data.win_length,
hop_length=config.data.hop_length,
f_max=config.data.f_max,
f_min=config.data.f_min)
f_max=config.data.fmax,
f_min=config.data.fmin)
normalizer = LogMagnitude()
records = []

View File

@ -20,6 +20,8 @@ import paddle
import parakeet
from parakeet.frontend import EnglishCharacter
from parakeet.models.tacotron2 import Tacotron2
from parakeet.utils import display
from matplotlib import pyplot as plt
from config import get_cfg_defaults
@ -46,10 +48,13 @@ def main(config, args):
for i, sentence in enumerate(sentences):
sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0)
mel_output, _ = model.predict(sentence)
mel_output = mel_output["mel_outputs_postnet"][0].numpy().T
outputs = model.infer(sentence)
mel_output = outputs["mel_outputs_postnet"][0].numpy().T
alignment = outputs["alignments"][0].numpy().T
np.save(str(output_dir / f"sentence_{i}"), mel_output)
display.plot_alignment(alignment)
plt.savefig(str(output_dir / f"sentence_{i}.png"))
if args.verbose:
print("spectrogram saved at {}".format(output_dir /
f"sentence_{i}.npy"))

View File

@ -34,14 +34,18 @@ from ljspeech import LJSpeech, LJSpeechCollector
class Experiment(ExperimentBase):
def compute_losses(self, inputs, outputs):
_, mel_targets, plens, slens, stop_tokens = inputs
texts, mel_targets, plens, slens = inputs
mel_outputs = outputs["mel_output"]
mel_outputs_postnet = outputs["mel_outputs_postnet"]
attention_weight = outputs["alignments"]
if self.config.model.use_stop_token:
stop_logits = outputs["stop_logits"]
else:
stop_logits = None
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
attention_weight, slens, plens)
attention_weight, slens, plens, stop_logits)
return losses
def train_batch(self):
@ -52,7 +56,7 @@ class Experiment(ExperimentBase):
self.optimizer.clear_grad()
self.model.train()
texts, mels, text_lens, output_lens = batch
outputs = self.model(texts, mels, text_lens, output_lens)
outputs = self.model(texts, text_lens, mels, output_lens)
losses = self.compute_losses(batch, outputs)
loss = losses["loss"]
loss.backward()
@ -80,7 +84,7 @@ class Experiment(ExperimentBase):
valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader):
texts, mels, text_lens, output_lens = batch
outputs = self.model(texts, mels, text_lens, output_lens)
outputs = self.model(texts, text_lens, mels, output_lens)
losses = self.compute_losses(batch, outputs)
for k, v in losses.items():
valid_losses[k].append(float(v))
@ -88,7 +92,7 @@ class Experiment(ExperimentBase):
attention_weights = outputs["alignments"]
self.visualizer.add_figure(
f"valid_sentence_{i}_alignments",
display.plot_alignment(attention_weights[0].numpy()),
display.plot_alignment(attention_weights[0].numpy().T),
self.iteration)
self.visualizer.add_figure(
f"valid_sentence_{i}_target_spectrogram",
@ -114,9 +118,8 @@ class Experiment(ExperimentBase):
def setup_model(self):
config = self.config
frontend = EnglishCharacter()
model = Tacotron2(
frontend,
vocab_size=config.model.vocab_size,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers,
@ -135,7 +138,8 @@ class Experiment(ExperimentBase):
p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout)
p_postnet_dropout=config.model.p_postnet_dropout,
use_stop_token=config.model.use_stop_token)
if self.parallel:
model = paddle.DataParallel(model)
@ -148,7 +152,10 @@ class Experiment(ExperimentBase):
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma)
criterion = Tacotron2Loss(
use_stop_token_loss=config.model.use_stop_token,
use_guided_attention_loss=config.model.use_guided_attention_loss,
sigma=config.model.guided_attention_loss_sigma)
self.model = model
self.optimizer = optimizer
self.criterion = criterion

View File

@ -23,6 +23,7 @@ from paddle.fluid.layers import sequence_mask
from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules.losses import guided_attention_loss
from parakeet.utils import checkpoint
from tqdm import trange
__all__ = ["Tacotron2", "Tacotron2Loss"]
@ -274,7 +275,7 @@ class Tacotron2Decoder(nn.Layer):
d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int,
d_attention: int, attention_filters: int,
attention_kernel_size: int, p_prenet_dropout: float,
p_attention_dropout: float, p_decoder_dropout: float):
p_attention_dropout: float, p_decoder_dropout: float, use_stop_token: bool=False):
super().__init__()
self.d_mels = d_mels
self.reduction_factor = reduction_factor
@ -304,6 +305,10 @@ class Tacotron2Decoder(nn.Layer):
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
d_mels * reduction_factor)
self.use_stop_token = use_stop_token
if use_stop_token:
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
# states - temporary attributes
self.attention_hidden = None
self.attention_cell = None
@ -379,6 +384,9 @@ class Tacotron2Decoder(nn.Layer):
[self.decoder_hidden, self.attention_context], axis=-1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
if self.use_stop_token:
stop_logit = self.stop_layer(decoder_hidden_attention_context)
return decoder_output, self.attention_weights, stop_logit
return decoder_output, self.attention_weights
def forward(self, keys, querys, mask):
@ -416,16 +424,24 @@ class Tacotron2Decoder(nn.Layer):
querys = self.prenet(querys)
mel_outputs, alignments = [], []
stop_logits = []
# Ignore the last time step
while len(mel_outputs) < querys.shape[1] - 1:
query = querys[:, len(mel_outputs), :]
if self.use_stop_token:
mel_output, attention_weights, stop_logit = self._decode(query)
else:
mel_output, attention_weights = self._decode(query)
mel_outputs.append(mel_output)
alignments.append(attention_weights)
if self.use_stop_token:
stop_logits.append(stop_logit)
alignments = paddle.stack(alignments, axis=1)
mel_outputs = paddle.stack(mel_outputs, axis=1)
if self.use_stop_token:
stop_logits = paddle.concat(stop_logits, axis=1)
return mel_outputs, alignments, stop_logits
return mel_outputs, alignments
def infer(self, key, max_decoder_steps=1000):
@ -460,17 +476,28 @@ class Tacotron2Decoder(nn.Layer):
first_hit_end = None
mel_outputs, alignments = [], []
stop_logits = []
for i in trange(max_decoder_steps):
query = self.prenet(query)
if self.use_stop_token:
mel_output, alignment, stop_logit = self._decode(query)
else:
mel_output, alignment = self._decode(query)
mel_outputs.append(mel_output)
alignments.append(alignment) # (B=1, T)
if self.use_stop_token:
stop_logits.append(stop_logit)
if self.use_stop_token:
if F.sigmoid(stop_logit) > 0.5:
print("hit stop condition!")
break
else:
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
if first_hit_end is None:
first_hit_end = i
if first_hit_end is not None and i > (first_hit_end + 10):
elif i > (first_hit_end + 10):
print("content exhausted!")
break
if len(mel_outputs) == max_decoder_steps:
@ -481,7 +508,9 @@ class Tacotron2Decoder(nn.Layer):
alignments = paddle.stack(alignments, axis=1)
mel_outputs = paddle.stack(mel_outputs, axis=1)
if self.use_stop_token:
stop_logits = paddle.concat(stop_logits, axis=1)
return mel_outputs, alignments, stop_logits
return mel_outputs, alignments
@ -580,7 +609,8 @@ class Tacotron2(nn.Layer):
p_attention_dropout: float = 0.1,
p_decoder_dropout: float = 0.1,
p_postnet_dropout: float = 0.5,
d_global_condition=None):
d_global_condition=None,
use_stop_token=True):
super().__init__()
std = math.sqrt(2.0 / (vocab_size + d_encoder))
@ -606,7 +636,7 @@ class Tacotron2(nn.Layer):
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
d_decoder_rnn, d_attention, attention_filters,
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
p_decoder_dropout)
p_decoder_dropout, use_stop_token=use_stop_token)
self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor,
d_hidden=d_postnet,
kernel_size=postnet_kernel_size,
@ -664,10 +694,14 @@ class Tacotron2(nn.Layer):
# [B, T_enc, 1]
mask = sequence_mask(text_lens,
dtype=encoder_outputs.dtype).unsqueeze(-1)
if self.decoder.use_stop_token:
mel_outputs, alignments, stop_logits = self.decoder(encoder_outputs,
mels,
mask=mask)
else:
mel_outputs, alignments = self.decoder(encoder_outputs,
mels,
mask=mask)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
@ -681,6 +715,8 @@ class Tacotron2(nn.Layer):
"mel_outputs_postnet": mel_outputs_postnet,
"alignments": alignments
}
if self.decoder.use_stop_token:
outputs["stop_logits"] = stop_logits
return outputs
@ -723,6 +759,10 @@ class Tacotron2(nn.Layer):
global_condition, [-1, encoder_outputs.shape[1], -1])
encoder_outputs = paddle.concat(
[encoder_outputs, global_condition], -1)
if self.decoder.use_stop_token:
mel_outputs, alignments, stop_logits = self.decoder.infer(
encoder_outputs, max_decoder_steps=max_decoder_steps)
else:
mel_outputs, alignments = self.decoder.infer(
encoder_outputs, max_decoder_steps=max_decoder_steps)
@ -734,22 +774,72 @@ class Tacotron2(nn.Layer):
"mel_outputs_postnet": mel_outputs_postnet,
"alignments": alignments
}
if self.decoder.use_stop_token:
outputs["stop_logits"] = stop_logits
return outputs
@classmethod
def from_pretrained(cls, config, checkpoint_path):
"""Build a Tacotron2 model from a pretrained model.
Parameters
----------
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
ConditionalWaveFlow
The model built from pretrained result.
"""
model = cls(
vocab_size=config.model.vocab_size,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers,
encoder_kernel_size=config.model.encoder_kernel_size,
d_prenet=config.model.d_prenet,
d_attention_rnn=config.model.d_attention_rnn,
d_decoder_rnn=config.model.d_decoder_rnn,
attention_filters=config.model.attention_filters,
attention_kernel_size=config.model.attention_kernel_size,
d_attention=config.model.d_attention,
d_postnet=config.model.d_postnet,
postnet_kernel_size=config.model.postnet_kernel_size,
postnet_conv_layers=config.model.postnet_conv_layers,
reduction_factor=config.model.reduction_factor,
p_encoder_dropout=config.model.p_encoder_dropout,
p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout,
use_stop_token=config.model.use_stop_token)
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
return model
class Tacotron2Loss(nn.Layer):
""" Tacotron2 Loss module
"""
def __init__(self, sigma=0.2):
def __init__(self,
use_stop_token_loss=True,
use_guided_attention_loss=False,
sigma=0.2):
super().__init__()
self.spec_criterion = nn.MSELoss()
self.use_stop_token_loss = use_stop_token_loss
self.use_guided_attention_loss = use_guided_attention_loss
self.attn_criterion = guided_attention_loss
self.stop_criterion = paddle.nn.BCEWithLogitsLoss()
self.sigma = sigma
def forward(self, mel_outputs, mel_outputs_postnet, mel_targets,
attention_weights, slens, plens):
attention_weights=None, slens=None, plens=None, stop_logits=None):
"""Calculate tacotron2 loss.
Parameters
@ -775,13 +865,24 @@ class Tacotron2Loss(nn.Layer):
"""
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
total_loss = mel_loss + post_mel_loss
if self.use_guided_attention_loss:
gal_loss = self.attn_criterion(attention_weights, slens, plens,
self.sigma)
total_loss = mel_loss + post_mel_loss + gal_loss
total_loss += gal_loss
if self.use_stop_token_loss:
T_dec = mel_targets.shape[1]
stop_labels = F.one_hot(slens - 1, num_classes=T_dec)
stop_token_loss = self.stop_criterion(stop_logits, stop_labels)
total_loss += stop_token_loss
losses = {
"loss": total_loss,
"mel_loss": mel_loss,
"post_mel_loss": post_mel_loss,
"guided_attn_loss": gal_loss
"post_mel_loss": post_mel_loss
}
if self.use_guided_attention_loss:
losses["guided_attn_loss"] = gal_loss
if self.use_stop_token_loss:
losses["stop_loss"] = stop_token_loss
return losses