fix log format of fastspeech2 speedyspeech and pwg

This commit is contained in:
TianYuan 2021-09-03 07:53:33 +00:00
parent 5bc570aee5
commit 065fa32a37
17 changed files with 327 additions and 102 deletions

View File

@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
def get_cfg_default(): def get_cfg_default():
config = _C.clone() config = _C.clone()
return config return config
print(get_cfg_default())

View File

@ -11,10 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from paddle import distributed as dist
from parakeet.models.fastspeech2 import FastSpeech2Loss from parakeet.models.fastspeech2 import FastSpeech2Loss
from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater from parakeet.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater): class FastSpeech2Updater(StandardUpdater):
@ -24,12 +32,22 @@ class FastSpeech2Updater(StandardUpdater):
dataloader, dataloader,
init_state=None, init_state=None,
use_masking=False, use_masking=False,
use_weighted_masking=False): use_weighted_masking=False,
output_dir=None):
super().__init__(model, optimizer, dataloader, init_state=None) super().__init__(model, optimizer, dataloader, init_state=None)
self.use_masking = use_masking self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch): def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
@ -70,18 +88,36 @@ class FastSpeech2Updater(StandardUpdater):
report("train/pitch_loss", float(pitch_loss)) report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss)) report("train/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class FastSpeech2Evaluator(StandardEvaluator): class FastSpeech2Evaluator(StandardEvaluator):
def __init__(self, def __init__(self,
model, model,
dataloader, dataloader,
use_masking=False, use_masking=False,
use_weighted_masking=False): use_weighted_masking=False,
output_dir=None):
super().__init__(model, dataloader) super().__init__(model, dataloader)
self.use_masking = use_masking self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch): def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
@ -114,3 +150,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
report("eval/duration_loss", float(duration_loss)) report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss)) report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss)) report("eval/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,8 +23,10 @@ import soundfile as sf
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore from parakeet.modules.normalizer import ZScore
@ -102,9 +104,7 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.") description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument( parser.add_argument(
"--fastspeech2-config", "--fastspeech2-config", type=str, help="fastspeech2 config file.")
type=str,
help="config file to overwrite default config.")
parser.add_argument( parser.add_argument(
"--fastspeech2-checkpoint", "--fastspeech2-checkpoint",
type=str, type=str,
@ -115,10 +115,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2." help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
) )
parser.add_argument( parser.add_argument(
"--pwg-config", "--pwg-config", type=str, help="parallel wavegan config file.")
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument( parser.add_argument(
"--pwg-params", "--pwg-params",
type=str, type=str,

View File

@ -21,8 +21,10 @@ import paddle
import soundfile as sf import soundfile as sf
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore from parakeet.modules.normalizer import ZScore
from frontend import Frontend from frontend import Frontend
@ -113,9 +115,7 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.") description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument( parser.add_argument(
"--fastspeech2-config", "--fastspeech2-config", type=str, help="fastspeech2 config file.")
type=str,
help="fastspeech2 config file to overwrite default config.")
parser.add_argument( parser.add_argument(
"--fastspeech2-checkpoint", "--fastspeech2-checkpoint",
type=str, type=str,
@ -126,9 +126,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2." help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
) )
parser.add_argument( parser.add_argument(
"--pwg-config", "--pwg-config", type=str, help="parallel wavegan config file.")
type=str,
help="parallel wavegan config file to overwrite default config.")
parser.add_argument( parser.add_argument(
"--pwg-params", "--pwg-params",
type=str, type=str,

View File

@ -23,7 +23,8 @@ import paddle
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle import nn from paddle import nn
from paddle.io import DataLoader, DistributedBatchSampler from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2 from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.training.extensions.snapshot import Snapshot from parakeet.training.extensions.snapshot import Snapshot
@ -35,7 +36,8 @@ import yaml
from batch_fn import collate_aishell3_examples from batch_fn import collate_aishell3_examples
from config import get_cfg_default from config import get_cfg_default
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator from fastspeech2_updater import FastSpeech2Evaluator
from fastspeech2_updater import FastSpeech2Updater
optim_classes = dict( optim_classes = dict(
adadelta=paddle.optimizer.Adadelta, adadelta=paddle.optimizer.Adadelta,
@ -97,6 +99,7 @@ def train_sp(args, config):
"energy": np.load}, ) "energy": np.load}, )
with jsonlines.open(args.dev_metadata, 'r') as reader: with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader) dev_metadata = list(reader)
dev_dataset = DataTable( dev_dataset = DataTable(
data=dev_metadata, data=dev_metadata,
fields=[ fields=[
@ -154,16 +157,19 @@ def train_sp(args, config):
optimizer = build_optimizers(model, **config["optimizer"]) optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!") print("optimizer done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
updater = FastSpeech2Updater( updater = FastSpeech2Updater(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=train_dataloader, dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"]) **config["updater"])
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"]) evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))
@ -171,7 +177,7 @@ def train_sp(args, config):
trainer.extend(VisualDL(writer), trigger=(1, "iteration")) trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions) # print(trainer.extensions)
trainer.run() trainer.run()

View File

@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
def get_cfg_default(): def get_cfg_default():
config = _C.clone() config = _C.clone()
return config return config
print(get_cfg_default())

View File

@ -11,10 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from paddle import distributed as dist
from parakeet.models.fastspeech2 import FastSpeech2Loss from parakeet.models.fastspeech2 import FastSpeech2Loss
from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater from parakeet.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class FastSpeech2Updater(StandardUpdater): class FastSpeech2Updater(StandardUpdater):
@ -24,12 +32,21 @@ class FastSpeech2Updater(StandardUpdater):
dataloader, dataloader,
init_state=None, init_state=None,
use_masking=False, use_masking=False,
use_weighted_masking=False): use_weighted_masking=False,
output_dir=None):
super().__init__(model, optimizer, dataloader, init_state=None) super().__init__(model, optimizer, dataloader, init_state=None)
self.use_masking = use_masking self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch): def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
@ -69,18 +86,36 @@ class FastSpeech2Updater(StandardUpdater):
report("train/pitch_loss", float(pitch_loss)) report("train/pitch_loss", float(pitch_loss))
report("train/energy_loss", float(energy_loss)) report("train/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class FastSpeech2Evaluator(StandardEvaluator): class FastSpeech2Evaluator(StandardEvaluator):
def __init__(self, def __init__(self,
model, model,
dataloader, dataloader,
use_masking=False, use_masking=False,
use_weighted_masking=False): use_weighted_masking=False,
output_dir=None):
super().__init__(model, dataloader) super().__init__(model, dataloader)
self.use_masking = use_masking self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking self.use_weighted_masking = use_weighted_masking
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch): def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"], text=batch["text"],
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
@ -112,3 +147,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
report("eval/duration_loss", float(duration_loss)) report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss)) report("eval/pitch_loss", float(pitch_loss))
report("eval/energy_loss", float(energy_loss)) report("eval/energy_loss", float(energy_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,8 +23,10 @@ import soundfile as sf
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore from parakeet.modules.normalizer import ZScore
@ -91,9 +93,7 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.") description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument( parser.add_argument(
"--fastspeech2-config", "--fastspeech2-config", type=str, help="fastspeech2 config file.")
type=str,
help="config file to overwrite default config.")
parser.add_argument( parser.add_argument(
"--fastspeech2-checkpoint", "--fastspeech2-checkpoint",
type=str, type=str,
@ -104,10 +104,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2." help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
) )
parser.add_argument( parser.add_argument(
"--pwg-config", "--pwg-config", type=str, help="parallel wavegan config file.")
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument( parser.add_argument(
"--pwg-params", "--pwg-params",
type=str, type=str,

View File

@ -21,8 +21,10 @@ import paddle
import soundfile as sf import soundfile as sf
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference from parakeet.models.fastspeech2 import FastSpeech2
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference from parakeet.models.fastspeech2 import FastSpeech2Inference
from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGInference
from parakeet.modules.normalizer import ZScore from parakeet.modules.normalizer import ZScore
from frontend import Frontend from frontend import Frontend
@ -103,9 +105,7 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize with fastspeech2 & parallel wavegan.") description="Synthesize with fastspeech2 & parallel wavegan.")
parser.add_argument( parser.add_argument(
"--fastspeech2-config", "--fastspeech2-config", type=str, help="fastspeech2 config file.")
type=str,
help="fastspeech2 config file to overwrite default config.")
parser.add_argument( parser.add_argument(
"--fastspeech2-checkpoint", "--fastspeech2-checkpoint",
type=str, type=str,
@ -116,9 +116,7 @@ def main():
help="mean and standard deviation used to normalize spectrogram when training fastspeech2." help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
) )
parser.add_argument( parser.add_argument(
"--pwg-config", "--pwg-config", type=str, help="parallel wavegan config file.")
type=str,
help="parallel wavegan config file to overwrite default config.")
parser.add_argument( parser.add_argument(
"--pwg-params", "--pwg-params",
type=str, type=str,

View File

@ -35,7 +35,8 @@ import yaml
from batch_fn import collate_baker_examples from batch_fn import collate_baker_examples
from config import get_cfg_default from config import get_cfg_default
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator from fastspeech2_updater import FastSpeech2Evaluator
from fastspeech2_updater import FastSpeech2Updater
optim_classes = dict( optim_classes = dict(
adadelta=paddle.optimizer.Adadelta, adadelta=paddle.optimizer.Adadelta,
@ -108,6 +109,7 @@ def train_sp(args, config):
"energy": np.load}, ) "energy": np.load}, )
# collate function and dataloader # collate function and dataloader
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
@ -145,16 +147,20 @@ def train_sp(args, config):
optimizer = build_optimizers(model, **config["optimizer"]) optimizer = build_optimizers(model, **config["optimizer"])
print("optimizer done!") print("optimizer done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
updater = FastSpeech2Updater( updater = FastSpeech2Updater(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=train_dataloader, dataloader=train_dataloader,
output_dir=output_dir,
**config["updater"]) **config["updater"])
output_dir = Path(args.output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"]) evaluator = FastSpeech2Evaluator(
model, dev_dataloader, output_dir=output_dir, **config["updater"])
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))
@ -162,7 +168,7 @@ def train_sp(args, config):
trainer.extend(VisualDL(writer), trigger=(1, "iteration")) trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions) # print(trainer.extensions)
trainer.run() trainer.run()

View File

@ -16,27 +16,33 @@ import logging
from typing import Dict from typing import Dict
import paddle import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from paddle.io import DataLoader
from timer import timer
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.updaters.standard_updater import UpdaterState
from timer import timer
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PWGUpdater(StandardUpdater): class PWGUpdater(StandardUpdater):
def __init__( def __init__(self,
self,
models: Dict[str, Layer], models: Dict[str, Layer],
optimizers: Dict[str, Optimizer], optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer], criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler], schedulers: Dict[str, LRScheduler],
dataloader: DataLoader, dataloader: DataLoader,
discriminator_train_start_steps: int, discriminator_train_start_steps: int,
lambda_adv: float, ): lambda_adv: float,
output_dir=None):
self.models = models self.models = models
self.generator: Layer = models['generator'] self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator'] self.discriminator: Layer = models['discriminator']
@ -61,7 +67,16 @@ class PWGUpdater(StandardUpdater):
self.train_iterator = iter(self.dataloader) self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch): def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch # parse batch
wav, mel = batch wav, mel = batch
@ -70,7 +85,7 @@ class PWGUpdater(StandardUpdater):
with timer() as t: with timer() as t:
wav_ = self.generator(noise, mel) wav_ = self.generator(noise, mel)
logging.debug(f"Generator takes {t.elapse}s.") # logging.debug(f"Generator takes {t.elapse}s.")
# initialize # initialize
gen_loss = 0.0 gen_loss = 0.0
@ -78,10 +93,14 @@ class PWGUpdater(StandardUpdater):
## Multi-resolution stft loss ## Multi-resolution stft loss
with timer() as t: with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav) sc_loss, mag_loss = self.criterion_stft(wav_, wav)
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
report("train/spectral_convergence_loss", float(sc_loss)) report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss)) report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss gen_loss += sc_loss + mag_loss
## Adversarial loss ## Adversarial loss
@ -89,22 +108,24 @@ class PWGUpdater(StandardUpdater):
with timer() as t: with timer() as t:
p_ = self.discriminator(wav_) p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
logging.debug( # logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s") # f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss)) report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss)) report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
with timer() as t: with timer() as t:
self.optimizer_g.clear_grad() self.optimizer_g.clear_grad()
gen_loss.backward() gen_loss.backward()
logging.debug(f"Backward takes {t.elapse}s.") # logging.debug(f"Backward takes {t.elapse}s.")
with timer() as t: with timer() as t:
self.optimizer_g.step() self.optimizer_g.step()
self.scheduler_g.step() self.scheduler_g.step()
logging.debug(f"Update takes {t.elapse}s.") # logging.debug(f"Update takes {t.elapse}s.")
# Disctiminator # Disctiminator
if self.state.iteration > self.discriminator_train_start_steps: if self.state.iteration > self.discriminator_train_start_steps:
@ -118,6 +139,9 @@ class PWGUpdater(StandardUpdater):
report("train/real_loss", float(real_loss)) report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss)) report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss)) report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad() self.optimizer_d.clear_grad()
dis_loss.backward() dis_loss.backward()
@ -125,9 +149,17 @@ class PWGUpdater(StandardUpdater):
self.optimizer_d.step() self.optimizer_d.step()
self.scheduler_d.step() self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class PWGEvaluator(StandardEvaluator): class PWGEvaluator(StandardEvaluator):
def __init__(self, models, criterions, dataloader, lambda_adv): def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
self.models = models self.models = models
self.generator = models['generator'] self.generator = models['generator']
self.discriminator = models['discriminator'] self.discriminator = models['discriminator']
@ -139,34 +171,47 @@ class PWGEvaluator(StandardEvaluator):
self.dataloader = dataloader self.dataloader = dataloader
self.lambda_adv = lambda_adv self.lambda_adv = lambda_adv
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch): def evaluate_core(self, batch):
logging.debug("Evaluate: ") # logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch wav, mel = batch
noise = paddle.randn(wav.shape) noise = paddle.randn(wav.shape)
with timer() as t: with timer() as t:
wav_ = self.generator(noise, mel) wav_ = self.generator(noise, mel)
logging.debug(f"Generator takes {t.elapse}s") # logging.debug(f"Generator takes {t.elapse}s")
## Adversarial loss ## Adversarial loss
with timer() as t: with timer() as t:
p_ = self.discriminator(wav_) p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
logging.debug( # logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s") # f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss)) report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss gen_loss = self.lambda_adv * adv_loss
# stft loss # stft loss
with timer() as t: with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav) sc_loss, mag_loss = self.criterion_stft(wav_, wav)
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") # logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss)) report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss)) report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss gen_loss += sc_loss + mag_loss
report("eval/generator_loss", float(gen_loss)) report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator # Disctiminator
p = self.discriminator(wav) p = self.discriminator(wav)
@ -176,3 +221,11 @@ class PWGEvaluator(StandardEvaluator):
report("eval/real_loss", float(real_loss)) report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss)) report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss)) report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,11 +23,13 @@ import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle import nn from paddle import nn
from paddle.io import DataLoader, DistributedBatchSampler from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay from paddle.optimizer.lr import StepDecay
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator from parakeet.models.parallel_wavegan import PWGGenerator
from parakeet.models.parallel_wavegan import PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
from parakeet.training.extensions.snapshot import Snapshot from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL from parakeet.training.extensions.visualizer import VisualDL
@ -38,7 +40,8 @@ from visualdl import LogWriter
from batch_fn import Clip from batch_fn import Clip
from config import get_cfg_default from config import get_cfg_default
from pwg_updater import PWGUpdater, PWGEvaluator from pwg_updater import PWGUpdater
from pwg_updater import PWGEvaluator
def train_sp(args, config): def train_sp(args, config):
@ -99,11 +102,13 @@ def train_sp(args, config):
batch_max_steps=config.batch_max_steps, batch_max_steps=config.batch_max_steps,
hop_size=config.hop_length, hop_size=config.hop_length,
aux_context_window=config.generator_params.aux_context_window) aux_context_window=config.generator_params.aux_context_window)
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
collate_fn=train_batch_fn, collate_fn=train_batch_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
dev_dataloader = DataLoader( dev_dataloader = DataLoader(
dev_dataset, dev_dataset,
batch_sampler=dev_sampler, batch_sampler=dev_sampler,
@ -139,10 +144,8 @@ def train_sp(args, config):
print("optimizers done!") print("optimizers done!")
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
checkpoint_dir = output_dir / "checkpoints"
if dist.get_rank() == 0: if dist.get_rank() == 0:
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / "config.yaml", 'wt') as f: with open(output_dir / "config.yaml", 'wt') as f:
f.write(config.dump(default_flow_style=None)) f.write(config.dump(default_flow_style=None))
@ -165,7 +168,8 @@ def train_sp(args, config):
}, },
dataloader=train_dataloader, dataloader=train_dataloader,
discriminator_train_start_steps=config.discriminator_train_start_steps, discriminator_train_start_steps=config.discriminator_train_start_steps,
lambda_adv=config.lambda_adv, ) lambda_adv=config.lambda_adv,
output_dir=output_dir)
evaluator = PWGEvaluator( evaluator = PWGEvaluator(
models={ models={
@ -177,21 +181,23 @@ def train_sp(args, config):
"mse": criterion_mse, "mse": criterion_mse,
}, },
dataloader=dev_dataloader, dataloader=dev_dataloader,
lambda_adv=config.lambda_adv, ) lambda_adv=config.lambda_adv,
output_dir=output_dir)
trainer = Trainer( trainer = Trainer(
updater, updater,
stop_trigger=(config.train_max_steps, "iteration"), stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir, ) out=output_dir, )
trainer.extend(evaluator, trigger=(config.eval_interval_steps, 'iteration'))
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
writer = LogWriter(str(trainer.out)) writer = LogWriter(str(trainer.out))
trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration')) trigger=(config.save_interval_steps, 'iteration'))
print(trainer.extensions.keys()) # print(trainer.extensions.keys())
print("Trainer Done!") print("Trainer Done!")
trainer.run() trainer.run()

View File

@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import paddle import paddle
from paddle import distributed as dist
from paddle.fluid.layers import huber_loss from paddle.fluid.layers import huber_loss
from paddle.nn import functional as F from paddle.nn import functional as F
from parakeet.modules.losses import masked_l1_loss, weighted_mean from parakeet.modules.losses import masked_l1_loss, weighted_mean
@ -20,10 +22,32 @@ from parakeet.modules.ssim import ssim
from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater from parakeet.training.updaters.standard_updater import StandardUpdater
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class SpeedySpeechUpdater(StandardUpdater): class SpeedySpeechUpdater(StandardUpdater):
def __init__(self,
model,
optimizer,
dataloader,
init_state=None,
output_dir=None):
super().__init__(model, optimizer, dataloader, init_state=None)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch): def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phones"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
@ -65,9 +89,28 @@ class SpeedySpeechUpdater(StandardUpdater):
report("train/duration_loss", float(duration_loss)) report("train/duration_loss", float(duration_loss))
report("train/ssim_loss", float(ssim_loss)) report("train/ssim_loss", float(ssim_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["ssim_loss"] = float(ssim_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class SpeedySpeechEvaluator(StandardEvaluator): class SpeedySpeechEvaluator(StandardEvaluator):
def __init__(self, model, dataloader, output_dir=None):
super().__init__(model, dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch): def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
decoded, predicted_durations = self.model( decoded, predicted_durations = self.model(
text=batch["phones"], text=batch["phones"],
tones=batch["tones"], tones=batch["tones"],
@ -105,3 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator):
report("eval/l1_loss", float(l1_loss)) report("eval/l1_loss", float(l1_loss))
report("eval/duration_loss", float(duration_loss)) report("eval/duration_loss", float(duration_loss))
report("eval/ssim_loss", float(ssim_loss)) report("eval/ssim_loss", float(ssim_loss))
losses_dict["l1_loss"] = float(l1_loss)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["ssim_loss"] = float(ssim_loss)
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

View File

@ -23,8 +23,9 @@ import yaml
from paddle import distributed as dist from paddle import distributed as dist
from paddle import DataParallel from paddle import DataParallel
from paddle import nn from paddle import nn
from paddle.io import DataLoader, DistributedBatchSampler from paddle.io import DataLoader
from paddle.optimizer import Adam # No RAdaom from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam # No RAdam
from parakeet.datasets.data_table import DataTable from parakeet.datasets.data_table import DataTable
from parakeet.models.speedyspeech import SpeedySpeech from parakeet.models.speedyspeech import SpeedySpeech
from parakeet.training.extensions.snapshot import Snapshot from parakeet.training.extensions.snapshot import Snapshot
@ -36,7 +37,8 @@ from visualdl import LogWriter
from batch_fn import collate_baker_examples from batch_fn import collate_baker_examples
from config import get_cfg_default from config import get_cfg_default
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator from speedyspeech_updater import SpeedySpeechUpdater
from speedyspeech_updater import SpeedySpeechEvaluator
def train_sp(args, config): def train_sp(args, config):
@ -121,13 +123,19 @@ def train_sp(args, config):
grad_clip=nn.ClipGradByGlobalNorm(5.0)) grad_clip=nn.ClipGradByGlobalNorm(5.0))
print("optimizer done!") print("optimizer done!")
updater = SpeedySpeechUpdater(
model=model, optimizer=optimizer, dataloader=train_dataloader)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
updater = SpeedySpeechUpdater(
model=model,
optimizer=optimizer,
dataloader=train_dataloader,
output_dir=output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
evaluator = SpeedySpeechEvaluator(model, dev_dataloader) evaluator = SpeedySpeechEvaluator(
model, dev_dataloader, output_dir=output_dir)
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(evaluator, trigger=(1, "epoch"))

View File

@ -280,14 +280,12 @@ class FastSpeech2(nn.Layer):
use_batch_norm=use_batch_norm, use_batch_norm=use_batch_norm,
dropout_rate=postnet_dropout_rate, )) dropout_rate=postnet_dropout_rate, ))
nn.initializer.set_global_initializer(None)
self._reset_parameters( self._reset_parameters(
init_enc_alpha=init_enc_alpha, init_enc_alpha=init_enc_alpha,
init_dec_alpha=init_dec_alpha, ) init_dec_alpha=init_dec_alpha, )
# define criterions
self.criterion = FastSpeech2Loss(
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def forward( def forward(
self, self,
text: paddle.Tensor, text: paddle.Tensor,

View File

@ -20,7 +20,6 @@ from typing import List
from typing import Union from typing import Union
import six import six
import tqdm
from parakeet.training.extension import Extension from parakeet.training.extension import Extension
from parakeet.training.extension import PRIORITY_READER from parakeet.training.extension import PRIORITY_READER
@ -122,6 +121,7 @@ class Trainer(object):
entry.extension.initialize(self) entry.extension.initialize(self)
update = self.updater.update # training step update = self.updater.update # training step
stop_trigger = self.stop_trigger stop_trigger = self.stop_trigger
# display only one progress bar # display only one progress bar
@ -135,8 +135,6 @@ class Trainer(object):
else: else:
max_iteration = self.stop_trigger.limit max_iteration = self.stop_trigger.limit
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
try: try:
while not stop_trigger(self): while not stop_trigger(self):
self.observation = {} self.observation = {}
@ -146,7 +144,21 @@ class Trainer(object):
# updating parameters and state # updating parameters and state
with scope(self.observation): with scope(self.observation):
update() update()
p.update() batch_read_time = self.updater.batch_read_time
batch_time = self.updater.batch_time
logger = self.updater.logger
logger.removeHandler(self.updater.filehandler)
msg = self.updater.msg
msg = " iter: {}/{}, ".format(self.updater.state.iteration,
max_iteration) + msg
msg += ", avg_reader_cost: {:.5f} sec, ".format(
batch_read_time
) + "avg_batch_cost: {:.5f} sec, ".format(batch_time)
msg += "avg_samples: {}, ".format(
self.updater.
batch_size) + "avg_ips: {:.5f} sequences/sec".format(
self.updater.batch_size / batch_time)
logger.info(msg)
# execute extension when necessary # execute extension when necessary
for name, entry in extensions: for name, entry in extensions:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import time
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
@ -57,6 +58,8 @@ class StandardUpdater(UpdaterBase):
self.state = init_state self.state = init_state
self.train_iterator = iter(dataloader) self.train_iterator = iter(dataloader)
self.batch_read_time = 0
self.batch_time = 0
def update(self): def update(self):
# We increase the iteration index after updating and before extension. # We increase the iteration index after updating and before extension.
@ -99,8 +102,17 @@ class StandardUpdater(UpdaterBase):
layer.train() layer.train()
# training for a step is implemented here # training for a step is implemented here
time_before_read = time.time()
batch = self.read_batch() batch = self.read_batch()
time_before_core = time.time()
self.update_core(batch) self.update_core(batch)
self.batch_time = time.time() - time_before_core
self.batch_read_time = time_before_core - time_before_read
if isinstance(batch, dict):
self.batch_size = len(list(batch.items())[0][-1])
# for pwg
elif isinstance(batch, list):
self.batch_size = batch[0].shape[0]
self.state.iteration += 1 self.state.iteration += 1
if self.updates_per_epoch is not None: if self.updates_per_epoch is not None: