Merge pull request #158 from yt605155624/fix_log
fix log format of fastspeech2 speedyspeech and pwg
This commit is contained in:
commit
def8218d33
|
@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
|
|||
def get_cfg_default():
|
||||
config = _C.clone()
|
||||
return config
|
||||
|
||||
|
||||
print(get_cfg_default())
|
||||
|
|
|
@ -11,10 +11,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from paddle import distributed as dist
|
||||
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class FastSpeech2Updater(StandardUpdater):
|
||||
|
@ -24,12 +32,22 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
dataloader,
|
||||
init_state=None,
|
||||
use_masking=False,
|
||||
use_weighted_masking=False):
|
||||
use_weighted_masking=False,
|
||||
output_dir=None):
|
||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||
self.use_masking = use_masking
|
||||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
|
@ -70,18 +88,36 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
report("train/pitch_loss", float(pitch_loss))
|
||||
report("train/energy_loss", float(energy_loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||
losses_dict["energy_loss"] = float(energy_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class FastSpeech2Evaluator(StandardEvaluator):
|
||||
def __init__(self,
|
||||
model,
|
||||
dataloader,
|
||||
use_masking=False,
|
||||
use_weighted_masking=False):
|
||||
use_weighted_masking=False,
|
||||
output_dir=None):
|
||||
super().__init__(model, dataloader)
|
||||
self.use_masking = use_masking
|
||||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
|
@ -114,3 +150,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|||
report("eval/duration_loss", float(duration_loss))
|
||||
report("eval/pitch_loss", float(pitch_loss))
|
||||
report("eval/energy_loss", float(energy_loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||
losses_dict["energy_loss"] = float(energy_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
||||
|
|
|
@ -23,8 +23,10 @@ import soundfile as sf
|
|||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.models.fastspeech2 import FastSpeech2
|
||||
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
from parakeet.models.parallel_wavegan import PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
|
||||
|
@ -102,9 +104,7 @@ def main():
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-config",
|
||||
type=str,
|
||||
help="config file to overwrite default config.")
|
||||
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-checkpoint",
|
||||
type=str,
|
||||
|
@ -115,10 +115,7 @@ def main():
|
|||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||
)
|
||||
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
|
|
|
@ -21,8 +21,10 @@ import paddle
|
|||
import soundfile as sf
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.models.fastspeech2 import FastSpeech2
|
||||
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
from parakeet.models.parallel_wavegan import PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
from frontend import Frontend
|
||||
|
@ -113,9 +115,7 @@ def main():
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-config",
|
||||
type=str,
|
||||
help="fastspeech2 config file to overwrite default config.")
|
||||
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-checkpoint",
|
||||
type=str,
|
||||
|
@ -126,9 +126,7 @@ def main():
|
|||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="parallel wavegan config file to overwrite default config.")
|
||||
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
|
|
|
@ -23,7 +23,8 @@ import paddle
|
|||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.fastspeech2 import FastSpeech2
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
|
@ -35,7 +36,8 @@ import yaml
|
|||
|
||||
from batch_fn import collate_aishell3_examples
|
||||
from config import get_cfg_default
|
||||
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
||||
from fastspeech2_updater import FastSpeech2Evaluator
|
||||
from fastspeech2_updater import FastSpeech2Updater
|
||||
|
||||
optim_classes = dict(
|
||||
adadelta=paddle.optimizer.Adadelta,
|
||||
|
@ -97,6 +99,7 @@ def train_sp(args, config):
|
|||
"energy": np.load}, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=[
|
||||
|
@ -154,16 +157,19 @@ def train_sp(args, config):
|
|||
optimizer = build_optimizers(model, **config["optimizer"])
|
||||
print("optimizer done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
updater = FastSpeech2Updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
output_dir=output_dir,
|
||||
**config["updater"])
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
|
||||
evaluator = FastSpeech2Evaluator(
|
||||
model, dev_dataloader, output_dir=output_dir, **config["updater"])
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
|
@ -171,7 +177,7 @@ def train_sp(args, config):
|
|||
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
print(trainer.extensions)
|
||||
# print(trainer.extensions)
|
||||
trainer.run()
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
|
|||
def get_cfg_default():
|
||||
config = _C.clone()
|
||||
return config
|
||||
|
||||
|
||||
print(get_cfg_default())
|
||||
|
|
|
@ -11,10 +11,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from paddle import distributed as dist
|
||||
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class FastSpeech2Updater(StandardUpdater):
|
||||
|
@ -24,12 +32,21 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
dataloader,
|
||||
init_state=None,
|
||||
use_masking=False,
|
||||
use_weighted_masking=False):
|
||||
use_weighted_masking=False,
|
||||
output_dir=None):
|
||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||
self.use_masking = use_masking
|
||||
self.use_weighted_masking = use_weighted_masking
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
|
@ -69,18 +86,36 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
report("train/pitch_loss", float(pitch_loss))
|
||||
report("train/energy_loss", float(energy_loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||
losses_dict["energy_loss"] = float(energy_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class FastSpeech2Evaluator(StandardEvaluator):
|
||||
def __init__(self,
|
||||
model,
|
||||
dataloader,
|
||||
use_masking=False,
|
||||
use_weighted_masking=False):
|
||||
use_weighted_masking=False,
|
||||
output_dir=None):
|
||||
super().__init__(model, dataloader)
|
||||
self.use_masking = use_masking
|
||||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
|
@ -112,3 +147,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|||
report("eval/duration_loss", float(duration_loss))
|
||||
report("eval/pitch_loss", float(pitch_loss))
|
||||
report("eval/energy_loss", float(energy_loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||
losses_dict["energy_loss"] = float(energy_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
||||
|
|
|
@ -23,8 +23,10 @@ import soundfile as sf
|
|||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.models.fastspeech2 import FastSpeech2
|
||||
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
from parakeet.models.parallel_wavegan import PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
|
||||
|
@ -91,9 +93,7 @@ def main():
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-config",
|
||||
type=str,
|
||||
help="config file to overwrite default config.")
|
||||
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-checkpoint",
|
||||
type=str,
|
||||
|
@ -104,10 +104,7 @@ def main():
|
|||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||
)
|
||||
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
|
|
|
@ -21,8 +21,10 @@ import paddle
|
|||
import soundfile as sf
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.models.fastspeech2 import FastSpeech2
|
||||
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
from parakeet.models.parallel_wavegan import PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
from frontend import Frontend
|
||||
|
@ -103,9 +105,7 @@ def main():
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-config",
|
||||
type=str,
|
||||
help="fastspeech2 config file to overwrite default config.")
|
||||
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-checkpoint",
|
||||
type=str,
|
||||
|
@ -116,9 +116,7 @@ def main():
|
|||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="parallel wavegan config file to overwrite default config.")
|
||||
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
|
|
|
@ -35,7 +35,8 @@ import yaml
|
|||
|
||||
from batch_fn import collate_baker_examples
|
||||
from config import get_cfg_default
|
||||
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
||||
from fastspeech2_updater import FastSpeech2Evaluator
|
||||
from fastspeech2_updater import FastSpeech2Updater
|
||||
|
||||
optim_classes = dict(
|
||||
adadelta=paddle.optimizer.Adadelta,
|
||||
|
@ -108,6 +109,7 @@ def train_sp(args, config):
|
|||
"energy": np.load}, )
|
||||
|
||||
# collate function and dataloader
|
||||
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
|
@ -145,16 +147,20 @@ def train_sp(args, config):
|
|||
optimizer = build_optimizers(model, **config["optimizer"])
|
||||
print("optimizer done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
updater = FastSpeech2Updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
output_dir=output_dir,
|
||||
**config["updater"])
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
|
||||
evaluator = FastSpeech2Evaluator(
|
||||
model, dev_dataloader, output_dir=output_dir, **config["updater"])
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
|
@ -162,7 +168,7 @@ def train_sp(args, config):
|
|||
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
print(trainer.extensions)
|
||||
# print(trainer.extensions)
|
||||
trainer.run()
|
||||
|
||||
|
||||
|
|
|
@ -16,27 +16,33 @@ import logging
|
|||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.optimizer.lr import LRScheduler
|
||||
from paddle.io import DataLoader
|
||||
from timer import timer
|
||||
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
from parakeet.training.updaters.standard_updater import UpdaterState
|
||||
from timer import timer
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class PWGUpdater(StandardUpdater):
|
||||
def __init__(
|
||||
self,
|
||||
models: Dict[str, Layer],
|
||||
optimizers: Dict[str, Optimizer],
|
||||
criterions: Dict[str, Layer],
|
||||
schedulers: Dict[str, LRScheduler],
|
||||
dataloader: DataLoader,
|
||||
discriminator_train_start_steps: int,
|
||||
lambda_adv: float, ):
|
||||
def __init__(self,
|
||||
models: Dict[str, Layer],
|
||||
optimizers: Dict[str, Optimizer],
|
||||
criterions: Dict[str, Layer],
|
||||
schedulers: Dict[str, LRScheduler],
|
||||
dataloader: DataLoader,
|
||||
discriminator_train_start_steps: int,
|
||||
lambda_adv: float,
|
||||
output_dir=None):
|
||||
self.models = models
|
||||
self.generator: Layer = models['generator']
|
||||
self.discriminator: Layer = models['discriminator']
|
||||
|
@ -61,7 +67,16 @@ class PWGUpdater(StandardUpdater):
|
|||
|
||||
self.train_iterator = iter(self.dataloader)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
|
||||
# parse batch
|
||||
wav, mel = batch
|
||||
|
||||
|
@ -70,7 +85,7 @@ class PWGUpdater(StandardUpdater):
|
|||
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
logging.debug(f"Generator takes {t.elapse}s.")
|
||||
# logging.debug(f"Generator takes {t.elapse}s.")
|
||||
|
||||
# initialize
|
||||
gen_loss = 0.0
|
||||
|
@ -78,10 +93,14 @@ class PWGUpdater(StandardUpdater):
|
|||
## Multi-resolution stft loss
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||
|
||||
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||
|
||||
gen_loss += sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
|
@ -89,22 +108,24 @@ class PWGUpdater(StandardUpdater):
|
|||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
# logging.debug(
|
||||
# f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("train/adversarial_loss", float(adv_loss))
|
||||
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
report("train/generator_loss", float(gen_loss))
|
||||
losses_dict["generator_loss"] = float(gen_loss)
|
||||
|
||||
with timer() as t:
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
logging.debug(f"Backward takes {t.elapse}s.")
|
||||
# logging.debug(f"Backward takes {t.elapse}s.")
|
||||
|
||||
with timer() as t:
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
logging.debug(f"Update takes {t.elapse}s.")
|
||||
# logging.debug(f"Update takes {t.elapse}s.")
|
||||
|
||||
# Disctiminator
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
|
@ -118,6 +139,9 @@ class PWGUpdater(StandardUpdater):
|
|||
report("train/real_loss", float(real_loss))
|
||||
report("train/fake_loss", float(fake_loss))
|
||||
report("train/discriminator_loss", float(dis_loss))
|
||||
losses_dict["real_loss"] = float(real_loss)
|
||||
losses_dict["fake_loss"] = float(fake_loss)
|
||||
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||
|
||||
self.optimizer_d.clear_grad()
|
||||
dis_loss.backward()
|
||||
|
@ -125,9 +149,17 @@ class PWGUpdater(StandardUpdater):
|
|||
self.optimizer_d.step()
|
||||
self.scheduler_d.step()
|
||||
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class PWGEvaluator(StandardEvaluator):
|
||||
def __init__(self, models, criterions, dataloader, lambda_adv):
|
||||
def __init__(self,
|
||||
models,
|
||||
criterions,
|
||||
dataloader,
|
||||
lambda_adv,
|
||||
output_dir=None):
|
||||
self.models = models
|
||||
self.generator = models['generator']
|
||||
self.discriminator = models['discriminator']
|
||||
|
@ -139,34 +171,47 @@ class PWGEvaluator(StandardEvaluator):
|
|||
self.dataloader = dataloader
|
||||
self.lambda_adv = lambda_adv
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
logging.debug("Evaluate: ")
|
||||
# logging.debug("Evaluate: ")
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
|
||||
wav, mel = batch
|
||||
noise = paddle.randn(wav.shape)
|
||||
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
logging.debug(f"Generator takes {t.elapse}s")
|
||||
# logging.debug(f"Generator takes {t.elapse}s")
|
||||
|
||||
## Adversarial loss
|
||||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
# logging.debug(
|
||||
# f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("eval/adversarial_loss", float(adv_loss))
|
||||
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||
gen_loss = self.lambda_adv * adv_loss
|
||||
|
||||
# stft loss
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||
gen_loss += sc_loss + mag_loss
|
||||
|
||||
report("eval/generator_loss", float(gen_loss))
|
||||
losses_dict["generator_loss"] = float(gen_loss)
|
||||
|
||||
# Disctiminator
|
||||
p = self.discriminator(wav)
|
||||
|
@ -176,3 +221,11 @@ class PWGEvaluator(StandardEvaluator):
|
|||
report("eval/real_loss", float(real_loss))
|
||||
report("eval/fake_loss", float(fake_loss))
|
||||
report("eval/discriminator_loss", float(dis_loss))
|
||||
|
||||
losses_dict["real_loss"] = float(real_loss)
|
||||
losses_dict["fake_loss"] = float(fake_loss)
|
||||
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
||||
|
|
|
@ -23,11 +23,13 @@ import yaml
|
|||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
from parakeet.models.parallel_wavegan import PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
|
@ -38,7 +40,8 @@ from visualdl import LogWriter
|
|||
|
||||
from batch_fn import Clip
|
||||
from config import get_cfg_default
|
||||
from pwg_updater import PWGUpdater, PWGEvaluator
|
||||
from pwg_updater import PWGUpdater
|
||||
from pwg_updater import PWGEvaluator
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
|
@ -99,11 +102,13 @@ def train_sp(args, config):
|
|||
batch_max_steps=config.batch_max_steps,
|
||||
hop_size=config.hop_length,
|
||||
aux_context_window=config.generator_params.aux_context_window)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=train_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
|
@ -139,10 +144,8 @@ def train_sp(args, config):
|
|||
print("optimizers done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
checkpoint_dir = output_dir / "checkpoints"
|
||||
if dist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / "config.yaml", 'wt') as f:
|
||||
f.write(config.dump(default_flow_style=None))
|
||||
|
||||
|
@ -165,7 +168,8 @@ def train_sp(args, config):
|
|||
},
|
||||
dataloader=train_dataloader,
|
||||
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||
lambda_adv=config.lambda_adv, )
|
||||
lambda_adv=config.lambda_adv,
|
||||
output_dir=output_dir)
|
||||
|
||||
evaluator = PWGEvaluator(
|
||||
models={
|
||||
|
@ -177,21 +181,23 @@ def train_sp(args, config):
|
|||
"mse": criterion_mse,
|
||||
},
|
||||
dataloader=dev_dataloader,
|
||||
lambda_adv=config.lambda_adv, )
|
||||
lambda_adv=config.lambda_adv,
|
||||
output_dir=output_dir)
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(config.train_max_steps, "iteration"),
|
||||
out=output_dir, )
|
||||
|
||||
trainer.extend(evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(
|
||||
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||
writer = LogWriter(str(trainer.out))
|
||||
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots),
|
||||
trigger=(config.save_interval_steps, 'iteration'))
|
||||
|
||||
print(trainer.extensions.keys())
|
||||
# print(trainer.extensions.keys())
|
||||
print("Trainer Done!")
|
||||
trainer.run()
|
||||
|
||||
|
|
|
@ -11,8 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.fluid.layers import huber_loss
|
||||
from paddle.nn import functional as F
|
||||
from parakeet.modules.losses import masked_l1_loss, weighted_mean
|
||||
|
@ -20,10 +22,32 @@ from parakeet.modules.ssim import ssim
|
|||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class SpeedySpeechUpdater(StandardUpdater):
|
||||
def __init__(self,
|
||||
model,
|
||||
optimizer,
|
||||
dataloader,
|
||||
init_state=None,
|
||||
output_dir=None):
|
||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
|
||||
decoded, predicted_durations = self.model(
|
||||
text=batch["phones"],
|
||||
tones=batch["tones"],
|
||||
|
@ -65,9 +89,28 @@ class SpeedySpeechUpdater(StandardUpdater):
|
|||
report("train/duration_loss", float(duration_loss))
|
||||
report("train/ssim_loss", float(ssim_loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["ssim_loss"] = float(ssim_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class SpeedySpeechEvaluator(StandardEvaluator):
|
||||
def __init__(self, model, dataloader, output_dir=None):
|
||||
super().__init__(model, dataloader)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
|
||||
decoded, predicted_durations = self.model(
|
||||
text=batch["phones"],
|
||||
tones=batch["tones"],
|
||||
|
@ -105,3 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator):
|
|||
report("eval/l1_loss", float(l1_loss))
|
||||
report("eval/duration_loss", float(duration_loss))
|
||||
report("eval/ssim_loss", float(ssim_loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["ssim_loss"] = float(ssim_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
||||
|
|
|
@ -23,8 +23,9 @@ import yaml
|
|||
from paddle import distributed as dist
|
||||
from paddle import DataParallel
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdam
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.speedyspeech import SpeedySpeech
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
|
@ -36,7 +37,8 @@ from visualdl import LogWriter
|
|||
|
||||
from batch_fn import collate_baker_examples
|
||||
from config import get_cfg_default
|
||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||
from speedyspeech_updater import SpeedySpeechUpdater
|
||||
from speedyspeech_updater import SpeedySpeechEvaluator
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
|
@ -121,13 +123,19 @@ def train_sp(args, config):
|
|||
grad_clip=nn.ClipGradByGlobalNorm(5.0))
|
||||
print("optimizer done!")
|
||||
|
||||
updater = SpeedySpeechUpdater(
|
||||
model=model, optimizer=optimizer, dataloader=train_dataloader)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
updater = SpeedySpeechUpdater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
output_dir=output_dir)
|
||||
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
|
||||
evaluator = SpeedySpeechEvaluator(
|
||||
model, dev_dataloader, output_dir=output_dir)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
|
|
|
@ -280,14 +280,12 @@ class FastSpeech2(nn.Layer):
|
|||
use_batch_norm=use_batch_norm,
|
||||
dropout_rate=postnet_dropout_rate, ))
|
||||
|
||||
nn.initializer.set_global_initializer(None)
|
||||
|
||||
self._reset_parameters(
|
||||
init_enc_alpha=init_enc_alpha,
|
||||
init_dec_alpha=init_dec_alpha, )
|
||||
|
||||
# define criterions
|
||||
self.criterion = FastSpeech2Loss(
|
||||
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
|
|
|
@ -20,7 +20,6 @@ from typing import List
|
|||
from typing import Union
|
||||
|
||||
import six
|
||||
import tqdm
|
||||
|
||||
from parakeet.training.extension import Extension
|
||||
from parakeet.training.extension import PRIORITY_READER
|
||||
|
@ -122,6 +121,7 @@ class Trainer(object):
|
|||
entry.extension.initialize(self)
|
||||
|
||||
update = self.updater.update # training step
|
||||
|
||||
stop_trigger = self.stop_trigger
|
||||
|
||||
# display only one progress bar
|
||||
|
@ -135,8 +135,6 @@ class Trainer(object):
|
|||
else:
|
||||
max_iteration = self.stop_trigger.limit
|
||||
|
||||
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
|
||||
|
||||
try:
|
||||
while not stop_trigger(self):
|
||||
self.observation = {}
|
||||
|
@ -146,7 +144,21 @@ class Trainer(object):
|
|||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
p.update()
|
||||
batch_read_time = self.updater.batch_read_time
|
||||
batch_time = self.updater.batch_time
|
||||
logger = self.updater.logger
|
||||
logger.removeHandler(self.updater.filehandler)
|
||||
msg = self.updater.msg
|
||||
msg = " iter: {}/{}, ".format(self.updater.state.iteration,
|
||||
max_iteration) + msg
|
||||
msg += ", avg_reader_cost: {:.5f} sec, ".format(
|
||||
batch_read_time
|
||||
) + "avg_batch_cost: {:.5f} sec, ".format(batch_time)
|
||||
msg += "avg_samples: {}, ".format(
|
||||
self.updater.
|
||||
batch_size) + "avg_ips: {:.5f} sequences/sec".format(
|
||||
self.updater.batch_size / batch_time)
|
||||
logger.info(msg)
|
||||
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
|
@ -57,6 +58,8 @@ class StandardUpdater(UpdaterBase):
|
|||
self.state = init_state
|
||||
|
||||
self.train_iterator = iter(dataloader)
|
||||
self.batch_read_time = 0
|
||||
self.batch_time = 0
|
||||
|
||||
def update(self):
|
||||
# We increase the iteration index after updating and before extension.
|
||||
|
@ -99,8 +102,17 @@ class StandardUpdater(UpdaterBase):
|
|||
layer.train()
|
||||
|
||||
# training for a step is implemented here
|
||||
time_before_read = time.time()
|
||||
batch = self.read_batch()
|
||||
time_before_core = time.time()
|
||||
self.update_core(batch)
|
||||
self.batch_time = time.time() - time_before_core
|
||||
self.batch_read_time = time_before_core - time_before_read
|
||||
if isinstance(batch, dict):
|
||||
self.batch_size = len(list(batch.items())[0][-1])
|
||||
# for pwg
|
||||
elif isinstance(batch, list):
|
||||
self.batch_size = batch[0].shape[0]
|
||||
|
||||
self.state.iteration += 1
|
||||
if self.updates_per_epoch is not None:
|
||||
|
|
Loading…
Reference in New Issue