Merge pull request #45 from lfchener/reborn

add TTS model tacotron2
This commit is contained in:
Feiyu Chan 2020-12-11 16:33:22 +08:00 committed by GitHub
commit 1d2e93c75f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1016 additions and 203 deletions

View File

@ -0,0 +1,32 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import unicodedata
from builtins import str as unicode
from parakeet.frontend.normalizer.numbers import normalize_numbers
def normalize(sentence):
# preprocessing
sentence = unicode(sentence)
sentence = normalize_numbers(sentence)
sentence = ''.join(
char for char in unicodedata.normalize('NFD', sentence)
if unicodedata.category(char) != 'Mn') # Strip accents
sentence = sentence.lower()
sentence = re.sub(r"[^ a-z'.,?!\-]", "", sentence)
sentence = sentence.replace("i.e.", "that is")
sentence = sentence.replace("e.g.", "for example")
return sentence.split()

View File

@ -1,3 +1,84 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# number expansion is not that easy # number expansion is not that easy
import num2words
import inflect import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(
num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

View File

@ -1,3 +1,17 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union from typing import Union
from g2p_en import G2p from g2p_en import G2p
@ -5,8 +19,9 @@ from g2pM import G2pM
from parakeet.frontend import Vocab from parakeet.frontend import Vocab
from opencc import OpenCC from opencc import OpenCC
from parakeet.frontend.punctuation import get_punctuations from parakeet.frontend.punctuation import get_punctuations
from parakeet.frontend.normalizer.normalizer import normalize
__all__ = ["Phonetics", "English", "Chinese"] __all__ = ["Phonetics", "English", "EnglishCharacter", "Chinese"]
class Phonetics(ABC): class Phonetics(ABC):
@ -22,6 +37,7 @@ class Phonetics(ABC):
def numericalize(self, phonemes): def numericalize(self, phonemes):
pass pass
class English(Phonetics): class English(Phonetics):
def __init__(self): def __init__(self):
self.backend = G2p() self.backend = G2p()
@ -38,7 +54,48 @@ class English(Phonetics):
return phonemes return phonemes
def numericalize(self, phonemes): def numericalize(self, phonemes):
ids = [self.vocab.lookup(item) for item in phonemes if item in self.vocab.stoi] ids = [
self.vocab.lookup(item) for item in phonemes
if item in self.vocab.stoi
]
return ids
def reverse(self, ids):
return [self.vocab.reverse(i) for i in ids]
def __call__(self, sentence):
return self.numericalize(self.phoneticize(sentence))
@property
def vocab_size(self):
return len(self.vocab)
class EnglishCharacter(Phonetics):
def __init__(self):
self.backend = G2p()
self.graphemes = list(self.backend.graphemes)
self.punctuations = get_punctuations("en")
self.vocab = Vocab(self.graphemes + self.punctuations)
def phoneticize(self, sentence):
start = self.vocab.start_symbol
end = self.vocab.end_symbol
words = ([] if start is None else [start]) \
+ normalize(sentence) \
+ ([] if end is None else [end])
return words
def numericalize(self, words):
ids = []
for word in words:
if word in self.vocab.stoi:
ids.append(self.vocab.lookup(word))
continue
for char in word:
if char in self.vocab.stoi:
ids.append(self.vocab.lookup(char))
return ids return ids
def reverse(self, ids): def reverse(self, ids):
@ -61,7 +118,9 @@ class Chinese(Phonetics):
self.vocab = Vocab(self.phonemes + self.punctuations) self.vocab = Vocab(self.phonemes + self.punctuations)
def _get_all_syllables(self): def _get_all_syllables(self):
all_syllables = set([syllable for k, v in self.backend.cedict.items() for syllable in v]) all_syllables = set([
syllable for k, v in self.backend.cedict.items() for syllable in v
])
return list(all_syllables) return list(all_syllables)
def phoneticize(self, sentence): def phoneticize(self, sentence):

View File

@ -0,0 +1,427 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
import parakeet
from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules import masking
__all__ = ["Tacotron2", "Tacotron2Loss"]
class DecoderPreNet(nn.Layer):
def __init__(self,
d_input: int,
d_hidden: int,
d_output: int,
dropout_rate: int=0.2):
super().__init__()
self.dropout_rate = dropout_rate
self.linear1 = nn.Linear(d_input, d_hidden, bias_attr=False)
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
def forward(self, x):
x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate)
output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate)
return output
class DecoderPostNet(nn.Layer):
def __init__(self,
d_mels: int=80,
d_hidden: int=512,
kernel_size: int=5,
padding: int=0,
num_layers: int=5,
dropout=0.1):
super().__init__()
self.dropout = dropout
self.num_layers = num_layers
self.conv_batchnorms = nn.LayerList()
k = math.sqrt(1.0 / (d_mels * kernel_size))
self.conv_batchnorms.append(
Conv1dBatchNorm(
d_mels,
d_hidden,
kernel_size=kernel_size,
padding=padding,
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-k, high=k)),
data_format='NLC'))
k = math.sqrt(1.0 / (d_hidden * kernel_size))
self.conv_batchnorms.extend([
Conv1dBatchNorm(
d_hidden,
d_hidden,
kernel_size=kernel_size,
padding=padding,
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-k, high=k)),
data_format='NLC') for i in range(1, num_layers - 1)
])
self.conv_batchnorms.append(
Conv1dBatchNorm(
d_hidden,
d_mels,
kernel_size=kernel_size,
padding=padding,
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-k, high=k)),
data_format='NLC'))
def forward(self, input):
for i in range(len(self.conv_batchnorms) - 1):
input = F.dropout(
F.tanh(self.conv_batchnorms[i](input), self.dropout))
input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
self.dropout)
return input
class Tacotron2Encoder(nn.Layer):
def __init__(self,
d_hidden: int,
conv_layers: int,
kernel_size: int,
p_dropout: float):
super().__init__()
k = math.sqrt(1.0 / (d_hidden * kernel_size))
self.conv_batchnorms = paddle.nn.LayerList([
Conv1dBatchNorm(
d_hidden,
d_hidden,
kernel_size,
stride=1,
padding=int((kernel_size - 1) / 2),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-k, high=k)),
data_format='NLC') for i in range(conv_layers)
])
self.p_dropout = p_dropout
self.hidden_size = int(d_hidden / 2)
self.lstm = nn.LSTM(
d_hidden, self.hidden_size, direction="bidirectional")
def forward(self, x, input_lens=None):
for conv_batchnorm in self.conv_batchnorms:
x = F.dropout(F.relu(conv_batchnorm(x)),
self.p_dropout) #(B, T, C)
output, _ = self.lstm(inputs=x, sequence_length=input_lens)
return output
class Tacotron2Decoder(nn.Layer):
def __init__(self,
d_mels: int,
reduction_factor: int,
d_encoder: int,
d_prenet: int,
d_attention_rnn: int,
d_decoder_rnn: int,
d_attention: int,
attention_filters: int,
attention_kernel_size: int,
p_prenet_dropout: float,
p_attention_dropout: float,
p_decoder_dropout: float):
super().__init__()
self.d_mels = d_mels
self.reduction_factor = reduction_factor
self.d_encoder = d_encoder
self.d_attention_rnn = d_attention_rnn
self.d_decoder_rnn = d_decoder_rnn
self.p_attention_dropout = p_attention_dropout
self.p_decoder_dropout = p_decoder_dropout
self.prenet = DecoderPreNet(
d_mels * reduction_factor,
d_prenet,
d_prenet,
dropout_rate=p_prenet_dropout)
self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn)
self.attention_layer = LocationSensitiveAttention(
d_attention_rnn, d_encoder, d_attention, attention_filters,
attention_kernel_size)
self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder,
d_decoder_rnn)
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
d_mels * reduction_factor)
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
def _initialize_decoder_states(self, key):
batch_size = key.shape[0]
MAX_TIME = key.shape[1]
self.attention_hidden = paddle.zeros(
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
self.attention_cell = paddle.zeros(
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
self.decoder_hidden = paddle.zeros(
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
self.decoder_cell = paddle.zeros(
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
self.attention_weights = paddle.zeros(
shape=[batch_size, MAX_TIME], dtype=key.dtype)
self.attention_weights_cum = paddle.zeros(
shape=[batch_size, MAX_TIME], dtype=key.dtype)
self.attention_context = paddle.zeros(
shape=[batch_size, self.d_encoder], dtype=key.dtype)
self.key = key #[B, T, C]
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
def _decode(self, query):
cell_input = paddle.concat([query, self.attention_context], axis=-1)
# The first lstm layer
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden,
self.p_attention_dropout)
# Loaction sensitive attention
attention_weights_cat = paddle.stack(
[self.attention_weights, self.attention_weights_cum], axis=-1)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.processed_key, self.key,
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
# The second lstm layer
decoder_input = paddle.concat(
[self.attention_hidden, self.attention_context], axis=-1)
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, p=self.p_decoder_dropout)
# decode output one step
decoder_hidden_attention_context = paddle.concat(
[self.decoder_hidden, self.attention_context], axis=-1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
stop_logit = self.stop_layer(decoder_hidden_attention_context)
return decoder_output, stop_logit, self.attention_weights
def forward(self, keys, querys, mask):
querys = paddle.reshape(
querys,
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
querys = paddle.concat(
[
paddle.zeros(
shape=[
querys.shape[0], 1,
querys.shape[-1] * self.reduction_factor
],
dtype=querys.dtype), querys
],
axis=1)
querys = self.prenet(querys)
self._initialize_decoder_states(keys)
self.mask = mask
mel_outputs, stop_logits, alignments = [], [], []
while len(mel_outputs) < querys.shape[
1] - 1: # Ignore the last time step
query = querys[:, len(mel_outputs), :]
mel_output, stop_logit, attention_weights = self._decode(query)
mel_outputs += [mel_output]
stop_logits += [stop_logit]
alignments += [attention_weights]
alignments = paddle.stack(alignments, axis=1)
stop_logits = paddle.concat(stop_logits, axis=1)
mel_outputs = paddle.stack(mel_outputs, axis=1)
return mel_outputs, stop_logits, alignments
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
decoder_input = paddle.zeros(
shape=[key.shape[0], self.d_mels * self.reduction_factor],
dtype=key.dtype) #[B, C]
self.initialize_decoder_states(key)
self.mask = None
mel_outputs, stop_logits, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, stop_logit, alignment = self.decode(decoder_input)
mel_outputs += [mel_output]
stop_logits += [stop_logit]
alignments += [alignment]
if F.sigmoid(stop_logit) > stop_threshold:
break
elif len(mel_outputs) == max_decoder_steps:
print("Warning! Reached max decoder steps!!!")
break
decoder_input = mel_output
alignments = paddle.stack(alignments, axis=1)
stop_logits = paddle.concat(stop_logits, axis=1)
mel_outputs = paddle.stack(mel_outputs, axis=1)
return mel_outputs, stop_logits, alignments
class Tacotron2(nn.Layer):
"""
Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
This is a module of Spectrogram prediction network in Tacotron2 described
in `Natural TTS Synthesis
by Conditioning WaveNet on Mel Spectrogram Predictions`_,
which converts the sequence of characters
into the sequence of mel spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self,
frontend: parakeet.frontend.Phonetics,
d_mels: int=80,
d_encoder: int=512,
encoder_conv_layers: int=3,
encoder_kernel_size: int=5,
d_prenet: int=256,
d_attention_rnn: int=1024,
d_decoder_rnn: int=1024,
attention_filters: int=32,
attention_kernel_size: int=31,
d_attention: int=128,
d_postnet: int=512,
postnet_kernel_size: int=5,
postnet_conv_layers: int=5,
reduction_factor: int=1,
p_encoder_dropout: float=0.5,
p_prenet_dropout: float=0.5,
p_attention_dropout: float=0.1,
p_decoder_dropout: float=0.1,
p_postnet_dropout: float=0.5):
super().__init__()
std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder))
val = math.sqrt(3.0) * std # uniform bounds for std
self.embedding = nn.Embedding(
frontend.vocab_size,
d_encoder,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-val, high=val)))
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
encoder_kernel_size, p_encoder_dropout)
self.decoder = Tacotron2Decoder(
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
d_decoder_rnn, d_attention, attention_filters,
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
p_decoder_dropout)
self.postnet = DecoderPostNet(
d_mels=d_mels,
d_hidden=d_postnet,
kernel_size=postnet_kernel_size,
padding=int((postnet_kernel_size - 1) / 2),
num_layers=postnet_conv_layers,
dropout=p_postnet_dropout)
def forward(self, text_inputs, mels, text_lens, output_lens=None):
embedded_inputs = self.embedding(text_inputs)
encoder_outputs = self.encoder(embedded_inputs, text_lens)
mask = paddle.tensor.unsqueeze(
paddle.fluid.layers.sequence_mask(
x=text_lens, dtype=encoder_outputs.dtype), [-1])
mel_outputs, stop_logits, alignments = self.decoder(
encoder_outputs, mels, mask=mask)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
if output_lens is not None:
mask = paddle.tensor.unsqueeze(
paddle.fluid.layers.sequence_mask(x=output_lens),
[-1]) #[B, T, 1]
mel_outputs = mel_outputs * mask #[B, T, C]
mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C]
stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0]
) * 1e3 #[B, T]
outputs = {
"mel_output": mel_outputs,
"mel_outputs_postnet": mel_outputs_postnet,
"stop_logits": stop_logits,
"alignments": alignments
}
return outputs
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
embedded_inputs = self.embedding(text_inputs)
encoder_outputs = self.encoder(embedded_inputs)
mel_outputs, stop_logits, alignments = self.decoder.inference(
encoder_outputs,
stop_threshold=stop_threshold,
max_decoder_steps=max_decoder_steps)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
outputs = {
"mel_output": mel_outputs,
"mel_outputs_postnet": mel_outputs_postnet,
"stop_logits": stop_logits,
"alignments": alignments
}
return outputs
def predict(self, text):
# TODO(lifuchen): implement predict function to product mel from texts
pass
class Tacotron2Loss(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
mel_targets, stop_tokens):
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)
total_loss = mel_loss + post_mel_loss + stop_loss
losses = dict(
loss=total_loss,
mel_loss=mel_loss,
post_mel_loss=post_mel_loss,
stop_loss=stop_loss)
return losses

View File

@ -1,3 +1,17 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from tqdm import trange from tqdm import trange
import paddle import paddle
@ -15,6 +29,7 @@ from parakeet.modules import losses as L
__all__ = ["TransformerTTS", "TransformerTTSLoss"] __all__ = ["TransformerTTS", "TransformerTTSLoss"]
# Transformer TTS's own implementation of transformer # Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer): class MultiheadAttention(nn.Layer):
""" """
@ -25,7 +40,14 @@ class MultiheadAttention(nn.Layer):
Another deviation is that it concats the input query and context vector before Another deviation is that it concats the input query and context vector before
applying the output projection. applying the output projection.
""" """
def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None):
def __init__(self,
model_dim,
num_heads,
k_dim=None,
v_dim=None,
k_input_dim=None,
v_input_dim=None):
""" """
Args: Args:
model_dim (int): the feature size of query. model_dim (int): the feature size of query.
@ -80,7 +102,8 @@ class MultiheadAttention(nn.Layer):
context_vectors, attention_weights = scaled_dot_product_attention( context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, training=self.training) q, k, v, mask, training=self.training)
context_vectors = drop_head(context_vectors, drop_n_heads, self.training) context_vectors = drop_head(context_vectors, drop_n_heads,
self.training)
context_vectors = _concat_heads(context_vectors) # (B, T, h*C) context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
concat_feature = paddle.concat([q_in, context_vectors], -1) concat_feature = paddle.concat([q_in, context_vectors], -1)
@ -92,6 +115,7 @@ class TransformerEncoderLayer(nn.Layer):
""" """
Transformer encoder layer. Transformer encoder layer.
""" """
def __init__(self, d_model, n_heads, d_ffn, dropout=0.): def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
""" """
Args: Args:
@ -114,8 +138,10 @@ class TransformerEncoderLayer(nn.Layer):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x x_in = x
x = self.layer_norm1(x) x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) context_vector, attn_weights = self.self_mha(x, x, x, mask,
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) drop_n_heads)
context_vector = x_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
def _forward_ffn(self, x): def _forward_ffn(self, x):
@ -145,6 +171,7 @@ class TransformerDecoderLayer(nn.Layer):
""" """
Transformer decoder layer. Transformer decoder layer.
""" """
def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None): def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None):
""" """
Args: Args:
@ -158,7 +185,8 @@ class TransformerDecoderLayer(nn.Layer):
self.self_mha = MultiheadAttention(d_model, n_heads) self.self_mha = MultiheadAttention(d_model, n_heads)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder) self.cross_mha = MultiheadAttention(
d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder)
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout) self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
@ -170,16 +198,20 @@ class TransformerDecoderLayer(nn.Layer):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
x_in = x x_in = x
x = self.layer_norm1(x) x = self.layer_norm1(x)
context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) context_vector, attn_weights = self.self_mha(x, x, x, mask,
context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) drop_n_heads)
context_vector = x_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
def _forward_cross_mha(self, q, k, v, mask, drop_n_heads): def _forward_cross_mha(self, q, k, v, mask, drop_n_heads):
# PreLN scheme: Norm -> SubLayer -> Dropout -> Residual # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual
q_in = q q_in = q
q = self.layer_norm2(q) q = self.layer_norm2(q)
context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads) context_vector, attn_weights = self.cross_mha(q, k, v, mask,
context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training) drop_n_heads)
context_vector = q_in + F.dropout(
context_vector, self.dropout, training=self.training)
return context_vector, attn_weights return context_vector, attn_weights
def _forward_ffn(self, x): def _forward_ffn(self, x):
@ -204,8 +236,10 @@ class TransformerDecoderLayer(nn.Layer):
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
""" """
q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads) q, self_attn_weights = self._forward_self_mha(q, decoder_mask,
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads) drop_n_heads)
q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask,
drop_n_heads)
q = self._forward_ffn(q) q = self._forward_ffn(q)
return q, self_attn_weights, cross_attn_weights return q, self_attn_weights, cross_attn_weights
@ -214,7 +248,8 @@ class TransformerEncoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.): def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.):
super(TransformerEncoder, self).__init__() super(TransformerEncoder, self).__init__()
for _ in range(n_layers): for _ in range(n_layers):
self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout)) self.append(
TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout))
def forward(self, x, mask, drop_n_heads=0): def forward(self, x, mask, drop_n_heads=0):
""" """
@ -236,10 +271,18 @@ class TransformerEncoder(nn.LayerList):
class TransformerDecoder(nn.LayerList): class TransformerDecoder(nn.LayerList):
def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None): def __init__(self,
d_model,
n_heads,
d_ffn,
n_layers,
dropout=0.,
d_encoder=None):
super(TransformerDecoder, self).__init__() super(TransformerDecoder, self).__init__()
for _ in range(n_layers): for _ in range(n_layers):
self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder)) self.append(
TransformerDecoderLayer(
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0): def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
"""[summary] """[summary]
@ -260,7 +303,8 @@ class TransformerDecoder(nn.LayerList):
self_attention_weights = [] self_attention_weights = []
cross_attention_weights = [] cross_attention_weights = []
for layer in self: for layer in self:
q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads) q, self_attention_weights_i, cross_attention_weights_i = layer(
q, k, v, encoder_mask, decoder_mask, drop_n_heads)
self_attention_weights.append(self_attention_weights_i) self_attention_weights.append(self_attention_weights_i)
cross_attention_weights.append(cross_attention_weights_i) cross_attention_weights.append(cross_attention_weights_i)
return q, self_attention_weights, cross_attention_weights return q, self_attention_weights, cross_attention_weights
@ -268,6 +312,7 @@ class TransformerDecoder(nn.LayerList):
class MLPPreNet(nn.Layer): class MLPPreNet(nn.Layer):
"""Decoder's prenet.""" """Decoder's prenet."""
def __init__(self, d_input, d_hidden, d_output, dropout): def __init__(self, d_input, d_hidden, d_output, dropout):
# (lin + relu + dropout) * n + last projection # (lin + relu + dropout) * n + last projection
super(MLPPreNet, self).__init__() super(MLPPreNet, self).__init__()
@ -277,14 +322,22 @@ class MLPPreNet(nn.Layer):
self.dropout = dropout self.dropout = dropout
def forward(self, x, dropout): def forward(self, x, dropout):
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training) l1 = F.dropout(
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training) F.relu(self.lin1(x)), self.dropout, training=self.training)
l2 = F.dropout(
F.relu(self.lin2(l1)), self.dropout, training=self.training)
l3 = self.lin3(l2) l3 = self.lin3(l2)
return l3 return l3
# NOTE: not used in # NOTE: not used in
class CNNPreNet(nn.Layer): class CNNPreNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers, def __init__(self,
d_input,
d_hidden,
d_output,
kernel_size,
n_layers,
dropout=0.): dropout=0.):
# (conv + bn + relu + dropout) * n + last projection # (conv + bn + relu + dropout) * n + last projection
super(CNNPreNet, self).__init__() super(CNNPreNet, self).__init__()
@ -292,16 +345,21 @@ class CNNPreNet(nn.Layer):
c_in = d_input c_in = d_input
for _ in range(n_layers): for _ in range(n_layers):
self.convs.append( self.convs.append(
Conv1dBatchNorm(c_in, d_hidden, kernel_size, Conv1dBatchNorm(
c_in,
d_hidden,
kernel_size,
weight_attr=I.XavierUniform(), weight_attr=I.XavierUniform(),
padding="same", data_format="NLC")) padding="same",
data_format="NLC"))
c_in = d_hidden c_in = d_hidden
self.affine_out = nn.Linear(d_hidden, d_output) self.affine_out = nn.Linear(d_hidden, d_output)
self.dropout = dropout self.dropout = dropout
def forward(self, x): def forward(self, x):
for layer in self.convs: for layer in self.convs:
x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training) x = F.dropout(
F.relu(layer(x)), self.dropout, training=self.training)
x = self.affine_out(x) x = self.affine_out(x)
return x return x
@ -310,13 +368,17 @@ class CNNPostNet(nn.Layer):
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
super(CNNPostNet, self).__init__() super(CNNPostNet, self).__init__()
self.convs = nn.LayerList() self.convs = nn.LayerList()
kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, ) kernel_size = kernel_size if isinstance(kernel_size, (
tuple, list)) else (kernel_size, )
padding = (kernel_size[0] - 1, 0) padding = (kernel_size[0] - 1, 0)
for i in range(n_layers): for i in range(n_layers):
c_in = d_input if i == 0 else d_hidden c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append( self.convs.append(
Conv1dBatchNorm(c_in, c_out, kernel_size, Conv1dBatchNorm(
c_in,
c_out,
kernel_size,
weight_attr=I.XavierUniform(), weight_attr=I.XavierUniform(),
padding=padding)) padding=padding))
self.last_bn = nn.BatchNorm1D(d_output) self.last_bn = nn.BatchNorm1D(d_output)
@ -359,15 +421,16 @@ class TransformerTTS(nn.Layer):
# encoder # encoder
self.encoder_prenet = nn.Embedding( self.encoder_prenet = nn.Embedding(
frontend.vocab_size, d_encoder, frontend.vocab_size,
d_encoder,
padding_idx=frontend.vocab.padding_index, padding_idx=frontend.vocab.padding_index,
weight_attr=I.Uniform(-0.05, 0.05)) weight_attr=I.Uniform(-0.05, 0.05))
# position encoding matrix may be extended later # position encoding matrix may be extended later
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter( self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.)) [1], attr=I.Constant(1.))
self.encoder = TransformerEncoder( self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
d_encoder, n_heads, d_ffn, encoder_layers, dropout) encoder_layers, dropout)
# decoder # decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
@ -375,11 +438,15 @@ class TransformerTTS(nn.Layer):
self.decoder_pe_scalar = self.create_parameter( self.decoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.)) [1], attr=I.Constant(1.))
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
d_decoder, n_heads, d_ffn, decoder_layers, dropout, d_decoder,
n_heads,
d_ffn,
decoder_layers,
dropout,
d_encoder=d_encoder) d_encoder=d_encoder)
self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel) self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel)
self.decoder_postnet = CNNPostNet( self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel,
d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) postnet_kernel_size, postnet_layers)
self.stop_conditioner = nn.Linear(d_mel, 3) self.stop_conditioner = nn.Linear(d_mel, 3)
# specs # specs
@ -404,7 +471,8 @@ class TransformerTTS(nn.Layer):
def forward(self, text, mel): def forward(self, text, mel):
encoded, encoder_attention_weights, encoder_mask = self.encode(text) encoded, encoder_attention_weights, encoder_mask = self.encode(text)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask) mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
encoded, mel, encoder_mask)
outputs = { outputs = {
"mel_output": mel_output, "mel_output": mel_output,
"mel_intermediate": mel_intermediate, "mel_intermediate": mel_intermediate,
@ -421,13 +489,16 @@ class TransformerTTS(nn.Layer):
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2) new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder) self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
pos_enc = self.encoder_pe[:T_enc, :] # (T, C) pos_enc = self.encoder_pe[:T_enc, :] # (T, C)
x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar x = embed.scale(math.sqrt(
self.d_encoder)) + pos_enc * self.encoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training) x = F.dropout(x, self.dropout, training=self.training)
# TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask # TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask
encoder_padding_mask = paddle.unsqueeze( encoder_padding_mask = paddle.unsqueeze(
masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1) masking.id_mask(
x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads) text, self.padding_idx, dtype=x.dtype), 1)
x, attention_weights = self.encoder(x, encoder_padding_mask,
self.drop_n_heads)
return x, attention_weights, encoder_padding_mask return x, attention_weights, encoder_padding_mask
def decode(self, encoder_output, input, encoder_padding_mask): def decode(self, encoder_output, input, encoder_padding_mask):
@ -439,23 +510,23 @@ class TransformerTTS(nn.Layer):
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2) new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder) self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder)
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :] pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar x = x.scale(math.sqrt(
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
x = F.dropout(x, self.dropout, training=self.training) x = F.dropout(x, self.dropout, training=self.training)
no_future_mask = masking.future_mask(T_dec, dtype=input.dtype) no_future_mask = masking.future_mask(T_dec, dtype=input.dtype)
decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype) decoder_padding_mask = masking.feature_mask(
decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask) input, axis=-1, dtype=input.dtype)
decoder_mask = masking.combine_mask(
decoder_padding_mask.unsqueeze(-1), no_future_mask)
decoder_output, _, cross_attention_weights = self.decoder( decoder_output, _, cross_attention_weights = self.decoder(
x, x, encoder_output, encoder_output, encoder_padding_mask,
encoder_output, decoder_mask, self.drop_n_heads)
encoder_output,
encoder_padding_mask,
decoder_mask,
self.drop_n_heads)
# use only parts of it # use only parts of it
output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim] output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim]
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim]) mel_intermediate = paddle.reshape(output_proj,
[batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate) stop_logits = self.stop_conditioner(mel_intermediate)
# cnn postnet # cnn postnet
@ -483,18 +554,24 @@ class TransformerTTS(nn.Layer):
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
# encoder the text sequence # encoder the text sequence
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input) encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
for _ in trange(int(max_length // self.r) + 1): text_input)
for _ in range(int(max_length // self.r) + 1):
mel_output, _, cross_attention_weights, stop_logits = self.decode( mel_output, _, cross_attention_weights, stop_logits = self.decode(
encoder_output, decoder_input, encoder_padding_mask) encoder_output, decoder_input, encoder_padding_mask)
# extract last step and append it to decoder input # extract last step and append it to decoder input
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1) decoder_input = paddle.concat(
[decoder_input, mel_output[:, -1:, :]], 1)
# extract last r steps and append it to decoder output # extract last r steps and append it to decoder output
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) decoder_output = paddle.concat(
[decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition: (if any ouput frame of the output multiframes hits the stop condition) # stop condition: (if any ouput frame of the output multiframes hits the stop condition)
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index): if paddle.any(
paddle.argmax(
stop_logits[0, -self.r:, :], axis=-1) ==
self.stop_prob_index):
if verbose: if verbose:
print("Hits stop condition.") print("Hits stop condition.")
break break
@ -517,15 +594,19 @@ class TransformerTTSLoss(nn.Layer):
super(TransformerTTSLoss, self).__init__() super(TransformerTTSLoss, self).__init__()
self.stop_loss_scale = stop_loss_scale self.stop_loss_scale = stop_loss_scale
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) stop_probs):
mask = masking.feature_mask(
mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1) mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
mel_len = mask.shape[-1] mel_len = mask.shape[-1]
last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len) last_position = F.one_hot(
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype) mask.sum(-1).astype("int64") - 1, num_classes=mel_len)
mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(
mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy( stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
@ -543,8 +624,10 @@ class AdaptiveTransformerTTSLoss(nn.Layer):
def __init__(self): def __init__(self):
super(AdaptiveTransformerTTSLoss, self).__init__() super(AdaptiveTransformerTTSLoss, self).__init__()
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) stop_probs):
mask = masking.feature_mask(
mel_target, axis=-1, dtype=mel_target.dtype)
mask1 = paddle.unsqueeze(mask, -1) mask1 = paddle.unsqueeze(mask, -1)
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
@ -553,7 +636,8 @@ class AdaptiveTransformerTTSLoss(nn.Layer):
valid_lengths = mask.sum(-1).astype("int64") valid_lengths = mask.sum(-1).astype("int64")
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len) last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
stop_loss_scale = valid_lengths.sum() / batch_size - 1 stop_loss_scale = valid_lengths.sum() / batch_size - 1
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype) mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(
mask.dtype)
stop_loss = L.masked_softmax_with_cross_entropy( stop_loss = L.masked_softmax_with_cross_entropy(
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))

View File

@ -1,10 +1,30 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True):
def scaled_dot_product_attention(q,
k,
v,
mask=None,
dropout=0.0,
training=True):
""" """
scaled dot product attention with mask. Assume q, k, v all have the same scaled dot product attention with mask. Assume q, k, v all have the same
leader dimensions(denoted as * in descriptions below). Dropout is applied to leader dimensions(denoted as * in descriptions below). Dropout is applied to
@ -34,6 +54,7 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True)
out = paddle.matmul(attn_weights, v) out = paddle.matmul(attn_weights, v)
return out, attn_weights return out, attn_weights
def drop_head(x, drop_n_heads, training): def drop_head(x, drop_n_heads, training):
""" """
Drop n heads from multiple context vectors. Drop n heads from multiple context vectors.
@ -63,18 +84,21 @@ def drop_head(x, drop_n_heads, training):
out = x * paddle.to_tensor(mask) out = x * paddle.to_tensor(mask)
return out return out
def _split_heads(x, num_heads): def _split_heads(x, num_heads):
batch_size, time_steps, _ = x.shape batch_size, time_steps, _ = x.shape
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1]) x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
x = paddle.transpose(x, [0, 2, 1, 3]) x = paddle.transpose(x, [0, 2, 1, 3])
return x return x
def _concat_heads(x): def _concat_heads(x):
batch_size, _, time_steps, _ = x.shape batch_size, _, time_steps, _ = x.shape
x = paddle.transpose(x, [0, 2, 1, 3]) x = paddle.transpose(x, [0, 2, 1, 3])
x = paddle.reshape(x, [batch_size, time_steps, -1]) x = paddle.reshape(x, [batch_size, time_steps, -1])
return x return x
# Standard implementations of Monohead Attention & Multihead Attention # Standard implementations of Monohead Attention & Multihead Attention
class MonoheadAttention(nn.Layer): class MonoheadAttention(nn.Layer):
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None): def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
@ -134,7 +158,13 @@ class MultiheadAttention(nn.Layer):
""" """
Multihead scaled dot product attention. Multihead scaled dot product attention.
""" """
def __init__(self, model_dim, num_heads, dropout=0.0, k_dim=None, v_dim=None):
def __init__(self,
model_dim,
num_heads,
dropout=0.0,
k_dim=None,
v_dim=None):
""" """
Multihead Attention module. Multihead Attention module.
@ -195,3 +225,56 @@ class MultiheadAttention(nn.Layer):
context_vectors = _concat_heads(context_vectors) # (B, T, h*C) context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
out = self.affine_o(context_vectors) out = self.affine_o(context_vectors)
return out, attention_weights return out, attention_weights
class LocationSensitiveAttention(nn.Layer):
def __init__(self,
d_query: int,
d_key: int,
d_attention: int,
location_filters: int,
location_kernel_size: int):
super().__init__()
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
self.value = nn.Linear(d_attention, 1, bias_attr=False)
#Location Layer
self.location_conv = nn.Conv1D(
2,
location_filters,
location_kernel_size,
1,
int((location_kernel_size - 1) / 2),
1,
bias_attr=False,
data_format='NLC')
self.location_layer = nn.Linear(
location_filters, d_attention, bias_attr=False)
def forward(self,
query,
processed_key,
value,
attention_weights_cat,
mask=None):
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
processed_attention_weights = self.location_layer(
self.location_conv(attention_weights_cat))
alignment = self.value(
paddle.tanh(processed_attention_weights + processed_key +
processed_query))
if mask is not None:
alignment = alignment + (1.0 - mask) * -1e9
attention_weights = F.softmax(alignment, axis=1)
attention_context = paddle.matmul(
attention_weights, value, transpose_x=True)
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
attention_context = paddle.squeeze(attention_context, axis=[1])
return attention_context, attention_weights

View File

@ -1,6 +1,21 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
from paddle import nn from paddle import nn
class Conv1dCell(nn.Conv1D): class Conv1dCell(nn.Conv1D):
""" """
A subclass of Conv1d layer, which can be used like an RNN cell. It can take A subclass of Conv1d layer, which can be used like an RNN cell. It can take
@ -14,6 +29,7 @@ class Conv1dCell(nn.Conv1D):
As a result, these arguments are removed form the initializer. As a result, these arguments are removed form the initializer.
""" """
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
@ -21,8 +37,10 @@ class Conv1dCell(nn.Conv1D):
dilation=1, dilation=1,
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None):
_dilation = dilation[0] if isinstance(dilation, (tuple, list)) else dilation _dilation = dilation[0] if isinstance(dilation,
_kernel_size = kernel_size[0] if isinstance(kernel_size, (tuple, list)) else kernel_size (tuple, list)) else dilation
_kernel_size = kernel_size[0] if isinstance(kernel_size, (
tuple, list)) else kernel_size
self._r = 1 + (_kernel_size - 1) * _dilation self._r = 1 + (_kernel_size - 1) * _dilation
super(Conv1dCell, self).__init__( super(Conv1dCell, self).__init__(
in_channels, in_channels,
@ -50,8 +68,8 @@ class Conv1dCell(nn.Conv1D):
# see also: https://github.com/pytorch/pytorch/issues/47588 # see also: https://github.com/pytorch/pytorch/issues/47588
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
hook(self, None) hook(self, None)
self._reshaped_weight = paddle.reshape( self._reshaped_weight = paddle.reshape(self.weight,
self.weight, (self._out_channels, -1)) (self._out_channels, -1))
def initialize_buffer(self, x_t): def initialize_buffer(self, x_t):
batch_size, _ = x_t.shape batch_size, _ = x_t.shape
@ -90,20 +108,34 @@ class Conv1dCell(nn.Conv1D):
class Conv1dBatchNorm(nn.Layer): class Conv1dBatchNorm(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def __init__(self,
weight_attr=None, bias_attr=None, data_format="NCL"): in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
weight_attr=None,
bias_attr=None,
data_format="NCL",
momentum=0.9,
epsilon=1e-05):
super(Conv1dBatchNorm, self).__init__() super(Conv1dBatchNorm, self).__init__()
# TODO(chenfeiyu): carefully initialize Conv1d's weight self.conv = nn.Conv1D(
self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride, in_channels,
out_channels,
kernel_size,
stride,
padding=padding, padding=padding,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr,
data_format=data_format) data_format=data_format)
# TODO: channel last, but BatchNorm1d does not support channel last layout self.bn = nn.BatchNorm1D(
self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format) out_channels,
momentum=momentum,
epsilon=epsilon,
data_format=data_format)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
return x return x

View File

@ -1,3 +1,17 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time import time
import logging import logging
from pathlib import Path from pathlib import Path
@ -11,6 +25,7 @@ from collections import defaultdict
import parakeet import parakeet
from parakeet.utils import checkpoint, mp_tools from parakeet.utils import checkpoint, mp_tools
class ExperimentBase(object): class ExperimentBase(object):
""" """
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple. An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
@ -22,7 +37,7 @@ class ExperimentBase(object):
We have some conventions to follow. We have some conventions to follow.
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes. 1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment. 2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
3. There are three method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented. 3. There are four method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you need. Feel free to add/overwrite other methods and standalone functions if you need.
@ -54,6 +69,7 @@ class ExperimentBase(object):
main(config, args) main(config, args)
""" """
def __init__(self, config, args): def __init__(self, config, args):
self.config = config self.config = config
self.args = args self.args = args
@ -82,8 +98,8 @@ class ExperimentBase(object):
dist.init_parallel_env() dist.init_parallel_env()
def save(self): def save(self):
checkpoint.save_parameters( checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.checkpoint_dir, self.iteration, self.model, self.optimizer) self.model, self.optimizer)
def resume_or_load(self): def resume_or_load(self):
iteration = checkpoint.load_parameters( iteration = checkpoint.load_parameters(
@ -177,4 +193,3 @@ class ExperimentBase(object):
def setup_dataloader(self): def setup_dataloader(self):
raise NotImplementedError("setup_dataloader should be implemented.") raise NotImplementedError("setup_dataloader should be implemented.")