From c1e0aecdde9e9f80402f3ef17a1ea983cda5b9f9 Mon Sep 17 00:00:00 2001 From: iclementine Date: Fri, 16 Oct 2020 13:51:56 +0800 Subject: [PATCH] 1. import models into parakeet.models; 2. add predict for TransformerTTS and test its io. --- parakeet/models/__init__.py | 8 +++++ parakeet/models/clarinet.py | 2 ++ parakeet/models/deepvoice3.py | 2 ++ parakeet/models/transformer_tts.py | 52 ++++++++++++++++++++++++++++-- parakeet/models/waveflow.py | 3 +- parakeet/models/wavenet.py | 2 ++ tests/test_transformer_tts.py | 21 +++++++++--- 7 files changed, 82 insertions(+), 8 deletions(-) diff --git a/parakeet/models/__init__.py b/parakeet/models/__init__.py index abf198b..d8521da 100644 --- a/parakeet/models/__init__.py +++ b/parakeet/models/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from parakeet.models.clarinet import * +from parakeet.models.waveflow import * +from parakeet.models.wavenet import * + +from parakeet.models.transformer_tts import * +from parakeet.models.deepvoice3 import * +# from parakeet.models.fastspeech import * diff --git a/parakeet/models/clarinet.py b/parakeet/models/clarinet.py index 728fba3..ba859b2 100644 --- a/parakeet/models/clarinet.py +++ b/parakeet/models/clarinet.py @@ -5,6 +5,8 @@ from paddle import distribution as D from parakeet.models.wavenet import WaveNet, UpsampleNet, crop +__all__ = ["Clarinet"] + class ParallelWaveNet(nn.LayerList): def __init__(self, n_loops, n_layers, residual_channels, condition_dim, filter_size): diff --git a/parakeet/models/deepvoice3.py b/parakeet/models/deepvoice3.py index 3e6ff1c..e44edcb 100644 --- a/parakeet/models/deepvoice3.py +++ b/parakeet/models/deepvoice3.py @@ -8,6 +8,8 @@ from paddle.nn import initializer as I from parakeet.modules import positional_encoding as pe +__all__ = ["SpectraNet"] + class ConvBlock(nn.Layer): def __init__(self, in_channel, kernel_size, causal=False, has_bias=False, bias_dim=None, keep_prob=1.): diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index 2802638..86993dd 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -9,6 +9,8 @@ from parakeet.modules import masking from parakeet.modules.cbhg import Conv1dBatchNorm from parakeet.modules import positional_encoding as pe +__all__ = ["TransformerTTS"] + # Transformer TTS's own implementation of transformer class MultiheadAttention(nn.Layer): """ @@ -255,12 +257,22 @@ class TransformerTTS(nn.Layer): self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout) self.final_proj = nn.Linear(d_model, max_reduction_factor * d_mel) self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) + self.stop_conditioner = nn.Linear(d_mel, 3) # specs self.padding_idx = padding_idx self.d_model = d_model self.pe_scalar = positional_encoding_scalar + # start and end + dtype = paddle.get_default_dtype() + self.start_vec = paddle.fill_constant([1, d_mel], dtype=dtype, value=0) + self.end_vec = paddle.fill_constant([1, d_mel], dtype=dtype, value=0) + self.stop_prob_index = 2 + + self.max_r = max_reduction_factor + self.r = max_reduction_factor # set it every call + def forward(self, text, mel, stop): pass @@ -292,11 +304,45 @@ class TransformerTTS(nn.Layer): output_proj = self.final_proj(decoder_output) mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim]) + stop_logits = self.stop_conditioner(mel_intermediate) mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1]) mel_output = self.decoder_postnet(mel_channel_first) mel_output = paddle.transpose(mel_output, [0, 2, 1]) - return mel_output, mel_intermediate, cross_attention_weights + return mel_output, mel_intermediate, cross_attention_weights, stop_logits - def infer(self): - pass \ No newline at end of file + def predict(self, input, max_length=1000, verbose=True): + """[summary] + + Args: + input (Tensor): shape (T), dtype int, input text sequencce. + max_length (int, optional): max decoder steps. Defaults to 1000. + verbose (bool, optional): display progress bar. Defaults to True. + """ + text_input = paddle.unsqueeze(input, 0) # (1, T) + decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) + decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) + + # encoder the text sequence + encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input) + for _ in range(int(max_length // self.r) + 1): + mel_output, _, cross_attention_weights, stop_logits = self.decode( + encoder_output, decoder_input, encoder_padding_mask) + + # extract last step and append it to decoder input + decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1) + # extract last r steps and append it to decoder output + decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) + + # stop condition? + if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index: + if verbose: + print("Hits stop condition.") + break + + return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights + + + + + \ No newline at end of file diff --git a/parakeet/models/waveflow.py b/parakeet/models/waveflow.py index c0158bc..fefd2f8 100644 --- a/parakeet/models/waveflow.py +++ b/parakeet/models/waveflow.py @@ -11,7 +11,8 @@ import itertools import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid -from parakeet.modules import weight_norm + +__all__ = ["WaveFlow"] def fold(x, n_group): """Fold audio or spectrogram's temporal dimension in to groups. diff --git a/parakeet/models/wavenet.py b/parakeet/models/wavenet.py index 327d53e..41a06be 100644 --- a/parakeet/models/wavenet.py +++ b/parakeet/models/wavenet.py @@ -26,6 +26,8 @@ import paddle.fluid.layers.distributions as D from parakeet.modules.conv import Conv1dCell +__all__ = ["ConditionalWavenet"] + def quantize(values, n_bands): """Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands). diff --git a/tests/test_transformer_tts.py b/tests/test_transformer_tts.py index 78ad523..04676bc 100644 --- a/tests/test_transformer_tts.py +++ b/tests/test_transformer_tts.py @@ -78,9 +78,10 @@ class TestTransformerTTS(unittest.TestCase): text = text * mask encoded, attention_weights, encoder_mask = net.encode(text) - print(encoded.numpy().shape) - print([item.shape for item in attention_weights]) - print(encoder_mask.numpy().shape) + print("output shapes:") + print("encoded:", encoded.numpy().shape) + print("encoder_attentions:", [item.shape for item in attention_weights]) + print("encoder_mask:", encoder_mask.numpy().shape) def test_all_io(self): net = self.net @@ -96,7 +97,7 @@ class TestTransformerTTS(unittest.TestCase): mel = mel * mask.unsqueeze(-1) encoded, encoder_attention_weights, encoder_mask = net.encode(text) - mel_output, mel_intermediate, cross_attention_weights = net.decode(encoded, mel, encoder_mask) + mel_output, mel_intermediate, cross_attention_weights, stop_logits = net.decode(encoded, mel, encoder_mask) print("output shapes:") print("encoder_output:", encoded.numpy().shape) @@ -105,5 +106,17 @@ class TestTransformerTTS(unittest.TestCase): print("mel_output: ", mel_output.numpy().shape) print("mel_intermediate: ", mel_intermediate.numpy().shape) print("decoder_attentions:", [item.shape for item in cross_attention_weights]) + print("stop_logits:", stop_logits.numpy().shape) + def test_predict_io(self): + net = self.net + net.eval() + with paddle.no_grad(): + text = paddle.randint(0, 128, [176]) + decoder_output, encoder_attention_weights, cross_attention_weights = net.predict(text) + + print("output shapes:") + print("mel_output: ", decoder_output.numpy().shape) + print("encoder_attentions:", [item.shape for item in encoder_attention_weights]) + print("decoder_attentions:", [item.shape for item in cross_attention_weights]) \ No newline at end of file