1. import models into parakeet.models;
2. add predict for TransformerTTS and test its io.
This commit is contained in:
parent
6aa7af1aa4
commit
c1e0aecdde
|
@ -11,3 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from parakeet.models.clarinet import *
|
||||
from parakeet.models.waveflow import *
|
||||
from parakeet.models.wavenet import *
|
||||
|
||||
from parakeet.models.transformer_tts import *
|
||||
from parakeet.models.deepvoice3 import *
|
||||
# from parakeet.models.fastspeech import *
|
||||
|
|
|
@ -5,6 +5,8 @@ from paddle import distribution as D
|
|||
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet, crop
|
||||
|
||||
__all__ = ["Clarinet"]
|
||||
|
||||
class ParallelWaveNet(nn.LayerList):
|
||||
def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
|
||||
filter_size):
|
||||
|
|
|
@ -8,6 +8,8 @@ from paddle.nn import initializer as I
|
|||
|
||||
from parakeet.modules import positional_encoding as pe
|
||||
|
||||
__all__ = ["SpectraNet"]
|
||||
|
||||
class ConvBlock(nn.Layer):
|
||||
def __init__(self, in_channel, kernel_size, causal=False, has_bias=False,
|
||||
bias_dim=None, keep_prob=1.):
|
||||
|
|
|
@ -9,6 +9,8 @@ from parakeet.modules import masking
|
|||
from parakeet.modules.cbhg import Conv1dBatchNorm
|
||||
from parakeet.modules import positional_encoding as pe
|
||||
|
||||
__all__ = ["TransformerTTS"]
|
||||
|
||||
# Transformer TTS's own implementation of transformer
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""
|
||||
|
@ -255,12 +257,22 @@ class TransformerTTS(nn.Layer):
|
|||
self.decoder = TransformerDecoder(d_model, n_heads, d_ffn, decoder_layers, dropout)
|
||||
self.final_proj = nn.Linear(d_model, max_reduction_factor * d_mel)
|
||||
self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers)
|
||||
self.stop_conditioner = nn.Linear(d_mel, 3)
|
||||
|
||||
# specs
|
||||
self.padding_idx = padding_idx
|
||||
self.d_model = d_model
|
||||
self.pe_scalar = positional_encoding_scalar
|
||||
|
||||
# start and end
|
||||
dtype = paddle.get_default_dtype()
|
||||
self.start_vec = paddle.fill_constant([1, d_mel], dtype=dtype, value=0)
|
||||
self.end_vec = paddle.fill_constant([1, d_mel], dtype=dtype, value=0)
|
||||
self.stop_prob_index = 2
|
||||
|
||||
self.max_r = max_reduction_factor
|
||||
self.r = max_reduction_factor # set it every call
|
||||
|
||||
|
||||
def forward(self, text, mel, stop):
|
||||
pass
|
||||
|
@ -292,11 +304,45 @@ class TransformerTTS(nn.Layer):
|
|||
|
||||
output_proj = self.final_proj(decoder_output)
|
||||
mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim])
|
||||
stop_logits = self.stop_conditioner(mel_intermediate)
|
||||
|
||||
mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1])
|
||||
mel_output = self.decoder_postnet(mel_channel_first)
|
||||
mel_output = paddle.transpose(mel_output, [0, 2, 1])
|
||||
return mel_output, mel_intermediate, cross_attention_weights
|
||||
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
||||
|
||||
def predict(self, input, max_length=1000, verbose=True):
|
||||
"""[summary]
|
||||
|
||||
Args:
|
||||
input (Tensor): shape (T), dtype int, input text sequencce.
|
||||
max_length (int, optional): max decoder steps. Defaults to 1000.
|
||||
verbose (bool, optional): display progress bar. Defaults to True.
|
||||
"""
|
||||
text_input = paddle.unsqueeze(input, 0) # (1, T)
|
||||
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||
|
||||
# encoder the text sequence
|
||||
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input)
|
||||
for _ in range(int(max_length // self.r) + 1):
|
||||
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
||||
encoder_output, decoder_input, encoder_padding_mask)
|
||||
|
||||
# extract last step and append it to decoder input
|
||||
decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1)
|
||||
# extract last r steps and append it to decoder output
|
||||
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||
|
||||
# stop condition?
|
||||
if paddle.argmax(stop_logits[:, -1, :]) == self.stop_prob_index:
|
||||
if verbose:
|
||||
print("Hits stop condition.")
|
||||
break
|
||||
|
||||
return decoder_output[:, 1:, :], encoder_attentions, cross_attention_weights
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def infer(self):
|
||||
pass
|
|
@ -11,7 +11,8 @@ import itertools
|
|||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
from parakeet.modules import weight_norm
|
||||
|
||||
__all__ = ["WaveFlow"]
|
||||
|
||||
def fold(x, n_group):
|
||||
"""Fold audio or spectrogram's temporal dimension in to groups.
|
||||
|
|
|
@ -26,6 +26,8 @@ import paddle.fluid.layers.distributions as D
|
|||
|
||||
from parakeet.modules.conv import Conv1dCell
|
||||
|
||||
__all__ = ["ConditionalWavenet"]
|
||||
|
||||
def quantize(values, n_bands):
|
||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
||||
|
||||
|
|
|
@ -78,9 +78,10 @@ class TestTransformerTTS(unittest.TestCase):
|
|||
text = text * mask
|
||||
|
||||
encoded, attention_weights, encoder_mask = net.encode(text)
|
||||
print(encoded.numpy().shape)
|
||||
print([item.shape for item in attention_weights])
|
||||
print(encoder_mask.numpy().shape)
|
||||
print("output shapes:")
|
||||
print("encoded:", encoded.numpy().shape)
|
||||
print("encoder_attentions:", [item.shape for item in attention_weights])
|
||||
print("encoder_mask:", encoder_mask.numpy().shape)
|
||||
|
||||
def test_all_io(self):
|
||||
net = self.net
|
||||
|
@ -96,7 +97,7 @@ class TestTransformerTTS(unittest.TestCase):
|
|||
mel = mel * mask.unsqueeze(-1)
|
||||
|
||||
encoded, encoder_attention_weights, encoder_mask = net.encode(text)
|
||||
mel_output, mel_intermediate, cross_attention_weights = net.decode(encoded, mel, encoder_mask)
|
||||
mel_output, mel_intermediate, cross_attention_weights, stop_logits = net.decode(encoded, mel, encoder_mask)
|
||||
|
||||
print("output shapes:")
|
||||
print("encoder_output:", encoded.numpy().shape)
|
||||
|
@ -105,5 +106,17 @@ class TestTransformerTTS(unittest.TestCase):
|
|||
print("mel_output: ", mel_output.numpy().shape)
|
||||
print("mel_intermediate: ", mel_intermediate.numpy().shape)
|
||||
print("decoder_attentions:", [item.shape for item in cross_attention_weights])
|
||||
print("stop_logits:", stop_logits.numpy().shape)
|
||||
|
||||
def test_predict_io(self):
|
||||
net = self.net
|
||||
net.eval()
|
||||
with paddle.no_grad():
|
||||
text = paddle.randint(0, 128, [176])
|
||||
decoder_output, encoder_attention_weights, cross_attention_weights = net.predict(text)
|
||||
|
||||
print("output shapes:")
|
||||
print("mel_output: ", decoder_output.numpy().shape)
|
||||
print("encoder_attentions:", [item.shape for item in encoder_attention_weights])
|
||||
print("decoder_attentions:", [item.shape for item in cross_attention_weights])
|
||||
|
Loading…
Reference in New Issue