1. import models into parakeet.models;

2. add predict for TransformerTTS and test its io.
This commit is contained in:
iclementine 2020-10-16 13:51:56 +08:00
parent 6aa7af1aa4
commit c1e0aecdde
7 changed files with 82 additions and 8 deletions

View File

@ -11,3 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parakeet.models.clarinet import *
from parakeet.models.waveflow import *
from parakeet.models.wavenet import *
from parakeet.models.transformer_tts import *
from parakeet.models.deepvoice3 import *
# from parakeet.models.fastspeech import *

View File

@ -5,6 +5,8 @@ from paddle import distribution as D
from parakeet.models.wavenet import WaveNet, UpsampleNet, crop
__all__ = ["Clarinet"]
class ParallelWaveNet(nn.LayerList):
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
filter_size):

View File

@ -8,6 +8,8 @@ from paddle.nn import initializer as I
from parakeet.modules import positional_encoding as pe
__all__ = ["SpectraNet"]
class ConvBlock(nn.Layer):
def __init__(self, in_channel, kernel_size, causal=False, has_bias=False,
bias_dim=None, keep_prob=1.):

View File

@ -9,6 +9,8 @@ from parakeet.modules import masking
from parakeet.modules.cbhg import Conv1dBatchNorm
from parakeet.modules import positional_encoding as pe
__all__ = ["TransformerTTS"]
# Transformer TTS's own implementation of transformer
class MultiheadAttention(nn.Layer):
"""
@ -255,12 +257,22 @@ class TransformerTTS(nn.Layer):
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
self.final_proj = nn.Linear(d_model, max_reduction_factor * d_mel)
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
self.stop_conditioner = nn.Linear(d_mel, 3)
# specs
self.padding_idx = padding_idx
self.d_model = d_model
self.pe_scalar = positional_encoding_scalar
# start and end
dtype = paddle.get_default_dtype()
self.start_vec = paddle.fill_constant([1, d_mel], dtype=dtype, value=0)
self.end_vec = paddle.fill_constant([1, d_mel], dtype=dtype, value=0)
self.stop_prob_index = 2
self.max_r = max_reduction_factor
self.r = max_reduction_factor # set it every call
def forward(self, text, mel, stop):
pass
@ -292,11 +304,45 @@ class TransformerTTS(nn.Layer):
output_proj = self.final_proj(decoder_output)
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
stop_logits = self.stop_conditioner(mel_intermediate)
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
mel_output = self.decoder_postnet(mel_channel_first)
mel_output = paddle.transpose(mel_output, [0, 2, 1])
return mel_output, mel_intermediate, cross_attention_weights
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
def predict(self, input, max_length=1000, verbose=True):
"""[summary]
Args:
input (Tensor): shape (T), dtype int, input text sequencce.
max_length (int, optional): max decoder steps. Defaults to 1000.
verbose (bool, optional): display progress bar. Defaults to True.
"""
text_input = paddle.unsqueeze(input, 0) # (1, T)
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
# encoder the text sequence
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
for _ in range(int(max_length // self.r) + 1):
mel_output, _, cross_attention_weights, stop_logits = self.decode(
encoder_output, decoder_input, encoder_padding_mask)
# extract last step and append it to decoder input
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1)
# extract last r steps and append it to decoder output
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition?
if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index:
if verbose:
print("Hits stop condition.")
break
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
def infer(self):
pass

View File

@ -11,7 +11,8 @@ import itertools
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from parakeet.modules import weight_norm
__all__ = ["WaveFlow"]
def fold(x, n_group):
"""Fold audio or spectrogram's temporal dimension in to groups.

View File

@ -26,6 +26,8 @@ import paddle.fluid.layers.distributions as D
from parakeet.modules.conv import Conv1dCell
__all__ = ["ConditionalWavenet"]
def quantize(values, n_bands):
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).

View File

@ -78,9 +78,10 @@ class TestTransformerTTS(unittest.TestCase):
text = text * mask
encoded, attention_weights, encoder_mask = net.encode(text)
print(encoded.numpy().shape)
print([item.shape for item in attention_weights])
print(encoder_mask.numpy().shape)
print("output shapes:")
print("encoded:", encoded.numpy().shape)
print("encoder_attentions:", [item.shape for item in attention_weights])
print("encoder_mask:", encoder_mask.numpy().shape)
def test_all_io(self):
net = self.net
@ -96,7 +97,7 @@ class TestTransformerTTS(unittest.TestCase):
mel = mel * mask.unsqueeze(-1)
encoded, encoder_attention_weights, encoder_mask = net.encode(text)
mel_output, mel_intermediate, cross_attention_weights = net.decode(encoded, mel, encoder_mask)
mel_output, mel_intermediate, cross_attention_weights, stop_logits = net.decode(encoded, mel, encoder_mask)
print("output shapes:")
print("encoder_output:", encoded.numpy().shape)
@ -105,5 +106,17 @@ class TestTransformerTTS(unittest.TestCase):
print("mel_output: ", mel_output.numpy().shape)
print("mel_intermediate: ", mel_intermediate.numpy().shape)
print("decoder_attentions:", [item.shape for item in cross_attention_weights])
print("stop_logits:", stop_logits.numpy().shape)
def test_predict_io(self):
net = self.net
net.eval()
with paddle.no_grad():
text = paddle.randint(0, 128, [176])
decoder_output, encoder_attention_weights, cross_attention_weights = net.predict(text)
print("output shapes:")
print("mel_output: ", decoder_output.numpy().shape)
print("encoder_attentions:", [item.shape for item in encoder_attention_weights])
print("decoder_attentions:", [item.shape for item in cross_attention_weights])