Merge branch 'add_TranTTS' into 'master'
modified fastspeech network See merge request !6
This commit is contained in:
commit
9b54ab07f2
|
@ -3,8 +3,8 @@ audio:
|
|||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
hop_length: 256
|
||||
win_length: 1024
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
|
|
|
@ -11,7 +11,7 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward
|
|||
class FFTBlock(dg.Layer):
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2):
|
||||
super(FFTBlock, self).__init__()
|
||||
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout)
|
||||
self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, is_bias=True, dropout=dropout, is_concat=False)
|
||||
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
|
||||
|
|
|
@ -161,7 +161,8 @@ class FastSpeech(dg.Layer):
|
|||
num_conv=5,
|
||||
outputs_per_step=cfg.audio.outputs_per_step,
|
||||
use_cudnn=True,
|
||||
dropout=0.1)
|
||||
dropout=0.1,
|
||||
batchnorm_last=True)
|
||||
|
||||
def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0):
|
||||
"""
|
||||
|
|
|
@ -29,7 +29,6 @@ def load_checkpoint(step, model_path):
|
|||
return new_state_dict, opti_dict
|
||||
|
||||
def main(cfg):
|
||||
|
||||
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
||||
nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1
|
||||
|
||||
|
|
|
@ -47,21 +47,25 @@ class ScaledDotProductAttention(dg.Layer):
|
|||
return result, attention
|
||||
|
||||
class MultiheadAttention(dg.Layer):
|
||||
def __init__(self, num_hidden, d_k, d_q, num_head=4, dropout=0.1):
|
||||
def __init__(self, num_hidden, d_k, d_q, num_head=4, is_bias=False, dropout=0.1, is_concat=True):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
self.d_k = d_k
|
||||
self.d_q = d_q
|
||||
self.dropout = dropout
|
||||
self.is_concat = is_concat
|
||||
|
||||
self.key = Linear(num_hidden, num_head * d_k, is_bias=False)
|
||||
self.value = Linear(num_hidden, num_head * d_k, is_bias=False)
|
||||
self.query = Linear(num_hidden, num_head * d_q, is_bias=False)
|
||||
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
||||
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
||||
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
|
||||
|
||||
self.scal_attn = ScaledDotProductAttention(d_k)
|
||||
|
||||
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
||||
if self.is_concat:
|
||||
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
||||
else:
|
||||
self.fc = Linear(num_head * d_q, num_hidden)
|
||||
|
||||
self.layer_norm = dg.LayerNorm(num_hidden)
|
||||
|
||||
|
@ -105,7 +109,8 @@ class MultiheadAttention(dg.Layer):
|
|||
# concat all multihead result
|
||||
result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
||||
result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
|
||||
result = layers.concat([query_input,result], axis=-1)
|
||||
if self.is_concat:
|
||||
result = layers.concat([query_input,result], axis=-1)
|
||||
result = layers.dropout(self.fc(result), self.dropout)
|
||||
result = result + query_input
|
||||
|
||||
|
|
|
@ -12,11 +12,13 @@ class PostConvNet(dg.Layer):
|
|||
num_conv=5,
|
||||
outputs_per_step=1,
|
||||
use_cudnn=True,
|
||||
dropout=0.1):
|
||||
dropout=0.1,
|
||||
batchnorm_last=False):
|
||||
super(PostConvNet, self).__init__()
|
||||
|
||||
self.dropout = dropout
|
||||
self.num_conv = num_conv
|
||||
self.batchnorm_last = batchnorm_last
|
||||
self.conv_list = []
|
||||
self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step,
|
||||
out_channels = num_hidden,
|
||||
|
@ -45,8 +47,9 @@ class PostConvNet(dg.Layer):
|
|||
|
||||
self.batch_norm_list = [dg.BatchNorm(num_hidden,
|
||||
data_layout='NCHW') for _ in range(num_conv-1)]
|
||||
#self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
|
||||
# data_layout='NCHW'))
|
||||
if self.batchnorm_last:
|
||||
self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step,
|
||||
data_layout='NCHW'))
|
||||
for i, layer in enumerate(self.batch_norm_list):
|
||||
self.add_sublayer("batch_norm_list_{}".format(i), layer)
|
||||
|
||||
|
@ -70,5 +73,8 @@ class PostConvNet(dg.Layer):
|
|||
input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout)
|
||||
conv = self.conv_list[self.num_conv-1]
|
||||
input = conv(input)[:,:,:len]
|
||||
if self.batchnorm_last:
|
||||
batch_norm = self.batch_norm_list[self.num_conv-1]
|
||||
input = layers.dropout(batch_norm(input), self.dropout)
|
||||
output = layers.transpose(input, [0,2,1])
|
||||
return output
|
Loading…
Reference in New Issue