add tacotron2.py and a new frontend for en
This commit is contained in:
parent
e29502f634
commit
f375792c51
|
@ -14,4 +14,4 @@
|
||||||
|
|
||||||
__version__ = "0.2.0"
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
from parakeet import audio, data, datastes, frontend, models, modules, training, utils
|
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
||||||
|
|
|
@ -1,3 +1,84 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# number expansion is not that easy
|
# number expansion is not that easy
|
||||||
import num2words
|
import inflect
|
||||||
import inflect
|
import re
|
||||||
|
|
||||||
|
_inflect = inflect.engine()
|
||||||
|
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||||
|
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||||
|
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||||
|
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||||
|
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||||
|
_number_re = re.compile(r'[0-9]+')
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_commas(m):
|
||||||
|
return m.group(1).replace(',', '')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_decimal_point(m):
|
||||||
|
return m.group(1).replace('.', ' point ')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_dollars(m):
|
||||||
|
match = m.group(1)
|
||||||
|
parts = match.split('.')
|
||||||
|
if len(parts) > 2:
|
||||||
|
return match + ' dollars' # Unexpected format
|
||||||
|
dollars = int(parts[0]) if parts[0] else 0
|
||||||
|
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||||
|
if dollars and cents:
|
||||||
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||||
|
elif dollars:
|
||||||
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
|
return '%s %s' % (dollars, dollar_unit)
|
||||||
|
elif cents:
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return '%s %s' % (cents, cent_unit)
|
||||||
|
else:
|
||||||
|
return 'zero dollars'
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ordinal(m):
|
||||||
|
return _inflect.number_to_words(m.group(0))
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_number(m):
|
||||||
|
num = int(m.group(0))
|
||||||
|
if num > 1000 and num < 3000:
|
||||||
|
if num == 2000:
|
||||||
|
return 'two thousand'
|
||||||
|
elif num > 2000 and num < 2010:
|
||||||
|
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||||
|
elif num % 100 == 0:
|
||||||
|
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(
|
||||||
|
num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(num, andword='')
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_numbers(text):
|
||||||
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
|
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||||
|
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||||
|
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||||
|
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||||
|
text = re.sub(_number_re, _expand_number, text)
|
||||||
|
return text
|
||||||
|
|
|
@ -1,34 +1,53 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from g2p_en import G2p
|
from g2p_en import G2p
|
||||||
from g2pM import G2pM
|
from g2pM import G2pM
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
from builtins import str as unicode
|
||||||
from parakeet.frontend import Vocab
|
from parakeet.frontend import Vocab
|
||||||
from opencc import OpenCC
|
from opencc import OpenCC
|
||||||
from parakeet.frontend.punctuation import get_punctuations
|
from parakeet.frontend.punctuation import get_punctuations
|
||||||
|
from parakeet.frontend.normalizer.numbers import normalize_numbers
|
||||||
|
|
||||||
__all__ = ["Phonetics", "English", "Chinese"]
|
__all__ = ["Phonetics", "English", "EnglishCharacter", "Chinese"]
|
||||||
|
|
||||||
|
|
||||||
class Phonetics(ABC):
|
class Phonetics(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def numericalize(self, phonemes):
|
def numericalize(self, phonemes):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class English(Phonetics):
|
class English(Phonetics):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.backend = G2p()
|
self.backend = G2p()
|
||||||
self.phonemes = list(self.backend.phonemes)
|
self.phonemes = list(self.backend.phonemes)
|
||||||
self.punctuations = get_punctuations("en")
|
self.punctuations = get_punctuations("en")
|
||||||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||||
|
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
start = self.vocab.start_symbol
|
start = self.vocab.start_symbol
|
||||||
end = self.vocab.end_symbol
|
end = self.vocab.end_symbol
|
||||||
|
@ -36,17 +55,67 @@ class English(Phonetics):
|
||||||
+ self.backend(sentence) \
|
+ self.backend(sentence) \
|
||||||
+ ([] if end is None else [end])
|
+ ([] if end is None else [end])
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
def numericalize(self, phonemes):
|
def numericalize(self, phonemes):
|
||||||
ids = [self.vocab.lookup(item) for item in phonemes if item in self.vocab.stoi]
|
ids = [
|
||||||
|
self.vocab.lookup(item) for item in phonemes
|
||||||
|
if item in self.vocab.stoi
|
||||||
|
]
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def reverse(self, ids):
|
def reverse(self, ids):
|
||||||
return [self.vocab.reverse(i) for i in ids]
|
return [self.vocab.reverse(i) for i in ids]
|
||||||
|
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
return self.numericalize(self.phoneticize(sentence))
|
return self.numericalize(self.phoneticize(sentence))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.vocab)
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishCharacter(Phonetics):
|
||||||
|
def __init__(self):
|
||||||
|
self.backend = G2p()
|
||||||
|
self.phonemes = list(self.backend.graphemes)
|
||||||
|
self.punctuations = get_punctuations("en")
|
||||||
|
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||||
|
|
||||||
|
def _prepocessing(self, text):
|
||||||
|
# preprocessing
|
||||||
|
text = unicode(text)
|
||||||
|
text = normalize_numbers(text)
|
||||||
|
text = ''.join(
|
||||||
|
char for char in unicodedata.normalize('NFD', text)
|
||||||
|
if unicodedata.category(char) != 'Mn') # Strip accents
|
||||||
|
text = text.lower()
|
||||||
|
text = re.sub(r"[^ a-z'.,?!\-]", "", text)
|
||||||
|
text = text.replace("i.e.", "that is")
|
||||||
|
text = text.replace("e.g.", "for example")
|
||||||
|
return text
|
||||||
|
|
||||||
|
def phoneticize(self, sentence):
|
||||||
|
start = self.vocab.start_symbol
|
||||||
|
end = self.vocab.end_symbol
|
||||||
|
|
||||||
|
chars = ([] if start is None else [start]) \
|
||||||
|
+ _prepocessing(sentence) \
|
||||||
|
+ ([] if end is None else [end])
|
||||||
|
return chars
|
||||||
|
|
||||||
|
def numericalize(self, chars):
|
||||||
|
ids = [
|
||||||
|
self.vocab.lookup(item) for item in chars
|
||||||
|
if item in self.vocab.stoi
|
||||||
|
]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def reverse(self, ids):
|
||||||
|
return [self.vocab.reverse(i) for i in ids]
|
||||||
|
|
||||||
|
def __call__(self, sentence):
|
||||||
|
return self.numericalize(self.phoneticize(sentence))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.vocab)
|
return len(self.vocab)
|
||||||
|
@ -59,9 +128,11 @@ class Chinese(Phonetics):
|
||||||
self.phonemes = self._get_all_syllables()
|
self.phonemes = self._get_all_syllables()
|
||||||
self.punctuations = get_punctuations("cn")
|
self.punctuations = get_punctuations("cn")
|
||||||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||||
|
|
||||||
def _get_all_syllables(self):
|
def _get_all_syllables(self):
|
||||||
all_syllables = set([syllable for k, v in self.backend.cedict.items() for syllable in v])
|
all_syllables = set([
|
||||||
|
syllable for k, v in self.backend.cedict.items() for syllable in v
|
||||||
|
])
|
||||||
return list(all_syllables)
|
return list(all_syllables)
|
||||||
|
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
|
@ -73,7 +144,7 @@ class Chinese(Phonetics):
|
||||||
+ phonemes \
|
+ phonemes \
|
||||||
+ ([] if end is None else [end])
|
+ ([] if end is None else [end])
|
||||||
return self._filter_symbols(phonemes)
|
return self._filter_symbols(phonemes)
|
||||||
|
|
||||||
def _filter_symbols(self, phonemes):
|
def _filter_symbols(self, phonemes):
|
||||||
cleaned_phonemes = []
|
cleaned_phonemes = []
|
||||||
for item in phonemes:
|
for item in phonemes:
|
||||||
|
@ -84,17 +155,17 @@ class Chinese(Phonetics):
|
||||||
if char in self.vocab.stoi:
|
if char in self.vocab.stoi:
|
||||||
cleaned_phonemes.append(char)
|
cleaned_phonemes.append(char)
|
||||||
return cleaned_phonemes
|
return cleaned_phonemes
|
||||||
|
|
||||||
def numericalize(self, phonemes):
|
def numericalize(self, phonemes):
|
||||||
ids = [self.vocab.lookup(item) for item in phonemes]
|
ids = [self.vocab.lookup(item) for item in phonemes]
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
return self.numericalize(self.phoneticize(sentence))
|
return self.numericalize(self.phoneticize(sentence))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.vocab)
|
return len(self.vocab)
|
||||||
|
|
||||||
def reverse(self, ids):
|
def reverse(self, ids):
|
||||||
return [self.vocab.reverse(i) for i in ids]
|
return [self.vocab.reverse(i) for i in ids]
|
||||||
|
|
|
@ -0,0 +1,424 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from parakeet.modules.conv import Conv1dBatchNorm
|
||||||
|
from parakeet.modules.attention import LocationSensitiveAttention
|
||||||
|
from parakeet.modules import masking
|
||||||
|
|
||||||
|
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderPreNet(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_input: int,
|
||||||
|
d_hidden: int,
|
||||||
|
d_output: int,
|
||||||
|
dropout_rate: int=0.2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False)
|
||||||
|
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate)
|
||||||
|
output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderPostNet(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_mels: int=80,
|
||||||
|
d_hidden: int=512,
|
||||||
|
kernel_size: int=5,
|
||||||
|
padding: int=0,
|
||||||
|
num_layers: int=5,
|
||||||
|
dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.conv_batchnorms = nn.LayerList()
|
||||||
|
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
||||||
|
self.conv_batchnorms.append(
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_mels,
|
||||||
|
d_hidden,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC'))
|
||||||
|
|
||||||
|
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||||
|
self.conv_batchnorms.extend([
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_hidden,
|
||||||
|
d_hidden,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC') for i in range(1, num_layers - 1)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_batchnorms.append(
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_hidden,
|
||||||
|
d_mels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC'))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
for i in range(len(self.conv_batchnorms) - 1):
|
||||||
|
input = F.dropout(
|
||||||
|
F.tanh(self.conv_batchnorms[i](input), self.dropout))
|
||||||
|
input = F.dropout(self.conv_batchnorms[-1](input), self.dropout)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2Encoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_hidden: int,
|
||||||
|
conv_layers: int,
|
||||||
|
kernel_size: int,
|
||||||
|
p_dropout: float):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||||
|
self.conv_batchnorms = paddle.nn.LayerList([
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_hidden,
|
||||||
|
d_hidden,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=int((kernel_size - 1) / 2),
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC') for i in range(conv_layers)
|
||||||
|
])
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
|
self.hidden_size = int(d_hidden / 2)
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
d_hidden, self.hidden_size, direction="bidirectional")
|
||||||
|
|
||||||
|
def forward(self, x, input_lens=None):
|
||||||
|
for conv_batchnorm in conv_batchnorms:
|
||||||
|
x = F.dropout(F.relu(conv_batchnorm(x)),
|
||||||
|
self.p_dropout) #(B, T, C)
|
||||||
|
|
||||||
|
output, _ = self.lstm(inputs=x, sequence_length=input_lens)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2Decoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_mels: int,
|
||||||
|
reduction_factor: int,
|
||||||
|
d_encoder: int,
|
||||||
|
d_prenet: int,
|
||||||
|
d_attention_rnn: int,
|
||||||
|
d_decoder_rnn: int,
|
||||||
|
d_attention: int,
|
||||||
|
attention_filters: int,
|
||||||
|
attention_kernel_size: int,
|
||||||
|
p_prenet_dropout: float,
|
||||||
|
p_attention_dropout: float,
|
||||||
|
p_decoder_dropout: float):
|
||||||
|
super().__init__()
|
||||||
|
self.d_mels = d_mels
|
||||||
|
self.reduction_factor = reduction_factor
|
||||||
|
self.d_encoder = d_encoder
|
||||||
|
self.d_attention_rnn = d_attention_rnn
|
||||||
|
self.d_decoder_rnn = d_decoder_rnn
|
||||||
|
self.p_attention_dropout = p_attention_dropout
|
||||||
|
self.p_decoder_dropout = p_decoder_dropout
|
||||||
|
|
||||||
|
self.prenet = DecoderPreNet(
|
||||||
|
d_mels * reduction_factor,
|
||||||
|
d_prenet,
|
||||||
|
d_prenet,
|
||||||
|
dropout_rate=p_prenet_dropout)
|
||||||
|
|
||||||
|
self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn)
|
||||||
|
|
||||||
|
self.attention_layer = LocationSensitiveAttention(
|
||||||
|
d_attention_rnn, d_encoder, d_attention, attention_filters,
|
||||||
|
attention_kernel_size)
|
||||||
|
self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder,
|
||||||
|
d_decoder_rnn)
|
||||||
|
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
||||||
|
d_mels * reduction_factor)
|
||||||
|
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||||
|
|
||||||
|
def _initialize_decoder_states(self, key):
|
||||||
|
batch_size = key.shape[0]
|
||||||
|
MAX_TIME = key.shape[1]
|
||||||
|
|
||||||
|
self.attention_hidden = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
||||||
|
self.attention_cell = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
||||||
|
|
||||||
|
self.decoder_hidden = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
||||||
|
self.decoder_cell = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
||||||
|
|
||||||
|
self.attention_weights = paddle.zeros(
|
||||||
|
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
||||||
|
self.attention_weights_cum = paddle.zeros(
|
||||||
|
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
||||||
|
self.attention_context = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_encoder], dtype=key.dtype)
|
||||||
|
|
||||||
|
self.key = key #[B, T, C]
|
||||||
|
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
|
||||||
|
|
||||||
|
def _decode(self, query):
|
||||||
|
cell_input = paddle.concat([query, self.attention_context], axis=-1)
|
||||||
|
|
||||||
|
# The first lstm layer
|
||||||
|
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
|
||||||
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
|
self.attention_hidden = F.dropout(self.attention_hidden,
|
||||||
|
self.p_attention_dropout)
|
||||||
|
|
||||||
|
# Loaction sensitive attention
|
||||||
|
attention_weights_cat = paddle.stack(
|
||||||
|
[self.attention_weights, self.attention_weights_cum], axis=-1)
|
||||||
|
self.attention_context, self.attention_weights = self.attention_layer(
|
||||||
|
self.attention_hidden, self.processed_key, self.key,
|
||||||
|
attention_weights_cat, self.mask)
|
||||||
|
self.attention_weights_cum += self.attention_weights
|
||||||
|
|
||||||
|
# The second lasm layer
|
||||||
|
decoder_input = paddle.concat(
|
||||||
|
[self.attention_hidden, self.attention_context], axis=-1)
|
||||||
|
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
||||||
|
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
|
self.decoder_hidden = F.dropout(
|
||||||
|
self.decoder_hidden, p=self.p_decoder_dropout)
|
||||||
|
|
||||||
|
# decode output one step
|
||||||
|
decoder_hidden_attention_context = paddle.concat(
|
||||||
|
[self.decoder_hidden, self.attention_context], axis=-1)
|
||||||
|
decoder_output = self.linear_projection(
|
||||||
|
decoder_hidden_attention_context)
|
||||||
|
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||||
|
return decoder_output, stop_logit, self.attention_weights
|
||||||
|
|
||||||
|
def forward(self, key, query, mask):
|
||||||
|
query = paddle.reshape(
|
||||||
|
query,
|
||||||
|
[query.shape[0], query.shape[1] // self.reduction_factor, -1])
|
||||||
|
query = paddle.concat(
|
||||||
|
[
|
||||||
|
paddle.zeros(
|
||||||
|
shape=[
|
||||||
|
query.shape[0], 1,
|
||||||
|
query.shape[-1] * self.reduction_factor
|
||||||
|
],
|
||||||
|
dtype=query.dtype), query
|
||||||
|
],
|
||||||
|
axis=1)
|
||||||
|
query = self.prenet(query)
|
||||||
|
|
||||||
|
self._initialize_decoder_states(key)
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
|
while len(mel_outputs) < query.shape[
|
||||||
|
1] - 1: # Ignore the last time step
|
||||||
|
query = query[:, len(mel_outputs), :]
|
||||||
|
mel_output, stop_logit, attention_weights = self._decode(query)
|
||||||
|
mel_outputs += [mel_output]
|
||||||
|
stop_logits += [stop_logit]
|
||||||
|
alignments += [attention_weights]
|
||||||
|
|
||||||
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
|
||||||
|
return mel_outputs, stop_logits, alignments
|
||||||
|
|
||||||
|
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
||||||
|
decoder_input = paddle.zeros(
|
||||||
|
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||||
|
dtype=key.dtype) #[B, C]
|
||||||
|
|
||||||
|
self.initialize_decoder_states(key)
|
||||||
|
self.mask = None
|
||||||
|
|
||||||
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
|
while True:
|
||||||
|
decoder_input = self.prenet(decoder_input)
|
||||||
|
mel_output, stop_logit, alignment = self.decode(decoder_input)
|
||||||
|
|
||||||
|
mel_outputs += [mel_output]
|
||||||
|
stop_logits += [stop_logit]
|
||||||
|
alignments += [alignment]
|
||||||
|
|
||||||
|
if F.sigmoid(stop_logit) > stop_threshold:
|
||||||
|
break
|
||||||
|
elif len(mel_outputs) == max_decoder_steps:
|
||||||
|
print("Warning! Reached max decoder steps!!!")
|
||||||
|
break
|
||||||
|
|
||||||
|
decoder_input = mel_output
|
||||||
|
|
||||||
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
|
||||||
|
return mel_outputs, stop_logits, alignments
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2(nn.Layer):
|
||||||
|
"""
|
||||||
|
Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
|
||||||
|
|
||||||
|
This is a module of Spectrogram prediction network in Tacotron2 described
|
||||||
|
in `Natural TTS Synthesis
|
||||||
|
by Conditioning WaveNet on Mel Spectrogram Predictions`_,
|
||||||
|
which converts the sequence of characters
|
||||||
|
into the sequence of mel spectrogram.
|
||||||
|
|
||||||
|
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
|
||||||
|
https://arxiv.org/abs/1712.05884
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
frontend: parakeet.frontend.Phonetics,
|
||||||
|
d_mels: int=80,
|
||||||
|
d_embedding: int=512,
|
||||||
|
encoder_conv_layers: int=3,
|
||||||
|
d_encoder: int=512,
|
||||||
|
encoder_kernel_size: int=5,
|
||||||
|
d_prenet: int=256,
|
||||||
|
d_attention_rnn: int=1024,
|
||||||
|
d_decoder_rnn: int=1024,
|
||||||
|
attention_filters: int=32,
|
||||||
|
attention_kernel_size: int=31,
|
||||||
|
d_attention: int=128,
|
||||||
|
d_postnet: int=512,
|
||||||
|
postnet_kernel_size: int=5,
|
||||||
|
postnet_conv_layers: int=5,
|
||||||
|
reduction_factor: int=1,
|
||||||
|
p_encoder_dropout: float=0.5,
|
||||||
|
p_prenet_dropout: float=0.5,
|
||||||
|
p_attention_dropout: float=0.1,
|
||||||
|
p_decoder_dropout: float=0.1,
|
||||||
|
p_postnet_dropout: float=0.5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
std = math.sqrt(2.0 / (frontend.vocab_size + d_embedding))
|
||||||
|
val = math.sqrt(3.0) * std # uniform bounds for std
|
||||||
|
self.embedding = nn.Embedding(
|
||||||
|
frontend.vocab_size,
|
||||||
|
d_embedding,
|
||||||
|
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-val, high=val)))
|
||||||
|
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
||||||
|
encoder_kernel_size, p_encoder_dropout)
|
||||||
|
self.decoder = Tacotron2Decoder(
|
||||||
|
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
|
||||||
|
d_decoder_rnn, d_attention, attention_filters,
|
||||||
|
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||||
|
p_decoder_dropout)
|
||||||
|
self.postnet = DecoderPostNet(
|
||||||
|
d_mels=d_mels,
|
||||||
|
d_hidden=d_postnet,
|
||||||
|
kernel_size=postnet_kernel_size,
|
||||||
|
padding=int((postnet_kernel_size - 1) / 2),
|
||||||
|
num_layers=postnet_conv_layers,
|
||||||
|
dropout=p_postnet_dropout)
|
||||||
|
|
||||||
|
def forward(self, text_inputs, mels, text_lens, output_lens=None):
|
||||||
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
|
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||||
|
|
||||||
|
mask = paddle.tensor.unsqueeze(
|
||||||
|
paddle.fluid.layers.sequence_mask(
|
||||||
|
x=text_lens, dtype=encoder_outputs.dtype), [-1])
|
||||||
|
mel_outputs, stop_logits, alignments = self.decoder(
|
||||||
|
encoder_outputs, mels, mask=mask)
|
||||||
|
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
|
if output_lens is not None:
|
||||||
|
mask = paddle.tensor.unsqueeze(
|
||||||
|
paddle.fluid.layers.sequence_mask(x=output_lens),
|
||||||
|
[-1]) #[B, T, 1]
|
||||||
|
mel_outputs = mel_outputs * mask #[B, T, C]
|
||||||
|
mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C]
|
||||||
|
stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0]
|
||||||
|
) * 1e3 #[B, T]
|
||||||
|
outputs = {
|
||||||
|
"mel_output": mel_outputs,
|
||||||
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
|
"stop_logits": stop_logits,
|
||||||
|
"alignments": alignments
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
|
||||||
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
|
encoder_outputs = self.encoder(embedded_inputs)
|
||||||
|
mel_outputs, stop_logits, alignments = self.decoder.inference(
|
||||||
|
encoder_outputs,
|
||||||
|
stop_threshold=stop_threshold,
|
||||||
|
max_decoder_steps=max_decoder_steps)
|
||||||
|
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"mel_output": mel_outputs,
|
||||||
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
|
"stop_logits": stop_logits,
|
||||||
|
"alignments": alignments
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def predict(self, text):
|
||||||
|
# TODO(lifuchen): implement predict function to product mel from texts
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2Loss(nn.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
|
||||||
|
mel_targets, stop_tokens):
|
||||||
|
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
||||||
|
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
||||||
|
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)
|
||||||
|
total_loss = mel_loss + post_mel_loss + stop_loss
|
||||||
|
losses = dict(
|
||||||
|
loss=total_loss,
|
||||||
|
mel_loss=mel_loss,
|
||||||
|
post_mel_loss=post_mel_loss,
|
||||||
|
stop_loss=stop_loss)
|
||||||
|
return losses
|
Loading…
Reference in New Issue