fix log format of fastspeech2 speedyspeech and pwg

This commit is contained in:
TianYuan 2021-09-03 07:53:33 +00:00
parent 5bc570aee5
commit 065fa32a37
17 changed files with 327 additions and 102 deletions

View File

@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
def get_cfg_default():
config = _C.clone()
return config
print(get_cfg_default())

View File

@ -11,10 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from paddle import distributed as dist
from parakeet.models.fastspeech2 import FastSpeech2Loss
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater):
@ -24,12 +32,22 @@ class FastSpeech2Updater(StandardUpdater):
dataloader,
init_state=None,
use_masking=False,
use_weighted_masking=False):
use_weighted_masking=False,
output_dir=None):
super().__init__(model, optimizer, dataloader, init_state=None)
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
@ -70,18 +88,36 @@ class FastSpeech2Updater(StandardUpdater):
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class FastSpeech2Evaluator(StandardEvaluator):
def __init__(self,
model,
dataloader,
use_masking=False,
use_weighted_masking=False):
use_weighted_masking=False,
output_dir=None):
super().__init__(model, dataloader)
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
@ -114,3 +150,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,8 +23,10 @@ import soundfile as sf
import yaml
from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore
@ -102,9 +104,7 @@ def main():
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config",
type=str,
help="config file to overwrite default config.")
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
@ -115,10 +115,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
type=str,

View File

@ -21,8 +21,10 @@ import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore
from frontend import Frontend
@ -113,9 +115,7 @@ def main():
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config",
type=str,
help="fastspeech2 config file to overwrite default config.")
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
@ -126,9 +126,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config",
type=str,
help="parallel wavegan config file to overwrite default config.")
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
type=str,

View File

@ -23,7 +23,8 @@ import paddle
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.training.extensions.snapshot import Snapshot
@ -35,7 +36,8 @@ import yaml
from batch_fn import collate_aishell3_examples
from config import get_cfg_default
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
from fastspeech2_updater import FastSpeech2Evaluator
from fastspeech2_updater import FastSpeech2Updater
optim_classes = dict(
adadelta=paddle.optimizer.Adadelta,
@ -97,6 +99,7 @@ def train_sp(args, config):
"energy": np.load}, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=[
@ -154,16 +157,19 @@ def train_sp(args, config):
optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"])
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
@ -171,7 +177,7 @@ def train_sp(args, config):
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions)
# print(trainer.extensions)
trainer.run()

View File

@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
def get_cfg_default():
config = _C.clone()
return config
print(get_cfg_default())

View File

@ -11,10 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from paddle import distributed as dist
from parakeet.models.fastspeech2 import FastSpeech2Loss
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater):
@ -24,12 +32,21 @@ class FastSpeech2Updater(StandardUpdater):
dataloader,
init_state=None,
use_masking=False,
use_weighted_masking=False):
use_weighted_masking=False,
output_dir=None):
super().__init__(model, optimizer, dataloader, init_state=None)
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
@ -69,18 +86,36 @@ class FastSpeech2Updater(StandardUpdater):
report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class FastSpeech2Evaluator(StandardEvaluator):
def __init__(self,
model,
dataloader,
use_masking=False,
use_weighted_masking=False):
use_weighted_masking=False,
output_dir=None):
super().__init__(model, dataloader)
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
@ -112,3 +147,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,8 +23,10 @@ import soundfile as sf
import yaml
from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore
@ -91,9 +93,7 @@ def main():
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config",
type=str,
help="config file to overwrite default config.")
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
@ -104,10 +104,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
type=str,

View File

@ -21,8 +21,10 @@ import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore
from frontend import Frontend
@ -103,9 +105,7 @@ def main():
parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument(
"--fastspeech2-config",
type=str,
help="fastspeech2 config file to overwrite default config.")
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
@ -116,9 +116,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config",
type=str,
help="parallel wavegan config file to overwrite default config.")
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-params",
type=str,

View File

@ -35,7 +35,8 @@ import yaml
from batch_fn import collate_baker_examples
from config import get_cfg_default
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
from fastspeech2_updater import FastSpeech2Evaluator
from fastspeech2_updater import FastSpeech2Updater
optim_classes = dict(
adadelta=paddle.optimizer.Adadelta,
@ -108,6 +109,7 @@ def train_sp(args, config):
"energy": np.load}, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
@ -145,16 +147,20 @@ def train_sp(args, config):
optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
updater = FastSpeech2Updater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"])
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
@ -162,7 +168,7 @@ def train_sp(args, config):
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions)
# print(trainer.extensions)
trainer.run()

View File

@ -16,27 +16,33 @@ import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddle.io import DataLoader
from timer import timer
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.updaters.standard_updater import UpdaterState
from timer import timer
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PWGUpdater(StandardUpdater):
def __init__(
self,
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float, ):
lambda_adv: float,
output_dir=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
@ -61,7 +67,16 @@ class PWGUpdater(StandardUpdater):
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
wav, mel = batch
@ -70,7 +85,7 @@ class PWGUpdater(StandardUpdater):
with timer() as t:
wav_ = self.generator(noise, mel)
logging.debug(f"Generator takes {t.elapse}s.")
# logging.debug(f"Generator takes {t.elapse}s.")
# initialize
gen_loss = 0.0
@ -78,10 +93,14 @@ class PWGUpdater(StandardUpdater):
## Multi-resolution stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
## Adversarial loss
@ -89,22 +108,24 @@ class PWGUpdater(StandardUpdater):
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s")
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
logging.debug(f"Backward takes {t.elapse}s.")
# logging.debug(f"Backward takes {t.elapse}s.")
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
logging.debug(f"Update takes {t.elapse}s.")
# logging.debug(f"Update takes {t.elapse}s.")
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
@ -118,6 +139,9 @@ class PWGUpdater(StandardUpdater):
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
@ -125,9 +149,17 @@ class PWGUpdater(StandardUpdater):
self.optimizer_d.step()
self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class PWGEvaluator(StandardEvaluator):
def __init__(self, models, criterions, dataloader, lambda_adv):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
@ -139,34 +171,47 @@ class PWGEvaluator(StandardEvaluator):
self.dataloader = dataloader
self.lambda_adv = lambda_adv
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
logging.debug("Evaluate: ")
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
logging.debug(f"Generator takes {t.elapse}s")
# logging.debug(f"Generator takes {t.elapse}s")
## Adversarial loss
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s")
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator
p = self.discriminator(wav)
@ -176,3 +221,11 @@ class PWGEvaluator(StandardEvaluator):
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,11 +23,13 @@ import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay
from parakeet.datasets.data_table import DataTable
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
@ -38,7 +40,8 @@ from visualdl import LogWriter
from batch_fn import Clip
from config import get_cfg_default
from pwg_updater import PWGUpdater, PWGEvaluator
from pwg_updater import PWGUpdater
from pwg_updater import PWGEvaluator
def train_sp(args, config):
@ -99,11 +102,13 @@ def train_sp(args, config):
batch_max_steps=config.batch_max_steps,
hop_size=config.hop_length,
aux_context_window=config.generator_params.aux_context_window)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=train_batch_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
@ -139,10 +144,8 @@ def train_sp(args, config):
print("optimizers done!")
output_dir = Path(args.output_dir)
checkpoint_dir = output_dir / "checkpoints"
if dist.get_rank() == 0:
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / "config.yaml", 'wt') as f:
f.write(config.dump(default_flow_style=None))
@ -165,7 +168,8 @@ def train_sp(args, config):
},
dataloader=train_dataloader,
discriminator_train_start_steps=config.discriminator_train_start_steps,
lambda_adv=config.lambda_adv, )
lambda_adv=config.lambda_adv,
output_dir=output_dir)
evaluator = PWGEvaluator(
models={
@ -177,21 +181,23 @@ def train_sp(args, config):
"mse": criterion_mse,
},
dataloader=dev_dataloader,
lambda_adv=config.lambda_adv, )
lambda_adv=config.lambda_adv,
output_dir=output_dir)
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir, )
trainer.extend(evaluator, trigger=(config.eval_interval_steps, 'iteration'))
if dist.get_rank() == 0:
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
writer = LogWriter(str(trainer.out))
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print(trainer.extensions.keys())
# print(trainer.extensions.keys())
print("Trainer Done!")
trainer.run()

View File

@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from paddle import distributed as dist
from paddle.fluid.layers import huber_loss
from paddle.nn import functional as F
from parakeet.modules.losses import masked_l1_loss, weighted_mean
@ -20,10 +22,32 @@ from parakeet.modules.ssim import ssim
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class SpeedySpeechUpdater(StandardUpdater):
def __init__(self,
model,
optimizer,
dataloader,
init_state=None,
output_dir=None):
super().__init__(model, optimizer, dataloader, init_state=None)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],
@ -65,9 +89,28 @@ class SpeedySpeechUpdater(StandardUpdater):
report("train/duration_loss", float(duration_loss))
report("train/ssim_loss", float(ssim_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["ssim_loss"] = float(ssim_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class SpeedySpeechEvaluator(StandardEvaluator):
def __init__(self, model, dataloader, output_dir=None):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],
@ -105,3 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator):
report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss))
report("eval/ssim_loss", float(ssim_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["ssim_loss"] = float(ssim_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,8 +23,9 @@ import yaml
from paddle import distributed as dist
from paddle import DataParallel
from paddle import nn
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam # No RAdam
from parakeet.datasets.data_table import DataTable
from parakeet.models.speedyspeech import SpeedySpeech
from parakeet.training.extensions.snapshot import Snapshot
@ -36,7 +37,8 @@ from visualdl import LogWriter
from batch_fn import collate_baker_examples
from config import get_cfg_default
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
from speedyspeech_updater import SpeedySpeechUpdater
from speedyspeech_updater import SpeedySpeechEvaluator
def train_sp(args, config):
@ -121,13 +123,19 @@ def train_sp(args, config):
grad_clip=nn.ClipGradByGlobalNorm(5.0))
print("optimizer done!")
updater = SpeedySpeechUpdater(
model=model, optimizer=optimizer, dataloader=train_dataloader)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
updater = SpeedySpeechUpdater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
evaluator = SpeedySpeechEvaluator(
model, dev_dataloader, output_dir=output_dir)
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))

View File

@ -280,14 +280,12 @@ class FastSpeech2(nn.Layer):
use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, ))
nn.initializer.set_global_initializer(None)
self._reset_parameters(
init_enc_alpha=init_enc_alpha,
init_dec_alpha=init_dec_alpha, )
# define criterions
self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def forward(
self,
text: paddle.Tensor,

View File

@ -20,7 +20,6 @@ from typing import List
from typing import Union
import six
import tqdm
from parakeet.training.extension import Extension
from parakeet.training.extension import PRIORITY_READER
@ -122,6 +121,7 @@ class Trainer(object):
entry.extension.initialize(self)
update = self.updater.update # training step
stop_trigger = self.stop_trigger
# display only one progress bar
@ -135,8 +135,6 @@ class Trainer(object):
else:
max_iteration = self.stop_trigger.limit
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
try:
while not stop_trigger(self):
self.observation = {}
@ -146,7 +144,21 @@ class Trainer(object):
# updating parameters and state
with scope(self.observation):
update()
p.update()
batch_read_time = self.updater.batch_read_time
batch_time = self.updater.batch_time
logger = self.updater.logger
logger.removeHandler(self.updater.filehandler)
msg = self.updater.msg
msg = " iter: {}/{}, ".format(self.updater.state.iteration,
max_iteration) + msg
msg += ", avg_reader_cost: {:.5f} sec, ".format(
batch_read_time
) + "avg_batch_cost: {:.5f} sec, ".format(batch_time)
msg += "avg_samples: {}, ".format(
self.updater.
batch_size) + "avg_ips: {:.5f} sequences/sec".format(
self.updater.batch_size / batch_time)
logger.info(msg)
# execute extension when necessary
for name, entry in extensions:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Dict
from typing import Optional
@ -57,6 +58,8 @@ class StandardUpdater(UpdaterBase):
self.state = init_state
self.train_iterator = iter(dataloader)
self.batch_read_time = 0
self.batch_time = 0
def update(self):
# We increase the iteration index after updating and before extension.
@ -99,8 +102,17 @@ class StandardUpdater(UpdaterBase):
layer.train()
# training for a step is implemented here
time_before_read = time.time()
batch = self.read_batch()
time_before_core = time.time()
self.update_core(batch)
self.batch_time = time.time() - time_before_core
self.batch_read_time = time_before_core - time_before_read
if isinstance(batch, dict):
self.batch_size = len(list(batch.items())[0][-1])
# for pwg
elif isinstance(batch, list):
self.batch_size = batch[0].shape[0]
self.state.iteration += 1
if self.updates_per_epoch is not None: