27 lines
865 B
Python
27 lines
865 B
Python
import paddle.fluid.dygraph as dg
|
|
import paddle.fluid as fluid
|
|
from parakeet.models.transformer_tts.encoder import Encoder
|
|
from parakeet.models.transformer_tts.decoder import Decoder
|
|
|
|
class TransformerTTS(dg.Layer):
|
|
def __init__(self, config):
|
|
super(TransformerTTS, self).__init__()
|
|
self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
|
|
self.decoder = Decoder(config['hidden_size'], config)
|
|
self.config = config
|
|
|
|
def forward(self, characters, mel_input, pos_text, pos_mel):
|
|
|
|
key, c_mask, attns_enc = self.encoder(characters, pos_text)
|
|
|
|
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(key, key, mel_input, c_mask, pos_mel)
|
|
|
|
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|