WIP: pwg training works
This commit is contained in:
parent
54c7905f40
commit
b0983e4d76
|
@ -60,13 +60,13 @@ class Clip(object):
|
||||||
"""
|
"""
|
||||||
# check length
|
# check length
|
||||||
examples = [
|
examples = [
|
||||||
self._adjust_length(*b) for b in examples
|
self._adjust_length(b['wave_path'], b['feats_path'])
|
||||||
if len(b[1]) > self.mel_threshold
|
for b in examples if b['feats_path'].shape[1] > self.mel_threshold
|
||||||
]
|
]
|
||||||
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
|
xs, cs = [b[0] for b in examples], [b[1] for b in examples]
|
||||||
|
|
||||||
# make batch with random cut
|
# make batch with random cut
|
||||||
c_lengths = [len(c) for c in cs]
|
c_lengths = [c.shape[1] for c in cs]
|
||||||
start_frames = np.array([
|
start_frames = np.array([
|
||||||
np.random.randint(self.start_offset, cl + self.end_offset)
|
np.random.randint(self.start_offset, cl + self.end_offset)
|
||||||
for cl in c_lengths
|
for cl in c_lengths
|
||||||
|
@ -76,16 +76,17 @@ class Clip(object):
|
||||||
|
|
||||||
c_starts = start_frames - self.aux_context_window
|
c_starts = start_frames - self.aux_context_window
|
||||||
c_ends = start_frames + self.batch_max_frames + self.aux_context_window
|
c_ends = start_frames + self.batch_max_frames + self.aux_context_window
|
||||||
y_batch = [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)]
|
y_batch = np.stack(
|
||||||
c_batch = [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]
|
[x[start:end] for x, start, end in zip(xs, x_starts, x_ends)])
|
||||||
|
c_batch = np.stack(
|
||||||
|
[c[:, start:end] for c, start, end in zip(cs, c_starts, c_ends)])
|
||||||
|
|
||||||
# convert each batch to tensor, asuume that each item in batch has the same length
|
# convert each batch to tensor, asuume that each item in batch has the same length
|
||||||
y_batch = paddle.to_tensor(
|
y_batch = paddle.to_tensor(
|
||||||
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
|
y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T)
|
||||||
c_batch = paddle.to_tensor(
|
c_batch = paddle.to_tensor(c_batch, dtype=paddle.float32) # (B, C, T')
|
||||||
c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T')
|
|
||||||
|
|
||||||
return (c_batch, ), y_batch
|
return y_batch, c_batch
|
||||||
|
|
||||||
def _adjust_length(self, x, c):
|
def _adjust_length(self, x, c):
|
||||||
"""Adjust the audio and feature lengths.
|
"""Adjust the audio and feature lengths.
|
||||||
|
@ -96,10 +97,12 @@ class Clip(object):
|
||||||
features, this process will be needed.
|
features, this process will be needed.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if len(x) < len(c) * self.hop_size:
|
if len(x) < c.shape[1] * self.hop_size:
|
||||||
x = np.pad(x, (0, len(c) * self.hop_size - len(x)), mode="edge")
|
x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)),
|
||||||
|
mode="edge")
|
||||||
|
|
||||||
# check the legnth is valid
|
# check the legnth is valid
|
||||||
assert len(x) == len(c) * self.hop_size
|
assert len(x) == c.shape[
|
||||||
|
1] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[1]})"
|
||||||
|
|
||||||
return x, c
|
return x, c
|
||||||
|
|
|
@ -86,7 +86,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
|
||||||
batch_size: 6 # Batch size.
|
batch_size: 6 # Batch size.
|
||||||
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
|
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||||
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
||||||
num_workers: 2 # Number of workers in Pytorch DataLoader.
|
num_workers: 0 # Number of workers in Pytorch DataLoader.
|
||||||
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
|
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
|
||||||
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
|
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.training.updater import UpdaterBase, UpdaterState
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.training.checkpoint import KBest, KLatest
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||||
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
|
|
||||||
|
|
||||||
|
class PWGUpdater(UpdaterBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
models,
|
||||||
|
optimizers,
|
||||||
|
criterions,
|
||||||
|
schedulers,
|
||||||
|
dataloaders,
|
||||||
|
discriminator_train_start_steps: int,
|
||||||
|
lambda_adv: float, ):
|
||||||
|
self.models = models
|
||||||
|
self.generator = models['generator']
|
||||||
|
self.discriminator = models['discriminator']
|
||||||
|
|
||||||
|
self.optimizers = optimizers
|
||||||
|
self.optimizer_g = optimizers['generator']
|
||||||
|
self.optimizer_d = optimizers['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_stft = criterions['stft']
|
||||||
|
self.criterion_mse = criterions['mse']
|
||||||
|
|
||||||
|
self.schedulers = schedulers
|
||||||
|
self.scheduler_g = schedulers['generator']
|
||||||
|
self.scheduler_d = schedulers['discriminator']
|
||||||
|
|
||||||
|
self.dataloaders = dataloaders
|
||||||
|
self.train_dataloader = dataloaders['train']
|
||||||
|
self.dev_dataloader = dataloaders['dev']
|
||||||
|
|
||||||
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
self.state = UpdaterState(iteration=0, epoch=0)
|
||||||
|
|
||||||
|
self.train_iterator = iter(self.train_dataloader)
|
||||||
|
|
||||||
|
def update_core(self):
|
||||||
|
try:
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
except StopIteration:
|
||||||
|
self.train_iterator = iter(self.train_dataloader)
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
|
||||||
|
wav, mel = batch
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
noise = paddle.randn(wav.shape)
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
|
||||||
|
## Multi-resolution stft loss
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
gen_loss = sc_loss + mag_loss
|
||||||
|
|
||||||
|
## Adversarial loss
|
||||||
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
self.optimizer_g.clear_grad()
|
||||||
|
gen_loss.backward()
|
||||||
|
self.optimizer_g.step()
|
||||||
|
self.scheduler_g.step()
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
p = self.discriminator(wav)
|
||||||
|
p_ = self.discriminator(wav_.detach())
|
||||||
|
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||||
|
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||||
|
report("train/real_loss", float(real_loss))
|
||||||
|
report("train/fake_loss", float(fake_loss))
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
report("train/discriminator_loss", float(dis_loss))
|
||||||
|
|
||||||
|
self.optimizer_d.clear_grad()
|
||||||
|
dis_loss.backward()
|
||||||
|
self.optimizer_d.step()
|
||||||
|
self.scheduler_d.step()
|
|
@ -41,7 +41,9 @@ from parakeet.training.checkpoint import KBest, KLatest
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
|
|
||||||
|
from batch_fn import Clip
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
|
from pwg_updater import PWGUpdater
|
||||||
|
|
||||||
|
|
||||||
def train_sp(args, config):
|
def train_sp(args, config):
|
||||||
|
@ -90,25 +92,35 @@ def train_sp(args, config):
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
drop_last=False)
|
drop_last=False)
|
||||||
|
print("samplers done!")
|
||||||
|
|
||||||
|
train_batch_fn = Clip(
|
||||||
|
batch_max_steps=config.batch_max_steps,
|
||||||
|
hop_size=config.hop_length,
|
||||||
|
aux_context_window=config.generator_params.aux_context_window)
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_sampler=train_sampler,
|
batch_sampler=train_sampler,
|
||||||
collate_fn=None, # TODO(defaine collate fn)
|
collate_fn=train_batch_fn, # TODO(defaine collate fn)
|
||||||
num_workers=4)
|
num_workers=config.num_workers)
|
||||||
dev_dataloader = DataLoader(
|
dev_dataloader = DataLoader(
|
||||||
dev_dataset,
|
dev_dataset,
|
||||||
batch_sampler=dev_sampler,
|
batch_sampler=dev_sampler,
|
||||||
collate_fn=None, # TODO(defaine collate fn)
|
collate_fn=train_batch_fn, # TODO(defaine collate fn)
|
||||||
num_workers=4)
|
num_workers=config.num_workers)
|
||||||
|
print("dataloaders done!")
|
||||||
|
|
||||||
generator = PWGGenerator(**config["generator_params"])
|
generator = PWGGenerator(**config["generator_params"])
|
||||||
discriminator = PWGDiscriminator(**config["discriminator_params"])
|
discriminator = PWGDiscriminator(**config["discriminator_params"])
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
generator = DataParallel(generator)
|
generator = DataParallel(generator)
|
||||||
discriminator = DataParallel(discriminator)
|
discriminator = DataParallel(discriminator)
|
||||||
|
print("models done!")
|
||||||
|
|
||||||
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"])
|
||||||
criterion_mse = nn.MSELoss()
|
criterion_mse = nn.MSELoss()
|
||||||
|
print("criterions done!")
|
||||||
|
|
||||||
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
||||||
optimizer_g = Adam(
|
optimizer_g = Adam(
|
||||||
lr_schedule_g,
|
lr_schedule_g,
|
||||||
|
@ -119,14 +131,43 @@ def train_sp(args, config):
|
||||||
lr_schedule_d,
|
lr_schedule_d,
|
||||||
parameters=discriminator.parameters(),
|
parameters=discriminator.parameters(),
|
||||||
**config["discriminator_optimizer_params"])
|
**config["discriminator_optimizer_params"])
|
||||||
|
print("optimizers done!")
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
log_writer = None
|
log_writer = None
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
log_writer = LogWriter(output_dir)
|
log_writer = LogWriter(str(output_dir))
|
||||||
|
|
||||||
# training loop
|
updater = PWGUpdater(
|
||||||
|
models={
|
||||||
|
"generator": generator,
|
||||||
|
"discriminator": discriminator,
|
||||||
|
},
|
||||||
|
optimizers={
|
||||||
|
"generator": optimizer_g,
|
||||||
|
"discriminator": optimizer_d,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"stft": criterion_stft,
|
||||||
|
"mse": criterion_mse,
|
||||||
|
},
|
||||||
|
schedulers={
|
||||||
|
"generator": lr_schedule_g,
|
||||||
|
"discriminator": lr_schedule_d,
|
||||||
|
},
|
||||||
|
dataloaders={
|
||||||
|
"train": train_dataloader,
|
||||||
|
"dev": dev_dataloader,
|
||||||
|
},
|
||||||
|
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||||
|
lambda_adv=config.lambda_adv, )
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
updater,
|
||||||
|
stop_trigger=(config.train_max_steps, "iteration"),
|
||||||
|
out=output_dir, )
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -76,7 +76,7 @@ class Trainer(object):
|
||||||
else:
|
else:
|
||||||
max_iteration = self.stop_trigger.period
|
max_iteration = self.stop_trigger.period
|
||||||
|
|
||||||
while not stop_trigger(self):
|
while True:
|
||||||
self.observation = {}
|
self.observation = {}
|
||||||
# set observation as the report target
|
# set observation as the report target
|
||||||
# you can use report freely in Updater.update()
|
# you can use report freely in Updater.update()
|
||||||
|
@ -84,8 +84,13 @@ class Trainer(object):
|
||||||
# updating parameters and state
|
# updating parameters and state
|
||||||
with scope(self.observation):
|
with scope(self.observation):
|
||||||
update()
|
update()
|
||||||
|
print(self.observation)
|
||||||
|
|
||||||
# execute extension when necessary
|
# execute extension when necessary
|
||||||
for name, entry in extensions:
|
for name, entry in extensions:
|
||||||
if entry.trigger(self):
|
if entry.trigger(self):
|
||||||
entry.extension(self)
|
entry.extension(self)
|
||||||
|
|
||||||
|
if stop_trigger(self):
|
||||||
|
print("Training Done!")
|
||||||
|
break
|
||||||
|
|
|
@ -23,9 +23,9 @@ class IntervalTrigger(object):
|
||||||
def __call__(self, trainer):
|
def __call__(self, trainer):
|
||||||
state = trainer.updater.state
|
state = trainer.updater.state
|
||||||
if self.unit == "epoch":
|
if self.unit == "epoch":
|
||||||
fire = not (state.epoch % self.period)
|
fire = state.epoch % self.period == 0
|
||||||
else:
|
else:
|
||||||
fire = not (state.iteration % self.iteration)
|
fire = state.iteration % self.period == 0
|
||||||
return fire
|
return fire
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -57,7 +57,8 @@ class UpdaterBase(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
pass
|
self.state.iteration += 1
|
||||||
|
self.update_core()
|
||||||
|
|
||||||
def update_core(self):
|
def update_core(self):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue