remove bn in postnet

This commit is contained in:
chenfeiyu 2021-02-27 03:26:41 +08:00
parent 929165b64a
commit 3c60fec900
1 changed files with 5 additions and 3 deletions

View File

@ -340,13 +340,13 @@ class CNNPostNet(nn.Layer):
c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append(
Conv1dBatchNorm(
nn.Conv1D(
c_in,
c_out,
kernel_size,
weight_attr=I.XavierUniform(),
padding=padding))
self.last_bn = nn.BatchNorm1D(d_output)
# self.last_bn = nn.BatchNorm1D(d_output)
# for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to
# train the scale and bias
@ -358,7 +358,9 @@ class CNNPostNet(nn.Layer):
x = layer(x)
if i != (len(self.convs) - 1):
x = F.tanh(x)
x = self.last_bn(x_in + x)
# TODO: check it
x = x_in + x
# x = self.last_bn(x_in + x)
return x