fix some bugs of tacotron2
This commit is contained in:
parent
fb64c79f7a
commit
09f1840082
|
@ -20,13 +20,13 @@ from parakeet.frontend.normalizer.numbers import normalize_numbers
|
||||||
|
|
||||||
def normalize(sentence):
|
def normalize(sentence):
|
||||||
# preprocessing
|
# preprocessing
|
||||||
text = unicode(text)
|
sentence = unicode(sentence)
|
||||||
text = normalize_numbers(text)
|
sentence = normalize_numbers(sentence)
|
||||||
text = ''.join(
|
sentence = ''.join(
|
||||||
char for char in unicodedata.normalize('NFD', text)
|
char for char in unicodedata.normalize('NFD', sentence)
|
||||||
if unicodedata.category(char) != 'Mn') # Strip accents
|
if unicodedata.category(char) != 'Mn') # Strip accents
|
||||||
text = text.lower()
|
sentence = sentence.lower()
|
||||||
text = re.sub(r"[^ a-z'.,?!\-]", "", text)
|
sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence)
|
||||||
text = text.replace("i.e.", "that is")
|
sentence = sentence.replace("i.e.", "that is")
|
||||||
text = text.replace("e.g.", "for example")
|
sentence = sentence.replace("e.g.", "for example")
|
||||||
return text
|
return sentence.split()
|
||||||
|
|
|
@ -16,6 +16,7 @@ import math
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
import parakeet
|
||||||
from parakeet.modules.conv import Conv1dBatchNorm
|
from parakeet.modules.conv import Conv1dBatchNorm
|
||||||
from parakeet.modules.attention import LocationSensitiveAttention
|
from parakeet.modules.attention import LocationSensitiveAttention
|
||||||
from parakeet.modules import masking
|
from parakeet.modules import masking
|
||||||
|
@ -31,6 +32,7 @@ class DecoderPreNet(nn.Layer):
|
||||||
dropout_rate: int=0.2):
|
dropout_rate: int=0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False)
|
self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False)
|
||||||
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
|
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
|
||||||
|
|
||||||
|
@ -50,6 +52,7 @@ class DecoderPostNet(nn.Layer):
|
||||||
dropout=0.1):
|
dropout=0.1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
self.conv_batchnorms = nn.LayerList()
|
self.conv_batchnorms = nn.LayerList()
|
||||||
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
||||||
|
@ -89,7 +92,8 @@ class DecoderPostNet(nn.Layer):
|
||||||
for i in range(len(self.conv_batchnorms) - 1):
|
for i in range(len(self.conv_batchnorms) - 1):
|
||||||
input = F.dropout(
|
input = F.dropout(
|
||||||
F.tanh(self.conv_batchnorms[i](input), self.dropout))
|
F.tanh(self.conv_batchnorms[i](input), self.dropout))
|
||||||
input = F.dropout(self.conv_batchnorms[-1](input), self.dropout)
|
input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
|
||||||
|
self.dropout)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
@ -120,7 +124,7 @@ class Tacotron2Encoder(nn.Layer):
|
||||||
d_hidden, self.hidden_size, direction="bidirectional")
|
d_hidden, self.hidden_size, direction="bidirectional")
|
||||||
|
|
||||||
def forward(self, x, input_lens=None):
|
def forward(self, x, input_lens=None):
|
||||||
for conv_batchnorm in conv_batchnorms:
|
for conv_batchnorm in self.conv_batchnorms:
|
||||||
x = F.dropout(F.relu(conv_batchnorm(x)),
|
x = F.dropout(F.relu(conv_batchnorm(x)),
|
||||||
self.p_dropout) #(B, T, C)
|
self.p_dropout) #(B, T, C)
|
||||||
|
|
||||||
|
@ -209,7 +213,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
attention_weights_cat, self.mask)
|
attention_weights_cat, self.mask)
|
||||||
self.attention_weights_cum += self.attention_weights
|
self.attention_weights_cum += self.attention_weights
|
||||||
|
|
||||||
# The second lasm layer
|
# The second lstm layer
|
||||||
decoder_input = paddle.concat(
|
decoder_input = paddle.concat(
|
||||||
[self.attention_hidden, self.attention_context], axis=-1)
|
[self.attention_hidden, self.attention_context], axis=-1)
|
||||||
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
||||||
|
@ -225,29 +229,29 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||||
return decoder_output, stop_logit, self.attention_weights
|
return decoder_output, stop_logit, self.attention_weights
|
||||||
|
|
||||||
def forward(self, key, query, mask):
|
def forward(self, keys, querys, mask):
|
||||||
query = paddle.reshape(
|
querys = paddle.reshape(
|
||||||
query,
|
querys,
|
||||||
[query.shape[0], query.shape[1] // self.reduction_factor, -1])
|
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
||||||
query = paddle.concat(
|
querys = paddle.concat(
|
||||||
[
|
[
|
||||||
paddle.zeros(
|
paddle.zeros(
|
||||||
shape=[
|
shape=[
|
||||||
query.shape[0], 1,
|
querys.shape[0], 1,
|
||||||
query.shape[-1] * self.reduction_factor
|
querys.shape[-1] * self.reduction_factor
|
||||||
],
|
],
|
||||||
dtype=query.dtype), query
|
dtype=querys.dtype), querys
|
||||||
],
|
],
|
||||||
axis=1)
|
axis=1)
|
||||||
query = self.prenet(query)
|
querys = self.prenet(querys)
|
||||||
|
|
||||||
self._initialize_decoder_states(key)
|
self._initialize_decoder_states(keys)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
|
||||||
mel_outputs, stop_logits, alignments = [], [], []
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
while len(mel_outputs) < query.shape[
|
while len(mel_outputs) < querys.shape[
|
||||||
1] - 1: # Ignore the last time step
|
1] - 1: # Ignore the last time step
|
||||||
query = query[:, len(mel_outputs), :]
|
query = querys[:, len(mel_outputs), :]
|
||||||
mel_output, stop_logit, attention_weights = self._decode(query)
|
mel_output, stop_logit, attention_weights = self._decode(query)
|
||||||
mel_outputs += [mel_output]
|
mel_outputs += [mel_output]
|
||||||
stop_logits += [stop_logit]
|
stop_logits += [stop_logit]
|
||||||
|
@ -308,9 +312,8 @@ class Tacotron2(nn.Layer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
frontend: parakeet.frontend.Phonetics,
|
frontend: parakeet.frontend.Phonetics,
|
||||||
d_mels: int=80,
|
d_mels: int=80,
|
||||||
d_embedding: int=512,
|
|
||||||
encoder_conv_layers: int=3,
|
|
||||||
d_encoder: int=512,
|
d_encoder: int=512,
|
||||||
|
encoder_conv_layers: int=3,
|
||||||
encoder_kernel_size: int=5,
|
encoder_kernel_size: int=5,
|
||||||
d_prenet: int=256,
|
d_prenet: int=256,
|
||||||
d_attention_rnn: int=1024,
|
d_attention_rnn: int=1024,
|
||||||
|
@ -329,11 +332,11 @@ class Tacotron2(nn.Layer):
|
||||||
p_postnet_dropout: float=0.5):
|
p_postnet_dropout: float=0.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
std = math.sqrt(2.0 / (frontend.vocab_size + d_embedding))
|
std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder))
|
||||||
val = math.sqrt(3.0) * std # uniform bounds for std
|
val = math.sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
frontend.vocab_size,
|
frontend.vocab_size,
|
||||||
d_embedding,
|
d_encoder,
|
||||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
low=-val, high=val)))
|
low=-val, high=val)))
|
||||||
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -11,6 +25,7 @@ from collections import defaultdict
|
||||||
import parakeet
|
import parakeet
|
||||||
from parakeet.utils import checkpoint, mp_tools
|
from parakeet.utils import checkpoint, mp_tools
|
||||||
|
|
||||||
|
|
||||||
class ExperimentBase(object):
|
class ExperimentBase(object):
|
||||||
"""
|
"""
|
||||||
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
|
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
|
||||||
|
@ -22,7 +37,7 @@ class ExperimentBase(object):
|
||||||
We have some conventions to follow.
|
We have some conventions to follow.
|
||||||
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
|
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
|
||||||
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
|
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
|
||||||
3. There are three method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
|
3. There are four method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
|
||||||
|
|
||||||
Feel free to add/overwrite other methods and standalone functions if you need.
|
Feel free to add/overwrite other methods and standalone functions if you need.
|
||||||
|
|
||||||
|
@ -54,6 +69,7 @@ class ExperimentBase(object):
|
||||||
main(config, args)
|
main(config, args)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, args):
|
def __init__(self, config, args):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -67,7 +83,7 @@ class ExperimentBase(object):
|
||||||
self.setup_visualizer()
|
self.setup_visualizer()
|
||||||
self.setup_logger()
|
self.setup_logger()
|
||||||
self.setup_checkpointer()
|
self.setup_checkpointer()
|
||||||
|
|
||||||
self.setup_dataloader()
|
self.setup_dataloader()
|
||||||
self.setup_model()
|
self.setup_model()
|
||||||
|
|
||||||
|
@ -82,13 +98,13 @@ class ExperimentBase(object):
|
||||||
dist.init_parallel_env()
|
dist.init_parallel_env()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
checkpoint.save_parameters(
|
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
||||||
self.checkpoint_dir, self.iteration, self.model, self.optimizer)
|
self.model, self.optimizer)
|
||||||
|
|
||||||
def resume_or_load(self):
|
def resume_or_load(self):
|
||||||
iteration = checkpoint.load_parameters(
|
iteration = checkpoint.load_parameters(
|
||||||
self.model,
|
self.model,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
checkpoint_path=self.args.checkpoint_path)
|
checkpoint_path=self.args.checkpoint_path)
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
@ -115,10 +131,10 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
if self.iteration % self.config.training.valid_interval == 0:
|
if self.iteration % self.config.training.valid_interval == 0:
|
||||||
self.valid()
|
self.valid()
|
||||||
|
|
||||||
if self.iteration % self.config.training.save_interval == 0:
|
if self.iteration % self.config.training.save_interval == 0:
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.resume_or_load()
|
self.resume_or_load()
|
||||||
try:
|
try:
|
||||||
|
@ -126,7 +142,7 @@ class ExperimentBase(object):
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.save()
|
self.save()
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_output_dir(self):
|
def setup_output_dir(self):
|
||||||
# output dir
|
# output dir
|
||||||
|
@ -134,7 +150,7 @@ class ExperimentBase(object):
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_checkpointer(self):
|
def setup_checkpointer(self):
|
||||||
# checkpoint dir
|
# checkpoint dir
|
||||||
|
@ -161,7 +177,7 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def dump_config(self):
|
def dump_config(self):
|
||||||
with open(self.output_dir / "config.yaml", 'wt') as f:
|
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||||
print(self.config, file=f)
|
print(self.config, file=f)
|
||||||
|
|
||||||
def train_batch(self):
|
def train_batch(self):
|
||||||
|
@ -177,4 +193,3 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
def setup_dataloader(self):
|
def setup_dataloader(self):
|
||||||
raise NotImplementedError("setup_dataloader should be implemented.")
|
raise NotImplementedError("setup_dataloader should be implemented.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue