remove bn in postnet
This commit is contained in:
parent
929165b64a
commit
3c60fec900
|
@ -340,13 +340,13 @@ class CNNPostNet(nn.Layer):
|
|||
c_in = d_input if i == 0 else d_hidden
|
||||
c_out = d_output if i == n_layers - 1 else d_hidden
|
||||
self.convs.append(
|
||||
Conv1dBatchNorm(
|
||||
nn.Conv1D(
|
||||
c_in,
|
||||
c_out,
|
||||
kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding=padding))
|
||||
self.last_bn = nn.BatchNorm1D(d_output)
|
||||
# self.last_bn = nn.BatchNorm1D(d_output)
|
||||
# for a layer that ends with a normalization layer that is targeted to
|
||||
# output a non zero-central output, it may take a long time to
|
||||
# train the scale and bias
|
||||
|
@ -358,7 +358,9 @@ class CNNPostNet(nn.Layer):
|
|||
x = layer(x)
|
||||
if i != (len(self.convs) - 1):
|
||||
x = F.tanh(x)
|
||||
x = self.last_bn(x_in + x)
|
||||
# TODO: check it
|
||||
x = x_in + x
|
||||
# x = self.last_bn(x_in + x)
|
||||
return x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue