From 3c60fec900c0f209652987a0245d7176786486b4 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Sat, 27 Feb 2021 03:26:41 +0800 Subject: [PATCH] remove bn in postnet --- parakeet/models/transformer_tts.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index ab941e7..64945eb 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -340,13 +340,13 @@ class CNNPostNet(nn.Layer): c_in = d_input if i == 0 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden self.convs.append( - Conv1dBatchNorm( + nn.Conv1D( c_in, c_out, kernel_size, weight_attr=I.XavierUniform(), padding=padding)) - self.last_bn = nn.BatchNorm1D(d_output) + # self.last_bn = nn.BatchNorm1D(d_output) # for a layer that ends with a normalization layer that is targeted to # output a non zero-central output, it may take a long time to # train the scale and bias @@ -358,7 +358,9 @@ class CNNPostNet(nn.Layer): x = layer(x) if i != (len(self.convs) - 1): x = F.tanh(x) - x = self.last_bn(x_in + x) + # TODO: check it + x = x_in + x + # x = self.last_bn(x_in + x) return x