fix some bugs of tacotron2

This commit is contained in:
lfchener 2020-12-11 03:56:40 +00:00
parent fb64c79f7a
commit 09f1840082
3 changed files with 58 additions and 40 deletions

View File

@ -20,13 +20,13 @@ from parakeet.frontend.normalizer.numbers import normalize_numbers
def normalize(sentence):
# preprocessing
text = unicode(text)
text = normalize_numbers(text)
text = ''.join(
char for char in unicodedata.normalize('NFD', text)
sentence = unicode(sentence)
sentence = normalize_numbers(sentence)
sentence = ''.join(
char for char in unicodedata.normalize('NFD', sentence)
if unicodedata.category(char) != 'Mn') # Strip accents
text = text.lower()
text = re.sub(r"[^ a-z'.,?!\-]", "", text)
text = text.replace("i.e.", "that is")
text = text.replace("e.g.", "for example")
return text
sentence = sentence.lower()
sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence)
sentence = sentence.replace("i.e.", "that is")
sentence = sentence.replace("e.g.", "for example")
return sentence.split()

View File

@ -16,6 +16,7 @@ import math
import paddle
from paddle import nn
from paddle.nn import functional as F
import parakeet
from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules import masking
@ -31,6 +32,7 @@ class DecoderPreNet(nn.Layer):
dropout_rate: int=0.2):
super().__init__()
self.dropout_rate = dropout_rate
self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False)
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
@ -50,6 +52,7 @@ class DecoderPostNet(nn.Layer):
dropout=0.1):
super().__init__()
self.dropout = dropout
self.num_layers = num_layers
self.conv_batchnorms = nn.LayerList()
k = math.sqrt(1.0 / (d_mels * kernel_size))
@ -89,7 +92,8 @@ class DecoderPostNet(nn.Layer):
for i in range(len(self.conv_batchnorms) - 1):
input = F.dropout(
F.tanh(self.conv_batchnorms[i](input), self.dropout))
input = F.dropout(self.conv_batchnorms[-1](input), self.dropout)
input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
self.dropout)
return input
@ -120,7 +124,7 @@ class Tacotron2Encoder(nn.Layer):
d_hidden, self.hidden_size, direction="bidirectional")
def forward(self, x, input_lens=None):
for conv_batchnorm in conv_batchnorms:
for conv_batchnorm in self.conv_batchnorms:
x = F.dropout(F.relu(conv_batchnorm(x)),
self.p_dropout) #(B, T, C)
@ -209,7 +213,7 @@ class Tacotron2Decoder(nn.Layer):
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
# The second lasm layer
# The second lstm layer
decoder_input = paddle.concat(
[self.attention_hidden, self.attention_context], axis=-1)
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
@ -225,29 +229,29 @@ class Tacotron2Decoder(nn.Layer):
stop_logit = self.stop_layer(decoder_hidden_attention_context)
return decoder_output, stop_logit, self.attention_weights
def forward(self, key, query, mask):
query = paddle.reshape(
query,
[query.shape[0], query.shape[1] // self.reduction_factor, -1])
query = paddle.concat(
def forward(self, keys, querys, mask):
querys = paddle.reshape(
querys,
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
querys = paddle.concat(
[
paddle.zeros(
shape=[
query.shape[0], 1,
query.shape[-1] * self.reduction_factor
querys.shape[0], 1,
querys.shape[-1] * self.reduction_factor
],
dtype=query.dtype), query
dtype=querys.dtype), querys
],
axis=1)
query = self.prenet(query)
querys = self.prenet(querys)
self._initialize_decoder_states(key)
self._initialize_decoder_states(keys)
self.mask = mask
mel_outputs, stop_logits, alignments = [], [], []
while len(mel_outputs) < query.shape[
while len(mel_outputs) < querys.shape[
1] - 1: # Ignore the last time step
query = query[:, len(mel_outputs), :]
query = querys[:, len(mel_outputs), :]
mel_output, stop_logit, attention_weights = self._decode(query)
mel_outputs += [mel_output]
stop_logits += [stop_logit]
@ -308,9 +312,8 @@ class Tacotron2(nn.Layer):
def __init__(self,
frontend: parakeet.frontend.Phonetics,
d_mels: int=80,
d_embedding: int=512,
encoder_conv_layers: int=3,
d_encoder: int=512,
encoder_conv_layers: int=3,
encoder_kernel_size: int=5,
d_prenet: int=256,
d_attention_rnn: int=1024,
@ -329,11 +332,11 @@ class Tacotron2(nn.Layer):
p_postnet_dropout: float=0.5):
super().__init__()
std = math.sqrt(2.0 / (frontend.vocab_size + d_embedding))
std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder))
val = math.sqrt(3.0) * std # uniform bounds for std
self.embedding = nn.Embedding(
frontend.vocab_size,
d_embedding,
d_encoder,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-val, high=val)))
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,

View File

@ -1,3 +1,17 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import logging
from pathlib import Path
@ -11,6 +25,7 @@ from collections import defaultdict
import parakeet
from parakeet.utils import checkpoint, mp_tools
class ExperimentBase(object):
"""
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
@ -22,7 +37,7 @@ class ExperimentBase(object):
We have some conventions to follow.
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
3. There are three method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
3. There are four method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you need.
@ -54,6 +69,7 @@ class ExperimentBase(object):
main(config, args)
"""
def __init__(self, config, args):
self.config = config
self.args = args
@ -67,7 +83,7 @@ class ExperimentBase(object):
self.setup_visualizer()
self.setup_logger()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
@ -82,13 +98,13 @@ class ExperimentBase(object):
dist.init_parallel_env()
def save(self):
checkpoint.save_parameters(
self.checkpoint_dir, self.iteration, self.model, self.optimizer)
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer)
def resume_or_load(self):
iteration = checkpoint.load_parameters(
self.model,
self.optimizer,
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration
@ -115,10 +131,10 @@ class ExperimentBase(object):
if self.iteration % self.config.training.valid_interval == 0:
self.valid()
if self.iteration % self.config.training.save_interval == 0:
self.save()
def run(self):
self.resume_or_load()
try:
@ -126,7 +142,7 @@ class ExperimentBase(object):
except KeyboardInterrupt:
self.save()
exit(-1)
@mp_tools.rank_zero_only
def setup_output_dir(self):
# output dir
@ -134,7 +150,7 @@ class ExperimentBase(object):
output_dir.mkdir(exist_ok=True)
self.output_dir = output_dir
@mp_tools.rank_zero_only
def setup_checkpointer(self):
# checkpoint dir
@ -161,7 +177,7 @@ class ExperimentBase(object):
@mp_tools.rank_zero_only
def dump_config(self):
with open(self.output_dir / "config.yaml", 'wt') as f:
with open(self.output_dir / "config.yaml", 'wt') as f:
print(self.config, file=f)
def train_batch(self):
@ -177,4 +193,3 @@ class ExperimentBase(object):
def setup_dataloader(self):
raise NotImplementedError("setup_dataloader should be implemented.")