fix pwg
This commit is contained in:
parent
3ac2e01263
commit
796fafbac8
|
@ -21,10 +21,10 @@ from typing import List, Dict, Any
|
|||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
from parakeet.data.get_feats import LogMelFBank, Energy, Pitch
|
||||
import tqdm
|
||||
|
||||
from config import get_cfg_default
|
||||
from get_feats import LogMelFBank, Energy, Pitch
|
||||
|
||||
|
||||
def get_phn_dur(file_name):
|
||||
|
|
|
@ -94,7 +94,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"--fastspeech2-config",
|
||||
type=str,
|
||||
help="config file to overwrite default config")
|
||||
help="config file to overwrite default config.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-checkpoint",
|
||||
type=str,
|
||||
|
@ -123,11 +123,11 @@ def main():
|
|||
type=str,
|
||||
default="phone_id_map.txt",
|
||||
help="phone vocabulary file.")
|
||||
parser.add_argument("--test-metadata", type=str, help="test metadata")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument("--test-metadata", type=str, help="test metadata.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
"--device", type=str, default="gpu", help="device type to use.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
with open(args.fastspeech2_config) as f:
|
||||
|
|
|
@ -99,7 +99,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"--fastspeech2-config",
|
||||
type=str,
|
||||
help="config file to overwrite default config")
|
||||
help="fastspeech2 config file to overwrite default config.")
|
||||
parser.add_argument(
|
||||
"--fastspeech2-checkpoint",
|
||||
type=str,
|
||||
|
@ -112,8 +112,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"--pwg-config",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||
)
|
||||
help="parallel wavegan config file to overwrite default config.")
|
||||
parser.add_argument(
|
||||
"--pwg-params",
|
||||
type=str,
|
||||
|
@ -131,11 +130,11 @@ def main():
|
|||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
help="text to synthesize, a 'utt_id sentence' pair per line")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
help="text to synthesize, a 'utt_id sentence' pair per line.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
"--device", type=str, default="gpu", help="device type to use.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
with open(args.fastspeech2_config) as f:
|
||||
|
|
|
@ -169,18 +169,18 @@ def train_sp(args, config):
|
|||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
parser = argparse.ArgumentParser(description="Train a FastSpeech2 "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
"--config", type=str, help="config file to overwrite default config.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
"--device", type=str, default="gpu", help="device type to use.")
|
||||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
"--nprocs", type=int, default=1, help="number of processes.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
parser.add_argument(
|
||||
"--phones-dict",
|
||||
type=str,
|
||||
|
|
|
@ -27,10 +27,14 @@ class Clip(object):
|
|||
aux_context_window=0, ):
|
||||
"""Initialize customized collater for DataLoader.
|
||||
|
||||
Args:
|
||||
batch_max_steps (int): The maximum length of input signal in batch.
|
||||
hop_size (int): Hop size of auxiliary features.
|
||||
aux_context_window (int): Context window size for auxiliary feature conv.
|
||||
Parameters
|
||||
----------
|
||||
batch_max_steps : int
|
||||
The maximum length of input signal in batch.
|
||||
hop_size : int
|
||||
Hop size of auxiliary features.
|
||||
aux_context_window : int
|
||||
Context window size for auxiliary feature conv.
|
||||
|
||||
"""
|
||||
if batch_max_steps % hop_size != 0:
|
||||
|
@ -49,14 +53,18 @@ class Clip(object):
|
|||
def __call__(self, examples):
|
||||
"""Convert into batch tensors.
|
||||
|
||||
Args:
|
||||
batch (list): list of tuple of the pair of audio and features. Audio shape
|
||||
(T, ), features shape(T', C).
|
||||
Parameters
|
||||
----------
|
||||
batch : list
|
||||
list of tuple of the pair of audio and features. Audio shape (T, ), features shape(T', C).
|
||||
|
||||
Returns:
|
||||
Tensor: Auxiliary feature batch (B, C, T'), where
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Auxiliary feature batch (B, C, T'), where
|
||||
T = (T' - 2 * aux_context_window) * hop_size.
|
||||
Tensor: Target signal batch (B, 1, T).
|
||||
Tensor
|
||||
Target signal batch (B, 1, T).
|
||||
|
||||
"""
|
||||
# check length
|
||||
|
@ -93,7 +101,8 @@ class Clip(object):
|
|||
def _adjust_length(self, x, c):
|
||||
"""Adjust the audio and feature lengths.
|
||||
|
||||
Note:
|
||||
Note
|
||||
-------
|
||||
Basically we assume that the length of x and c are adjusted
|
||||
through preprocessing stage, but if we use other library processed
|
||||
features, this process will be needed.
|
||||
|
|
|
@ -82,7 +82,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
|
|||
###########################################################
|
||||
# DATA LOADER SETTING #
|
||||
###########################################################
|
||||
batch_size: 6 # Batch size.
|
||||
batch_size: 8 # Batch size.
|
||||
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
|
||||
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
|
||||
num_workers: 4 # Number of workers in Pytorch DataLoader.
|
||||
|
|
|
@ -12,88 +12,28 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import numpy as np
|
||||
import argparse
|
||||
import yaml
|
||||
import json
|
||||
import jsonlines
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
from operator import itemgetter
|
||||
from praatio import tgio
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import argparse
|
||||
import jsonlines
|
||||
import librosa
|
||||
import logging
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from parakeet.data.get_feats import LogMelFBank
|
||||
from pathlib import Path
|
||||
from praatio import tgio
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def logmelfilterbank(audio,
|
||||
sr,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
n_mels=80,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10):
|
||||
"""Compute log-Mel filterbank feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : ndarray
|
||||
Audio signal (T,).
|
||||
sr : int
|
||||
Sampling rate.
|
||||
n_fft : int
|
||||
FFT size. (Default value = 1024)
|
||||
hop_length : int
|
||||
Hop size. (Default value = 256)
|
||||
win_length : int
|
||||
Window length. If set to None, it will be the same as fft_size. (Default value = None)
|
||||
window : str
|
||||
Window function type. (Default value = "hann")
|
||||
n_mels : int
|
||||
Number of mel basis. (Default value = 80)
|
||||
fmin : int
|
||||
Minimum frequency in mel basis calculation. (Default value = None)
|
||||
fmax : int
|
||||
Maximum frequency in mel basis calculation. (Default value = None)
|
||||
eps : float
|
||||
Epsilon value to avoid inf in log calculation. (Default value = 1e-10)
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Log Mel filterbank feature (#frames, num_mels).
|
||||
|
||||
"""
|
||||
# get amplitude spectrogram
|
||||
x_stft = librosa.stft(
|
||||
audio,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode="reflect")
|
||||
spc = np.abs(x_stft) # (#bins, #frames,)
|
||||
|
||||
# get mel basis
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = sr / 2 if fmax is None else fmax
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
|
||||
return np.log10(np.maximum(eps, np.dot(mel_basis, spc)))
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
alignment_fp: Path,
|
||||
output_dir: Path):
|
||||
output_dir: Path,
|
||||
mel_extractor=None):
|
||||
utt_id = fp.stem
|
||||
|
||||
# reading
|
||||
|
@ -134,19 +74,11 @@ def process_sentence(config: Dict[str, Any],
|
|||
frame_length=config.trim_frame_length,
|
||||
hop_length=config.trim_hop_length)
|
||||
|
||||
logmel = logmelfilterbank(
|
||||
y,
|
||||
sr=sr,
|
||||
n_fft=config.n_fft,
|
||||
window=config.window,
|
||||
win_length=config.win_length,
|
||||
hop_length=config.hop_length,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(y)
|
||||
|
||||
# adjust time to make num_samples == num_frames * hop_length
|
||||
num_frames = logmel.shape[1]
|
||||
num_frames = logmel.shape[0]
|
||||
if y.size < num_frames * config.hop_length:
|
||||
y = np.pad(y, (0, num_frames * config.hop_length - y.size),
|
||||
mode="reflect")
|
||||
|
@ -157,7 +89,7 @@ def process_sentence(config: Dict[str, Any],
|
|||
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||
wav_path = output_dir / (utt_id + "_wave.npy")
|
||||
np.save(wav_path, y) # (num_samples, )
|
||||
np.save(mel_path, logmel.T) # (num_frames, n_mels)
|
||||
np.save(mel_path, logmel) # (num_frames, n_mels)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"num_samples": num_sample,
|
||||
|
@ -172,19 +104,22 @@ def process_sentences(config,
|
|||
fps: List[Path],
|
||||
alignment_fps: List[Path],
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
nprocs: int=1):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp, alignment_fp in tqdm.tqdm(zip(fps, alignment_fps)):
|
||||
results.append(
|
||||
process_sentence(config, fp, alignment_fp, output_dir))
|
||||
process_sentence(config, fp, alignment_fp, output_dir,
|
||||
mel_extractor))
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
alignment_fp, output_dir)
|
||||
alignment_fp, output_dir,
|
||||
mel_extractor)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
|
@ -260,24 +195,37 @@ def main():
|
|||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=C.sr,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.hop_length,
|
||||
win_length=C.win_length,
|
||||
window=C.window,
|
||||
n_mels=C.n_mels,
|
||||
fmin=C.fmin,
|
||||
fmax=C.fmax)
|
||||
|
||||
# process for the 3 sections
|
||||
process_sentences(
|
||||
C,
|
||||
train_wav_files,
|
||||
train_alignment_files,
|
||||
train_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
dev_wav_files,
|
||||
dev_alignment_files,
|
||||
dev_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
test_wav_files,
|
||||
test_alignment_files,
|
||||
test_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
|
||||
|
||||
|
|
|
@ -78,16 +78,17 @@ class PWGUpdater(StandardUpdater):
|
|||
wav_ = self.generator(noise, mel)
|
||||
logging.debug(f"Generator takes {t.elapse}s.")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
# initialize
|
||||
gen_loss = 0.0
|
||||
|
||||
## Multi-resolution stft loss
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||
gen_loss = sc_loss + mag_loss
|
||||
gen_loss += sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
|
@ -119,9 +120,9 @@ class PWGUpdater(StandardUpdater):
|
|||
p_ = self.discriminator(wav_.detach())
|
||||
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("train/real_loss", float(real_loss))
|
||||
report("train/fake_loss", float(fake_loss))
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("train/discriminator_loss", float(dis_loss))
|
||||
|
||||
self.optimizer_d.clear_grad()
|
||||
|
@ -164,8 +165,7 @@ class PWGEvaluator(StandardEvaluator):
|
|||
|
||||
# stft loss
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
|
@ -178,7 +178,7 @@ class PWGEvaluator(StandardEvaluator):
|
|||
p = self.discriminator(wav)
|
||||
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("eval/real_loss", float(real_loss))
|
||||
report("eval/fake_loss", float(fake_loss))
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("eval/discriminator_loss", float(dis_loss))
|
||||
|
|
|
@ -32,14 +32,14 @@ from parakeet.models.parallel_wavegan import PWGGenerator
|
|||
from config import get_cfg_default
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="synthesize with parallel wavegan.")
|
||||
description="Synthesize with parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--checkpoint", type=str, help="snapshot to load")
|
||||
parser.add_argument("--test-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument("--device", type=str, default="gpu", help="device to run")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
"--config", type=str, help="config file to overwrite default config.")
|
||||
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
||||
parser.add_argument("--test-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument("--device", type=str, default="gpu", help="device to run.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = get_cfg_default()
|
||||
|
@ -89,5 +89,5 @@ for example in test_dataset:
|
|||
print(
|
||||
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}."
|
||||
)
|
||||
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr)
|
||||
sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.sr)
|
||||
print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }")
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
python3 synthesize.py \
|
||||
--config=conf/default.yaml \
|
||||
--checkpoint=exp/default/checkpoints/snapshot_iter_220000.pdz \
|
||||
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||
--output-dir=exp/debug/test
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from parakeet.data.get_feats import LogMelFBank
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGInference
|
||||
from parakeet.modules.normalizer import ZScore
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def evaluate(args, config):
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
vocoder = PWGGenerator(**config["generator_params"])
|
||||
state_dict = paddle.load(args.checkpoint)
|
||||
vocoder.set_state_dict(state_dict["generator_params"])
|
||||
vocoder.remove_weight_norm()
|
||||
vocoder.eval()
|
||||
print("model done!")
|
||||
|
||||
stat = np.load(args.stat)
|
||||
mu, std = stat
|
||||
mu = paddle.to_tensor(mu)
|
||||
std = paddle.to_tensor(std)
|
||||
normalizer = ZScore(mu, std)
|
||||
|
||||
pwg_inference = PWGInference(normalizer, vocoder)
|
||||
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=config.sr,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.hop_length,
|
||||
win_length=config.win_length,
|
||||
window=config.window,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
|
||||
for utt_name in os.listdir(input_dir):
|
||||
wav, _ = librosa.load(str(input_dir / utt_name), sr=config.sr)
|
||||
# extract mel feats
|
||||
mel = mel_extractor.get_log_mel_fbank(wav)
|
||||
mel = paddle.to_tensor(mel)
|
||||
gen_wav = pwg_inference(mel)
|
||||
sf.write(
|
||||
str(output_dir / ("gen_" + utt_name)),
|
||||
gen_wav.numpy(),
|
||||
samplerate=config.sr)
|
||||
print(f"{utt_name} done!")
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with parallel wavegan.")
|
||||
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config.")
|
||||
parser.add_argument("--checkpoint", type=str, help="snapshot to load.")
|
||||
parser.add_argument(
|
||||
"--stat",
|
||||
type=str,
|
||||
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||
)
|
||||
parser.add_argument("--input-dir", type=str, help="input dir of wavs.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device to run.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
|
||||
evaluate(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -12,36 +12,29 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import os
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training import extension
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.seeding import seed_everything
|
||||
from parakeet.training.trainer import Trainer
|
||||
from pathlib import Path
|
||||
from visualdl import LogWriter
|
||||
|
||||
from batch_fn import Clip
|
||||
from config import get_cfg_default
|
||||
|
@ -210,15 +203,15 @@ def main():
|
|||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
"--config", type=str, help="config file to overwrite default config.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
"--device", type=str, default="gpu", help="device type to use.")
|
||||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
"--nprocs", type=int, default=1, help="number of processes.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device == "cpu" and args.nprocs > 1:
|
||||
|
|
|
@ -12,94 +12,34 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import numpy as np
|
||||
import argparse
|
||||
import yaml
|
||||
import json
|
||||
import re
|
||||
import jsonlines
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
from operator import itemgetter
|
||||
from praatio import tgio
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import argparse
|
||||
import jsonlines
|
||||
import librosa
|
||||
import logging
|
||||
import numpy as np
|
||||
import re
|
||||
import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from parakeet.data.get_feats import LogMelFBank
|
||||
from pathlib import Path
|
||||
from praatio import tgio
|
||||
|
||||
from config import get_cfg_default
|
||||
from tg_utils import validate_textgrid
|
||||
|
||||
|
||||
def logmelfilterbank(audio,
|
||||
sr,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
n_mels=80,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10):
|
||||
"""Compute log-Mel filterbank feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : ndarray
|
||||
Audio signal (T,).
|
||||
sr : int
|
||||
Sampling rate.
|
||||
n_fft : int
|
||||
FFT size. (Default value = 1024)
|
||||
hop_length : int
|
||||
Hop size. (Default value = 256)
|
||||
win_length : int
|
||||
Window length. If set to None, it will be the same as fft_size. (Default value = None)
|
||||
window : str
|
||||
Window function type. (Default value = "hann")
|
||||
n_mels : int
|
||||
Number of mel basis. (Default value = 80)
|
||||
fmin : int
|
||||
Minimum frequency in mel basis calculation. (Default value = None)
|
||||
fmax : int
|
||||
Maximum frequency in mel basis calculation. (Default value = None)
|
||||
eps : float
|
||||
Epsilon value to avoid inf in log calculation. (Default value = 1e-10)
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Log Mel filterbank feature (#frames, num_mels).
|
||||
|
||||
"""
|
||||
# get amplitude spectrogram
|
||||
x_stft = librosa.stft(
|
||||
audio,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode="reflect")
|
||||
spc = np.abs(x_stft) # (#bins, #frames,)
|
||||
|
||||
# get mel basis
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = sr / 2 if fmax is None else fmax
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
|
||||
return np.log10(np.maximum(eps, np.dot(mel_basis, spc)))
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
alignment_fp: Path,
|
||||
output_dir: Path):
|
||||
output_dir: Path,
|
||||
mel_extractor=None):
|
||||
utt_id = fp.stem
|
||||
|
||||
# reading
|
||||
y, sr = librosa.load(fp, sr=config.sr) # resampling may occur
|
||||
y, sr = librosa.load(str(fp), sr=config.sr) # resampling may occur
|
||||
assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(y).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
|
@ -125,16 +65,8 @@ def process_sentence(config: Dict[str, Any],
|
|||
f" There is something wrong with the last interval {last} in utterance: {utt_id}"
|
||||
)
|
||||
|
||||
logmel = logmelfilterbank(
|
||||
y,
|
||||
sr=sr,
|
||||
n_fft=config.n_fft,
|
||||
window=config.window,
|
||||
win_length=config.win_length,
|
||||
hop_length=config.hop_length,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(y)
|
||||
|
||||
# extract phone and duration
|
||||
phones = []
|
||||
|
@ -162,7 +94,7 @@ def process_sentence(config: Dict[str, Any],
|
|||
ends, sr=sr, hop_length=config.hop_length)
|
||||
durations_frame = np.diff(frame_pos, prepend=0)
|
||||
|
||||
num_frames = logmel.shape[-1] # number of frames of the spectrogram
|
||||
num_frames = logmel.shape[0] # number of frames of the spectrogram
|
||||
extra = np.sum(durations_frame) - num_frames
|
||||
assert extra <= 0, (
|
||||
f"Number of frames inferred from alignemnt is "
|
||||
|
@ -173,7 +105,7 @@ def process_sentence(config: Dict[str, Any],
|
|||
durations_frame = durations_frame.tolist()
|
||||
|
||||
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||
np.save(mel_path, logmel.T) # (num_frames, n_mels)
|
||||
np.save(mel_path, logmel) # (num_frames, n_mels)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
|
@ -190,20 +122,23 @@ def process_sentences(config,
|
|||
fps: List[Path],
|
||||
alignment_fps: List[Path],
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
nprocs: int=1):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp, alignment_fp in tqdm.tqdm(
|
||||
zip(fps, alignment_fps), total=len(fps)):
|
||||
results.append(
|
||||
process_sentence(config, fp, alignment_fp, output_dir))
|
||||
process_sentence(config, fp, alignment_fp, output_dir,
|
||||
mel_extractor))
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
alignment_fp, output_dir)
|
||||
alignment_fp, output_dir,
|
||||
mel_extractor)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
|
@ -284,24 +219,37 @@ def main():
|
|||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=C.sr,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.hop_length,
|
||||
win_length=C.win_length,
|
||||
window=C.window,
|
||||
n_mels=C.n_mels,
|
||||
fmin=C.fmin,
|
||||
fmax=C.fmax)
|
||||
|
||||
# process for the 3 sections
|
||||
process_sentences(
|
||||
C,
|
||||
train_wav_files,
|
||||
train_alignment_files,
|
||||
train_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
dev_wav_files,
|
||||
dev_alignment_files,
|
||||
dev_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
test_wav_files,
|
||||
test_alignment_files,
|
||||
test_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
|
||||
|
||||
|
|
|
@ -12,40 +12,31 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
|
||||
import yaml
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import distributed as dist
|
||||
from paddle import DataParallel
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.speedyspeech import SpeedySpeech
|
||||
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training import extension
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.seeding import seed_everything
|
||||
from parakeet.training.trainer import Trainer
|
||||
from pathlib import Path
|
||||
from visualdl import LogWriter
|
||||
|
||||
from batch_fn import collate_baker_examples
|
||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||
from config import get_cfg_default
|
||||
from speedyspeech_updater import SpeedySpeechUpdater, SpeedySpeechEvaluator
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
|
@ -93,10 +84,6 @@ def train_sp(args, config):
|
|||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=True)
|
||||
# dev_sampler = DistributedBatchSampler(dev_dataset,
|
||||
# batch_size=config.batch_size,
|
||||
# shuffle=False,
|
||||
# drop_last=False)
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
|
@ -113,13 +100,9 @@ def train_sp(args, config):
|
|||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
|
||||
# # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
||||
# import pdb; pdb.set_trace()
|
||||
model = SpeedySpeech(**config["model"])
|
||||
if world_size > 1:
|
||||
model = DataParallel(model) # TODO, do not use vocab size from config
|
||||
# print(model)
|
||||
print("model done!")
|
||||
optimizer = Adam(
|
||||
0.001,
|
||||
|
@ -147,18 +130,18 @@ def train_sp(args, config):
|
|||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
parser = argparse.ArgumentParser(description="Train a SpeedySpeech "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
"--config", type=str, help="config file to overwrite default config.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
"--device", type=str, default="gpu", help="device type to use.")
|
||||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
"--nprocs", type=int, default=1, help="number of processes.")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device == "cpu" and args.nprocs > 1:
|
||||
|
|
|
@ -27,5 +27,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from parakeet.data.dataset import *
|
||||
from parakeet.data.batch import *
|
||||
from parakeet.data.dataset import *
|
||||
from parakeet.data.get_feats import *
|
||||
|
|
|
@ -17,8 +17,6 @@ import numpy as np
|
|||
import pyworld
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
class LogMelFBank():
|
||||
def __init__(self,
|
||||
|
@ -42,8 +40,8 @@ class LogMelFBank():
|
|||
|
||||
# mel
|
||||
self.n_mels = n_mels
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmin = 0 if fmin is None else fmin
|
||||
self.fmax = sr / 2 if fmax is None else fmax
|
||||
|
||||
self.mel_filter = self._create_mel_filter()
|
||||
|
||||
|
@ -217,41 +215,3 @@ class Energy():
|
|||
if use_token_averaged_energy and duration is not None:
|
||||
energy = self._average_by_duration(energy, duration)
|
||||
return energy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
C = get_cfg_default()
|
||||
filename = "../raw_data/data/format.1/000001.flac"
|
||||
wav, _ = librosa.load(filename, sr=C.fs)
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=C.fs,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.n_shift,
|
||||
win_length=C.win_length,
|
||||
window=C.window,
|
||||
n_mels=C.n_mels,
|
||||
fmin=C.fmin,
|
||||
fmax=C.fmax, )
|
||||
mel = mel_extractor.get_log_mel_fbank(wav)
|
||||
print(mel)
|
||||
print(mel.shape)
|
||||
|
||||
pitch_extractor = Pitch(
|
||||
sr=C.fs, hop_length=C.n_shift, f0min=C.f0min, f0max=C.f0max)
|
||||
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
|
||||
duration = np.array([int(x) for x in duration.split(" ")])
|
||||
avg_f0 = pitch_extractor.get_pitch(wav, duration=duration)
|
||||
print(avg_f0)
|
||||
print(avg_f0.shape)
|
||||
|
||||
energy_extractor = Energy(
|
||||
sr=C.fs,
|
||||
n_fft=C.n_fft,
|
||||
hop_length=C.n_shift,
|
||||
win_length=C.win_length,
|
||||
window=C.window)
|
||||
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
|
||||
duration = np.array([int(x) for x in duration.split(" ")])
|
||||
avg_energy = energy_extractor.get_energy(wav, duration=duration)
|
||||
print(avg_energy)
|
||||
print(avg_energy.sum())
|
|
@ -109,4 +109,5 @@ class Frontend():
|
|||
def get_phonemes(self, sentence):
|
||||
sentences = self.text_normalizer.normalize(sentence)
|
||||
phonemes = self._g2p(sentences)
|
||||
print("phonemes:", phonemes)
|
||||
return phonemes
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
from typing import Dict, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from parakeet.modules.fastspeech2_predictor.duration_predictor import DurationPredictor, DurationPredictorLoss
|
||||
|
|
|
@ -132,7 +132,7 @@ class DurationPredictor(nn.Layer):
|
|||
|
||||
Returns
|
||||
----------
|
||||
LongTensor
|
||||
Tensor
|
||||
Batch of predicted durations in linear domain int64 (B, Tmax).
|
||||
"""
|
||||
return self._forward(xs, x_masks, True)
|
||||
|
@ -166,7 +166,7 @@ class DurationPredictorLoss(nn.Layer):
|
|||
----------
|
||||
outputs : Tensor
|
||||
Batch of prediction durations in log domain (B, T)
|
||||
targets : LongTensor
|
||||
targets : Tensor
|
||||
Batch of groundtruth durations in linear domain (B, T)
|
||||
|
||||
Returns
|
||||
|
|
|
@ -31,7 +31,7 @@ class PositionalEncoding(nn.Layer):
|
|||
max_len : int
|
||||
Maximum input length.
|
||||
reverse : bool
|
||||
Whether to reverse the input position. Only for
|
||||
Whether to reverse the input position.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
|
|
|
@ -29,8 +29,8 @@ class SpectralConvergenceLoss(nn.Layer):
|
|||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, C, T).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, C, T).
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
"""
|
||||
|
@ -50,11 +50,16 @@ class LogSTFTMagnitudeLoss(nn.Layer):
|
|||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
Parameters
|
||||
----------
|
||||
x_mag : Tensor
|
||||
Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag : Tensor
|
||||
Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Log STFT magnitude loss value.
|
||||
"""
|
||||
return F.l1_loss(
|
||||
paddle.log(paddle.clip(
|
||||
|
@ -86,15 +91,23 @@ class STFTLoss(nn.Layer):
|
|||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
Predicted signal (B, T).
|
||||
y : Tensor
|
||||
Groundtruth signal (B, T).
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Spectral convergence loss value.
|
||||
Tensor
|
||||
Log STFT magnitude loss value.
|
||||
"""
|
||||
x_mag = self.stft.magnitude(x)
|
||||
y_mag = self.stft.magnitude(y)
|
||||
x_mag = x_mag.transpose([0, 2, 1])
|
||||
y_mag = y_mag.transpose([0, 2, 1])
|
||||
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
||||
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||
|
||||
|
@ -111,11 +124,16 @@ class MultiResolutionSTFTLoss(nn.Layer):
|
|||
win_lengths=[600, 1200, 240],
|
||||
window="hann", ):
|
||||
"""Initialize Multi resolution STFT loss module.
|
||||
Args:
|
||||
fft_sizes (list): List of FFT sizes.
|
||||
hop_sizes (list): List of hop sizes.
|
||||
win_lengths (list): List of window lengths.
|
||||
window (str): Window function type.
|
||||
Parameters
|
||||
----------
|
||||
fft_sizes : list
|
||||
List of FFT sizes.
|
||||
hop_sizes : list
|
||||
List of hop sizes.
|
||||
win_lengths : list
|
||||
List of window lengths.
|
||||
window : str
|
||||
Window function type.
|
||||
"""
|
||||
super().__init__()
|
||||
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||
|
@ -125,13 +143,24 @@ class MultiResolutionSTFTLoss(nn.Layer):
|
|||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
Returns:
|
||||
Tensor: Multi resolution spectral convergence loss value.
|
||||
Tensor: Multi resolution log STFT magnitude loss value.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
Predicted signal (B, T) or (B, #subband, T).
|
||||
y : Tensor
|
||||
Groundtruth signal (B, T) or (B, #subband, T).
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Multi resolution spectral convergence loss value.
|
||||
Tensor
|
||||
Multi resolution log STFT magnitude loss value.
|
||||
"""
|
||||
if len(x.shape) == 3:
|
||||
# (B, C, T) -> (B x C, T)
|
||||
x = x.reshape([-1, x.shape[2]])
|
||||
# (B, C, T) -> (B x C, T)
|
||||
y = y.reshape([-1, y.shape[2]])
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
for f in self.stft_losses:
|
||||
|
|
Loading…
Reference in New Issue