Merge branch 'master' into 'master'

fixes for wavenet and modules

See merge request !47
This commit is contained in:
liuyibing01 2020-03-22 11:44:42 +08:00
commit be70b41fd1
5 changed files with 24 additions and 8 deletions

View File

@ -101,6 +101,8 @@ if __name__ == "__main__":
state, _ = dg.load_dygraph(args.checkpoint)
dv3.set_dict(state)
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
for layer in dv3.sublayers():
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()

View File

@ -21,6 +21,7 @@ from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
from parakeet.modules.weight_norm import WeightNormWrapper
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
@ -114,6 +115,12 @@ if __name__ == "__main__":
print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict)
# WARNING: don't forget to remove weight norm to re-compute each wrapped layer's weight
# removing weight norm also speeds up computation
for layer in model.sublayers():
if isinstance(layer, WeightNormWrapper):
layer.remove_weight_norm()
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)

View File

@ -313,6 +313,7 @@ class WaveNet(dg.Layer):
"""
# Causal Conv
if self.loss_type == "softmax":
x = F.clip(x, min=-1., max=0.99999)
x = quantize(x, self.output_dim)
x = self.embed(x) # (B, T, C), T=1
else:

View File

@ -86,7 +86,7 @@ class Conv1D(dg.Conv2D):
stride=1,
padding=0,
dilation=1,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -128,7 +128,7 @@ class Conv1DTranspose(dg.Conv2DTranspose):
padding=0,
stride=1,
dilation=1,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -179,7 +179,7 @@ class Conv1DCell(Conv1D):
filter_size,
dilation=1,
causal=False,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -225,6 +225,12 @@ class Conv1DCell(Conv1D):
def start_sequence(self):
"""Prepare the Conv1DCell to generate a new sequence, this method should be called before calling add_input multiple times.
WARNING:
This method accesses `self.weight` directly. If a `Conv1DCell` object is wrapped in a `WeightNormWrapper`, make sure this method is called only after the `WeightNormWrapper`'s hook is called.
`WeightNormWrapper` removes the wrapped layer's `weight`, add has a `weight_v` and `weight_g` to re-compute the wrapped layer's weight as $weight = weight_g * weight_v / ||weight_v||$. (Recomputing the `weight` is a hook before calling the wrapped layer's `forward` method.)
Whenever a `WeightNormWrapper`'s `forward` method is called, the wrapped layer's weight is updated. But when loading from a checkpoint, `weight_v` and `weight_g` are updated but the wrapped layer's weight is not, since it is no longer a `Parameter`. You should manually call `remove_weight_norm` or `hook` to re-compute the wrapped layer's weight before calling this method if you don't call `forward` first.
So when loading a model which uses `Conv1DCell` objects wrapped in `WeightNormWrapper`s, remember to call `remove_weight_norm` for all `WeightNormWrapper`s before synthesizing. Also, removing weight norm speeds up computation.
"""
if not self.causal:
raise ValueError(

View File

@ -151,7 +151,7 @@ def Conv1D(num_channels,
stride=1,
padding=0,
dilation=1,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -170,7 +170,7 @@ def Conv1DTranspose(num_channels,
padding=0,
stride=1,
dilation=1,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -188,7 +188,7 @@ def Conv1DCell(num_channels,
filter_size,
dilation=1,
causal=False,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -207,7 +207,7 @@ def Conv2D(num_channels,
stride=1,
padding=0,
dilation=1,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,
@ -228,7 +228,7 @@ def Conv2DTranspose(num_channels,
padding=0,
stride=1,
dilation=1,
groups=None,
groups=1,
param_attr=None,
bias_attr=None,
use_cudnn=True,