ParakeetRebeccaRosario/examples/parallelwave_gan/baker/train.py

208 lines
6.6 KiB
Python
Raw Normal View History

2021-06-13 17:00:44 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
import dataclasses
from pathlib import Path
import yaml
import dacite
import json
import paddle
import numpy as np
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay
from paddle import DataParallel
from visualdl import LogWriter
from parakeet.datasets.data_table import DataTable
from parakeet.training.updater import UpdaterBase
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training.checkpoint import KBest, KLatest
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
2021-06-14 17:05:37 +08:00
from batch_fn import Clip
2021-06-13 17:00:44 +08:00
from config import get_cfg_default
2021-06-14 17:05:37 +08:00
from pwg_updater import PWGUpdater
2021-06-13 17:00:44 +08:00
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if not paddle.is_compiled_with_cuda:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# construct dataset for training and validation
with open(args.train_metadata) as f:
train_metadata = json.load(f)
train_dataset = DataTable(
data=train_metadata,
fields=["wave_path", "feats_path"],
converters={
"wave_path": np.load,
"feats_path": np.load,
}, )
with open(args.dev_metadata) as f:
dev_metadata = json.load(f)
dev_dataset = DataTable(
data=dev_metadata,
fields=["wave_path", "feats_path"],
converters={
"wave_path": np.load,
"feats_path": np.load,
}, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
2021-06-14 17:05:37 +08:00
print("samplers done!")
2021-06-13 17:00:44 +08:00
2021-06-14 17:05:37 +08:00
train_batch_fn = Clip(
batch_max_steps=config.batch_max_steps,
hop_size=config.hop_length,
aux_context_window=config.generator_params.aux_context_window)
2021-06-13 17:00:44 +08:00
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
2021-06-14 17:05:37 +08:00
collate_fn=train_batch_fn, # TODO(defaine collate fn)
num_workers=config.num_workers)
2021-06-13 17:00:44 +08:00
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
2021-06-14 17:05:37 +08:00
collate_fn=train_batch_fn, # TODO(defaine collate fn)
num_workers=config.num_workers)
print("dataloaders done!")
2021-06-13 17:00:44 +08:00
generator = PWGGenerator(**config["generator_params"])
discriminator = PWGDiscriminator(**config["discriminator_params"])
if world_size > 1:
generator = DataParallel(generator)
discriminator = DataParallel(discriminator)
2021-06-14 17:05:37 +08:00
print("models done!")
2021-06-13 17:00:44 +08:00
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
criterion_mse = nn.MSELoss()
2021-06-14 17:05:37 +08:00
print("criterions done!")
2021-06-13 17:00:44 +08:00
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
optimizer_g = Adam(
lr_schedule_g,
parameters=generator.parameters(),
**config["generator_optimizer_params"])
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
optimizer_d = Adam(
lr_schedule_d,
parameters=discriminator.parameters(),
**config["discriminator_optimizer_params"])
2021-06-14 17:05:37 +08:00
print("optimizers done!")
2021-06-13 17:00:44 +08:00
output_dir = Path(args.output_dir)
log_writer = None
if dist.get_rank() == 0:
output_dir.mkdir(parents=True, exist_ok=True)
2021-06-14 17:05:37 +08:00
log_writer = LogWriter(str(output_dir))
updater = PWGUpdater(
models={
"generator": generator,
"discriminator": discriminator,
},
optimizers={
"generator": optimizer_g,
"discriminator": optimizer_d,
},
criterions={
"stft": criterion_stft,
"mse": criterion_mse,
},
schedulers={
"generator": lr_schedule_g,
"discriminator": lr_schedule_d,
},
dataloaders={
"train": train_dataloader,
"dev": dev_dataloader,
},
discriminator_train_start_steps=config.discriminator_train_start_steps,
lambda_adv=config.lambda_adv, )
2021-06-13 17:00:44 +08:00
2021-06-14 17:05:37 +08:00
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir, )
trainer.run()
2021-06-13 17:00:44 +08:00
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
"model with Baker Mandrin TTS dataset.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config")
parser.add_argument("--train-metadata", type=str, help="training data")
parser.add_argument("--dev-metadata", type=str, help="dev data")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--nprocs", type=int, default=1, help="number of processes")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.nprocs > 1:
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
else:
train_sp(args, config)
if __name__ == "__main__":
main()