ParakeetEricRoss/parakeet/models/transformer_tts/vocoder.py

28 lines
1.0 KiB
Python
Raw Normal View History

2020-02-10 15:47:19 +08:00
import paddle.fluid.dygraph as dg
import paddle.fluid as fluid
2020-02-11 16:57:30 +08:00
from parakeet.modules.customized import Conv1D
from parakeet.models.transformer_tts.utils import *
2020-02-13 20:46:21 +08:00
from parakeet.models.transformer_tts.cbhg import CBHG
2020-02-10 15:47:19 +08:00
class Vocoder(dg.Layer):
"""
CBHG Network (mel -> linear)
"""
2020-02-13 14:48:21 +08:00
def __init__(self, config, batch_size):
2020-02-10 15:47:19 +08:00
super(Vocoder, self).__init__()
2020-02-13 14:48:21 +08:00
self.pre_proj = Conv1D(num_channels = config['audio']['num_mels'],
num_filters = config['hidden_size'],
2020-02-12 16:51:32 +08:00
filter_size=1)
2020-02-13 14:48:21 +08:00
self.cbhg = CBHG(config['hidden_size'], batch_size)
self.post_proj = Conv1D(num_channels = config['hidden_size'],
num_filters = (config['audio']['n_fft'] // 2) + 1,
2020-02-12 16:51:32 +08:00
filter_size=1)
2020-02-10 15:47:19 +08:00
def forward(self, mel):
mel = layers.transpose(mel, [0,2,1])
mel = self.pre_proj(mel)
mel = self.cbhg(mel)
mag_pred = self.post_proj(mel)
mag_pred = layers.transpose(mag_pred, [0,2,1])
return mag_pred