modified fastspeech network

This commit is contained in:
lifuchen 2020-02-06 09:11:28 +00:00 committed by chenfeiyu
parent 47a618ce38
commit 6068374a3c
6 changed files with 25 additions and 14 deletions

View File

@ -3,8 +3,8 @@ audio:
n_fft: 2048
sr: 22050
preemphasis: 0.97
hop_length: 275
win_length: 1102
hop_length: 256
win_length: 1024
power: 1.2
min_level_db: -100
ref_level_db: 20

View File

@ -11,7 +11,7 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward
class FFTBlock(dg.Layer):
def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2):
super(FFTBlock, self).__init__()
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout)
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, is_bias=True, dropout=dropout, is_concat=False)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):

View File

@ -161,7 +161,8 @@ class FastSpeech(dg.Layer):
num_conv=5,
outputs_per_step=cfg.audio.outputs_per_step,
use_cudnn=True,
dropout=0.1)
dropout=0.1,
batchnorm_last=True)
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
"""

View File

@ -29,7 +29,6 @@ def load_checkpoint(step, model_path):
return new_state_dict, opti_dict
def main(cfg):
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1

View File

@ -47,21 +47,25 @@ class ScaledDotProductAttention(dg.Layer):
return result, attention
class MultiheadAttention(dg.Layer):
def __init__(self, num_hidden, d_k, d_q, num_head=4, dropout=0.1):
def __init__(self, num_hidden, d_k, d_q, num_head=4, is_bias=False, dropout=0.1, is_concat=True):
super(MultiheadAttention, self).__init__()
self.num_hidden = num_hidden
self.num_head = num_head
self.d_k = d_k
self.d_q = d_q
self.dropout = dropout
self.is_concat = is_concat
self.key = Linear(num_hidden, num_head * d_k, is_bias=False)
self.value = Linear(num_hidden, num_head * d_k, is_bias=False)
self.query = Linear(num_hidden, num_head * d_q, is_bias=False)
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
self.scal_attn = ScaledDotProductAttention(d_k)
self.fc = Linear(num_head * d_q * 2, num_hidden)
if self.is_concat:
self.fc = Linear(num_head * d_q * 2, num_hidden)
else:
self.fc = Linear(num_head * d_q, num_hidden)
self.layer_norm = dg.LayerNorm(num_hidden)
@ -105,7 +109,8 @@ class MultiheadAttention(dg.Layer):
# concat all multihead result
result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
result = layers.concat([query_input,result], axis=-1)
if self.is_concat:
result = layers.concat([query_input,result], axis=-1)
result = layers.dropout(self.fc(result), self.dropout)
result = result + query_input

View File

@ -12,11 +12,13 @@ class PostConvNet(dg.Layer):
num_conv=5,
outputs_per_step=1,
use_cudnn=True,
dropout=0.1):
dropout=0.1,
batchnorm_last=False):
super(PostConvNet, self).__init__()
self.dropout = dropout
self.num_conv = num_conv
self.batchnorm_last = batchnorm_last
self.conv_list = []
self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
out_channels = num_hidden,
@ -45,8 +47,9 @@ class PostConvNet(dg.Layer):
self.batch_norm_list = [dg.BatchNorm(num_hidden,
data_layout='NCHW') for _ in range(num_conv-1)]
#self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
# data_layout='NCHW'))
if self.batchnorm_last:
self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
data_layout='NCHW'))
for i, layer in enumerate(self.batch_norm_list):
self.add_sublayer("batch_norm_list_{}".format(i), layer)
@ -70,5 +73,8 @@ class PostConvNet(dg.Layer):
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
conv = self.conv_list[self.num_conv-1]
input = conv(input)[:,:,:len]
if self.batchnorm_last:
batch_norm = self.batch_norm_list[self.num_conv-1]
input = layers.dropout(batch_norm(input), self.dropout)
output = layers.transpose(input, [0,2,1])
return output