an inference interface for speedyspeech and pwg
This commit is contained in:
parent
4a7888b8c6
commit
8b7dabbd8d
|
@ -27,30 +27,15 @@ import soundfile as sf
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
from paddle.io import DataLoader, DistributedBatchSampler
|
from yacs.config import CfgNode
|
||||||
from paddle.optimizer import Adam # No RAdaom
|
|
||||||
from paddle.optimizer.lr import StepDecay
|
|
||||||
from paddle import DataParallel
|
|
||||||
from visualdl import LogWriter
|
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.speedyspeech import SpeedySpeech
|
from parakeet.models.speedyspeech import SpeedySpeech, SpeedySpeechInference
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||||
|
from parakeet.modules.normalizer import ZScore
|
||||||
from parakeet.training.updater import UpdaterBase
|
|
||||||
from parakeet.training.trainer import Trainer
|
|
||||||
from parakeet.training.reporter import report
|
|
||||||
from parakeet.training import extension
|
|
||||||
from parakeet.training.extensions.snapshot import Snapshot
|
|
||||||
from parakeet.training.extensions.visualizer import VisualDL
|
|
||||||
from parakeet.training.seeding import seed_everything
|
|
||||||
|
|
||||||
from batch_fn import collate_baker_examples
|
|
||||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
|
||||||
from config import get_cfg_default
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, config):
|
def evaluate(args, speedyspeech_config, pwg_config):
|
||||||
# dataloader has been too verbose
|
# dataloader has been too verbose
|
||||||
logging.getLogger("DataLoader").disabled = True
|
logging.getLogger("DataLoader").disabled = True
|
||||||
|
|
||||||
|
@ -60,28 +45,32 @@ def evaluate(args, config):
|
||||||
test_dataset = DataTable(
|
test_dataset = DataTable(
|
||||||
data=test_metadata, fields=["utt_id", "phones", "tones"])
|
data=test_metadata, fields=["utt_id", "phones", "tones"])
|
||||||
|
|
||||||
model = SpeedySpeech(**config["model"])
|
model = SpeedySpeech(**speedyspeech_config["model"])
|
||||||
model.set_state_dict(paddle.load(args.checkpoint)["main_params"])
|
model.set_state_dict(
|
||||||
|
paddle.load(args.speedyspeech_checkpoint)["main_params"])
|
||||||
model.eval()
|
model.eval()
|
||||||
vocoder_config = yaml.safe_load(
|
|
||||||
open("../../parallelwave_gan/baker/conf/default.yaml"))
|
vocoder = PWGGenerator(**pwg_config["generator_params"])
|
||||||
vocoder = PWGGenerator(**vocoder_config["generator_params"])
|
vocoder.set_state_dict(paddle.load(args.pwg_params))
|
||||||
vocoder.set_state_dict(
|
|
||||||
paddle.load("../../parallelwave_gan/baker/converted.pdparams"))
|
|
||||||
vocoder.remove_weight_norm()
|
vocoder.remove_weight_norm()
|
||||||
vocoder.eval()
|
vocoder.eval()
|
||||||
# print(model)
|
|
||||||
print("model done!")
|
print("model done!")
|
||||||
|
|
||||||
stat = np.load("../../speedyspeech/baker/dump/train/stats.npy")
|
stat = np.load(args.speedyspeech_stat)
|
||||||
mu, std = stat
|
mu, std = stat
|
||||||
mu = paddle.to_tensor(mu)
|
mu = paddle.to_tensor(mu)
|
||||||
std = paddle.to_tensor(std)
|
std = paddle.to_tensor(std)
|
||||||
|
speedyspeech_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
stat2 = np.load("../../parallelwave_gan/baker/dump/train/stats.npy")
|
stat = np.load(args.pwg_stat)
|
||||||
mu2, std2 = stat2
|
mu, std = stat
|
||||||
mu2 = paddle.to_tensor(mu2)
|
mu = paddle.to_tensor(mu)
|
||||||
std2 = paddle.to_tensor(std2)
|
std = paddle.to_tensor(std)
|
||||||
|
pwg_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
|
||||||
|
model)
|
||||||
|
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -91,23 +80,45 @@ def evaluate(args, config):
|
||||||
phones = paddle.to_tensor(datum["phones"])
|
phones = paddle.to_tensor(datum["phones"])
|
||||||
tones = paddle.to_tensor(datum["tones"])
|
tones = paddle.to_tensor(datum["tones"])
|
||||||
|
|
||||||
mel, _ = model.inference(phones, tones)
|
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
|
||||||
mel = mel * std + mu
|
|
||||||
mel = (mel - mu2) / std2
|
|
||||||
|
|
||||||
wav = vocoder.inference(mel)
|
|
||||||
sf.write(
|
sf.write(
|
||||||
output_dir / (utt_id + ".wav"), wav.numpy(), samplerate=config.sr)
|
output_dir / (utt_id + ".wav"),
|
||||||
|
wav.numpy(),
|
||||||
|
samplerate=speedyspeech_config.sr)
|
||||||
print(f"{utt_id} done!")
|
print(f"{utt_id} done!")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# parse args and config and redirect to train_sp
|
# parse args and config and redirect to train_sp
|
||||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
parser = argparse.ArgumentParser(
|
||||||
"model with Baker Mandrin TTS dataset.")
|
description="Synthesize with speedyspeech & parallel wavegan.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config", type=str, help="config file to overwrite default config")
|
"--speedyspeech-config",
|
||||||
parser.add_argument("--checkpoint", type=str, help="checkpoint to load.")
|
type=str,
|
||||||
|
help="config file to overwrite default config")
|
||||||
|
parser.add_argument(
|
||||||
|
"--speedyspeech-checkpoint",
|
||||||
|
type=str,
|
||||||
|
help="speedyspeech checkpoint to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--speedyspeech-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-config",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-params",
|
||||||
|
type=str,
|
||||||
|
help="parallel wavegan generator parameters to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||||
|
)
|
||||||
parser.add_argument("--test-metadata", type=str, help="training data")
|
parser.add_argument("--test-metadata", type=str, help="training data")
|
||||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -115,16 +126,18 @@ def main():
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = get_cfg_default()
|
with open(args.speedyspeech_config) as f:
|
||||||
if args.config:
|
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
||||||
config.merge_from_file(args.config)
|
with open(args.pwg_config) as f:
|
||||||
|
pwg_config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
print("========Args========")
|
print("========Args========")
|
||||||
print(yaml.safe_dump(vars(args)))
|
print(yaml.safe_dump(vars(args)))
|
||||||
print("========Config========")
|
print("========Config========")
|
||||||
print(config)
|
print(speedyspeech_config)
|
||||||
|
print(pwg_config)
|
||||||
|
|
||||||
evaluate(args, config)
|
evaluate(args, speedyspeech_config, pwg_config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -768,3 +768,15 @@ class ResidualPWGDiscriminator(nn.Layer):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.apply(_remove_weight_norm)
|
self.apply(_remove_weight_norm)
|
||||||
|
|
||||||
|
|
||||||
|
class PWGInference(nn.Layer):
|
||||||
|
def __init__(self, normalizer, pwg_generator):
|
||||||
|
super().__init__()
|
||||||
|
self.normalizer = normalizer
|
||||||
|
self.pwg_generator = pwg_generator
|
||||||
|
|
||||||
|
def forward(self, logmel):
|
||||||
|
normalized_mel = self.normalizer(logmel)
|
||||||
|
wav = self.pwg_generator.inference(normalized_mel)
|
||||||
|
return wav
|
||||||
|
|
|
@ -211,4 +211,16 @@ class SpeedySpeech(nn.Layer):
|
||||||
t_dec, feature_size = shape[1], shape[2]
|
t_dec, feature_size = shape[1], shape[2]
|
||||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||||
decoded = self.decoder(encodings)
|
decoded = self.decoder(encodings)
|
||||||
return decoded[0], pred_durations[0]
|
return decoded[0]
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedySpeechInference(nn.Layer):
|
||||||
|
def __init__(self, normalizer, speedyspeech_model):
|
||||||
|
super().__init__()
|
||||||
|
self.normalizer = normalizer
|
||||||
|
self.acoustic_model = speedyspeech_model
|
||||||
|
|
||||||
|
def forward(self, phones, tones):
|
||||||
|
normalized_mel = self.acoustic_model.inference(phones, tones)
|
||||||
|
logmel = self.normalizer.inverse(normalized_mel)
|
||||||
|
return logmel
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ZScore(nn.Layer):
|
||||||
|
# feature last
|
||||||
|
def __init__(self, mu, sigma):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer("mu", mu)
|
||||||
|
self.register_buffer("sigma", sigma)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (x - self.mu) / self.sigma
|
||||||
|
|
||||||
|
def inverse(self, x):
|
||||||
|
return x * self.sigma + self.mu
|
Loading…
Reference in New Issue