WIP: pwg training works
This commit is contained in:
parent
54c7905f40
commit
b0983e4d76
|
@ -60,13 +60,13 @@ class Clip(object):
|
|||
"""
|
||||
# check length
|
||||
examples = [
|
||||
self._adjust_length(*b) for b in examples
|
||||
if len(b[1]) > self.mel_threshold
|
||||
self._adjust_length(b['wave_path'], b['feats_path'])
|
||||
for b in examples if b['feats_path'].shape[1] > self.mel_threshold
|
||||
]
|
||||
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
|
||||
|
||||
# make batch with random cut
|
||||
c_lengths = [len(c) for c in cs]
|
||||
c_lengths = [c.shape[1] for c in cs]
|
||||
start_frames = np.array([
|
||||
np.random.randint(self.start_offset, cl + self.end_offset)
|
||||
for cl in c_lengths
|
||||
|
@ -76,16 +76,17 @@ class Clip(object):
|
|||
|
||||
c_starts = start_frames - self.aux_context_window
|
||||
c_ends = start_frames + self.batch_max_frames + self.aux_context_window
|
||||
y_batch = [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)]
|
||||
c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]
|
||||
y_batch = np.stack(
|
||||
[x[start:end] for x, start, end in zip(xs, x_starts, x_ends)])
|
||||
c_batch = np.stack(
|
||||
[c[:, start:end] for c, start, end in zip(cs, c_starts, c_ends)])
|
||||
|
||||
# convert each batch to tensor, asuume that each item in batch has the same length
|
||||
y_batch = paddle.to_tensor(
|
||||
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
|
||||
c_batch = paddle.to_tensor(
|
||||
c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T')
|
||||
c_batch = paddle.to_tensor(c_batch, dtype=paddle.float32) # (B, C, T')
|
||||
|
||||
return (c_batch, ), y_batch
|
||||
return y_batch, c_batch
|
||||
|
||||
def _adjust_length(self, x, c):
|
||||
"""Adjust the audio and feature lengths.
|
||||
|
@ -96,10 +97,12 @@ class Clip(object):
|
|||
features, this process will be needed.
|
||||
|
||||
"""
|
||||
if len(x) < len(c) * self.hop_size:
|
||||
x = np.pad(x, (0, len(c) * self.hop_size - len(x)), mode="edge")
|
||||
if len(x) < c.shape[1] * self.hop_size:
|
||||
x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)),
|
||||
mode="edge")
|
||||
|
||||
# check the legnth is valid
|
||||
assert len(x) == len(c) * self.hop_size
|
||||
assert len(x) == c.shape[
|
||||
1] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[1]})"
|
||||
|
||||
return x, c
|
||||
|
|
|
@ -86,7 +86,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
|
|||
batch_size: 6 # Batch size.
|
||||
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
||||
num_workers: 2 # Number of workers in Pytorch DataLoader.
|
||||
num_workers: 0 # Number of workers in Pytorch DataLoader.
|
||||
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
|
||||
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
|
||||
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.training.updater import UpdaterBase, UpdaterState
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.checkpoint import KBest, KLatest
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
|
||||
|
||||
class PWGUpdater(UpdaterBase):
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
optimizers,
|
||||
criterions,
|
||||
schedulers,
|
||||
dataloaders,
|
||||
discriminator_train_start_steps: int,
|
||||
lambda_adv: float, ):
|
||||
self.models = models
|
||||
self.generator = models['generator']
|
||||
self.discriminator = models['discriminator']
|
||||
|
||||
self.optimizers = optimizers
|
||||
self.optimizer_g = optimizers['generator']
|
||||
self.optimizer_d = optimizers['discriminator']
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_stft = criterions['stft']
|
||||
self.criterion_mse = criterions['mse']
|
||||
|
||||
self.schedulers = schedulers
|
||||
self.scheduler_g = schedulers['generator']
|
||||
self.scheduler_d = schedulers['discriminator']
|
||||
|
||||
self.dataloaders = dataloaders
|
||||
self.train_dataloader = dataloaders['train']
|
||||
self.dev_dataloader = dataloaders['dev']
|
||||
|
||||
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||
self.lambda_adv = lambda_adv
|
||||
self.state = UpdaterState(iteration=0, epoch=0)
|
||||
|
||||
self.train_iterator = iter(self.train_dataloader)
|
||||
|
||||
def update_core(self):
|
||||
try:
|
||||
batch = next(self.train_iterator)
|
||||
except StopIteration:
|
||||
self.train_iterator = iter(self.train_dataloader)
|
||||
batch = next(self.train_iterator)
|
||||
|
||||
wav, mel = batch
|
||||
|
||||
# Generator
|
||||
noise = paddle.randn(wav.shape)
|
||||
wav_ = self.generator(noise, mel)
|
||||
|
||||
## Multi-resolution stft loss
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||
gen_loss = sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
report("train/adversarial_loss", float(adv_loss))
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
report("train/generator_loss", float(gen_loss))
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
|
||||
# Disctiminator
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
with paddle.no_grad():
|
||||
wav_ = self.generator(noise, mel)
|
||||
p = self.discriminator(wav)
|
||||
p_ = self.discriminator(wav_.detach())
|
||||
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||
report("train/real_loss", float(real_loss))
|
||||
report("train/fake_loss", float(fake_loss))
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("train/discriminator_loss", float(dis_loss))
|
||||
|
||||
self.optimizer_d.clear_grad()
|
||||
dis_loss.backward()
|
||||
self.optimizer_d.step()
|
||||
self.scheduler_d.step()
|
|
@ -41,7 +41,9 @@ from parakeet.training.checkpoint import KBest, KLatest
|
|||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
|
||||
from batch_fn import Clip
|
||||
from config import get_cfg_default
|
||||
from pwg_updater import PWGUpdater
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
|
@ -90,25 +92,35 @@ def train_sp(args, config):
|
|||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
print("samplers done!")
|
||||
|
||||
train_batch_fn = Clip(
|
||||
batch_max_steps=config.batch_max_steps,
|
||||
hop_size=config.hop_length,
|
||||
aux_context_window=config.generator_params.aux_context_window)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=None, # TODO(defaine collate fn)
|
||||
num_workers=4)
|
||||
collate_fn=train_batch_fn, # TODO(defaine collate fn)
|
||||
num_workers=config.num_workers)
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
collate_fn=None, # TODO(defaine collate fn)
|
||||
num_workers=4)
|
||||
collate_fn=train_batch_fn, # TODO(defaine collate fn)
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
generator = PWGGenerator(**config["generator_params"])
|
||||
discriminator = PWGDiscriminator(**config["discriminator_params"])
|
||||
if world_size > 1:
|
||||
generator = DataParallel(generator)
|
||||
discriminator = DataParallel(discriminator)
|
||||
print("models done!")
|
||||
|
||||
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
||||
criterion_mse = nn.MSELoss()
|
||||
print("criterions done!")
|
||||
|
||||
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
||||
optimizer_g = Adam(
|
||||
lr_schedule_g,
|
||||
|
@ -119,14 +131,43 @@ def train_sp(args, config):
|
|||
lr_schedule_d,
|
||||
parameters=discriminator.parameters(),
|
||||
**config["discriminator_optimizer_params"])
|
||||
print("optimizers done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
log_writer = None
|
||||
if dist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_writer = LogWriter(output_dir)
|
||||
log_writer = LogWriter(str(output_dir))
|
||||
|
||||
# training loop
|
||||
updater = PWGUpdater(
|
||||
models={
|
||||
"generator": generator,
|
||||
"discriminator": discriminator,
|
||||
},
|
||||
optimizers={
|
||||
"generator": optimizer_g,
|
||||
"discriminator": optimizer_d,
|
||||
},
|
||||
criterions={
|
||||
"stft": criterion_stft,
|
||||
"mse": criterion_mse,
|
||||
},
|
||||
schedulers={
|
||||
"generator": lr_schedule_g,
|
||||
"discriminator": lr_schedule_d,
|
||||
},
|
||||
dataloaders={
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
},
|
||||
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||
lambda_adv=config.lambda_adv, )
|
||||
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(config.train_max_steps, "iteration"),
|
||||
out=output_dir, )
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -76,7 +76,7 @@ class Trainer(object):
|
|||
else:
|
||||
max_iteration = self.stop_trigger.period
|
||||
|
||||
while not stop_trigger(self):
|
||||
while True:
|
||||
self.observation = {}
|
||||
# set observation as the report target
|
||||
# you can use report freely in Updater.update()
|
||||
|
@ -84,8 +84,13 @@ class Trainer(object):
|
|||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
print(self.observation)
|
||||
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
if entry.trigger(self):
|
||||
entry.extension(self)
|
||||
|
||||
if stop_trigger(self):
|
||||
print("Training Done!")
|
||||
break
|
||||
|
|
|
@ -23,9 +23,9 @@ class IntervalTrigger(object):
|
|||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
if self.unit == "epoch":
|
||||
fire = not (state.epoch % self.period)
|
||||
fire = state.epoch % self.period == 0
|
||||
else:
|
||||
fire = not (state.iteration % self.iteration)
|
||||
fire = state.iteration % self.period == 0
|
||||
return fire
|
||||
|
||||
|
||||
|
|
|
@ -57,7 +57,8 @@ class UpdaterBase(object):
|
|||
"""
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
self.state.iteration += 1
|
||||
self.update_core()
|
||||
|
||||
def update_core(self):
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue