From 737d142ae44186b25d1294395060a8e0b671b0de Mon Sep 17 00:00:00 2001 From: liuyibing01 Date: Tue, 25 Feb 2020 15:53:54 +0000 Subject: [PATCH] Enable the fp16 inference for waveflow --- README.md | 10 ++-- examples/waveflow/README.md | 10 ++++ examples/waveflow/benchmark.py | 7 ++- examples/waveflow/synthesis.py | 8 ++- examples/waveflow/train.py | 2 + examples/waveflow/utils.py | 9 ++- parakeet/models/waveflow/waveflow.py | 15 +++-- parakeet/models/waveflow/waveflow_modules.py | 58 +++++++++++++------- parakeet/modules/weight_norm.py | 18 +++++- 9 files changed, 104 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 2c5b475..8b7587b 100644 --- a/README.md +++ b/README.md @@ -36,14 +36,16 @@ nltk.download("cmudict") ``` -## Supported models +## Related Research - [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654) - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263). +- [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219) ## Examples -- [Train a deepvoice 3 model with ljspeech dataset](./parakeet/examples/deepvoice3) -- [Train a transformer_tts model with ljspeech dataset](./parakeet/examples/transformer_tts) -- [Train a fastspeech model with ljspeech dataset](./parakeet/examples/fastspeech) +- [Train a DeepVoice3 model with ljspeech dataset](./parakeet/examples/deepvoice3) +- [Train a TransformerTTS model with ljspeech dataset](./parakeet/examples/transformer_tts) +- [Train a FastSpeech model with ljspeech dataset](./parakeet/examples/fastspeech) +- [Train a WaveFlow model with ljspeech dataset](./parakeet/examples/waveflow) diff --git a/examples/waveflow/README.md b/examples/waveflow/README.md index 184396f..050bb17 100644 --- a/examples/waveflow/README.md +++ b/examples/waveflow/README.md @@ -109,3 +109,13 @@ python -u benchmark.py \ --root=./data/LJSpeech-1.1 \ --name=${ModelName} --use_gpu=true ``` + +### Low-precision inference + +This model supports the float16 low-precsion inference. By appending the argument + +```bash + --use_fp16=true +``` + +to the command of synthesis and benchmarking, one can experience the fast speed of low-precision inference. diff --git a/examples/waveflow/benchmark.py b/examples/waveflow/benchmark.py index eb6b6fc..24d83c4 100644 --- a/examples/waveflow/benchmark.py +++ b/examples/waveflow/benchmark.py @@ -24,9 +24,14 @@ def add_options_to_parser(parser): parser.add_argument( '--use_gpu', - type=bool, + type=utils.str2bool, default=True, help="option to use gpu training") + parser.add_argument( + '--use_fp16', + type=utils.str2bool, + default=True, + help="option to use fp16 for inference") parser.add_argument( '--iteration', diff --git a/examples/waveflow/synthesis.py b/examples/waveflow/synthesis.py index 1e3fb9e..76df229 100644 --- a/examples/waveflow/synthesis.py +++ b/examples/waveflow/synthesis.py @@ -24,9 +24,14 @@ def add_options_to_parser(parser): parser.add_argument( '--use_gpu', - type=bool, + type=utils.str2bool, default=True, help="option to use gpu training") + parser.add_argument( + '--use_fp16', + type=utils.str2bool, + default=True, + help="option to use fp16 for inference") parser.add_argument( '--iteration', @@ -74,7 +79,6 @@ def synthesize(config): # Build model. model = WaveFlow(config, checkpoint_dir) model.build(training=False) - # Obtain the current iteration. if config.checkpoint is None: if config.iteration is None: diff --git a/examples/waveflow/train.py b/examples/waveflow/train.py index e41597e..92bb9ef 100644 --- a/examples/waveflow/train.py +++ b/examples/waveflow/train.py @@ -127,4 +127,6 @@ if __name__ == "__main__": # the preceding update will be overwritten by the following one. config = parser.parse_args() config = utils.add_yaml_config(config) + # Force to use fp32 in model training + vars(config)["use_fp16"] = False train(config) diff --git a/examples/waveflow/utils.py b/examples/waveflow/utils.py index c088b1d..51f6296 100644 --- a/examples/waveflow/utils.py +++ b/examples/waveflow/utils.py @@ -126,7 +126,8 @@ def load_parameters(checkpoint_dir, model, optimizer=None, iteration=None, - file_path=None): + file_path=None, + dtype="float32"): if file_path is None: if iteration is None: iteration = load_latest_checkpoint(checkpoint_dir, rank) @@ -135,6 +136,12 @@ def load_parameters(checkpoint_dir, file_path = "{}/step-{}".format(checkpoint_dir, iteration) model_dict, optimizer_dict = dg.load_dygraph(file_path) + if dtype == "float16": + for k, v in model_dict.items(): + if "conv2d_transpose" in k: + model_dict[k] = v.astype("float32") + else: + model_dict[k] = v.astype(dtype) model.set_dict(model_dict) print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) if optimizer and optimizer_dict: diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index 569086e..1b1b8bf 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -8,6 +8,7 @@ from paddle import fluid from scipy.io.wavfile import write import utils +from parakeet.modules import weight_norm from .data import LJSpeech from .waveflow_modules import WaveFlowLoss, WaveFlowModule @@ -26,6 +27,7 @@ class WaveFlow(): self.rank = rank self.nranks = nranks self.tb_logger = tb_logger + self.dtype = "float16" if config.use_fp16 else "float32" def build(self, training=True): config = self.config @@ -36,9 +38,9 @@ class WaveFlow(): waveflow = WaveFlowModule(config) # Dry run once to create and initalize all necessary parameters. - audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32)) + audio = dg.to_variable(np.random.randn(1, 16000).astype(self.dtype)) mel = dg.to_variable( - np.random.randn(1, config.mel_bands, 63).astype(np.float32)) + np.random.randn(1, config.mel_bands, 63).astype(self.dtype)) waveflow(audio, mel) if training: @@ -72,9 +74,14 @@ class WaveFlow(): self.rank, waveflow, iteration=config.iteration, - file_path=config.checkpoint) + file_path=config.checkpoint, + dtype=self.dtype) print("Rank {}: checkpoint loaded.".format(self.rank)) + for layer in waveflow.sublayers(): + if isinstance(layer, weight_norm.WeightNormWrapper): + layer.remove_weight_norm() + self.waveflow = waveflow def train_step(self, iteration): @@ -173,7 +180,7 @@ class WaveFlow(): syn_time)) # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range. - audio = audio.numpy() * 32768.0 + audio = audio.numpy().astype("float32") * 32768.0 audio = audio.astype('int16') write(filename, config.sample_rate, audio) diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index c981fe7..1b8938a 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -1,5 +1,4 @@ import itertools - import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid @@ -49,7 +48,7 @@ class WaveFlowLoss: class Conditioner(dg.Layer): - def __init__(self): + def __init__(self, dtype): super(Conditioner, self).__init__() upsample_factors = [16, 16] @@ -65,7 +64,8 @@ class Conditioner(dg.Layer): padding=(1, s // 2), stride=(1, s), param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + dtype="float32") self.upsample_conv2d.append(conv_trans2d) for i, layer in enumerate(self.upsample_conv2d): @@ -74,19 +74,30 @@ class Conditioner(dg.Layer): def forward(self, x): x = fluid.layers.unsqueeze(x, 1) for layer in self.upsample_conv2d: - x = fluid.layers.leaky_relu(layer(x), alpha=0.4) + in_dtype = x.dtype + if in_dtype == fluid.core.VarDesc.VarType.FP16: + x = fluid.layers.cast(x, "float32") + x = layer(x) + if in_dtype == fluid.core.VarDesc.VarType.FP16: + x = fluid.layers.cast(x, "float16") + x = fluid.layers.leaky_relu(x, alpha=0.4) - return fluid.layers.squeeze(x, [1]) + return fluid.layers.reshape(x, [x.shape[0], x.shape[2], x.shape[3]]) def infer(self, x): x = fluid.layers.unsqueeze(x, 1) for layer in self.upsample_conv2d: + in_dtype = x.dtype + if in_dtype == fluid.core.VarDesc.VarType.FP16: + x = fluid.layers.cast(x, "float32") x = layer(x) + if in_dtype == fluid.core.VarDesc.VarType.FP16: + x = fluid.layers.cast(x, "float16") # Trim conv artifacts. time_cutoff = layer._filter_size[1] - layer._stride[1] x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4) - return fluid.layers.squeeze(x, [1]) + return fluid.layers.reshape(x, [x.shape[0], x.shape[2], x.shape[3]]) class Flow(dg.Layer): @@ -96,6 +107,7 @@ class Flow(dg.Layer): self.n_channels = config.n_channels self.kernel_h = config.kernel_h self.kernel_w = config.kernel_w + self.dtype = "float16" if config.use_fp16 else "float32" # Transform audio: [batch, 1, n_group, time/n_group] # => [batch, n_channels, n_group, time/n_group] @@ -105,7 +117,8 @@ class Flow(dg.Layer): num_filters=self.n_channels, filter_size=(1, 1), param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + dtype=self.dtype) # Initializing last layer to 0 makes the affine coupling layers # do nothing at first. This helps with training stability @@ -117,7 +130,8 @@ class Flow(dg.Layer): num_filters=2, filter_size=(1, 1), param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + dtype=self.dtype) # receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze dilation_dict = { @@ -145,7 +159,8 @@ class Flow(dg.Layer): filter_size=(self.kernel_h, self.kernel_w), dilation=(dilation_h, dilation_w), param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + dtype=self.dtype) self.in_layers.append(in_layer) param_attr, bias_attr = get_param_attr( @@ -155,7 +170,8 @@ class Flow(dg.Layer): num_filters=2 * self.n_channels, filter_size=(1, 1), param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + dtype=self.dtype) self.cond_layers.append(cond_layer) if i < self.n_layers - 1: @@ -169,7 +185,8 @@ class Flow(dg.Layer): num_filters=res_skip_channels, filter_size=(1, 1), param_attr=param_attr, - bias_attr=bias_attr) + bias_attr=bias_attr, + dtype=self.dtype) self.res_skip_layers.append(res_skip_layer) self.add_sublayer("in_layer_{}".format(i), in_layer) @@ -191,7 +208,6 @@ class Flow(dg.Layer): pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2) audio_pad = fluid.layers.pad2d( audio, paddings=[pad_top, pad_bottom, pad_left, pad_right]) - hidden = self.in_layers[i](audio_pad) cond_hidden = self.cond_layers[i](mel) in_acts = hidden + cond_hidden @@ -239,12 +255,11 @@ class Flow(dg.Layer): pad_right = int((self.kernel_w - 1) * dilation_w / 2) state = fluid.layers.pad2d( state, paddings=[pad_top, pad_bottom, pad_left, pad_right]) - hidden = self.in_layers[i](state) cond_hidden = self.cond_layers[i](mel) in_acts = hidden + cond_hidden out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ - fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) + fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) res_skip_acts = self.res_skip_layers[i](out_acts) if i < self.n_layers - 1: @@ -270,7 +285,8 @@ class WaveFlowModule(dg.Layer): assert self.n_group % 2 == 0 assert self.n_flows % 2 == 0 - self.conditioner = Conditioner() + self.dtype = "float16" if config.use_fp16 else "float32" + self.conditioner = Conditioner(self.dtype) self.flows = [] for i in range(self.n_flows): flow = Flow(config) @@ -324,17 +340,21 @@ class WaveFlowModule(dg.Layer): mel_slices = [mel[:, :, j, :] for j in self.perms[i]] mel = fluid.layers.stack(mel_slices, axis=2) - z = fluid.layers.squeeze(audio, [1]) + z = fluid.layers.reshape( + audio, [audio.shape[0], audio.shape[2], audio.shape[3]]) return z, log_s_list def synthesize(self, mel, sigma=1.0): + if self.dtype == "float16": + mel = fluid.layers.cast(mel, self.dtype) mel = self.conditioner.infer(mel) # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) audio = fluid.layers.gaussian_random( shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma) - + if self.dtype == "float16": + audio = fluid.layers.cast(audio, self.dtype) for i in reversed(range(self.n_flows)): # Permute over the height dimension. audio_slices = [audio[:, :, j, :] for j in self.perms[i]] @@ -362,9 +382,9 @@ class WaveFlowModule(dg.Layer): audio = fluid.layers.concat(audio_list, axis=2) # audio: [bs, n_group, time/n_group] - audio = fluid.layers.squeeze(audio, [1]) + audio = fluid.layers.reshape( + audio, [audio.shape[0], audio.shape[2], audio.shape[3]]) # audio: [bs, time] audio = fluid.layers.reshape( fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1]) - return audio diff --git a/parakeet/modules/weight_norm.py b/parakeet/modules/weight_norm.py index 992f099..9e28792 100644 --- a/parakeet/modules/weight_norm.py +++ b/parakeet/modules/weight_norm.py @@ -8,8 +8,13 @@ from parakeet.modules import customized as L def norm(param, dim, power): powered = F.pow(param, power) + in_dtype = powered.dtype + if in_dtype == fluid.core.VarDesc.VarType.FP16: + powered = F.cast(powered, "float32") powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False) norm_ = F.pow(powered_norm, 1. / power) + if in_dtype == fluid.core.VarDesc.VarType.FP16: + norm_ = F.cast(norm_, "float16") return norm_ @@ -46,6 +51,15 @@ def compute_weight(v, g, dim, power): return weight +def assign_by_cast(i, o): + fluid.default_main_program().current_block().append_op( + type="cast", + inputs={"X": i}, + outputs={"Out": o}, + attrs={"in_dtype": i.dtype, + "out_dtype": o.dtype}) + + class WeightNormWrapper(dg.Layer): def __init__(self, layer, param_name="weight", dim=0, power=2): super(WeightNormWrapper, self).__init__() @@ -65,13 +79,13 @@ class WeightNormWrapper(dg.Layer): w_v, self.create_parameter( shape=original_weight.shape, dtype=original_weight.dtype)) - F.assign(original_weight, getattr(self, w_v)) + assign_by_cast(original_weight, getattr(self, w_v)) delattr(layer, param_name) temp = norm_except(getattr(self, w_v), self.dim, self.power) self.add_parameter( w_g, self.create_parameter( shape=temp.shape, dtype=temp.dtype)) - F.assign(temp, getattr(self, w_g)) + assign_by_cast(temp, getattr(self, w_g)) # also set this when setting up setattr(self.layer, self.param_name,