diff --git a/parakeet/models/fastspeech/config/fastspeech.yaml b/parakeet/models/fastspeech/config/fastspeech.yaml index 3e53388..90f520f 100644 --- a/parakeet/models/fastspeech/config/fastspeech.yaml +++ b/parakeet/models/fastspeech/config/fastspeech.yaml @@ -3,8 +3,8 @@ audio: n_fft: 2048 sr: 22050 preemphasis: 0.97 - hop_length: 275 - win_length: 1102 + hop_length: 256 + win_length: 1024 power: 1.2 min_level_db: -100 ref_level_db: 20 diff --git a/parakeet/models/fastspeech/modules.py b/parakeet/models/fastspeech/modules.py index 7950728..68d4776 100644 --- a/parakeet/models/fastspeech/modules.py +++ b/parakeet/models/fastspeech/modules.py @@ -11,7 +11,7 @@ from parakeet.modules.feed_forward import PositionwiseFeedForward class FFTBlock(dg.Layer): def __init__(self, d_model, d_inner, n_head, d_k, d_v, filter_size, padding, dropout=0.2): super(FFTBlock, self).__init__() - self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, dropout=dropout) + self.slf_attn = MultiheadAttention(d_model, d_k, d_v, num_head=n_head, is_bias=True, dropout=dropout, is_concat=False) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, filter_size =filter_size, padding =padding, dropout=dropout) def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): diff --git a/parakeet/models/fastspeech/network.py b/parakeet/models/fastspeech/network.py index 5005cba..f1a1e91 100644 --- a/parakeet/models/fastspeech/network.py +++ b/parakeet/models/fastspeech/network.py @@ -161,7 +161,8 @@ class FastSpeech(dg.Layer): num_conv=5, outputs_per_step=cfg.audio.outputs_per_step, use_cudnn=True, - dropout=0.1) + dropout=0.1, + batchnorm_last=True) def forward(self, character, text_pos, mel_pos=None, length_target=None, alpha=1.0): """ diff --git a/parakeet/models/fastspeech/train.py b/parakeet/models/fastspeech/train.py index c29bf6d..2caf1e9 100644 --- a/parakeet/models/fastspeech/train.py +++ b/parakeet/models/fastspeech/train.py @@ -29,7 +29,6 @@ def load_checkpoint(step, model_path): return new_state_dict, opti_dict def main(cfg): - local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0 nranks = dg.parallel.Env().nranks if cfg.use_data_parallel else 1 diff --git a/parakeet/modules/multihead_attention.py b/parakeet/modules/multihead_attention.py index 627ca32..7a1d63f 100644 --- a/parakeet/modules/multihead_attention.py +++ b/parakeet/modules/multihead_attention.py @@ -47,21 +47,25 @@ class ScaledDotProductAttention(dg.Layer): return result, attention class MultiheadAttention(dg.Layer): - def __init__(self, num_hidden, d_k, d_q, num_head=4, dropout=0.1): + def __init__(self, num_hidden, d_k, d_q, num_head=4, is_bias=False, dropout=0.1, is_concat=True): super(MultiheadAttention, self).__init__() self.num_hidden = num_hidden self.num_head = num_head self.d_k = d_k self.d_q = d_q self.dropout = dropout + self.is_concat = is_concat - self.key = Linear(num_hidden, num_head * d_k, is_bias=False) - self.value = Linear(num_hidden, num_head * d_k, is_bias=False) - self.query = Linear(num_hidden, num_head * d_q, is_bias=False) + self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias) + self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias) + self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias) self.scal_attn = ScaledDotProductAttention(d_k) - self.fc = Linear(num_head * d_q * 2, num_hidden) + if self.is_concat: + self.fc = Linear(num_head * d_q * 2, num_hidden) + else: + self.fc = Linear(num_head * d_q, num_hidden) self.layer_norm = dg.LayerNorm(num_hidden) @@ -105,7 +109,8 @@ class MultiheadAttention(dg.Layer): # concat all multihead result result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q]) result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1]) - result = layers.concat([query_input,result], axis=-1) + if self.is_concat: + result = layers.concat([query_input,result], axis=-1) result = layers.dropout(self.fc(result), self.dropout) result = result + query_input diff --git a/parakeet/modules/post_convnet.py b/parakeet/modules/post_convnet.py index 3546c7a..8a6a490 100644 --- a/parakeet/modules/post_convnet.py +++ b/parakeet/modules/post_convnet.py @@ -12,11 +12,13 @@ class PostConvNet(dg.Layer): num_conv=5, outputs_per_step=1, use_cudnn=True, - dropout=0.1): + dropout=0.1, + batchnorm_last=False): super(PostConvNet, self).__init__() self.dropout = dropout self.num_conv = num_conv + self.batchnorm_last = batchnorm_last self.conv_list = [] self.conv_list.append(Conv(in_channels = n_mels * outputs_per_step, out_channels = num_hidden, @@ -45,8 +47,9 @@ class PostConvNet(dg.Layer): self.batch_norm_list = [dg.BatchNorm(num_hidden, data_layout='NCHW') for _ in range(num_conv-1)] - #self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step, - # data_layout='NCHW')) + if self.batchnorm_last: + self.batch_norm_list.append(dg.BatchNorm(n_mels * outputs_per_step, + data_layout='NCHW')) for i, layer in enumerate(self.batch_norm_list): self.add_sublayer("batch_norm_list_{}".format(i), layer) @@ -70,5 +73,8 @@ class PostConvNet(dg.Layer): input = layers.dropout(layers.tanh(batch_norm(conv(input)[:,:,:len])), self.dropout) conv = self.conv_list[self.num_conv-1] input = conv(input)[:,:,:len] + if self.batchnorm_last: + batch_norm = self.batch_norm_list[self.num_conv-1] + input = layers.dropout(batch_norm(input), self.dropout) output = layers.transpose(input, [0,2,1]) return output \ No newline at end of file