remove bn in postnet
This commit is contained in:
parent
929165b64a
commit
3c60fec900
|
@ -340,13 +340,13 @@ class CNNPostNet(nn.Layer):
|
||||||
c_in = d_input if i == 0 else d_hidden
|
c_in = d_input if i == 0 else d_hidden
|
||||||
c_out = d_output if i == n_layers - 1 else d_hidden
|
c_out = d_output if i == n_layers - 1 else d_hidden
|
||||||
self.convs.append(
|
self.convs.append(
|
||||||
Conv1dBatchNorm(
|
nn.Conv1D(
|
||||||
c_in,
|
c_in,
|
||||||
c_out,
|
c_out,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
weight_attr=I.XavierUniform(),
|
weight_attr=I.XavierUniform(),
|
||||||
padding=padding))
|
padding=padding))
|
||||||
self.last_bn = nn.BatchNorm1D(d_output)
|
# self.last_bn = nn.BatchNorm1D(d_output)
|
||||||
# for a layer that ends with a normalization layer that is targeted to
|
# for a layer that ends with a normalization layer that is targeted to
|
||||||
# output a non zero-central output, it may take a long time to
|
# output a non zero-central output, it may take a long time to
|
||||||
# train the scale and bias
|
# train the scale and bias
|
||||||
|
@ -358,7 +358,9 @@ class CNNPostNet(nn.Layer):
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
if i != (len(self.convs) - 1):
|
if i != (len(self.convs) - 1):
|
||||||
x = F.tanh(x)
|
x = F.tanh(x)
|
||||||
x = self.last_bn(x_in + x)
|
# TODO: check it
|
||||||
|
x = x_in + x
|
||||||
|
# x = self.last_bn(x_in + x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue