Enable the fp16 inference for waveflow
This commit is contained in:
parent
1c6cd10ae8
commit
737d142ae4
10
README.md
10
README.md
|
@ -36,14 +36,16 @@ nltk.download("cmudict")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Supported models
|
## Related Research
|
||||||
|
|
||||||
- [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654)
|
- [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654)
|
||||||
- [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895)
|
- [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895)
|
||||||
- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263).
|
- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263).
|
||||||
|
- [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219)
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
- [Train a deepvoice 3 model with ljspeech dataset](./parakeet/examples/deepvoice3)
|
- [Train a DeepVoice3 model with ljspeech dataset](./parakeet/examples/deepvoice3)
|
||||||
- [Train a transformer_tts model with ljspeech dataset](./parakeet/examples/transformer_tts)
|
- [Train a TransformerTTS model with ljspeech dataset](./parakeet/examples/transformer_tts)
|
||||||
- [Train a fastspeech model with ljspeech dataset](./parakeet/examples/fastspeech)
|
- [Train a FastSpeech model with ljspeech dataset](./parakeet/examples/fastspeech)
|
||||||
|
- [Train a WaveFlow model with ljspeech dataset](./parakeet/examples/waveflow)
|
||||||
|
|
|
@ -109,3 +109,13 @@ python -u benchmark.py \
|
||||||
--root=./data/LJSpeech-1.1 \
|
--root=./data/LJSpeech-1.1 \
|
||||||
--name=${ModelName} --use_gpu=true
|
--name=${ModelName} --use_gpu=true
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Low-precision inference
|
||||||
|
|
||||||
|
This model supports the float16 low-precsion inference. By appending the argument
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--use_fp16=true
|
||||||
|
```
|
||||||
|
|
||||||
|
to the command of synthesis and benchmarking, one can experience the fast speed of low-precision inference.
|
||||||
|
|
|
@ -24,9 +24,14 @@ def add_options_to_parser(parser):
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--use_gpu',
|
'--use_gpu',
|
||||||
type=bool,
|
type=utils.str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="option to use gpu training")
|
help="option to use gpu training")
|
||||||
|
parser.add_argument(
|
||||||
|
'--use_fp16',
|
||||||
|
type=utils.str2bool,
|
||||||
|
default=True,
|
||||||
|
help="option to use fp16 for inference")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--iteration',
|
'--iteration',
|
||||||
|
|
|
@ -24,9 +24,14 @@ def add_options_to_parser(parser):
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--use_gpu',
|
'--use_gpu',
|
||||||
type=bool,
|
type=utils.str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="option to use gpu training")
|
help="option to use gpu training")
|
||||||
|
parser.add_argument(
|
||||||
|
'--use_fp16',
|
||||||
|
type=utils.str2bool,
|
||||||
|
default=True,
|
||||||
|
help="option to use fp16 for inference")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--iteration',
|
'--iteration',
|
||||||
|
@ -74,7 +79,6 @@ def synthesize(config):
|
||||||
# Build model.
|
# Build model.
|
||||||
model = WaveFlow(config, checkpoint_dir)
|
model = WaveFlow(config, checkpoint_dir)
|
||||||
model.build(training=False)
|
model.build(training=False)
|
||||||
|
|
||||||
# Obtain the current iteration.
|
# Obtain the current iteration.
|
||||||
if config.checkpoint is None:
|
if config.checkpoint is None:
|
||||||
if config.iteration is None:
|
if config.iteration is None:
|
||||||
|
|
|
@ -127,4 +127,6 @@ if __name__ == "__main__":
|
||||||
# the preceding update will be overwritten by the following one.
|
# the preceding update will be overwritten by the following one.
|
||||||
config = parser.parse_args()
|
config = parser.parse_args()
|
||||||
config = utils.add_yaml_config(config)
|
config = utils.add_yaml_config(config)
|
||||||
|
# Force to use fp32 in model training
|
||||||
|
vars(config)["use_fp16"] = False
|
||||||
train(config)
|
train(config)
|
||||||
|
|
|
@ -126,7 +126,8 @@ def load_parameters(checkpoint_dir,
|
||||||
model,
|
model,
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
iteration=None,
|
iteration=None,
|
||||||
file_path=None):
|
file_path=None,
|
||||||
|
dtype="float32"):
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
if iteration is None:
|
if iteration is None:
|
||||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||||
|
@ -135,6 +136,12 @@ def load_parameters(checkpoint_dir,
|
||||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||||
|
|
||||||
model_dict, optimizer_dict = dg.load_dygraph(file_path)
|
model_dict, optimizer_dict = dg.load_dygraph(file_path)
|
||||||
|
if dtype == "float16":
|
||||||
|
for k, v in model_dict.items():
|
||||||
|
if "conv2d_transpose" in k:
|
||||||
|
model_dict[k] = v.astype("float32")
|
||||||
|
else:
|
||||||
|
model_dict[k] = v.astype(dtype)
|
||||||
model.set_dict(model_dict)
|
model.set_dict(model_dict)
|
||||||
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
|
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
|
||||||
if optimizer and optimizer_dict:
|
if optimizer and optimizer_dict:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from paddle import fluid
|
||||||
from scipy.io.wavfile import write
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
|
from parakeet.modules import weight_norm
|
||||||
from .data import LJSpeech
|
from .data import LJSpeech
|
||||||
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
||||||
|
|
||||||
|
@ -26,6 +27,7 @@ class WaveFlow():
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
self.nranks = nranks
|
self.nranks = nranks
|
||||||
self.tb_logger = tb_logger
|
self.tb_logger = tb_logger
|
||||||
|
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||||
|
|
||||||
def build(self, training=True):
|
def build(self, training=True):
|
||||||
config = self.config
|
config = self.config
|
||||||
|
@ -36,9 +38,9 @@ class WaveFlow():
|
||||||
waveflow = WaveFlowModule(config)
|
waveflow = WaveFlowModule(config)
|
||||||
|
|
||||||
# Dry run once to create and initalize all necessary parameters.
|
# Dry run once to create and initalize all necessary parameters.
|
||||||
audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32))
|
audio = dg.to_variable(np.random.randn(1, 16000).astype(self.dtype))
|
||||||
mel = dg.to_variable(
|
mel = dg.to_variable(
|
||||||
np.random.randn(1, config.mel_bands, 63).astype(np.float32))
|
np.random.randn(1, config.mel_bands, 63).astype(self.dtype))
|
||||||
waveflow(audio, mel)
|
waveflow(audio, mel)
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
|
@ -72,9 +74,14 @@ class WaveFlow():
|
||||||
self.rank,
|
self.rank,
|
||||||
waveflow,
|
waveflow,
|
||||||
iteration=config.iteration,
|
iteration=config.iteration,
|
||||||
file_path=config.checkpoint)
|
file_path=config.checkpoint,
|
||||||
|
dtype=self.dtype)
|
||||||
print("Rank {}: checkpoint loaded.".format(self.rank))
|
print("Rank {}: checkpoint loaded.".format(self.rank))
|
||||||
|
|
||||||
|
for layer in waveflow.sublayers():
|
||||||
|
if isinstance(layer, weight_norm.WeightNormWrapper):
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
self.waveflow = waveflow
|
self.waveflow = waveflow
|
||||||
|
|
||||||
def train_step(self, iteration):
|
def train_step(self, iteration):
|
||||||
|
@ -173,7 +180,7 @@ class WaveFlow():
|
||||||
syn_time))
|
syn_time))
|
||||||
|
|
||||||
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
|
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
|
||||||
audio = audio.numpy() * 32768.0
|
audio = audio.numpy().astype("float32") * 32768.0
|
||||||
audio = audio.astype('int16')
|
audio = audio.astype('int16')
|
||||||
write(filename, config.sample_rate, audio)
|
write(filename, config.sample_rate, audio)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
@ -49,7 +48,7 @@ class WaveFlowLoss:
|
||||||
|
|
||||||
|
|
||||||
class Conditioner(dg.Layer):
|
class Conditioner(dg.Layer):
|
||||||
def __init__(self):
|
def __init__(self, dtype):
|
||||||
super(Conditioner, self).__init__()
|
super(Conditioner, self).__init__()
|
||||||
upsample_factors = [16, 16]
|
upsample_factors = [16, 16]
|
||||||
|
|
||||||
|
@ -65,7 +64,8 @@ class Conditioner(dg.Layer):
|
||||||
padding=(1, s // 2),
|
padding=(1, s // 2),
|
||||||
stride=(1, s),
|
stride=(1, s),
|
||||||
param_attr=param_attr,
|
param_attr=param_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr,
|
||||||
|
dtype="float32")
|
||||||
self.upsample_conv2d.append(conv_trans2d)
|
self.upsample_conv2d.append(conv_trans2d)
|
||||||
|
|
||||||
for i, layer in enumerate(self.upsample_conv2d):
|
for i, layer in enumerate(self.upsample_conv2d):
|
||||||
|
@ -74,19 +74,30 @@ class Conditioner(dg.Layer):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = fluid.layers.unsqueeze(x, 1)
|
x = fluid.layers.unsqueeze(x, 1)
|
||||||
for layer in self.upsample_conv2d:
|
for layer in self.upsample_conv2d:
|
||||||
x = fluid.layers.leaky_relu(layer(x), alpha=0.4)
|
in_dtype = x.dtype
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
x = fluid.layers.cast(x, "float32")
|
||||||
|
x = layer(x)
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
x = fluid.layers.cast(x, "float16")
|
||||||
|
x = fluid.layers.leaky_relu(x, alpha=0.4)
|
||||||
|
|
||||||
return fluid.layers.squeeze(x, [1])
|
return fluid.layers.reshape(x, [x.shape[0], x.shape[2], x.shape[3]])
|
||||||
|
|
||||||
def infer(self, x):
|
def infer(self, x):
|
||||||
x = fluid.layers.unsqueeze(x, 1)
|
x = fluid.layers.unsqueeze(x, 1)
|
||||||
for layer in self.upsample_conv2d:
|
for layer in self.upsample_conv2d:
|
||||||
|
in_dtype = x.dtype
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
x = fluid.layers.cast(x, "float32")
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
x = fluid.layers.cast(x, "float16")
|
||||||
# Trim conv artifacts.
|
# Trim conv artifacts.
|
||||||
time_cutoff = layer._filter_size[1] - layer._stride[1]
|
time_cutoff = layer._filter_size[1] - layer._stride[1]
|
||||||
x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)
|
x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)
|
||||||
|
|
||||||
return fluid.layers.squeeze(x, [1])
|
return fluid.layers.reshape(x, [x.shape[0], x.shape[2], x.shape[3]])
|
||||||
|
|
||||||
|
|
||||||
class Flow(dg.Layer):
|
class Flow(dg.Layer):
|
||||||
|
@ -96,6 +107,7 @@ class Flow(dg.Layer):
|
||||||
self.n_channels = config.n_channels
|
self.n_channels = config.n_channels
|
||||||
self.kernel_h = config.kernel_h
|
self.kernel_h = config.kernel_h
|
||||||
self.kernel_w = config.kernel_w
|
self.kernel_w = config.kernel_w
|
||||||
|
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||||
|
|
||||||
# Transform audio: [batch, 1, n_group, time/n_group]
|
# Transform audio: [batch, 1, n_group, time/n_group]
|
||||||
# => [batch, n_channels, n_group, time/n_group]
|
# => [batch, n_channels, n_group, time/n_group]
|
||||||
|
@ -105,7 +117,8 @@ class Flow(dg.Layer):
|
||||||
num_filters=self.n_channels,
|
num_filters=self.n_channels,
|
||||||
filter_size=(1, 1),
|
filter_size=(1, 1),
|
||||||
param_attr=param_attr,
|
param_attr=param_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
# Initializing last layer to 0 makes the affine coupling layers
|
# Initializing last layer to 0 makes the affine coupling layers
|
||||||
# do nothing at first. This helps with training stability
|
# do nothing at first. This helps with training stability
|
||||||
|
@ -117,7 +130,8 @@ class Flow(dg.Layer):
|
||||||
num_filters=2,
|
num_filters=2,
|
||||||
filter_size=(1, 1),
|
filter_size=(1, 1),
|
||||||
param_attr=param_attr,
|
param_attr=param_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
|
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
|
||||||
dilation_dict = {
|
dilation_dict = {
|
||||||
|
@ -145,7 +159,8 @@ class Flow(dg.Layer):
|
||||||
filter_size=(self.kernel_h, self.kernel_w),
|
filter_size=(self.kernel_h, self.kernel_w),
|
||||||
dilation=(dilation_h, dilation_w),
|
dilation=(dilation_h, dilation_w),
|
||||||
param_attr=param_attr,
|
param_attr=param_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr,
|
||||||
|
dtype=self.dtype)
|
||||||
self.in_layers.append(in_layer)
|
self.in_layers.append(in_layer)
|
||||||
|
|
||||||
param_attr, bias_attr = get_param_attr(
|
param_attr, bias_attr = get_param_attr(
|
||||||
|
@ -155,7 +170,8 @@ class Flow(dg.Layer):
|
||||||
num_filters=2 * self.n_channels,
|
num_filters=2 * self.n_channels,
|
||||||
filter_size=(1, 1),
|
filter_size=(1, 1),
|
||||||
param_attr=param_attr,
|
param_attr=param_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr,
|
||||||
|
dtype=self.dtype)
|
||||||
self.cond_layers.append(cond_layer)
|
self.cond_layers.append(cond_layer)
|
||||||
|
|
||||||
if i < self.n_layers - 1:
|
if i < self.n_layers - 1:
|
||||||
|
@ -169,7 +185,8 @@ class Flow(dg.Layer):
|
||||||
num_filters=res_skip_channels,
|
num_filters=res_skip_channels,
|
||||||
filter_size=(1, 1),
|
filter_size=(1, 1),
|
||||||
param_attr=param_attr,
|
param_attr=param_attr,
|
||||||
bias_attr=bias_attr)
|
bias_attr=bias_attr,
|
||||||
|
dtype=self.dtype)
|
||||||
self.res_skip_layers.append(res_skip_layer)
|
self.res_skip_layers.append(res_skip_layer)
|
||||||
|
|
||||||
self.add_sublayer("in_layer_{}".format(i), in_layer)
|
self.add_sublayer("in_layer_{}".format(i), in_layer)
|
||||||
|
@ -191,7 +208,6 @@ class Flow(dg.Layer):
|
||||||
pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2)
|
pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2)
|
||||||
audio_pad = fluid.layers.pad2d(
|
audio_pad = fluid.layers.pad2d(
|
||||||
audio, paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
audio, paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||||
|
|
||||||
hidden = self.in_layers[i](audio_pad)
|
hidden = self.in_layers[i](audio_pad)
|
||||||
cond_hidden = self.cond_layers[i](mel)
|
cond_hidden = self.cond_layers[i](mel)
|
||||||
in_acts = hidden + cond_hidden
|
in_acts = hidden + cond_hidden
|
||||||
|
@ -239,7 +255,6 @@ class Flow(dg.Layer):
|
||||||
pad_right = int((self.kernel_w - 1) * dilation_w / 2)
|
pad_right = int((self.kernel_w - 1) * dilation_w / 2)
|
||||||
state = fluid.layers.pad2d(
|
state = fluid.layers.pad2d(
|
||||||
state, paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
state, paddings=[pad_top, pad_bottom, pad_left, pad_right])
|
||||||
|
|
||||||
hidden = self.in_layers[i](state)
|
hidden = self.in_layers[i](state)
|
||||||
cond_hidden = self.cond_layers[i](mel)
|
cond_hidden = self.cond_layers[i](mel)
|
||||||
in_acts = hidden + cond_hidden
|
in_acts = hidden + cond_hidden
|
||||||
|
@ -270,7 +285,8 @@ class WaveFlowModule(dg.Layer):
|
||||||
assert self.n_group % 2 == 0
|
assert self.n_group % 2 == 0
|
||||||
assert self.n_flows % 2 == 0
|
assert self.n_flows % 2 == 0
|
||||||
|
|
||||||
self.conditioner = Conditioner()
|
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||||
|
self.conditioner = Conditioner(self.dtype)
|
||||||
self.flows = []
|
self.flows = []
|
||||||
for i in range(self.n_flows):
|
for i in range(self.n_flows):
|
||||||
flow = Flow(config)
|
flow = Flow(config)
|
||||||
|
@ -324,17 +340,21 @@ class WaveFlowModule(dg.Layer):
|
||||||
mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
|
mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
|
||||||
mel = fluid.layers.stack(mel_slices, axis=2)
|
mel = fluid.layers.stack(mel_slices, axis=2)
|
||||||
|
|
||||||
z = fluid.layers.squeeze(audio, [1])
|
z = fluid.layers.reshape(
|
||||||
|
audio, [audio.shape[0], audio.shape[2], audio.shape[3]])
|
||||||
return z, log_s_list
|
return z, log_s_list
|
||||||
|
|
||||||
def synthesize(self, mel, sigma=1.0):
|
def synthesize(self, mel, sigma=1.0):
|
||||||
|
if self.dtype == "float16":
|
||||||
|
mel = fluid.layers.cast(mel, self.dtype)
|
||||||
mel = self.conditioner.infer(mel)
|
mel = self.conditioner.infer(mel)
|
||||||
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
|
||||||
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
|
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
|
||||||
|
|
||||||
audio = fluid.layers.gaussian_random(
|
audio = fluid.layers.gaussian_random(
|
||||||
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
|
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
|
||||||
|
if self.dtype == "float16":
|
||||||
|
audio = fluid.layers.cast(audio, self.dtype)
|
||||||
for i in reversed(range(self.n_flows)):
|
for i in reversed(range(self.n_flows)):
|
||||||
# Permute over the height dimension.
|
# Permute over the height dimension.
|
||||||
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
|
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
|
||||||
|
@ -362,9 +382,9 @@ class WaveFlowModule(dg.Layer):
|
||||||
audio = fluid.layers.concat(audio_list, axis=2)
|
audio = fluid.layers.concat(audio_list, axis=2)
|
||||||
|
|
||||||
# audio: [bs, n_group, time/n_group]
|
# audio: [bs, n_group, time/n_group]
|
||||||
audio = fluid.layers.squeeze(audio, [1])
|
audio = fluid.layers.reshape(
|
||||||
|
audio, [audio.shape[0], audio.shape[2], audio.shape[3]])
|
||||||
# audio: [bs, time]
|
# audio: [bs, time]
|
||||||
audio = fluid.layers.reshape(
|
audio = fluid.layers.reshape(
|
||||||
fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1])
|
fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1])
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
|
@ -8,8 +8,13 @@ from parakeet.modules import customized as L
|
||||||
|
|
||||||
def norm(param, dim, power):
|
def norm(param, dim, power):
|
||||||
powered = F.pow(param, power)
|
powered = F.pow(param, power)
|
||||||
|
in_dtype = powered.dtype
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
powered = F.cast(powered, "float32")
|
||||||
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
|
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
|
||||||
norm_ = F.pow(powered_norm, 1. / power)
|
norm_ = F.pow(powered_norm, 1. / power)
|
||||||
|
if in_dtype == fluid.core.VarDesc.VarType.FP16:
|
||||||
|
norm_ = F.cast(norm_, "float16")
|
||||||
return norm_
|
return norm_
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,6 +51,15 @@ def compute_weight(v, g, dim, power):
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def assign_by_cast(i, o):
|
||||||
|
fluid.default_main_program().current_block().append_op(
|
||||||
|
type="cast",
|
||||||
|
inputs={"X": i},
|
||||||
|
outputs={"Out": o},
|
||||||
|
attrs={"in_dtype": i.dtype,
|
||||||
|
"out_dtype": o.dtype})
|
||||||
|
|
||||||
|
|
||||||
class WeightNormWrapper(dg.Layer):
|
class WeightNormWrapper(dg.Layer):
|
||||||
def __init__(self, layer, param_name="weight", dim=0, power=2):
|
def __init__(self, layer, param_name="weight", dim=0, power=2):
|
||||||
super(WeightNormWrapper, self).__init__()
|
super(WeightNormWrapper, self).__init__()
|
||||||
|
@ -65,13 +79,13 @@ class WeightNormWrapper(dg.Layer):
|
||||||
w_v,
|
w_v,
|
||||||
self.create_parameter(
|
self.create_parameter(
|
||||||
shape=original_weight.shape, dtype=original_weight.dtype))
|
shape=original_weight.shape, dtype=original_weight.dtype))
|
||||||
F.assign(original_weight, getattr(self, w_v))
|
assign_by_cast(original_weight, getattr(self, w_v))
|
||||||
delattr(layer, param_name)
|
delattr(layer, param_name)
|
||||||
temp = norm_except(getattr(self, w_v), self.dim, self.power)
|
temp = norm_except(getattr(self, w_v), self.dim, self.power)
|
||||||
self.add_parameter(
|
self.add_parameter(
|
||||||
w_g, self.create_parameter(
|
w_g, self.create_parameter(
|
||||||
shape=temp.shape, dtype=temp.dtype))
|
shape=temp.shape, dtype=temp.dtype))
|
||||||
F.assign(temp, getattr(self, w_g))
|
assign_by_cast(temp, getattr(self, w_g))
|
||||||
|
|
||||||
# also set this when setting up
|
# also set this when setting up
|
||||||
setattr(self.layer, self.param_name,
|
setattr(self.layer, self.param_name,
|
||||||
|
|
Loading…
Reference in New Issue