fix log format of fastspeech2 speedyspeech and pwg
This commit is contained in:
parent
5bc570aee5
commit
065fa32a37
|
@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
|
||||||
def get_cfg_default():
|
def get_cfg_default():
|
||||||
config = _C.clone()
|
config = _C.clone()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
print(get_cfg_default())
|
|
||||||
|
|
|
@ -11,10 +11,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from paddle import distributed as dist
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
||||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
from parakeet.training.reporter import report
|
from parakeet.training.reporter import report
|
||||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class FastSpeech2Updater(StandardUpdater):
|
class FastSpeech2Updater(StandardUpdater):
|
||||||
|
@ -24,12 +32,22 @@ class FastSpeech2Updater(StandardUpdater):
|
||||||
dataloader,
|
dataloader,
|
||||||
init_state=None,
|
init_state=None,
|
||||||
use_masking=False,
|
use_masking=False,
|
||||||
use_weighted_masking=False):
|
use_weighted_masking=False,
|
||||||
|
output_dir=None):
|
||||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||||
self.use_masking = use_masking
|
self.use_masking = use_masking
|
||||||
self.use_weighted_masking = use_weighted_masking
|
self.use_weighted_masking = use_weighted_masking
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||||
text=batch["text"],
|
text=batch["text"],
|
||||||
text_lengths=batch["text_lengths"],
|
text_lengths=batch["text_lengths"],
|
||||||
|
@ -70,18 +88,36 @@ class FastSpeech2Updater(StandardUpdater):
|
||||||
report("train/pitch_loss", float(pitch_loss))
|
report("train/pitch_loss", float(pitch_loss))
|
||||||
report("train/energy_loss", float(energy_loss))
|
report("train/energy_loss", float(energy_loss))
|
||||||
|
|
||||||
|
losses_dict["l1_loss"] = float(l1_loss)
|
||||||
|
losses_dict["duration_loss"] = float(duration_loss)
|
||||||
|
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||||
|
losses_dict["energy_loss"] = float(energy_loss)
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
class FastSpeech2Evaluator(StandardEvaluator):
|
class FastSpeech2Evaluator(StandardEvaluator):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model,
|
model,
|
||||||
dataloader,
|
dataloader,
|
||||||
use_masking=False,
|
use_masking=False,
|
||||||
use_weighted_masking=False):
|
use_weighted_masking=False,
|
||||||
|
output_dir=None):
|
||||||
super().__init__(model, dataloader)
|
super().__init__(model, dataloader)
|
||||||
self.use_masking = use_masking
|
self.use_masking = use_masking
|
||||||
self.use_weighted_masking = use_weighted_masking
|
self.use_weighted_masking = use_weighted_masking
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def evaluate_core(self, batch):
|
def evaluate_core(self, batch):
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||||
text=batch["text"],
|
text=batch["text"],
|
||||||
text_lengths=batch["text_lengths"],
|
text_lengths=batch["text_lengths"],
|
||||||
|
@ -114,3 +150,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
||||||
report("eval/duration_loss", float(duration_loss))
|
report("eval/duration_loss", float(duration_loss))
|
||||||
report("eval/pitch_loss", float(pitch_loss))
|
report("eval/pitch_loss", float(pitch_loss))
|
||||||
report("eval/energy_loss", float(energy_loss))
|
report("eval/energy_loss", float(energy_loss))
|
||||||
|
|
||||||
|
losses_dict["l1_loss"] = float(l1_loss)
|
||||||
|
losses_dict["duration_loss"] = float(duration_loss)
|
||||||
|
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||||
|
losses_dict["energy_loss"] = float(energy_loss)
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
||||||
|
|
|
@ -23,8 +23,10 @@ import soundfile as sf
|
||||||
import yaml
|
import yaml
|
||||||
from yacs.config import CfgNode
|
from yacs.config import CfgNode
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
from parakeet.models.fastspeech2 import FastSpeech2
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
from parakeet.models.parallel_wavegan import PWGInference
|
||||||
from parakeet.modules.normalizer import ZScore
|
from parakeet.modules.normalizer import ZScore
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,9 +104,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-config",
|
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||||
type=str,
|
|
||||||
help="config file to overwrite default config.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-checkpoint",
|
"--fastspeech2-checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -115,10 +115,7 @@ def main():
|
||||||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-config",
|
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||||
type=str,
|
|
||||||
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-params",
|
"--pwg-params",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
@ -21,8 +21,10 @@ import paddle
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import yaml
|
import yaml
|
||||||
from yacs.config import CfgNode
|
from yacs.config import CfgNode
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
from parakeet.models.fastspeech2 import FastSpeech2
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
from parakeet.models.parallel_wavegan import PWGInference
|
||||||
from parakeet.modules.normalizer import ZScore
|
from parakeet.modules.normalizer import ZScore
|
||||||
|
|
||||||
from frontend import Frontend
|
from frontend import Frontend
|
||||||
|
@ -113,9 +115,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-config",
|
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||||
type=str,
|
|
||||||
help="fastspeech2 config file to overwrite default config.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-checkpoint",
|
"--fastspeech2-checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -126,9 +126,7 @@ def main():
|
||||||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-config",
|
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||||
type=str,
|
|
||||||
help="parallel wavegan config file to overwrite default config.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-params",
|
"--pwg-params",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
@ -23,7 +23,8 @@ import paddle
|
||||||
from paddle import DataParallel
|
from paddle import DataParallel
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.io import DataLoader, DistributedBatchSampler
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2
|
from parakeet.models.fastspeech2 import FastSpeech2
|
||||||
from parakeet.training.extensions.snapshot import Snapshot
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
@ -35,7 +36,8 @@ import yaml
|
||||||
|
|
||||||
from batch_fn import collate_aishell3_examples
|
from batch_fn import collate_aishell3_examples
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
from fastspeech2_updater import FastSpeech2Evaluator
|
||||||
|
from fastspeech2_updater import FastSpeech2Updater
|
||||||
|
|
||||||
optim_classes = dict(
|
optim_classes = dict(
|
||||||
adadelta=paddle.optimizer.Adadelta,
|
adadelta=paddle.optimizer.Adadelta,
|
||||||
|
@ -97,6 +99,7 @@ def train_sp(args, config):
|
||||||
"energy": np.load}, )
|
"energy": np.load}, )
|
||||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
dev_metadata = list(reader)
|
dev_metadata = list(reader)
|
||||||
|
|
||||||
dev_dataset = DataTable(
|
dev_dataset = DataTable(
|
||||||
data=dev_metadata,
|
data=dev_metadata,
|
||||||
fields=[
|
fields=[
|
||||||
|
@ -154,16 +157,19 @@ def train_sp(args, config):
|
||||||
optimizer = build_optimizers(model, **config["optimizer"])
|
optimizer = build_optimizers(model, **config["optimizer"])
|
||||||
print("optimizer done!")
|
print("optimizer done!")
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
updater = FastSpeech2Updater(
|
updater = FastSpeech2Updater(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
dataloader=train_dataloader,
|
dataloader=train_dataloader,
|
||||||
|
output_dir=output_dir,
|
||||||
**config["updater"])
|
**config["updater"])
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||||
|
|
||||||
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
|
evaluator = FastSpeech2Evaluator(
|
||||||
|
model, dev_dataloader, output_dir=output_dir, **config["updater"])
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||||
|
@ -171,7 +177,7 @@ def train_sp(args, config):
|
||||||
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||||
trainer.extend(
|
trainer.extend(
|
||||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||||
print(trainer.extensions)
|
# print(trainer.extensions)
|
||||||
trainer.run()
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,3 @@ with open(config_path, 'rt') as f:
|
||||||
def get_cfg_default():
|
def get_cfg_default():
|
||||||
config = _C.clone()
|
config = _C.clone()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
print(get_cfg_default())
|
|
||||||
|
|
|
@ -11,10 +11,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from paddle import distributed as dist
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
from parakeet.models.fastspeech2 import FastSpeech2Loss
|
||||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
from parakeet.training.reporter import report
|
from parakeet.training.reporter import report
|
||||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class FastSpeech2Updater(StandardUpdater):
|
class FastSpeech2Updater(StandardUpdater):
|
||||||
|
@ -24,12 +32,21 @@ class FastSpeech2Updater(StandardUpdater):
|
||||||
dataloader,
|
dataloader,
|
||||||
init_state=None,
|
init_state=None,
|
||||||
use_masking=False,
|
use_masking=False,
|
||||||
use_weighted_masking=False):
|
use_weighted_masking=False,
|
||||||
|
output_dir=None):
|
||||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||||
self.use_masking = use_masking
|
self.use_masking = use_masking
|
||||||
self.use_weighted_masking = use_weighted_masking
|
self.use_weighted_masking = use_weighted_masking
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||||
text=batch["text"],
|
text=batch["text"],
|
||||||
text_lengths=batch["text_lengths"],
|
text_lengths=batch["text_lengths"],
|
||||||
|
@ -69,18 +86,36 @@ class FastSpeech2Updater(StandardUpdater):
|
||||||
report("train/pitch_loss", float(pitch_loss))
|
report("train/pitch_loss", float(pitch_loss))
|
||||||
report("train/energy_loss", float(energy_loss))
|
report("train/energy_loss", float(energy_loss))
|
||||||
|
|
||||||
|
losses_dict["l1_loss"] = float(l1_loss)
|
||||||
|
losses_dict["duration_loss"] = float(duration_loss)
|
||||||
|
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||||
|
losses_dict["energy_loss"] = float(energy_loss)
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
class FastSpeech2Evaluator(StandardEvaluator):
|
class FastSpeech2Evaluator(StandardEvaluator):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model,
|
model,
|
||||||
dataloader,
|
dataloader,
|
||||||
use_masking=False,
|
use_masking=False,
|
||||||
use_weighted_masking=False):
|
use_weighted_masking=False,
|
||||||
|
output_dir=None):
|
||||||
super().__init__(model, dataloader)
|
super().__init__(model, dataloader)
|
||||||
self.use_masking = use_masking
|
self.use_masking = use_masking
|
||||||
self.use_weighted_masking = use_weighted_masking
|
self.use_weighted_masking = use_weighted_masking
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def evaluate_core(self, batch):
|
def evaluate_core(self, batch):
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||||
text=batch["text"],
|
text=batch["text"],
|
||||||
text_lengths=batch["text_lengths"],
|
text_lengths=batch["text_lengths"],
|
||||||
|
@ -112,3 +147,12 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
||||||
report("eval/duration_loss", float(duration_loss))
|
report("eval/duration_loss", float(duration_loss))
|
||||||
report("eval/pitch_loss", float(pitch_loss))
|
report("eval/pitch_loss", float(pitch_loss))
|
||||||
report("eval/energy_loss", float(energy_loss))
|
report("eval/energy_loss", float(energy_loss))
|
||||||
|
|
||||||
|
losses_dict["l1_loss"] = float(l1_loss)
|
||||||
|
losses_dict["duration_loss"] = float(duration_loss)
|
||||||
|
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||||
|
losses_dict["energy_loss"] = float(energy_loss)
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
||||||
|
|
|
@ -23,8 +23,10 @@ import soundfile as sf
|
||||||
import yaml
|
import yaml
|
||||||
from yacs.config import CfgNode
|
from yacs.config import CfgNode
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
from parakeet.models.fastspeech2 import FastSpeech2
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
from parakeet.models.parallel_wavegan import PWGInference
|
||||||
from parakeet.modules.normalizer import ZScore
|
from parakeet.modules.normalizer import ZScore
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,9 +93,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-config",
|
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||||
type=str,
|
|
||||||
help="config file to overwrite default config.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-checkpoint",
|
"--fastspeech2-checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -104,10 +104,7 @@ def main():
|
||||||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-config",
|
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||||
type=str,
|
|
||||||
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-params",
|
"--pwg-params",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
@ -21,8 +21,10 @@ import paddle
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import yaml
|
import yaml
|
||||||
from yacs.config import CfgNode
|
from yacs.config import CfgNode
|
||||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Inference
|
from parakeet.models.fastspeech2 import FastSpeech2
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
from parakeet.models.fastspeech2 import FastSpeech2Inference
|
||||||
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
from parakeet.models.parallel_wavegan import PWGInference
|
||||||
from parakeet.modules.normalizer import ZScore
|
from parakeet.modules.normalizer import ZScore
|
||||||
|
|
||||||
from frontend import Frontend
|
from frontend import Frontend
|
||||||
|
@ -103,9 +105,7 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize with fastspeech2 & parallel wavegan.")
|
description="Synthesize with fastspeech2 & parallel wavegan.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-config",
|
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||||
type=str,
|
|
||||||
help="fastspeech2 config file to overwrite default config.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--fastspeech2-checkpoint",
|
"--fastspeech2-checkpoint",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -116,9 +116,7 @@ def main():
|
||||||
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-config",
|
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||||
type=str,
|
|
||||||
help="parallel wavegan config file to overwrite default config.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pwg-params",
|
"--pwg-params",
|
||||||
type=str,
|
type=str,
|
||||||
|
|
|
@ -35,7 +35,8 @@ import yaml
|
||||||
|
|
||||||
from batch_fn import collate_baker_examples
|
from batch_fn import collate_baker_examples
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
from fastspeech2_updater import FastSpeech2Evaluator
|
||||||
|
from fastspeech2_updater import FastSpeech2Updater
|
||||||
|
|
||||||
optim_classes = dict(
|
optim_classes = dict(
|
||||||
adadelta=paddle.optimizer.Adadelta,
|
adadelta=paddle.optimizer.Adadelta,
|
||||||
|
@ -108,6 +109,7 @@ def train_sp(args, config):
|
||||||
"energy": np.load}, )
|
"energy": np.load}, )
|
||||||
|
|
||||||
# collate function and dataloader
|
# collate function and dataloader
|
||||||
|
|
||||||
train_sampler = DistributedBatchSampler(
|
train_sampler = DistributedBatchSampler(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
|
@ -145,16 +147,20 @@ def train_sp(args, config):
|
||||||
optimizer = build_optimizers(model, **config["optimizer"])
|
optimizer = build_optimizers(model, **config["optimizer"])
|
||||||
print("optimizer done!")
|
print("optimizer done!")
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
updater = FastSpeech2Updater(
|
updater = FastSpeech2Updater(
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
dataloader=train_dataloader,
|
dataloader=train_dataloader,
|
||||||
|
output_dir=output_dir,
|
||||||
**config["updater"])
|
**config["updater"])
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
|
||||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||||
|
|
||||||
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **config["updater"])
|
evaluator = FastSpeech2Evaluator(
|
||||||
|
model, dev_dataloader, output_dir=output_dir, **config["updater"])
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||||
|
@ -162,7 +168,7 @@ def train_sp(args, config):
|
||||||
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||||
trainer.extend(
|
trainer.extend(
|
||||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||||
print(trainer.extensions)
|
# print(trainer.extensions)
|
||||||
trainer.run()
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,27 +16,33 @@ import logging
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
from paddle.nn import Layer
|
from paddle.nn import Layer
|
||||||
from paddle.optimizer import Optimizer
|
from paddle.optimizer import Optimizer
|
||||||
from paddle.optimizer.lr import LRScheduler
|
from paddle.optimizer.lr import LRScheduler
|
||||||
from paddle.io import DataLoader
|
|
||||||
from timer import timer
|
|
||||||
|
|
||||||
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
|
|
||||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
from parakeet.training.reporter import report
|
from parakeet.training.reporter import report
|
||||||
|
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||||
|
from parakeet.training.updaters.standard_updater import UpdaterState
|
||||||
|
from timer import timer
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class PWGUpdater(StandardUpdater):
|
class PWGUpdater(StandardUpdater):
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
models: Dict[str, Layer],
|
models: Dict[str, Layer],
|
||||||
optimizers: Dict[str, Optimizer],
|
optimizers: Dict[str, Optimizer],
|
||||||
criterions: Dict[str, Layer],
|
criterions: Dict[str, Layer],
|
||||||
schedulers: Dict[str, LRScheduler],
|
schedulers: Dict[str, LRScheduler],
|
||||||
dataloader: DataLoader,
|
dataloader: DataLoader,
|
||||||
discriminator_train_start_steps: int,
|
discriminator_train_start_steps: int,
|
||||||
lambda_adv: float, ):
|
lambda_adv: float,
|
||||||
|
output_dir=None):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.generator: Layer = models['generator']
|
self.generator: Layer = models['generator']
|
||||||
self.discriminator: Layer = models['discriminator']
|
self.discriminator: Layer = models['discriminator']
|
||||||
|
@ -61,7 +67,16 @@ class PWGUpdater(StandardUpdater):
|
||||||
|
|
||||||
self.train_iterator = iter(self.dataloader)
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
# parse batch
|
# parse batch
|
||||||
wav, mel = batch
|
wav, mel = batch
|
||||||
|
|
||||||
|
@ -70,7 +85,7 @@ class PWGUpdater(StandardUpdater):
|
||||||
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
logging.debug(f"Generator takes {t.elapse}s.")
|
# logging.debug(f"Generator takes {t.elapse}s.")
|
||||||
|
|
||||||
# initialize
|
# initialize
|
||||||
gen_loss = 0.0
|
gen_loss = 0.0
|
||||||
|
@ -78,10 +93,14 @@ class PWGUpdater(StandardUpdater):
|
||||||
## Multi-resolution stft loss
|
## Multi-resolution stft loss
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||||
|
|
||||||
report("train/spectral_convergence_loss", float(sc_loss))
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
|
||||||
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||||
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||||
|
|
||||||
gen_loss += sc_loss + mag_loss
|
gen_loss += sc_loss + mag_loss
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
|
@ -89,22 +108,24 @@ class PWGUpdater(StandardUpdater):
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
logging.debug(
|
# logging.debug(
|
||||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
# f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("train/adversarial_loss", float(adv_loss))
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||||
gen_loss += self.lambda_adv * adv_loss
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
report("train/generator_loss", float(gen_loss))
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||||||
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
self.optimizer_g.clear_grad()
|
self.optimizer_g.clear_grad()
|
||||||
gen_loss.backward()
|
gen_loss.backward()
|
||||||
logging.debug(f"Backward takes {t.elapse}s.")
|
# logging.debug(f"Backward takes {t.elapse}s.")
|
||||||
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
self.optimizer_g.step()
|
self.optimizer_g.step()
|
||||||
self.scheduler_g.step()
|
self.scheduler_g.step()
|
||||||
logging.debug(f"Update takes {t.elapse}s.")
|
# logging.debug(f"Update takes {t.elapse}s.")
|
||||||
|
|
||||||
# Disctiminator
|
# Disctiminator
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
@ -118,6 +139,9 @@ class PWGUpdater(StandardUpdater):
|
||||||
report("train/real_loss", float(real_loss))
|
report("train/real_loss", float(real_loss))
|
||||||
report("train/fake_loss", float(fake_loss))
|
report("train/fake_loss", float(fake_loss))
|
||||||
report("train/discriminator_loss", float(dis_loss))
|
report("train/discriminator_loss", float(dis_loss))
|
||||||
|
losses_dict["real_loss"] = float(real_loss)
|
||||||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||||||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||||
|
|
||||||
self.optimizer_d.clear_grad()
|
self.optimizer_d.clear_grad()
|
||||||
dis_loss.backward()
|
dis_loss.backward()
|
||||||
|
@ -125,9 +149,17 @@ class PWGUpdater(StandardUpdater):
|
||||||
self.optimizer_d.step()
|
self.optimizer_d.step()
|
||||||
self.scheduler_d.step()
|
self.scheduler_d.step()
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
class PWGEvaluator(StandardEvaluator):
|
class PWGEvaluator(StandardEvaluator):
|
||||||
def __init__(self, models, criterions, dataloader, lambda_adv):
|
def __init__(self,
|
||||||
|
models,
|
||||||
|
criterions,
|
||||||
|
dataloader,
|
||||||
|
lambda_adv,
|
||||||
|
output_dir=None):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.generator = models['generator']
|
self.generator = models['generator']
|
||||||
self.discriminator = models['discriminator']
|
self.discriminator = models['discriminator']
|
||||||
|
@ -139,34 +171,47 @@ class PWGEvaluator(StandardEvaluator):
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.lambda_adv = lambda_adv
|
self.lambda_adv = lambda_adv
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def evaluate_core(self, batch):
|
def evaluate_core(self, batch):
|
||||||
logging.debug("Evaluate: ")
|
# logging.debug("Evaluate: ")
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
wav, mel = batch
|
wav, mel = batch
|
||||||
noise = paddle.randn(wav.shape)
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
logging.debug(f"Generator takes {t.elapse}s")
|
# logging.debug(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
logging.debug(
|
# logging.debug(
|
||||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
# f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("eval/adversarial_loss", float(adv_loss))
|
report("eval/adversarial_loss", float(adv_loss))
|
||||||
|
losses_dict["adversarial_loss"] = float(adv_loss)
|
||||||
gen_loss = self.lambda_adv * adv_loss
|
gen_loss = self.lambda_adv * adv_loss
|
||||||
|
|
||||||
# stft loss
|
# stft loss
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
losses_dict["spectral_convergence_loss"] = float(sc_loss)
|
||||||
|
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
|
||||||
gen_loss += sc_loss + mag_loss
|
gen_loss += sc_loss + mag_loss
|
||||||
|
|
||||||
report("eval/generator_loss", float(gen_loss))
|
report("eval/generator_loss", float(gen_loss))
|
||||||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||||||
|
|
||||||
# Disctiminator
|
# Disctiminator
|
||||||
p = self.discriminator(wav)
|
p = self.discriminator(wav)
|
||||||
|
@ -176,3 +221,11 @@ class PWGEvaluator(StandardEvaluator):
|
||||||
report("eval/real_loss", float(real_loss))
|
report("eval/real_loss", float(real_loss))
|
||||||
report("eval/fake_loss", float(fake_loss))
|
report("eval/fake_loss", float(fake_loss))
|
||||||
report("eval/discriminator_loss", float(dis_loss))
|
report("eval/discriminator_loss", float(dis_loss))
|
||||||
|
|
||||||
|
losses_dict["real_loss"] = float(real_loss)
|
||||||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||||||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||||||
|
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
||||||
|
|
|
@ -23,11 +23,13 @@ import yaml
|
||||||
from paddle import DataParallel
|
from paddle import DataParallel
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.io import DataLoader, DistributedBatchSampler
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
from paddle.optimizer import Adam # No RAdaom
|
from paddle.optimizer import Adam # No RAdaom
|
||||||
from paddle.optimizer.lr import StepDecay
|
from paddle.optimizer.lr import StepDecay
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
from parakeet.models.parallel_wavegan import PWGDiscriminator
|
||||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
from parakeet.training.extensions.snapshot import Snapshot
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
from parakeet.training.extensions.visualizer import VisualDL
|
from parakeet.training.extensions.visualizer import VisualDL
|
||||||
|
@ -38,7 +40,8 @@ from visualdl import LogWriter
|
||||||
|
|
||||||
from batch_fn import Clip
|
from batch_fn import Clip
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
from pwg_updater import PWGUpdater, PWGEvaluator
|
from pwg_updater import PWGUpdater
|
||||||
|
from pwg_updater import PWGEvaluator
|
||||||
|
|
||||||
|
|
||||||
def train_sp(args, config):
|
def train_sp(args, config):
|
||||||
|
@ -99,11 +102,13 @@ def train_sp(args, config):
|
||||||
batch_max_steps=config.batch_max_steps,
|
batch_max_steps=config.batch_max_steps,
|
||||||
hop_size=config.hop_length,
|
hop_size=config.hop_length,
|
||||||
aux_context_window=config.generator_params.aux_context_window)
|
aux_context_window=config.generator_params.aux_context_window)
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_sampler=train_sampler,
|
batch_sampler=train_sampler,
|
||||||
collate_fn=train_batch_fn,
|
collate_fn=train_batch_fn,
|
||||||
num_workers=config.num_workers)
|
num_workers=config.num_workers)
|
||||||
|
|
||||||
dev_dataloader = DataLoader(
|
dev_dataloader = DataLoader(
|
||||||
dev_dataset,
|
dev_dataset,
|
||||||
batch_sampler=dev_sampler,
|
batch_sampler=dev_sampler,
|
||||||
|
@ -139,10 +144,8 @@ def train_sp(args, config):
|
||||||
print("optimizers done!")
|
print("optimizers done!")
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
checkpoint_dir = output_dir / "checkpoints"
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with open(output_dir / "config.yaml", 'wt') as f:
|
with open(output_dir / "config.yaml", 'wt') as f:
|
||||||
f.write(config.dump(default_flow_style=None))
|
f.write(config.dump(default_flow_style=None))
|
||||||
|
|
||||||
|
@ -165,7 +168,8 @@ def train_sp(args, config):
|
||||||
},
|
},
|
||||||
dataloader=train_dataloader,
|
dataloader=train_dataloader,
|
||||||
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||||
lambda_adv=config.lambda_adv, )
|
lambda_adv=config.lambda_adv,
|
||||||
|
output_dir=output_dir)
|
||||||
|
|
||||||
evaluator = PWGEvaluator(
|
evaluator = PWGEvaluator(
|
||||||
models={
|
models={
|
||||||
|
@ -177,21 +181,23 @@ def train_sp(args, config):
|
||||||
"mse": criterion_mse,
|
"mse": criterion_mse,
|
||||||
},
|
},
|
||||||
dataloader=dev_dataloader,
|
dataloader=dev_dataloader,
|
||||||
lambda_adv=config.lambda_adv, )
|
lambda_adv=config.lambda_adv,
|
||||||
|
output_dir=output_dir)
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
updater,
|
updater,
|
||||||
stop_trigger=(config.train_max_steps, "iteration"),
|
stop_trigger=(config.train_max_steps, "iteration"),
|
||||||
out=output_dir, )
|
out=output_dir, )
|
||||||
|
|
||||||
trainer.extend(evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
|
trainer.extend(
|
||||||
|
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||||
writer = LogWriter(str(trainer.out))
|
writer = LogWriter(str(trainer.out))
|
||||||
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
|
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
|
||||||
trainer.extend(
|
trainer.extend(
|
||||||
Snapshot(max_size=config.num_snapshots),
|
Snapshot(max_size=config.num_snapshots),
|
||||||
trigger=(config.save_interval_steps, 'iteration'))
|
trigger=(config.save_interval_steps, 'iteration'))
|
||||||
|
|
||||||
print(trainer.extensions.keys())
|
# print(trainer.extensions.keys())
|
||||||
print("Trainer Done!")
|
print("Trainer Done!")
|
||||||
trainer.run()
|
trainer.run()
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,10 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle import distributed as dist
|
||||||
from paddle.fluid.layers import huber_loss
|
from paddle.fluid.layers import huber_loss
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
from parakeet.modules.losses import masked_l1_loss, weighted_mean
|
from parakeet.modules.losses import masked_l1_loss, weighted_mean
|
||||||
|
@ -20,10 +22,32 @@ from parakeet.modules.ssim import ssim
|
||||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
from parakeet.training.reporter import report
|
from parakeet.training.reporter import report
|
||||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||||
|
logging.basicConfig(
|
||||||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class SpeedySpeechUpdater(StandardUpdater):
|
class SpeedySpeechUpdater(StandardUpdater):
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
dataloader,
|
||||||
|
init_state=None,
|
||||||
|
output_dir=None):
|
||||||
|
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
decoded, predicted_durations = self.model(
|
decoded, predicted_durations = self.model(
|
||||||
text=batch["phones"],
|
text=batch["phones"],
|
||||||
tones=batch["tones"],
|
tones=batch["tones"],
|
||||||
|
@ -65,9 +89,28 @@ class SpeedySpeechUpdater(StandardUpdater):
|
||||||
report("train/duration_loss", float(duration_loss))
|
report("train/duration_loss", float(duration_loss))
|
||||||
report("train/ssim_loss", float(ssim_loss))
|
report("train/ssim_loss", float(ssim_loss))
|
||||||
|
|
||||||
|
losses_dict["l1_loss"] = float(l1_loss)
|
||||||
|
losses_dict["duration_loss"] = float(duration_loss)
|
||||||
|
losses_dict["ssim_loss"] = float(ssim_loss)
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
|
||||||
|
|
||||||
class SpeedySpeechEvaluator(StandardEvaluator):
|
class SpeedySpeechEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self, model, dataloader, output_dir=None):
|
||||||
|
super().__init__(model, dataloader)
|
||||||
|
|
||||||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||||||
|
logger.addHandler(self.filehandler)
|
||||||
|
self.logger = logger
|
||||||
|
self.msg = ""
|
||||||
|
|
||||||
def evaluate_core(self, batch):
|
def evaluate_core(self, batch):
|
||||||
|
self.msg = "Evaluate: "
|
||||||
|
losses_dict = {}
|
||||||
|
|
||||||
decoded, predicted_durations = self.model(
|
decoded, predicted_durations = self.model(
|
||||||
text=batch["phones"],
|
text=batch["phones"],
|
||||||
tones=batch["tones"],
|
tones=batch["tones"],
|
||||||
|
@ -105,3 +148,11 @@ class SpeedySpeechEvaluator(StandardEvaluator):
|
||||||
report("eval/l1_loss", float(l1_loss))
|
report("eval/l1_loss", float(l1_loss))
|
||||||
report("eval/duration_loss", float(duration_loss))
|
report("eval/duration_loss", float(duration_loss))
|
||||||
report("eval/ssim_loss", float(ssim_loss))
|
report("eval/ssim_loss", float(ssim_loss))
|
||||||
|
|
||||||
|
losses_dict["l1_loss"] = float(l1_loss)
|
||||||
|
losses_dict["duration_loss"] = float(duration_loss)
|
||||||
|
losses_dict["ssim_loss"] = float(ssim_loss)
|
||||||
|
losses_dict["loss"] = float(loss)
|
||||||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||||
|
for k, v in losses_dict.items())
|
||||||
|
self.logger.info(self.msg)
|
||||||
|
|
|
@ -23,8 +23,9 @@ import yaml
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
from paddle import DataParallel
|
from paddle import DataParallel
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.io import DataLoader, DistributedBatchSampler
|
from paddle.io import DataLoader
|
||||||
from paddle.optimizer import Adam # No RAdaom
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from paddle.optimizer import Adam # No RAdam
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.speedyspeech import SpeedySpeech
|
from parakeet.models.speedyspeech import SpeedySpeech
|
||||||
from parakeet.training.extensions.snapshot import Snapshot
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
@ -36,7 +37,8 @@ from visualdl import LogWriter
|
||||||
|
|
||||||
from batch_fn import collate_baker_examples
|
from batch_fn import collate_baker_examples
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
from speedyspeech_updater import SpeedySpeechUpdater
|
||||||
|
from speedyspeech_updater import SpeedySpeechEvaluator
|
||||||
|
|
||||||
|
|
||||||
def train_sp(args, config):
|
def train_sp(args, config):
|
||||||
|
@ -121,13 +123,19 @@ def train_sp(args, config):
|
||||||
grad_clip=nn.ClipGradByGlobalNorm(5.0))
|
grad_clip=nn.ClipGradByGlobalNorm(5.0))
|
||||||
print("optimizer done!")
|
print("optimizer done!")
|
||||||
|
|
||||||
updater = SpeedySpeechUpdater(
|
|
||||||
model=model, optimizer=optimizer, dataloader=train_dataloader)
|
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
updater = SpeedySpeechUpdater(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
dataloader=train_dataloader,
|
||||||
|
output_dir=output_dir)
|
||||||
|
|
||||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||||
|
|
||||||
evaluator = SpeedySpeechEvaluator(model, dev_dataloader)
|
evaluator = SpeedySpeechEvaluator(
|
||||||
|
model, dev_dataloader, output_dir=output_dir)
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||||
|
|
|
@ -280,14 +280,12 @@ class FastSpeech2(nn.Layer):
|
||||||
use_batch_norm=use_batch_norm,
|
use_batch_norm=use_batch_norm,
|
||||||
dropout_rate=postnet_dropout_rate, ))
|
dropout_rate=postnet_dropout_rate, ))
|
||||||
|
|
||||||
|
nn.initializer.set_global_initializer(None)
|
||||||
|
|
||||||
self._reset_parameters(
|
self._reset_parameters(
|
||||||
init_enc_alpha=init_enc_alpha,
|
init_enc_alpha=init_enc_alpha,
|
||||||
init_dec_alpha=init_dec_alpha, )
|
init_dec_alpha=init_dec_alpha, )
|
||||||
|
|
||||||
# define criterions
|
|
||||||
self.criterion = FastSpeech2Loss(
|
|
||||||
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text: paddle.Tensor,
|
text: paddle.Tensor,
|
||||||
|
|
|
@ -20,7 +20,6 @@ from typing import List
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import six
|
import six
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from parakeet.training.extension import Extension
|
from parakeet.training.extension import Extension
|
||||||
from parakeet.training.extension import PRIORITY_READER
|
from parakeet.training.extension import PRIORITY_READER
|
||||||
|
@ -122,6 +121,7 @@ class Trainer(object):
|
||||||
entry.extension.initialize(self)
|
entry.extension.initialize(self)
|
||||||
|
|
||||||
update = self.updater.update # training step
|
update = self.updater.update # training step
|
||||||
|
|
||||||
stop_trigger = self.stop_trigger
|
stop_trigger = self.stop_trigger
|
||||||
|
|
||||||
# display only one progress bar
|
# display only one progress bar
|
||||||
|
@ -135,8 +135,6 @@ class Trainer(object):
|
||||||
else:
|
else:
|
||||||
max_iteration = self.stop_trigger.limit
|
max_iteration = self.stop_trigger.limit
|
||||||
|
|
||||||
p = tqdm.tqdm(initial=self.updater.state.iteration, total=max_iteration)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not stop_trigger(self):
|
while not stop_trigger(self):
|
||||||
self.observation = {}
|
self.observation = {}
|
||||||
|
@ -146,7 +144,21 @@ class Trainer(object):
|
||||||
# updating parameters and state
|
# updating parameters and state
|
||||||
with scope(self.observation):
|
with scope(self.observation):
|
||||||
update()
|
update()
|
||||||
p.update()
|
batch_read_time = self.updater.batch_read_time
|
||||||
|
batch_time = self.updater.batch_time
|
||||||
|
logger = self.updater.logger
|
||||||
|
logger.removeHandler(self.updater.filehandler)
|
||||||
|
msg = self.updater.msg
|
||||||
|
msg = " iter: {}/{}, ".format(self.updater.state.iteration,
|
||||||
|
max_iteration) + msg
|
||||||
|
msg += ", avg_reader_cost: {:.5f} sec, ".format(
|
||||||
|
batch_read_time
|
||||||
|
) + "avg_batch_cost: {:.5f} sec, ".format(batch_time)
|
||||||
|
msg += "avg_samples: {}, ".format(
|
||||||
|
self.updater.
|
||||||
|
batch_size) + "avg_ips: {:.5f} sequences/sec".format(
|
||||||
|
self.updater.batch_size / batch_time)
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
# execute extension when necessary
|
# execute extension when necessary
|
||||||
for name, entry in extensions:
|
for name, entry in extensions:
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -57,6 +58,8 @@ class StandardUpdater(UpdaterBase):
|
||||||
self.state = init_state
|
self.state = init_state
|
||||||
|
|
||||||
self.train_iterator = iter(dataloader)
|
self.train_iterator = iter(dataloader)
|
||||||
|
self.batch_read_time = 0
|
||||||
|
self.batch_time = 0
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
# We increase the iteration index after updating and before extension.
|
# We increase the iteration index after updating and before extension.
|
||||||
|
@ -99,8 +102,17 @@ class StandardUpdater(UpdaterBase):
|
||||||
layer.train()
|
layer.train()
|
||||||
|
|
||||||
# training for a step is implemented here
|
# training for a step is implemented here
|
||||||
|
time_before_read = time.time()
|
||||||
batch = self.read_batch()
|
batch = self.read_batch()
|
||||||
|
time_before_core = time.time()
|
||||||
self.update_core(batch)
|
self.update_core(batch)
|
||||||
|
self.batch_time = time.time() - time_before_core
|
||||||
|
self.batch_read_time = time_before_core - time_before_read
|
||||||
|
if isinstance(batch, dict):
|
||||||
|
self.batch_size = len(list(batch.items())[0][-1])
|
||||||
|
# for pwg
|
||||||
|
elif isinstance(batch, list):
|
||||||
|
self.batch_size = batch[0].shape[0]
|
||||||
|
|
||||||
self.state.iteration += 1
|
self.state.iteration += 1
|
||||||
if self.updates_per_epoch is not None:
|
if self.updates_per_epoch is not None:
|
||||||
|
|
Loading…
Reference in New Issue