ParakeetEricRoss/parakeet/models/transformer_tts/transformer_tts.py

27 lines
865 B
Python

import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
from parakeet.models.transformer_tts.encoder import Encoder
from parakeet.models.transformer_tts.decoder import Decoder
class TransformerTTS(dg.Layer):
def __init__(self, config):
super(TransformerTTS, self).__init__()
self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
self.decoder = Decoder(config['hidden_size'], config)
self.config = config
def forward(self, characters, mel_input, pos_text, pos_mel):
key, c_mask, attns_enc = self.encoder(characters, pos_text)
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(key, key, mel_input, c_mask, pos_mel)
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec