Merge branch 'develop' of https://github.com/PaddlePaddle/Parakeet into use_mfa
This commit is contained in:
commit
96b8e44015
|
@ -142,3 +142,5 @@ dmypy.json
|
||||||
*.swp
|
*.swp
|
||||||
runs
|
runs
|
||||||
syn_audios
|
syn_audios
|
||||||
|
exp/
|
||||||
|
dump/
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
|
repos:
|
||||||
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
|
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
|
||||||
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
|
rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
files: \.py$
|
files: \.py$
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
sha: a11d9314b22d8f8c7556443875b731ef05965464
|
rev: a11d9314b22d8f8c7556443875b731ef05965464
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-merge-conflict
|
- id: check-merge-conflict
|
||||||
- id: check-symlinks
|
- id: check-symlinks
|
||||||
|
@ -15,7 +16,7 @@
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
files: \.md$
|
files: \.md$
|
||||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||||
sha: v1.0.1
|
rev: v1.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: forbid-crlf
|
- id: forbid-crlf
|
||||||
files: \.md$
|
files: \.md$
|
||||||
|
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
class Clip(object):
|
||||||
|
"""Collate functor for training vocoders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
batch_max_steps=20480,
|
||||||
|
hop_size=256,
|
||||||
|
aux_context_window=0, ):
|
||||||
|
"""Initialize customized collater for DataLoader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_max_steps (int): The maximum length of input signal in batch.
|
||||||
|
hop_size (int): Hop size of auxiliary features.
|
||||||
|
aux_context_window (int): Context window size for auxiliary feature conv.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if batch_max_steps % hop_size != 0:
|
||||||
|
batch_max_steps += -(batch_max_steps % hop_size)
|
||||||
|
assert batch_max_steps % hop_size == 0
|
||||||
|
self.batch_max_steps = batch_max_steps
|
||||||
|
self.batch_max_frames = batch_max_steps // hop_size
|
||||||
|
self.hop_size = hop_size
|
||||||
|
self.aux_context_window = aux_context_window
|
||||||
|
|
||||||
|
# set useful values in random cutting
|
||||||
|
self.start_offset = aux_context_window
|
||||||
|
self.end_offset = -(self.batch_max_frames + aux_context_window)
|
||||||
|
self.mel_threshold = self.batch_max_frames + 2 * aux_context_window
|
||||||
|
|
||||||
|
def __call__(self, examples):
|
||||||
|
"""Convert into batch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (list): list of tuple of the pair of audio and features. Audio shape
|
||||||
|
(T, ), features shape(T', C).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Auxiliary feature batch (B, C, T'), where
|
||||||
|
T = (T' - 2 * aux_context_window) * hop_size.
|
||||||
|
Tensor: Target signal batch (B, 1, T).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# check length
|
||||||
|
examples = [
|
||||||
|
self._adjust_length(b['wave'], b['feats']) for b in examples
|
||||||
|
if b['feats'].shape[0] > self.mel_threshold
|
||||||
|
]
|
||||||
|
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
|
||||||
|
|
||||||
|
# make batch with random cut
|
||||||
|
c_lengths = [c.shape[0] for c in cs]
|
||||||
|
start_frames = np.array([
|
||||||
|
np.random.randint(self.start_offset, cl + self.end_offset)
|
||||||
|
for cl in c_lengths
|
||||||
|
])
|
||||||
|
x_starts = start_frames * self.hop_size
|
||||||
|
x_ends = x_starts + self.batch_max_steps
|
||||||
|
|
||||||
|
c_starts = start_frames - self.aux_context_window
|
||||||
|
c_ends = start_frames + self.batch_max_frames + self.aux_context_window
|
||||||
|
y_batch = np.stack(
|
||||||
|
[x[start:end] for x, start, end in zip(xs, x_starts, x_ends)])
|
||||||
|
c_batch = np.stack(
|
||||||
|
[c[start:end] for c, start, end in zip(cs, c_starts, c_ends)])
|
||||||
|
|
||||||
|
# convert each batch to tensor, asuume that each item in batch has the same length
|
||||||
|
y_batch = paddle.to_tensor(
|
||||||
|
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
|
||||||
|
c_batch = paddle.to_tensor(
|
||||||
|
c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T')
|
||||||
|
|
||||||
|
return y_batch, c_batch
|
||||||
|
|
||||||
|
def _adjust_length(self, x, c):
|
||||||
|
"""Adjust the audio and feature lengths.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Basically we assume that the length of x and c are adjusted
|
||||||
|
through preprocessing stage, but if we use other library processed
|
||||||
|
features, this process will be needed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(x) < c.shape[1] * self.hop_size:
|
||||||
|
x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)),
|
||||||
|
mode="edge")
|
||||||
|
|
||||||
|
# check the legnth is valid
|
||||||
|
assert len(x) == c.shape[
|
||||||
|
0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})"
|
||||||
|
|
||||||
|
return x, c
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Calculate statistics of feature files."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.utils.h5_utils import read_hdf5
|
||||||
|
from parakeet.utils.h5_utils import write_hdf5
|
||||||
|
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run preprocessing process."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Compute mean and variance of dumped raw features.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata", type=str, help="json file with id and file paths ")
|
||||||
|
parser.add_argument(
|
||||||
|
"--field-name",
|
||||||
|
type=str,
|
||||||
|
help="name of the field to compute statistics for.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="yaml format configuration file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dumpdir",
|
||||||
|
type=str,
|
||||||
|
help="directory to save statistics. if not provided, "
|
||||||
|
"stats will be saved in the above root directory. (default=None)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="logging level. higher is more logging. (default=1)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# set logger
|
||||||
|
if args.verbose > 1:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
elif args.verbose > 0:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARN,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
logging.warning('Skip DEBUG/INFO messages')
|
||||||
|
|
||||||
|
config = get_cfg_default()
|
||||||
|
# load config
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# check directory existence
|
||||||
|
if args.dumpdir is None:
|
||||||
|
args.dumpdir = os.path.dirname(args.metadata)
|
||||||
|
if not os.path.exists(args.dumpdir):
|
||||||
|
os.makedirs(args.dumpdir)
|
||||||
|
|
||||||
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
|
metadata = list(reader)
|
||||||
|
dataset = DataTable(
|
||||||
|
metadata,
|
||||||
|
fields=[args.field_name],
|
||||||
|
converters={args.field_name: np.load}, )
|
||||||
|
logging.info(f"The number of files = {len(dataset)}.")
|
||||||
|
|
||||||
|
# calculate statistics
|
||||||
|
scaler = StandardScaler()
|
||||||
|
for datum in tqdm(dataset):
|
||||||
|
# StandardScalar supports (*, num_features) by default
|
||||||
|
scaler.partial_fit(datum[args.field_name])
|
||||||
|
|
||||||
|
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||||
|
np.save(
|
||||||
|
os.path.join(args.dumpdir, "stats.npy"),
|
||||||
|
stats.astype(np.float32),
|
||||||
|
allow_pickle=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,128 @@
|
||||||
|
# This is the hyperparameter configuration file for Parallel WaveGAN.
|
||||||
|
# Please make sure this is adjusted for the CSMSC dataset. If you want to
|
||||||
|
# apply to the other dataset, you might need to carefully change some parameters.
|
||||||
|
# This configuration requires 12 GB GPU memory and takes ~3 days on RTX TITAN.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
sr: 24000 # Sampling rate.
|
||||||
|
n_fft: 2048 # FFT size.
|
||||||
|
hop_length: 300 # Hop size.
|
||||||
|
win_length: 1200 # Window length.
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
n_mels: 80 # Number of mel basis.
|
||||||
|
fmin: 80 # Minimum freq in mel basis calculation.
|
||||||
|
fmax: 7600 # Maximum frequency in mel basis calculation.
|
||||||
|
# global_gain_scale: 1.0 # Will be multiplied to all of waveform.
|
||||||
|
trim_silence: false # Whether to trim the start and end of silence.
|
||||||
|
top_db: 60 # Need to tune carefully if the recording is not good.
|
||||||
|
trim_frame_length: 2048 # Frame size in trimming.(in samples)
|
||||||
|
trim_hop_length: 512 # Hop size in trimming.(in samples)
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_params:
|
||||||
|
in_channels: 1 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_size: 3 # Kernel size of dilated convolution.
|
||||||
|
layers: 30 # Number of residual block layers.
|
||||||
|
stacks: 3 # Number of stacks i.e., dilation cycles.
|
||||||
|
residual_channels: 64 # Number of channels in residual conv.
|
||||||
|
gate_channels: 128 # Number of channels in gated conv.
|
||||||
|
skip_channels: 64 # Number of channels in skip conv.
|
||||||
|
aux_channels: 80 # Number of channels for auxiliary feature conv.
|
||||||
|
# Must be the same as num_mels.
|
||||||
|
aux_context_window: 2 # Context window size for auxiliary feature.
|
||||||
|
# If set to 2, previous 2 and future 2 frames will be considered.
|
||||||
|
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
|
||||||
|
bias: true # use bias in residual blocks
|
||||||
|
use_weight_norm: true # Whether to use weight norm.
|
||||||
|
# If set to true, it will be applied to all of the conv layers.
|
||||||
|
use_causal_conv: false # use causal conv in residual blocks and upsample layers
|
||||||
|
# upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
|
||||||
|
upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size.
|
||||||
|
interpolate_mode: "nearest" # upsample net interpolate mode
|
||||||
|
freq_axis_kernel_size: 1 # upsamling net: convolution kernel size in frequencey axis
|
||||||
|
nonlinear_activation: null
|
||||||
|
nonlinear_activation_params: {}
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||||
|
###########################################################
|
||||||
|
discriminator_params:
|
||||||
|
in_channels: 1 # Number of input channels.
|
||||||
|
out_channels: 1 # Number of output channels.
|
||||||
|
kernel_size: 3 # Number of output channels.
|
||||||
|
layers: 10 # Number of conv layers.
|
||||||
|
conv_channels: 64 # Number of chnn layers.
|
||||||
|
bias: true # Whether to use bias parameter in conv.
|
||||||
|
use_weight_norm: true # Whether to use weight norm.
|
||||||
|
# If set to true, it will be applied to all of the conv layers.
|
||||||
|
nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
|
||||||
|
nonlinear_activation_params: # Nonlinear function parameters
|
||||||
|
negative_slope: 0.2 # Alpha in LeakyReLU.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# STFT LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
stft_loss_params:
|
||||||
|
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
|
||||||
|
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
|
||||||
|
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
|
||||||
|
window: "hann" # Window function for STFT-based loss
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# ADVERSARIAL LOSS SETTING #
|
||||||
|
###########################################################
|
||||||
|
lambda_adv: 4.0 # Loss balancing coefficient.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA LOADER SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 6 # Batch size.
|
||||||
|
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||||
|
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
||||||
|
num_workers: 4 # Number of workers in Pytorch DataLoader.
|
||||||
|
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
|
||||||
|
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER & SCHEDULER SETTING #
|
||||||
|
###########################################################
|
||||||
|
generator_optimizer_params:
|
||||||
|
epsilon: 1.0e-6 # Generator's epsilon.
|
||||||
|
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||||
|
generator_scheduler_params:
|
||||||
|
learning_rate: 0.0001 # Generator's learning rate.
|
||||||
|
step_size: 200000 # Generator's scheduler step size.
|
||||||
|
gamma: 0.5 # Generator's scheduler gamma.
|
||||||
|
# At each step size, lr will be multiplied by this parameter.
|
||||||
|
generator_grad_norm: 10 # Generator's gradient norm.
|
||||||
|
discriminator_optimizer_params:
|
||||||
|
epsilon: 1.0e-6 # Discriminator's epsilon.
|
||||||
|
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||||
|
discriminator_scheduler_params:
|
||||||
|
learning_rate: 0.00005 # Discriminator's learning rate.
|
||||||
|
step_size: 200000 # Discriminator's scheduler step size.
|
||||||
|
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||||
|
# At each step size, lr will be multiplied by this parameter.
|
||||||
|
discriminator_grad_norm: 1 # Discriminator's gradient norm.
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# INTERVAL SETTING #
|
||||||
|
###########################################################
|
||||||
|
discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator.
|
||||||
|
train_max_steps: 400000 # Number of training steps.
|
||||||
|
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||||
|
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
|
||||||
|
num_snapshots: 10 # max number of snapshots to keep while training
|
||||||
|
seed: 42 # random seed for paddle, random, and np.random
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode as Configuration
|
||||||
|
|
||||||
|
with open("conf/default.yaml", 'rt') as f:
|
||||||
|
_C = yaml.safe_load(f)
|
||||||
|
_C = Configuration(_C)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cfg_default():
|
||||||
|
config = _C.clone()
|
||||||
|
return config
|
|
@ -0,0 +1,145 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Normalize feature files and dump them."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from operator import itemgetter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
import jsonlines
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run preprocessing process."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory including feature files to be normalized. "
|
||||||
|
"you need to specify either *-scp or rootdir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dumpdir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory to dump normalized feature files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stats", type=str, required=True, help="statistics file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-wav-copy",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to skip the copy of wav files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="yaml format configuration file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="logging level. higher is more logging. (default=1)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# set logger
|
||||||
|
if args.verbose > 1:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
elif args.verbose > 0:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARN,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
logging.warning('Skip DEBUG/INFO messages')
|
||||||
|
|
||||||
|
# load config
|
||||||
|
config = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# check directory existence
|
||||||
|
dumpdir = Path(args.dumpdir).resolve()
|
||||||
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# get dataset
|
||||||
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
|
metadata = list(reader)
|
||||||
|
dataset = DataTable(
|
||||||
|
metadata,
|
||||||
|
fields=["utt_id", "wave", "feats"],
|
||||||
|
converters={
|
||||||
|
'utt_id': None,
|
||||||
|
'wave': None if args.skip_wav_copy else np.load,
|
||||||
|
'feats': np.load,
|
||||||
|
})
|
||||||
|
logging.info(f"The number of files = {len(dataset)}.")
|
||||||
|
|
||||||
|
# restore scaler
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaler.mean_ = np.load(args.stats)[0]
|
||||||
|
scaler.scale_ = np.load(args.stats)[1]
|
||||||
|
|
||||||
|
# from version 0.23.0, this information is needed
|
||||||
|
scaler.n_features_in_ = scaler.mean_.shape[0]
|
||||||
|
|
||||||
|
# process each file
|
||||||
|
output_metadata = []
|
||||||
|
|
||||||
|
for item in tqdm(dataset):
|
||||||
|
utt_id = item['utt_id']
|
||||||
|
wave = item['wave']
|
||||||
|
mel = item['feats']
|
||||||
|
# normalize
|
||||||
|
mel = scaler.transform(mel)
|
||||||
|
|
||||||
|
# save
|
||||||
|
mel_path = dumpdir / f"{utt_id}-feats.npy"
|
||||||
|
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
|
||||||
|
if not args.skip_wav_copy:
|
||||||
|
wav_path = dumpdir / f"{utt_id}-wave.npy"
|
||||||
|
np.save(wav_path, wave.astype(np.float32), allow_pickle=False)
|
||||||
|
else:
|
||||||
|
wav_path = wave
|
||||||
|
output_metadata.append({
|
||||||
|
'utt_id': utt_id,
|
||||||
|
'wave': str(wav_path),
|
||||||
|
'feats': str(mel_path),
|
||||||
|
})
|
||||||
|
output_metadata.sort(key=itemgetter('utt_id'))
|
||||||
|
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||||
|
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||||
|
for item in output_metadata:
|
||||||
|
writer.write(item)
|
||||||
|
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,287 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import soundfile as sf
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import jsonlines
|
||||||
|
import concurrent.futures
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
import tqdm
|
||||||
|
from operator import itemgetter
|
||||||
|
from praatio import tgio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
def logmelfilterbank(audio,
|
||||||
|
sr,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=None,
|
||||||
|
window="hann",
|
||||||
|
n_mels=80,
|
||||||
|
fmin=None,
|
||||||
|
fmax=None,
|
||||||
|
eps=1e-10):
|
||||||
|
"""Compute log-Mel filterbank feature.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio : ndarray
|
||||||
|
Audio signal (T,).
|
||||||
|
sr : int
|
||||||
|
Sampling rate.
|
||||||
|
n_fft : int
|
||||||
|
FFT size. (Default value = 1024)
|
||||||
|
hop_length : int
|
||||||
|
Hop size. (Default value = 256)
|
||||||
|
win_length : int
|
||||||
|
Window length. If set to None, it will be the same as fft_size. (Default value = None)
|
||||||
|
window : str
|
||||||
|
Window function type. (Default value = "hann")
|
||||||
|
n_mels : int
|
||||||
|
Number of mel basis. (Default value = 80)
|
||||||
|
fmin : int
|
||||||
|
Minimum frequency in mel basis calculation. (Default value = None)
|
||||||
|
fmax : int
|
||||||
|
Maximum frequency in mel basis calculation. (Default value = None)
|
||||||
|
eps : float
|
||||||
|
Epsilon value to avoid inf in log calculation. (Default value = 1e-10)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
Log Mel filterbank feature (#frames, num_mels).
|
||||||
|
|
||||||
|
"""
|
||||||
|
# get amplitude spectrogram
|
||||||
|
x_stft = librosa.stft(
|
||||||
|
audio,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window,
|
||||||
|
pad_mode="reflect")
|
||||||
|
spc = np.abs(x_stft) # (#bins, #frames,)
|
||||||
|
|
||||||
|
# get mel basis
|
||||||
|
fmin = 0 if fmin is None else fmin
|
||||||
|
fmax = sr / 2 if fmax is None else fmax
|
||||||
|
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||||
|
|
||||||
|
return np.log10(np.maximum(eps, np.dot(mel_basis, spc)))
|
||||||
|
|
||||||
|
|
||||||
|
def process_sentence(config: Dict[str, Any],
|
||||||
|
fp: Path,
|
||||||
|
alignment_fp: Path,
|
||||||
|
output_dir: Path):
|
||||||
|
utt_id = fp.stem
|
||||||
|
|
||||||
|
# reading
|
||||||
|
y, sr = librosa.load(fp, sr=config.sr) # resampling may occur
|
||||||
|
assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||||
|
assert np.abs(y).max(
|
||||||
|
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||||
|
duration = librosa.get_duration(y, sr=sr)
|
||||||
|
|
||||||
|
# trim according to the alignment file
|
||||||
|
alignment = tgio.openTextgrid(alignment_fp)
|
||||||
|
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||||
|
first, last = intervals[0], intervals[-1]
|
||||||
|
start = 0
|
||||||
|
end = last.end
|
||||||
|
if first.label == "sil" and first.end < duration:
|
||||||
|
start = first.end
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}"
|
||||||
|
)
|
||||||
|
if last.label == "sil" and last.start < duration:
|
||||||
|
end = last.start
|
||||||
|
else:
|
||||||
|
end = duration
|
||||||
|
logging.warning(
|
||||||
|
f" There is something wrong with the last interval {last} in utterance: {utt_id}"
|
||||||
|
)
|
||||||
|
# silence trimmed
|
||||||
|
start, end = librosa.time_to_samples([first.end, last.start], sr=sr)
|
||||||
|
y = y[start:end]
|
||||||
|
|
||||||
|
# energy based silence trimming
|
||||||
|
if config.trim_silence:
|
||||||
|
y, _ = librosa.effects.trim(
|
||||||
|
y,
|
||||||
|
top_db=config.top_db,
|
||||||
|
frame_length=config.trim_frame_length,
|
||||||
|
hop_length=config.trim_hop_length)
|
||||||
|
|
||||||
|
logmel = logmelfilterbank(
|
||||||
|
y,
|
||||||
|
sr=sr,
|
||||||
|
n_fft=config.n_fft,
|
||||||
|
window=config.window,
|
||||||
|
win_length=config.win_length,
|
||||||
|
hop_length=config.hop_length,
|
||||||
|
n_mels=config.n_mels,
|
||||||
|
fmin=config.fmin,
|
||||||
|
fmax=config.fmax)
|
||||||
|
|
||||||
|
# adjust time to make num_samples == num_frames * hop_length
|
||||||
|
num_frames = logmel.shape[1]
|
||||||
|
if y.size < num_frames * config.hop_length:
|
||||||
|
y = np.pad(y, (0, num_frames * config.hop_length - y.size),
|
||||||
|
mode="reflect")
|
||||||
|
else:
|
||||||
|
y = y[:num_frames * config.hop_length]
|
||||||
|
num_sample = y.shape[0]
|
||||||
|
|
||||||
|
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||||
|
wav_path = output_dir / (utt_id + "_wave.npy")
|
||||||
|
np.save(wav_path, y) # (num_samples, )
|
||||||
|
np.save(mel_path, logmel.T) # (num_frames, n_mels)
|
||||||
|
record = {
|
||||||
|
"utt_id": utt_id,
|
||||||
|
"num_samples": num_sample,
|
||||||
|
"num_frames": num_frames,
|
||||||
|
"feats": str(mel_path.resolve()),
|
||||||
|
"wave": str(wav_path.resolve()),
|
||||||
|
}
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
def process_sentences(config,
|
||||||
|
fps: List[Path],
|
||||||
|
alignment_fps: List[Path],
|
||||||
|
output_dir: Path,
|
||||||
|
nprocs: int=1):
|
||||||
|
if nprocs == 1:
|
||||||
|
results = []
|
||||||
|
for fp, alignment_fp in tqdm.tqdm(zip(fps, alignment_fps)):
|
||||||
|
results.append(
|
||||||
|
process_sentence(config, fp, alignment_fp, output_dir))
|
||||||
|
else:
|
||||||
|
with ThreadPoolExecutor(nprocs) as pool:
|
||||||
|
futures = []
|
||||||
|
with tqdm.tqdm(total=len(fps)) as progress:
|
||||||
|
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||||
|
future = pool.submit(process_sentence, config, fp,
|
||||||
|
alignment_fp, output_dir)
|
||||||
|
future.add_done_callback(lambda p: progress.update())
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for ft in futures:
|
||||||
|
results.append(ft.result())
|
||||||
|
|
||||||
|
results.sort(key=itemgetter("utt_id"))
|
||||||
|
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||||
|
for item in results:
|
||||||
|
writer.write(item)
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse config and args
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rootdir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="directory including wav files. you need to specify either scp or rootdir."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dumpdir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory to dump feature files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="yaml format configuration file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="logging level. higher is more logging. (default=1)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_cpu", type=int, default=1, help="number of process.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
C = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
C.merge_from_file(args.config)
|
||||||
|
C.freeze()
|
||||||
|
|
||||||
|
if args.verbose > 1:
|
||||||
|
print(vars(args))
|
||||||
|
print(C)
|
||||||
|
|
||||||
|
root_dir = Path(args.rootdir).expanduser()
|
||||||
|
dumpdir = Path(args.dumpdir).expanduser()
|
||||||
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
|
||||||
|
alignment_files = sorted(
|
||||||
|
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||||
|
|
||||||
|
# split data into 3 sections
|
||||||
|
num_train = 9800
|
||||||
|
num_dev = 100
|
||||||
|
|
||||||
|
train_wav_files = wav_files[:num_train]
|
||||||
|
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||||
|
test_wav_files = wav_files[num_train + num_dev:]
|
||||||
|
|
||||||
|
train_alignment_files = alignment_files[:num_train]
|
||||||
|
dev_alignment_files = alignment_files[num_train:num_train + num_dev]
|
||||||
|
test_alignment_files = alignment_files[num_train + num_dev:]
|
||||||
|
|
||||||
|
train_dump_dir = dumpdir / "train" / "raw"
|
||||||
|
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||||
|
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
test_dump_dir = dumpdir / "test" / "raw"
|
||||||
|
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# process for the 3 sections
|
||||||
|
process_sentences(
|
||||||
|
C,
|
||||||
|
train_wav_files,
|
||||||
|
train_alignment_files,
|
||||||
|
train_dump_dir,
|
||||||
|
nprocs=args.num_cpu)
|
||||||
|
process_sentences(
|
||||||
|
C,
|
||||||
|
dev_wav_files,
|
||||||
|
dev_alignment_files,
|
||||||
|
dev_dump_dir,
|
||||||
|
nprocs=args.num_cpu)
|
||||||
|
process_sentences(
|
||||||
|
C,
|
||||||
|
test_wav_files,
|
||||||
|
test_alignment_files,
|
||||||
|
test_dump_dir,
|
||||||
|
nprocs=args.num_cpu)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,184 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from paddle.optimizer.lr import LRScheduler
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from timer import timer
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
|
||||||
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||||
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
|
from parakeet.utils.profile import synchronize
|
||||||
|
|
||||||
|
|
||||||
|
class PWGUpdater(StandardUpdater):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
models: Dict[str, Layer],
|
||||||
|
optimizers: Dict[str, Optimizer],
|
||||||
|
criterions: Dict[str, Layer],
|
||||||
|
schedulers: Dict[str, LRScheduler],
|
||||||
|
dataloader: DataLoader,
|
||||||
|
discriminator_train_start_steps: int,
|
||||||
|
lambda_adv: float, ):
|
||||||
|
self.models = models
|
||||||
|
self.generator: Layer = models['generator']
|
||||||
|
self.discriminator: Layer = models['discriminator']
|
||||||
|
|
||||||
|
self.optimizers = optimizers
|
||||||
|
self.optimizer_g: Optimizer = optimizers['generator']
|
||||||
|
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_stft = criterions['stft']
|
||||||
|
self.criterion_mse = criterions['mse']
|
||||||
|
|
||||||
|
self.schedulers = schedulers
|
||||||
|
self.scheduler_g = schedulers['generator']
|
||||||
|
self.scheduler_d = schedulers['discriminator']
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.state = UpdaterState(iteration=0, epoch=0)
|
||||||
|
|
||||||
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
# parse batch
|
||||||
|
wav, mel = batch
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
|
with timer() as t:
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
logging.debug(f"Generator takes {t.elapse}s.")
|
||||||
|
|
||||||
|
## Multi-resolution stft loss
|
||||||
|
|
||||||
|
with timer() as t:
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||||
|
|
||||||
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
gen_loss = sc_loss + mag_loss
|
||||||
|
|
||||||
|
## Adversarial loss
|
||||||
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
with timer() as t:
|
||||||
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
|
logging.debug(
|
||||||
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
|
||||||
|
with timer() as t:
|
||||||
|
self.optimizer_g.clear_grad()
|
||||||
|
gen_loss.backward()
|
||||||
|
logging.debug(f"Backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
with timer() as t:
|
||||||
|
self.optimizer_g.step()
|
||||||
|
self.scheduler_g.step()
|
||||||
|
logging.debug(f"Update takes {t.elapse}s.")
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
p = self.discriminator(wav)
|
||||||
|
p_ = self.discriminator(wav_.detach())
|
||||||
|
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||||
|
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||||
|
report("train/real_loss", float(real_loss))
|
||||||
|
report("train/fake_loss", float(fake_loss))
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
report("train/discriminator_loss", float(dis_loss))
|
||||||
|
|
||||||
|
self.optimizer_d.clear_grad()
|
||||||
|
dis_loss.backward()
|
||||||
|
|
||||||
|
self.optimizer_d.step()
|
||||||
|
self.scheduler_d.step()
|
||||||
|
|
||||||
|
|
||||||
|
class PWGEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self, models, criterions, dataloader, lambda_adv):
|
||||||
|
self.models = models
|
||||||
|
self.generator = models['generator']
|
||||||
|
self.discriminator = models['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_stft = criterions['stft']
|
||||||
|
self.criterion_mse = criterions['mse']
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
logging.debug("Evaluate: ")
|
||||||
|
wav, mel = batch
|
||||||
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
|
with timer() as t:
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
logging.debug(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
|
## Adversarial loss
|
||||||
|
with timer() as t:
|
||||||
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
|
logging.debug(
|
||||||
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
|
report("eval/adversarial_loss", float(adv_loss))
|
||||||
|
gen_loss = self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
# stft loss
|
||||||
|
with timer() as t:
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
gen_loss += sc_loss + mag_loss
|
||||||
|
|
||||||
|
report("eval/generator_loss", float(gen_loss))
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
p = self.discriminator(wav)
|
||||||
|
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||||
|
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||||
|
report("eval/real_loss", float(real_loss))
|
||||||
|
report("eval/fake_loss", float(fake_loss))
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
report("eval/discriminator_loss", float(dis_loss))
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from timer import timer
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import jsonlines
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
from paddle import distributed as dist
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="synthesize with parallel wavegan.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config")
|
||||||
|
parser.add_argument("--checkpoint", type=str, help="snapshot to load")
|
||||||
|
parser.add_argument("--test-metadata", type=str, help="dev data")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
|
parser.add_argument("--device", type=str, default="gpu", help="device to run")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
config = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
paddle.set_device(args.device)
|
||||||
|
generator = PWGGenerator(**config["generator_params"])
|
||||||
|
state_dict = paddle.load(args.checkpoint)
|
||||||
|
generator.set_state_dict(state_dict["generator_params"])
|
||||||
|
|
||||||
|
generator.remove_weight_norm()
|
||||||
|
generator.eval()
|
||||||
|
with jsonlines.open(args.test_metadata, 'r') as reader:
|
||||||
|
metadata = list(reader)
|
||||||
|
|
||||||
|
test_dataset = DataTable(
|
||||||
|
metadata,
|
||||||
|
fields=['utt_id', 'feats'],
|
||||||
|
converters={
|
||||||
|
'utt_id': None,
|
||||||
|
'feats': np.load,
|
||||||
|
})
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
N = 0
|
||||||
|
T = 0
|
||||||
|
for example in test_dataset:
|
||||||
|
utt_id = example['utt_id']
|
||||||
|
mel = example['feats']
|
||||||
|
mel = paddle.to_tensor(mel) # (T, C)
|
||||||
|
with timer() as t:
|
||||||
|
wav = generator.inference(c=mel)
|
||||||
|
wav = wav.numpy()
|
||||||
|
N += wav.size
|
||||||
|
T += t.elapse
|
||||||
|
speed = wav.size / t.elapse
|
||||||
|
print(
|
||||||
|
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}."
|
||||||
|
)
|
||||||
|
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr)
|
||||||
|
print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }")
|
|
@ -0,0 +1,246 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import jsonlines
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader, DistributedBatchSampler
|
||||||
|
from paddle.optimizer import Adam # No RAdaom
|
||||||
|
from paddle.optimizer.lr import StepDecay
|
||||||
|
from paddle import DataParallel
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.training.updater import UpdaterBase
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.training import extension
|
||||||
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
from parakeet.training.extensions.visualizer import VisualDL
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||||
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
|
from parakeet.training.seeding import seed_everything
|
||||||
|
|
||||||
|
from batch_fn import Clip
|
||||||
|
from config import get_cfg_default
|
||||||
|
from pwg_updater import PWGUpdater, PWGEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
def train_sp(args, config):
|
||||||
|
# decides device type and whether to run in parallel
|
||||||
|
# setup running environment correctly
|
||||||
|
if not paddle.is_compiled_with_cuda:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
else:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
world_size = paddle.distributed.get_world_size()
|
||||||
|
if world_size > 1:
|
||||||
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(config.seed)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# dataloader has been too verbose
|
||||||
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
# construct dataset for training and validation
|
||||||
|
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||||
|
train_metadata = list(reader)
|
||||||
|
train_dataset = DataTable(
|
||||||
|
data=train_metadata,
|
||||||
|
fields=["wave", "feats"],
|
||||||
|
converters={
|
||||||
|
"wave": np.load,
|
||||||
|
"feats": np.load,
|
||||||
|
}, )
|
||||||
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
|
dev_metadata = list(reader)
|
||||||
|
dev_dataset = DataTable(
|
||||||
|
data=dev_metadata,
|
||||||
|
fields=["wave", "feats"],
|
||||||
|
converters={
|
||||||
|
"wave": np.load,
|
||||||
|
"feats": np.load,
|
||||||
|
}, )
|
||||||
|
|
||||||
|
# collate function and dataloader
|
||||||
|
train_sampler = DistributedBatchSampler(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=True)
|
||||||
|
dev_sampler = DistributedBatchSampler(
|
||||||
|
dev_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False)
|
||||||
|
print("samplers done!")
|
||||||
|
|
||||||
|
train_batch_fn = Clip(
|
||||||
|
batch_max_steps=config.batch_max_steps,
|
||||||
|
hop_size=config.hop_length,
|
||||||
|
aux_context_window=config.generator_params.aux_context_window)
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=train_sampler,
|
||||||
|
collate_fn=train_batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
dev_dataloader = DataLoader(
|
||||||
|
dev_dataset,
|
||||||
|
batch_sampler=dev_sampler,
|
||||||
|
collate_fn=train_batch_fn,
|
||||||
|
num_workers=config.num_workers)
|
||||||
|
print("dataloaders done!")
|
||||||
|
|
||||||
|
generator = PWGGenerator(**config["generator_params"])
|
||||||
|
discriminator = PWGDiscriminator(**config["discriminator_params"])
|
||||||
|
if world_size > 1:
|
||||||
|
generator = DataParallel(generator)
|
||||||
|
discriminator = DataParallel(discriminator)
|
||||||
|
print("models done!")
|
||||||
|
|
||||||
|
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
||||||
|
criterion_mse = nn.MSELoss()
|
||||||
|
print("criterions done!")
|
||||||
|
|
||||||
|
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
||||||
|
gradient_clip_g = nn.ClipGradByGlobalNorm(config["generator_grad_norm"])
|
||||||
|
optimizer_g = Adam(
|
||||||
|
learning_rate=lr_schedule_g,
|
||||||
|
grad_clip=gradient_clip_g,
|
||||||
|
parameters=generator.parameters(),
|
||||||
|
**config["generator_optimizer_params"])
|
||||||
|
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
||||||
|
gradient_clip_d = nn.ClipGradByGlobalNorm(config[
|
||||||
|
"discriminator_grad_norm"])
|
||||||
|
optimizer_d = Adam(
|
||||||
|
learning_rate=lr_schedule_d,
|
||||||
|
grad_clip=gradient_clip_d,
|
||||||
|
parameters=discriminator.parameters(),
|
||||||
|
**config["discriminator_optimizer_params"])
|
||||||
|
print("optimizers done!")
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
checkpoint_dir = output_dir / "checkpoints"
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(output_dir / "config.yaml", 'wt') as f:
|
||||||
|
f.write(config.dump(default_flow_style=None))
|
||||||
|
|
||||||
|
updater = PWGUpdater(
|
||||||
|
models={
|
||||||
|
"generator": generator,
|
||||||
|
"discriminator": discriminator,
|
||||||
|
},
|
||||||
|
optimizers={
|
||||||
|
"generator": optimizer_g,
|
||||||
|
"discriminator": optimizer_d,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"stft": criterion_stft,
|
||||||
|
"mse": criterion_mse,
|
||||||
|
},
|
||||||
|
schedulers={
|
||||||
|
"generator": lr_schedule_g,
|
||||||
|
"discriminator": lr_schedule_d,
|
||||||
|
},
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||||
|
lambda_adv=config.lambda_adv, )
|
||||||
|
|
||||||
|
evaluator = PWGEvaluator(
|
||||||
|
models={
|
||||||
|
"generator": generator,
|
||||||
|
"discriminator": discriminator,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"stft": criterion_stft,
|
||||||
|
"mse": criterion_mse,
|
||||||
|
},
|
||||||
|
dataloader=dev_dataloader,
|
||||||
|
lambda_adv=config.lambda_adv, )
|
||||||
|
trainer = Trainer(
|
||||||
|
updater,
|
||||||
|
stop_trigger=(config.train_max_steps, "iteration"),
|
||||||
|
out=output_dir, )
|
||||||
|
|
||||||
|
trainer.extend(
|
||||||
|
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
writer = LogWriter(str(trainer.out))
|
||||||
|
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(max_size=config.num_snapshots),
|
||||||
|
trigger=(config.save_interval_steps, 'iteration'))
|
||||||
|
|
||||||
|
print(trainer.extensions.keys())
|
||||||
|
print("Trainer Done!")
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||||
|
"model with Baker Mandrin TTS dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="config file to overwrite default config")
|
||||||
|
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||||
|
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default="gpu", help="device type to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"--nprocs", type=int, default=1, help="number of processes")
|
||||||
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.device == "cpu" and args.nprocs > 1:
|
||||||
|
raise RuntimeError("Multiprocess training on CPU is not supported.")
|
||||||
|
config = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(config)
|
||||||
|
print(
|
||||||
|
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# dispatch
|
||||||
|
if args.nprocs > 1:
|
||||||
|
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
|
||||||
|
else:
|
||||||
|
train_sp(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -58,10 +58,10 @@ For more help on arguments
|
||||||
|
|
||||||
## Synthesis
|
## Synthesis
|
||||||
|
|
||||||
After training the Tacotron2, spectrogram can be synthesized by running ``synthesis.py``.
|
After training the Tacotron2, spectrogram can be synthesized by running ``synthesize.py``.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python synthesis.py \
|
python synthesize.py \
|
||||||
--config=${CONFIGPATH} \
|
--config=${CONFIGPATH} \
|
||||||
--checkpoint_path=${CHECKPOINTPATH} \
|
--checkpoint_path=${CHECKPOINTPATH} \
|
||||||
--input=${TEXTPATH} \
|
--input=${TEXTPATH} \
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
__version__ = "0.2.0-beta.0"
|
__version__ = "0.0.0"
|
||||||
|
|
||||||
|
import logging
|
||||||
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
||||||
|
|
|
@ -0,0 +1,151 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from multiprocessing import Manager
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class DataTable(Dataset):
|
||||||
|
"""Dataset to load and convert data for general purpose.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : List[Dict[str, Any]]
|
||||||
|
Metadata, a list of meta datum, each of which is composed of
|
||||||
|
several fields
|
||||||
|
fields : List[str], optional
|
||||||
|
Fields to use, if not specified, all the fields in the data are
|
||||||
|
used, by default None
|
||||||
|
converters : Dict[str, Callable], optional
|
||||||
|
Converters used to process each field, by default None
|
||||||
|
use_cache : bool, optional
|
||||||
|
Whether to use cache, by default False
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If there is some field that does not exist in data.
|
||||||
|
ValueError
|
||||||
|
If there is some field in converters that does not exist in fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data: List[Dict[str, Any]],
|
||||||
|
fields: List[str]=None,
|
||||||
|
converters: Dict[str, Callable]=None,
|
||||||
|
use_cache: bool=False):
|
||||||
|
# metadata
|
||||||
|
self.data = data
|
||||||
|
assert len(data) > 0, "This dataset has no examples"
|
||||||
|
|
||||||
|
# peak an example to get existing fields.
|
||||||
|
first_example = self.data[0]
|
||||||
|
fields_in_data = first_example.keys()
|
||||||
|
|
||||||
|
# check all the requested fields exist
|
||||||
|
if fields is None:
|
||||||
|
self.fields = fields_in_data
|
||||||
|
else:
|
||||||
|
for field in fields:
|
||||||
|
if field not in fields_in_data:
|
||||||
|
raise ValueError(
|
||||||
|
f"The requested field ({field}) is not found"
|
||||||
|
f"in the data. Fields in the data is {fields_in_data}")
|
||||||
|
self.fields = fields
|
||||||
|
|
||||||
|
# check converters
|
||||||
|
if converters is None:
|
||||||
|
self.converters = {}
|
||||||
|
else:
|
||||||
|
for field in converters.keys():
|
||||||
|
if field not in self.fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"The converter has a non existing field ({field})")
|
||||||
|
self.converters = converters
|
||||||
|
|
||||||
|
self.use_cache = use_cache
|
||||||
|
if use_cache:
|
||||||
|
self._initialize_cache()
|
||||||
|
|
||||||
|
def _initialize_cache(self):
|
||||||
|
self.manager = Manager()
|
||||||
|
self.caches = self.manager.list()
|
||||||
|
self.caches += [None for _ in range(len(self))]
|
||||||
|
|
||||||
|
def _get_metadata(self, idx: int) -> Dict[str, Any]:
|
||||||
|
"""Return a meta-datum given an index."""
|
||||||
|
return self.data[idx]
|
||||||
|
|
||||||
|
def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Convert a meta datum to an example by applying the corresponding
|
||||||
|
converters to each fields requested.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
meta_datum : Dict[str, Any]
|
||||||
|
Meta datum
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
Converted example
|
||||||
|
"""
|
||||||
|
example = {}
|
||||||
|
for field in self.fields:
|
||||||
|
converter = self.converters.get(field, None)
|
||||||
|
meta_datum_field = meta_datum[field]
|
||||||
|
if converter is not None:
|
||||||
|
converted_field = converter(meta_datum_field)
|
||||||
|
else:
|
||||||
|
converted_field = meta_datum_field
|
||||||
|
example[field] = converted_field
|
||||||
|
return example
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||||
|
"""Get an example given an index.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
idx : int
|
||||||
|
Index of the example to get
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[str, Any]
|
||||||
|
A converted example
|
||||||
|
"""
|
||||||
|
if self.use_cache and self.caches[idx] is not None:
|
||||||
|
return self.caches[idx]
|
||||||
|
|
||||||
|
meta_datum = self._get_metadata(idx)
|
||||||
|
example = self._convert(meta_datum)
|
||||||
|
|
||||||
|
if self.use_cache:
|
||||||
|
self.caches[idx] = example
|
||||||
|
|
||||||
|
return example
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Returns the size of the dataset.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
The length of the dataset
|
||||||
|
"""
|
||||||
|
return len(self.data)
|
|
@ -0,0 +1,770 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Dict, Any, Union, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Stretch2D(nn.Layer):
|
||||||
|
def __init__(self, w_scale: int, h_scale: int, mode: str="nearest"):
|
||||||
|
"""Strech an image (or image-like object) with some interpolation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
w_scale : int
|
||||||
|
Scalar of width.
|
||||||
|
h_scale : int
|
||||||
|
Scalar of the height.
|
||||||
|
mode : str, optional
|
||||||
|
Interpolation mode, modes suppored are "nearest", "bilinear",
|
||||||
|
"trilinear", "bicubic", "linear" and "area",by default "nearest"
|
||||||
|
|
||||||
|
For more details about interpolation, see
|
||||||
|
`paddle.nn.functional.interpolate <https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/nn/functional/interpolate_en.html>`_.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.w_scale = w_scale
|
||||||
|
self.h_scale = h_scale
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, C, H, W)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, C, H', W'), where ``H'=h_scale * H``, ``W'=w_scale * W``.
|
||||||
|
The stretched image.
|
||||||
|
"""
|
||||||
|
out = F.interpolate(
|
||||||
|
x, scale_factor=(self.h_scale, self.w_scale), mode=self.mode)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampleNet(nn.Layer):
|
||||||
|
"""A Layer to upsample spectrogram by applying consecutive stretch and
|
||||||
|
convolutions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
upsample_scales : List[int]
|
||||||
|
Upsampling factors for each strech.
|
||||||
|
nonlinear_activation : Optional[str], optional
|
||||||
|
Activation after each convolution, by default None
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters passed to construct the activation, by default {}
|
||||||
|
interpolate_mode : str, optional
|
||||||
|
Interpolation mode of the strech, by default "nearest"
|
||||||
|
freq_axis_kernel_size : int, optional
|
||||||
|
Convolution kernel size along the frequency axis, by default 1
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding before convolution, by default False
|
||||||
|
|
||||||
|
If True, Causal padding is used along the time axis, i.e. padding
|
||||||
|
amount is ``receptive field - 1`` and 0 for before and after,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
If False, "same" padding is used along the time axis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
upsample_scales: List[int],
|
||||||
|
nonlinear_activation: Optional[str]=None,
|
||||||
|
nonlinear_activation_params: Dict[str, Any]={},
|
||||||
|
interpolate_mode: str="nearest",
|
||||||
|
freq_axis_kernel_size: int=1,
|
||||||
|
use_causal_conv: bool=False):
|
||||||
|
super().__init__()
|
||||||
|
self.use_causal_conv = use_causal_conv
|
||||||
|
self.up_layers = nn.LayerList()
|
||||||
|
for scale in upsample_scales:
|
||||||
|
stretch = Stretch2D(scale, 1, interpolate_mode)
|
||||||
|
assert freq_axis_kernel_size % 2 == 1
|
||||||
|
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||||||
|
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||||||
|
if use_causal_conv:
|
||||||
|
padding = (freq_axis_padding, scale * 2)
|
||||||
|
else:
|
||||||
|
padding = (freq_axis_padding, scale)
|
||||||
|
conv = nn.Conv2D(
|
||||||
|
1, 1, kernel_size, padding=padding, bias_attr=False)
|
||||||
|
self.up_layers.extend([stretch, conv])
|
||||||
|
if nonlinear_activation is not None:
|
||||||
|
nonlinear = getattr(
|
||||||
|
nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||||
|
self.up_layers.append(nonlinear)
|
||||||
|
|
||||||
|
def forward(self, c: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, F, T), spectrogram
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled
|
||||||
|
spectrogram
|
||||||
|
"""
|
||||||
|
c = c.unsqueeze(1)
|
||||||
|
for f in self.up_layers:
|
||||||
|
if self.use_causal_conv and isinstance(f, nn.Conv2D):
|
||||||
|
c = f(c)[:, :, :, c.shape[-1]]
|
||||||
|
else:
|
||||||
|
c = f(c)
|
||||||
|
return c.squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvInUpsampleNet(nn.Layer):
|
||||||
|
"""A Layer to upsample spectrogram composed of a convolution and an
|
||||||
|
UpsampleNet.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
upsample_scales : List[int]
|
||||||
|
Upsampling factors for each strech.
|
||||||
|
nonlinear_activation : Optional[str], optional
|
||||||
|
Activation after each convolution, by default None
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters passed to construct the activation, by default {}
|
||||||
|
interpolate_mode : str, optional
|
||||||
|
Interpolation mode of the strech, by default "nearest"
|
||||||
|
freq_axis_kernel_size : int, optional
|
||||||
|
Convolution kernel size along the frequency axis, by default 1
|
||||||
|
aux_channels : int, optional
|
||||||
|
Feature size of the input, by default 80
|
||||||
|
aux_context_window : int, optional
|
||||||
|
Context window of the first 1D convolution applied to the input. It
|
||||||
|
related to the kernel size of the convolution, by default 0
|
||||||
|
|
||||||
|
If use causal convolution, the kernel size is ``window + 1``, else
|
||||||
|
the kernel size is ``2 * window + 1``.
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding before convolution, by default False
|
||||||
|
|
||||||
|
If True, Causal padding is used along the time axis, i.e. padding
|
||||||
|
amount is ``receptive field - 1`` and 0 for before and after,
|
||||||
|
respectively.
|
||||||
|
|
||||||
|
If False, "same" padding is used along the time axis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
upsample_scales: List[int],
|
||||||
|
nonlinear_activation: Optional[str]=None,
|
||||||
|
nonlinear_activation_params: Dict[str, Any]={},
|
||||||
|
interpolate_mode: str="nearest",
|
||||||
|
freq_axis_kernel_size: int=1,
|
||||||
|
aux_channels: int=80,
|
||||||
|
aux_context_window: int=0,
|
||||||
|
use_causal_conv: bool=False):
|
||||||
|
super().__init__()
|
||||||
|
self.aux_context_window = aux_context_window
|
||||||
|
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||||||
|
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
||||||
|
self.conv_in = nn.Conv1D(
|
||||||
|
aux_channels,
|
||||||
|
aux_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
bias_attr=False)
|
||||||
|
self.upsample = UpsampleNet(
|
||||||
|
upsample_scales=upsample_scales,
|
||||||
|
nonlinear_activation=nonlinear_activation,
|
||||||
|
nonlinear_activation_params=nonlinear_activation_params,
|
||||||
|
interpolate_mode=interpolate_mode,
|
||||||
|
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||||
|
use_causal_conv=use_causal_conv)
|
||||||
|
|
||||||
|
def forward(self, c: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, F, T), spectrogram
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensors
|
||||||
|
Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled
|
||||||
|
spectrogram
|
||||||
|
"""
|
||||||
|
c_ = self.conv_in(c)
|
||||||
|
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
|
||||||
|
return self.upsample(c)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Layer):
|
||||||
|
"""A gated activation unit composed of an 1D convolution, a gated tanh
|
||||||
|
unit and parametric redidual and skip connections. For more details,
|
||||||
|
refer to `WaveNet: A Generative Model for Raw Audio <https://arxiv.org/abs/1609.03499>`_.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
kernel_size : int, optional
|
||||||
|
Kernel size of the 1D convolution, by default 3
|
||||||
|
residual_channels : int, optional
|
||||||
|
Feature size of the resiaudl output(and also the input), by default 64
|
||||||
|
gate_channels : int, optional
|
||||||
|
Output feature size of the 1D convolution, by default 128
|
||||||
|
skip_channels : int, optional
|
||||||
|
Feature size of the skip output, by default 64
|
||||||
|
aux_channels : int, optional
|
||||||
|
Feature size of the auxiliary input (e.g. spectrogram), by default 80
|
||||||
|
dropout : float, optional
|
||||||
|
Probability of the dropout before the 1D convolution, by default 0.
|
||||||
|
dilation : int, optional
|
||||||
|
Dilation of the 1D convolution, by default 1
|
||||||
|
bias : bool, optional
|
||||||
|
Whether to use bias in the 1D convolution, by default True
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding for the 1D convolution, by default False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
kernel_size: int=3,
|
||||||
|
residual_channels: int=64,
|
||||||
|
gate_channels: int=128,
|
||||||
|
skip_channels: int=64,
|
||||||
|
aux_channels: int=80,
|
||||||
|
dropout: float=0.,
|
||||||
|
dilation: int=1,
|
||||||
|
bias: bool=True,
|
||||||
|
use_causal_conv: bool=False):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
if use_causal_conv:
|
||||||
|
padding = (kernel_size - 1) * dilation
|
||||||
|
else:
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
self.use_causal_conv = use_causal_conv
|
||||||
|
|
||||||
|
self.conv = nn.Conv1D(
|
||||||
|
residual_channels,
|
||||||
|
gate_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias_attr=bias)
|
||||||
|
if aux_channels is not None:
|
||||||
|
self.conv1x1_aux = nn.Conv1D(
|
||||||
|
aux_channels, gate_channels, kernel_size=1, bias_attr=False)
|
||||||
|
else:
|
||||||
|
self.conv1x1_aux = None
|
||||||
|
|
||||||
|
gate_out_channels = gate_channels // 2
|
||||||
|
self.conv1x1_out = nn.Conv1D(
|
||||||
|
gate_out_channels,
|
||||||
|
residual_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
bias_attr=bias)
|
||||||
|
self.conv1x1_skip = nn.Conv1D(
|
||||||
|
gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, C_res, T), the input features.
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, C_aux, T), the auxiliary input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
res : Tensor
|
||||||
|
Shape (N, C_res, T), the residual output, which is used as the
|
||||||
|
input of the next ResidualBlock in a stack of ResidualBlocks.
|
||||||
|
skip : Tensor
|
||||||
|
Shape (N, C_skip, T), the skip output, which is collected among
|
||||||
|
each layer in a stack of ResidualBlocks.
|
||||||
|
"""
|
||||||
|
x_input = x
|
||||||
|
x = F.dropout(x, self.dropout, training=self.training)
|
||||||
|
x = self.conv(x)
|
||||||
|
x = x[:, :, x_input.shape[-1]] if self.use_causal_conv else x
|
||||||
|
if c is not None:
|
||||||
|
c = self.conv1x1_aux(c)
|
||||||
|
x += c
|
||||||
|
|
||||||
|
a, b = paddle.chunk(x, 2, axis=1)
|
||||||
|
x = paddle.tanh(a) * F.sigmoid(b)
|
||||||
|
|
||||||
|
skip = self.conv1x1_skip(x)
|
||||||
|
res = (self.conv1x1_out(x) + x_input) * math.sqrt(0.5)
|
||||||
|
return res, skip
|
||||||
|
|
||||||
|
|
||||||
|
class PWGGenerator(nn.Layer):
|
||||||
|
"""Wave Generator for Parallel WaveGAN
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int, optional
|
||||||
|
Number of channels of the input waveform, by default 1
|
||||||
|
out_channels : int, optional
|
||||||
|
Number of channels of the output waveform, by default 1
|
||||||
|
kernel_size : int, optional
|
||||||
|
Kernel size of the residual blocks inside, by default 3
|
||||||
|
layers : int, optional
|
||||||
|
Number of residual blocks inside, by default 30
|
||||||
|
stacks : int, optional
|
||||||
|
The number of groups to split the residual blocks into, by default 3
|
||||||
|
|
||||||
|
Within each group, the dilation of the residual block grows
|
||||||
|
exponentially.
|
||||||
|
residual_channels : int, optional
|
||||||
|
Residual channel of the residual blocks, by default 64
|
||||||
|
gate_channels : int, optional
|
||||||
|
Gate channel of the residual blocks, by default 128
|
||||||
|
skip_channels : int, optional
|
||||||
|
Skip channel of the residual blocks, by default 64
|
||||||
|
aux_channels : int, optional
|
||||||
|
Auxiliary channel of the residual blocks, by default 80
|
||||||
|
aux_context_window : int, optional
|
||||||
|
The context window size of the first convolution applied to the
|
||||||
|
auxiliary input, by default 2
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout of the residual blocks, by default 0.
|
||||||
|
bias : bool, optional
|
||||||
|
Whether to use bias in residual blocks, by default True
|
||||||
|
use_weight_norm : bool, optional
|
||||||
|
Whether to use weight norm in all convolutions, by default True
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal padding in the upsample network and residual
|
||||||
|
blocks, by default False
|
||||||
|
upsample_scales : List[int], optional
|
||||||
|
Upsample scales of the upsample network, by default [4, 4, 4, 4]
|
||||||
|
nonlinear_activation : Optional[str], optional
|
||||||
|
Non linear activation in upsample network, by default None
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters passed to the linear activation in the upsample network,
|
||||||
|
by default {}
|
||||||
|
interpolate_mode : str, optional
|
||||||
|
Interpolation mode of the upsample network, by default "nearest"
|
||||||
|
freq_axis_kernel_size : int, optional
|
||||||
|
Kernel size along the frequency axis of the upsample network, by default 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int=1,
|
||||||
|
out_channels: int=1,
|
||||||
|
kernel_size: int=3,
|
||||||
|
layers: int=30,
|
||||||
|
stacks: int=3,
|
||||||
|
residual_channels: int=64,
|
||||||
|
gate_channels: int=128,
|
||||||
|
skip_channels: int=64,
|
||||||
|
aux_channels: int=80,
|
||||||
|
aux_context_window: int=2,
|
||||||
|
dropout: float=0.,
|
||||||
|
bias: bool=True,
|
||||||
|
use_weight_norm: bool=True,
|
||||||
|
use_causal_conv: bool=False,
|
||||||
|
upsample_scales: List[int]=[4, 4, 4, 4],
|
||||||
|
nonlinear_activation: Optional[str]=None,
|
||||||
|
nonlinear_activation_params: Dict[str, Any]={},
|
||||||
|
interpolate_mode: str="nearest",
|
||||||
|
freq_axis_kernel_size: int=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.aux_channels = aux_channels
|
||||||
|
self.aux_context_window = aux_context_window
|
||||||
|
self.layers = layers
|
||||||
|
self.stacks = stacks
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
assert layers % stacks == 0
|
||||||
|
layers_per_stack = layers // stacks
|
||||||
|
|
||||||
|
self.first_conv = nn.Conv1D(
|
||||||
|
in_channels, residual_channels, 1, bias_attr=True)
|
||||||
|
self.upsample_net = ConvInUpsampleNet(
|
||||||
|
upsample_scales=upsample_scales,
|
||||||
|
nonlinear_activation=nonlinear_activation,
|
||||||
|
nonlinear_activation_params=nonlinear_activation_params,
|
||||||
|
interpolate_mode=interpolate_mode,
|
||||||
|
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
aux_context_window=aux_context_window,
|
||||||
|
use_causal_conv=use_causal_conv)
|
||||||
|
self.upsample_factor = np.prod(upsample_scales)
|
||||||
|
|
||||||
|
self.conv_layers = nn.LayerList()
|
||||||
|
for layer in range(layers):
|
||||||
|
dilation = 2**(layer % layers_per_stack)
|
||||||
|
conv = ResidualBlock(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
residual_channels=residual_channels,
|
||||||
|
gate_channels=gate_channels,
|
||||||
|
skip_channels=skip_channels,
|
||||||
|
aux_channels=aux_channels,
|
||||||
|
dilation=dilation,
|
||||||
|
dropout=dropout,
|
||||||
|
bias=bias,
|
||||||
|
use_causal_conv=use_causal_conv)
|
||||||
|
self.conv_layers.append(conv)
|
||||||
|
|
||||||
|
self.last_conv_layers = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1D(
|
||||||
|
skip_channels, skip_channels, 1, bias_attr=True),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1D(
|
||||||
|
skip_channels, out_channels, 1, bias_attr=True))
|
||||||
|
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, c: Tensor) -> Tensor:
|
||||||
|
"""Generate waveform.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, C_in, T), The input waveform.
|
||||||
|
c : Tensor
|
||||||
|
Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram). It
|
||||||
|
is upsampled to match the time resolution of the input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, C_out, T), the generated waveform.
|
||||||
|
"""
|
||||||
|
c = self.upsample_net(c)
|
||||||
|
assert c.shape[-1] == x.shape[-1]
|
||||||
|
|
||||||
|
x = self.first_conv(x)
|
||||||
|
skips = 0
|
||||||
|
for f in self.conv_layers:
|
||||||
|
x, s = f(x, c)
|
||||||
|
skips += s
|
||||||
|
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
||||||
|
|
||||||
|
x = self.last_conv_layers(skips)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
"""Recursively apply weight normalization to all the Convolution layers
|
||||||
|
in the sublayers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _apply_weight_norm(layer):
|
||||||
|
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||||
|
nn.utils.weight_norm(layer)
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
"""Recursively remove weight normalization from all the Convolution
|
||||||
|
layers in the sublayers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _remove_weight_norm(layer):
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
|
def inference(self, c: Optional[Tensor]=None,
|
||||||
|
x: Optional[Tensor]=None) -> Tensor:
|
||||||
|
"""Waveform generation. This function is used for single instance
|
||||||
|
inference.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
c : Tensor, optional
|
||||||
|
Shape (T', C_aux), the auxiliary input, by default None
|
||||||
|
x : Tensor, optional
|
||||||
|
Shape (T, C_in), the noise waveform, by default None
|
||||||
|
If not provided, a sample is drawn from a gaussian distribution.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (T, C_out), the generated waveform
|
||||||
|
"""
|
||||||
|
if x is not None:
|
||||||
|
x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch
|
||||||
|
else:
|
||||||
|
assert c is not None
|
||||||
|
x = paddle.randn(
|
||||||
|
[1, self.in_channels, c.shape[0] * self.upsample_factor])
|
||||||
|
|
||||||
|
if c is not None:
|
||||||
|
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
|
||||||
|
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
|
||||||
|
out = self.forward(x, c).squeeze(0).transpose([1, 0])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PWGDiscriminator(nn.Layer):
|
||||||
|
"""A convolutional discriminator for audio.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int, optional
|
||||||
|
Number of channels of the input audio, by default 1
|
||||||
|
out_channels : int, optional
|
||||||
|
Output feature size, by default 1
|
||||||
|
kernel_size : int, optional
|
||||||
|
Kernel size of convolutional sublayers, by default 3
|
||||||
|
layers : int, optional
|
||||||
|
Number of layers, by default 10
|
||||||
|
conv_channels : int, optional
|
||||||
|
Feature size of the convolutional sublayers, by default 64
|
||||||
|
dilation_factor : int, optional
|
||||||
|
The factor with which dilation of each convolutional sublayers grows
|
||||||
|
exponentially if it is greater than 1, else the dilation of each
|
||||||
|
convolutional sublayers grows linearly, by default 1
|
||||||
|
nonlinear_activation : str, optional
|
||||||
|
The activation after each convolutional sublayer, by default "LeakyReLU"
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
The parameters passed to the activation's initializer, by default
|
||||||
|
{"negative_slope": 0.2}
|
||||||
|
bias : bool, optional
|
||||||
|
Whether to use bias in convolutional sublayers, by default True
|
||||||
|
use_weight_norm : bool, optional
|
||||||
|
Whether to use weight normalization at all convolutional sublayers,
|
||||||
|
by default True
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int=1,
|
||||||
|
out_channels: int=1,
|
||||||
|
kernel_size: int=3,
|
||||||
|
layers: int=10,
|
||||||
|
conv_channels: int=64,
|
||||||
|
dilation_factor: int=1,
|
||||||
|
nonlinear_activation: str="LeakyReLU",
|
||||||
|
nonlinear_activation_params: Dict[
|
||||||
|
str, Any]={"negative_slope": 0.2},
|
||||||
|
bias: bool=True,
|
||||||
|
use_weight_norm: bool=True):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
assert dilation_factor > 0
|
||||||
|
conv_layers = []
|
||||||
|
conv_in_channels = in_channels
|
||||||
|
for i in range(layers - 1):
|
||||||
|
if i == 0:
|
||||||
|
dilation = 1
|
||||||
|
else:
|
||||||
|
dilation = i if dilation_factor == 1 else dilation_factor**i
|
||||||
|
conv_in_channels = conv_channels
|
||||||
|
padding = (kernel_size - 1) // 2 * dilation
|
||||||
|
conv_layer = nn.Conv1D(
|
||||||
|
conv_in_channels,
|
||||||
|
conv_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
bias_attr=bias)
|
||||||
|
nonlinear = getattr(
|
||||||
|
nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||||
|
conv_layers.append(conv_layer)
|
||||||
|
conv_layers.append(nonlinear)
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
last_conv = nn.Conv1D(
|
||||||
|
conv_in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
bias_attr=bias)
|
||||||
|
conv_layers.append(last_conv)
|
||||||
|
self.conv_layers = nn.Sequential(*conv_layers)
|
||||||
|
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, in_channels, num_samples), the input audio.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, out_channels, num_samples), the predicted logits.
|
||||||
|
"""
|
||||||
|
return self.conv_layers(x)
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
def _apply_weight_norm(layer):
|
||||||
|
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||||
|
nn.utils.weight_norm(layer)
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
def _remove_weight_norm(layer):
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualPWGDiscriminator(nn.Layer):
|
||||||
|
"""A wavenet-style discriminator for audio.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int, optional
|
||||||
|
Number of channels of the input audio, by default 1
|
||||||
|
out_channels : int, optional
|
||||||
|
Output feature size, by default 1
|
||||||
|
kernel_size : int, optional
|
||||||
|
Kernel size of residual blocks, by default 3
|
||||||
|
layers : int, optional
|
||||||
|
Number of residual blocks, by default 30
|
||||||
|
stacks : int, optional
|
||||||
|
Number of groups of residual blocks, within which the dilation
|
||||||
|
of each residual blocks grows exponentially, by default 3
|
||||||
|
residual_channels : int, optional
|
||||||
|
Residual channels of residual blocks, by default 64
|
||||||
|
gate_channels : int, optional
|
||||||
|
Gate channels of residual blocks, by default 128
|
||||||
|
skip_channels : int, optional
|
||||||
|
Skip channels of residual blocks, by default 64
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout probability of residual blocks, by default 0.
|
||||||
|
bias : bool, optional
|
||||||
|
Whether to use bias in residual blocks, by default True
|
||||||
|
use_weight_norm : bool, optional
|
||||||
|
Whether to use weight normalization in all convolutional layers,
|
||||||
|
by default True
|
||||||
|
use_causal_conv : bool, optional
|
||||||
|
Whether to use causal convolution in residual blocks, by default False
|
||||||
|
nonlinear_activation : str, optional
|
||||||
|
Activation after convolutions other than those in residual blocks,
|
||||||
|
by default "LeakyReLU"
|
||||||
|
nonlinear_activation_params : Dict[str, Any], optional
|
||||||
|
Parameters to pass to the activation, by default {"negative_slope": 0.2}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int=1,
|
||||||
|
out_channels: int=1,
|
||||||
|
kernel_size: int=3,
|
||||||
|
layers: int=30,
|
||||||
|
stacks: int=3,
|
||||||
|
residual_channels: int=64,
|
||||||
|
gate_channels: int=128,
|
||||||
|
skip_channels: int=64,
|
||||||
|
dropout: float=0.,
|
||||||
|
bias: bool=True,
|
||||||
|
use_weight_norm: bool=True,
|
||||||
|
use_causal_conv: bool=False,
|
||||||
|
nonlinear_activation: str="LeakyReLU",
|
||||||
|
nonlinear_activation_params: Dict[
|
||||||
|
str, Any]={"negative_slope": 0.2}):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.layers = layers
|
||||||
|
self.stacks = stacks
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
assert layers % stacks == 0
|
||||||
|
layers_per_stack = layers // stacks
|
||||||
|
|
||||||
|
self.first_conv = nn.Sequential(
|
||||||
|
nn.Conv1D(
|
||||||
|
in_channels, residual_channels, 1, bias_attr=True),
|
||||||
|
getattr(nn, nonlinear_activation)(**nonlinear_activation_params))
|
||||||
|
|
||||||
|
self.conv_layers = nn.LayerList()
|
||||||
|
for layer in range(layers):
|
||||||
|
dilation = 2**(layer % layers_per_stack)
|
||||||
|
conv = ResidualBlock(
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
residual_channels=residual_channels,
|
||||||
|
gate_channels=gate_channels,
|
||||||
|
skip_channels=skip_channels,
|
||||||
|
aux_channels=None, # no auxiliary input
|
||||||
|
dropout=dropout,
|
||||||
|
dilation=dilation,
|
||||||
|
bias=bias,
|
||||||
|
use_causal_conv=use_causal_conv)
|
||||||
|
self.conv_layers.append(conv)
|
||||||
|
|
||||||
|
self.last_conv_layers = nn.Sequential(
|
||||||
|
getattr(nn, nonlinear_activation)(**nonlinear_activation_params),
|
||||||
|
nn.Conv1D(
|
||||||
|
skip_channels, skip_channels, 1, bias_attr=True),
|
||||||
|
getattr(nn, nonlinear_activation)(**nonlinear_activation_params),
|
||||||
|
nn.Conv1D(
|
||||||
|
skip_channels, out_channels, 1, bias_attr=True))
|
||||||
|
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor
|
||||||
|
Shape (N, in_channels, num_samples), the input audio.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
Shape (N, out_channels, num_samples), the predicted logits.
|
||||||
|
"""
|
||||||
|
x = self.first_conv(x)
|
||||||
|
skip = 0
|
||||||
|
for f in self.conv_layers:
|
||||||
|
x, h = f(x, None)
|
||||||
|
skip += h
|
||||||
|
skip *= math.sqrt(1 / len(self.conv_layers))
|
||||||
|
|
||||||
|
x = skip
|
||||||
|
x = self.last_conv_layers(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
def _apply_weight_norm(layer):
|
||||||
|
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||||
|
nn.utils.weight_norm(layer)
|
||||||
|
|
||||||
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
def _remove_weight_norm(layer):
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.apply(_remove_weight_norm)
|
|
@ -44,7 +44,8 @@ def fold(x, n_group):
|
||||||
Tensor : [shape=(\*, time_steps // n_group, group)]
|
Tensor : [shape=(\*, time_steps // n_group, group)]
|
||||||
Folded tensor.
|
Folded tensor.
|
||||||
"""
|
"""
|
||||||
*spatial_shape, time_steps = x.shape
|
spatial_shape = list(x.shape[:-1])
|
||||||
|
time_steps = paddle.shape(x)[-1]
|
||||||
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
new_shape = spatial_shape + [time_steps // n_group, n_group]
|
||||||
return paddle.reshape(x, new_shape)
|
return paddle.reshape(x, new_shape)
|
||||||
|
|
||||||
|
@ -232,7 +233,7 @@ class ResidualBlock(nn.Layer):
|
||||||
"""
|
"""
|
||||||
if self.training:
|
if self.training:
|
||||||
raise ValueError("Only use start sequence at evaluation mode.")
|
raise ValueError("Only use start sequence at evaluation mode.")
|
||||||
self._conv_buffer = None
|
self._conv_buffer = paddle.zeros([1])
|
||||||
|
|
||||||
# NOTE: call self.conv's weight norm hook expliccitly since
|
# NOTE: call self.conv's weight norm hook expliccitly since
|
||||||
# its weight will be visited directly in `add_input` without
|
# its weight will be visited directly in `add_input` without
|
||||||
|
@ -263,10 +264,9 @@ class ResidualBlock(nn.Layer):
|
||||||
A row of the skip output.
|
A row of the skip output.
|
||||||
"""
|
"""
|
||||||
x_row_in = x_row
|
x_row_in = x_row
|
||||||
if self._conv_buffer is None:
|
if len(paddle.shape(self._conv_buffer)) == 1:
|
||||||
self._init_buffer(x_row)
|
self._init_buffer(x_row)
|
||||||
self._update_buffer(x_row)
|
self._update_buffer(x_row)
|
||||||
|
|
||||||
rw = self.rw
|
rw = self.rw
|
||||||
x_row = F.conv2d(
|
x_row = F.conv2d(
|
||||||
self._conv_buffer,
|
self._conv_buffer,
|
||||||
|
@ -275,7 +275,6 @@ class ResidualBlock(nn.Layer):
|
||||||
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
padding=[0, 0, rw // 2, (rw - 1) // 2],
|
||||||
dilation=self.dilations)
|
dilation=self.dilations)
|
||||||
x_row += self.condition_proj(condition_row)
|
x_row += self.condition_proj(condition_row)
|
||||||
|
|
||||||
content, gate = paddle.chunk(x_row, 2, axis=1)
|
content, gate = paddle.chunk(x_row, 2, axis=1)
|
||||||
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
x_row = paddle.tanh(content) * F.sigmoid(gate)
|
||||||
|
|
||||||
|
@ -329,7 +328,7 @@ class ResidualNet(nn.LayerList):
|
||||||
if len(dilations_h) != n_layer:
|
if len(dilations_h) != n_layer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"number of dilations_h should equals num of layers")
|
"number of dilations_h should equals num of layers")
|
||||||
super().__init__()
|
super(ResidualNet, self).__init__()
|
||||||
for i in range(n_layer):
|
for i in range(n_layer):
|
||||||
dilation = (dilations_h[i], 2**i)
|
dilation = (dilations_h[i], 2**i)
|
||||||
layer = ResidualBlock(residual_channels, condition_channels,
|
layer = ResidualBlock(residual_channels, condition_channels,
|
||||||
|
@ -539,27 +538,21 @@ class Flow(nn.Layer):
|
||||||
transformation from x to z.
|
transformation from x to z.
|
||||||
"""
|
"""
|
||||||
z_0 = z[:, :, :1, :]
|
z_0 = z[:, :, :1, :]
|
||||||
x = []
|
x = paddle.zeros_like(z)
|
||||||
logs_list = []
|
x[:, :, :1, :] = z_0
|
||||||
b_list = []
|
|
||||||
x.append(z_0)
|
|
||||||
|
|
||||||
self._start_sequence()
|
self._start_sequence()
|
||||||
for i in range(1, self.n_group):
|
|
||||||
x_row = x[-1] # actuallt i-1:i
|
num_step = paddle.ones([1], dtype='int32') * (self.n_group)
|
||||||
|
for i in range(1, num_step):
|
||||||
|
x_row = x[:, :, i - 1:i, :]
|
||||||
z_row = z[:, :, i:i + 1, :]
|
z_row = z[:, :, i:i + 1, :]
|
||||||
condition_row = condition[:, :, i:i + 1, :]
|
condition_row = condition[:, :, i:i + 1, :]
|
||||||
|
|
||||||
x_next_row, (logs, b) = self._inverse_row(z_row, x_row,
|
x_next_row, (logs, b) = self._inverse_row(z_row, x_row,
|
||||||
condition_row)
|
condition_row)
|
||||||
x.append(x_next_row)
|
x[:, :, i:i+1, :] = x_next_row
|
||||||
logs_list.append(logs)
|
|
||||||
b_list.append(b)
|
return x
|
||||||
|
|
||||||
x = paddle.concat(x, 2)
|
|
||||||
logs = paddle.concat(logs_list, 2)
|
|
||||||
b = paddle.concat(b_list, 2)
|
|
||||||
return x, (logs, b)
|
|
||||||
|
|
||||||
|
|
||||||
class WaveFlow(nn.LayerList):
|
class WaveFlow(nn.LayerList):
|
||||||
|
@ -611,16 +604,18 @@ class WaveFlow(nn.LayerList):
|
||||||
perms = []
|
perms = []
|
||||||
for i in range(n_flows):
|
for i in range(n_flows):
|
||||||
if i < n_flows // 2:
|
if i < n_flows // 2:
|
||||||
perms.append(indices[::-1])
|
perm = indices[::-1]
|
||||||
else:
|
else:
|
||||||
perm = list(reversed(indices[:half])) + list(
|
perm = list(reversed(indices[:half])) + list(
|
||||||
reversed(indices[half:]))
|
reversed(indices[half:]))
|
||||||
perms.append(perm)
|
perm = paddle.to_tensor(perm)
|
||||||
|
self.register_buffer(perm.name, perm)
|
||||||
|
perms.append(perm)
|
||||||
return perms
|
return perms
|
||||||
|
|
||||||
def _trim(self, x, condition):
|
def _trim(self, x, condition):
|
||||||
assert condition.shape[-1] >= x.shape[-1]
|
assert condition.shape[-1] >= x.shape[-1]
|
||||||
pruned_len = int(x.shape[-1] // self.n_group * self.n_group)
|
pruned_len = int(paddle.shape(x)[-1] // self.n_group * self.n_group)
|
||||||
|
|
||||||
if x.shape[-1] > pruned_len:
|
if x.shape[-1] > pruned_len:
|
||||||
x = x[:, :pruned_len]
|
x = x[:, :pruned_len]
|
||||||
|
@ -707,7 +702,7 @@ class WaveFlow(nn.LayerList):
|
||||||
for i in reversed(range(self.n_flows)):
|
for i in reversed(range(self.n_flows)):
|
||||||
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
z = geo.shuffle_dim(z, 2, perm=self.perms[i])
|
||||||
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
condition = geo.shuffle_dim(condition, 2, perm=self.perms[i])
|
||||||
z, (logs, b) = self[i].inverse(z, condition)
|
z = self[i].inverse(z, condition)
|
||||||
|
|
||||||
x = paddle.squeeze(z, 1) # (B, H, W)
|
x = paddle.squeeze(z, 1) # (B, H, W)
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
|
@ -893,3 +888,21 @@ class WaveFlowLoss(nn.Layer):
|
||||||
) - log_det_jacobian
|
) - log_det_jacobian
|
||||||
loss = loss / np.prod(z.shape)
|
loss = loss / np.prod(z.shape)
|
||||||
return loss + self.const
|
return loss + self.const
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionalWaveFlow2Infer(ConditionalWaveFlow):
|
||||||
|
def forward(self, mel):
|
||||||
|
"""Generate raw audio given mel spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
mel : np.ndarray [shape=(C_mel, T_mel)]
|
||||||
|
Mel spectrogram of an utterance(in log-magnitude).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray [shape=(T,)]
|
||||||
|
The synthesized audio.
|
||||||
|
"""
|
||||||
|
audio = self.predict(mel)
|
||||||
|
return audio
|
||||||
|
|
|
@ -20,7 +20,7 @@ import librosa
|
||||||
from librosa.util import pad_center
|
from librosa.util import pad_center
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
__all__ = ["quantize", "dequantize", "STFT"]
|
__all__ = ["quantize", "dequantize", "STFT", "MelScale"]
|
||||||
|
|
||||||
|
|
||||||
def quantize(values, n_bands):
|
def quantize(values, n_bands):
|
||||||
|
@ -96,10 +96,10 @@ class STFT(nn.Layer):
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
|
||||||
pad_mode : string or function
|
pad_mode : string or function
|
||||||
If center=True, this argument is passed to np.pad for padding the edges
|
If center=True, this argument is passed to np.pad for padding the edges
|
||||||
of the signal y. By default (pad_mode="reflect"), y is padded on both
|
of the signal y. By default (pad_mode="reflect"), y is padded on both
|
||||||
sides with its own reflection, mirrored around its first and last
|
sides with its own reflection, mirrored around its first and last
|
||||||
sample respectively. If center=False, this argument is ignored.
|
sample respectively. If center=False, this argument is ignored.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -163,17 +163,15 @@ class STFT(nn.Layer):
|
||||||
w = np.concatenate([w_real, w_imag], axis=0)
|
w = np.concatenate([w_real, w_imag], axis=0)
|
||||||
w = w * window
|
w = w * window
|
||||||
w = np.expand_dims(w, 1)
|
w = np.expand_dims(w, 1)
|
||||||
self.weight = paddle.cast(
|
weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype())
|
||||||
paddle.to_tensor(w), paddle.get_default_dtype())
|
self.register_buffer("weight", weight)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Compute the stft transform.
|
"""Compute the stft transform.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
------------
|
------------
|
||||||
x : Tensor [shape=(B, T)]
|
x : Tensor [shape=(B, T)]
|
||||||
The input waveform.
|
The input waveform.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
------------
|
------------
|
||||||
real : Tensor [shape=(B, C, frames)]
|
real : Tensor [shape=(B, C, frames)]
|
||||||
|
@ -195,36 +193,32 @@ class STFT(nn.Layer):
|
||||||
|
|
||||||
def power(self, x):
|
def power(self, x):
|
||||||
"""Compute the power spectrum.
|
"""Compute the power spectrum.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
------------
|
------------
|
||||||
x : Tensor [shape=(B, T)]
|
x : Tensor [shape=(B, T)]
|
||||||
The input waveform.
|
The input waveform.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
------------
|
------------
|
||||||
Tensor [shape=(B, C, T)]
|
Tensor [shape=(B, C, T)]
|
||||||
The power spectrum.
|
The power spectrum.
|
||||||
"""
|
"""
|
||||||
real, imag = self(x)
|
real, imag = self.forward(x)
|
||||||
power = real**2 + imag**2
|
power = real**2 + imag**2
|
||||||
return power
|
return power
|
||||||
|
|
||||||
def magnitude(self, x):
|
def magnitude(self, x):
|
||||||
"""Compute the magnitude of the spectrum.
|
"""Compute the magnitude of the spectrum.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
------------
|
------------
|
||||||
x : Tensor [shape=(B, T)]
|
x : Tensor [shape=(B, T)]
|
||||||
The input waveform.
|
The input waveform.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
------------
|
------------
|
||||||
Tensor [shape=(B, C, T)]
|
Tensor [shape=(B, C, T)]
|
||||||
The magnitude of the spectrum.
|
The magnitude of the spectrum.
|
||||||
"""
|
"""
|
||||||
power = self.power(x)
|
power = self.power(x)
|
||||||
magnitude = paddle.sqrt(power)
|
magnitude = paddle.sqrt(power) # TODO(chenfeiyu): maybe clipping
|
||||||
return magnitude
|
return magnitude
|
||||||
|
|
||||||
|
|
||||||
|
@ -232,7 +226,9 @@ class MelScale(nn.Layer):
|
||||||
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
|
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||||
self.weight = paddle.to_tensor(mel_basis)
|
# self.weight = paddle.to_tensor(mel_basis)
|
||||||
|
weight = paddle.to_tensor(mel_basis, dtype=paddle.get_default_dtype())
|
||||||
|
self.register_buffer("weight", weight)
|
||||||
|
|
||||||
def forward(self, spec):
|
def forward(self, spec):
|
||||||
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
|
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
from parakeet.modules.audio import STFT
|
||||||
|
|
||||||
|
|
||||||
|
class SpectralConvergenceLoss(nn.Layer):
|
||||||
|
"""Spectral convergence loss module."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initilize spectral convergence loss module."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, C, T).
|
||||||
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, C, T).
|
||||||
|
Returns:
|
||||||
|
Tensor: Spectral convergence loss value.
|
||||||
|
"""
|
||||||
|
return paddle.norm(
|
||||||
|
y_mag - x_mag, p="fro") / paddle.clip(
|
||||||
|
paddle.norm(
|
||||||
|
y_mag, p="fro"), min=1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
class LogSTFTMagnitudeLoss(nn.Layer):
|
||||||
|
"""Log STFT magnitude loss module."""
|
||||||
|
|
||||||
|
def __init__(self, epsilon=1e-10):
|
||||||
|
"""Initilize los STFT magnitude loss module."""
|
||||||
|
super().__init__()
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def forward(self, x_mag, y_mag):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||||
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||||
|
Returns:
|
||||||
|
Tensor: Log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
return F.l1_loss(
|
||||||
|
paddle.log(paddle.clip(
|
||||||
|
y_mag, min=self.epsilon)),
|
||||||
|
paddle.log(paddle.clip(
|
||||||
|
x_mag, min=self.epsilon)))
|
||||||
|
|
||||||
|
|
||||||
|
class STFTLoss(nn.Layer):
|
||||||
|
"""STFT loss module."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
fft_size=1024,
|
||||||
|
shift_size=120,
|
||||||
|
win_length=600,
|
||||||
|
window="hann"):
|
||||||
|
"""Initialize STFT loss module."""
|
||||||
|
super().__init__()
|
||||||
|
self.fft_size = fft_size
|
||||||
|
self.shift_size = shift_size
|
||||||
|
self.win_length = win_length
|
||||||
|
self.stft = STFT(
|
||||||
|
n_fft=fft_size,
|
||||||
|
hop_length=shift_size,
|
||||||
|
win_length=win_length,
|
||||||
|
window=window)
|
||||||
|
self.spectral_convergence_loss = SpectralConvergenceLoss()
|
||||||
|
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (B, T).
|
||||||
|
y (Tensor): Groundtruth signal (B, T).
|
||||||
|
Returns:
|
||||||
|
Tensor: Spectral convergence loss value.
|
||||||
|
Tensor: Log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
x_mag = self.stft.magnitude(x)
|
||||||
|
y_mag = self.stft.magnitude(y)
|
||||||
|
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
||||||
|
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||||
|
|
||||||
|
return sc_loss, mag_loss
|
||||||
|
|
||||||
|
|
||||||
|
class MultiResolutionSTFTLoss(nn.Layer):
|
||||||
|
"""Multi resolution STFT loss module."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fft_sizes=[1024, 2048, 512],
|
||||||
|
hop_sizes=[120, 240, 50],
|
||||||
|
win_lengths=[600, 1200, 240],
|
||||||
|
window="hann", ):
|
||||||
|
"""Initialize Multi resolution STFT loss module.
|
||||||
|
Args:
|
||||||
|
fft_sizes (list): List of FFT sizes.
|
||||||
|
hop_sizes (list): List of hop sizes.
|
||||||
|
win_lengths (list): List of window lengths.
|
||||||
|
window (str): Window function type.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||||
|
self.stft_losses = nn.LayerList()
|
||||||
|
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||||
|
self.stft_losses.append(STFTLoss(fs, ss, wl, window))
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
"""Calculate forward propagation.
|
||||||
|
Args:
|
||||||
|
x (Tensor): Predicted signal (B, T).
|
||||||
|
y (Tensor): Groundtruth signal (B, T).
|
||||||
|
Returns:
|
||||||
|
Tensor: Multi resolution spectral convergence loss value.
|
||||||
|
Tensor: Multi resolution log STFT magnitude loss value.
|
||||||
|
"""
|
||||||
|
sc_loss = 0.0
|
||||||
|
mag_loss = 0.0
|
||||||
|
for f in self.stft_losses:
|
||||||
|
sc_l, mag_l = f(x, y)
|
||||||
|
sc_loss += sc_l
|
||||||
|
mag_loss += mag_l
|
||||||
|
sc_loss /= len(self.stft_losses)
|
||||||
|
mag_loss /= len(self.stft_losses)
|
||||||
|
|
||||||
|
return sc_loss, mag_loss
|
|
@ -1,162 +0,0 @@
|
||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import Callable, Mapping, List
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
class KBest(object):
|
|
||||||
"""
|
|
||||||
A utility class to help save the hard drive by only keeping K best
|
|
||||||
checkpoints.
|
|
||||||
|
|
||||||
To be as modularized as possible, this class does not assume anything like
|
|
||||||
a Trainer class or anything like a checkpoint directory, it does not know
|
|
||||||
about the model or the optimizer, etc.
|
|
||||||
|
|
||||||
It is basically a dynamically mantained K-bset Mapping. When a new item is
|
|
||||||
added to the map, save_fn is called. And when an item is removed from the
|
|
||||||
map, del_fn is called. `save_fn` and `del_fn` takes a Path object as input
|
|
||||||
and returns nothing.
|
|
||||||
|
|
||||||
Though it is designed to control checkpointing behaviors, it can be used
|
|
||||||
to do something else if you pass some save_fn and del_fn.
|
|
||||||
|
|
||||||
Example
|
|
||||||
--------
|
|
||||||
|
|
||||||
>>> from pathlib import Path
|
|
||||||
>>> import shutil
|
|
||||||
>>> import paddle
|
|
||||||
>>> from paddle import nn
|
|
||||||
|
|
||||||
>>> model = nn.Linear(2, 3)
|
|
||||||
>>> def save_model(path):
|
|
||||||
... paddle.save(model.state_dict(), path)
|
|
||||||
|
|
||||||
>>> kbest_manager = KBest(max_size=5, save_fn=save_model)
|
|
||||||
>>> checkpoint_dir = Path("checkpoints")
|
|
||||||
>>> shutil.rmtree(checkpoint_dir)
|
|
||||||
>>> checkpoint_dir.mkdir(parents=True)
|
|
||||||
>>> a = np.random.rand(20)
|
|
||||||
>>> for i, score in enumerate(a):
|
|
||||||
... path = checkpoint_dir / f"step_{i}"
|
|
||||||
... kbest_manager.add_checkpoint(score, path)
|
|
||||||
>>> assert len(list(checkpoint_dir.glob("step_*"))) == 5
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
max_size: int=5,
|
|
||||||
save_fn: Callable[[Path], None]=None,
|
|
||||||
del_fn: Callable[[Path], None]=lambda f: f.unlink()):
|
|
||||||
self.best_records: Mapping[Path, float] = {}
|
|
||||||
self.save_fn = save_fn
|
|
||||||
self.del_fn = del_fn
|
|
||||||
self.max_size = max_size
|
|
||||||
self._save_all = (max_size == -1)
|
|
||||||
|
|
||||||
def should_save(self, metric: float) -> bool:
|
|
||||||
if not self.full():
|
|
||||||
return True
|
|
||||||
|
|
||||||
# already full
|
|
||||||
worst_record_path = max(self.best_records, key=self.best_records.get)
|
|
||||||
worst_metric = self.best_records[worst_record_path]
|
|
||||||
return metric < worst_metric
|
|
||||||
|
|
||||||
def full(self):
|
|
||||||
return (not self._save_all) and len(self.best_records) == self.max_size
|
|
||||||
|
|
||||||
def add_checkpoint(self, metric, path):
|
|
||||||
if self.should_save(metric):
|
|
||||||
self.save_checkpoint_and_update(metric, path)
|
|
||||||
|
|
||||||
def save_checkpoint_and_update(self, metric, path):
|
|
||||||
# remove the worst
|
|
||||||
if self.full():
|
|
||||||
worst_record_path = max(self.best_records,
|
|
||||||
key=self.best_records.get)
|
|
||||||
self.best_records.pop(worst_record_path)
|
|
||||||
self.del_fn(worst_record_path)
|
|
||||||
|
|
||||||
# add the new one
|
|
||||||
self.save_fn(path)
|
|
||||||
self.best_records[path] = metric
|
|
||||||
|
|
||||||
|
|
||||||
class KLatest(object):
|
|
||||||
"""
|
|
||||||
A utility class to help save the hard drive by only keeping K latest
|
|
||||||
checkpoints.
|
|
||||||
|
|
||||||
To be as modularized as possible, this class does not assume anything like
|
|
||||||
a Trainer class or anything like a checkpoint directory, it does not know
|
|
||||||
about the model or the optimizer, etc.
|
|
||||||
|
|
||||||
It is basically a dynamically mantained Queue. When a new item is
|
|
||||||
added to the queue, save_fn is called. And when an item is removed from the
|
|
||||||
queue, del_fn is called. `save_fn` and `del_fn` takes a Path object as input
|
|
||||||
and returns nothing.
|
|
||||||
|
|
||||||
Though it is designed to control checkpointing behaviors, it can be used
|
|
||||||
to do something else if you pass some save_fn and del_fn.
|
|
||||||
|
|
||||||
Example
|
|
||||||
--------
|
|
||||||
|
|
||||||
>>> from pathlib import Path
|
|
||||||
>>> import shutil
|
|
||||||
>>> import paddle
|
|
||||||
>>> from paddle import nn
|
|
||||||
|
|
||||||
>>> model = nn.Linear(2, 3)
|
|
||||||
>>> def save_model(path):
|
|
||||||
... paddle.save(model.state_dict(), path)
|
|
||||||
|
|
||||||
>>> klatest_manager = KLatest(max_size=5, save_fn=save_model)
|
|
||||||
>>> checkpoint_dir = Path("checkpoints")
|
|
||||||
>>> shutil.rmtree(checkpoint_dir)
|
|
||||||
>>> checkpoint_dir.mkdir(parents=True)
|
|
||||||
>>> for i in range(20):
|
|
||||||
... path = checkpoint_dir / f"step_{i}"
|
|
||||||
... klatest_manager.add_checkpoint(path)
|
|
||||||
>>> assert len(list(checkpoint_dir.glob("step_*"))) == 5
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
max_size: int=5,
|
|
||||||
save_fn: Callable[[Path], None]=None,
|
|
||||||
del_fn: Callable[[Path], None]=lambda f: f.unlink()):
|
|
||||||
self.latest_records: List[Path] = []
|
|
||||||
self.save_fn = save_fn
|
|
||||||
self.del_fn = del_fn
|
|
||||||
self.max_size = max_size
|
|
||||||
self._save_all = (max_size == -1)
|
|
||||||
|
|
||||||
def full(self):
|
|
||||||
return (
|
|
||||||
not self._save_all) and len(self.latest_records) == self.max_size
|
|
||||||
|
|
||||||
def add_checkpoint(self, path):
|
|
||||||
self.save_checkpoint_and_update(path)
|
|
||||||
|
|
||||||
def save_checkpoint_and_update(self, path):
|
|
||||||
# remove the earist
|
|
||||||
if self.full():
|
|
||||||
eariest_record_path = self.latest_records.pop(0)
|
|
||||||
self.del_fn(eariest_record_path)
|
|
||||||
|
|
||||||
# add the new one
|
|
||||||
self.save_fn(path)
|
|
||||||
self.latest_records.append(path)
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
PRIORITY_WRITER = 300
|
||||||
|
PRIORITY_EDITOR = 200
|
||||||
|
PRIORITY_READER = 100
|
||||||
|
|
||||||
|
|
||||||
|
class Extension(object):
|
||||||
|
"""Extension to customize the behavior of Trainer."""
|
||||||
|
trigger = (1, 'iteration')
|
||||||
|
priority = PRIORITY_READER
|
||||||
|
name = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_name(self):
|
||||||
|
"""Default name of the extension, class name by default."""
|
||||||
|
return type(self).__name__
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Main action of the extention. After each update, it is executed
|
||||||
|
when the trigger fires."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
'Extension implementation must override __call__.')
|
||||||
|
|
||||||
|
def initialize(self, trainer):
|
||||||
|
"""Action that is executed once to get the corect trainer state.
|
||||||
|
It is called before training normally, but if the trainer restores
|
||||||
|
states with an Snapshot extension, this method should also be called.g
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_error(self, trainer, exc, tb):
|
||||||
|
"""Handles the error raised during training before finalization.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finalize(self, trainer):
|
||||||
|
"""Action that is executed when training is done.
|
||||||
|
For example, visualizers would need to be closed.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def make_extension(trigger: Callable=None,
|
||||||
|
default_name: str=None,
|
||||||
|
priority: int=None,
|
||||||
|
finalizer: Callable=None,
|
||||||
|
initializer: Callable=None,
|
||||||
|
on_error: Callable=None):
|
||||||
|
"""Make an Extension-like object by injecting required attributes to it.
|
||||||
|
"""
|
||||||
|
if trigger is None:
|
||||||
|
trigger = Extension.trigger
|
||||||
|
if priority is None:
|
||||||
|
priority = Extension.priority
|
||||||
|
|
||||||
|
def decorator(ext):
|
||||||
|
ext.trigger = trigger
|
||||||
|
ext.default_name = default_name or ext.__name__
|
||||||
|
ext.priority = priority
|
||||||
|
ext.finalize = finalizer
|
||||||
|
ext.on_error = on_error
|
||||||
|
ext.initialize = initializer
|
||||||
|
return ext
|
||||||
|
|
||||||
|
return decorator
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional, Callable, Dict
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
|
||||||
|
from parakeet.training.reporter import scope, report, DictSummary
|
||||||
|
from parakeet.training import extension
|
||||||
|
|
||||||
|
|
||||||
|
class StandardEvaluator(extension.Extension):
|
||||||
|
|
||||||
|
trigger = (1, 'epoch')
|
||||||
|
default_name = 'validation'
|
||||||
|
priority = extension.PRIORITY_WRITER
|
||||||
|
|
||||||
|
name = None
|
||||||
|
|
||||||
|
def __init__(self, model: Layer, dataloader: DataLoader):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
# compute
|
||||||
|
self.model(batch) # you may report here
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
# switch to eval mode
|
||||||
|
for layer in self.models.values():
|
||||||
|
layer.eval()
|
||||||
|
|
||||||
|
# to average evaluation metrics
|
||||||
|
summary = DictSummary()
|
||||||
|
for batch in self.dataloader:
|
||||||
|
observation = {}
|
||||||
|
with scope(observation):
|
||||||
|
# main evaluation computation here.
|
||||||
|
with paddle.no_grad():
|
||||||
|
self.evaluate_core(batch)
|
||||||
|
summary.add(observation)
|
||||||
|
summary = summary.compute_mean()
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def __call__(self, trainer=None):
|
||||||
|
# evaluate and report the averaged metric to current observation
|
||||||
|
# if it is used to extend a trainer, the metrics is reported to
|
||||||
|
# to observation of the trainer
|
||||||
|
# or otherwise, you can use your own observation
|
||||||
|
summary = self.evaluate()
|
||||||
|
for k, v in summary.items():
|
||||||
|
report(k, v)
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from parakeet.utils.mp_tools import rank_zero_only
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training import extension
|
||||||
|
|
||||||
|
|
||||||
|
def load_records(records_fp):
|
||||||
|
"""Load record files (json lines.)"""
|
||||||
|
with jsonlines.open(records_fp, 'r') as reader:
|
||||||
|
records = list(reader)
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
class Snapshot(extension.Extension):
|
||||||
|
"""An extension to make snapshot of the updater object inside
|
||||||
|
the trainer. It is done by calling the updater's `save` method.
|
||||||
|
|
||||||
|
An Updater save its state_dict by default, which contains the
|
||||||
|
updater state, (i.e. epoch and iteration) and all the model
|
||||||
|
parameters and optimizer states. If the updater inside the trainer
|
||||||
|
subclasses StandardUpdater, everything is good to go.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
checkpoint_dir : Union[str, Path]
|
||||||
|
The directory to save checkpoints into.
|
||||||
|
"""
|
||||||
|
|
||||||
|
trigger = (1, 'epoch')
|
||||||
|
priority = -100
|
||||||
|
default_name = "snapshot"
|
||||||
|
|
||||||
|
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
|
||||||
|
self.records: List[Dict[str, Any]] = []
|
||||||
|
self.max_size = max_size
|
||||||
|
self._snapshot_on_error = snapshot_on_error
|
||||||
|
self._save_all = (max_size == -1)
|
||||||
|
self.checkpoint_dir =...
|
||||||
|
|
||||||
|
def initialize(self, trainer: Trainer):
|
||||||
|
"""Setting up this extention."""
|
||||||
|
self.checkpoint_dir = trainer.out / "checkpoints"
|
||||||
|
|
||||||
|
# load existing records
|
||||||
|
record_path: Path = self.checkpoint_dir / "records.jsonl"
|
||||||
|
if record_path.exists():
|
||||||
|
logging.debug("Loading from an existing checkpoint dir")
|
||||||
|
self.records = load_records(record_path)
|
||||||
|
trainer.updater.load(self.records[-1]['path'])
|
||||||
|
|
||||||
|
def on_error(self, trainer, exc, tb):
|
||||||
|
if self._snapshot_on_error:
|
||||||
|
self.save_checkpoint_and_update(trainer)
|
||||||
|
|
||||||
|
def __call__(self, trainer: Trainer):
|
||||||
|
self.save_checkpoint_and_update(trainer)
|
||||||
|
|
||||||
|
def full(self):
|
||||||
|
"""Whether the number of snapshots it keeps track of is greater
|
||||||
|
than the max_size."""
|
||||||
|
return (not self._save_all) and len(self.records) > self.max_size
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def save_checkpoint_and_update(self, trainer: Trainer):
|
||||||
|
"""Saving new snapshot and remove the oldest snapshot if needed."""
|
||||||
|
iteration = trainer.updater.state.iteration
|
||||||
|
path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz"
|
||||||
|
|
||||||
|
# add the new one
|
||||||
|
trainer.updater.save(path)
|
||||||
|
record = {
|
||||||
|
"time": str(datetime.now()),
|
||||||
|
'path': str(path.resolve()), # use absolute path
|
||||||
|
'iteration': iteration
|
||||||
|
}
|
||||||
|
self.records.append(record)
|
||||||
|
|
||||||
|
# remove the earist
|
||||||
|
if self.full():
|
||||||
|
eariest_record = self.records[0]
|
||||||
|
os.remove(eariest_record["path"])
|
||||||
|
self.records.pop(0)
|
||||||
|
|
||||||
|
# update the record file
|
||||||
|
record_path = self.checkpoint_dir / "records.jsonl"
|
||||||
|
with jsonlines.open(record_path, 'w') as writer:
|
||||||
|
for record in self.records:
|
||||||
|
# jsonlines.open may return a Writer or a Reader
|
||||||
|
writer.write(record) # pylint: disable=no-member
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training import extension
|
||||||
|
|
||||||
|
|
||||||
|
class VisualDL(extension.Extension):
|
||||||
|
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
|
||||||
|
are all scalars which are recorded into the `.observation` dictionary of the
|
||||||
|
trainer object. The dictionary is created for each step, thus the visualdl log
|
||||||
|
writer uses the iteration from the updater's `iteration` as the global step to
|
||||||
|
add records.
|
||||||
|
"""
|
||||||
|
trigger = (1, 'iteration')
|
||||||
|
default_name = 'visualdl'
|
||||||
|
priority = extension.PRIORITY_READER
|
||||||
|
|
||||||
|
def __init__(self, writer):
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
|
def __call__(self, trainer: Trainer):
|
||||||
|
for k, v in trainer.observation.items():
|
||||||
|
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
|
||||||
|
|
||||||
|
def finalize(self, trainer):
|
||||||
|
self.writer.close()
|
|
@ -12,7 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
OBSERVATIONS = None
|
OBSERVATIONS = None
|
||||||
|
|
||||||
|
@ -45,3 +47,113 @@ def report(name, value):
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
observations[name] = value
|
observations[name] = value
|
||||||
|
|
||||||
|
|
||||||
|
class Summary(object):
|
||||||
|
"""Online summarization of a sequence of scalars.
|
||||||
|
Summary computes the statistics of given scalars online.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._x = 0.0
|
||||||
|
self._x2 = 0.0
|
||||||
|
self._n = 0
|
||||||
|
|
||||||
|
def add(self, value, weight=1):
|
||||||
|
"""Adds a scalar value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Scalar value to accumulate. It is either a NumPy scalar or
|
||||||
|
a zero-dimensional array (on CPU or GPU).
|
||||||
|
weight: An optional weight for the value. It is a NumPy scalar or
|
||||||
|
a zero-dimensional array (on CPU or GPU).
|
||||||
|
Default is 1 (integer).
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._x += weight * value
|
||||||
|
self._x2 += weight * value * value
|
||||||
|
self._n += weight
|
||||||
|
|
||||||
|
def compute_mean(self):
|
||||||
|
"""Computes the mean."""
|
||||||
|
x, n = self._x, self._n
|
||||||
|
return x / n
|
||||||
|
|
||||||
|
def make_statistics(self):
|
||||||
|
"""Computes and returns the mean and standard deviation values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Mean and standard deviation values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
x, n = self._x, self._n
|
||||||
|
mean = x / n
|
||||||
|
var = self._x2 / n - mean * mean
|
||||||
|
std = math.sqrt(var)
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
class DictSummary(object):
|
||||||
|
"""Online summarization of a sequence of dictionaries.
|
||||||
|
|
||||||
|
``DictSummary`` computes the statistics of a given set of scalars online.
|
||||||
|
It only computes the statistics for scalar values and variables of scalar
|
||||||
|
values in the dictionaries.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._summaries = defaultdict(Summary)
|
||||||
|
|
||||||
|
def add(self, d):
|
||||||
|
"""Adds a dictionary of scalars.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d (dict): Dictionary of scalars to accumulate. Only elements of
|
||||||
|
scalars, zero-dimensional arrays, and variables of
|
||||||
|
zero-dimensional arrays are accumulated. When the value
|
||||||
|
is a tuple, the second element is interpreted as a weight.
|
||||||
|
|
||||||
|
"""
|
||||||
|
summaries = self._summaries
|
||||||
|
for k, v in d.items():
|
||||||
|
w = 1
|
||||||
|
if isinstance(v, tuple):
|
||||||
|
w = v[1]
|
||||||
|
v = v[0]
|
||||||
|
summaries[k].add(v, weight=w)
|
||||||
|
|
||||||
|
def compute_mean(self):
|
||||||
|
"""Creates a dictionary of mean values.
|
||||||
|
|
||||||
|
It returns a single dictionary that holds a mean value for each entry
|
||||||
|
added to the summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of mean values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: summary.compute_mean()
|
||||||
|
for name, summary in self._summaries.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def make_statistics(self):
|
||||||
|
"""Creates a dictionary of statistics.
|
||||||
|
|
||||||
|
It returns a single dictionary that holds mean and standard deviation
|
||||||
|
values for every entry added to the summary. For an entry of name
|
||||||
|
``'key'``, these values are added to the dictionary by names ``'key'``
|
||||||
|
and ``'key.std'``, respectively.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of statistics of all entries.
|
||||||
|
|
||||||
|
"""
|
||||||
|
stats = {}
|
||||||
|
for name, summary in self._summaries.items():
|
||||||
|
mean, std = summary.make_statistics()
|
||||||
|
stats[name] = mean
|
||||||
|
stats[name + '.std'] = std
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int):
|
||||||
|
"""Seed paddle, random and np.random to help reproductivity."""
|
||||||
|
paddle.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
logging.debug(f"Set the seed of paddle, random, np.random to {seed}.")
|
|
@ -12,16 +12,22 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import six
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tqdm
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from typing import Callable, Union, List
|
||||||
|
|
||||||
from parakeet.training.trigger import get_trigger, IntervalTrigger
|
import tqdm
|
||||||
|
|
||||||
|
from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger
|
||||||
from parakeet.training.updater import UpdaterBase
|
from parakeet.training.updater import UpdaterBase
|
||||||
from parakeet.training.reporter import scope
|
from parakeet.training.reporter import scope
|
||||||
|
from parakeet.training.extension import Extension, PRIORITY_READER
|
||||||
|
|
||||||
|
|
||||||
class ExtensionEntry(object):
|
class _ExtensionEntry(object):
|
||||||
def __init__(self, extension, trigger, priority):
|
def __init__(self, extension, trigger, priority):
|
||||||
self.extension = extension
|
self.extension = extension
|
||||||
self.trigger = trigger
|
self.trigger = trigger
|
||||||
|
@ -31,31 +37,76 @@ class ExtensionEntry(object):
|
||||||
class Trainer(object):
|
class Trainer(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
updater: UpdaterBase,
|
updater: UpdaterBase,
|
||||||
stop_trigger=None,
|
stop_trigger: Callable=None,
|
||||||
out='result',
|
out: Union[str, Path]='result',
|
||||||
extensions=None):
|
extensions: List[Extension]=None):
|
||||||
self.updater = updater
|
self.updater = updater
|
||||||
self.extensions = {}
|
self.extensions = OrderedDict()
|
||||||
self.stop_trigger = get_trigger(stop_trigger)
|
self.stop_trigger = LimitTrigger(*stop_trigger)
|
||||||
self.out = Path(out)
|
self.out = Path(out)
|
||||||
self.observation = {}
|
self.observation =...
|
||||||
|
|
||||||
def setup(self):
|
self._done = False
|
||||||
pass
|
if extensions:
|
||||||
|
for ext in extensions:
|
||||||
|
self.extend(ext)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_before_training(self):
|
||||||
|
return self.updater.state.iteration == 0
|
||||||
|
|
||||||
def extend(self, extension, name=None, trigger=None, priority=None):
|
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||||
|
# get name for the extension
|
||||||
|
# argument \
|
||||||
|
# -> extention's name \
|
||||||
|
# -> default_name (class name, when it is an object) \
|
||||||
|
# -> function name when it is a function \
|
||||||
|
# -> error
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
name = getattr(extension, 'name', None)
|
||||||
|
if name is None:
|
||||||
|
name = getattr(extension, 'default_name', None)
|
||||||
|
if name is None:
|
||||||
|
name = getattr(extension, '__name__', None)
|
||||||
|
if name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Name is not given for the extension.")
|
||||||
|
if name == 'training':
|
||||||
|
raise ValueError("training is a reserved name.")
|
||||||
|
|
||||||
|
if trigger is None:
|
||||||
|
trigger = getattr(extension, 'trigger', (1, 'iteration'))
|
||||||
trigger = get_trigger(trigger)
|
trigger = get_trigger(trigger)
|
||||||
|
|
||||||
|
if priority is None:
|
||||||
|
priority = getattr(extension, 'priority', PRIORITY_READER)
|
||||||
|
|
||||||
|
# add suffix to avoid nameing conflict
|
||||||
ordinal = 0
|
ordinal = 0
|
||||||
modified_name = name
|
modified_name = name
|
||||||
while name in self.extensions:
|
while modified_name in self.extensions:
|
||||||
ordinal += 1
|
ordinal += 1
|
||||||
modified_name = f"{name}_{ordinal}"
|
modified_name = f"{name}_{ordinal}"
|
||||||
|
extension.name = modified_name
|
||||||
|
|
||||||
self.extensions[modified_name] = ExtensionEntry(extension, trigger,
|
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
|
||||||
priority)
|
priority)
|
||||||
|
|
||||||
|
def get_extension(self, name):
|
||||||
|
"""get extension by name."""
|
||||||
|
extensions = self.extensions
|
||||||
|
if name in extensions:
|
||||||
|
return extensions[name].extension
|
||||||
|
else:
|
||||||
|
raise ValueError(f'extension {name} not found')
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
if self._done:
|
||||||
|
raise RuntimeError("Training is already done!.")
|
||||||
|
|
||||||
|
self.out.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# sort extensions by priorities once
|
# sort extensions by priorities once
|
||||||
extension_order = sorted(
|
extension_order = sorted(
|
||||||
self.extensions.keys(),
|
self.extensions.keys(),
|
||||||
|
@ -64,28 +115,72 @@ class Trainer(object):
|
||||||
extensions = [(name, self.extensions[name])
|
extensions = [(name, self.extensions[name])
|
||||||
for name in extension_order]
|
for name in extension_order]
|
||||||
|
|
||||||
update = self.updater.update
|
# initializing all extensions
|
||||||
|
for name, entry in extensions:
|
||||||
|
if hasattr(entry.extension, "initialize"):
|
||||||
|
entry.extension.initialize(self)
|
||||||
|
|
||||||
|
update = self.updater.update # training step
|
||||||
stop_trigger = self.stop_trigger
|
stop_trigger = self.stop_trigger
|
||||||
|
|
||||||
# TODO(chenfeiyu): display progress bar correctly
|
print(self.updater.state)
|
||||||
# if the trainer is controlled by epoch: use 2 progressbars
|
|
||||||
# if the trainer is controlled by iteration: use 1 progressbar
|
# display only one progress bar
|
||||||
if isinstance(stop_trigger, IntervalTrigger):
|
max_iteration = None
|
||||||
|
if isinstance(stop_trigger, LimitTrigger):
|
||||||
if stop_trigger.unit is 'epoch':
|
if stop_trigger.unit is 'epoch':
|
||||||
max_epoch = self.stop_trigger.period
|
max_epoch = self.stop_trigger.limit
|
||||||
|
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
|
||||||
|
None)
|
||||||
|
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
|
||||||
else:
|
else:
|
||||||
max_iteration = self.stop_trigger.period
|
max_iteration = self.stop_trigger.limit
|
||||||
|
|
||||||
while not stop_trigger(self):
|
p = tqdm.tqdm(
|
||||||
self.observation = {}
|
initial=self.updater.state.iteration, total=max_iteration)
|
||||||
# set observation as the report target
|
|
||||||
# you can use report freely in Updater.update()
|
|
||||||
|
|
||||||
# updating parameters and state
|
try:
|
||||||
with scope(self.observation):
|
while not stop_trigger(self):
|
||||||
update()
|
self.observation = {}
|
||||||
|
# set observation as the report target
|
||||||
|
# you can use report freely in Updater.update()
|
||||||
|
|
||||||
# execute extension when necessary
|
# updating parameters and state
|
||||||
|
with scope(self.observation):
|
||||||
|
update()
|
||||||
|
p.update()
|
||||||
|
|
||||||
|
# execute extension when necessary
|
||||||
|
for name, entry in extensions:
|
||||||
|
if entry.trigger(self):
|
||||||
|
entry.extension(self)
|
||||||
|
|
||||||
|
# print("###", self.observation)
|
||||||
|
except Exception as e:
|
||||||
|
f = sys.stderr
|
||||||
|
f.write(f"Exception in main training loop: {e}\n")
|
||||||
|
f.write("Traceback (most recent call last):\n")
|
||||||
|
traceback.print_tb(sys.exc_info()[2])
|
||||||
|
f.write(
|
||||||
|
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
|
||||||
|
)
|
||||||
|
|
||||||
|
# capture the exception in the mian training loop
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
|
||||||
|
# try to handle it
|
||||||
for name, entry in extensions:
|
for name, entry in extensions:
|
||||||
if entry.trigger(self):
|
if hasattr(entry.extension, "on_error"):
|
||||||
entry.extension(self)
|
try:
|
||||||
|
entry.extension.on_error(self, e, sys.exc_info()[2])
|
||||||
|
except Exception as ee:
|
||||||
|
f.write(f"Exception in error handler: {ee}\n")
|
||||||
|
f.write('Traceback (most recent call last):\n')
|
||||||
|
traceback.print_tb(sys.exc_info()[2])
|
||||||
|
|
||||||
|
# raise exception in main training loop
|
||||||
|
six.reraise(*exc_info)
|
||||||
|
finally:
|
||||||
|
for name, entry in extensions:
|
||||||
|
if hasattr(entry.extension, "finalize"):
|
||||||
|
entry.extension.finalize(self)
|
||||||
|
|
|
@ -12,21 +12,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from parakeet.training.triggers.interval_trigger import IntervalTrigger
|
||||||
class IntervalTrigger(object):
|
from parakeet.training.triggers.limit_trigger import LimitTrigger
|
||||||
def __init__(self, period: int, unit: str):
|
from parakeet.training.triggers.time_trigger import TimeTrigger
|
||||||
if unit not in ("iteration", "epoch"):
|
|
||||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
|
||||||
self.period = period
|
|
||||||
self.unit = unit
|
|
||||||
|
|
||||||
def __call__(self, trainer):
|
|
||||||
state = trainer.updater.state
|
|
||||||
if self.unit == "epoch":
|
|
||||||
fire = not (state.epoch % self.period)
|
|
||||||
else:
|
|
||||||
fire = not (state.iteration % self.iteration)
|
|
||||||
return fire
|
|
||||||
|
|
||||||
|
|
||||||
def never_file_trigger(trainer):
|
def never_file_trigger(trainer):
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
class IntervalTrigger(object):
|
||||||
|
"""A Predicate to do something every N cycle."""
|
||||||
|
|
||||||
|
def __init__(self, period: int, unit: str):
|
||||||
|
if unit not in ("iteration", "epoch"):
|
||||||
|
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||||
|
if period <= 0:
|
||||||
|
raise ValueError("period should be a positive integer.")
|
||||||
|
self.period = period
|
||||||
|
self.unit = unit
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
state = trainer.updater.state
|
||||||
|
index = getattr(state, self.unit)
|
||||||
|
fire = index % self.period == 0
|
||||||
|
return fire
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
class LimitTrigger(object):
|
||||||
|
"""A Predicate to decide whether to stop."""
|
||||||
|
|
||||||
|
def __init__(self, limit: int, unit: str):
|
||||||
|
if unit not in ("iteration", "epoch"):
|
||||||
|
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||||
|
if limit <= 0:
|
||||||
|
raise ValueError("limit should be a positive integer.")
|
||||||
|
self.limit = limit
|
||||||
|
self.unit = unit
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
state = trainer.updater.state
|
||||||
|
index = getattr(state, self.unit)
|
||||||
|
fire = index >= self.limit
|
||||||
|
return fire
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
class TimeTrigger(object):
|
||||||
|
"""Trigger based on a fixed time interval.
|
||||||
|
|
||||||
|
This trigger accepts iterations with a given interval time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
period (float): Interval time. It is given in seconds.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, period):
|
||||||
|
self._period = period
|
||||||
|
self._next_time = self._period
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
if self._next_time < trainer.elapsed_time:
|
||||||
|
self._next_time += self._period
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
|
@ -12,12 +12,21 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from timer import timer
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
from paddle.nn import Layer
|
from paddle.nn import Layer
|
||||||
from paddle.optimizer import Optimizer
|
from paddle.optimizer import Optimizer
|
||||||
from paddle.io import DataLoader
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -56,68 +65,33 @@ class UpdaterBase(object):
|
||||||
So the best practice is to define a model and define a updater for it.
|
So the best practice is to define a model and define a updater for it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update(self):
|
def __init__(self, init_state=None):
|
||||||
pass
|
|
||||||
|
|
||||||
def update_core(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StandardUpdater(UpdaterBase):
|
|
||||||
"""An example of over-simplification. Things may not be that simple, but
|
|
||||||
you can subclass it to fit your need.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model: Layer,
|
|
||||||
dataloader: DataLoader,
|
|
||||||
optimizer: Optimizer,
|
|
||||||
loss_func=None,
|
|
||||||
auto_new_epoch: bool=True,
|
|
||||||
init_state: Optional[UpdaterState]=None):
|
|
||||||
self.model = model
|
|
||||||
self.dataloader = dataloader
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.loss_func = loss_func
|
|
||||||
self.auto_new_epoch = auto_new_epoch
|
|
||||||
self.iterator = iter(dataloader)
|
|
||||||
|
|
||||||
if init_state is None:
|
if init_state is None:
|
||||||
self.state = UpdaterState()
|
self.state = UpdaterState()
|
||||||
else:
|
else:
|
||||||
self.state = init_state
|
self.state = init_state
|
||||||
|
|
||||||
def update(self):
|
def update(self, batch):
|
||||||
self.update_core()
|
raise NotImplementedError(
|
||||||
self.state.iteration += 1
|
"Implement your own `update` method for training a step.")
|
||||||
|
|
||||||
def new_epoch(self):
|
def state_dict(self):
|
||||||
self.iterator = iter(self.dataloader)
|
state_dict = {
|
||||||
self.state.epoch += 1
|
"epoch": self.state.epoch,
|
||||||
|
"iteration": self.state.iteration,
|
||||||
|
}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def update_core(self):
|
def set_state_dict(self, state_dict):
|
||||||
model = self.model
|
self.state.epoch = state_dict["epoch"]
|
||||||
optimizer = self.optimizer
|
self.state.iteration = state_dict["iteration"]
|
||||||
loss_func = self.loss_func
|
|
||||||
|
|
||||||
model.train()
|
def save(self, path):
|
||||||
optimizer.clear_grad()
|
logging.debug(f"Saving to {path}.")
|
||||||
|
archive = self.state_dict()
|
||||||
|
paddle.save(archive, str(path))
|
||||||
|
|
||||||
# fetch a batch
|
def load(self, path):
|
||||||
try:
|
logging.debug(f"Loading from {path}.")
|
||||||
batch = next(self.iterator)
|
archive = paddle.load(str(path))
|
||||||
except StopIteration as e:
|
self.set_state_dict(archive)
|
||||||
if self.auto_new_epoch:
|
|
||||||
self.new_epoch()
|
|
||||||
|
|
||||||
# forward
|
|
||||||
if self.loss_func is not None:
|
|
||||||
loss = loss_func(batch)
|
|
||||||
else:
|
|
||||||
loss = model(batch)
|
|
||||||
|
|
||||||
# backward
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
# update parameters
|
|
||||||
optimizer.step()
|
|
||||||
|
|
|
@ -0,0 +1,190 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from timer import timer
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.training.updater import UpdaterBase, UpdaterState
|
||||||
|
|
||||||
|
|
||||||
|
class StandardUpdater(UpdaterBase):
|
||||||
|
"""An example of over-simplification. Things may not be that simple, but
|
||||||
|
you can subclass it to fit your need.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model: Layer,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
init_state: Optional[UpdaterState]=None):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# it is designed to hold multiple optimizers
|
||||||
|
optimizers = {"main": optimizer}
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.optimizers: Dict[str, Optimizer] = optimizers
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
# init state
|
||||||
|
if init_state is None:
|
||||||
|
self.state = UpdaterState()
|
||||||
|
else:
|
||||||
|
self.state = init_state
|
||||||
|
|
||||||
|
self.train_iterator = iter(dataloader)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
# We increase the iteration index after updating and before extension.
|
||||||
|
# Here are the reasons.
|
||||||
|
|
||||||
|
# 0. Snapshotting(as well as other extensions, like visualizer) is
|
||||||
|
# executed after a step of updating;
|
||||||
|
# 1. We decide to increase the iteration index after updating and
|
||||||
|
# before any all extension is executed.
|
||||||
|
# 3. We do not increase the iteration after extension because we
|
||||||
|
# prefer a consistent resume behavior, when load from a
|
||||||
|
# `snapshot_iter_100.pdz` then the next step to train is `101`,
|
||||||
|
# naturally. But if iteration is increased increased after
|
||||||
|
# extension(including snapshot), then, a `snapshot_iter_99` is
|
||||||
|
# loaded. You would need a extra increasing of the iteration idex
|
||||||
|
# before training to avoid another iteration `99`, which has been
|
||||||
|
# done before snapshotting.
|
||||||
|
# 4. Thus iteration index represrnts "currently how mant epochs has
|
||||||
|
# been done."
|
||||||
|
# NOTE: use report to capture the correctly value. If you want to
|
||||||
|
# report the learning rate used for a step, you must report it before
|
||||||
|
# the learning rate scheduler's step() has been called. In paddle's
|
||||||
|
# convention, we do not use an extension to change the learning rate.
|
||||||
|
# so if you want to report it, do it in the updater.
|
||||||
|
|
||||||
|
# Then here comes the next question. When is the proper time to
|
||||||
|
# increase the epoch index? Since all extensions are executed after
|
||||||
|
# updating, it is the time that after updating is the proper time to
|
||||||
|
# increase epoch index.
|
||||||
|
# 1. If we increase the epoch index before updating, then an extension
|
||||||
|
# based ot epoch would miss the correct timing. It could only be
|
||||||
|
# triggerd after an extra updating.
|
||||||
|
# 2. Theoretically, when an epoch is done, the epoch index should be
|
||||||
|
# increased. So it would be increase after updating.
|
||||||
|
# 3. Thus, eppoch index represents "currently how many epochs has been
|
||||||
|
# done." So it starts from 0.
|
||||||
|
|
||||||
|
# switch to training mode
|
||||||
|
for layer in self.models.values():
|
||||||
|
layer.train()
|
||||||
|
|
||||||
|
# training for a step is implemented here
|
||||||
|
batch = self.read_batch()
|
||||||
|
self.update_core(batch)
|
||||||
|
|
||||||
|
self.state.iteration += 1
|
||||||
|
if self.updaters_per_epoch is not None:
|
||||||
|
if self.state.iteration % self.updaters_per_epoch == 0:
|
||||||
|
self.state.epoch += 1
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
"""A simple case for a training step. Basic assumptions are:
|
||||||
|
Single model;
|
||||||
|
Single optimizer;
|
||||||
|
A batch from the dataloader is just the input of the model;
|
||||||
|
The model return a single loss, or a dict containing serval losses.
|
||||||
|
Parameters updates at every batch, no gradient accumulation.
|
||||||
|
"""
|
||||||
|
loss = self.model(*batch)
|
||||||
|
|
||||||
|
if isinstance(loss, Tensor):
|
||||||
|
loss_dict = {"main": loss}
|
||||||
|
else:
|
||||||
|
# Dict[str, Tensor]
|
||||||
|
loss_dict = loss
|
||||||
|
if "main" not in loss_dict:
|
||||||
|
main_loss = 0
|
||||||
|
for loss_item in loss.values():
|
||||||
|
main_loss += loss_item
|
||||||
|
loss_dict["main"] = main_loss
|
||||||
|
|
||||||
|
for name, loss_item in loss_dict.items():
|
||||||
|
report(name, float(loss_item))
|
||||||
|
|
||||||
|
self.optimizer.clear_gradient()
|
||||||
|
loss_dict["main"].backward()
|
||||||
|
self.optimizer.update()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def updaters_per_epoch(self):
|
||||||
|
"""Number of updater per epoch, determined by the length of the
|
||||||
|
dataloader."""
|
||||||
|
length_of_dataloader = None
|
||||||
|
try:
|
||||||
|
length_of_dataloader = len(self.dataloader)
|
||||||
|
except TypeError:
|
||||||
|
logging.debug("This dataloader has no __len__.")
|
||||||
|
finally:
|
||||||
|
return length_of_dataloader
|
||||||
|
|
||||||
|
def new_epoch(self):
|
||||||
|
"""Start a new epoch."""
|
||||||
|
# NOTE: all batch sampler for distributed training should
|
||||||
|
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||||
|
batch_sampler = self.dataloader.batch_sampler
|
||||||
|
if isinstance(batch_sampler, DistributedBatchSampler):
|
||||||
|
batch_sampler.set_epoch(self.state.epoch)
|
||||||
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
|
def read_batch(self):
|
||||||
|
"""Read a batch from the data loader, auto renew when data is exhausted."""
|
||||||
|
with timer() as t:
|
||||||
|
try:
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.new_epoch()
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
logging.debug(
|
||||||
|
f"Read a batch takes {t.elapse}s.") # replace it with logging
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""State dict of a Updater, model, optimizer and updater state are included."""
|
||||||
|
state_dict = super().state_dict()
|
||||||
|
for name, layer in self.models.items():
|
||||||
|
state_dict[f"{name}_params"] = layer.state_dict()
|
||||||
|
for name, optim in self.optimizers.items():
|
||||||
|
state_dict[f"{name}_optimizer"] = optim.state_dict()
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def set_state_dict(self, state_dict):
|
||||||
|
"""Set state dict for a Updater. Parameters of models, states for
|
||||||
|
optimizers and UpdaterState are restored."""
|
||||||
|
for name, layer in self.models.items():
|
||||||
|
layer.set_state_dict(state_dict[f"{name}_params"])
|
||||||
|
for name, optim in self.optimizers.items():
|
||||||
|
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||||
|
super().set_state_dict(state_dict)
|
|
@ -0,0 +1,105 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union, Any
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def read_hdf5(filename: Union[Path, str], dataset_name: str) -> Any:
|
||||||
|
"""Read a dataset from a HDF5 file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename : Union[Path, str]
|
||||||
|
Path of the HDF5 file.
|
||||||
|
dataset_name : str
|
||||||
|
Name of the dataset to read.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Any
|
||||||
|
The retrieved dataset.
|
||||||
|
"""
|
||||||
|
filename = Path(filename)
|
||||||
|
|
||||||
|
if not filename.exists():
|
||||||
|
logging.error(f"There is no such a hdf5 file ({filename}).")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
hdf5_file = h5py.File(filename, "r")
|
||||||
|
|
||||||
|
if dataset_name not in hdf5_file:
|
||||||
|
logging.error(
|
||||||
|
f"There is no such a data in hdf5 file. ({dataset_name})")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# [()]: a special syntax of h5py to get the dataset as-is
|
||||||
|
hdf5_data = hdf5_file[dataset_name][()]
|
||||||
|
hdf5_file.close()
|
||||||
|
|
||||||
|
return hdf5_data
|
||||||
|
|
||||||
|
|
||||||
|
def write_hdf5(filename: Union[Path, str],
|
||||||
|
dataset_name: str,
|
||||||
|
write_data: np.ndarray,
|
||||||
|
is_overwrite: bool=True) -> None:
|
||||||
|
"""Write dataset to HDF5 file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename : Union[Path, str]
|
||||||
|
Path of the HDF5 file.
|
||||||
|
dataset_name : str
|
||||||
|
Name of the dataset to write to.
|
||||||
|
write_data : np.ndarrays
|
||||||
|
The data to write.
|
||||||
|
is_overwrite : bool, optional
|
||||||
|
Whether to overwrite, by default True
|
||||||
|
"""
|
||||||
|
# convert to numpy array
|
||||||
|
filename = Path(filename)
|
||||||
|
write_data = np.array(write_data)
|
||||||
|
|
||||||
|
# check folder existence
|
||||||
|
filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# check hdf5 existence
|
||||||
|
if filename.exists():
|
||||||
|
# if already exists, open with r+ mode
|
||||||
|
hdf5_file = h5py.File(filename, "r+")
|
||||||
|
# check dataset existence
|
||||||
|
if dataset_name in hdf5_file:
|
||||||
|
if is_overwrite:
|
||||||
|
logging.warning("Dataset in hdf5 file already exists. "
|
||||||
|
"recreate dataset in hdf5.")
|
||||||
|
hdf5_file.__delitem__(dataset_name)
|
||||||
|
else:
|
||||||
|
logging.error(
|
||||||
|
"Dataset in hdf5 file already exists. "
|
||||||
|
"if you want to overwrite, please set is_overwrite = True.")
|
||||||
|
hdf5_file.close()
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
# if not exists, open with w mode
|
||||||
|
hdf5_file = h5py.File(filename, "w")
|
||||||
|
|
||||||
|
# write data to hdf5
|
||||||
|
hdf5_file.create_dataset(dataset_name, data=write_data)
|
||||||
|
hdf5_file.flush()
|
||||||
|
hdf5_file.close()
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle.framework import core
|
||||||
|
from paddle.framework import CUDAPlace
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize():
|
||||||
|
"""Trigger cuda synchronization for better timing."""
|
||||||
|
place = paddle.fluid.framework._current_expected_place()
|
||||||
|
if isinstance(place, CUDAPlace):
|
||||||
|
paddle.fluid.core._cuda_synchronize(place)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def nvtx_span(name):
|
||||||
|
try:
|
||||||
|
core.nvprof_nvtx_push(name)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
core.nvprof_nvtx_pop()
|
|
@ -0,0 +1,319 @@
|
||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import google.protobuf.text_format as text_format
|
||||||
|
import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
'--profile_path',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help='Input profile file name. If there are multiple file, the format '
|
||||||
|
'should be trainer1=file1,trainer2=file2,ps=file3')
|
||||||
|
parser.add_argument(
|
||||||
|
'--timeline_path', type=str, default='', help='Output timeline file name.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class _ChromeTraceFormatter(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._events = []
|
||||||
|
self._metadata = []
|
||||||
|
|
||||||
|
def _create_event(self, ph, category, name, pid, tid, timestamp):
|
||||||
|
"""Creates a new Chrome Trace event.
|
||||||
|
|
||||||
|
For details of the file format, see:
|
||||||
|
https://github.com/catapult-project/catapult/blob/master/tracing/README.md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ph: The type of event - usually a single character.
|
||||||
|
category: The event category as a string.
|
||||||
|
name: The event name as a string.
|
||||||
|
pid: Identifier of the process generating this event as an integer.
|
||||||
|
tid: Identifier of the thread generating this event as an integer.
|
||||||
|
timestamp: The timestamp of this event as a long integer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON compatible event object.
|
||||||
|
"""
|
||||||
|
event = {}
|
||||||
|
event['ph'] = ph
|
||||||
|
event['cat'] = category
|
||||||
|
event['name'] = name.replace("ParallelExecutor::Run/", "")
|
||||||
|
event['pid'] = pid
|
||||||
|
event['tid'] = tid
|
||||||
|
event['ts'] = timestamp
|
||||||
|
return event
|
||||||
|
|
||||||
|
def emit_pid(self, name, pid):
|
||||||
|
"""Adds a process metadata event to the trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The process name as a string.
|
||||||
|
pid: Identifier of the process as an integer.
|
||||||
|
"""
|
||||||
|
event = {}
|
||||||
|
event['name'] = 'process_name'
|
||||||
|
event['ph'] = 'M'
|
||||||
|
event['pid'] = pid
|
||||||
|
event['args'] = {'name': name}
|
||||||
|
self._metadata.append(event)
|
||||||
|
|
||||||
|
def emit_region(self, timestamp, duration, pid, tid, category, name, args):
|
||||||
|
"""Adds a region event to the trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: The start timestamp of this region as a long integer.
|
||||||
|
duration: The duration of this region as a long integer.
|
||||||
|
pid: Identifier of the process generating this event as an integer.
|
||||||
|
tid: Identifier of the thread generating this event as an integer.
|
||||||
|
category: The event category as a string.
|
||||||
|
name: The event name as a string.
|
||||||
|
args: A JSON-compatible dictionary of event arguments.
|
||||||
|
"""
|
||||||
|
event = self._create_event('X', category, name, pid, tid, timestamp)
|
||||||
|
event['dur'] = duration
|
||||||
|
event['args'] = args
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def emit_counter(self, category, name, pid, timestamp, counter, value):
|
||||||
|
"""Emits a record for a single counter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The event category as string
|
||||||
|
name: The event name as string
|
||||||
|
pid: Identifier of the process generating this event as integer
|
||||||
|
timestamp: The timestamps of this event as long integer
|
||||||
|
counter: Name of the counter as string
|
||||||
|
value: Value of the counter as integer
|
||||||
|
tid: Thread id of the allocation as integer
|
||||||
|
"""
|
||||||
|
event = self._create_event('C', category, name, pid, 0, timestamp)
|
||||||
|
event['args'] = {counter: value}
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def format_to_string(self, pretty=False):
|
||||||
|
"""Formats the chrome trace to a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretty: (Optional.) If True, produce human-readable JSON output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON-formatted string in Chrome Trace format.
|
||||||
|
"""
|
||||||
|
trace = {}
|
||||||
|
trace['traceEvents'] = self._metadata + self._events
|
||||||
|
if pretty:
|
||||||
|
return json.dumps(trace, indent=4, separators=(',', ': '))
|
||||||
|
else:
|
||||||
|
return json.dumps(trace, separators=(',', ':'))
|
||||||
|
|
||||||
|
|
||||||
|
class Timeline(object):
|
||||||
|
def __init__(self, profile_dict):
|
||||||
|
self._profile_dict = profile_dict
|
||||||
|
self._pid = 0
|
||||||
|
self._devices = dict()
|
||||||
|
self._mem_devices = dict()
|
||||||
|
self._chrome_trace = _ChromeTraceFormatter()
|
||||||
|
|
||||||
|
def _allocate_pid(self):
|
||||||
|
cur_pid = self._pid
|
||||||
|
self._pid += 1
|
||||||
|
return cur_pid
|
||||||
|
|
||||||
|
def _allocate_pids(self):
|
||||||
|
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||||
|
for event in profile_pb.events:
|
||||||
|
if event.type == profiler_pb2.Event.CPU:
|
||||||
|
if (k, event.device_id, "CPU") not in self._devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._devices[(k, event.device_id, "CPU")] = pid
|
||||||
|
# -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy)
|
||||||
|
if event.device_id == -1:
|
||||||
|
self._chrome_trace.emit_pid("%s:cuda_api" % k, pid)
|
||||||
|
else:
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"%s:cpu:block:%d" % (k, event.device_id), pid)
|
||||||
|
elif event.type == profiler_pb2.Event.GPUKernel:
|
||||||
|
if (k, event.device_id, "GPUKernel") not in self._devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._devices[(k, event.device_id, "GPUKernel")] = pid
|
||||||
|
self._chrome_trace.emit_pid("%s:gpu:%d" %
|
||||||
|
(k, event.device_id), pid)
|
||||||
|
if not hasattr(profile_pb, "mem_events"):
|
||||||
|
continue
|
||||||
|
for mevent in profile_pb.mem_events:
|
||||||
|
if mevent.place == profiler_pb2.MemEvent.CUDAPlace:
|
||||||
|
if (k, mevent.device_id, "GPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id, "GPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:gpu:%d" % (k, mevent.device_id),
|
||||||
|
pid)
|
||||||
|
elif mevent.place == profiler_pb2.MemEvent.CPUPlace:
|
||||||
|
if (k, mevent.device_id, "CPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id, "CPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:cpu:%d" % (k, mevent.device_id),
|
||||||
|
pid)
|
||||||
|
elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace:
|
||||||
|
if (k, mevent.device_id, "CUDAPinnedPlace"
|
||||||
|
) not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id,
|
||||||
|
"CUDAPinnedPlace")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:cudapinnedplace:%d" %
|
||||||
|
(k, mevent.device_id), pid)
|
||||||
|
elif mevent.place == profiler_pb2.MemEvent.NPUPlace:
|
||||||
|
if (k, mevent.device_id, "NPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id, "NPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:npu:%d" % (k, mevent.device_id),
|
||||||
|
pid)
|
||||||
|
if (k, 0, "CPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "CPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" %
|
||||||
|
(k, 0), pid)
|
||||||
|
if (k, 0, "GPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "GPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" %
|
||||||
|
(k, 0), pid)
|
||||||
|
if (k, 0, "CUDAPinnedPlace") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:cudapinnedplace:%d" % (k, 0), pid)
|
||||||
|
if (k, 0, "NPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "NPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid("memory usage on %s:npu:%d" %
|
||||||
|
(k, 0), pid)
|
||||||
|
|
||||||
|
def _allocate_events(self):
|
||||||
|
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||||
|
for event in profile_pb.events:
|
||||||
|
if event.type == profiler_pb2.Event.CPU:
|
||||||
|
type = "CPU"
|
||||||
|
elif event.type == profiler_pb2.Event.GPUKernel:
|
||||||
|
type = "GPUKernel"
|
||||||
|
pid = self._devices[(k, event.device_id, type)]
|
||||||
|
args = {'name': event.name}
|
||||||
|
if event.memcopy.bytes > 0:
|
||||||
|
args['mem_bytes'] = event.memcopy.bytes
|
||||||
|
if hasattr(event, "detail_info") and event.detail_info:
|
||||||
|
args['detail_info'] = event.detail_info
|
||||||
|
# TODO(panyx0718): Chrome tracing only handles ms. However, some
|
||||||
|
# ops takes micro-seconds. Hence, we keep the ns here.
|
||||||
|
self._chrome_trace.emit_region(
|
||||||
|
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
|
||||||
|
event.sub_device_id, 'Op', event.name, args)
|
||||||
|
|
||||||
|
def _allocate_memory_event(self):
|
||||||
|
if not hasattr(profiler_pb2, "MemEvent"):
|
||||||
|
return
|
||||||
|
place_to_str = {
|
||||||
|
profiler_pb2.MemEvent.CPUPlace: "CPU",
|
||||||
|
profiler_pb2.MemEvent.CUDAPlace: "GPU",
|
||||||
|
profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace",
|
||||||
|
profiler_pb2.MemEvent.NPUPlace: "NPU"
|
||||||
|
}
|
||||||
|
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||||
|
mem_list = []
|
||||||
|
end_profiler = 0
|
||||||
|
for mevent in profile_pb.mem_events:
|
||||||
|
crt_info = dict()
|
||||||
|
crt_info['time'] = mevent.start_ns
|
||||||
|
crt_info['size'] = mevent.bytes
|
||||||
|
if mevent.place in place_to_str:
|
||||||
|
place = place_to_str[mevent.place]
|
||||||
|
else:
|
||||||
|
place = "UnDefine"
|
||||||
|
crt_info['place'] = place
|
||||||
|
pid = self._mem_devices[(k, mevent.device_id, place)]
|
||||||
|
crt_info['pid'] = pid
|
||||||
|
crt_info['thread_id'] = mevent.thread_id
|
||||||
|
crt_info['device_id'] = mevent.device_id
|
||||||
|
mem_list.append(crt_info)
|
||||||
|
crt_info = dict()
|
||||||
|
crt_info['place'] = place
|
||||||
|
crt_info['pid'] = pid
|
||||||
|
crt_info['thread_id'] = mevent.thread_id
|
||||||
|
crt_info['device_id'] = mevent.device_id
|
||||||
|
crt_info['time'] = mevent.end_ns
|
||||||
|
crt_info['size'] = -mevent.bytes
|
||||||
|
mem_list.append(crt_info)
|
||||||
|
end_profiler = max(end_profiler, crt_info['time'])
|
||||||
|
mem_list.sort(key=lambda tmp: (tmp.get('time', 0)))
|
||||||
|
i = 0
|
||||||
|
total_size = 0
|
||||||
|
while i < len(mem_list):
|
||||||
|
total_size += mem_list[i]['size']
|
||||||
|
while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[
|
||||||
|
i + 1]['time']:
|
||||||
|
total_size += mem_list[i + 1]['size']
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
self._chrome_trace.emit_counter(
|
||||||
|
"Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'],
|
||||||
|
0, total_size)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
def generate_chrome_trace(self):
|
||||||
|
self._allocate_pids()
|
||||||
|
self._allocate_events()
|
||||||
|
self._allocate_memory_event()
|
||||||
|
return self._chrome_trace.format_to_string()
|
||||||
|
|
||||||
|
|
||||||
|
profile_path = '/tmp/profile'
|
||||||
|
if args.profile_path:
|
||||||
|
profile_path = args.profile_path
|
||||||
|
timeline_path = '/tmp/timeline'
|
||||||
|
if args.timeline_path:
|
||||||
|
timeline_path = args.timeline_path
|
||||||
|
|
||||||
|
profile_paths = profile_path.split(',')
|
||||||
|
profile_dict = dict()
|
||||||
|
if len(profile_paths) == 1:
|
||||||
|
with open(profile_path, 'rb') as f:
|
||||||
|
profile_s = f.read()
|
||||||
|
profile_pb = profiler_pb2.Profile()
|
||||||
|
profile_pb.ParseFromString(profile_s)
|
||||||
|
profile_dict['trainer'] = profile_pb
|
||||||
|
else:
|
||||||
|
for profile_path in profile_paths:
|
||||||
|
k, v = profile_path.split('=')
|
||||||
|
with open(v, 'rb') as f:
|
||||||
|
profile_s = f.read()
|
||||||
|
profile_pb = profiler_pb2.Profile()
|
||||||
|
profile_pb.ParseFromString(profile_s)
|
||||||
|
profile_dict[k] = profile_pb
|
||||||
|
|
||||||
|
tl = Timeline(profile_dict)
|
||||||
|
with open(timeline_path, 'w') as f:
|
||||||
|
f.write(tl.generate_chrome_trace())
|
4
setup.py
4
setup.py
|
@ -64,7 +64,6 @@ setup_info = dict(
|
||||||
'scipy',
|
'scipy',
|
||||||
'pandas',
|
'pandas',
|
||||||
'sox',
|
'sox',
|
||||||
# 'opencc',
|
|
||||||
'soundfile',
|
'soundfile',
|
||||||
'g2p_en',
|
'g2p_en',
|
||||||
'yacs',
|
'yacs',
|
||||||
|
@ -73,6 +72,9 @@ setup_info = dict(
|
||||||
'webrtcvad',
|
'webrtcvad',
|
||||||
'g2pM',
|
'g2pM',
|
||||||
'praatio',
|
'praatio',
|
||||||
|
"h5py",
|
||||||
|
"timer",
|
||||||
|
'jsonlines',
|
||||||
],
|
],
|
||||||
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
||||||
|
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from parakeet.training.checkpoint import KBest, KLatest
|
|
||||||
|
|
||||||
|
|
||||||
def test_kbest():
|
|
||||||
def save_fn(path):
|
|
||||||
with open(path, 'wt') as f:
|
|
||||||
f.write(f"My path is {str(path)}\n")
|
|
||||||
|
|
||||||
K = 1
|
|
||||||
kbest_manager = KBest(max_size=K, save_fn=save_fn)
|
|
||||||
checkpoint_dir = Path("checkpoints")
|
|
||||||
shutil.rmtree(checkpoint_dir)
|
|
||||||
checkpoint_dir.mkdir(parents=True)
|
|
||||||
a = np.random.rand(20)
|
|
||||||
for i, score in enumerate(a):
|
|
||||||
path = checkpoint_dir / f"step_{i}"
|
|
||||||
kbest_manager.add_checkpoint(score, path)
|
|
||||||
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
|
||||||
|
|
||||||
|
|
||||||
def test_klatest():
|
|
||||||
def save_fn(path):
|
|
||||||
with open(path, 'wt') as f:
|
|
||||||
f.write(f"My path is {str(path)}\n")
|
|
||||||
|
|
||||||
K = 5
|
|
||||||
klatest_manager = KLatest(max_size=K, save_fn=save_fn)
|
|
||||||
checkpoint_dir = Path("checkpoints")
|
|
||||||
shutil.rmtree(checkpoint_dir)
|
|
||||||
checkpoint_dir.mkdir(parents=True)
|
|
||||||
for i in range(20):
|
|
||||||
path = checkpoint_dir / f"step_{i}"
|
|
||||||
klatest_manager.add_checkpoint(path)
|
|
||||||
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from parakeet.datasets.data_tabel import DataTable
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_dataset():
|
||||||
|
metadata = [{'name': 'Sonic', 'v': 1000}, {'name': 'Prestol', 'v': 2000}]
|
||||||
|
converters = {'v': lambda x: x / 1000}
|
||||||
|
dataset = DataTable(metadata, fields=['v'], converters=converters)
|
||||||
|
assert dataset[0] == {'v': 1.0}
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
from paddle.optimizer.lr import StepDecay
|
||||||
|
|
||||||
|
|
||||||
|
def test_optimizer():
|
||||||
|
model1 = nn.Linear(3, 4)
|
||||||
|
optim1 = Adam(
|
||||||
|
parameters=model1.parameters(), learning_rate=StepDecay(0.1, 100))
|
||||||
|
|
||||||
|
output_dir = Path("temp_test_optimizer")
|
||||||
|
shutil.rmtree(output_dir, ignore_errors=True)
|
||||||
|
output_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# model1.set_state_dict(model1.state_dict())
|
||||||
|
optim1.set_state_dict(optim1.state_dict())
|
||||||
|
|
||||||
|
x = paddle.randn([6, 3])
|
||||||
|
y = model1(x).sum()
|
||||||
|
y.backward()
|
||||||
|
optim1.step()
|
|
@ -0,0 +1,240 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import torch
|
||||||
|
from timer import timer
|
||||||
|
from parallel_wavegan.layers import upsample, residual_block
|
||||||
|
from parallel_wavegan.models import parallel_wavegan as pwgan
|
||||||
|
from parakeet.utils.layer_tools import summary
|
||||||
|
from parakeet.utils.profile import synchronize
|
||||||
|
|
||||||
|
from parakeet.models.parallel_wavegan import ConvInUpsampleNet, ResidualBlock
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator, ResidualPWGDiscriminator
|
||||||
|
|
||||||
|
paddle.set_device("gpu:0")
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
|
||||||
|
def test_convin_upsample_net():
|
||||||
|
net = ConvInUpsampleNet(
|
||||||
|
[4, 4, 4, 4],
|
||||||
|
"LeakyReLU", {"negative_slope": 0.2},
|
||||||
|
freq_axis_kernel_size=3,
|
||||||
|
aux_context_window=0)
|
||||||
|
net2 = upsample.ConvInUpsampleNetwork(
|
||||||
|
[4, 4, 4, 4],
|
||||||
|
nonlinear_activation="LeakyReLU",
|
||||||
|
nonlinear_activation_params={"negative_slope": 0.2},
|
||||||
|
freq_axis_kernel_size=3,
|
||||||
|
aux_context_window=0).to(device)
|
||||||
|
summary(net)
|
||||||
|
for k, v in net2.named_parameters():
|
||||||
|
print(k, v.shape)
|
||||||
|
net.state_dict()[k].set_value(v.data.cpu().numpy())
|
||||||
|
|
||||||
|
c = paddle.randn([4, 80, 180])
|
||||||
|
synchronize()
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out = net(c)
|
||||||
|
synchronize()
|
||||||
|
print(f"paddle conv_in_upsample_net forward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out.sum().backward()
|
||||||
|
synchronize()
|
||||||
|
print(f"paddle conv_in_upsample_net backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
c_torch = torch.as_tensor(c.numpy()).to(device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out2 = net2(c_torch)
|
||||||
|
print(f"torch conv_in_upsample_net forward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out2.sum().backward()
|
||||||
|
print(f"torch conv_in_upsample_net backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
print("forward check")
|
||||||
|
print(out.numpy()[0])
|
||||||
|
print(out2.data.cpu().numpy()[0])
|
||||||
|
|
||||||
|
print("backward check")
|
||||||
|
print(net.conv_in.weight.grad.numpy()[0])
|
||||||
|
print(net2.conv_in.weight.grad.data.cpu().numpy()[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_residual_block():
|
||||||
|
net = ResidualBlock(dilation=9)
|
||||||
|
net2 = residual_block.ResidualBlock(dilation=9)
|
||||||
|
summary(net)
|
||||||
|
summary(net2)
|
||||||
|
for k, v in net2.named_parameters():
|
||||||
|
net.state_dict()[k].set_value(v.data.cpu().numpy())
|
||||||
|
|
||||||
|
x = paddle.randn([4, 64, 180])
|
||||||
|
c = paddle.randn([4, 80, 180])
|
||||||
|
res, skip = net(x, c)
|
||||||
|
res2, skip2 = net2(torch.as_tensor(x.numpy()), torch.as_tensor(c.numpy()))
|
||||||
|
|
||||||
|
print("forward:")
|
||||||
|
print(res.numpy()[0])
|
||||||
|
print(res2.data.cpu().numpy()[0])
|
||||||
|
print(skip.numpy()[0])
|
||||||
|
print(skip2.data.cpu().numpy()[0])
|
||||||
|
|
||||||
|
(res.sum() + skip.sum()).backward()
|
||||||
|
(res2.sum() + skip2.sum()).backward()
|
||||||
|
|
||||||
|
print("backward:")
|
||||||
|
print(net.conv.weight.grad.numpy().squeeze()[0])
|
||||||
|
print(net2.conv.weight.grad.data.cpu().numpy().squeeze()[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_pwg_generator():
|
||||||
|
net = PWGGenerator(
|
||||||
|
layers=9,
|
||||||
|
stacks=3,
|
||||||
|
upsample_scales=[4, 4, 4, 4],
|
||||||
|
nonlinear_activation="LeakyReLU",
|
||||||
|
nonlinear_activation_params={"negative_slope": 0.5},
|
||||||
|
use_weight_norm=True)
|
||||||
|
net2 = pwgan.ParallelWaveGANGenerator(
|
||||||
|
layers=9,
|
||||||
|
stacks=3,
|
||||||
|
upsample_params={
|
||||||
|
"upsample_scales": [4, 4, 4, 4],
|
||||||
|
"nonlinear_activation": "LeakyReLU",
|
||||||
|
"nonlinear_activation_params": {
|
||||||
|
"negative_slope": 0.5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
use_weight_norm=True).to(device)
|
||||||
|
summary(net)
|
||||||
|
summary(net2)
|
||||||
|
for k, v in net2.named_parameters():
|
||||||
|
p = net.state_dict()[k]
|
||||||
|
if k.endswith("_g"):
|
||||||
|
p.set_value(v.data.cpu().numpy().reshape([-1]))
|
||||||
|
else:
|
||||||
|
p.set_value(v.data.cpu().numpy())
|
||||||
|
x = paddle.randn([4, 1, 80 * 256])
|
||||||
|
c = paddle.randn([4, 80, 80 + 4])
|
||||||
|
|
||||||
|
synchronize()
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out = net(x, c)
|
||||||
|
synchronize()
|
||||||
|
print(f"paddle generator forward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
synchronize()
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out.sum().backward()
|
||||||
|
synchronize()
|
||||||
|
print(f"paddle generator backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
x_torch = torch.as_tensor(x.numpy()).to(device)
|
||||||
|
c_torch = torch.as_tensor(c.numpy()).to(device)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out2 = net2(x_torch, c_torch)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"torch generator forward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with timer(unit='s') as t:
|
||||||
|
out2.sum().backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"torch generator backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
print("test forward:")
|
||||||
|
print(out.numpy()[0])
|
||||||
|
print(out2.data.cpu().numpy()[0])
|
||||||
|
|
||||||
|
print("test backward:")
|
||||||
|
print("wv")
|
||||||
|
print(net.first_conv.weight_v.grad.numpy().squeeze())
|
||||||
|
print(net2.first_conv.weight_v.grad.data.cpu().numpy().squeeze())
|
||||||
|
|
||||||
|
print("wg")
|
||||||
|
print(net.first_conv.weight_g.grad.numpy().squeeze())
|
||||||
|
print(net2.first_conv.weight_g.grad.data.cpu().numpy().squeeze())
|
||||||
|
# print(out.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pwg_discriminator():
|
||||||
|
net = PWGDiscriminator()
|
||||||
|
net2 = pwgan.ParallelWaveGANDiscriminator().to(device)
|
||||||
|
summary(net)
|
||||||
|
summary(net2)
|
||||||
|
for k, v in net2.named_parameters():
|
||||||
|
p = net.state_dict()[k]
|
||||||
|
if k.endswith("_g"):
|
||||||
|
p.set_value(v.data.cpu().numpy().reshape([-1]))
|
||||||
|
else:
|
||||||
|
p.set_value(v.data.cpu().numpy())
|
||||||
|
x = paddle.randn([4, 1, 180 * 256])
|
||||||
|
|
||||||
|
synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
y = net(x)
|
||||||
|
synchronize()
|
||||||
|
print(f"forward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
y.sum().backward()
|
||||||
|
synchronize()
|
||||||
|
print(f"backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
x_torch = torch.as_tensor(x.numpy()).to(device)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
y2 = net2(x_torch)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"forward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
y2.sum().backward()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"backward takes {t.elapse}s.")
|
||||||
|
|
||||||
|
print("test forward:")
|
||||||
|
print(y.numpy()[0])
|
||||||
|
print(y2.data.cpu().numpy()[0])
|
||||||
|
|
||||||
|
print("test backward:")
|
||||||
|
print(net.conv_layers[0].weight_v.grad.numpy().squeeze())
|
||||||
|
print(net2.conv_layers[0].weight_v.grad.data.cpu().numpy().squeeze())
|
||||||
|
|
||||||
|
|
||||||
|
def test_residual_pwg_discriminator():
|
||||||
|
net = ResidualPWGDiscriminator()
|
||||||
|
net2 = pwgan.ResidualParallelWaveGANDiscriminator()
|
||||||
|
summary(net)
|
||||||
|
summary(net2)
|
||||||
|
for k, v in net2.named_parameters():
|
||||||
|
p = net.state_dict()[k]
|
||||||
|
if k.endswith("_g"):
|
||||||
|
p.set_value(v.data.cpu().numpy().reshape([-1]))
|
||||||
|
else:
|
||||||
|
p.set_value(v.data.cpu().numpy())
|
||||||
|
x = paddle.randn([4, 1, 180 * 256])
|
||||||
|
y = net(x)
|
||||||
|
y2 = net2(torch.as_tensor(x.numpy()))
|
||||||
|
print(y.numpy()[0])
|
||||||
|
print(y2.data.cpu().numpy()[0])
|
||||||
|
print(y.shape)
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from parakeet.training.reporter import report, scope
|
||||||
|
from parakeet.training.reporter import Summary, DictSummary
|
||||||
|
|
||||||
|
|
||||||
|
def test_reporter_scope():
|
||||||
|
first = {}
|
||||||
|
second = {}
|
||||||
|
third = {}
|
||||||
|
|
||||||
|
with scope(first):
|
||||||
|
report("first_begin", 1)
|
||||||
|
with scope(second):
|
||||||
|
report("second_begin", 2)
|
||||||
|
with scope(third):
|
||||||
|
report("third_begin", 3)
|
||||||
|
report("third_end", 4)
|
||||||
|
report("seconf_end", 5)
|
||||||
|
report("first_end", 6)
|
||||||
|
|
||||||
|
assert first == {'first_begin': 1, 'first_end': 6}
|
||||||
|
assert second == {'second_begin': 2, 'seconf_end': 5}
|
||||||
|
assert third == {'third_begin': 3, 'third_end': 4}
|
||||||
|
print(first)
|
||||||
|
print(second)
|
||||||
|
print(third)
|
||||||
|
|
||||||
|
|
||||||
|
def test_summary():
|
||||||
|
summary = Summary()
|
||||||
|
summary.add(1)
|
||||||
|
summary.add(2)
|
||||||
|
summary.add(3)
|
||||||
|
state = summary.make_statistics()
|
||||||
|
print(state)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.array(list(state)), np.array([2.0, np.std([1, 2, 3])]))
|
|
@ -0,0 +1,55 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
from itertools import count
|
||||||
|
|
||||||
|
from parakeet.training.updater import StandardUpdater
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_snapshot():
|
||||||
|
model = nn.Linear(3, 4)
|
||||||
|
optimizer = Adam(parameters=model.parameters())
|
||||||
|
|
||||||
|
# use a simplest iterable object as dataloader
|
||||||
|
dataloader = count()
|
||||||
|
|
||||||
|
# hack the training proecss: training does nothing except increse iteration
|
||||||
|
updater = StandardUpdater(model, optimizer, dataloader=dataloader)
|
||||||
|
updater.update_core = lambda x: None
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
updater, stop_trigger=(1000, 'iteration'), out='temp_test_snapshot')
|
||||||
|
shutil.rmtree(trainer.out, ignore_errors=True)
|
||||||
|
|
||||||
|
snap = Snapshot(max_size=5)
|
||||||
|
trigger = (10, 'iteration')
|
||||||
|
trainer.extend(snap, name='snapshot', trigger=trigger, priority=0)
|
||||||
|
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
checkpoint_dir = trainer.out / "checkpoints"
|
||||||
|
snapshots = sorted(list(checkpoint_dir.glob("snapshot_iter_*.pdz")))
|
||||||
|
for snap in snapshots:
|
||||||
|
print(snap)
|
||||||
|
assert len(snapshots) == 5
|
||||||
|
shutil.rmtree(trainer.out)
|
|
@ -0,0 +1,73 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import torch
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from parakeet.modules.stft_loss import STFT, MultiResolutionSTFTLoss
|
||||||
|
from parallel_wavegan.losses import stft_loss as sl
|
||||||
|
from scipy import signal
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft():
|
||||||
|
stft = STFT(n_fft=1024, hop_length=256, win_length=1024)
|
||||||
|
x = paddle.uniform([4, 46080])
|
||||||
|
S = stft.magnitude(x)
|
||||||
|
window = signal.get_window('hann', 1024, fftbins=True)
|
||||||
|
D2 = torch.stft(
|
||||||
|
torch.as_tensor(x.numpy()),
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
window=torch.as_tensor(window))
|
||||||
|
S2 = (D2**2).sum(-1).sqrt()
|
||||||
|
S3 = np.abs(
|
||||||
|
librosa.stft(
|
||||||
|
x.numpy()[0], n_fft=1024, hop_length=256, win_length=1024))
|
||||||
|
print(S2.shape)
|
||||||
|
print(S.numpy()[0])
|
||||||
|
print(S2.data.cpu().numpy()[0])
|
||||||
|
print(S3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_torch_stft():
|
||||||
|
# NOTE: torch.stft use no window by default
|
||||||
|
x = np.random.uniform(-1.0, 1.0, size=(46080, ))
|
||||||
|
window = signal.get_window('hann', 1024, fftbins=True)
|
||||||
|
D2 = torch.stft(
|
||||||
|
torch.as_tensor(x),
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
window=torch.as_tensor(window))
|
||||||
|
D3 = librosa.stft(
|
||||||
|
x, n_fft=1024, hop_length=256, win_length=1024, window='hann')
|
||||||
|
print(D2[:, :, 0].data.cpu().numpy()[:, 30:60])
|
||||||
|
print(D3.real[:, 30:60])
|
||||||
|
# print(D3.imag[:, 30:60])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_resolution_stft_loss():
|
||||||
|
net = MultiResolutionSTFTLoss()
|
||||||
|
net2 = sl.MultiResolutionSTFTLoss()
|
||||||
|
|
||||||
|
x = paddle.uniform([4, 46080])
|
||||||
|
y = paddle.uniform([4, 46080])
|
||||||
|
sc, m = net(x, y)
|
||||||
|
sc2, m2 = net2(torch.as_tensor(x.numpy()), torch.as_tensor(y.numpy()))
|
||||||
|
print(sc.numpy())
|
||||||
|
print(sc2.data.cpu().numpy())
|
||||||
|
print(m.numpy())
|
||||||
|
print(m2.data.cpu().numpy())
|
Loading…
Reference in New Issue