From 263d3eb88b2cbd19a5c8cd51c05f2c0beef9234d Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Mon, 26 Apr 2021 21:18:29 +0800 Subject: [PATCH] add an optional to alter the loss and model structure of tacotron2, add an alternative config --- examples/tacotron2/config.py | 8 +- examples/tacotron2/ljspeech.py | 12 +-- examples/tacotron2/preprocess.py | 4 +- examples/tacotron2/synthesize.py | 9 +- examples/tacotron2/train.py | 25 +++-- parakeet/models/tacotron2.py | 151 ++++++++++++++++++++++++++----- 6 files changed, 160 insertions(+), 49 deletions(-) diff --git a/examples/tacotron2/config.py b/examples/tacotron2/config.py index 41e1bd9..87c28e6 100644 --- a/examples/tacotron2/config.py +++ b/examples/tacotron2/config.py @@ -23,8 +23,8 @@ _C.data = CN( n_fft=1024, # fft frame size win_length=1024, # window size hop_length=256, # hop size between ajacent frame - f_max=8000, # Hz, max frequency when converting to mel - f_min=0, # Hz, min frequency when converting to mel + fmax=8000, # Hz, max frequency when converting to mel + fmin=0, # Hz, min frequency when converting to mel d_mels=80, # mel bands padding_idx=0, # text embedding's padding index )) @@ -50,7 +50,9 @@ _C.model = CN( p_attention_dropout=0.1, # droput probability of first rnn layer in decoder p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder p_postnet_dropout=0.5, # droput probability in decoder postnet - guided_attn_loss_sigma=0.2 # sigma in guided attention loss + use_stop_token=True, # wherther to use binary classifier to predict when to stop + use_guided_attention_loss=False, # whether to use guided attention loss + guided_attention_loss_sigma=0.2 # sigma in guided attention loss )) _C.training = CN( diff --git a/examples/tacotron2/ljspeech.py b/examples/tacotron2/ljspeech.py index 9acebc4..a98ce5d 100644 --- a/examples/tacotron2/ljspeech.py +++ b/examples/tacotron2/ljspeech.py @@ -58,7 +58,7 @@ class LJSpeechCollector(object): mels = [] text_lens = [] mel_lens = [] - stop_tokens = [] + for data in examples: text, mel = data text = np.array(text, dtype=np.int64) @@ -66,8 +66,6 @@ class LJSpeechCollector(object): mels.append(mel) texts.append(text) mel_lens.append(mel.shape[1]) - stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32) - stop_tokens.append(np.append(stop_token, 1.0)) # Sort by text_len in descending order texts = [ @@ -86,15 +84,13 @@ class LJSpeechCollector(object): for i, _ in sorted( zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True) ] + mel_lens = np.array(mel_lens, dtype=np.int64) - - text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64) # Pad sequence with largest len of the batch texts, _ = batch_text_id(texts, pad_id=self.padding_idx) - mels, _ = np.transpose( - batch_spec( - mels, pad_value=self.padding_value), axes=(0, 2, 1)) + mels, _ = batch_spec(mels, pad_value=self.padding_value) + mels = np.transpose(mels, axes=(0, 2, 1)) return texts, mels, text_lens, mel_lens diff --git a/examples/tacotron2/preprocess.py b/examples/tacotron2/preprocess.py index 22f6443..3d8305d 100644 --- a/examples/tacotron2/preprocess.py +++ b/examples/tacotron2/preprocess.py @@ -39,8 +39,8 @@ def create_dataset(config, source_path, target_path, verbose=False): n_mels=config.data.d_mels, win_length=config.data.win_length, hop_length=config.data.hop_length, - f_max=config.data.f_max, - f_min=config.data.f_min) + f_max=config.data.fmax, + f_min=config.data.fmin) normalizer = LogMagnitude() records = [] diff --git a/examples/tacotron2/synthesize.py b/examples/tacotron2/synthesize.py index 8ea8ae0..91557ab 100644 --- a/examples/tacotron2/synthesize.py +++ b/examples/tacotron2/synthesize.py @@ -20,6 +20,8 @@ import paddle import parakeet from parakeet.frontend import EnglishCharacter from parakeet.models.tacotron2 import Tacotron2 +from parakeet.utils import display +from matplotlib import pyplot as plt from config import get_cfg_defaults @@ -46,10 +48,13 @@ def main(config, args): for i, sentence in enumerate(sentences): sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0) - mel_output, _ = model.predict(sentence) - mel_output = mel_output["mel_outputs_postnet"][0].numpy().T + outputs = model.infer(sentence) + mel_output = outputs["mel_outputs_postnet"][0].numpy().T + alignment = outputs["alignments"][0].numpy().T np.save(str(output_dir / f"sentence_{i}"), mel_output) + display.plot_alignment(alignment) + plt.savefig(str(output_dir / f"sentence_{i}.png")) if args.verbose: print("spectrogram saved at {}".format(output_dir / f"sentence_{i}.npy")) diff --git a/examples/tacotron2/train.py b/examples/tacotron2/train.py index 3b02394..bc3ae2c 100644 --- a/examples/tacotron2/train.py +++ b/examples/tacotron2/train.py @@ -34,14 +34,18 @@ from ljspeech import LJSpeech, LJSpeechCollector class Experiment(ExperimentBase): def compute_losses(self, inputs, outputs): - _, mel_targets, plens, slens, stop_tokens = inputs + texts, mel_targets, plens, slens = inputs mel_outputs = outputs["mel_output"] mel_outputs_postnet = outputs["mel_outputs_postnet"] attention_weight = outputs["alignments"] + if self.config.model.use_stop_token: + stop_logits = outputs["stop_logits"] + else: + stop_logits = None losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets, - attention_weight, slens, plens) + attention_weight, slens, plens, stop_logits) return losses def train_batch(self): @@ -52,7 +56,7 @@ class Experiment(ExperimentBase): self.optimizer.clear_grad() self.model.train() texts, mels, text_lens, output_lens = batch - outputs = self.model(texts, mels, text_lens, output_lens) + outputs = self.model(texts, text_lens, mels, output_lens) losses = self.compute_losses(batch, outputs) loss = losses["loss"] loss.backward() @@ -80,7 +84,7 @@ class Experiment(ExperimentBase): valid_losses = defaultdict(list) for i, batch in enumerate(self.valid_loader): texts, mels, text_lens, output_lens = batch - outputs = self.model(texts, mels, text_lens, output_lens) + outputs = self.model(texts, text_lens, mels, output_lens) losses = self.compute_losses(batch, outputs) for k, v in losses.items(): valid_losses[k].append(float(v)) @@ -88,7 +92,7 @@ class Experiment(ExperimentBase): attention_weights = outputs["alignments"] self.visualizer.add_figure( f"valid_sentence_{i}_alignments", - display.plot_alignment(attention_weights[0].numpy()), + display.plot_alignment(attention_weights[0].numpy().T), self.iteration) self.visualizer.add_figure( f"valid_sentence_{i}_target_spectrogram", @@ -114,9 +118,8 @@ class Experiment(ExperimentBase): def setup_model(self): config = self.config - frontend = EnglishCharacter() model = Tacotron2( - frontend, + vocab_size=config.model.vocab_size, d_mels=config.data.d_mels, d_encoder=config.model.d_encoder, encoder_conv_layers=config.model.encoder_conv_layers, @@ -135,7 +138,8 @@ class Experiment(ExperimentBase): p_prenet_dropout=config.model.p_prenet_dropout, p_attention_dropout=config.model.p_attention_dropout, p_decoder_dropout=config.model.p_decoder_dropout, - p_postnet_dropout=config.model.p_postnet_dropout) + p_postnet_dropout=config.model.p_postnet_dropout, + use_stop_token=config.model.use_stop_token) if self.parallel: model = paddle.DataParallel(model) @@ -148,7 +152,10 @@ class Experiment(ExperimentBase): weight_decay=paddle.regularizer.L2Decay( config.training.weight_decay), grad_clip=grad_clip) - criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma) + criterion = Tacotron2Loss( + use_stop_token_loss=config.model.use_stop_token, + use_guided_attention_loss=config.model.use_guided_attention_loss, + sigma=config.model.guided_attention_loss_sigma) self.model = model self.optimizer = optimizer self.criterion = criterion diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index e3f5dfa..6eddecb 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -23,6 +23,7 @@ from paddle.fluid.layers import sequence_mask from parakeet.modules.conv import Conv1dBatchNorm from parakeet.modules.attention import LocationSensitiveAttention from parakeet.modules.losses import guided_attention_loss +from parakeet.utils import checkpoint from tqdm import trange __all__ = ["Tacotron2", "Tacotron2Loss"] @@ -274,7 +275,7 @@ class Tacotron2Decoder(nn.Layer): d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int, d_attention: int, attention_filters: int, attention_kernel_size: int, p_prenet_dropout: float, - p_attention_dropout: float, p_decoder_dropout: float): + p_attention_dropout: float, p_decoder_dropout: float, use_stop_token: bool=False): super().__init__() self.d_mels = d_mels self.reduction_factor = reduction_factor @@ -303,6 +304,10 @@ class Tacotron2Decoder(nn.Layer): d_decoder_rnn) self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder, d_mels * reduction_factor) + + self.use_stop_token = use_stop_token + if use_stop_token: + self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1) # states - temporary attributes self.attention_hidden = None @@ -379,6 +384,9 @@ class Tacotron2Decoder(nn.Layer): [self.decoder_hidden, self.attention_context], axis=-1) decoder_output = self.linear_projection( decoder_hidden_attention_context) + if self.use_stop_token: + stop_logit = self.stop_layer(decoder_hidden_attention_context) + return decoder_output, self.attention_weights, stop_logit return decoder_output, self.attention_weights def forward(self, keys, querys, mask): @@ -416,16 +424,24 @@ class Tacotron2Decoder(nn.Layer): querys = self.prenet(querys) mel_outputs, alignments = [], [] + stop_logits = [] # Ignore the last time step while len(mel_outputs) < querys.shape[1] - 1: query = querys[:, len(mel_outputs), :] - mel_output, attention_weights = self._decode(query) + if self.use_stop_token: + mel_output, attention_weights, stop_logit = self._decode(query) + else: + mel_output, attention_weights = self._decode(query) mel_outputs.append(mel_output) alignments.append(attention_weights) + if self.use_stop_token: + stop_logits.append(stop_logit) alignments = paddle.stack(alignments, axis=1) mel_outputs = paddle.stack(mel_outputs, axis=1) - + if self.use_stop_token: + stop_logits = paddle.concat(stop_logits, axis=1) + return mel_outputs, alignments, stop_logits return mel_outputs, alignments def infer(self, key, max_decoder_steps=1000): @@ -460,19 +476,30 @@ class Tacotron2Decoder(nn.Layer): first_hit_end = None mel_outputs, alignments = [], [] + stop_logits = [] for i in trange(max_decoder_steps): query = self.prenet(query) - mel_output, alignment = self._decode(query) + if self.use_stop_token: + mel_output, alignment, stop_logit = self._decode(query) + else: + mel_output, alignment = self._decode(query) mel_outputs.append(mel_output) alignments.append(alignment) # (B=1, T) + if self.use_stop_token: + stop_logits.append(stop_logit) - if int(paddle.argmax(alignment[0])) == encoder_steps - 1: - if first_hit_end is None: - first_hit_end = i - if first_hit_end is not None and i > (first_hit_end + 10): - print("content exhausted!") - break + if self.use_stop_token: + if F.sigmoid(stop_logit) > 0.5: + print("hit stop condition!") + break + else: + if int(paddle.argmax(alignment[0])) == encoder_steps - 1: + if first_hit_end is None: + first_hit_end = i + elif i > (first_hit_end + 10): + print("content exhausted!") + break if len(mel_outputs) == max_decoder_steps: print("Warning! Reached max decoder steps!!!") break @@ -481,7 +508,9 @@ class Tacotron2Decoder(nn.Layer): alignments = paddle.stack(alignments, axis=1) mel_outputs = paddle.stack(mel_outputs, axis=1) - + if self.use_stop_token: + stop_logits = paddle.concat(stop_logits, axis=1) + return mel_outputs, alignments, stop_logits return mel_outputs, alignments @@ -580,7 +609,8 @@ class Tacotron2(nn.Layer): p_attention_dropout: float = 0.1, p_decoder_dropout: float = 0.1, p_postnet_dropout: float = 0.5, - d_global_condition=None): + d_global_condition=None, + use_stop_token=True): super().__init__() std = math.sqrt(2.0 / (vocab_size + d_encoder)) @@ -606,7 +636,7 @@ class Tacotron2(nn.Layer): d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn, d_decoder_rnn, d_attention, attention_filters, attention_kernel_size, p_prenet_dropout, p_attention_dropout, - p_decoder_dropout) + p_decoder_dropout, use_stop_token=use_stop_token) self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor, d_hidden=d_postnet, kernel_size=postnet_kernel_size, @@ -664,10 +694,14 @@ class Tacotron2(nn.Layer): # [B, T_enc, 1] mask = sequence_mask(text_lens, dtype=encoder_outputs.dtype).unsqueeze(-1) - mel_outputs, alignments = self.decoder(encoder_outputs, - mels, - mask=mask) - + if self.decoder.use_stop_token: + mel_outputs, alignments, stop_logits = self.decoder(encoder_outputs, + mels, + mask=mask) + else: + mel_outputs, alignments = self.decoder(encoder_outputs, + mels, + mask=mask) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet @@ -681,6 +715,8 @@ class Tacotron2(nn.Layer): "mel_outputs_postnet": mel_outputs_postnet, "alignments": alignments } + if self.decoder.use_stop_token: + outputs["stop_logits"] = stop_logits return outputs @@ -723,8 +759,12 @@ class Tacotron2(nn.Layer): global_condition, [-1, encoder_outputs.shape[1], -1]) encoder_outputs = paddle.concat( [encoder_outputs, global_condition], -1) - mel_outputs, alignments = self.decoder.infer( - encoder_outputs, max_decoder_steps=max_decoder_steps) + if self.decoder.use_stop_token: + mel_outputs, alignments, stop_logits = self.decoder.infer( + encoder_outputs, max_decoder_steps=max_decoder_steps) + else: + mel_outputs, alignments = self.decoder.infer( + encoder_outputs, max_decoder_steps=max_decoder_steps) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet @@ -734,22 +774,72 @@ class Tacotron2(nn.Layer): "mel_outputs_postnet": mel_outputs_postnet, "alignments": alignments } + if self.decoder.use_stop_token: + outputs["stop_logits"] = stop_logits return outputs + @classmethod + def from_pretrained(cls, config, checkpoint_path): + """Build a Tacotron2 model from a pretrained model. + + Parameters + ---------- + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path or str + the path of pretrained model checkpoint, without extension name + + Returns + ------- + ConditionalWaveFlow + The model built from pretrained result. + """ + model = cls( + vocab_size=config.model.vocab_size, + d_mels=config.data.d_mels, + d_encoder=config.model.d_encoder, + encoder_conv_layers=config.model.encoder_conv_layers, + encoder_kernel_size=config.model.encoder_kernel_size, + d_prenet=config.model.d_prenet, + d_attention_rnn=config.model.d_attention_rnn, + d_decoder_rnn=config.model.d_decoder_rnn, + attention_filters=config.model.attention_filters, + attention_kernel_size=config.model.attention_kernel_size, + d_attention=config.model.d_attention, + d_postnet=config.model.d_postnet, + postnet_kernel_size=config.model.postnet_kernel_size, + postnet_conv_layers=config.model.postnet_conv_layers, + reduction_factor=config.model.reduction_factor, + p_encoder_dropout=config.model.p_encoder_dropout, + p_prenet_dropout=config.model.p_prenet_dropout, + p_attention_dropout=config.model.p_attention_dropout, + p_decoder_dropout=config.model.p_decoder_dropout, + p_postnet_dropout=config.model.p_postnet_dropout, + use_stop_token=config.model.use_stop_token) + checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) + return model + class Tacotron2Loss(nn.Layer): """ Tacotron2 Loss module """ - def __init__(self, sigma=0.2): + def __init__(self, + use_stop_token_loss=True, + use_guided_attention_loss=False, + sigma=0.2): super().__init__() self.spec_criterion = nn.MSELoss() + self.use_stop_token_loss = use_stop_token_loss + self.use_guided_attention_loss = use_guided_attention_loss self.attn_criterion = guided_attention_loss + self.stop_criterion = paddle.nn.BCEWithLogitsLoss() self.sigma = sigma def forward(self, mel_outputs, mel_outputs_postnet, mel_targets, - attention_weights, slens, plens): + attention_weights=None, slens=None, plens=None, stop_logits=None): """Calculate tacotron2 loss. Parameters @@ -775,13 +865,24 @@ class Tacotron2Loss(nn.Layer): """ mel_loss = self.spec_criterion(mel_outputs, mel_targets) post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets) - gal_loss = self.attn_criterion(attention_weights, slens, plens, + total_loss = mel_loss + post_mel_loss + if self.use_guided_attention_loss: + gal_loss = self.attn_criterion(attention_weights, slens, plens, self.sigma) - total_loss = mel_loss + post_mel_loss + gal_loss + total_loss += gal_loss + if self.use_stop_token_loss: + T_dec = mel_targets.shape[1] + stop_labels = F.one_hot(slens - 1, num_classes=T_dec) + stop_token_loss = self.stop_criterion(stop_logits, stop_labels) + total_loss += stop_token_loss + losses = { "loss": total_loss, "mel_loss": mel_loss, - "post_mel_loss": post_mel_loss, - "guided_attn_loss": gal_loss + "post_mel_loss": post_mel_loss } + if self.use_guided_attention_loss: + losses["guided_attn_loss"] = gal_loss + if self.use_stop_token_loss: + losses["stop_loss"] = stop_token_loss return losses