add an optional to alter the loss and model structure of tacotron2, add an alternative config

This commit is contained in:
chenfeiyu 2021-04-26 21:18:29 +08:00
parent 4fc86abf5a
commit 263d3eb88b
6 changed files with 160 additions and 49 deletions

View File

@ -23,8 +23,8 @@ _C.data = CN(
n_fft=1024, # fft frame size n_fft=1024, # fft frame size
win_length=1024, # window size win_length=1024, # window size
hop_length=256, # hop size between ajacent frame hop_length=256, # hop size between ajacent frame
f_max=8000, # Hz, max frequency when converting to mel fmax=8000, # Hz, max frequency when converting to mel
f_min=0, # Hz, min frequency when converting to mel fmin=0, # Hz, min frequency when converting to mel
d_mels=80, # mel bands d_mels=80, # mel bands
padding_idx=0, # text embedding's padding index padding_idx=0, # text embedding's padding index
)) ))
@ -50,7 +50,9 @@ _C.model = CN(
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
p_postnet_dropout=0.5, # droput probability in decoder postnet p_postnet_dropout=0.5, # droput probability in decoder postnet
guided_attn_loss_sigma=0.2 # sigma in guided attention loss use_stop_token=True, # wherther to use binary classifier to predict when to stop
use_guided_attention_loss=False, # whether to use guided attention loss
guided_attention_loss_sigma=0.2 # sigma in guided attention loss
)) ))
_C.training = CN( _C.training = CN(

View File

@ -58,7 +58,7 @@ class LJSpeechCollector(object):
mels = [] mels = []
text_lens = [] text_lens = []
mel_lens = [] mel_lens = []
stop_tokens = []
for data in examples: for data in examples:
text, mel = data text, mel = data
text = np.array(text, dtype=np.int64) text = np.array(text, dtype=np.int64)
@ -66,8 +66,6 @@ class LJSpeechCollector(object):
mels.append(mel) mels.append(mel)
texts.append(text) texts.append(text)
mel_lens.append(mel.shape[1]) mel_lens.append(mel.shape[1])
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
stop_tokens.append(np.append(stop_token, 1.0))
# Sort by text_len in descending order # Sort by text_len in descending order
texts = [ texts = [
@ -86,15 +84,13 @@ class LJSpeechCollector(object):
for i, _ in sorted( for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
] ]
mel_lens = np.array(mel_lens, dtype=np.int64) mel_lens = np.array(mel_lens, dtype=np.int64)
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64) text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
# Pad sequence with largest len of the batch # Pad sequence with largest len of the batch
texts, _ = batch_text_id(texts, pad_id=self.padding_idx) texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
mels, _ = np.transpose( mels, _ = batch_spec(mels, pad_value=self.padding_value)
batch_spec( mels = np.transpose(mels, axes=(0, 2, 1))
mels, pad_value=self.padding_value), axes=(0, 2, 1))
return texts, mels, text_lens, mel_lens return texts, mels, text_lens, mel_lens

View File

@ -39,8 +39,8 @@ def create_dataset(config, source_path, target_path, verbose=False):
n_mels=config.data.d_mels, n_mels=config.data.d_mels,
win_length=config.data.win_length, win_length=config.data.win_length,
hop_length=config.data.hop_length, hop_length=config.data.hop_length,
f_max=config.data.f_max, f_max=config.data.fmax,
f_min=config.data.f_min) f_min=config.data.fmin)
normalizer = LogMagnitude() normalizer = LogMagnitude()
records = [] records = []

View File

@ -20,6 +20,8 @@ import paddle
import parakeet import parakeet
from parakeet.frontend import EnglishCharacter from parakeet.frontend import EnglishCharacter
from parakeet.models.tacotron2 import Tacotron2 from parakeet.models.tacotron2 import Tacotron2
from parakeet.utils import display
from matplotlib import pyplot as plt
from config import get_cfg_defaults from config import get_cfg_defaults
@ -46,10 +48,13 @@ def main(config, args):
for i, sentence in enumerate(sentences): for i, sentence in enumerate(sentences):
sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0) sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0)
mel_output, _ = model.predict(sentence) outputs = model.infer(sentence)
mel_output = mel_output["mel_outputs_postnet"][0].numpy().T mel_output = outputs["mel_outputs_postnet"][0].numpy().T
alignment = outputs["alignments"][0].numpy().T
np.save(str(output_dir / f"sentence_{i}"), mel_output) np.save(str(output_dir / f"sentence_{i}"), mel_output)
display.plot_alignment(alignment)
plt.savefig(str(output_dir / f"sentence_{i}.png"))
if args.verbose: if args.verbose:
print("spectrogram saved at {}".format(output_dir / print("spectrogram saved at {}".format(output_dir /
f"sentence_{i}.npy")) f"sentence_{i}.npy"))

View File

@ -34,14 +34,18 @@ from ljspeech import LJSpeech, LJSpeechCollector
class Experiment(ExperimentBase): class Experiment(ExperimentBase):
def compute_losses(self, inputs, outputs): def compute_losses(self, inputs, outputs):
_, mel_targets, plens, slens, stop_tokens = inputs texts, mel_targets, plens, slens = inputs
mel_outputs = outputs["mel_output"] mel_outputs = outputs["mel_output"]
mel_outputs_postnet = outputs["mel_outputs_postnet"] mel_outputs_postnet = outputs["mel_outputs_postnet"]
attention_weight = outputs["alignments"] attention_weight = outputs["alignments"]
if self.config.model.use_stop_token:
stop_logits = outputs["stop_logits"]
else:
stop_logits = None
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets, losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
attention_weight, slens, plens) attention_weight, slens, plens, stop_logits)
return losses return losses
def train_batch(self): def train_batch(self):
@ -52,7 +56,7 @@ class Experiment(ExperimentBase):
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.model.train() self.model.train()
texts, mels, text_lens, output_lens = batch texts, mels, text_lens, output_lens = batch
outputs = self.model(texts, mels, text_lens, output_lens) outputs = self.model(texts, text_lens, mels, output_lens)
losses = self.compute_losses(batch, outputs) losses = self.compute_losses(batch, outputs)
loss = losses["loss"] loss = losses["loss"]
loss.backward() loss.backward()
@ -80,7 +84,7 @@ class Experiment(ExperimentBase):
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
texts, mels, text_lens, output_lens = batch texts, mels, text_lens, output_lens = batch
outputs = self.model(texts, mels, text_lens, output_lens) outputs = self.model(texts, text_lens, mels, output_lens)
losses = self.compute_losses(batch, outputs) losses = self.compute_losses(batch, outputs)
for k, v in losses.items(): for k, v in losses.items():
valid_losses[k].append(float(v)) valid_losses[k].append(float(v))
@ -88,7 +92,7 @@ class Experiment(ExperimentBase):
attention_weights = outputs["alignments"] attention_weights = outputs["alignments"]
self.visualizer.add_figure( self.visualizer.add_figure(
f"valid_sentence_{i}_alignments", f"valid_sentence_{i}_alignments",
display.plot_alignment(attention_weights[0].numpy()), display.plot_alignment(attention_weights[0].numpy().T),
self.iteration) self.iteration)
self.visualizer.add_figure( self.visualizer.add_figure(
f"valid_sentence_{i}_target_spectrogram", f"valid_sentence_{i}_target_spectrogram",
@ -114,9 +118,8 @@ class Experiment(ExperimentBase):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
frontend = EnglishCharacter()
model = Tacotron2( model = Tacotron2(
frontend, vocab_size=config.model.vocab_size,
d_mels=config.data.d_mels, d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder, d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers, encoder_conv_layers=config.model.encoder_conv_layers,
@ -135,7 +138,8 @@ class Experiment(ExperimentBase):
p_prenet_dropout=config.model.p_prenet_dropout, p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout, p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout, p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout) p_postnet_dropout=config.model.p_postnet_dropout,
use_stop_token=config.model.use_stop_token)
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
@ -148,7 +152,10 @@ class Experiment(ExperimentBase):
weight_decay=paddle.regularizer.L2Decay( weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay), config.training.weight_decay),
grad_clip=grad_clip) grad_clip=grad_clip)
criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma) criterion = Tacotron2Loss(
use_stop_token_loss=config.model.use_stop_token,
use_guided_attention_loss=config.model.use_guided_attention_loss,
sigma=config.model.guided_attention_loss_sigma)
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.criterion = criterion self.criterion = criterion

View File

@ -23,6 +23,7 @@ from paddle.fluid.layers import sequence_mask
from parakeet.modules.conv import Conv1dBatchNorm from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules.losses import guided_attention_loss from parakeet.modules.losses import guided_attention_loss
from parakeet.utils import checkpoint
from tqdm import trange from tqdm import trange
__all__ = ["Tacotron2", "Tacotron2Loss"] __all__ = ["Tacotron2", "Tacotron2Loss"]
@ -274,7 +275,7 @@ class Tacotron2Decoder(nn.Layer):
d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int, d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int,
d_attention: int, attention_filters: int, d_attention: int, attention_filters: int,
attention_kernel_size: int, p_prenet_dropout: float, attention_kernel_size: int, p_prenet_dropout: float,
p_attention_dropout: float, p_decoder_dropout: float): p_attention_dropout: float, p_decoder_dropout: float, use_stop_token: bool=False):
super().__init__() super().__init__()
self.d_mels = d_mels self.d_mels = d_mels
self.reduction_factor = reduction_factor self.reduction_factor = reduction_factor
@ -303,6 +304,10 @@ class Tacotron2Decoder(nn.Layer):
d_decoder_rnn) d_decoder_rnn)
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder, self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
d_mels * reduction_factor) d_mels * reduction_factor)
self.use_stop_token = use_stop_token
if use_stop_token:
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
# states - temporary attributes # states - temporary attributes
self.attention_hidden = None self.attention_hidden = None
@ -379,6 +384,9 @@ class Tacotron2Decoder(nn.Layer):
[self.decoder_hidden, self.attention_context], axis=-1) [self.decoder_hidden, self.attention_context], axis=-1)
decoder_output = self.linear_projection( decoder_output = self.linear_projection(
decoder_hidden_attention_context) decoder_hidden_attention_context)
if self.use_stop_token:
stop_logit = self.stop_layer(decoder_hidden_attention_context)
return decoder_output, self.attention_weights, stop_logit
return decoder_output, self.attention_weights return decoder_output, self.attention_weights
def forward(self, keys, querys, mask): def forward(self, keys, querys, mask):
@ -416,16 +424,24 @@ class Tacotron2Decoder(nn.Layer):
querys = self.prenet(querys) querys = self.prenet(querys)
mel_outputs, alignments = [], [] mel_outputs, alignments = [], []
stop_logits = []
# Ignore the last time step # Ignore the last time step
while len(mel_outputs) < querys.shape[1] - 1: while len(mel_outputs) < querys.shape[1] - 1:
query = querys[:, len(mel_outputs), :] query = querys[:, len(mel_outputs), :]
mel_output, attention_weights = self._decode(query) if self.use_stop_token:
mel_output, attention_weights, stop_logit = self._decode(query)
else:
mel_output, attention_weights = self._decode(query)
mel_outputs.append(mel_output) mel_outputs.append(mel_output)
alignments.append(attention_weights) alignments.append(attention_weights)
if self.use_stop_token:
stop_logits.append(stop_logit)
alignments = paddle.stack(alignments, axis=1) alignments = paddle.stack(alignments, axis=1)
mel_outputs = paddle.stack(mel_outputs, axis=1) mel_outputs = paddle.stack(mel_outputs, axis=1)
if self.use_stop_token:
stop_logits = paddle.concat(stop_logits, axis=1)
return mel_outputs, alignments, stop_logits
return mel_outputs, alignments return mel_outputs, alignments
def infer(self, key, max_decoder_steps=1000): def infer(self, key, max_decoder_steps=1000):
@ -460,19 +476,30 @@ class Tacotron2Decoder(nn.Layer):
first_hit_end = None first_hit_end = None
mel_outputs, alignments = [], [] mel_outputs, alignments = [], []
stop_logits = []
for i in trange(max_decoder_steps): for i in trange(max_decoder_steps):
query = self.prenet(query) query = self.prenet(query)
mel_output, alignment = self._decode(query) if self.use_stop_token:
mel_output, alignment, stop_logit = self._decode(query)
else:
mel_output, alignment = self._decode(query)
mel_outputs.append(mel_output) mel_outputs.append(mel_output)
alignments.append(alignment) # (B=1, T) alignments.append(alignment) # (B=1, T)
if self.use_stop_token:
stop_logits.append(stop_logit)
if int(paddle.argmax(alignment[0])) == encoder_steps - 1: if self.use_stop_token:
if first_hit_end is None: if F.sigmoid(stop_logit) > 0.5:
first_hit_end = i print("hit stop condition!")
if first_hit_end is not None and i > (first_hit_end + 10): break
print("content exhausted!") else:
break if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
if first_hit_end is None:
first_hit_end = i
elif i > (first_hit_end + 10):
print("content exhausted!")
break
if len(mel_outputs) == max_decoder_steps: if len(mel_outputs) == max_decoder_steps:
print("Warning! Reached max decoder steps!!!") print("Warning! Reached max decoder steps!!!")
break break
@ -481,7 +508,9 @@ class Tacotron2Decoder(nn.Layer):
alignments = paddle.stack(alignments, axis=1) alignments = paddle.stack(alignments, axis=1)
mel_outputs = paddle.stack(mel_outputs, axis=1) mel_outputs = paddle.stack(mel_outputs, axis=1)
if self.use_stop_token:
stop_logits = paddle.concat(stop_logits, axis=1)
return mel_outputs, alignments, stop_logits
return mel_outputs, alignments return mel_outputs, alignments
@ -580,7 +609,8 @@ class Tacotron2(nn.Layer):
p_attention_dropout: float = 0.1, p_attention_dropout: float = 0.1,
p_decoder_dropout: float = 0.1, p_decoder_dropout: float = 0.1,
p_postnet_dropout: float = 0.5, p_postnet_dropout: float = 0.5,
d_global_condition=None): d_global_condition=None,
use_stop_token=True):
super().__init__() super().__init__()
std = math.sqrt(2.0 / (vocab_size + d_encoder)) std = math.sqrt(2.0 / (vocab_size + d_encoder))
@ -606,7 +636,7 @@ class Tacotron2(nn.Layer):
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn, d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
d_decoder_rnn, d_attention, attention_filters, d_decoder_rnn, d_attention, attention_filters,
attention_kernel_size, p_prenet_dropout, p_attention_dropout, attention_kernel_size, p_prenet_dropout, p_attention_dropout,
p_decoder_dropout) p_decoder_dropout, use_stop_token=use_stop_token)
self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor, self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor,
d_hidden=d_postnet, d_hidden=d_postnet,
kernel_size=postnet_kernel_size, kernel_size=postnet_kernel_size,
@ -664,10 +694,14 @@ class Tacotron2(nn.Layer):
# [B, T_enc, 1] # [B, T_enc, 1]
mask = sequence_mask(text_lens, mask = sequence_mask(text_lens,
dtype=encoder_outputs.dtype).unsqueeze(-1) dtype=encoder_outputs.dtype).unsqueeze(-1)
mel_outputs, alignments = self.decoder(encoder_outputs, if self.decoder.use_stop_token:
mels, mel_outputs, alignments, stop_logits = self.decoder(encoder_outputs,
mask=mask) mels,
mask=mask)
else:
mel_outputs, alignments = self.decoder(encoder_outputs,
mels,
mask=mask)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs_postnet = mel_outputs + mel_outputs_postnet
@ -681,6 +715,8 @@ class Tacotron2(nn.Layer):
"mel_outputs_postnet": mel_outputs_postnet, "mel_outputs_postnet": mel_outputs_postnet,
"alignments": alignments "alignments": alignments
} }
if self.decoder.use_stop_token:
outputs["stop_logits"] = stop_logits
return outputs return outputs
@ -723,8 +759,12 @@ class Tacotron2(nn.Layer):
global_condition, [-1, encoder_outputs.shape[1], -1]) global_condition, [-1, encoder_outputs.shape[1], -1])
encoder_outputs = paddle.concat( encoder_outputs = paddle.concat(
[encoder_outputs, global_condition], -1) [encoder_outputs, global_condition], -1)
mel_outputs, alignments = self.decoder.infer( if self.decoder.use_stop_token:
encoder_outputs, max_decoder_steps=max_decoder_steps) mel_outputs, alignments, stop_logits = self.decoder.infer(
encoder_outputs, max_decoder_steps=max_decoder_steps)
else:
mel_outputs, alignments = self.decoder.infer(
encoder_outputs, max_decoder_steps=max_decoder_steps)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs_postnet = mel_outputs + mel_outputs_postnet
@ -734,22 +774,72 @@ class Tacotron2(nn.Layer):
"mel_outputs_postnet": mel_outputs_postnet, "mel_outputs_postnet": mel_outputs_postnet,
"alignments": alignments "alignments": alignments
} }
if self.decoder.use_stop_token:
outputs["stop_logits"] = stop_logits
return outputs return outputs
@classmethod
def from_pretrained(cls, config, checkpoint_path):
"""Build a Tacotron2 model from a pretrained model.
Parameters
----------
config: yacs.config.CfgNode
model configs
checkpoint_path: Path or str
the path of pretrained model checkpoint, without extension name
Returns
-------
ConditionalWaveFlow
The model built from pretrained result.
"""
model = cls(
vocab_size=config.model.vocab_size,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers,
encoder_kernel_size=config.model.encoder_kernel_size,
d_prenet=config.model.d_prenet,
d_attention_rnn=config.model.d_attention_rnn,
d_decoder_rnn=config.model.d_decoder_rnn,
attention_filters=config.model.attention_filters,
attention_kernel_size=config.model.attention_kernel_size,
d_attention=config.model.d_attention,
d_postnet=config.model.d_postnet,
postnet_kernel_size=config.model.postnet_kernel_size,
postnet_conv_layers=config.model.postnet_conv_layers,
reduction_factor=config.model.reduction_factor,
p_encoder_dropout=config.model.p_encoder_dropout,
p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout,
use_stop_token=config.model.use_stop_token)
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
return model
class Tacotron2Loss(nn.Layer): class Tacotron2Loss(nn.Layer):
""" Tacotron2 Loss module """ Tacotron2 Loss module
""" """
def __init__(self, sigma=0.2): def __init__(self,
use_stop_token_loss=True,
use_guided_attention_loss=False,
sigma=0.2):
super().__init__() super().__init__()
self.spec_criterion = nn.MSELoss() self.spec_criterion = nn.MSELoss()
self.use_stop_token_loss = use_stop_token_loss
self.use_guided_attention_loss = use_guided_attention_loss
self.attn_criterion = guided_attention_loss self.attn_criterion = guided_attention_loss
self.stop_criterion = paddle.nn.BCEWithLogitsLoss()
self.sigma = sigma self.sigma = sigma
def forward(self, mel_outputs, mel_outputs_postnet, mel_targets, def forward(self, mel_outputs, mel_outputs_postnet, mel_targets,
attention_weights, slens, plens): attention_weights=None, slens=None, plens=None, stop_logits=None):
"""Calculate tacotron2 loss. """Calculate tacotron2 loss.
Parameters Parameters
@ -775,13 +865,24 @@ class Tacotron2Loss(nn.Layer):
""" """
mel_loss = self.spec_criterion(mel_outputs, mel_targets) mel_loss = self.spec_criterion(mel_outputs, mel_targets)
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets) post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
gal_loss = self.attn_criterion(attention_weights, slens, plens, total_loss = mel_loss + post_mel_loss
if self.use_guided_attention_loss:
gal_loss = self.attn_criterion(attention_weights, slens, plens,
self.sigma) self.sigma)
total_loss = mel_loss + post_mel_loss + gal_loss total_loss += gal_loss
if self.use_stop_token_loss:
T_dec = mel_targets.shape[1]
stop_labels = F.one_hot(slens - 1, num_classes=T_dec)
stop_token_loss = self.stop_criterion(stop_logits, stop_labels)
total_loss += stop_token_loss
losses = { losses = {
"loss": total_loss, "loss": total_loss,
"mel_loss": mel_loss, "mel_loss": mel_loss,
"post_mel_loss": post_mel_loss, "post_mel_loss": post_mel_loss
"guided_attn_loss": gal_loss
} }
if self.use_guided_attention_loss:
losses["guided_attn_loss"] = gal_loss
if self.use_stop_token_loss:
losses["stop_loss"] = stop_token_loss
return losses return losses