add WIP: speedyspeech model and example with baker dataset.
This commit is contained in:
parent
124dedbd7b
commit
6c21d80025
|
@ -208,8 +208,7 @@ def main():
|
|||
"--rootdir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory including wav files. you need to specify either scp or rootdir."
|
||||
)
|
||||
help="directory to baker dataset.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from parakeet.data.batch import batch_sequences
|
||||
|
||||
|
||||
def collate_baker_examples(examples):
|
||||
# fields = ["phones", "tones", "num_phones", "num_frames", "feats"]
|
||||
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
|
||||
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
|
||||
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
|
||||
num_phones = np.array([item["num_phones"] for item in examples])
|
||||
num_frames = np.array([item["num_frames"] for item in examples])
|
||||
|
||||
phones = batch_sequences(phones)
|
||||
tones = batch_sequences(tones)
|
||||
feats = batch_sequences(feats)
|
||||
batch = {
|
||||
"phones": phones,
|
||||
"tones": tones,
|
||||
"num_phones": num_phones,
|
||||
"num_frames": num_frames,
|
||||
"feats": feats,
|
||||
}
|
||||
return batch
|
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Calculate statistics of feature files."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
import json
|
||||
import jsonlines
|
||||
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.utils.h5_utils import read_hdf5
|
||||
from parakeet.utils.h5_utils import write_hdf5
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute mean and variance of dumped raw features.")
|
||||
parser.add_argument(
|
||||
"--metadata", type=str, help="json file with id and file paths ")
|
||||
parser.add_argument(
|
||||
"--field-name",
|
||||
type=str,
|
||||
help="name of the field to compute statistics for.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
help="directory to save statistics. if not provided, "
|
||||
"stats will be saved in the above root directory. (default=None)")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# set logger
|
||||
if args.verbose > 1:
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
elif args.verbose > 0:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.WARN,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
logging.warning('Skip DEBUG/INFO messages')
|
||||
|
||||
config = get_cfg_default()
|
||||
# load config
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# check directory existence
|
||||
if args.dumpdir is None:
|
||||
args.dumpdir = os.path.dirname(args.metadata)
|
||||
if not os.path.exists(args.dumpdir):
|
||||
os.makedirs(args.dumpdir)
|
||||
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata,
|
||||
fields=[args.field_name],
|
||||
converters={args.field_name: np.load}, )
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# calculate statistics
|
||||
scaler = StandardScaler()
|
||||
for datum in tqdm(dataset):
|
||||
# StandardScalar supports (*, num_features) by default
|
||||
scaler.partial_fit(datum[args.field_name])
|
||||
|
||||
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||
np.save(
|
||||
os.path.join(args.dumpdir, "stats.npy"),
|
||||
stats.astype(np.float32),
|
||||
allow_pickle=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,54 @@
|
|||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
sr: 24000 # Sampling rate.
|
||||
n_fft: 2048 # FFT size.
|
||||
hop_length: 300 # Hop size.
|
||||
win_length: 1200 # Window length.
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
n_mels: 80 # Number of mel basis.
|
||||
fmin: 80 # Minimum freq in mel basis calculation.
|
||||
fmax: 7600 # Maximum frequency in mel basis calculation.
|
||||
# global_gain_scale: 1.0 # Will be multiplied to all of waveform.
|
||||
trim_silence: false # Whether to trim the start and end of silence.
|
||||
top_db: 60 # Need to tune carefully if the recording is not good.
|
||||
trim_frame_length: 2048 # Frame size in trimming.(in samples)
|
||||
trim_hop_length: 512 # Hop size in trimming.(in samples)
|
||||
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 16
|
||||
num_workers: 0
|
||||
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model:
|
||||
vocab_size: 68
|
||||
tone_size: 6
|
||||
encoder_hidden_size: 128
|
||||
encoder_kernel_size: 3
|
||||
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||
duration_predictor_hidden_size: 128
|
||||
decoder_hidden_size: 128
|
||||
decoder_output_size: 80
|
||||
decoder_kernel_size: 3
|
||||
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
|
||||
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
###########################################################
|
||||
|
||||
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 10086
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import yaml
|
||||
from yacs.config import CfgNode as Configuration
|
||||
|
||||
with open("conf/default.yaml", 'rt') as f:
|
||||
_C = yaml.safe_load(f)
|
||||
_C = Configuration(_C)
|
||||
|
||||
|
||||
def get_cfg_default():
|
||||
config = _C.clone()
|
||||
return config
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Normalize feature files and dump them."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from copy import copy
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
import jsonlines
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from parakeet.frontend.vocab import Vocab
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump normalized feature files.")
|
||||
parser.add_argument(
|
||||
"--stats", type=str, required=True, help="statistics file.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# set logger
|
||||
if args.verbose > 1:
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
elif args.verbose > 0:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.WARN,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
logging.warning('Skip DEBUG/INFO messages')
|
||||
|
||||
# load config
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# check directory existence
|
||||
dumpdir = Path(args.dumpdir).resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(metadata, converters={'feats': np.load, })
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# restore scaler
|
||||
scaler = StandardScaler()
|
||||
scaler.mean_ = np.load(args.stats)[0]
|
||||
scaler.scale_ = np.load(args.stats)[1]
|
||||
|
||||
# from version 0.23.0, this information is needed
|
||||
scaler.n_features_in_ = scaler.mean_.shape[0]
|
||||
|
||||
with open("phones.txt", 'rt') as f:
|
||||
phones = [line.strip() for line in f.readlines()]
|
||||
|
||||
with open("tones.txt", 'rt') as f:
|
||||
tones = [line.strip() for line in f.readlines()]
|
||||
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
|
||||
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)
|
||||
|
||||
# process each file
|
||||
output_metadata = []
|
||||
|
||||
for item in tqdm(dataset):
|
||||
utt_id = item['utt_id']
|
||||
mel = item['feats']
|
||||
# normalize
|
||||
mel = scaler.transform(mel)
|
||||
|
||||
# save
|
||||
mel_path = dumpdir / f"{utt_id}-feats.npy"
|
||||
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
|
||||
phone_ids = [voc_phones.lookup(p) for p in item['phones']]
|
||||
tone_ids = [voc_tones.lookup(t) for t in item['tones']]
|
||||
output_metadata.append({
|
||||
'utt_id': utt_id,
|
||||
'phones': phone_ids,
|
||||
'tones': tone_ids,
|
||||
'num_phones': item['num_phones'],
|
||||
'num_frames': item['num_frames'],
|
||||
'durations': item['durations'],
|
||||
'feats': str(mel_path),
|
||||
})
|
||||
output_metadata.sort(key=itemgetter('utt_id'))
|
||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||
for item in output_metadata:
|
||||
writer.write(item)
|
||||
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,99 @@
|
|||
b
|
||||
p
|
||||
m
|
||||
f
|
||||
d
|
||||
t
|
||||
n
|
||||
l
|
||||
g
|
||||
k
|
||||
h
|
||||
zh
|
||||
ch
|
||||
sh
|
||||
r
|
||||
z
|
||||
c
|
||||
s
|
||||
j
|
||||
q
|
||||
x
|
||||
a
|
||||
ar
|
||||
ai
|
||||
air
|
||||
ao
|
||||
aor
|
||||
an
|
||||
anr
|
||||
ang
|
||||
angr
|
||||
e
|
||||
er
|
||||
ei
|
||||
eir
|
||||
en
|
||||
enr
|
||||
eng
|
||||
engr
|
||||
o
|
||||
or
|
||||
ou
|
||||
our
|
||||
ong
|
||||
ongr
|
||||
ii
|
||||
iir
|
||||
iii
|
||||
iiir
|
||||
i
|
||||
ir
|
||||
ia
|
||||
iar
|
||||
iao
|
||||
iaor
|
||||
ian
|
||||
ianr
|
||||
iang
|
||||
iangr
|
||||
ie
|
||||
ier
|
||||
io
|
||||
ior
|
||||
iou
|
||||
iour
|
||||
iong
|
||||
iongr
|
||||
in
|
||||
inr
|
||||
ing
|
||||
ingr
|
||||
u
|
||||
ur
|
||||
ua
|
||||
uar
|
||||
uai
|
||||
uair
|
||||
uan
|
||||
uanr
|
||||
uang
|
||||
uangr
|
||||
uei
|
||||
ueir
|
||||
uo
|
||||
uor
|
||||
uen
|
||||
uenr
|
||||
ueng
|
||||
uengr
|
||||
v
|
||||
vr
|
||||
ve
|
||||
ver
|
||||
van
|
||||
vanr
|
||||
vn
|
||||
vnr
|
||||
sil
|
||||
sp
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import numpy as np
|
||||
import argparse
|
||||
import yaml
|
||||
import json
|
||||
import re
|
||||
import jsonlines
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
from operator import itemgetter
|
||||
from praatio import tgio
|
||||
import logging
|
||||
|
||||
from config import get_cfg_default
|
||||
from tg_utils import validate_textgrid
|
||||
|
||||
|
||||
def logmelfilterbank(audio,
|
||||
sr,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=None,
|
||||
window="hann",
|
||||
n_mels=80,
|
||||
fmin=None,
|
||||
fmax=None,
|
||||
eps=1e-10):
|
||||
"""Compute log-Mel filterbank feature.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : ndarray
|
||||
Audio signal (T,).
|
||||
sr : int
|
||||
Sampling rate.
|
||||
n_fft : int
|
||||
FFT size. (Default value = 1024)
|
||||
hop_length : int
|
||||
Hop size. (Default value = 256)
|
||||
win_length : int
|
||||
Window length. If set to None, it will be the same as fft_size. (Default value = None)
|
||||
window : str
|
||||
Window function type. (Default value = "hann")
|
||||
n_mels : int
|
||||
Number of mel basis. (Default value = 80)
|
||||
fmin : int
|
||||
Minimum frequency in mel basis calculation. (Default value = None)
|
||||
fmax : int
|
||||
Maximum frequency in mel basis calculation. (Default value = None)
|
||||
eps : float
|
||||
Epsilon value to avoid inf in log calculation. (Default value = 1e-10)
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Log Mel filterbank feature (#frames, num_mels).
|
||||
|
||||
"""
|
||||
# get amplitude spectrogram
|
||||
x_stft = librosa.stft(
|
||||
audio,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
pad_mode="reflect")
|
||||
spc = np.abs(x_stft) # (#bins, #frames,)
|
||||
|
||||
# get mel basis
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = sr / 2 if fmax is None else fmax
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
|
||||
return np.log10(np.maximum(eps, np.dot(mel_basis, spc)))
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
alignment_fp: Path,
|
||||
output_dir: Path):
|
||||
utt_id = fp.stem
|
||||
|
||||
# reading
|
||||
y, sr = librosa.load(fp, sr=config.sr) # resampling may occur
|
||||
assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(y).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
duration = librosa.get_duration(y, sr=sr)
|
||||
|
||||
# intervals with empty lables are ignored
|
||||
alignment = tgio.openTextgrid(alignment_fp)
|
||||
|
||||
# validate text grid against audio file
|
||||
num_samples = y.shape[0]
|
||||
validate_textgrid(alignment, num_samples, sr)
|
||||
|
||||
# only with baker's annotation
|
||||
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
|
||||
|
||||
first, last = intervals[0], intervals[-1]
|
||||
if not (first.label == "sil" and first.end < duration):
|
||||
logging.warning(
|
||||
f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}"
|
||||
)
|
||||
if not (last.label == "sil" and last.start < duration):
|
||||
logging.warning(
|
||||
f" There is something wrong with the last interval {last} in utterance: {utt_id}"
|
||||
)
|
||||
|
||||
logmel = logmelfilterbank(
|
||||
y,
|
||||
sr=sr,
|
||||
n_fft=config.n_fft,
|
||||
window=config.window,
|
||||
win_length=config.win_length,
|
||||
hop_length=config.hop_length,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
|
||||
# extract phone and duration
|
||||
phones = []
|
||||
tones = []
|
||||
ends = []
|
||||
durations_sec = []
|
||||
|
||||
for interval in intervals:
|
||||
label = interval.label
|
||||
label = label.replace("sp1", "sp") # Baker has sp1 rather than sp
|
||||
|
||||
# split tone from finals
|
||||
match = re.match(r'^(\w+)([012345])$', label)
|
||||
if match:
|
||||
phones.append(match.group(1))
|
||||
tones.append(match.group(2))
|
||||
else:
|
||||
phones.append(label)
|
||||
tones.append('0')
|
||||
end = min(duration, interval.end)
|
||||
ends.append(end)
|
||||
durations_sec.append(end - interval.start) # duration in seconds
|
||||
|
||||
frame_pos = librosa.time_to_frames(
|
||||
ends, sr=sr, hop_length=config.hop_length)
|
||||
durations_frame = np.diff(frame_pos, prepend=0)
|
||||
|
||||
num_frames = logmel.shape[-1] # number of frames of the spectrogram
|
||||
extra = np.sum(durations_frame) - num_frames
|
||||
assert extra <= 0, (
|
||||
f"Number of frames inferred from alignemnt is "
|
||||
f"larger than number of frames of the spectrogram by {extra} frames")
|
||||
durations_frame[-1] += (-extra)
|
||||
|
||||
assert np.sum(durations_frame) == num_frames
|
||||
durations_frame = durations_frame.tolist()
|
||||
|
||||
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||
np.save(mel_path, logmel.T) # (num_frames, n_mels)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
"tones": tones,
|
||||
"num_phones": len(phones),
|
||||
"num_frames": num_frames,
|
||||
"durations": durations_frame,
|
||||
"feats": str(mel_path.resolve()), # use absolute path
|
||||
}
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(config,
|
||||
fps: List[Path],
|
||||
alignment_fps: List[Path],
|
||||
output_dir: Path,
|
||||
nprocs: int=1):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp, alignment_fp in tqdm.tqdm(
|
||||
zip(fps, alignment_fps), total=len(fps)):
|
||||
results.append(
|
||||
process_sentence(config, fp, alignment_fp, output_dir))
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
alignment_fp, output_dir)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
results.append(ft.result())
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rootdir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory including wav files. you need to specify either scp or rootdir."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
parser.add_argument(
|
||||
"--num_cpu", type=int, default=1, help="number of process.")
|
||||
args = parser.parse_args()
|
||||
|
||||
C = get_cfg_default()
|
||||
if args.config:
|
||||
C.merge_from_file(args.config)
|
||||
C.freeze()
|
||||
|
||||
if args.verbose > 1:
|
||||
print(vars(args))
|
||||
print(C)
|
||||
|
||||
root_dir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
|
||||
alignment_files = sorted(
|
||||
list((root_dir / "PhoneLabeling").rglob("*.interval")))
|
||||
|
||||
# filter out several files that have errors in annotation
|
||||
exclude = {'000611', '000662', '002365', '005107'}
|
||||
wav_files = [f for f in wav_files if f.stem not in exclude]
|
||||
alignment_files = [f for f in alignment_files if f.stem not in exclude]
|
||||
|
||||
# split data into 3 sections
|
||||
num_train = 9800
|
||||
num_dev = 100
|
||||
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
|
||||
train_alignment_files = alignment_files[:num_train]
|
||||
dev_alignment_files = alignment_files[num_train:num_train + num_dev]
|
||||
test_alignment_files = alignment_files[num_train + num_dev:]
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# process for the 3 sections
|
||||
process_sentences(
|
||||
C,
|
||||
train_wav_files,
|
||||
train_alignment_files,
|
||||
train_dump_dir,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
dev_wav_files,
|
||||
dev_alignment_files,
|
||||
dev_dump_dir,
|
||||
nprocs=args.num_cpu)
|
||||
process_sentences(
|
||||
C,
|
||||
test_wav_files,
|
||||
test_alignment_files,
|
||||
test_dump_dir,
|
||||
nprocs=args.num_cpu)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
from paddle.fluid.layers import huber_loss
|
||||
|
||||
from parakeet.modules.ssim import ssim
|
||||
from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.models.speedyspeech import SpeedySpeech
|
||||
|
||||
|
||||
class SpeedySpeechUpdater(StandardUpdater):
|
||||
def update_core(self, batch):
|
||||
decoded, predicted_durations = self.model(
|
||||
text=batch["phonemes"],
|
||||
tones=batch["tones"],
|
||||
plens=batch["phoneme_lenghts"],
|
||||
durations=batch["phoneme_durations"])
|
||||
|
||||
target_mel = batch["mel"]
|
||||
spec_mask = F.sequence_mask(
|
||||
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
|
||||
text_mask = F.sequence_mask(
|
||||
batch["phoneme_lenghts"], dtype=predicted_durations.dtype)
|
||||
|
||||
# spec loss
|
||||
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
|
||||
|
||||
# duration loss
|
||||
target_durations = batch["phoneme_durations"]
|
||||
target_durations = paddle.clip(target_durations, min=1.0)
|
||||
duration_loss = weighted_mean(
|
||||
huber_loss(
|
||||
predicted_durations, paddle.log(target_durations), delta=1.0),
|
||||
text_mask, )
|
||||
|
||||
# ssim loss
|
||||
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
|
||||
(target_mel * spec_mask).unsqueeze(1))
|
||||
|
||||
loss = l1_loss + duration_loss + ssim_loss
|
||||
|
||||
optimizer = self.optimizer
|
||||
optimizer.clear_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
report("train/l1_loss", float(l1_loss))
|
||||
report("train/duration_loss", float(duration_loss))
|
||||
report("train/ssim_loss", float(ssim_loss))
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import librosa
|
||||
from praatio import tgio
|
||||
|
||||
|
||||
def validate_textgrid(text_grid, num_samples, sr):
|
||||
"""Validate Text Grid to make sure that the time interval annotated
|
||||
by the tex grid file does not go beyond the audio file.
|
||||
"""
|
||||
start = text_grid.minTimestamp
|
||||
end = text_grid.maxTimestamp
|
||||
|
||||
end_audio = librosa.samples_to_time(num_samples, sr)
|
||||
return start == 0.0 and end <= end_audio
|
|
@ -0,0 +1,6 @@
|
|||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
import jsonlines
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from paddle.optimizer import Adam # No RAdaom
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from paddle import DataParallel
|
||||
from visualdl import LogWriter
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.speedyspeech import SpeedySpeech
|
||||
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training import extension
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.seeding import seed_everything
|
||||
|
||||
from batch_fn import collate_baker_examples
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
if not paddle.is_compiled_with_cuda:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
|
||||
converters={"feats": np.load, }, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
|
||||
converters={"feats": np.load, }, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dev_sampler = DistributedBatchSampler(
|
||||
dev_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_baker_examples,
|
||||
num_workers=config.num_workers)
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
collate_fn=collate_baker_examples,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
|
||||
# batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
||||
# import pdb; pdb.set_trace()
|
||||
model = SpeedySpeech(**config["model"])
|
||||
print(model)
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
|
||||
"model with Baker Mandrin TTS dataset.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device == "cpu" and args.nprocs > 1:
|
||||
raise RuntimeError("Multiprocess training on CPU is not supported.")
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.nprocs > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -161,3 +161,27 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
|
|||
mode='constant',
|
||||
constant_values=pad_value))
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
||||
|
||||
def batch_sequences(sequences, axis=0, pad_value=0):
|
||||
# import pdb; pdb.set_trace()
|
||||
seq = sequences[0]
|
||||
ndim = seq.ndim
|
||||
if axis < 0:
|
||||
axis += ndim
|
||||
dtype = seq.dtype
|
||||
pad_value = dtype.type(pad_value)
|
||||
seq_lengths = [seq.shape[axis] for seq in sequences]
|
||||
max_length = np.max(seq_lengths)
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
|
||||
ndim - axis - 1)
|
||||
padded_seq = np.pad(seq,
|
||||
padding,
|
||||
mode='constant',
|
||||
constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
batch = np.stack(padded_sequences)
|
||||
return batch
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from parakeet.modules.positional_encoding import sinusoid_position_encoding
|
||||
from parakeet.modules.expansion import expand
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
def __init__(self, channels, kernel_size, dilation, n=2):
|
||||
super().__init__()
|
||||
blocks = [
|
||||
nn.Sequential(
|
||||
nn.Conv1D(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding="same",
|
||||
data_format="NLC"),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1D(
|
||||
channels, data_format="NLC"), ) for _ in range(n)
|
||||
]
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.blocks(x)
|
||||
|
||||
|
||||
class TextEmbedding(nn.Layer):
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
embedding_size: int,
|
||||
tone_vocab_size: int=None,
|
||||
tone_embedding_size: int=None,
|
||||
padding_idx: int=None,
|
||||
tone_padding_idx: int=None,
|
||||
concat: bool=False):
|
||||
super().__init__()
|
||||
self.text_embedding = nn.Embedding(vocab_size, embedding_size,
|
||||
padding_idx)
|
||||
if tone_vocab_size:
|
||||
tone_embedding_size = tone_embedding_size or embedding_size
|
||||
if tone_embedding_size != embedding_size and not concat:
|
||||
raise ValueError(
|
||||
"embedding size != tone_embedding size, only conat is avaiable."
|
||||
)
|
||||
self.tone_embedding = nn.Embedding(
|
||||
tone_vocab_size, tone_embedding_size, tone_padding_idx)
|
||||
self.concat = concat
|
||||
|
||||
def forward(self, text, tone=None):
|
||||
text_embed = self.text_embedding(text)
|
||||
if tone is None:
|
||||
return text_embed
|
||||
tone_embed = self.tone_embedding(tone)
|
||||
if self.concat:
|
||||
embed = paddle.concat([text_embed, tone_embed], -1)
|
||||
else:
|
||||
embed = text_embed + tone_embed
|
||||
return embed
|
||||
|
||||
|
||||
class SpeedySpeechEncoder(nn.Layer):
|
||||
def __init__(self, vocab_size, tone_size, hidden_size, kernel_size,
|
||||
dilations):
|
||||
super().__init__()
|
||||
self.embedding = TextEmbedding(
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
tone_size,
|
||||
padding_idx=0,
|
||||
tone_padding_idx=0)
|
||||
self.prenet = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
nn.ReLU(), )
|
||||
res_blocks = [
|
||||
ResidualBlock(
|
||||
hidden_size, kernel_size, d, n=2) for d in dilations
|
||||
]
|
||||
self.res_blocks = nn.Sequential(*res_blocks)
|
||||
|
||||
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
|
||||
self.postnet2 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1D(
|
||||
hidden_size, data_format="NLC"),
|
||||
nn.Linear(hidden_size, hidden_size), )
|
||||
|
||||
def forward(self, text, tones):
|
||||
embedding = self.embedding(text, tones)
|
||||
embedding = self.prenet(embedding)
|
||||
x = self.res_blocks(embedding)
|
||||
x = embedding + self.postnet1(x)
|
||||
x = self.postnet2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DurationPredictor(nn.Layer):
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
ResidualBlock(
|
||||
hidden_size, 4, 1, n=1),
|
||||
ResidualBlock(
|
||||
hidden_size, 3, 1, n=1),
|
||||
ResidualBlock(
|
||||
hidden_size, 1, 1, n=1),
|
||||
nn.Linear(hidden_size, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return paddle.squeeze(self.layers(x), -1)
|
||||
|
||||
|
||||
class SpeedySpeechDecoder(nn.Layer):
|
||||
def __init__(self, hidden_size, output_size, kernel_size, dilations):
|
||||
super().__init__()
|
||||
res_blocks = [
|
||||
ResidualBlock(
|
||||
hidden_size, kernel_size, d, n=2) for d in dilations
|
||||
]
|
||||
self.res_blocks = nn.Sequential(*res_blocks)
|
||||
|
||||
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
|
||||
self.postnet2 = nn.Sequential(
|
||||
ResidualBlock(
|
||||
hidden_size, kernel_size, 1, n=2),
|
||||
nn.Linear(hidden_size, output_size))
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.res_blocks(x)
|
||||
x = x + self.postnet1(xx)
|
||||
x = self.postnet2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpeedySpeech(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
encoder_hidden_size,
|
||||
encoder_kernel_size,
|
||||
encoder_dilations,
|
||||
duration_predictor_hidden_size,
|
||||
decoder_hidden_size,
|
||||
decoder_output_size,
|
||||
decoder_kernel_size,
|
||||
decoder_dilations,
|
||||
tone_size=None, ):
|
||||
super().__init__()
|
||||
encoder = SpeedySpeechEncoder(vocab_size, tone_size,
|
||||
encoder_hidden_size, encoder_kernel_size,
|
||||
encoder_dilations)
|
||||
duration_predictor = DurationPredictor(duration_predictor_hidden_size)
|
||||
decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size,
|
||||
decoder_kernel_size, decoder_dilations)
|
||||
|
||||
self.encoder = encoder
|
||||
self.duration_predictor = duration_predictor
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, text, tones, plens, durations):
|
||||
encodings = self.encoder(text, tones)
|
||||
pred_durations = self.duration_predictor(encodings.detach()) # (B, T)
|
||||
|
||||
# expand encodings
|
||||
durations_to_expand = durations
|
||||
encodings = expand(encodings, durations_to_expand)
|
||||
|
||||
# decode
|
||||
# remove positional encoding here
|
||||
_, t_dec, feature_size = encodings.shpae
|
||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||
decoded = self.decoder(encodings)
|
||||
return decoded, pred_durations
|
||||
|
||||
def inference(self, text, tones):
|
||||
# text: [T]
|
||||
# tones: [T]
|
||||
text = text.unsqueeze(0)
|
||||
if tones is not None:
|
||||
tones = tones.unsqueeze(0)
|
||||
|
||||
encodings = self.encoder(text, tones)
|
||||
pred_durations = self.duration_predictor(encodings) # (1, T)
|
||||
durations_to_expand = paddle.round(pred_durations.exp())
|
||||
durations_to_expand = (durations_to_expand).astype(paddle.int64)
|
||||
encodings = expand(encodings, durations_to_expand)
|
||||
|
||||
shape = paddle.shape(encodings)
|
||||
t_dec, feature_size = shape[1], shape[2]
|
||||
encodings += sinusoid_position_encoding(t_dec, feature_size)
|
||||
decoded = self.decoder(encodings)
|
||||
return decoded, pred_durations
|
|
@ -403,7 +403,7 @@ class TransformerTTS(nn.Layer):
|
|||
else:
|
||||
self.toned = False
|
||||
# position encoding matrix may be extended later
|
||||
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
|
||||
self.encoder_pe = pe.sinusoid_positional_encoding(1000, d_encoder)
|
||||
self.encoder_pe_scalar = self.create_parameter(
|
||||
[1], attr=I.Constant(1.))
|
||||
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
|
||||
|
@ -411,7 +411,7 @@ class TransformerTTS(nn.Layer):
|
|||
|
||||
# decoder
|
||||
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
|
||||
self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder)
|
||||
self.decoder_pe = pe.sinusoid_positional_encoding(1000, d_decoder)
|
||||
self.decoder_pe_scalar = self.create_parameter(
|
||||
[1], attr=I.Constant(1.))
|
||||
self.decoder = TransformerDecoder(
|
||||
|
@ -488,7 +488,7 @@ class TransformerTTS(nn.Layer):
|
|||
# twice its length if needed
|
||||
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
|
||||
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
||||
self.decoder_pe = pe.sinusoid_positional_encoding(0, new_T,
|
||||
self.decoder_pe = pe.sinusoid_positional_encoding(new_T,
|
||||
self.d_decoder)
|
||||
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
|
||||
x = x.scale(math.sqrt(
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
|
||||
|
||||
def expand(encodings: Tensor, durations: Tensor) -> Tensor:
|
||||
"""
|
||||
encodings: (B, T, C)
|
||||
durations: (B, T)
|
||||
"""
|
||||
batch_size, t_enc = durations.shape
|
||||
durations = durations.numpy()
|
||||
slens = np.sum(durations, -1)
|
||||
t_dec = np.max(slens)
|
||||
M = np.zeros([batch_size, t_dec, t_enc])
|
||||
for i in range(batch_size):
|
||||
k = 0
|
||||
for j in range(t_enc):
|
||||
d = durations[i, j]
|
||||
M[i, k:k + d, j] = 1
|
||||
k += d
|
||||
M = paddle.to_tensor(M, dtype=encodings.dtype)
|
||||
encodings = paddle.matmul(M, encodings)
|
||||
return encodings
|
|
@ -14,47 +14,56 @@
|
|||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
from paddle.nn import functional as F
|
||||
|
||||
__all__ = ["sinusoid_positional_encoding"]
|
||||
__all__ = ["sinusoid_position_encoding", "scaled_position_encoding"]
|
||||
|
||||
|
||||
def sinusoid_positional_encoding(start_index, length, size, dtype=None):
|
||||
r"""Generate standard positional encoding matrix.
|
||||
|
||||
.. math::
|
||||
|
||||
pe(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{size}}}) \\
|
||||
pe(pos, 2i+1) = cos(\frac{pos}{10000^{\frac{2i}{size}}})
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_index : int
|
||||
The start index.
|
||||
length : int
|
||||
The timesteps of the positional encoding to generate.
|
||||
size : int
|
||||
Feature size of positional encoding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(length, size)]
|
||||
The positional encoding.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``size`` is not divisible by 2.
|
||||
"""
|
||||
if (size % 2 != 0):
|
||||
def sinusoid_position_encoding(num_positions: int,
|
||||
feature_size: int,
|
||||
omega: float=1.0,
|
||||
start_pos: int=0,
|
||||
dtype=None) -> Tensor:
|
||||
# return tensor shape (num_positions, feature_size)
|
||||
if (feature_size % 2 != 0):
|
||||
raise ValueError("size should be divisible by 2")
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
channel = np.arange(0, size, 2)
|
||||
index = np.arange(start_index, start_index + length, 1)
|
||||
p = np.expand_dims(index, -1) / (10000**(channel / float(size)))
|
||||
encodings = np.zeros([length, size])
|
||||
encodings[:, 0::2] = np.sin(p)
|
||||
encodings[:, 1::2] = np.cos(p)
|
||||
encodings = paddle.to_tensor(encodings)
|
||||
|
||||
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
|
||||
index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype)
|
||||
p = (paddle.unsqueeze(index, -1) *
|
||||
omega) / (10000.0**(channel / float(feature_size)))
|
||||
encodings = paddle.zeros([num_positions, feature_size], dtype=dtype)
|
||||
encodings[:, 0::2] = paddle.sin(p)
|
||||
encodings[:, 1::2] = paddle.cos(p)
|
||||
return encodings
|
||||
|
||||
|
||||
def scaled_position_encoding(num_positions: int,
|
||||
feature_size: int,
|
||||
omega: Tensor,
|
||||
start_pos: int=0,
|
||||
dtype=None) -> Tensor:
|
||||
# omega: Tensor (batch_size, )
|
||||
# return tensor shape (batch_size, num_positions, feature_size)
|
||||
# consider renaming this as batched positioning encoding
|
||||
if (feature_size % 2 != 0):
|
||||
raise ValueError("size should be divisible by 2")
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
|
||||
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
|
||||
index = paddle.arange(
|
||||
start_pos, start_pos + num_positions, 1, dtype=omega.dtype)
|
||||
batch_size = omega.shape[0]
|
||||
omega = paddle.unsqueeze(omega, [1, 2])
|
||||
p = (paddle.unsqueeze(index, -1) *
|
||||
omega) / (10000.0**(channel / float(feature_size)))
|
||||
encodings = paddle.zeros(
|
||||
[batch_size, num_positions, feature_size], dtype=dtype)
|
||||
# it is nice to have fancy indexing and inplace operations
|
||||
encodings[:, :, 0::2] = paddle.sin(p)
|
||||
encodings[:, :, 1::2] = paddle.cos(p)
|
||||
return encodings
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from math import exp
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = paddle.to_tensor([
|
||||
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
|
||||
for x in range(window_size)
|
||||
])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = paddle.matmul(_1D_window,
|
||||
paddle.transpose(_1D_window,
|
||||
[1, 0])).unsqueeze([0, 1])
|
||||
window = paddle.expand(_2D_window, [channel, 1, window_size, window_size])
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(
|
||||
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(
|
||||
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(
|
||||
img1 * img2, window, padding=window_size // 2,
|
||||
groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \
|
||||
/ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(nn.Layer):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
return _ssim(img1, img2, self.window, self.window_size, self.channel,
|
||||
self.size_average)
|
||||
|
||||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.shape
|
||||
window = create_window(window_size, channel)
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from parakeet.modules import expansion
|
||||
|
||||
|
||||
def test_expand():
|
||||
x = paddle.randn([2, 4, 3]) # (B, T, C)
|
||||
lengths = paddle.to_tensor([[1, 2, 2, 1], [3, 1, 4, 0]])
|
||||
y = expansion.expand(x, lengths)
|
||||
|
||||
assert y.shape == [2, 8, 3]
|
||||
print("the first sequence")
|
||||
print(y[0])
|
||||
|
||||
print("the second sequence")
|
||||
print(y[1])
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
from paddle.jit import to_static
|
||||
from paddle.static import InputSpec
|
||||
|
||||
|
||||
def test_applicative_evaluation():
|
||||
def m_sqrt2(x):
|
||||
return paddle.scale(x, math.sqrt(2))
|
||||
|
||||
subgraph = to_static(m_sqrt2, input_spec=[InputSpec([-1])])
|
||||
paddle.jit.save(subgraph, './temp_test_to_static')
|
||||
|
||||
fn = paddle.jit.load('./temp_test_to_static')
|
||||
x = paddle.arange(10, dtype=paddle.float32)
|
||||
y = fn(x)
|
||||
|
||||
print(x)
|
||||
print(y)
|
Loading…
Reference in New Issue