modified fastspeech network
This commit is contained in:
parent
47a618ce38
commit
6068374a3c
|
@ -3,8 +3,8 @@ audio:
|
||||||
n_fft: 2048
|
n_fft: 2048
|
||||||
sr: 22050
|
sr: 22050
|
||||||
preemphasis: 0.97
|
preemphasis: 0.97
|
||||||
hop_length: 275
|
hop_length: 256
|
||||||
win_length: 1102
|
win_length: 1024
|
||||||
power: 1.2
|
power: 1.2
|
||||||
min_level_db: -100
|
min_level_db: -100
|
||||||
ref_level_db: 20
|
ref_level_db: 20
|
||||||
|
|
|
@ -11,7 +11,7 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward
|
||||||
class FFTBlock(dg.Layer):
|
class FFTBlock(dg.Layer):
|
||||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2):
|
def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2):
|
||||||
super(FFTBlock, self).__init__()
|
super(FFTBlock, self).__init__()
|
||||||
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout)
|
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, is_bias=True, dropout=dropout, is_concat=False)
|
||||||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
|
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
|
||||||
|
|
||||||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
|
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
|
||||||
|
|
|
@ -161,7 +161,8 @@ class FastSpeech(dg.Layer):
|
||||||
num_conv=5,
|
num_conv=5,
|
||||||
outputs_per_step=cfg.audio.outputs_per_step,
|
outputs_per_step=cfg.audio.outputs_per_step,
|
||||||
use_cudnn=True,
|
use_cudnn=True,
|
||||||
dropout=0.1)
|
dropout=0.1,
|
||||||
|
batchnorm_last=True)
|
||||||
|
|
||||||
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
|
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -29,7 +29,6 @@ def load_checkpoint(step, model_path):
|
||||||
return new_state_dict, opti_dict
|
return new_state_dict, opti_dict
|
||||||
|
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
|
|
||||||
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
||||||
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
|
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
|
||||||
|
|
||||||
|
|
|
@ -47,21 +47,25 @@ class ScaledDotProductAttention(dg.Layer):
|
||||||
return result, attention
|
return result, attention
|
||||||
|
|
||||||
class MultiheadAttention(dg.Layer):
|
class MultiheadAttention(dg.Layer):
|
||||||
def __init__(self, num_hidden, d_k, d_q, num_head=4, dropout=0.1):
|
def __init__(self, num_hidden, d_k, d_q, num_head=4, is_bias=False, dropout=0.1, is_concat=True):
|
||||||
super(MultiheadAttention, self).__init__()
|
super(MultiheadAttention, self).__init__()
|
||||||
self.num_hidden = num_hidden
|
self.num_hidden = num_hidden
|
||||||
self.num_head = num_head
|
self.num_head = num_head
|
||||||
self.d_k = d_k
|
self.d_k = d_k
|
||||||
self.d_q = d_q
|
self.d_q = d_q
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.is_concat = is_concat
|
||||||
|
|
||||||
self.key = Linear(num_hidden, num_head * d_k, is_bias=False)
|
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
||||||
self.value = Linear(num_hidden, num_head * d_k, is_bias=False)
|
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
||||||
self.query = Linear(num_hidden, num_head * d_q, is_bias=False)
|
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
|
||||||
|
|
||||||
self.scal_attn = ScaledDotProductAttention(d_k)
|
self.scal_attn = ScaledDotProductAttention(d_k)
|
||||||
|
|
||||||
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
if self.is_concat:
|
||||||
|
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
||||||
|
else:
|
||||||
|
self.fc = Linear(num_head * d_q, num_hidden)
|
||||||
|
|
||||||
self.layer_norm = dg.LayerNorm(num_hidden)
|
self.layer_norm = dg.LayerNorm(num_hidden)
|
||||||
|
|
||||||
|
@ -105,7 +109,8 @@ class MultiheadAttention(dg.Layer):
|
||||||
# concat all multihead result
|
# concat all multihead result
|
||||||
result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
||||||
result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
|
result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
|
||||||
result = layers.concat([query_input,result], axis=-1)
|
if self.is_concat:
|
||||||
|
result = layers.concat([query_input,result], axis=-1)
|
||||||
result = layers.dropout(self.fc(result), self.dropout)
|
result = layers.dropout(self.fc(result), self.dropout)
|
||||||
result = result + query_input
|
result = result + query_input
|
||||||
|
|
||||||
|
|
|
@ -12,11 +12,13 @@ class PostConvNet(dg.Layer):
|
||||||
num_conv=5,
|
num_conv=5,
|
||||||
outputs_per_step=1,
|
outputs_per_step=1,
|
||||||
use_cudnn=True,
|
use_cudnn=True,
|
||||||
dropout=0.1):
|
dropout=0.1,
|
||||||
|
batchnorm_last=False):
|
||||||
super(PostConvNet, self).__init__()
|
super(PostConvNet, self).__init__()
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.num_conv = num_conv
|
self.num_conv = num_conv
|
||||||
|
self.batchnorm_last = batchnorm_last
|
||||||
self.conv_list = []
|
self.conv_list = []
|
||||||
self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
|
self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
|
||||||
out_channels = num_hidden,
|
out_channels = num_hidden,
|
||||||
|
@ -45,8 +47,9 @@ class PostConvNet(dg.Layer):
|
||||||
|
|
||||||
self.batch_norm_list = [dg.BatchNorm(num_hidden,
|
self.batch_norm_list = [dg.BatchNorm(num_hidden,
|
||||||
data_layout='NCHW') for _ in range(num_conv-1)]
|
data_layout='NCHW') for _ in range(num_conv-1)]
|
||||||
#self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
|
if self.batchnorm_last:
|
||||||
# data_layout='NCHW'))
|
self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
|
||||||
|
data_layout='NCHW'))
|
||||||
for i, layer in enumerate(self.batch_norm_list):
|
for i, layer in enumerate(self.batch_norm_list):
|
||||||
self.add_sublayer("batch_norm_list_{}".format(i), layer)
|
self.add_sublayer("batch_norm_list_{}".format(i), layer)
|
||||||
|
|
||||||
|
@ -70,5 +73,8 @@ class PostConvNet(dg.Layer):
|
||||||
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
|
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
|
||||||
conv = self.conv_list[self.num_conv-1]
|
conv = self.conv_list[self.num_conv-1]
|
||||||
input = conv(input)[:,:,:len]
|
input = conv(input)[:,:,:len]
|
||||||
|
if self.batchnorm_last:
|
||||||
|
batch_norm = self.batch_norm_list[self.num_conv-1]
|
||||||
|
input = layers.dropout(batch_norm(input), self.dropout)
|
||||||
output = layers.transpose(input, [0,2,1])
|
output = layers.transpose(input, [0,2,1])
|
||||||
return output
|
return output
|
Loading…
Reference in New Issue