remove bn in postnet

This commit is contained in:
chenfeiyu 2021-02-27 03:26:41 +08:00
parent 929165b64a
commit 3c60fec900
1 changed files with 5 additions and 3 deletions

View File

@ -340,13 +340,13 @@ class CNNPostNet(nn.Layer):
c_in = d_input if i == 0 else d_hidden c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append( self.convs.append(
Conv1dBatchNorm( nn.Conv1D(
c_in, c_in,
c_out, c_out,
kernel_size, kernel_size,
weight_attr=I.XavierUniform(), weight_attr=I.XavierUniform(),
padding=padding)) padding=padding))
self.last_bn = nn.BatchNorm1D(d_output) # self.last_bn = nn.BatchNorm1D(d_output)
# for a layer that ends with a normalization layer that is targeted to # for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to # output a non zero-central output, it may take a long time to
# train the scale and bias # train the scale and bias
@ -358,7 +358,9 @@ class CNNPostNet(nn.Layer):
x = layer(x) x = layer(x)
if i != (len(self.convs) - 1): if i != (len(self.convs) - 1):
x = F.tanh(x) x = F.tanh(x)
x = self.last_bn(x_in + x) # TODO: check it
x = x_in + x
# x = self.last_bn(x_in + x)
return x return x