From 09f184008240805194799bfd6b25ed21902fc839 Mon Sep 17 00:00:00 2001 From: lfchener Date: Fri, 11 Dec 2020 03:56:40 +0000 Subject: [PATCH] fix some bugs of tacotron2 --- parakeet/frontend/normalizer/normalizer.py | 18 +++++----- parakeet/models/tacotron2.py | 41 ++++++++++++---------- parakeet/training/experiment.py | 39 +++++++++++++------- 3 files changed, 58 insertions(+), 40 deletions(-) diff --git a/parakeet/frontend/normalizer/normalizer.py b/parakeet/frontend/normalizer/normalizer.py index 96981f8..fe7d9f8 100644 --- a/parakeet/frontend/normalizer/normalizer.py +++ b/parakeet/frontend/normalizer/normalizer.py @@ -20,13 +20,13 @@ from parakeet.frontend.normalizer.numbers import normalize_numbers def normalize(sentence): # preprocessing - text = unicode(text) - text = normalize_numbers(text) - text = ''.join( - char for char in unicodedata.normalize('NFD', text) + sentence = unicode(sentence) + sentence = normalize_numbers(sentence) + sentence = ''.join( + char for char in unicodedata.normalize('NFD', sentence) if unicodedata.category(char) != 'Mn') # Strip accents - text = text.lower() - text = re.sub(r"[^ a-z'.,?!\-]", "", text) - text = text.replace("i.e.", "that is") - text = text.replace("e.g.", "for example") - return text + sentence = sentence.lower() + sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence) + sentence = sentence.replace("i.e.", "that is") + sentence = sentence.replace("e.g.", "for example") + return sentence.split() diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index 194e068..912cbab 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -16,6 +16,7 @@ import math import paddle from paddle import nn from paddle.nn import functional as F +import parakeet from parakeet.modules.conv import Conv1dBatchNorm from parakeet.modules.attention import LocationSensitiveAttention from parakeet.modules import masking @@ -31,6 +32,7 @@ class DecoderPreNet(nn.Layer): dropout_rate: int=0.2): super().__init__() + self.dropout_rate = dropout_rate self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False) self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False) @@ -50,6 +52,7 @@ class DecoderPostNet(nn.Layer): dropout=0.1): super().__init__() self.dropout = dropout + self.num_layers = num_layers self.conv_batchnorms = nn.LayerList() k = math.sqrt(1.0 / (d_mels * kernel_size)) @@ -89,7 +92,8 @@ class DecoderPostNet(nn.Layer): for i in range(len(self.conv_batchnorms) - 1): input = F.dropout( F.tanh(self.conv_batchnorms[i](input), self.dropout)) - input = F.dropout(self.conv_batchnorms[-1](input), self.dropout) + input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input), + self.dropout) return input @@ -120,7 +124,7 @@ class Tacotron2Encoder(nn.Layer): d_hidden, self.hidden_size, direction="bidirectional") def forward(self, x, input_lens=None): - for conv_batchnorm in conv_batchnorms: + for conv_batchnorm in self.conv_batchnorms: x = F.dropout(F.relu(conv_batchnorm(x)), self.p_dropout) #(B, T, C) @@ -209,7 +213,7 @@ class Tacotron2Decoder(nn.Layer): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights - # The second lasm layer + # The second lstm layer decoder_input = paddle.concat( [self.attention_hidden, self.attention_context], axis=-1) _, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn( @@ -225,29 +229,29 @@ class Tacotron2Decoder(nn.Layer): stop_logit = self.stop_layer(decoder_hidden_attention_context) return decoder_output, stop_logit, self.attention_weights - def forward(self, key, query, mask): - query = paddle.reshape( - query, - [query.shape[0], query.shape[1] // self.reduction_factor, -1]) - query = paddle.concat( + def forward(self, keys, querys, mask): + querys = paddle.reshape( + querys, + [querys.shape[0], querys.shape[1] // self.reduction_factor, -1]) + querys = paddle.concat( [ paddle.zeros( shape=[ - query.shape[0], 1, - query.shape[-1] * self.reduction_factor + querys.shape[0], 1, + querys.shape[-1] * self.reduction_factor ], - dtype=query.dtype), query + dtype=querys.dtype), querys ], axis=1) - query = self.prenet(query) + querys = self.prenet(querys) - self._initialize_decoder_states(key) + self._initialize_decoder_states(keys) self.mask = mask mel_outputs, stop_logits, alignments = [], [], [] - while len(mel_outputs) < query.shape[ + while len(mel_outputs) < querys.shape[ 1] - 1: # Ignore the last time step - query = query[:, len(mel_outputs), :] + query = querys[:, len(mel_outputs), :] mel_output, stop_logit, attention_weights = self._decode(query) mel_outputs += [mel_output] stop_logits += [stop_logit] @@ -308,9 +312,8 @@ class Tacotron2(nn.Layer): def __init__(self, frontend: parakeet.frontend.Phonetics, d_mels: int=80, - d_embedding: int=512, - encoder_conv_layers: int=3, d_encoder: int=512, + encoder_conv_layers: int=3, encoder_kernel_size: int=5, d_prenet: int=256, d_attention_rnn: int=1024, @@ -329,11 +332,11 @@ class Tacotron2(nn.Layer): p_postnet_dropout: float=0.5): super().__init__() - std = math.sqrt(2.0 / (frontend.vocab_size + d_embedding)) + std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder)) val = math.sqrt(3.0) * std # uniform bounds for std self.embedding = nn.Embedding( frontend.vocab_size, - d_embedding, + d_encoder, weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( low=-val, high=val))) self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers, diff --git a/parakeet/training/experiment.py b/parakeet/training/experiment.py index 846947d..05d5b6e 100644 --- a/parakeet/training/experiment.py +++ b/parakeet/training/experiment.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time import logging from pathlib import Path @@ -11,6 +25,7 @@ from collections import defaultdict import parakeet from parakeet.utils import checkpoint, mp_tools + class ExperimentBase(object): """ An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple. @@ -22,7 +37,7 @@ class ExperimentBase(object): We have some conventions to follow. 1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes. 2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment. - 3. There are three method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented. + 3. There are four method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented. Feel free to add/overwrite other methods and standalone functions if you need. @@ -54,6 +69,7 @@ class ExperimentBase(object): main(config, args) """ + def __init__(self, config, args): self.config = config self.args = args @@ -67,7 +83,7 @@ class ExperimentBase(object): self.setup_visualizer() self.setup_logger() self.setup_checkpointer() - + self.setup_dataloader() self.setup_model() @@ -82,13 +98,13 @@ class ExperimentBase(object): dist.init_parallel_env() def save(self): - checkpoint.save_parameters( - self.checkpoint_dir, self.iteration, self.model, self.optimizer) + checkpoint.save_parameters(self.checkpoint_dir, self.iteration, + self.model, self.optimizer) def resume_or_load(self): iteration = checkpoint.load_parameters( - self.model, - self.optimizer, + self.model, + self.optimizer, checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) self.iteration = iteration @@ -115,10 +131,10 @@ class ExperimentBase(object): if self.iteration % self.config.training.valid_interval == 0: self.valid() - + if self.iteration % self.config.training.save_interval == 0: self.save() - + def run(self): self.resume_or_load() try: @@ -126,7 +142,7 @@ class ExperimentBase(object): except KeyboardInterrupt: self.save() exit(-1) - + @mp_tools.rank_zero_only def setup_output_dir(self): # output dir @@ -134,7 +150,7 @@ class ExperimentBase(object): output_dir.mkdir(exist_ok=True) self.output_dir = output_dir - + @mp_tools.rank_zero_only def setup_checkpointer(self): # checkpoint dir @@ -161,7 +177,7 @@ class ExperimentBase(object): @mp_tools.rank_zero_only def dump_config(self): - with open(self.output_dir / "config.yaml", 'wt') as f: + with open(self.output_dir / "config.yaml", 'wt') as f: print(self.config, file=f) def train_batch(self): @@ -177,4 +193,3 @@ class ExperimentBase(object): def setup_dataloader(self): raise NotImplementedError("setup_dataloader should be implemented.") -