an inference interface for speedyspeech and pwg
This commit is contained in:
parent
4a7888b8c6
commit
8b7dabbd8d
|
@ -27,30 +27,15 @@ import soundfile as sf
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.speedyspeech import SpeedySpeech
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training import extension
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.seeding import seed_everything
|
||||
|
||||
from batch_fn import collate_baker_examples
|
||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||
from config import get_cfg_default
|
||||
from parakeet.models.speedyspeech import SpeedySpeech, SpeedySpeechInference
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
|
||||
def evaluate(args, config):
|
||||
def evaluate(args, speedyspeech_config, pwg_config):
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
|
@ -60,28 +45,32 @@ def evaluate(args, config):
|
|||
test_dataset = DataTable(
|
||||
data=test_metadata, fields=["utt_id", "phones", "tones"])
|
||||
|
||||
model = SpeedySpeech(**config["model"])
|
||||
model.set_state_dict(paddle.load(args.checkpoint)["main_params"])
|
||||
model = SpeedySpeech(**speedyspeech_config["model"])
|
||||
model.set_state_dict(
|
||||
paddle.load(args.speedyspeech_checkpoint)["main_params"])
|
||||
model.eval()
|
||||
vocoder_config = yaml.safe_load(
|
||||
open("../../parallelwave_gan/baker/conf/default.yaml"))
|
||||
vocoder = PWGGenerator(**vocoder_config["generator_params"])
|
||||
vocoder.set_state_dict(
|
||||
paddle.load("../../parallelwave_gan/baker/converted.pdparams"))
|
||||
|
||||
vocoder = PWGGenerator(**pwg_config["generator_params"])
|
||||
vocoder.set_state_dict(paddle.load(args.pwg_params))
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder.eval()
|
||||
# print(model)
|
||||
print("model done!")
|
||||
|
||||
stat = np.load("../../speedyspeech/baker/dump/train/stats.npy")
|
||||
stat = np.load(args.speedyspeech_stat)
|
||||
mu, std = stat
|
||||
mu = paddle.to_tensor(mu)
|
||||
std = paddle.to_tensor(std)
|
||||
speedyspeech_normalizer = ZScore(mu, std)
|
||||
|
||||
stat2 = np.load("../../parallelwave_gan/baker/dump/train/stats.npy")
|
||||
mu2, std2 = stat2
|
||||
mu2 = paddle.to_tensor(mu2)
|
||||
std2 = paddle.to_tensor(std2)
|
||||
stat = np.load(args.pwg_stat)
|
||||
mu, std = stat
|
||||
mu = paddle.to_tensor(mu)
|
||||
std = paddle.to_tensor(std)
|
||||
pwg_normalizer = ZScore(mu, std)
|
||||
|
||||
speedyspeech_inferencce = SpeedySpeechInference(speedyspeech_normalizer,
|
||||
model)
|
||||
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -91,23 +80,45 @@ def evaluate(args, config):
|
|||
phones = paddle.to_tensor(datum["phones"])
|
||||
tones = paddle.to_tensor(datum["tones"])
|
||||
|
||||
mel, _ = model.inference(phones, tones)
|
||||
mel = mel * std + mu
|
||||
mel = (mel - mu2) / std2
|
||||
|
||||
wav = vocoder.inference(mel)
|
||||
wav = pwg_inference(speedyspeech_inferencce(phones, tones))
|
||||
sf.write(
|
||||
output_dir / (utt_id + ".wav"), wav.numpy(), samplerate=config.sr)
|
||||
output_dir / (utt_id + ".wav"),
|
||||
wav.numpy(),
|
||||
samplerate=speedyspeech_config.sr)
|
||||
print(f"{utt_id} done!")
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with speedyspeech & parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--checkpoint", type=str, help="checkpoint to load.")
|
||||
"--speedyspeech-config",
|
||||
type=str,
|
||||
help="config file to overwrite default config")
|
||||
parser.add_argument(
|
||||
"--speedyspeech-checkpoint",
|
||||
type=str,
|
||||
help="speedyspeech checkpoint to load.")
|
||||
parser.add_argument(
|
||||
"--speedyspeech-stat",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
help="parallel wavegan generator parameters to load.")
|
||||
parser.add_argument(
|
||||
"--pwg-stat",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training speedyspeech."
|
||||
)
|
||||
parser.add_argument("--test-metadata", type=str, help="training data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
|
@ -115,16 +126,18 @@ def main():
|
|||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
with open(args.speedyspeech_config) as f:
|
||||
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
||||
with open(args.pwg_config) as f:
|
||||
pwg_config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(speedyspeech_config)
|
||||
print(pwg_config)
|
||||
|
||||
evaluate(args, config)
|
||||
evaluate(args, speedyspeech_config, pwg_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -768,3 +768,15 @@ class ResidualPWGDiscriminator(nn.Layer):
|
|||
pass
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
class PWGInference(nn.Layer):
|
||||
def __init__(self, normalizer, pwg_generator):
|
||||
super().__init__()
|
||||
self.normalizer = normalizer
|
||||
self.pwg_generator = pwg_generator
|
||||
|
||||
def forward(self, logmel):
|
||||
normalized_mel = self.normalizer(logmel)
|
||||
wav = self.pwg_generator.inference(normalized_mel)
|
||||
return wav
|
||||
|
|
|
@ -211,4 +211,16 @@ class SpeedySpeech(nn.Layer):
|
|||
t_dec, feature_size = shape[1], shape[2]
|
||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||
decoded = self.decoder(encodings)
|
||||
return decoded[0], pred_durations[0]
|
||||
return decoded[0]
|
||||
|
||||
|
||||
class SpeedySpeechInference(nn.Layer):
|
||||
def __init__(self, normalizer, speedyspeech_model):
|
||||
super().__init__()
|
||||
self.normalizer = normalizer
|
||||
self.acoustic_model = speedyspeech_model
|
||||
|
||||
def forward(self, phones, tones):
|
||||
normalized_mel = self.acoustic_model.inference(phones, tones)
|
||||
logmel = self.normalizer.inverse(normalized_mel)
|
||||
return logmel
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class ZScore(nn.Layer):
|
||||
# feature last
|
||||
def __init__(self, mu, sigma):
|
||||
super().__init__()
|
||||
self.register_buffer("mu", mu)
|
||||
self.register_buffer("sigma", sigma)
|
||||
|
||||
def forward(self, x):
|
||||
return (x - self.mu) / self.sigma
|
||||
|
||||
def inverse(self, x):
|
||||
return x * self.sigma + self.mu
|
Loading…
Reference in New Issue