WIP: pwg training works

This commit is contained in:
chenfeiyu 2021-06-14 17:05:37 +08:00
parent 54c7905f40
commit b0983e4d76
7 changed files with 183 additions and 22 deletions

View File

@ -60,13 +60,13 @@ class Clip(object):
"""
# check length
examples = [
self._adjust_length(*b) for b in examples
if len(b[1]) > self.mel_threshold
self._adjust_length(b['wave_path'], b['feats_path'])
for b in examples if b['feats_path'].shape[1] > self.mel_threshold
]
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
# make batch with random cut
c_lengths = [len(c) for c in cs]
c_lengths = [c.shape[1] for c in cs]
start_frames = np.array([
np.random.randint(self.start_offset, cl + self.end_offset)
for cl in c_lengths
@ -76,16 +76,17 @@ class Clip(object):
c_starts = start_frames - self.aux_context_window
c_ends = start_frames + self.batch_max_frames + self.aux_context_window
y_batch = [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)]
c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]
y_batch = np.stack(
[x[start:end] for x, start, end in zip(xs, x_starts, x_ends)])
c_batch = np.stack(
[c[:, start:end] for c, start, end in zip(cs, c_starts, c_ends)])
# convert each batch to tensor, asuume that each item in batch has the same length
y_batch = paddle.to_tensor(
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
c_batch = paddle.to_tensor(
c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T')
c_batch = paddle.to_tensor(c_batch, dtype=paddle.float32) # (B, C, T')
return (c_batch, ), y_batch
return y_batch, c_batch
def _adjust_length(self, x, c):
"""Adjust the audio and feature lengths.
@ -96,10 +97,12 @@ class Clip(object):
features, this process will be needed.
"""
if len(x) < len(c) * self.hop_size:
x = np.pad(x, (0, len(c) * self.hop_size - len(x)), mode="edge")
if len(x) < c.shape[1] * self.hop_size:
x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)),
mode="edge")
# check the legnth is valid
assert len(x) == len(c) * self.hop_size
assert len(x) == c.shape[
1] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[1]})"
return x, c

View File

@ -86,7 +86,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
batch_size: 6 # Batch size.
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
num_workers: 2 # Number of workers in Pytorch DataLoader.
num_workers: 0 # Number of workers in Pytorch DataLoader.
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.

View File

@ -0,0 +1,111 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from parakeet.datasets.data_table import DataTable
from parakeet.training.updater import UpdaterBase, UpdaterState
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training.checkpoint import KBest, KLatest
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
class PWGUpdater(UpdaterBase):
def __init__(
self,
models,
optimizers,
criterions,
schedulers,
dataloaders,
discriminator_train_start_steps: int,
lambda_adv: float, ):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.optimizers = optimizers
self.optimizer_g = optimizers['generator']
self.optimizer_d = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloaders = dataloaders
self.train_dataloader = dataloaders['train']
self.dev_dataloader = dataloaders['dev']
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.train_dataloader)
def update_core(self):
try:
batch = next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.train_dataloader)
batch = next(self.train_iterator)
wav, mel = batch
# Generator
noise = paddle.randn(wav.shape)
wav_ = self.generator(noise, mel)
## Multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
gen_loss = sc_loss + mag_loss
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
report("train/adversarial_loss", float(adv_loss))
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
self.optimizer_g.clear_grad()
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
dis_loss = real_loss + fake_loss
report("train/discriminator_loss", float(dis_loss))
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()

View File

@ -41,7 +41,9 @@ from parakeet.training.checkpoint import KBest, KLatest
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
from batch_fn import Clip
from config import get_cfg_default
from pwg_updater import PWGUpdater
def train_sp(args, config):
@ -90,25 +92,35 @@ def train_sp(args, config):
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
print("samplers done!")
train_batch_fn = Clip(
batch_max_steps=config.batch_max_steps,
hop_size=config.hop_length,
aux_context_window=config.generator_params.aux_context_window)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=None, # TODO(defaine collate fn)
num_workers=4)
collate_fn=train_batch_fn, # TODO(defaine collate fn)
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=None, # TODO(defaine collate fn)
num_workers=4)
collate_fn=train_batch_fn, # TODO(defaine collate fn)
num_workers=config.num_workers)
print("dataloaders done!")
generator = PWGGenerator(**config["generator_params"])
discriminator = PWGDiscriminator(**config["discriminator_params"])
if world_size > 1:
generator = DataParallel(generator)
discriminator = DataParallel(discriminator)
print("models done!")
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
criterion_mse = nn.MSELoss()
print("criterions done!")
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
optimizer_g = Adam(
lr_schedule_g,
@ -119,14 +131,43 @@ def train_sp(args, config):
lr_schedule_d,
parameters=discriminator.parameters(),
**config["discriminator_optimizer_params"])
print("optimizers done!")
output_dir = Path(args.output_dir)
log_writer = None
if dist.get_rank() == 0:
output_dir.mkdir(parents=True, exist_ok=True)
log_writer = LogWriter(output_dir)
log_writer = LogWriter(str(output_dir))
# training loop
updater = PWGUpdater(
models={
"generator": generator,
"discriminator": discriminator,
},
optimizers={
"generator": optimizer_g,
"discriminator": optimizer_d,
},
criterions={
"stft": criterion_stft,
"mse": criterion_mse,
},
schedulers={
"generator": lr_schedule_g,
"discriminator": lr_schedule_d,
},
dataloaders={
"train": train_dataloader,
"dev": dev_dataloader,
},
discriminator_train_start_steps=config.discriminator_train_start_steps,
lambda_adv=config.lambda_adv, )
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir, )
trainer.run()
def main():

View File

@ -76,7 +76,7 @@ class Trainer(object):
else:
max_iteration = self.stop_trigger.period
while not stop_trigger(self):
while True:
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
@ -84,8 +84,13 @@ class Trainer(object):
# updating parameters and state
with scope(self.observation):
update()
print(self.observation)
# execute extension when necessary
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
if stop_trigger(self):
print("Training Done!")
break

View File

@ -23,9 +23,9 @@ class IntervalTrigger(object):
def __call__(self, trainer):
state = trainer.updater.state
if self.unit == "epoch":
fire = not (state.epoch % self.period)
fire = state.epoch % self.period == 0
else:
fire = not (state.iteration % self.iteration)
fire = state.iteration % self.period == 0
return fire

View File

@ -57,7 +57,8 @@ class UpdaterBase(object):
"""
def update(self):
pass
self.state.iteration += 1
self.update_core()
def update_core(self):
pass