WIP: training setup done
This commit is contained in:
parent
0067851950
commit
54c7905f40
|
@ -142,3 +142,5 @@ dmypy.json
|
|||
*.swp
|
||||
runs
|
||||
syn_audios
|
||||
exp/
|
||||
dump/
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
class Clip(object):
|
||||
"""Collate functor for training vocoders.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_max_steps=20480,
|
||||
hop_size=256,
|
||||
aux_context_window=0, ):
|
||||
"""Initialize customized collater for PyTorch DataLoader.
|
||||
|
||||
Args:
|
||||
batch_max_steps (int): The maximum length of input signal in batch.
|
||||
hop_size (int): Hop size of auxiliary features.
|
||||
aux_context_window (int): Context window size for auxiliary feature conv.
|
||||
|
||||
"""
|
||||
if batch_max_steps % hop_size != 0:
|
||||
batch_max_steps += -(batch_max_steps % hop_size)
|
||||
assert batch_max_steps % hop_size == 0
|
||||
self.batch_max_steps = batch_max_steps
|
||||
self.batch_max_frames = batch_max_steps // hop_size
|
||||
self.hop_size = hop_size
|
||||
self.aux_context_window = aux_context_window
|
||||
|
||||
# set useful values in random cutting
|
||||
self.start_offset = aux_context_window
|
||||
self.end_offset = -(self.batch_max_frames + aux_context_window)
|
||||
self.mel_threshold = self.batch_max_frames + 2 * aux_context_window
|
||||
|
||||
def __call__(self, examples):
|
||||
"""Convert into batch tensors.
|
||||
|
||||
Args:
|
||||
batch (list): list of tuple of the pair of audio and features.
|
||||
|
||||
Returns:
|
||||
Tensor: Auxiliary feature batch (B, C, T'), where
|
||||
T = (T' - 2 * aux_context_window) * hop_size.
|
||||
Tensor: Target signal batch (B, 1, T).
|
||||
|
||||
"""
|
||||
# check length
|
||||
examples = [
|
||||
self._adjust_length(*b) for b in examples
|
||||
if len(b[1]) > self.mel_threshold
|
||||
]
|
||||
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
|
||||
|
||||
# make batch with random cut
|
||||
c_lengths = [len(c) for c in cs]
|
||||
start_frames = np.array([
|
||||
np.random.randint(self.start_offset, cl + self.end_offset)
|
||||
for cl in c_lengths
|
||||
])
|
||||
x_starts = start_frames * self.hop_size
|
||||
x_ends = x_starts + self.batch_max_steps
|
||||
|
||||
c_starts = start_frames - self.aux_context_window
|
||||
c_ends = start_frames + self.batch_max_frames + self.aux_context_window
|
||||
y_batch = [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)]
|
||||
c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]
|
||||
|
||||
# convert each batch to tensor, asuume that each item in batch has the same length
|
||||
y_batch = paddle.to_tensor(
|
||||
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
|
||||
c_batch = paddle.to_tensor(
|
||||
c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T')
|
||||
|
||||
return (c_batch, ), y_batch
|
||||
|
||||
def _adjust_length(self, x, c):
|
||||
"""Adjust the audio and feature lengths.
|
||||
|
||||
Note:
|
||||
Basically we assume that the length of x and c are adjusted
|
||||
through preprocessing stage, but if we use other library processed
|
||||
features, this process will be needed.
|
||||
|
||||
"""
|
||||
if len(x) < len(c) * self.hop_size:
|
||||
x = np.pad(x, (0, len(c) * self.hop_size - len(x)), mode="edge")
|
||||
|
||||
# check the legnth is valid
|
||||
assert len(x) == len(c) * self.hop_size
|
||||
|
||||
return x, c
|
|
@ -0,0 +1,127 @@
|
|||
# This is the hyperparameter configuration file for Parallel WaveGAN.
|
||||
# Please make sure this is adjusted for the CSMSC dataset. If you want to
|
||||
# apply to the other dataset, you might need to carefully change some parameters.
|
||||
# This configuration requires 12 GB GPU memory and takes ~3 days on RTX TITAN.
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
sr: 24000 # Sampling rate.
|
||||
n_fft: 2048 # FFT size.
|
||||
hop_length: 300 # Hop size.
|
||||
win_length: 1200 # Window length.
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
n_mels: 80 # Number of mel basis.
|
||||
fmin: 80 # Minimum freq in mel basis calculation.
|
||||
fmax: 7600 # Maximum frequency in mel basis calculation.
|
||||
# global_gain_scale: 1.0 # Will be multiplied to all of waveform.
|
||||
trim_silence: false # Whether to trim the start and end of silence.
|
||||
top_db: 60 # Need to tune carefully if the recording is not good.
|
||||
trim_frame_length: 2048 # Frame size in trimming.
|
||||
trim_hop_length: 512 # Hop size in trimming.
|
||||
# format: "npy" # Feature file format. "npy" or "hdf5" is supported.
|
||||
|
||||
###########################################################
|
||||
# GENERATOR NETWORK ARCHITECTURE SETTING #
|
||||
###########################################################
|
||||
generator_params:
|
||||
in_channels: 1 # Number of input channels.
|
||||
out_channels: 1 # Number of output channels.
|
||||
kernel_size: 3 # Kernel size of dilated convolution.
|
||||
layers: 30 # Number of residual block layers.
|
||||
stacks: 3 # Number of stacks i.e., dilation cycles.
|
||||
residual_channels: 64 # Number of channels in residual conv.
|
||||
gate_channels: 128 # Number of channels in gated conv.
|
||||
skip_channels: 64 # Number of channels in skip conv.
|
||||
aux_channels: 80 # Number of channels for auxiliary feature conv.
|
||||
# Must be the same as num_mels.
|
||||
aux_context_window: 2 # Context window size for auxiliary feature.
|
||||
# If set to 2, previous 2 and future 2 frames will be considered.
|
||||
dropout: 0.0 # Dropout rate. 0.0 means no dropout applied.
|
||||
bias: true # use bias in residual blocks
|
||||
use_weight_norm: true # Whether to use weight norm.
|
||||
# If set to true, it will be applied to all of the conv layers.
|
||||
use_causal_conv: false # use causal conv in residual blocks and upsample layers
|
||||
# upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture.
|
||||
upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size.
|
||||
interpolate_mode: "nearest" # upsample net interpolate mode
|
||||
freq_axis_kernel_size: 1 # upsamling net: convolution kernel size in frequencey axis
|
||||
nonlinear_activation: null
|
||||
nonlinear_activation_params: {}
|
||||
|
||||
###########################################################
|
||||
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
|
||||
###########################################################
|
||||
discriminator_params:
|
||||
in_channels: 1 # Number of input channels.
|
||||
out_channels: 1 # Number of output channels.
|
||||
kernel_size: 3 # Number of output channels.
|
||||
layers: 10 # Number of conv layers.
|
||||
conv_channels: 64 # Number of chnn layers.
|
||||
bias: true # Whether to use bias parameter in conv.
|
||||
use_weight_norm: true # Whether to use weight norm.
|
||||
# If set to true, it will be applied to all of the conv layers.
|
||||
nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv.
|
||||
nonlinear_activation_params: # Nonlinear function parameters
|
||||
negative_slope: 0.2 # Alpha in LeakyReLU.
|
||||
|
||||
###########################################################
|
||||
# STFT LOSS SETTING #
|
||||
###########################################################
|
||||
stft_loss_params:
|
||||
fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss.
|
||||
hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss
|
||||
win_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
|
||||
window: "hann" # Window function for STFT-based loss
|
||||
|
||||
###########################################################
|
||||
# ADVERSARIAL LOSS SETTING #
|
||||
###########################################################
|
||||
lambda_adv: 4.0 # Loss balancing coefficient.
|
||||
|
||||
###########################################################
|
||||
# DATA LOADER SETTING #
|
||||
###########################################################
|
||||
batch_size: 6 # Batch size.
|
||||
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
||||
num_workers: 2 # Number of workers in Pytorch DataLoader.
|
||||
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
|
||||
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER & SCHEDULER SETTING #
|
||||
###########################################################
|
||||
generator_optimizer_params:
|
||||
epsilon: 1.0e-6 # Generator's epsilon.
|
||||
weight_decay: 0.0 # Generator's weight decay coefficient.
|
||||
generator_scheduler_params:
|
||||
learning_rate: 0.0001 # Generator's learning rate.
|
||||
step_size: 200000 # Generator's scheduler step size.
|
||||
gamma: 0.5 # Generator's scheduler gamma.
|
||||
# At each step size, lr will be multiplied by this parameter.
|
||||
generator_grad_norm: 10 # Generator's gradient norm.
|
||||
discriminator_optimizer_params:
|
||||
epsilon: 1.0e-6 # Discriminator's epsilon.
|
||||
weight_decay: 0.0 # Discriminator's weight decay coefficient.
|
||||
discriminator_scheduler_params:
|
||||
learning_rate: 0.00005 # Discriminator's learning rate.
|
||||
step_size: 200000 # Discriminator's scheduler step size.
|
||||
gamma: 0.5 # Discriminator's scheduler gamma.
|
||||
# At each step size, lr will be multiplied by this parameter.
|
||||
discriminator_grad_norm: 1 # Discriminator's gradient norm.
|
||||
|
||||
###########################################################
|
||||
# INTERVAL SETTING #
|
||||
###########################################################
|
||||
discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator.
|
||||
train_max_steps: 400000 # Number of training steps.
|
||||
save_interval_steps: 5000 # Interval steps to save checkpoint.
|
||||
eval_interval_steps: 1000 # Interval steps to evaluate the network.
|
||||
log_interval_steps: 100 # Interval steps to record the training log.
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import yaml
|
||||
from yacs.config import CfgNode as Configuration
|
||||
|
||||
with open("conf/default.yaml", 'rt') as f:
|
||||
_C = yaml.safe_load(f)
|
||||
_C = Configuration(_C)
|
||||
|
||||
|
||||
def get_cfg_default():
|
||||
config = _C.clone()
|
||||
return config
|
|
@ -0,0 +1,280 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import numpy as np
|
||||
from config import get_cfg_default
|
||||
import argparse
|
||||
import yaml
|
||||
import json
|
||||
import dacite
|
||||
import dataclasses
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
from operator import itemgetter
|
||||
from praatio import tgio
|
||||
import logging
|
||||
|
||||
|
||||
def logmelfilterbank(audio,
|
||||
sr,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
n_mels=80,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10):
|
||||
"""Compute log-Mel filterbank feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : ndarray
|
||||
Audio signal (T,).
|
||||
sr : int
|
||||
Sampling rate.
|
||||
n_fft : int
|
||||
FFT size. (Default value = 1024)
|
||||
hop_length : int
|
||||
Hop size. (Default value = 256)
|
||||
win_length : int
|
||||
Window length. If set to None, it will be the same as fft_size. (Default value = None)
|
||||
window : str
|
||||
Window function type. (Default value = "hann")
|
||||
n_mels : int
|
||||
Number of mel basis. (Default value = 80)
|
||||
fmin : int
|
||||
Minimum frequency in mel basis calculation. (Default value = None)
|
||||
fmax : int
|
||||
Maximum frequency in mel basis calculation. (Default value = None)
|
||||
eps : float
|
||||
Epsilon value to avoid inf in log calculation. (Default value = 1e-10)
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Log Mel filterbank feature (#frames, num_mels).
|
||||
|
||||
"""
|
||||
# get amplitude spectrogram
|
||||
x_stft = librosa.stft(
|
||||
audio,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode="reflect")
|
||||
spc = np.abs(x_stft) # (#bins, #frames,)
|
||||
|
||||
# get mel basis
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = sr / 2 if fmax is None else fmax
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
|
||||
return np.log10(np.maximum(eps, np.dot(mel_basis, spc)))
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
alignment_fp: Path,
|
||||
output_dir: Path):
|
||||
utt_id = fp.stem
|
||||
|
||||
# reading
|
||||
y, sr = librosa.load(fp, sr=config.sr) # resampling may occur
|
||||
assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(y).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
duration = librosa.get_duration(y, sr=sr)
|
||||
|
||||
# trim according to the alignment file
|
||||
alignment = tgio.openTextgrid(alignment_fp)
|
||||
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||
first, last = intervals[0], intervals[-1]
|
||||
start = 0
|
||||
end = last.end
|
||||
if first.label == "sil" and first.end < duration:
|
||||
start = first.end
|
||||
else:
|
||||
logging.warning(
|
||||
f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}"
|
||||
)
|
||||
if last.label == "sil" and last.start < duration:
|
||||
end = last.start
|
||||
else:
|
||||
end = duration
|
||||
logging.warning(
|
||||
f" There is something wrong with the last interval {last} in utterance: {utt_id}"
|
||||
)
|
||||
# silence trimmed
|
||||
start, end = librosa.time_to_samples([first.end, last.start], sr=sr)
|
||||
y = y[start:end]
|
||||
|
||||
# energy based silence trimming
|
||||
if config.trim_silence:
|
||||
y, _ = librosa.effects.trim(
|
||||
y,
|
||||
top_db=config.top_db,
|
||||
frame_length=config.trim_frame_length,
|
||||
hop_length=config.trim_hop_length)
|
||||
|
||||
logmel = logmelfilterbank(
|
||||
y,
|
||||
sr=sr,
|
||||
n_fft=config.n_fft,
|
||||
window=config.window,
|
||||
win_length=config.win_length,
|
||||
hop_length=config.hop_length,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
|
||||
# adjust time to make num_samples == num_frames * hop_length
|
||||
num_frames = logmel.shape[1]
|
||||
y = np.pad(y, (0, config.n_fft), mode="reflect")
|
||||
y = y[:num_frames * config.hop_length]
|
||||
num_sample = y.shape[0]
|
||||
|
||||
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||
wav_path = output_dir / (utt_id + "_wave.npy")
|
||||
np.save(wav_path, y)
|
||||
np.save(mel_path, logmel)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"num_samples": num_sample,
|
||||
"num_frames": num_frames,
|
||||
"feats_path": str(mel_path.resolve()),
|
||||
"wave_path": str(wav_path.resolve()),
|
||||
}
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(config,
|
||||
fps: List[Path],
|
||||
alignment_fps: List[Path],
|
||||
output_dir: Path,
|
||||
nprocs: int=1):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp, alignment_fp in tqdm.tqdm(zip(fps, alignment_fps)):
|
||||
results.append(
|
||||
process_sentence(config, fp, alignment_fp, output_dir))
|
||||
else:
|
||||
with ProcessPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
alignment_fp, output_dir)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
results.append(ft.result())
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with open(output_dir / "metadata.json", 'wt') as f:
|
||||
json.dump(results, f)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rootdir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory including wav files. you need to specify either scp or rootdir."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
parser.add_argument(
|
||||
"--num_cpu", type=int, default=1, help="number of process.")
|
||||
args = parser.parse_args()
|
||||
|
||||
C = get_cfg_default()
|
||||
if args.config:
|
||||
C.merge_from_file(args.config)
|
||||
C.freeze()
|
||||
|
||||
if args.verbose > 1:
|
||||
print(vars(args))
|
||||
print(yaml.dump(dataclasses.asdict(C)))
|
||||
|
||||
root_dir = Path(args.rootdir)
|
||||
dumpdir = Path(args.dumpdir)
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
|
||||
alignment_files = sorted(
|
||||
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||
|
||||
# split data into 3 sections
|
||||
train_wav_files = wav_files[:9800]
|
||||
dev_wav_files = wav_files[9800:9900]
|
||||
test_wav_files = wav_files[9900:]
|
||||
|
||||
train_alignment_files = alignment_files[:9800]
|
||||
dev_alignment_files = alignment_files[9800:9900]
|
||||
test_alignment_files = alignment_files[9900:]
|
||||
|
||||
train_dump_dir = dumpdir / "train"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# process for the 3 sections
|
||||
process_sentences(
|
||||
C,
|
||||
train_wav_files,
|
||||
train_alignment_files,
|
||||
train_dump_dir,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
dev_wav_files,
|
||||
dev_alignment_files,
|
||||
dev_dump_dir,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
test_wav_files,
|
||||
test_alignment_files,
|
||||
test_dump_dir,
|
||||
nprocs=args.num_cpu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,166 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import dacite
|
||||
import json
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.checkpoint import KBest, KLatest
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
if not paddle.is_compiled_with_cuda:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
||||
# construct dataset for training and validation
|
||||
with open(args.train_metadata) as f:
|
||||
train_metadata = json.load(f)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=["wave_path", "feats_path"],
|
||||
converters={
|
||||
"wave_path": np.load,
|
||||
"feats_path": np.load,
|
||||
}, )
|
||||
with open(args.dev_metadata) as f:
|
||||
dev_metadata = json.load(f)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=["wave_path", "feats_path"],
|
||||
converters={
|
||||
"wave_path": np.load,
|
||||
"feats_path": np.load,
|
||||
}, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dev_sampler = DistributedBatchSampler(
|
||||
dev_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=None, # TODO(defaine collate fn)
|
||||
num_workers=4)
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
collate_fn=None, # TODO(defaine collate fn)
|
||||
num_workers=4)
|
||||
|
||||
generator = PWGGenerator(**config["generator_params"])
|
||||
discriminator = PWGDiscriminator(**config["discriminator_params"])
|
||||
if world_size > 1:
|
||||
generator = DataParallel(generator)
|
||||
discriminator = DataParallel(discriminator)
|
||||
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
||||
criterion_mse = nn.MSELoss()
|
||||
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
||||
optimizer_g = Adam(
|
||||
lr_schedule_g,
|
||||
parameters=generator.parameters(),
|
||||
**config["generator_optimizer_params"])
|
||||
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
||||
optimizer_d = Adam(
|
||||
lr_schedule_d,
|
||||
parameters=discriminator.parameters(),
|
||||
**config["discriminator_optimizer_params"])
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
log_writer = None
|
||||
if dist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_writer = LogWriter(output_dir)
|
||||
|
||||
# training loop
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.nprocs > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -109,10 +109,11 @@ class UpsampleNet(nn.Layer):
|
|||
padding = (freq_axis_padding, scale)
|
||||
conv = nn.Conv2D(
|
||||
1, 1, kernel_size, padding=padding, bias_attr=False)
|
||||
self.up_layers.extend([stretch, conv])
|
||||
if nonlinear_activation is not None:
|
||||
nonlinear = getattr(
|
||||
nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
self.up_layers.extend([stretch, conv, nonlinear])
|
||||
self.up_layers.append(nonlinear)
|
||||
|
||||
def forward(self, c: Tensor) -> Tensor:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue