add fastspeech2 example data preprocess
This commit is contained in:
parent
474bc4c06a
commit
47ec051136
|
@ -0,0 +1,104 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Calculate statistics of feature files."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute mean and variance of dumped raw features.")
|
||||
parser.add_argument(
|
||||
"--metadata", type=str, help="json file with id and file paths ")
|
||||
parser.add_argument(
|
||||
"--field-name",
|
||||
type=str,
|
||||
help="name of the field to compute statistics for.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="path to save statistics. if not provided, "
|
||||
"stats will be saved in the above root directory with name stats.npy")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# set logger
|
||||
if args.verbose > 1:
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
elif args.verbose > 0:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.WARN,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
logging.warning('Skip DEBUG/INFO messages')
|
||||
|
||||
config = get_cfg_default()
|
||||
# load config
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# check directory existence
|
||||
if args.output is None:
|
||||
args.output = Path(args.metadata).parent.with_name(args.field_name +
|
||||
"_stats.npy")
|
||||
else:
|
||||
args.output = Path(args.output)
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata,
|
||||
fields=[args.field_name],
|
||||
converters={args.field_name: np.load}, )
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# calculate statistics
|
||||
scaler = StandardScaler()
|
||||
for datum in tqdm(dataset):
|
||||
# StandardScalar supports (*, num_features) by default
|
||||
scaler.partial_fit(datum[args.field_name])
|
||||
|
||||
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||
|
||||
np.save(str(args.output), stats.astype(np.float32), allow_pickle=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -88,8 +88,8 @@ updater:
|
|||
# OPTIMIZER SETTING #
|
||||
###########################################################
|
||||
optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 0.001 # learning rate
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 0.0001 # learning rate
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import yaml
|
||||
from yacs.config import CfgNode as Configuration
|
||||
import yaml
|
||||
|
||||
with open("conf/default.yaml", 'rt') as f:
|
||||
_C = yaml.safe_load(f)
|
||||
|
|
|
@ -12,15 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
from paddle.fluid.layers import huber_loss
|
||||
|
||||
from parakeet.modules.losses import masked_l1_loss, weighted_mean
|
||||
from parakeet.models.fastspeech2 import FastSpeech2, FastSpeech2Loss
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.models.fastspeech2_new import FastSpeech2, FastSpeech2Loss
|
||||
|
||||
|
||||
class FastSpeech2Updater(StandardUpdater):
|
||||
|
@ -36,7 +31,7 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
def update_core(self, batch):
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, ilens, olens = self.model(
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
|
@ -48,6 +43,7 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
criterion = FastSpeech2Loss(
|
||||
use_masking=self.use_masking,
|
||||
use_weighted_masking=self.use_weighted_masking)
|
||||
|
||||
l1_loss, duration_loss, pitch_loss, energy_loss = criterion(
|
||||
after_outs=after_outs,
|
||||
before_outs=before_outs,
|
||||
|
@ -58,8 +54,9 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
ds=batch["durations"],
|
||||
ps=batch["pitch"],
|
||||
es=batch["energy"],
|
||||
ilens=ilens,
|
||||
olens=olens, )
|
||||
ilens=batch["text_lengths"],
|
||||
olens=olens)
|
||||
|
||||
loss = l1_loss + duration_loss + pitch_loss + energy_loss
|
||||
|
||||
optimizer = self.optimizer
|
||||
|
@ -67,7 +64,6 @@ class FastSpeech2Updater(StandardUpdater):
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
report("train/loss", float(loss))
|
||||
report("train/l1_loss", float(l1_loss))
|
||||
report("train/duration_loss", float(duration_loss))
|
||||
|
@ -86,14 +82,14 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, ilens, olens = self.model(
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
durations=batch["durations"],
|
||||
pitch=batch["pitch"],
|
||||
energy=batch["energy"], )
|
||||
energy=batch["energy"])
|
||||
|
||||
criterion = FastSpeech2Loss(
|
||||
use_masking=self.use_masking,
|
||||
|
@ -108,11 +104,10 @@ class FastSpeech2Evaluator(StandardEvaluator):
|
|||
ds=batch["durations"],
|
||||
ps=batch["pitch"],
|
||||
es=batch["energy"],
|
||||
ilens=ilens,
|
||||
ilens=batch["text_lengths"],
|
||||
olens=olens, )
|
||||
loss = l1_loss + duration_loss + pitch_loss + energy_loss
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
report("eval/loss", float(loss))
|
||||
report("eval/l1_loss", float(l1_loss))
|
||||
report("eval/duration_loss", float(duration_loss))
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
from praatio import tgio
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def readtg(config, tg_path):
|
||||
alignment = tgio.openTextgrid(tg_path, readRaw=True)
|
||||
phones = []
|
||||
ends = []
|
||||
for interval in alignment.tierDict["phones"].entryList:
|
||||
phone = interval.label
|
||||
phones.append(phone)
|
||||
ends.append(interval.end)
|
||||
frame_pos = librosa.time_to_frames(
|
||||
ends, sr=config.fs, hop_length=config.n_shift)
|
||||
durations = np.diff(frame_pos, prepend=0)
|
||||
assert len(durations) == len(phones)
|
||||
results = ""
|
||||
for (p, d) in zip(phones, durations):
|
||||
p = "sil" if p == "" else p
|
||||
results += p + " " + str(d) + " "
|
||||
return results.strip()
|
||||
|
||||
|
||||
# assume that the directory structure of inputdir is inputdir/speaker/*.TextGrid
|
||||
# in MFA1.x, there are blank labels("") in the end, we replace it with "sil"
|
||||
def gen_duration_from_textgrid(config, inputdir, output):
|
||||
durations_dict = {}
|
||||
|
||||
for speaker in os.listdir(inputdir):
|
||||
subdir = inputdir / speaker
|
||||
for file in os.listdir(subdir):
|
||||
if file.endswith(".TextGrid"):
|
||||
tg_path = subdir / file
|
||||
name = file.split(".")[0]
|
||||
durations_dict[name] = readtg(config, tg_path)
|
||||
with open(output, "w") as wf:
|
||||
for name in sorted(durations_dict.keys()):
|
||||
wf.write(name + "|" + durations_dict[name] + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
parser.add_argument(
|
||||
"--inputdir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to alignment files.")
|
||||
parser.add_argument(
|
||||
"--output", type=str, required=True, help="output duration file name")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
C = get_cfg_default()
|
||||
if args.config:
|
||||
C.merge_from_file(args.config)
|
||||
C.freeze()
|
||||
|
||||
inputdir = Path(args.inputdir).expanduser()
|
||||
output = Path(args.output).expanduser()
|
||||
gen_duration_from_textgrid(C, inputdir, output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,221 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pyworld
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
class LogMelFBank():
|
||||
def __init__(self, conf):
|
||||
self.sr = conf.fs
|
||||
# stft
|
||||
self.n_fft = conf.n_fft
|
||||
self.win_length = conf.win_length
|
||||
self.hop_length = conf.n_shift
|
||||
self.window = conf.window
|
||||
self.center = True
|
||||
self.pad_mode = "reflect"
|
||||
|
||||
# mel
|
||||
self.n_mels = conf.n_mels
|
||||
self.fmin = conf.fmin
|
||||
self.fmax = conf.fmax
|
||||
|
||||
self.mel_filter = self._create_mel_filter()
|
||||
|
||||
def _create_mel_filter(self):
|
||||
mel_filter = librosa.filters.mel(sr=self.sr,
|
||||
n_fft=self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax)
|
||||
return mel_filter
|
||||
|
||||
def _stft(self, wav):
|
||||
D = librosa.core.stft(
|
||||
wav,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode)
|
||||
return D
|
||||
|
||||
def _spectrogram(self, wav):
|
||||
D = self._stft(wav)
|
||||
return np.abs(D)
|
||||
|
||||
def _mel_spectrogram(self, wav):
|
||||
S = self._spectrogram(wav)
|
||||
mel = np.dot(self.mel_filter, S)
|
||||
return mel
|
||||
|
||||
def get_log_mel_fbank(self, wav):
|
||||
mel = self._mel_spectrogram(wav)
|
||||
mel = np.clip(mel, a_min=1e-10, a_max=float("inf"))
|
||||
mel = np.log10(mel.T)
|
||||
# (num_frames, n_mels)
|
||||
return mel
|
||||
|
||||
|
||||
class Pitch():
|
||||
def __init__(self, conf):
|
||||
|
||||
self.sr = conf.fs
|
||||
self.hop_length = conf.n_shift
|
||||
self.f0min = conf.f0min
|
||||
self.f0max = conf.f0max
|
||||
|
||||
def _convert_to_continuous_f0(self, f0: np.array) -> np.array:
|
||||
if (f0 == 0).all():
|
||||
print("All frames seems to be unvoiced.")
|
||||
return f0
|
||||
|
||||
# padding start and end of f0 sequence
|
||||
start_f0 = f0[f0 != 0][0]
|
||||
end_f0 = f0[f0 != 0][-1]
|
||||
start_idx = np.where(f0 == start_f0)[0][0]
|
||||
end_idx = np.where(f0 == end_f0)[0][-1]
|
||||
f0[:start_idx] = start_f0
|
||||
f0[end_idx:] = end_f0
|
||||
|
||||
# get non-zero frame index
|
||||
nonzero_idxs = np.where(f0 != 0)[0]
|
||||
|
||||
# perform linear interpolation
|
||||
interp_fn = interp1d(nonzero_idxs, f0[nonzero_idxs])
|
||||
f0 = interp_fn(np.arange(0, f0.shape[0]))
|
||||
|
||||
return f0
|
||||
|
||||
def _calculate_f0(self,
|
||||
input: np.array,
|
||||
use_continuous_f0=True,
|
||||
use_log_f0=True) -> np.array:
|
||||
input = input.astype(np.float)
|
||||
frame_period = 1000 * self.hop_length / self.sr
|
||||
f0, timeaxis = pyworld.dio(input,
|
||||
fs=self.sr,
|
||||
f0_floor=self.f0min,
|
||||
f0_ceil=self.f0max,
|
||||
frame_period=frame_period)
|
||||
f0 = pyworld.stonemask(input, f0, timeaxis, self.sr)
|
||||
if use_continuous_f0:
|
||||
f0 = self._convert_to_continuous_f0(f0)
|
||||
if use_log_f0:
|
||||
nonzero_idxs = np.where(f0 != 0)[0]
|
||||
f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
|
||||
return f0.reshape(-1)
|
||||
|
||||
def _average_by_duration(self, input: np.array, d: np.array) -> np.array:
|
||||
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
|
||||
arr_list = []
|
||||
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
|
||||
arr = input[start:end]
|
||||
mask = arr == 0
|
||||
arr[mask] = 0
|
||||
avg_arr = np.mean(arr, axis=0) if len(arr) != 0 else np.array(0)
|
||||
arr_list.append(avg_arr)
|
||||
arr_list = np.expand_dims(np.array(arr_list), 0).T
|
||||
|
||||
return arr_list
|
||||
|
||||
def get_pitch(self,
|
||||
wav,
|
||||
use_continuous_f0=True,
|
||||
use_log_f0=True,
|
||||
use_token_averaged_f0=True,
|
||||
duration=None):
|
||||
f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0)
|
||||
if use_token_averaged_f0 and duration is not None:
|
||||
f0 = self._average_by_duration(f0, duration)
|
||||
return f0
|
||||
|
||||
|
||||
class Energy():
|
||||
def __init__(self, conf):
|
||||
|
||||
self.sr = conf.fs
|
||||
self.n_fft = conf.n_fft
|
||||
self.win_length = conf.win_length
|
||||
self.hop_length = conf.n_shift
|
||||
self.window = conf.window
|
||||
self.center = True
|
||||
self.pad_mode = "reflect"
|
||||
|
||||
def _stft(self, wav):
|
||||
D = librosa.core.stft(
|
||||
wav,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=self.window,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode)
|
||||
return D
|
||||
|
||||
def _calculate_energy(self, input):
|
||||
input = input.astype(np.float32)
|
||||
input_stft = self._stft(input)
|
||||
input_power = np.abs(input_stft)**2
|
||||
energy = np.sqrt(
|
||||
np.clip(
|
||||
np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float(
|
||||
'inf')))
|
||||
return energy
|
||||
|
||||
def _average_by_duration(self, input: np.array, d: np.array) -> np.array:
|
||||
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
|
||||
arr_list = []
|
||||
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
|
||||
arr = input[start:end]
|
||||
avg_arr = np.mean(arr, axis=0) if len(arr) != 0 else np.array(0)
|
||||
arr_list.append(avg_arr)
|
||||
arr_list = np.expand_dims(np.array(arr_list), 0).T
|
||||
return arr_list
|
||||
|
||||
def get_energy(self, wav, use_token_averaged_energy=True, duration=None):
|
||||
energy = self._calculate_energy(wav)
|
||||
if use_token_averaged_energy and duration is not None:
|
||||
energy = self._average_by_duration(energy, duration)
|
||||
return energy
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
C = get_cfg_default()
|
||||
filename = "../raw_data/data/format.1/000001.flac"
|
||||
wav, _ = librosa.load(filename, sr=C.fs)
|
||||
mel_extractor = LogMelFBank(C)
|
||||
mel = mel_extractor.get_log_mel_fbank(wav)
|
||||
print(mel)
|
||||
print(mel.shape)
|
||||
|
||||
pitch_extractor = Pitch(C)
|
||||
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
|
||||
duration = np.array([int(x) for x in duration.split(" ")])
|
||||
avg_f0 = pitch_extractor.get_pitch(wav, duration=duration)
|
||||
print(avg_f0)
|
||||
print(avg_f0.shape)
|
||||
|
||||
energy_extractor = Energy(C)
|
||||
duration = "2 8 8 8 12 11 10 13 11 10 18 9 12 10 12 11 5"
|
||||
duration = np.array([int(x) for x in duration.split(" ")])
|
||||
avg_energy = energy_extractor.get_energy(wav, duration=duration)
|
||||
print(avg_energy)
|
||||
print(avg_energy.sum())
|
|
@ -0,0 +1,184 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Normalize feature files and dump them."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump normalized feature files.")
|
||||
parser.add_argument(
|
||||
"--speech_stats",
|
||||
type=str,
|
||||
required=True,
|
||||
help="speech statistics file.")
|
||||
parser.add_argument(
|
||||
"--pitch_stats",
|
||||
type=str,
|
||||
required=True,
|
||||
help="pitch statistics file.")
|
||||
parser.add_argument(
|
||||
"--energy_stats",
|
||||
type=str,
|
||||
required=True,
|
||||
help="energy statistics file.")
|
||||
parser.add_argument(
|
||||
"--phones",
|
||||
type=str,
|
||||
default="phone_id_map.txt ",
|
||||
help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# set logger
|
||||
if args.verbose > 1:
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
elif args.verbose > 0:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.WARN,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
logging.warning('Skip DEBUG/INFO messages')
|
||||
|
||||
# load config
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# check directory existence
|
||||
dumpdir = Path(args.dumpdir).resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata,
|
||||
converters={
|
||||
"speech": np.load,
|
||||
"pitch": np.load,
|
||||
"energy": np.load,
|
||||
})
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# restore scaler
|
||||
speech_scaler = StandardScaler()
|
||||
speech_scaler.mean_ = np.load(args.speech_stats)[0]
|
||||
speech_scaler.scale_ = np.load(args.speech_stats)[1]
|
||||
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
|
||||
|
||||
pitch_scaler = StandardScaler()
|
||||
pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
|
||||
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
|
||||
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
|
||||
|
||||
energy_scaler = StandardScaler()
|
||||
energy_scaler.mean_ = np.load(args.energy_stats)[0]
|
||||
energy_scaler.scale_ = np.load(args.energy_stats)[1]
|
||||
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
|
||||
|
||||
voc_phones = {}
|
||||
with open(args.phones, 'rt') as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
for phn, id in phn_id:
|
||||
voc_phones[phn] = int(id)
|
||||
|
||||
# process each file
|
||||
output_metadata = []
|
||||
|
||||
for item in tqdm(dataset):
|
||||
utt_id = item['utt_id']
|
||||
speech = item['speech']
|
||||
pitch = item['pitch']
|
||||
energy = item['energy']
|
||||
# normalize
|
||||
speech = speech_scaler.transform(speech)
|
||||
speech_dir = dumpdir / "data_speech"
|
||||
speech_dir.mkdir(parents=True, exist_ok=True)
|
||||
speech_path = speech_dir / f"{utt_id}_speech.npy"
|
||||
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
|
||||
|
||||
pitch = pitch_scaler.transform(pitch)
|
||||
pitch_dir = dumpdir / "data_pitch"
|
||||
pitch_dir.mkdir(parents=True, exist_ok=True)
|
||||
pitch_path = pitch_dir / f"{utt_id}_pitch.npy"
|
||||
np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False)
|
||||
|
||||
energy = energy_scaler.transform(energy)
|
||||
energy_dir = dumpdir / "data_energy"
|
||||
energy_dir.mkdir(parents=True, exist_ok=True)
|
||||
energy_path = energy_dir / f"{utt_id}_energy.npy"
|
||||
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
|
||||
phone_ids = [voc_phones[p] for p in item['phones']]
|
||||
record = {
|
||||
"utt_id": item['utt_id'],
|
||||
"text": phone_ids,
|
||||
"text_lengths": item['text_lengths'],
|
||||
"speech_lengths": item['speech_lengths'],
|
||||
"durations": item['durations'],
|
||||
"speech": str(speech_path),
|
||||
"pitch": str(pitch_path),
|
||||
"energy": str(energy_path)
|
||||
}
|
||||
output_metadata.append(record)
|
||||
output_metadata.sort(key=itemgetter('utt_id'))
|
||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||
for item in output_metadata:
|
||||
writer.write(item)
|
||||
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,351 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from config import get_cfg_default
|
||||
from get_feats import LogMelFBank, Energy, Pitch
|
||||
|
||||
|
||||
def get_phn_dur(file_name):
|
||||
'''
|
||||
read MFA duration.txt
|
||||
Parameters
|
||||
----------
|
||||
file_name : str or Path
|
||||
path of gen_duration_from_textgrid.py's result
|
||||
Returns
|
||||
----------
|
||||
Dict
|
||||
sentence: {'utt': ([char], [int])}
|
||||
'''
|
||||
f = open(file_name, 'r')
|
||||
sentence = {}
|
||||
for line in f:
|
||||
utt = line.strip().split('|')[0]
|
||||
p_d = line.strip().split('|')[-1]
|
||||
phn_dur = p_d.split()
|
||||
phn = phn_dur[::2]
|
||||
dur = phn_dur[1::2]
|
||||
assert len(phn) == len(dur)
|
||||
sentence[utt] = (phn, [int(i) for i in dur])
|
||||
f.close()
|
||||
return sentence
|
||||
|
||||
|
||||
def deal_silence(sentence):
|
||||
'''
|
||||
merge silences, set <eos>
|
||||
Parameters
|
||||
----------
|
||||
sentence : Dict
|
||||
sentence: {'utt': ([char], [int])}
|
||||
'''
|
||||
for utt in sentence:
|
||||
cur_phn, cur_dur = sentence[utt]
|
||||
new_phn = []
|
||||
new_dur = []
|
||||
|
||||
# merge sp and sil
|
||||
for i, p in enumerate(cur_phn):
|
||||
if i > 0 and 'sil' == p and cur_phn[i - 1] in {"sil", "sp"}:
|
||||
new_dur[-1] += cur_dur[i]
|
||||
new_phn[-1] = 'sil'
|
||||
else:
|
||||
new_phn.append(p)
|
||||
new_dur.append(cur_dur[i])
|
||||
|
||||
# merge little sil in the begin
|
||||
if new_phn[0] == 'sil' and new_dur[0] <= 14:
|
||||
new_phn = new_phn[1:]
|
||||
new_dur[1] += new_dur[0]
|
||||
new_dur = new_dur[1:]
|
||||
|
||||
# replace the last sil with <eos> if exist
|
||||
if new_phn[-1] == 'sil':
|
||||
new_phn[-1] = '<eos>'
|
||||
else:
|
||||
new_phn.append('<eos>')
|
||||
new_dur.append(0)
|
||||
|
||||
for i, (p, d) in enumerate(zip(new_phn, new_dur)):
|
||||
if p in {"sil", "sp"}:
|
||||
if d < 14:
|
||||
new_phn[i] = 'sp'
|
||||
else:
|
||||
new_phn[i] = 'sp1'
|
||||
|
||||
assert len(new_phn) == len(new_dur)
|
||||
sentence[utt] = (new_phn, new_dur)
|
||||
|
||||
|
||||
def get_input_token(sentence, output_path):
|
||||
'''
|
||||
get phone set from training data and save it
|
||||
Parameters
|
||||
----------
|
||||
sentence : Dict
|
||||
sentence: {'utt': ([char], [int])}
|
||||
output_path : str or path
|
||||
path to save phone_id_map
|
||||
'''
|
||||
phn_emb = set()
|
||||
for utt in sentence:
|
||||
for phn in sentence[utt][0]:
|
||||
if phn != "<eos>":
|
||||
phn_emb.add(phn)
|
||||
phn_emb = list(phn_emb)
|
||||
phn_emb.sort()
|
||||
phn_emb = ["<pad>", "<unk>"] + phn_emb
|
||||
phn_emb += [",", "。", "?", "!", "<eos>"]
|
||||
|
||||
f = open(output_path, 'w')
|
||||
for i, phn in enumerate(phn_emb):
|
||||
f.write(phn + ' ' + str(i) + '\n')
|
||||
f.close()
|
||||
|
||||
|
||||
def compare_duration_and_mel_length(sentences, utt, mel):
|
||||
'''
|
||||
check duration error, correct sentences[utt] if possible, else pop sentences[utt]
|
||||
Parameters
|
||||
----------
|
||||
sentences : Dict
|
||||
sentences[utt] = [phones_list ,durations_list]
|
||||
utt : str
|
||||
utt_id
|
||||
mel : np.ndarry
|
||||
features (num_frames, n_mels)
|
||||
'''
|
||||
|
||||
if utt in sentences:
|
||||
len_diff = mel.shape[0] - sum(sentences[utt][1])
|
||||
if len_diff != 0:
|
||||
if len_diff > 0:
|
||||
sentences[utt][1][-1] += len_diff
|
||||
elif sentences[utt][1][-1] + len_diff > 0:
|
||||
sentences[utt][1][-1] += len_diff
|
||||
elif sentences[utt][1][0] + len_diff > 0:
|
||||
sentences[utt][1][0] += len_diff
|
||||
else:
|
||||
# 一般不会触发这个
|
||||
print("the len_diff is unable to correct:", len_diff)
|
||||
sentences.pop(utt)
|
||||
|
||||
|
||||
def process_sentence(
|
||||
config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None, ):
|
||||
utt_id = fp.stem
|
||||
record = None
|
||||
if utt_id in sentences:
|
||||
# reading, resampling may occur
|
||||
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(wav).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(wav)
|
||||
# change duration according to mel_length
|
||||
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
||||
phones = sentences[utt_id][0]
|
||||
duration = sentences[utt_id][1]
|
||||
num_frames = logmel.shape[0]
|
||||
assert sum(duration) == num_frames
|
||||
mel_dir = output_dir / "data_speech"
|
||||
mel_dir.mkdir(parents=True, exist_ok=True)
|
||||
mel_path = mel_dir / (utt_id + "_speech.npy")
|
||||
np.save(mel_path, logmel)
|
||||
# extract pitch and energy
|
||||
f0 = pitch_extractor.get_pitch(wav, duration=np.array(duration))
|
||||
assert f0.shape[0] == len(duration)
|
||||
f0_dir = output_dir / "data_pitch"
|
||||
f0_dir.mkdir(parents=True, exist_ok=True)
|
||||
f0_path = f0_dir / (utt_id + "_pitch.npy")
|
||||
np.save(f0_path, f0)
|
||||
energy = energy_extractor.get_energy(wav, duration=np.array(duration))
|
||||
assert energy.shape[0] == len(duration)
|
||||
energy_dir = output_dir / "data_energy"
|
||||
energy_dir.mkdir(parents=True, exist_ok=True)
|
||||
energy_path = energy_dir / (utt_id + "_energy.npy")
|
||||
np.save(energy_path, energy)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
"text_lengths": len(phones),
|
||||
"speech_lengths": num_frames,
|
||||
"durations": duration,
|
||||
# use absolute path
|
||||
"speech": str(mel_path.resolve()),
|
||||
"pitch": str(f0_path.resolve()),
|
||||
"energy": str(energy_path.resolve())
|
||||
}
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(config,
|
||||
fps: List[Path],
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
nprocs: int=1):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||
record = process_sentence(config, fp, sentences, output_dir,
|
||||
mel_extractor, pitch_extractor,
|
||||
energy_extractor)
|
||||
if record:
|
||||
results.append(record)
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp in fps:
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
sentences, output_dir, mel_extractor,
|
||||
pitch_extractor, energy_extractor)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
record = ft.result()
|
||||
if record:
|
||||
results.append(record)
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
parser.add_argument(
|
||||
"--rootdir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to baker dataset.")
|
||||
parser.add_argument(
|
||||
"--dur_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path to baker durations.txt.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
parser.add_argument(
|
||||
"--num_cpu", type=int, default=1, help="number of process.")
|
||||
args = parser.parse_args()
|
||||
|
||||
C = get_cfg_default()
|
||||
if args.config:
|
||||
C.merge_from_file(args.config)
|
||||
C.freeze()
|
||||
|
||||
if args.verbose > 1:
|
||||
print(vars(args))
|
||||
print(C)
|
||||
|
||||
root_dir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sentences = get_phn_dur(args.dur_path)
|
||||
deal_silence(sentences)
|
||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||
get_input_token(sentences, phone_id_map_path)
|
||||
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
|
||||
|
||||
# split data into 3 sections
|
||||
num_train = 9800
|
||||
num_dev = 100
|
||||
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(C)
|
||||
pitch_extractor = Pitch(C)
|
||||
energy_extractor = Energy(C)
|
||||
|
||||
# process for the 3 sections
|
||||
process_sentences(
|
||||
C,
|
||||
train_wav_files,
|
||||
sentences,
|
||||
train_dump_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
dev_wav_files,
|
||||
sentences,
|
||||
dev_dump_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
test_wav_files,
|
||||
sentences,
|
||||
test_dump_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
nprocs=args.num_cpu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
|
||||
# get durations from MFA's result
|
||||
python3 gen_duration_from_textgrid.py --inputdir ./baker_alignment_tone --output durations.txt
|
||||
|
||||
# extract features
|
||||
python3 preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --dur_path durations.txt --num_cpu 16
|
||||
|
||||
# # get features' stats(mean and std)
|
||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="speech"
|
||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="pitch"
|
||||
python3 compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="energy"
|
||||
|
||||
# normalize and covert phone to id, dev and test should use train's stats
|
||||
python3 normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --speech_stats=dump/train/speech_stats.npy --pitch_stats=dump/train/pitch_stats.npy --energy_stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
|
||||
python3 normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --speech_stats=dump/train/speech_stats.npy --pitch_stats=dump/train/pitch_stats.npy --energy_stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
|
||||
python3 normalize.py --metadata=dump/test/raw/metadata.jsonl --dumpdir=dump/test/norm --speech_stats=dump/train/speech_stats.npy --pitch_stats=dump/train/pitch_stats.npy --energy_stats=dump/train/energy_stats.npy --phones dump/phone_id_map.txt
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
|
||||
python3 train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=conf/default.yaml \
|
||||
--output-dir=exp/default \
|
||||
--nprocs=1 \
|
||||
--phones=dump/phone_id_map.txt
|
File diff suppressed because it is too large
Load Diff
|
@ -12,40 +12,30 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
import paddle
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
|
||||
from paddle import distributed as dist
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.fastspeech2_new import FastSpeech2
|
||||
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training import extension
|
||||
from parakeet.models.fastspeech2 import FastSpeech2
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.seeding import seed_everything
|
||||
from parakeet.training.trainer import Trainer
|
||||
from visualdl import LogWriter
|
||||
import yaml
|
||||
|
||||
from batch_fn import collate_baker_examples
|
||||
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
||||
from config import get_cfg_default
|
||||
from fastspeech2_updater import FastSpeech2Updater, FastSpeech2Evaluator
|
||||
|
||||
optim_classes = dict(
|
||||
adadelta=paddle.optimizer.Adadelta,
|
||||
|
@ -61,7 +51,6 @@ optim_classes = dict(
|
|||
|
||||
def build_optimizers(model: nn.Layer, optim='adadelta',
|
||||
learning_rate=0.01) -> paddle.optimizer:
|
||||
|
||||
optim_class = optim_classes.get(optim)
|
||||
if optim_class is None:
|
||||
raise ValueError(f"must be one of {list(optim_classes)}: {optim}")
|
||||
|
@ -100,22 +89,12 @@ def train_sp(args, config):
|
|||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=[
|
||||
"text",
|
||||
"text_lengths",
|
||||
"speech",
|
||||
"speech_lengths",
|
||||
"durations",
|
||||
"pitch",
|
||||
"energy",
|
||||
# "durations_lengths",
|
||||
# "pitch_lengths",
|
||||
# "energy_lengths"
|
||||
"text", "text_lengths", "speech", "speech_lengths", "durations",
|
||||
"pitch", "energy"
|
||||
],
|
||||
converters={
|
||||
"speech": np.load,
|
||||
"pitch": np.load,
|
||||
"energy": np.load,
|
||||
}, )
|
||||
converters={"speech": np.load,
|
||||
"pitch": np.load,
|
||||
"energy": np.load}, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
|
@ -124,17 +103,15 @@ def train_sp(args, config):
|
|||
"text", "text_lengths", "speech", "speech_lengths", "durations",
|
||||
"pitch", "energy"
|
||||
],
|
||||
converters={
|
||||
"speech": np.load,
|
||||
"pitch": np.load,
|
||||
"energy": np.load,
|
||||
}, )
|
||||
converters={"speech": np.load,
|
||||
"pitch": np.load,
|
||||
"energy": np.load}, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
print("samplers done!")
|
||||
|
@ -144,6 +121,7 @@ def train_sp(args, config):
|
|||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_baker_examples,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
shuffle=False,
|
||||
|
@ -153,16 +131,18 @@ def train_sp(args, config):
|
|||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
vocab_size = 202
|
||||
with open(args.phones, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
odim = config.n_mels
|
||||
model = FastSpeech2(idim=vocab_size, odim=odim, **config["model"])
|
||||
if world_size > 1:
|
||||
model = DataParallel(model) # TODO, do not use vocab size from config
|
||||
# print(model)
|
||||
print("model done!")
|
||||
|
||||
optimizer = build_optimizers(model, **config["optimizer"])
|
||||
|
||||
print("optimizer done!")
|
||||
|
||||
updater = FastSpeech2Updater(
|
||||
|
@ -174,8 +154,8 @@ def train_sp(args, config):
|
|||
output_dir = Path(args.output_dir)
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = FastSpeech2Evaluator(model, dev_dataloader, **
|
||||
config["updater"])
|
||||
evaluator = FastSpeech2Evaluator(model, dev_dataloader,
|
||||
**config["updater"])
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
|
@ -201,6 +181,11 @@ def main():
|
|||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
parser.add_argument(
|
||||
"--phones",
|
||||
type=str,
|
||||
default="phone_id_map.txt ",
|
||||
help="phone vocabulary file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device == "cpu" and args.nprocs > 1:
|
||||
|
|
|
@ -20,4 +20,4 @@ from parakeet.models.transformer_tts import *
|
|||
#from parakeet.models.deepvoice3 import *
|
||||
# from parakeet.models.fastspeech import *
|
||||
from parakeet.models.tacotron2 import *
|
||||
from parakeet.models.fastspeech2_new import *
|
||||
from parakeet.models.fastspeech2 import *
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,660 +0,0 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fastspeech2 related modules for paddle"""
|
||||
|
||||
from typing import Dict
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typeguard import check_argument_types
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
from parakeet.modules.fastspeech2_predictor.duration_predictor import DurationPredictor
|
||||
from parakeet.modules.fastspeech2_predictor.duration_predictor import DurationPredictorLoss
|
||||
from parakeet.modules.fastspeech2_predictor.length_regulator import LengthRegulator
|
||||
from parakeet.modules.fastspeech2_predictor.postnet import Postnet
|
||||
from parakeet.modules.fastspeech2_predictor.variance_predictor import VariancePredictor
|
||||
from parakeet.modules.fastspeech2_transformer.embedding import PositionalEncoding
|
||||
from parakeet.modules.fastspeech2_transformer.embedding import ScaledPositionalEncoding
|
||||
from parakeet.modules.fastspeech2_transformer.encoder import Encoder as TransformerEncoder
|
||||
from parakeet.modules.nets_utils import initialize
|
||||
from parakeet.modules.nets_utils import make_non_pad_mask
|
||||
from parakeet.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class FastSpeech2(nn.Layer):
|
||||
"""FastSpeech2 module.
|
||||
|
||||
This is a module of FastSpeech2 described in `FastSpeech 2: Fast and
|
||||
High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and
|
||||
energy, we use token-averaged value introduced in `FastPitch: Parallel
|
||||
Text-to-speech with Pitch Prediction`_.
|
||||
|
||||
.. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
|
||||
https://arxiv.org/abs/2006.04558
|
||||
.. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`:
|
||||
https://arxiv.org/abs/2006.06873
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# network structure related
|
||||
idim: int,
|
||||
odim: int,
|
||||
adim: int=384,
|
||||
aheads: int=4,
|
||||
elayers: int=6,
|
||||
eunits: int=1536,
|
||||
dlayers: int=6,
|
||||
dunits: int=1536,
|
||||
postnet_layers: int=5,
|
||||
postnet_chans: int=512,
|
||||
postnet_filts: int=5,
|
||||
positionwise_layer_type: str="conv1d",
|
||||
positionwise_conv_kernel_size: int=1,
|
||||
use_scaled_pos_enc: bool=True,
|
||||
use_batch_norm: bool=True,
|
||||
encoder_normalize_before: bool=True,
|
||||
decoder_normalize_before: bool=True,
|
||||
encoder_concat_after: bool=False,
|
||||
decoder_concat_after: bool=False,
|
||||
reduction_factor: int=1,
|
||||
encoder_type: str="transformer",
|
||||
decoder_type: str="transformer",
|
||||
# duration predictor
|
||||
duration_predictor_layers: int=2,
|
||||
duration_predictor_chans: int=384,
|
||||
duration_predictor_kernel_size: int=3,
|
||||
# energy predictor
|
||||
energy_predictor_layers: int=2,
|
||||
energy_predictor_chans: int=384,
|
||||
energy_predictor_kernel_size: int=3,
|
||||
energy_predictor_dropout: float=0.5,
|
||||
energy_embed_kernel_size: int=9,
|
||||
energy_embed_dropout: float=0.5,
|
||||
stop_gradient_from_energy_predictor: bool=False,
|
||||
# pitch predictor
|
||||
pitch_predictor_layers: int=2,
|
||||
pitch_predictor_chans: int=384,
|
||||
pitch_predictor_kernel_size: int=3,
|
||||
pitch_predictor_dropout: float=0.5,
|
||||
pitch_embed_kernel_size: int=9,
|
||||
pitch_embed_dropout: float=0.5,
|
||||
stop_gradient_from_pitch_predictor: bool=False,
|
||||
# training related
|
||||
transformer_enc_dropout_rate: float=0.1,
|
||||
transformer_enc_positional_dropout_rate: float=0.1,
|
||||
transformer_enc_attn_dropout_rate: float=0.1,
|
||||
transformer_dec_dropout_rate: float=0.1,
|
||||
transformer_dec_positional_dropout_rate: float=0.1,
|
||||
transformer_dec_attn_dropout_rate: float=0.1,
|
||||
duration_predictor_dropout_rate: float=0.1,
|
||||
postnet_dropout_rate: float=0.5,
|
||||
init_type: str="xavier_uniform",
|
||||
init_enc_alpha: float=1.0,
|
||||
init_dec_alpha: float=1.0,
|
||||
use_masking: bool=False,
|
||||
use_weighted_masking: bool=False, ):
|
||||
"""Initialize FastSpeech2 module."""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
# store hyperparameters
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
self.eos = idim - 1
|
||||
self.reduction_factor = reduction_factor
|
||||
self.encoder_type = encoder_type
|
||||
self.decoder_type = decoder_type
|
||||
self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor
|
||||
self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor
|
||||
self.use_scaled_pos_enc = use_scaled_pos_enc
|
||||
|
||||
# use idx 0 as padding idx
|
||||
self.padding_idx = 0
|
||||
|
||||
# initialize parameters
|
||||
initialize(self, init_type)
|
||||
|
||||
# get positional encoding class
|
||||
pos_enc_class = (ScaledPositionalEncoding
|
||||
if self.use_scaled_pos_enc else PositionalEncoding)
|
||||
|
||||
# define encoder
|
||||
encoder_input_layer = nn.Embedding(
|
||||
num_embeddings=idim,
|
||||
embedding_dim=adim,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
if encoder_type == "transformer":
|
||||
self.encoder = TransformerEncoder(
|
||||
idim=idim,
|
||||
attention_dim=adim,
|
||||
attention_heads=aheads,
|
||||
linear_units=eunits,
|
||||
num_blocks=elayers,
|
||||
input_layer=encoder_input_layer,
|
||||
dropout_rate=transformer_enc_dropout_rate,
|
||||
positional_dropout_rate=transformer_enc_positional_dropout_rate,
|
||||
attention_dropout_rate=transformer_enc_attn_dropout_rate,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=encoder_normalize_before,
|
||||
concat_after=encoder_concat_after,
|
||||
positionwise_layer_type=positionwise_layer_type,
|
||||
positionwise_conv_kernel_size=positionwise_conv_kernel_size, )
|
||||
else:
|
||||
raise ValueError(f"{encoder_type} is not supported.")
|
||||
|
||||
# define duration predictor
|
||||
self.duration_predictor = DurationPredictor(
|
||||
idim=adim,
|
||||
n_layers=duration_predictor_layers,
|
||||
n_chans=duration_predictor_chans,
|
||||
kernel_size=duration_predictor_kernel_size,
|
||||
dropout_rate=duration_predictor_dropout_rate, )
|
||||
|
||||
# define pitch predictor
|
||||
self.pitch_predictor = VariancePredictor(
|
||||
idim=adim,
|
||||
n_layers=pitch_predictor_layers,
|
||||
n_chans=pitch_predictor_chans,
|
||||
kernel_size=pitch_predictor_kernel_size,
|
||||
dropout_rate=pitch_predictor_dropout, )
|
||||
# We use continuous pitch + FastPitch style avg
|
||||
self.pitch_embed = nn.Sequential(
|
||||
nn.Conv1D(
|
||||
in_channels=1,
|
||||
out_channels=adim,
|
||||
kernel_size=pitch_embed_kernel_size,
|
||||
padding=(pitch_embed_kernel_size - 1) // 2, ),
|
||||
nn.Dropout(pitch_embed_dropout), )
|
||||
|
||||
# define energy predictor
|
||||
self.energy_predictor = VariancePredictor(
|
||||
idim=adim,
|
||||
n_layers=energy_predictor_layers,
|
||||
n_chans=energy_predictor_chans,
|
||||
kernel_size=energy_predictor_kernel_size,
|
||||
dropout_rate=energy_predictor_dropout, )
|
||||
# We use continuous enegy + FastPitch style avg
|
||||
self.energy_embed = nn.Sequential(
|
||||
nn.Conv1D(
|
||||
in_channels=1,
|
||||
out_channels=adim,
|
||||
kernel_size=energy_embed_kernel_size,
|
||||
padding=(energy_embed_kernel_size - 1) // 2, ),
|
||||
nn.Dropout(energy_embed_dropout), )
|
||||
|
||||
# define length regulator
|
||||
self.length_regulator = LengthRegulator()
|
||||
|
||||
# define decoder
|
||||
# NOTE: we use encoder as decoder
|
||||
# because fastspeech's decoder is the same as encoder
|
||||
if decoder_type == "transformer":
|
||||
self.decoder = TransformerEncoder(
|
||||
idim=0,
|
||||
attention_dim=adim,
|
||||
attention_heads=aheads,
|
||||
linear_units=dunits,
|
||||
num_blocks=dlayers,
|
||||
input_layer=None,
|
||||
dropout_rate=transformer_dec_dropout_rate,
|
||||
positional_dropout_rate=transformer_dec_positional_dropout_rate,
|
||||
attention_dropout_rate=transformer_dec_attn_dropout_rate,
|
||||
pos_enc_class=pos_enc_class,
|
||||
normalize_before=decoder_normalize_before,
|
||||
concat_after=decoder_concat_after,
|
||||
positionwise_layer_type=positionwise_layer_type,
|
||||
positionwise_conv_kernel_size=positionwise_conv_kernel_size, )
|
||||
else:
|
||||
raise ValueError(f"{decoder_type} is not supported.")
|
||||
|
||||
# define final projection
|
||||
self.feat_out = nn.Linear(adim, odim * reduction_factor)
|
||||
|
||||
# define postnet
|
||||
self.postnet = (None if postnet_layers == 0 else Postnet(
|
||||
idim=idim,
|
||||
odim=odim,
|
||||
n_layers=postnet_layers,
|
||||
n_chans=postnet_chans,
|
||||
n_filts=postnet_filts,
|
||||
use_batch_norm=use_batch_norm,
|
||||
dropout_rate=postnet_dropout_rate, ))
|
||||
|
||||
self._reset_parameters(
|
||||
init_type=init_type,
|
||||
init_enc_alpha=init_enc_alpha,
|
||||
init_dec_alpha=init_dec_alpha, )
|
||||
|
||||
# define criterions
|
||||
self.criterion = FastSpeech2Loss(
|
||||
use_masking=use_masking, use_weighted_masking=use_weighted_masking)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
durations: paddle.Tensor,
|
||||
pitch: paddle.Tensor,
|
||||
energy: paddle.Tensor, ) -> Tuple[paddle.Tensor, Dict[
|
||||
str, paddle.Tensor], paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : LongTensor
|
||||
Batch of padded token ids (B, Tmax).
|
||||
text_lengths : LongTensor)
|
||||
Batch of lengths of each input (B,).
|
||||
speech : Tensor
|
||||
Batch of padded target features (B, Lmax, odim).
|
||||
speech_lengths : LongTensor
|
||||
Batch of the lengths of each target (B,).
|
||||
durations : LongTensor
|
||||
Batch of padded durations (B, Tmax + 1).
|
||||
pitch : Tensor
|
||||
Batch of padded token-averaged pitch (B, Tmax + 1, 1).
|
||||
energy : Tensor
|
||||
Batch of padded token-averaged energy (B, Tmax + 1, 1).
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
mel outs before postnet
|
||||
Tensor
|
||||
mel outs after postnet
|
||||
Tensor
|
||||
duration predictor's output
|
||||
Tensor
|
||||
pitch predictor's output
|
||||
Tensor
|
||||
energy predictor's output
|
||||
Tensor
|
||||
speech
|
||||
Tensor
|
||||
real text_lengths
|
||||
Tensor
|
||||
speech_lengths, modified if reduction_factor >1
|
||||
"""
|
||||
|
||||
batch_size = text.shape[0]
|
||||
|
||||
# Add eos at the last of sequence
|
||||
xs = np.pad(text.numpy(),
|
||||
pad_width=((0, 0), (0, 1)),
|
||||
mode="constant",
|
||||
constant_values=self.padding_idx)
|
||||
xs = paddle.to_tensor(xs)
|
||||
for i, l in enumerate(text_lengths):
|
||||
xs[i, l] = self.eos
|
||||
ilens = text_lengths + 1
|
||||
|
||||
ys, ds, ps, es = speech, durations, pitch, energy
|
||||
olens = speech_lengths
|
||||
|
||||
# forward propagation
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(
|
||||
xs, ilens, ys, olens, ds, ps, es, is_inference=False)
|
||||
# modify mod part of groundtruth
|
||||
if self.reduction_factor > 1:
|
||||
olens = paddle.to_tensor([
|
||||
olen - olen % self.reduction_factor for olen in olens.numpy()
|
||||
])
|
||||
max_olen = max(olens)
|
||||
ys = ys[:, :max_olen]
|
||||
|
||||
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, ilens, olens
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
ilens: paddle.Tensor,
|
||||
ys: paddle.Tensor=None,
|
||||
olens: paddle.Tensor=None,
|
||||
ds: paddle.Tensor=None,
|
||||
ps: paddle.Tensor=None,
|
||||
es: paddle.Tensor=None,
|
||||
is_inference: bool=False,
|
||||
alpha: float=1.0, ) -> Sequence[paddle.Tensor]:
|
||||
# forward encoder
|
||||
x_masks = self._source_mask(ilens)
|
||||
|
||||
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)
|
||||
# forward duration predictor and variance predictors
|
||||
d_masks = make_pad_mask(ilens)
|
||||
|
||||
if self.stop_gradient_from_pitch_predictor:
|
||||
p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1))
|
||||
else:
|
||||
p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1))
|
||||
if self.stop_gradient_from_energy_predictor:
|
||||
e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1))
|
||||
else:
|
||||
e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1))
|
||||
|
||||
if is_inference:
|
||||
d_outs = self.duration_predictor.inference(hs,
|
||||
d_masks) # (B, Tmax)
|
||||
# use prediction in inference
|
||||
# (B, Tmax, 1)
|
||||
|
||||
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
hs = hs + e_embs + p_embs
|
||||
hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim)
|
||||
else:
|
||||
d_outs = self.duration_predictor(hs, d_masks)
|
||||
# use groundtruth in training
|
||||
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
hs = hs + e_embs + p_embs
|
||||
hs = self.length_regulator(hs, ds) # (B, Lmax, adim)
|
||||
|
||||
# forward decoder
|
||||
if olens is not None and not is_inference:
|
||||
if self.reduction_factor > 1:
|
||||
olens_in = paddle.to_tensor(
|
||||
[olen // self.reduction_factor for olen in olens.numpy()])
|
||||
else:
|
||||
olens_in = olens
|
||||
h_masks = self._source_mask(olens_in)
|
||||
else:
|
||||
h_masks = None
|
||||
# (B, Lmax, adim)
|
||||
zs, _ = self.decoder(hs, h_masks)
|
||||
# (B, Lmax, odim)
|
||||
before_outs = self.feat_out(zs).reshape((zs.shape[0], -1, self.odim))
|
||||
|
||||
# postnet -> (B, Lmax//r * r, odim)
|
||||
if self.postnet is None:
|
||||
after_outs = before_outs
|
||||
else:
|
||||
after_outs = before_outs + self.postnet(
|
||||
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
||||
|
||||
return before_outs, after_outs, d_outs, p_outs, e_outs
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
speech: paddle.Tensor=None,
|
||||
durations: paddle.Tensor=None,
|
||||
pitch: paddle.Tensor=None,
|
||||
energy: paddle.Tensor=None,
|
||||
alpha: float=1.0,
|
||||
use_teacher_forcing: bool=False, ) -> Tuple[
|
||||
paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Generate the sequence of features given the sequences of characters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : LongTensor
|
||||
Input sequence of characters (T,).
|
||||
speech : Tensor, optional
|
||||
Feature sequence to extract style (N, idim).
|
||||
durations : LongTensor, optional
|
||||
Groundtruth of duration (T + 1,).
|
||||
pitch : Tensor, optional
|
||||
Groundtruth of token-averaged pitch (T + 1, 1).
|
||||
energy : Tensor, optional
|
||||
Groundtruth of token-averaged energy (T + 1, 1).
|
||||
alpha : float, optional
|
||||
Alpha to control the speed.
|
||||
use_teacher_forcing : bool, optional
|
||||
Whether to use teacher forcing.
|
||||
If true, groundtruth of duration, pitch and energy will be used.
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Output sequence of features (L, odim).
|
||||
None
|
||||
Dummy for compatibility.
|
||||
|
||||
"""
|
||||
x, y = text, speech
|
||||
d, p, e = durations, pitch, energy
|
||||
|
||||
# add eos at the last of sequence
|
||||
x = np.pad(text.numpy(),
|
||||
pad_width=((0, 1)),
|
||||
mode="constant",
|
||||
constant_values=self.eos)
|
||||
|
||||
x = paddle.to_tensor(x)
|
||||
|
||||
# setup batch axis
|
||||
ilens = paddle.to_tensor(
|
||||
[x.shape[0]], dtype=paddle.int64, place=x.place)
|
||||
xs, ys = x.unsqueeze(0), None
|
||||
|
||||
if y is not None:
|
||||
ys = y.unsqueeze(0)
|
||||
|
||||
if use_teacher_forcing:
|
||||
# use groundtruth of duration, pitch, and energy
|
||||
ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0)
|
||||
_, outs, *_ = self._forward(
|
||||
xs,
|
||||
ilens,
|
||||
ys,
|
||||
ds=ds,
|
||||
ps=ps,
|
||||
es=es, ) # (1, L, odim)
|
||||
else:
|
||||
_, outs, *_ = self._forward(
|
||||
xs,
|
||||
ilens,
|
||||
ys,
|
||||
is_inference=True,
|
||||
alpha=alpha, ) # (1, L, odim)
|
||||
|
||||
return outs[0], None, None
|
||||
|
||||
def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Make masks for self-attention.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ilens : LongTensor
|
||||
Batch of lengths (B,).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
Mask tensor for self-attention.
|
||||
dtype=paddle.bool
|
||||
|
||||
Examples
|
||||
-------
|
||||
>>> ilens = [5, 3]
|
||||
>>> self._source_mask(ilens)
|
||||
tensor([[[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 0, 0]]]) bool
|
||||
|
||||
"""
|
||||
x_masks = make_non_pad_mask(ilens)
|
||||
return x_masks.unsqueeze(-2)
|
||||
|
||||
def _reset_parameters(self,
|
||||
init_type: str,
|
||||
init_enc_alpha: float,
|
||||
init_dec_alpha: float):
|
||||
|
||||
# initialize alpha in scaled positional encoding
|
||||
if self.encoder_type == "transformer" and self.use_scaled_pos_enc:
|
||||
init_enc_alpha = paddle.to_tensor(init_enc_alpha)
|
||||
self.encoder.embed[-1].alpha = paddle.create_parameter(
|
||||
shape=init_enc_alpha.shape,
|
||||
dtype=str(init_enc_alpha.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(
|
||||
init_enc_alpha))
|
||||
if self.decoder_type == "transformer" and self.use_scaled_pos_enc:
|
||||
init_dec_alpha = paddle.to_tensor(init_dec_alpha)
|
||||
self.decoder.embed[-1].alpha = paddle.create_parameter(
|
||||
shape=init_dec_alpha.shape,
|
||||
dtype=str(init_dec_alpha.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(
|
||||
init_dec_alpha))
|
||||
|
||||
|
||||
class FastSpeech2Loss(nn.Layer):
|
||||
"""Loss function module for FastSpeech2."""
|
||||
|
||||
def __init__(self,
|
||||
use_masking: bool=True,
|
||||
use_weighted_masking: bool=False):
|
||||
"""Initialize feed-forward Transformer loss module.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
use_masking : bool
|
||||
Whether to apply masking for padded part in loss calculation.
|
||||
use_weighted_masking : bool
|
||||
Whether to weighted masking in loss calculation.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
assert (use_masking != use_weighted_masking) or not use_masking
|
||||
self.use_masking = use_masking
|
||||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
# define criterions
|
||||
reduction = "none" if self.use_weighted_masking else "mean"
|
||||
self.l1_criterion = nn.L1Loss(reduction=reduction)
|
||||
self.mse_criterion = nn.MSELoss(reduction=reduction)
|
||||
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
after_outs: paddle.Tensor,
|
||||
before_outs: paddle.Tensor,
|
||||
d_outs: paddle.Tensor,
|
||||
p_outs: paddle.Tensor,
|
||||
e_outs: paddle.Tensor,
|
||||
ys: paddle.Tensor,
|
||||
ds: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
es: paddle.Tensor,
|
||||
ilens: paddle.Tensor,
|
||||
olens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor,
|
||||
paddle.Tensor, paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
after_outs : Tensor
|
||||
Batch of outputs after postnets (B, Lmax, odim).
|
||||
before_outs : Tensor
|
||||
Batch of outputs before postnets (B, Lmax, odim).
|
||||
d_outs : LongTensor
|
||||
Batch of outputs of duration predictor (B, Tmax).
|
||||
p_outs : Tensor
|
||||
Batch of outputs of pitch predictor (B, Tmax, 1).
|
||||
e_outs : Tensor
|
||||
Batch of outputs of energy predictor (B, Tmax, 1).
|
||||
ys : Tensor
|
||||
Batch of target features (B, Lmax, odim).
|
||||
ds : LongTensor
|
||||
Batch of durations (B, Tmax).
|
||||
ps : Tensor
|
||||
Batch of target token-averaged pitch (B, Tmax, 1).
|
||||
es : Tensor
|
||||
Batch of target token-averaged energy (B, Tmax, 1).
|
||||
ilens : LongTensor
|
||||
Batch of the lengths of each input (B,).
|
||||
olens : LongTensor
|
||||
Batch of the lengths of each target (B,).
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
L1 loss value.
|
||||
Tensor
|
||||
Duration predictor loss value.
|
||||
Tensor
|
||||
Pitch predictor loss value.
|
||||
Tensor
|
||||
Energy predictor loss value.
|
||||
|
||||
"""
|
||||
# apply mask to remove padded part
|
||||
if self.use_masking:
|
||||
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
||||
before_outs = before_outs.masked_select(
|
||||
out_masks.broadcast_to(before_outs.shape))
|
||||
if after_outs is not None:
|
||||
after_outs = after_outs.masked_select(
|
||||
out_masks.broadcast_to(after_outs.shape))
|
||||
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
|
||||
duration_masks = make_non_pad_mask(ilens)
|
||||
d_outs = d_outs.masked_select(
|
||||
duration_masks.broadcast_to(d_outs.shape))
|
||||
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
|
||||
pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1)
|
||||
p_outs = p_outs.masked_select(
|
||||
pitch_masks.broadcast_to(p_outs.shape))
|
||||
e_outs = e_outs.masked_select(
|
||||
pitch_masks.broadcast_to(e_outs.shape))
|
||||
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
|
||||
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
|
||||
|
||||
# calculate loss
|
||||
l1_loss = self.l1_criterion(before_outs, ys)
|
||||
if after_outs is not None:
|
||||
l1_loss += self.l1_criterion(after_outs, ys)
|
||||
duration_loss = self.duration_criterion(d_outs, ds)
|
||||
pitch_loss = self.mse_criterion(p_outs, ps)
|
||||
energy_loss = self.mse_criterion(e_outs, es)
|
||||
|
||||
# make weighted mask and apply it
|
||||
if self.use_weighted_masking:
|
||||
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
||||
out_weights = out_masks.cast(
|
||||
dtype=paddle.float32) / out_masks.cast(
|
||||
dtype=paddle.float32).sum(axis=1, keepdim=True)
|
||||
out_weights /= ys.shape[0] * ys.shape[2]
|
||||
duration_masks = make_non_pad_mask(ilens)
|
||||
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
|
||||
duration_masks.cast(dtype=paddle.float32).sum(
|
||||
axis=1, keepdim=True))
|
||||
duration_weights /= ds.shape[0]
|
||||
|
||||
# apply weight
|
||||
|
||||
l1_loss = l1_loss.multiply(out_weights)
|
||||
l1_loss = l1_loss.masked_select(
|
||||
out_masks.broadcast_to(l1_loss.shape)).sum()
|
||||
duration_loss = (duration_loss.multiply(duration_weights)
|
||||
.masked_select(duration_masks).sum())
|
||||
pitch_masks = duration_masks.unsqueeze(-1)
|
||||
pitch_weights = duration_weights.unsqueeze(-1)
|
||||
pitch_loss = pitch_loss.multiply(pitch_weights)
|
||||
pitch_loss = pitch_loss.masked_select(
|
||||
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
|
||||
energy_loss = energy_loss.multiply(pitch_weights)
|
||||
energy_loss = energy_loss.masked_select(
|
||||
pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
||||
|
||||
return l1_loss, duration_loss, pitch_loss, energy_loss
|
Loading…
Reference in New Issue