commit
1d2e93c75f
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
from builtins import str as unicode
|
||||||
|
from parakeet.frontend.normalizer.numbers import normalize_numbers
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(sentence):
|
||||||
|
# preprocessing
|
||||||
|
sentence = unicode(sentence)
|
||||||
|
sentence = normalize_numbers(sentence)
|
||||||
|
sentence = ''.join(
|
||||||
|
char for char in unicodedata.normalize('NFD', sentence)
|
||||||
|
if unicodedata.category(char) != 'Mn') # Strip accents
|
||||||
|
sentence = sentence.lower()
|
||||||
|
sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence)
|
||||||
|
sentence = sentence.replace("i.e.", "that is")
|
||||||
|
sentence = sentence.replace("e.g.", "for example")
|
||||||
|
return sentence.split()
|
|
@ -1,3 +1,84 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
# number expansion is not that easy
|
# number expansion is not that easy
|
||||||
import num2words
|
import inflect
|
||||||
import inflect
|
import re
|
||||||
|
|
||||||
|
_inflect = inflect.engine()
|
||||||
|
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||||
|
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||||
|
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
||||||
|
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
||||||
|
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||||
|
_number_re = re.compile(r'[0-9]+')
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_commas(m):
|
||||||
|
return m.group(1).replace(',', '')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_decimal_point(m):
|
||||||
|
return m.group(1).replace('.', ' point ')
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_dollars(m):
|
||||||
|
match = m.group(1)
|
||||||
|
parts = match.split('.')
|
||||||
|
if len(parts) > 2:
|
||||||
|
return match + ' dollars' # Unexpected format
|
||||||
|
dollars = int(parts[0]) if parts[0] else 0
|
||||||
|
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||||
|
if dollars and cents:
|
||||||
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
||||||
|
elif dollars:
|
||||||
|
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
||||||
|
return '%s %s' % (dollars, dollar_unit)
|
||||||
|
elif cents:
|
||||||
|
cent_unit = 'cent' if cents == 1 else 'cents'
|
||||||
|
return '%s %s' % (cents, cent_unit)
|
||||||
|
else:
|
||||||
|
return 'zero dollars'
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_ordinal(m):
|
||||||
|
return _inflect.number_to_words(m.group(0))
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_number(m):
|
||||||
|
num = int(m.group(0))
|
||||||
|
if num > 1000 and num < 3000:
|
||||||
|
if num == 2000:
|
||||||
|
return 'two thousand'
|
||||||
|
elif num > 2000 and num < 2010:
|
||||||
|
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||||
|
elif num % 100 == 0:
|
||||||
|
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(
|
||||||
|
num, andword='', zero='oh', group=2).replace(', ', ' ')
|
||||||
|
else:
|
||||||
|
return _inflect.number_to_words(num, andword='')
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_numbers(text):
|
||||||
|
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||||
|
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||||
|
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||||
|
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||||
|
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||||
|
text = re.sub(_number_re, _expand_number, text)
|
||||||
|
return text
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from g2p_en import G2p
|
from g2p_en import G2p
|
||||||
|
@ -5,30 +19,32 @@ from g2pM import G2pM
|
||||||
from parakeet.frontend import Vocab
|
from parakeet.frontend import Vocab
|
||||||
from opencc import OpenCC
|
from opencc import OpenCC
|
||||||
from parakeet.frontend.punctuation import get_punctuations
|
from parakeet.frontend.punctuation import get_punctuations
|
||||||
|
from parakeet.frontend.normalizer.normalizer import normalize
|
||||||
|
|
||||||
__all__ = ["Phonetics", "English", "Chinese"]
|
__all__ = ["Phonetics", "English", "EnglishCharacter", "Chinese"]
|
||||||
|
|
||||||
|
|
||||||
class Phonetics(ABC):
|
class Phonetics(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def numericalize(self, phonemes):
|
def numericalize(self, phonemes):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class English(Phonetics):
|
class English(Phonetics):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.backend = G2p()
|
self.backend = G2p()
|
||||||
self.phonemes = list(self.backend.phonemes)
|
self.phonemes = list(self.backend.phonemes)
|
||||||
self.punctuations = get_punctuations("en")
|
self.punctuations = get_punctuations("en")
|
||||||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||||
|
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
start = self.vocab.start_symbol
|
start = self.vocab.start_symbol
|
||||||
end = self.vocab.end_symbol
|
end = self.vocab.end_symbol
|
||||||
|
@ -36,17 +52,58 @@ class English(Phonetics):
|
||||||
+ self.backend(sentence) \
|
+ self.backend(sentence) \
|
||||||
+ ([] if end is None else [end])
|
+ ([] if end is None else [end])
|
||||||
return phonemes
|
return phonemes
|
||||||
|
|
||||||
def numericalize(self, phonemes):
|
def numericalize(self, phonemes):
|
||||||
ids = [self.vocab.lookup(item) for item in phonemes if item in self.vocab.stoi]
|
ids = [
|
||||||
|
self.vocab.lookup(item) for item in phonemes
|
||||||
|
if item in self.vocab.stoi
|
||||||
|
]
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def reverse(self, ids):
|
def reverse(self, ids):
|
||||||
return [self.vocab.reverse(i) for i in ids]
|
return [self.vocab.reverse(i) for i in ids]
|
||||||
|
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
return self.numericalize(self.phoneticize(sentence))
|
return self.numericalize(self.phoneticize(sentence))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
return len(self.vocab)
|
||||||
|
|
||||||
|
|
||||||
|
class EnglishCharacter(Phonetics):
|
||||||
|
def __init__(self):
|
||||||
|
self.backend = G2p()
|
||||||
|
self.graphemes = list(self.backend.graphemes)
|
||||||
|
self.punctuations = get_punctuations("en")
|
||||||
|
self.vocab = Vocab(self.graphemes + self.punctuations)
|
||||||
|
|
||||||
|
def phoneticize(self, sentence):
|
||||||
|
start = self.vocab.start_symbol
|
||||||
|
end = self.vocab.end_symbol
|
||||||
|
|
||||||
|
words = ([] if start is None else [start]) \
|
||||||
|
+ normalize(sentence) \
|
||||||
|
+ ([] if end is None else [end])
|
||||||
|
return words
|
||||||
|
|
||||||
|
def numericalize(self, words):
|
||||||
|
ids = []
|
||||||
|
for word in words:
|
||||||
|
if word in self.vocab.stoi:
|
||||||
|
ids.append(self.vocab.lookup(word))
|
||||||
|
continue
|
||||||
|
for char in word:
|
||||||
|
if char in self.vocab.stoi:
|
||||||
|
ids.append(self.vocab.lookup(char))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def reverse(self, ids):
|
||||||
|
return [self.vocab.reverse(i) for i in ids]
|
||||||
|
|
||||||
|
def __call__(self, sentence):
|
||||||
|
return self.numericalize(self.phoneticize(sentence))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.vocab)
|
return len(self.vocab)
|
||||||
|
@ -59,9 +116,11 @@ class Chinese(Phonetics):
|
||||||
self.phonemes = self._get_all_syllables()
|
self.phonemes = self._get_all_syllables()
|
||||||
self.punctuations = get_punctuations("cn")
|
self.punctuations = get_punctuations("cn")
|
||||||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||||
|
|
||||||
def _get_all_syllables(self):
|
def _get_all_syllables(self):
|
||||||
all_syllables = set([syllable for k, v in self.backend.cedict.items() for syllable in v])
|
all_syllables = set([
|
||||||
|
syllable for k, v in self.backend.cedict.items() for syllable in v
|
||||||
|
])
|
||||||
return list(all_syllables)
|
return list(all_syllables)
|
||||||
|
|
||||||
def phoneticize(self, sentence):
|
def phoneticize(self, sentence):
|
||||||
|
@ -73,7 +132,7 @@ class Chinese(Phonetics):
|
||||||
+ phonemes \
|
+ phonemes \
|
||||||
+ ([] if end is None else [end])
|
+ ([] if end is None else [end])
|
||||||
return self._filter_symbols(phonemes)
|
return self._filter_symbols(phonemes)
|
||||||
|
|
||||||
def _filter_symbols(self, phonemes):
|
def _filter_symbols(self, phonemes):
|
||||||
cleaned_phonemes = []
|
cleaned_phonemes = []
|
||||||
for item in phonemes:
|
for item in phonemes:
|
||||||
|
@ -84,17 +143,17 @@ class Chinese(Phonetics):
|
||||||
if char in self.vocab.stoi:
|
if char in self.vocab.stoi:
|
||||||
cleaned_phonemes.append(char)
|
cleaned_phonemes.append(char)
|
||||||
return cleaned_phonemes
|
return cleaned_phonemes
|
||||||
|
|
||||||
def numericalize(self, phonemes):
|
def numericalize(self, phonemes):
|
||||||
ids = [self.vocab.lookup(item) for item in phonemes]
|
ids = [self.vocab.lookup(item) for item in phonemes]
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
def __call__(self, sentence):
|
def __call__(self, sentence):
|
||||||
return self.numericalize(self.phoneticize(sentence))
|
return self.numericalize(self.phoneticize(sentence))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return len(self.vocab)
|
return len(self.vocab)
|
||||||
|
|
||||||
def reverse(self, ids):
|
def reverse(self, ids):
|
||||||
return [self.vocab.reverse(i) for i in ids]
|
return [self.vocab.reverse(i) for i in ids]
|
||||||
|
|
|
@ -0,0 +1,427 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import parakeet
|
||||||
|
from parakeet.modules.conv import Conv1dBatchNorm
|
||||||
|
from parakeet.modules.attention import LocationSensitiveAttention
|
||||||
|
from parakeet.modules import masking
|
||||||
|
|
||||||
|
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderPreNet(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_input: int,
|
||||||
|
d_hidden: int,
|
||||||
|
d_output: int,
|
||||||
|
dropout_rate: int=0.2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False)
|
||||||
|
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate)
|
||||||
|
output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderPostNet(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_mels: int=80,
|
||||||
|
d_hidden: int=512,
|
||||||
|
kernel_size: int=5,
|
||||||
|
padding: int=0,
|
||||||
|
num_layers: int=5,
|
||||||
|
dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.conv_batchnorms = nn.LayerList()
|
||||||
|
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
||||||
|
self.conv_batchnorms.append(
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_mels,
|
||||||
|
d_hidden,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC'))
|
||||||
|
|
||||||
|
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||||
|
self.conv_batchnorms.extend([
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_hidden,
|
||||||
|
d_hidden,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC') for i in range(1, num_layers - 1)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv_batchnorms.append(
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_hidden,
|
||||||
|
d_mels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC'))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
for i in range(len(self.conv_batchnorms) - 1):
|
||||||
|
input = F.dropout(
|
||||||
|
F.tanh(self.conv_batchnorms[i](input), self.dropout))
|
||||||
|
input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
|
||||||
|
self.dropout)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2Encoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_hidden: int,
|
||||||
|
conv_layers: int,
|
||||||
|
kernel_size: int,
|
||||||
|
p_dropout: float):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||||
|
self.conv_batchnorms = paddle.nn.LayerList([
|
||||||
|
Conv1dBatchNorm(
|
||||||
|
d_hidden,
|
||||||
|
d_hidden,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=int((kernel_size - 1) / 2),
|
||||||
|
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-k, high=k)),
|
||||||
|
data_format='NLC') for i in range(conv_layers)
|
||||||
|
])
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
|
self.hidden_size = int(d_hidden / 2)
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
d_hidden, self.hidden_size, direction="bidirectional")
|
||||||
|
|
||||||
|
def forward(self, x, input_lens=None):
|
||||||
|
for conv_batchnorm in self.conv_batchnorms:
|
||||||
|
x = F.dropout(F.relu(conv_batchnorm(x)),
|
||||||
|
self.p_dropout) #(B, T, C)
|
||||||
|
|
||||||
|
output, _ = self.lstm(inputs=x, sequence_length=input_lens)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2Decoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_mels: int,
|
||||||
|
reduction_factor: int,
|
||||||
|
d_encoder: int,
|
||||||
|
d_prenet: int,
|
||||||
|
d_attention_rnn: int,
|
||||||
|
d_decoder_rnn: int,
|
||||||
|
d_attention: int,
|
||||||
|
attention_filters: int,
|
||||||
|
attention_kernel_size: int,
|
||||||
|
p_prenet_dropout: float,
|
||||||
|
p_attention_dropout: float,
|
||||||
|
p_decoder_dropout: float):
|
||||||
|
super().__init__()
|
||||||
|
self.d_mels = d_mels
|
||||||
|
self.reduction_factor = reduction_factor
|
||||||
|
self.d_encoder = d_encoder
|
||||||
|
self.d_attention_rnn = d_attention_rnn
|
||||||
|
self.d_decoder_rnn = d_decoder_rnn
|
||||||
|
self.p_attention_dropout = p_attention_dropout
|
||||||
|
self.p_decoder_dropout = p_decoder_dropout
|
||||||
|
|
||||||
|
self.prenet = DecoderPreNet(
|
||||||
|
d_mels * reduction_factor,
|
||||||
|
d_prenet,
|
||||||
|
d_prenet,
|
||||||
|
dropout_rate=p_prenet_dropout)
|
||||||
|
|
||||||
|
self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn)
|
||||||
|
|
||||||
|
self.attention_layer = LocationSensitiveAttention(
|
||||||
|
d_attention_rnn, d_encoder, d_attention, attention_filters,
|
||||||
|
attention_kernel_size)
|
||||||
|
self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder,
|
||||||
|
d_decoder_rnn)
|
||||||
|
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
||||||
|
d_mels * reduction_factor)
|
||||||
|
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||||
|
|
||||||
|
def _initialize_decoder_states(self, key):
|
||||||
|
batch_size = key.shape[0]
|
||||||
|
MAX_TIME = key.shape[1]
|
||||||
|
|
||||||
|
self.attention_hidden = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
||||||
|
self.attention_cell = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
||||||
|
|
||||||
|
self.decoder_hidden = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
||||||
|
self.decoder_cell = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
||||||
|
|
||||||
|
self.attention_weights = paddle.zeros(
|
||||||
|
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
||||||
|
self.attention_weights_cum = paddle.zeros(
|
||||||
|
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
||||||
|
self.attention_context = paddle.zeros(
|
||||||
|
shape=[batch_size, self.d_encoder], dtype=key.dtype)
|
||||||
|
|
||||||
|
self.key = key #[B, T, C]
|
||||||
|
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
|
||||||
|
|
||||||
|
def _decode(self, query):
|
||||||
|
cell_input = paddle.concat([query, self.attention_context], axis=-1)
|
||||||
|
|
||||||
|
# The first lstm layer
|
||||||
|
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
|
||||||
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
|
self.attention_hidden = F.dropout(self.attention_hidden,
|
||||||
|
self.p_attention_dropout)
|
||||||
|
|
||||||
|
# Loaction sensitive attention
|
||||||
|
attention_weights_cat = paddle.stack(
|
||||||
|
[self.attention_weights, self.attention_weights_cum], axis=-1)
|
||||||
|
self.attention_context, self.attention_weights = self.attention_layer(
|
||||||
|
self.attention_hidden, self.processed_key, self.key,
|
||||||
|
attention_weights_cat, self.mask)
|
||||||
|
self.attention_weights_cum += self.attention_weights
|
||||||
|
|
||||||
|
# The second lstm layer
|
||||||
|
decoder_input = paddle.concat(
|
||||||
|
[self.attention_hidden, self.attention_context], axis=-1)
|
||||||
|
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
||||||
|
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
|
self.decoder_hidden = F.dropout(
|
||||||
|
self.decoder_hidden, p=self.p_decoder_dropout)
|
||||||
|
|
||||||
|
# decode output one step
|
||||||
|
decoder_hidden_attention_context = paddle.concat(
|
||||||
|
[self.decoder_hidden, self.attention_context], axis=-1)
|
||||||
|
decoder_output = self.linear_projection(
|
||||||
|
decoder_hidden_attention_context)
|
||||||
|
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||||
|
return decoder_output, stop_logit, self.attention_weights
|
||||||
|
|
||||||
|
def forward(self, keys, querys, mask):
|
||||||
|
querys = paddle.reshape(
|
||||||
|
querys,
|
||||||
|
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
||||||
|
querys = paddle.concat(
|
||||||
|
[
|
||||||
|
paddle.zeros(
|
||||||
|
shape=[
|
||||||
|
querys.shape[0], 1,
|
||||||
|
querys.shape[-1] * self.reduction_factor
|
||||||
|
],
|
||||||
|
dtype=querys.dtype), querys
|
||||||
|
],
|
||||||
|
axis=1)
|
||||||
|
querys = self.prenet(querys)
|
||||||
|
|
||||||
|
self._initialize_decoder_states(keys)
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
|
while len(mel_outputs) < querys.shape[
|
||||||
|
1] - 1: # Ignore the last time step
|
||||||
|
query = querys[:, len(mel_outputs), :]
|
||||||
|
mel_output, stop_logit, attention_weights = self._decode(query)
|
||||||
|
mel_outputs += [mel_output]
|
||||||
|
stop_logits += [stop_logit]
|
||||||
|
alignments += [attention_weights]
|
||||||
|
|
||||||
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
|
||||||
|
return mel_outputs, stop_logits, alignments
|
||||||
|
|
||||||
|
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
||||||
|
decoder_input = paddle.zeros(
|
||||||
|
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||||
|
dtype=key.dtype) #[B, C]
|
||||||
|
|
||||||
|
self.initialize_decoder_states(key)
|
||||||
|
self.mask = None
|
||||||
|
|
||||||
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
|
while True:
|
||||||
|
decoder_input = self.prenet(decoder_input)
|
||||||
|
mel_output, stop_logit, alignment = self.decode(decoder_input)
|
||||||
|
|
||||||
|
mel_outputs += [mel_output]
|
||||||
|
stop_logits += [stop_logit]
|
||||||
|
alignments += [alignment]
|
||||||
|
|
||||||
|
if F.sigmoid(stop_logit) > stop_threshold:
|
||||||
|
break
|
||||||
|
elif len(mel_outputs) == max_decoder_steps:
|
||||||
|
print("Warning! Reached max decoder steps!!!")
|
||||||
|
break
|
||||||
|
|
||||||
|
decoder_input = mel_output
|
||||||
|
|
||||||
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
|
||||||
|
return mel_outputs, stop_logits, alignments
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2(nn.Layer):
|
||||||
|
"""
|
||||||
|
Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
|
||||||
|
|
||||||
|
This is a module of Spectrogram prediction network in Tacotron2 described
|
||||||
|
in `Natural TTS Synthesis
|
||||||
|
by Conditioning WaveNet on Mel Spectrogram Predictions`_,
|
||||||
|
which converts the sequence of characters
|
||||||
|
into the sequence of mel spectrogram.
|
||||||
|
|
||||||
|
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
|
||||||
|
https://arxiv.org/abs/1712.05884
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
frontend: parakeet.frontend.Phonetics,
|
||||||
|
d_mels: int=80,
|
||||||
|
d_encoder: int=512,
|
||||||
|
encoder_conv_layers: int=3,
|
||||||
|
encoder_kernel_size: int=5,
|
||||||
|
d_prenet: int=256,
|
||||||
|
d_attention_rnn: int=1024,
|
||||||
|
d_decoder_rnn: int=1024,
|
||||||
|
attention_filters: int=32,
|
||||||
|
attention_kernel_size: int=31,
|
||||||
|
d_attention: int=128,
|
||||||
|
d_postnet: int=512,
|
||||||
|
postnet_kernel_size: int=5,
|
||||||
|
postnet_conv_layers: int=5,
|
||||||
|
reduction_factor: int=1,
|
||||||
|
p_encoder_dropout: float=0.5,
|
||||||
|
p_prenet_dropout: float=0.5,
|
||||||
|
p_attention_dropout: float=0.1,
|
||||||
|
p_decoder_dropout: float=0.1,
|
||||||
|
p_postnet_dropout: float=0.5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder))
|
||||||
|
val = math.sqrt(3.0) * std # uniform bounds for std
|
||||||
|
self.embedding = nn.Embedding(
|
||||||
|
frontend.vocab_size,
|
||||||
|
d_encoder,
|
||||||
|
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||||
|
low=-val, high=val)))
|
||||||
|
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
||||||
|
encoder_kernel_size, p_encoder_dropout)
|
||||||
|
self.decoder = Tacotron2Decoder(
|
||||||
|
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
|
||||||
|
d_decoder_rnn, d_attention, attention_filters,
|
||||||
|
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||||
|
p_decoder_dropout)
|
||||||
|
self.postnet = DecoderPostNet(
|
||||||
|
d_mels=d_mels,
|
||||||
|
d_hidden=d_postnet,
|
||||||
|
kernel_size=postnet_kernel_size,
|
||||||
|
padding=int((postnet_kernel_size - 1) / 2),
|
||||||
|
num_layers=postnet_conv_layers,
|
||||||
|
dropout=p_postnet_dropout)
|
||||||
|
|
||||||
|
def forward(self, text_inputs, mels, text_lens, output_lens=None):
|
||||||
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
|
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||||
|
|
||||||
|
mask = paddle.tensor.unsqueeze(
|
||||||
|
paddle.fluid.layers.sequence_mask(
|
||||||
|
x=text_lens, dtype=encoder_outputs.dtype), [-1])
|
||||||
|
mel_outputs, stop_logits, alignments = self.decoder(
|
||||||
|
encoder_outputs, mels, mask=mask)
|
||||||
|
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
|
if output_lens is not None:
|
||||||
|
mask = paddle.tensor.unsqueeze(
|
||||||
|
paddle.fluid.layers.sequence_mask(x=output_lens),
|
||||||
|
[-1]) #[B, T, 1]
|
||||||
|
mel_outputs = mel_outputs * mask #[B, T, C]
|
||||||
|
mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C]
|
||||||
|
stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0]
|
||||||
|
) * 1e3 #[B, T]
|
||||||
|
outputs = {
|
||||||
|
"mel_output": mel_outputs,
|
||||||
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
|
"stop_logits": stop_logits,
|
||||||
|
"alignments": alignments
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
|
||||||
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
|
encoder_outputs = self.encoder(embedded_inputs)
|
||||||
|
mel_outputs, stop_logits, alignments = self.decoder.inference(
|
||||||
|
encoder_outputs,
|
||||||
|
stop_threshold=stop_threshold,
|
||||||
|
max_decoder_steps=max_decoder_steps)
|
||||||
|
|
||||||
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
"mel_output": mel_outputs,
|
||||||
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
|
"stop_logits": stop_logits,
|
||||||
|
"alignments": alignments
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def predict(self, text):
|
||||||
|
# TODO(lifuchen): implement predict function to product mel from texts
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Tacotron2Loss(nn.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
|
||||||
|
mel_targets, stop_tokens):
|
||||||
|
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
||||||
|
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
||||||
|
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)
|
||||||
|
total_loss = mel_loss + post_mel_loss + stop_loss
|
||||||
|
losses = dict(
|
||||||
|
loss=total_loss,
|
||||||
|
mel_loss=mel_loss,
|
||||||
|
post_mel_loss=post_mel_loss,
|
||||||
|
stop_loss=stop_loss)
|
||||||
|
return losses
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
import paddle
|
import paddle
|
||||||
|
@ -15,6 +29,7 @@ from parakeet.modules import losses as L
|
||||||
|
|
||||||
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
|
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
|
||||||
|
|
||||||
|
|
||||||
# Transformer TTS's own implementation of transformer
|
# Transformer TTS's own implementation of transformer
|
||||||
class MultiheadAttention(nn.Layer):
|
class MultiheadAttention(nn.Layer):
|
||||||
"""
|
"""
|
||||||
|
@ -25,7 +40,14 @@ class MultiheadAttention(nn.Layer):
|
||||||
Another deviation is that it concats the input query and context vector before
|
Another deviation is that it concats the input query and context vector before
|
||||||
applying the output projection.
|
applying the output projection.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None):
|
|
||||||
|
def __init__(self,
|
||||||
|
model_dim,
|
||||||
|
num_heads,
|
||||||
|
k_dim=None,
|
||||||
|
v_dim=None,
|
||||||
|
k_input_dim=None,
|
||||||
|
v_input_dim=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
model_dim (int): the feature size of query.
|
model_dim (int): the feature size of query.
|
||||||
|
@ -41,7 +63,7 @@ class MultiheadAttention(nn.Layer):
|
||||||
ValueError: if model_dim is not divisible by num_heads
|
ValueError: if model_dim is not divisible by num_heads
|
||||||
"""
|
"""
|
||||||
super(MultiheadAttention, self).__init__()
|
super(MultiheadAttention, self).__init__()
|
||||||
if model_dim % num_heads !=0:
|
if model_dim % num_heads != 0:
|
||||||
raise ValueError("model_dim must be divisible by num_heads")
|
raise ValueError("model_dim must be divisible by num_heads")
|
||||||
depth = model_dim // num_heads
|
depth = model_dim // num_heads
|
||||||
k_dim = k_dim or depth
|
k_dim = k_dim or depth
|
||||||
|
@ -52,10 +74,10 @@ class MultiheadAttention(nn.Layer):
|
||||||
self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
|
self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim)
|
||||||
self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
|
self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim)
|
||||||
self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)
|
self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim)
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
|
|
||||||
def forward(self, q, k, v, mask, drop_n_heads=0):
|
def forward(self, q, k, v, mask, drop_n_heads=0):
|
||||||
"""
|
"""
|
||||||
Compute context vector and attention weights.
|
Compute context vector and attention weights.
|
||||||
|
@ -72,17 +94,18 @@ class MultiheadAttention(nn.Layer):
|
||||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
||||||
"""
|
"""
|
||||||
q_in = q
|
q_in = q
|
||||||
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
||||||
k = _split_heads(self.affine_k(k), self.num_heads)
|
k = _split_heads(self.affine_k(k), self.num_heads)
|
||||||
v = _split_heads(self.affine_v(v), self.num_heads)
|
v = _split_heads(self.affine_v(v), self.num_heads)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
||||||
|
|
||||||
context_vectors, attention_weights = scaled_dot_product_attention(
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||||
q, k, v, mask, training=self.training)
|
q, k, v, mask, training=self.training)
|
||||||
context_vectors = drop_head(context_vectors, drop_n_heads, self.training)
|
context_vectors = drop_head(context_vectors, drop_n_heads,
|
||||||
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
self.training)
|
||||||
|
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
||||||
|
|
||||||
concat_feature = paddle.concat([q_in, context_vectors], -1)
|
concat_feature = paddle.concat([q_in, context_vectors], -1)
|
||||||
out = self.affine_o(concat_feature)
|
out = self.affine_o(concat_feature)
|
||||||
return out, attention_weights
|
return out, attention_weights
|
||||||
|
@ -92,6 +115,7 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
"""
|
"""
|
||||||
Transformer encoder layer.
|
Transformer encoder layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -114,8 +138,10 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.layer_norm1(x)
|
x = self.layer_norm1(x)
|
||||||
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
context_vector, attn_weights = self.self_mha(x, x, x, mask,
|
||||||
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
drop_n_heads)
|
||||||
|
context_vector = x_in + F.dropout(
|
||||||
|
context_vector, self.dropout, training=self.training)
|
||||||
return context_vector, attn_weights
|
return context_vector, attn_weights
|
||||||
|
|
||||||
def _forward_ffn(self, x):
|
def _forward_ffn(self, x):
|
||||||
|
@ -123,9 +149,9 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.layer_norm2(x)
|
x = self.layer_norm2(x)
|
||||||
x = self.ffn(x)
|
x = self.ffn(x)
|
||||||
out= x_in + F.dropout(x, self.dropout, training=self.training)
|
out = x_in + F.dropout(x, self.dropout, training=self.training)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x, mask, drop_n_heads=0):
|
def forward(self, x, mask, drop_n_heads=0):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -145,6 +171,7 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
"""
|
"""
|
||||||
Transformer decoder layer.
|
Transformer decoder layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
|
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -157,37 +184,42 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
super(TransformerDecoderLayer, self).__init__()
|
||||||
self.self_mha = MultiheadAttention(d_model, n_heads)
|
self.self_mha = MultiheadAttention(d_model, n_heads)
|
||||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
|
self.cross_mha = MultiheadAttention(
|
||||||
|
d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
|
||||||
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||||
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def _forward_self_mha(self, x, mask, drop_n_heads):
|
def _forward_self_mha(self, x, mask, drop_n_heads):
|
||||||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.layer_norm1(x)
|
x = self.layer_norm1(x)
|
||||||
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads)
|
context_vector, attn_weights = self.self_mha(x, x, x, mask,
|
||||||
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training)
|
drop_n_heads)
|
||||||
|
context_vector = x_in + F.dropout(
|
||||||
|
context_vector, self.dropout, training=self.training)
|
||||||
return context_vector, attn_weights
|
return context_vector, attn_weights
|
||||||
|
|
||||||
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
|
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
|
||||||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
q_in = q
|
q_in = q
|
||||||
q = self.layer_norm2(q)
|
q = self.layer_norm2(q)
|
||||||
context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads)
|
context_vector, attn_weights = self.cross_mha(q, k, v, mask,
|
||||||
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training)
|
drop_n_heads)
|
||||||
|
context_vector = q_in + F.dropout(
|
||||||
|
context_vector, self.dropout, training=self.training)
|
||||||
return context_vector, attn_weights
|
return context_vector, attn_weights
|
||||||
|
|
||||||
def _forward_ffn(self, x):
|
def _forward_ffn(self, x):
|
||||||
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
|
||||||
x_in = x
|
x_in = x
|
||||||
x = self.layer_norm3(x)
|
x = self.layer_norm3(x)
|
||||||
x = self.ffn(x)
|
x = self.ffn(x)
|
||||||
out= x_in + F.dropout(x, self.dropout, training=self.training)
|
out = x_in + F.dropout(x, self.dropout, training=self.training)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
||||||
|
@ -204,8 +236,10 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
||||||
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
||||||
"""
|
"""
|
||||||
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads)
|
q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
|
||||||
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads)
|
drop_n_heads)
|
||||||
|
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
|
||||||
|
drop_n_heads)
|
||||||
q = self._forward_ffn(q)
|
q = self._forward_ffn(q)
|
||||||
return q, self_attn_weights, cross_attn_weights
|
return q, self_attn_weights, cross_attn_weights
|
||||||
|
|
||||||
|
@ -214,7 +248,8 @@ class TransformerEncoder(nn.LayerList):
|
||||||
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
|
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
|
||||||
super(TransformerEncoder, self).__init__()
|
super(TransformerEncoder, self).__init__()
|
||||||
for _ in range(n_layers):
|
for _ in range(n_layers):
|
||||||
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
self.append(
|
||||||
|
TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
|
||||||
|
|
||||||
def forward(self, x, mask, drop_n_heads=0):
|
def forward(self, x, mask, drop_n_heads=0):
|
||||||
"""
|
"""
|
||||||
|
@ -236,10 +271,18 @@ class TransformerEncoder(nn.LayerList):
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.LayerList):
|
class TransformerDecoder(nn.LayerList):
|
||||||
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None):
|
def __init__(self,
|
||||||
|
d_model,
|
||||||
|
n_heads,
|
||||||
|
d_ffn,
|
||||||
|
n_layers,
|
||||||
|
dropout=0.,
|
||||||
|
d_encoder=None):
|
||||||
super(TransformerDecoder, self).__init__()
|
super(TransformerDecoder, self).__init__()
|
||||||
for _ in range(n_layers):
|
for _ in range(n_layers):
|
||||||
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
self.append(
|
||||||
|
TransformerDecoderLayer(
|
||||||
|
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
||||||
|
|
||||||
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
||||||
"""[summary]
|
"""[summary]
|
||||||
|
@ -260,7 +303,8 @@ class TransformerDecoder(nn.LayerList):
|
||||||
self_attention_weights = []
|
self_attention_weights = []
|
||||||
cross_attention_weights = []
|
cross_attention_weights = []
|
||||||
for layer in self:
|
for layer in self:
|
||||||
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads)
|
q, self_attention_weights_i, cross_attention_weights_i = layer(
|
||||||
|
q, k, v, encoder_mask, decoder_mask, drop_n_heads)
|
||||||
self_attention_weights.append(self_attention_weights_i)
|
self_attention_weights.append(self_attention_weights_i)
|
||||||
cross_attention_weights.append(cross_attention_weights_i)
|
cross_attention_weights.append(cross_attention_weights_i)
|
||||||
return q, self_attention_weights, cross_attention_weights
|
return q, self_attention_weights, cross_attention_weights
|
||||||
|
@ -268,6 +312,7 @@ class TransformerDecoder(nn.LayerList):
|
||||||
|
|
||||||
class MLPPreNet(nn.Layer):
|
class MLPPreNet(nn.Layer):
|
||||||
"""Decoder's prenet."""
|
"""Decoder's prenet."""
|
||||||
|
|
||||||
def __init__(self, d_input, d_hidden, d_output, dropout):
|
def __init__(self, d_input, d_hidden, d_output, dropout):
|
||||||
# (lin + relu + dropout) * n + last projection
|
# (lin + relu + dropout) * n + last projection
|
||||||
super(MLPPreNet, self).__init__()
|
super(MLPPreNet, self).__init__()
|
||||||
|
@ -275,16 +320,24 @@ class MLPPreNet(nn.Layer):
|
||||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||||
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x, dropout):
|
def forward(self, x, dropout):
|
||||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
l1 = F.dropout(
|
||||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||||
|
l2 = F.dropout(
|
||||||
|
F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||||
l3 = self.lin3(l2)
|
l3 = self.lin3(l2)
|
||||||
return l3
|
return l3
|
||||||
|
|
||||||
|
|
||||||
# NOTE: not used in
|
# NOTE: not used in
|
||||||
class CNNPreNet(nn.Layer):
|
class CNNPreNet(nn.Layer):
|
||||||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers,
|
def __init__(self,
|
||||||
|
d_input,
|
||||||
|
d_hidden,
|
||||||
|
d_output,
|
||||||
|
kernel_size,
|
||||||
|
n_layers,
|
||||||
dropout=0.):
|
dropout=0.):
|
||||||
# (conv + bn + relu + dropout) * n + last projection
|
# (conv + bn + relu + dropout) * n + last projection
|
||||||
super(CNNPreNet, self).__init__()
|
super(CNNPreNet, self).__init__()
|
||||||
|
@ -292,16 +345,21 @@ class CNNPreNet(nn.Layer):
|
||||||
c_in = d_input
|
c_in = d_input
|
||||||
for _ in range(n_layers):
|
for _ in range(n_layers):
|
||||||
self.convs.append(
|
self.convs.append(
|
||||||
Conv1dBatchNorm(c_in, d_hidden, kernel_size,
|
Conv1dBatchNorm(
|
||||||
weight_attr=I.XavierUniform(),
|
c_in,
|
||||||
padding="same", data_format="NLC"))
|
d_hidden,
|
||||||
|
kernel_size,
|
||||||
|
weight_attr=I.XavierUniform(),
|
||||||
|
padding="same",
|
||||||
|
data_format="NLC"))
|
||||||
c_in = d_hidden
|
c_in = d_hidden
|
||||||
self.affine_out = nn.Linear(d_hidden, d_output)
|
self.affine_out = nn.Linear(d_hidden, d_output)
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for layer in self.convs:
|
for layer in self.convs:
|
||||||
x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training)
|
x = F.dropout(
|
||||||
|
F.relu(layer(x)), self.dropout, training=self.training)
|
||||||
x = self.affine_out(x)
|
x = self.affine_out(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -310,21 +368,25 @@ class CNNPostNet(nn.Layer):
|
||||||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
||||||
super(CNNPostNet, self).__init__()
|
super(CNNPostNet, self).__init__()
|
||||||
self.convs = nn.LayerList()
|
self.convs = nn.LayerList()
|
||||||
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, )
|
kernel_size = kernel_size if isinstance(kernel_size, (
|
||||||
|
tuple, list)) else (kernel_size, )
|
||||||
padding = (kernel_size[0] - 1, 0)
|
padding = (kernel_size[0] - 1, 0)
|
||||||
for i in range(n_layers):
|
for i in range(n_layers):
|
||||||
c_in = d_input if i == 0 else d_hidden
|
c_in = d_input if i == 0 else d_hidden
|
||||||
c_out = d_output if i == n_layers - 1 else d_hidden
|
c_out = d_output if i == n_layers - 1 else d_hidden
|
||||||
self.convs.append(
|
self.convs.append(
|
||||||
Conv1dBatchNorm(c_in, c_out, kernel_size,
|
Conv1dBatchNorm(
|
||||||
weight_attr=I.XavierUniform(),
|
c_in,
|
||||||
padding=padding))
|
c_out,
|
||||||
|
kernel_size,
|
||||||
|
weight_attr=I.XavierUniform(),
|
||||||
|
padding=padding))
|
||||||
self.last_bn = nn.BatchNorm1D(d_output)
|
self.last_bn = nn.BatchNorm1D(d_output)
|
||||||
# for a layer that ends with a normalization layer that is targeted to
|
# for a layer that ends with a normalization layer that is targeted to
|
||||||
# output a non zero-central output, it may take a long time to
|
# output a non zero-central output, it may take a long time to
|
||||||
# train the scale and bias
|
# train the scale and bias
|
||||||
# NOTE: it can also be a non-causal conv
|
# NOTE: it can also be a non-causal conv
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_in = x
|
x_in = x
|
||||||
for i, layer in enumerate(self.convs):
|
for i, layer in enumerate(self.convs):
|
||||||
|
@ -336,19 +398,19 @@ class CNNPostNet(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class TransformerTTS(nn.Layer):
|
class TransformerTTS(nn.Layer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
frontend: parakeet.frontend.Phonetics,
|
frontend: parakeet.frontend.Phonetics,
|
||||||
d_encoder: int,
|
d_encoder: int,
|
||||||
d_decoder: int,
|
d_decoder: int,
|
||||||
d_mel: int,
|
d_mel: int,
|
||||||
n_heads: int,
|
n_heads: int,
|
||||||
d_ffn: int,
|
d_ffn: int,
|
||||||
encoder_layers: int,
|
encoder_layers: int,
|
||||||
decoder_layers: int,
|
decoder_layers: int,
|
||||||
d_prenet: int,
|
d_prenet: int,
|
||||||
d_postnet: int,
|
d_postnet: int,
|
||||||
postnet_layers: int,
|
postnet_layers: int,
|
||||||
postnet_kernel_size: int,
|
postnet_kernel_size: int,
|
||||||
max_reduction_factor: int,
|
max_reduction_factor: int,
|
||||||
decoder_prenet_dropout: float,
|
decoder_prenet_dropout: float,
|
||||||
dropout: float):
|
dropout: float):
|
||||||
|
@ -359,29 +421,34 @@ class TransformerTTS(nn.Layer):
|
||||||
|
|
||||||
# encoder
|
# encoder
|
||||||
self.encoder_prenet = nn.Embedding(
|
self.encoder_prenet = nn.Embedding(
|
||||||
frontend.vocab_size, d_encoder,
|
frontend.vocab_size,
|
||||||
padding_idx=frontend.vocab.padding_index,
|
d_encoder,
|
||||||
|
padding_idx=frontend.vocab.padding_index,
|
||||||
weight_attr=I.Uniform(-0.05, 0.05))
|
weight_attr=I.Uniform(-0.05, 0.05))
|
||||||
# position encoding matrix may be extended later
|
# position encoding matrix may be extended later
|
||||||
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
|
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
|
||||||
self.encoder_pe_scalar = self.create_parameter(
|
self.encoder_pe_scalar = self.create_parameter(
|
||||||
[1], attr=I.Constant(1.))
|
[1], attr=I.Constant(1.))
|
||||||
self.encoder = TransformerEncoder(
|
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
|
||||||
d_encoder, n_heads, d_ffn, encoder_layers, dropout)
|
encoder_layers, dropout)
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
||||||
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
|
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
|
||||||
self.decoder_pe_scalar = self.create_parameter(
|
self.decoder_pe_scalar = self.create_parameter(
|
||||||
[1], attr=I.Constant(1.))
|
[1], attr=I.Constant(1.))
|
||||||
self.decoder = TransformerDecoder(
|
self.decoder = TransformerDecoder(
|
||||||
d_decoder, n_heads, d_ffn, decoder_layers, dropout,
|
d_decoder,
|
||||||
|
n_heads,
|
||||||
|
d_ffn,
|
||||||
|
decoder_layers,
|
||||||
|
dropout,
|
||||||
d_encoder=d_encoder)
|
d_encoder=d_encoder)
|
||||||
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
|
||||||
self.decoder_postnet = CNNPostNet(
|
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel,
|
||||||
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
postnet_kernel_size, postnet_layers)
|
||||||
self.stop_conditioner = nn.Linear(d_mel, 3)
|
self.stop_conditioner = nn.Linear(d_mel, 3)
|
||||||
|
|
||||||
# specs
|
# specs
|
||||||
self.padding_idx = frontend.vocab.padding_index
|
self.padding_idx = frontend.vocab.padding_index
|
||||||
self.d_encoder = d_encoder
|
self.d_encoder = d_encoder
|
||||||
|
@ -390,21 +457,22 @@ class TransformerTTS(nn.Layer):
|
||||||
self.max_r = max_reduction_factor
|
self.max_r = max_reduction_factor
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.decoder_prenet_dropout = decoder_prenet_dropout
|
self.decoder_prenet_dropout = decoder_prenet_dropout
|
||||||
|
|
||||||
# start and end: though it is only used in predict
|
# start and end: though it is only used in predict
|
||||||
# it can also be used in training
|
# it can also be used in training
|
||||||
dtype = paddle.get_default_dtype()
|
dtype = paddle.get_default_dtype()
|
||||||
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
|
self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype)
|
||||||
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
|
self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype)
|
||||||
self.stop_prob_index = 2
|
self.stop_prob_index = 2
|
||||||
|
|
||||||
# mutables
|
# mutables
|
||||||
self.r = max_reduction_factor # set it every call
|
self.r = max_reduction_factor # set it every call
|
||||||
self.drop_n_heads = 0
|
self.drop_n_heads = 0
|
||||||
|
|
||||||
def forward(self, text, mel):
|
def forward(self, text, mel):
|
||||||
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
||||||
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask)
|
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
|
||||||
|
encoded, mel, encoder_mask)
|
||||||
outputs = {
|
outputs = {
|
||||||
"mel_output": mel_output,
|
"mel_output": mel_output,
|
||||||
"mel_intermediate": mel_intermediate,
|
"mel_intermediate": mel_intermediate,
|
||||||
|
@ -420,51 +488,54 @@ class TransformerTTS(nn.Layer):
|
||||||
if embed.shape[1] > self.encoder_pe.shape[0]:
|
if embed.shape[1] > self.encoder_pe.shape[0]:
|
||||||
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
||||||
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
||||||
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
|
pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
|
||||||
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar
|
x = embed.scale(math.sqrt(
|
||||||
|
self.d_encoder)) + pos_enc * self.encoder_pe_scalar
|
||||||
x = F.dropout(x, self.dropout, training=self.training)
|
x = F.dropout(x, self.dropout, training=self.training)
|
||||||
|
|
||||||
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
|
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
|
||||||
encoder_padding_mask = paddle.unsqueeze(
|
encoder_padding_mask = paddle.unsqueeze(
|
||||||
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1)
|
masking.id_mask(
|
||||||
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads)
|
text, self.padding_idx, dtype=x.dtype), 1)
|
||||||
|
x, attention_weights = self.encoder(x, encoder_padding_mask,
|
||||||
|
self.drop_n_heads)
|
||||||
return x, attention_weights, encoder_padding_mask
|
return x, attention_weights, encoder_padding_mask
|
||||||
|
|
||||||
def decode(self, encoder_output, input, encoder_padding_mask):
|
def decode(self, encoder_output, input, encoder_padding_mask):
|
||||||
batch_size, T_dec, mel_dim = input.shape
|
batch_size, T_dec, mel_dim = input.shape
|
||||||
|
|
||||||
x = self.decoder_prenet(input, self.decoder_prenet_dropout)
|
x = self.decoder_prenet(input, self.decoder_prenet_dropout)
|
||||||
# twice its length if needed
|
# twice its length if needed
|
||||||
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
|
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
|
||||||
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
||||||
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
|
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
|
||||||
pos_enc = self.decoder_pe[:T_dec*self.r:self.r, :]
|
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
|
||||||
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
x = x.scale(math.sqrt(
|
||||||
|
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
||||||
x = F.dropout(x, self.dropout, training=self.training)
|
x = F.dropout(x, self.dropout, training=self.training)
|
||||||
|
|
||||||
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
|
||||||
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype)
|
decoder_padding_mask = masking.feature_mask(
|
||||||
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
input, axis=-1, dtype=input.dtype)
|
||||||
|
decoder_mask = masking.combine_mask(
|
||||||
|
decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
||||||
decoder_output, _, cross_attention_weights = self.decoder(
|
decoder_output, _, cross_attention_weights = self.decoder(
|
||||||
x,
|
x, encoder_output, encoder_output, encoder_padding_mask,
|
||||||
encoder_output,
|
decoder_mask, self.drop_n_heads)
|
||||||
encoder_output,
|
|
||||||
encoder_padding_mask,
|
|
||||||
decoder_mask,
|
|
||||||
self.drop_n_heads)
|
|
||||||
|
|
||||||
# use only parts of it
|
# use only parts of it
|
||||||
output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim]
|
output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim]
|
||||||
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
mel_intermediate = paddle.reshape(output_proj,
|
||||||
|
[batch_size, -1, mel_dim])
|
||||||
stop_logits = self.stop_conditioner(mel_intermediate)
|
stop_logits = self.stop_conditioner(mel_intermediate)
|
||||||
|
|
||||||
# cnn postnet
|
# cnn postnet
|
||||||
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
||||||
mel_output = self.decoder_postnet(mel_channel_first)
|
mel_output = self.decoder_postnet(mel_channel_first)
|
||||||
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
||||||
|
|
||||||
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
||||||
|
|
||||||
def predict(self, input, raw_input=True, max_length=1000, verbose=True):
|
def predict(self, input, raw_input=True, max_length=1000, verbose=True):
|
||||||
"""Predict log scale magnitude mel spectrogram from text input.
|
"""Predict log scale magnitude mel spectrogram from text input.
|
||||||
|
|
||||||
|
@ -475,26 +546,32 @@ class TransformerTTS(nn.Layer):
|
||||||
"""
|
"""
|
||||||
if raw_input:
|
if raw_input:
|
||||||
text_ids = paddle.to_tensor(self.frontend(input))
|
text_ids = paddle.to_tensor(self.frontend(input))
|
||||||
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
||||||
else:
|
else:
|
||||||
text_input = input
|
text_input = input
|
||||||
|
|
||||||
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||||
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||||
|
|
||||||
# encoder the text sequence
|
# encoder the text sequence
|
||||||
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
|
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
|
||||||
for _ in trange(int(max_length // self.r) + 1):
|
text_input)
|
||||||
|
for _ in range(int(max_length // self.r) + 1):
|
||||||
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
||||||
encoder_output, decoder_input, encoder_padding_mask)
|
encoder_output, decoder_input, encoder_padding_mask)
|
||||||
|
|
||||||
# extract last step and append it to decoder input
|
# extract last step and append it to decoder input
|
||||||
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1)
|
decoder_input = paddle.concat(
|
||||||
|
[decoder_input, mel_output[:, -1:, :]], 1)
|
||||||
# extract last r steps and append it to decoder output
|
# extract last r steps and append it to decoder output
|
||||||
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
decoder_output = paddle.concat(
|
||||||
|
[decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||||
|
|
||||||
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
||||||
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index):
|
if paddle.any(
|
||||||
|
paddle.argmax(
|
||||||
|
stop_logits[0, -self.r:, :], axis=-1) ==
|
||||||
|
self.stop_prob_index):
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Hits stop condition.")
|
print("Hits stop condition.")
|
||||||
break
|
break
|
||||||
|
@ -516,24 +593,28 @@ class TransformerTTSLoss(nn.Layer):
|
||||||
def __init__(self, stop_loss_scale):
|
def __init__(self, stop_loss_scale):
|
||||||
super(TransformerTTSLoss, self).__init__()
|
super(TransformerTTSLoss, self).__init__()
|
||||||
self.stop_loss_scale = stop_loss_scale
|
self.stop_loss_scale = stop_loss_scale
|
||||||
|
|
||||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
|
||||||
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
stop_probs):
|
||||||
|
mask = masking.feature_mask(
|
||||||
|
mel_target, axis=-1, dtype=mel_target.dtype)
|
||||||
mask1 = paddle.unsqueeze(mask, -1)
|
mask1 = paddle.unsqueeze(mask, -1)
|
||||||
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||||
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||||
|
|
||||||
mel_len = mask.shape[-1]
|
mel_len = mask.shape[-1]
|
||||||
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
last_position = F.one_hot(
|
||||||
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype)
|
mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
|
||||||
|
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(
|
||||||
|
mask.dtype)
|
||||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||||
|
|
||||||
loss = mel_loss1 + mel_loss2 + stop_loss
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||||
losses = dict(
|
losses = dict(
|
||||||
loss=loss, # total loss
|
loss=loss, # total loss
|
||||||
mel_loss1=mel_loss1, # ouput mel loss
|
mel_loss1=mel_loss1, # ouput mel loss
|
||||||
mel_loss2=mel_loss2, # intermediate mel loss
|
mel_loss2=mel_loss2, # intermediate mel loss
|
||||||
stop_loss=stop_loss # stop prob loss
|
stop_loss=stop_loss # stop prob loss
|
||||||
)
|
)
|
||||||
return losses
|
return losses
|
||||||
|
@ -542,26 +623,29 @@ class TransformerTTSLoss(nn.Layer):
|
||||||
class AdaptiveTransformerTTSLoss(nn.Layer):
|
class AdaptiveTransformerTTSLoss(nn.Layer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(AdaptiveTransformerTTSLoss, self).__init__()
|
super(AdaptiveTransformerTTSLoss, self).__init__()
|
||||||
|
|
||||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
|
||||||
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
stop_probs):
|
||||||
|
mask = masking.feature_mask(
|
||||||
|
mel_target, axis=-1, dtype=mel_target.dtype)
|
||||||
mask1 = paddle.unsqueeze(mask, -1)
|
mask1 = paddle.unsqueeze(mask, -1)
|
||||||
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||||
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||||
|
|
||||||
batch_size, mel_len = mask.shape
|
batch_size, mel_len = mask.shape
|
||||||
valid_lengths = mask.sum(-1).astype("int64")
|
valid_lengths = mask.sum(-1).astype("int64")
|
||||||
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
||||||
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
||||||
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
|
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(
|
||||||
|
mask.dtype)
|
||||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||||
|
|
||||||
loss = mel_loss1 + mel_loss2 + stop_loss
|
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||||
losses = dict(
|
losses = dict(
|
||||||
loss=loss, # total loss
|
loss=loss, # total loss
|
||||||
mel_loss1=mel_loss1, # ouput mel loss
|
mel_loss1=mel_loss1, # ouput mel loss
|
||||||
mel_loss2=mel_loss2, # intermediate mel loss
|
mel_loss2=mel_loss2, # intermediate mel loss
|
||||||
stop_loss=stop_loss # stop prob loss
|
stop_loss=stop_loss # stop prob loss
|
||||||
)
|
)
|
||||||
return losses
|
return losses
|
||||||
|
|
|
@ -1,10 +1,30 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True):
|
|
||||||
|
def scaled_dot_product_attention(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
mask=None,
|
||||||
|
dropout=0.0,
|
||||||
|
training=True):
|
||||||
"""
|
"""
|
||||||
scaled dot product attention with mask. Assume q, k, v all have the same
|
scaled dot product attention with mask. Assume q, k, v all have the same
|
||||||
leader dimensions(denoted as * in descriptions below). Dropout is applied to
|
leader dimensions(denoted as * in descriptions below). Dropout is applied to
|
||||||
|
@ -22,18 +42,19 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
|
||||||
out (Tensor): shape(*, T_q, d_v), the context vector.
|
out (Tensor): shape(*, T_q, d_v), the context vector.
|
||||||
attn_weights (Tensor): shape(*, T_q, T_k), the attention weights.
|
attn_weights (Tensor): shape(*, T_q, T_k), the attention weights.
|
||||||
"""
|
"""
|
||||||
d = q.shape[-1] # we only support imperative execution
|
d = q.shape[-1] # we only support imperative execution
|
||||||
qk = paddle.matmul(q, k, transpose_y=True)
|
qk = paddle.matmul(q, k, transpose_y=True)
|
||||||
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
|
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
|
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
|
||||||
|
|
||||||
attn_weights = F.softmax(scaled_logit, axis=-1)
|
attn_weights = F.softmax(scaled_logit, axis=-1)
|
||||||
attn_weights = F.dropout(attn_weights, dropout, training=training)
|
attn_weights = F.dropout(attn_weights, dropout, training=training)
|
||||||
out = paddle.matmul(attn_weights, v)
|
out = paddle.matmul(attn_weights, v)
|
||||||
return out, attn_weights
|
return out, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def drop_head(x, drop_n_heads, training):
|
def drop_head(x, drop_n_heads, training):
|
||||||
"""
|
"""
|
||||||
Drop n heads from multiple context vectors.
|
Drop n heads from multiple context vectors.
|
||||||
|
@ -48,12 +69,12 @@ def drop_head(x, drop_n_heads, training):
|
||||||
"""
|
"""
|
||||||
if not training or (drop_n_heads == 0):
|
if not training or (drop_n_heads == 0):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
batch_size, num_heads, _, _ = x.shape
|
batch_size, num_heads, _, _ = x.shape
|
||||||
# drop all heads
|
# drop all heads
|
||||||
if num_heads == drop_n_heads:
|
if num_heads == drop_n_heads:
|
||||||
return paddle.zeros_like(x)
|
return paddle.zeros_like(x)
|
||||||
|
|
||||||
mask = np.ones([batch_size, num_heads])
|
mask = np.ones([batch_size, num_heads])
|
||||||
mask[:, :drop_n_heads] = 0
|
mask[:, :drop_n_heads] = 0
|
||||||
for subarray in mask:
|
for subarray in mask:
|
||||||
|
@ -63,18 +84,21 @@ def drop_head(x, drop_n_heads, training):
|
||||||
out = x * paddle.to_tensor(mask)
|
out = x * paddle.to_tensor(mask)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _split_heads(x, num_heads):
|
def _split_heads(x, num_heads):
|
||||||
batch_size, time_steps, _ = x.shape
|
batch_size, time_steps, _ = x.shape
|
||||||
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
|
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
|
||||||
x = paddle.transpose(x, [0, 2, 1, 3])
|
x = paddle.transpose(x, [0, 2, 1, 3])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _concat_heads(x):
|
def _concat_heads(x):
|
||||||
batch_size, _, time_steps, _ = x.shape
|
batch_size, _, time_steps, _ = x.shape
|
||||||
x = paddle.transpose(x, [0, 2, 1, 3])
|
x = paddle.transpose(x, [0, 2, 1, 3])
|
||||||
x = paddle.reshape(x, [batch_size, time_steps, -1])
|
x = paddle.reshape(x, [batch_size, time_steps, -1])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
# Standard implementations of Monohead Attention & Multihead Attention
|
# Standard implementations of Monohead Attention & Multihead Attention
|
||||||
class MonoheadAttention(nn.Layer):
|
class MonoheadAttention(nn.Layer):
|
||||||
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
|
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
|
||||||
|
@ -99,10 +123,10 @@ class MonoheadAttention(nn.Layer):
|
||||||
self.affine_k = nn.Linear(model_dim, k_dim)
|
self.affine_k = nn.Linear(model_dim, k_dim)
|
||||||
self.affine_v = nn.Linear(model_dim, v_dim)
|
self.affine_v = nn.Linear(model_dim, v_dim)
|
||||||
self.affine_o = nn.Linear(v_dim, model_dim)
|
self.affine_o = nn.Linear(v_dim, model_dim)
|
||||||
|
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, q, k, v, mask):
|
def forward(self, q, k, v, mask):
|
||||||
"""
|
"""
|
||||||
Compute context vector and attention weights.
|
Compute context vector and attention weights.
|
||||||
|
@ -119,22 +143,28 @@ class MonoheadAttention(nn.Layer):
|
||||||
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
||||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
||||||
"""
|
"""
|
||||||
q = self.affine_q(q) # (B, T, C)
|
q = self.affine_q(q) # (B, T, C)
|
||||||
k = self.affine_k(k)
|
k = self.affine_k(k)
|
||||||
v = self.affine_v(v)
|
v = self.affine_v(v)
|
||||||
|
|
||||||
context_vectors, attention_weights = scaled_dot_product_attention(
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||||
q, k, v, mask, self.dropout, self.training)
|
q, k, v, mask, self.dropout, self.training)
|
||||||
|
|
||||||
out = self.affine_o(context_vectors)
|
out = self.affine_o(context_vectors)
|
||||||
return out, attention_weights
|
return out, attention_weights
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.Layer):
|
class MultiheadAttention(nn.Layer):
|
||||||
"""
|
"""
|
||||||
Multihead scaled dot product attention.
|
Multihead scaled dot product attention.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model_dim, num_heads, dropout=0.0, k_dim=None, v_dim=None):
|
|
||||||
|
def __init__(self,
|
||||||
|
model_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.0,
|
||||||
|
k_dim=None,
|
||||||
|
v_dim=None):
|
||||||
"""
|
"""
|
||||||
Multihead Attention module.
|
Multihead Attention module.
|
||||||
|
|
||||||
|
@ -154,7 +184,7 @@ class MultiheadAttention(nn.Layer):
|
||||||
ValueError: if model_dim is not divisible by num_heads
|
ValueError: if model_dim is not divisible by num_heads
|
||||||
"""
|
"""
|
||||||
super(MultiheadAttention, self).__init__()
|
super(MultiheadAttention, self).__init__()
|
||||||
if model_dim % num_heads !=0:
|
if model_dim % num_heads != 0:
|
||||||
raise ValueError("model_dim must be divisible by num_heads")
|
raise ValueError("model_dim must be divisible by num_heads")
|
||||||
depth = model_dim // num_heads
|
depth = model_dim // num_heads
|
||||||
k_dim = k_dim or depth
|
k_dim = k_dim or depth
|
||||||
|
@ -163,11 +193,11 @@ class MultiheadAttention(nn.Layer):
|
||||||
self.affine_k = nn.Linear(model_dim, num_heads * k_dim)
|
self.affine_k = nn.Linear(model_dim, num_heads * k_dim)
|
||||||
self.affine_v = nn.Linear(model_dim, num_heads * v_dim)
|
self.affine_v = nn.Linear(model_dim, num_heads * v_dim)
|
||||||
self.affine_o = nn.Linear(num_heads * v_dim, model_dim)
|
self.affine_o = nn.Linear(num_heads * v_dim, model_dim)
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, q, k, v, mask):
|
def forward(self, q, k, v, mask):
|
||||||
"""
|
"""
|
||||||
Compute context vector and attention weights.
|
Compute context vector and attention weights.
|
||||||
|
@ -184,14 +214,67 @@ class MultiheadAttention(nn.Layer):
|
||||||
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
||||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
||||||
"""
|
"""
|
||||||
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
||||||
k = _split_heads(self.affine_k(k), self.num_heads)
|
k = _split_heads(self.affine_k(k), self.num_heads)
|
||||||
v = _split_heads(self.affine_v(v), self.num_heads)
|
v = _split_heads(self.affine_v(v), self.num_heads)
|
||||||
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
||||||
|
|
||||||
context_vectors, attention_weights = scaled_dot_product_attention(
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||||
q, k, v, mask, self.dropout, self.training)
|
q, k, v, mask, self.dropout, self.training)
|
||||||
# NOTE: there is more sophisticated implementation: Scheduled DropHead
|
# NOTE: there is more sophisticated implementation: Scheduled DropHead
|
||||||
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
||||||
out = self.affine_o(context_vectors)
|
out = self.affine_o(context_vectors)
|
||||||
return out, attention_weights
|
return out, attention_weights
|
||||||
|
|
||||||
|
|
||||||
|
class LocationSensitiveAttention(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_query: int,
|
||||||
|
d_key: int,
|
||||||
|
d_attention: int,
|
||||||
|
location_filters: int,
|
||||||
|
location_kernel_size: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
|
||||||
|
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
|
||||||
|
self.value = nn.Linear(d_attention, 1, bias_attr=False)
|
||||||
|
|
||||||
|
#Location Layer
|
||||||
|
self.location_conv = nn.Conv1D(
|
||||||
|
2,
|
||||||
|
location_filters,
|
||||||
|
location_kernel_size,
|
||||||
|
1,
|
||||||
|
int((location_kernel_size - 1) / 2),
|
||||||
|
1,
|
||||||
|
bias_attr=False,
|
||||||
|
data_format='NLC')
|
||||||
|
self.location_layer = nn.Linear(
|
||||||
|
location_filters, d_attention, bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
query,
|
||||||
|
processed_key,
|
||||||
|
value,
|
||||||
|
attention_weights_cat,
|
||||||
|
mask=None):
|
||||||
|
|
||||||
|
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||||
|
processed_attention_weights = self.location_layer(
|
||||||
|
self.location_conv(attention_weights_cat))
|
||||||
|
alignment = self.value(
|
||||||
|
paddle.tanh(processed_attention_weights + processed_key +
|
||||||
|
processed_query))
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
alignment = alignment + (1.0 - mask) * -1e9
|
||||||
|
|
||||||
|
attention_weights = F.softmax(alignment, axis=1)
|
||||||
|
attention_context = paddle.matmul(
|
||||||
|
attention_weights, value, transpose_x=True)
|
||||||
|
|
||||||
|
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
|
||||||
|
attention_context = paddle.squeeze(attention_context, axis=[1])
|
||||||
|
|
||||||
|
return attention_context, attention_weights
|
||||||
|
|
|
@ -1,6 +1,21 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
class Conv1dCell(nn.Conv1D):
|
class Conv1dCell(nn.Conv1D):
|
||||||
"""
|
"""
|
||||||
A subclass of Conv1d layer, which can be used like an RNN cell. It can take
|
A subclass of Conv1d layer, which can be used like an RNN cell. It can take
|
||||||
|
@ -14,30 +29,33 @@ class Conv1dCell(nn.Conv1D):
|
||||||
|
|
||||||
As a result, these arguments are removed form the initializer.
|
As a result, these arguments are removed form the initializer.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
weight_attr=None,
|
weight_attr=None,
|
||||||
bias_attr=None):
|
bias_attr=None):
|
||||||
_dilation = dilation[0] if isinstance(dilation, (tuple, list)) else dilation
|
_dilation = dilation[0] if isinstance(dilation,
|
||||||
_kernel_size = kernel_size[0] if isinstance(kernel_size, (tuple, list)) else kernel_size
|
(tuple, list)) else dilation
|
||||||
|
_kernel_size = kernel_size[0] if isinstance(kernel_size, (
|
||||||
|
tuple, list)) else kernel_size
|
||||||
self._r = 1 + (_kernel_size - 1) * _dilation
|
self._r = 1 + (_kernel_size - 1) * _dilation
|
||||||
super(Conv1dCell, self).__init__(
|
super(Conv1dCell, self).__init__(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
padding=(self._r - 1, 0),
|
padding=(self._r - 1, 0),
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
weight_attr=weight_attr,
|
weight_attr=weight_attr,
|
||||||
bias_attr=bias_attr,
|
bias_attr=bias_attr,
|
||||||
data_format="NCL")
|
data_format="NCL")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def receptive_field(self):
|
def receptive_field(self):
|
||||||
return self._r
|
return self._r
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
if self.training:
|
if self.training:
|
||||||
raise Exception("only use start_sequence in evaluation")
|
raise Exception("only use start_sequence in evaluation")
|
||||||
|
@ -50,15 +68,15 @@ class Conv1dCell(nn.Conv1D):
|
||||||
# see also: https://github.com/pytorch/pytorch/issues/47588
|
# see also: https://github.com/pytorch/pytorch/issues/47588
|
||||||
for hook in self._forward_pre_hooks.values():
|
for hook in self._forward_pre_hooks.values():
|
||||||
hook(self, None)
|
hook(self, None)
|
||||||
self._reshaped_weight = paddle.reshape(
|
self._reshaped_weight = paddle.reshape(self.weight,
|
||||||
self.weight, (self._out_channels, -1))
|
(self._out_channels, -1))
|
||||||
|
|
||||||
def initialize_buffer(self, x_t):
|
def initialize_buffer(self, x_t):
|
||||||
batch_size, _ = x_t.shape
|
batch_size, _ = x_t.shape
|
||||||
self._buffer = paddle.zeros(
|
self._buffer = paddle.zeros(
|
||||||
(batch_size, self._in_channels, self.receptive_field),
|
(batch_size, self._in_channels, self.receptive_field),
|
||||||
dtype=x_t.dtype)
|
dtype=x_t.dtype)
|
||||||
|
|
||||||
def update_buffer(self, x_t):
|
def update_buffer(self, x_t):
|
||||||
self._buffer = paddle.concat(
|
self._buffer = paddle.concat(
|
||||||
[self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1)
|
[self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1)
|
||||||
|
@ -74,7 +92,7 @@ class Conv1dCell(nn.Conv1D):
|
||||||
if self.receptive_field > 1:
|
if self.receptive_field > 1:
|
||||||
if self._buffer is None:
|
if self._buffer is None:
|
||||||
self.initialize_buffer(x_t)
|
self.initialize_buffer(x_t)
|
||||||
|
|
||||||
# update buffer
|
# update buffer
|
||||||
self.update_buffer(x_t)
|
self.update_buffer(x_t)
|
||||||
if self._dilation[0] > 1:
|
if self._dilation[0] > 1:
|
||||||
|
@ -90,20 +108,34 @@ class Conv1dCell(nn.Conv1D):
|
||||||
|
|
||||||
|
|
||||||
class Conv1dBatchNorm(nn.Layer):
|
class Conv1dBatchNorm(nn.Layer):
|
||||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
def __init__(self,
|
||||||
weight_attr=None, bias_attr=None, data_format="NCL"):
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
weight_attr=None,
|
||||||
|
bias_attr=None,
|
||||||
|
data_format="NCL",
|
||||||
|
momentum=0.9,
|
||||||
|
epsilon=1e-05):
|
||||||
super(Conv1dBatchNorm, self).__init__()
|
super(Conv1dBatchNorm, self).__init__()
|
||||||
# TODO(chenfeiyu): carefully initialize Conv1d's weight
|
self.conv = nn.Conv1D(
|
||||||
self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride,
|
in_channels,
|
||||||
padding=padding,
|
out_channels,
|
||||||
weight_attr=weight_attr,
|
kernel_size,
|
||||||
bias_attr=bias_attr,
|
stride,
|
||||||
data_format=data_format)
|
padding=padding,
|
||||||
# TODO: channel last, but BatchNorm1d does not support channel last layout
|
weight_attr=weight_attr,
|
||||||
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format)
|
bias_attr=bias_attr,
|
||||||
|
data_format=data_format)
|
||||||
|
self.bn = nn.BatchNorm1D(
|
||||||
|
out_channels,
|
||||||
|
momentum=momentum,
|
||||||
|
epsilon=epsilon,
|
||||||
|
data_format=data_format)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.bn(x)
|
x = self.bn(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -11,6 +25,7 @@ from collections import defaultdict
|
||||||
import parakeet
|
import parakeet
|
||||||
from parakeet.utils import checkpoint, mp_tools
|
from parakeet.utils import checkpoint, mp_tools
|
||||||
|
|
||||||
|
|
||||||
class ExperimentBase(object):
|
class ExperimentBase(object):
|
||||||
"""
|
"""
|
||||||
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
|
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
|
||||||
|
@ -22,7 +37,7 @@ class ExperimentBase(object):
|
||||||
We have some conventions to follow.
|
We have some conventions to follow.
|
||||||
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
|
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
|
||||||
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
|
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
|
||||||
3. There are three method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
|
3. There are four method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
|
||||||
|
|
||||||
Feel free to add/overwrite other methods and standalone functions if you need.
|
Feel free to add/overwrite other methods and standalone functions if you need.
|
||||||
|
|
||||||
|
@ -54,6 +69,7 @@ class ExperimentBase(object):
|
||||||
main(config, args)
|
main(config, args)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, args):
|
def __init__(self, config, args):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.args = args
|
self.args = args
|
||||||
|
@ -67,7 +83,7 @@ class ExperimentBase(object):
|
||||||
self.setup_visualizer()
|
self.setup_visualizer()
|
||||||
self.setup_logger()
|
self.setup_logger()
|
||||||
self.setup_checkpointer()
|
self.setup_checkpointer()
|
||||||
|
|
||||||
self.setup_dataloader()
|
self.setup_dataloader()
|
||||||
self.setup_model()
|
self.setup_model()
|
||||||
|
|
||||||
|
@ -82,13 +98,13 @@ class ExperimentBase(object):
|
||||||
dist.init_parallel_env()
|
dist.init_parallel_env()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
checkpoint.save_parameters(
|
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
||||||
self.checkpoint_dir, self.iteration, self.model, self.optimizer)
|
self.model, self.optimizer)
|
||||||
|
|
||||||
def resume_or_load(self):
|
def resume_or_load(self):
|
||||||
iteration = checkpoint.load_parameters(
|
iteration = checkpoint.load_parameters(
|
||||||
self.model,
|
self.model,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
checkpoint_path=self.args.checkpoint_path)
|
checkpoint_path=self.args.checkpoint_path)
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
@ -115,10 +131,10 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
if self.iteration % self.config.training.valid_interval == 0:
|
if self.iteration % self.config.training.valid_interval == 0:
|
||||||
self.valid()
|
self.valid()
|
||||||
|
|
||||||
if self.iteration % self.config.training.save_interval == 0:
|
if self.iteration % self.config.training.save_interval == 0:
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self.resume_or_load()
|
self.resume_or_load()
|
||||||
try:
|
try:
|
||||||
|
@ -126,7 +142,7 @@ class ExperimentBase(object):
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.save()
|
self.save()
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_output_dir(self):
|
def setup_output_dir(self):
|
||||||
# output dir
|
# output dir
|
||||||
|
@ -134,7 +150,7 @@ class ExperimentBase(object):
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_checkpointer(self):
|
def setup_checkpointer(self):
|
||||||
# checkpoint dir
|
# checkpoint dir
|
||||||
|
@ -161,7 +177,7 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def dump_config(self):
|
def dump_config(self):
|
||||||
with open(self.output_dir / "config.yaml", 'wt') as f:
|
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||||
print(self.config, file=f)
|
print(self.config, file=f)
|
||||||
|
|
||||||
def train_batch(self):
|
def train_batch(self):
|
||||||
|
@ -177,4 +193,3 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
def setup_dataloader(self):
|
def setup_dataloader(self):
|
||||||
raise NotImplementedError("setup_dataloader should be implemented.")
|
raise NotImplementedError("setup_dataloader should be implemented.")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue