add WIP: speedyspeech model and example with baker dataset.

This commit is contained in:
chenfeiyu 2021-07-08 16:47:08 +08:00
parent 124dedbd7b
commit 6c21d80025
20 changed files with 1505 additions and 41 deletions

View File

@ -208,8 +208,7 @@ def main():
"--rootdir",
default=None,
type=str,
help="directory including wav files. you need to specify either scp or rootdir."
)
help="directory to baker dataset.")
parser.add_argument(
"--dumpdir",
type=str,

View File

@ -0,0 +1,37 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from parakeet.data.batch import batch_sequences
def collate_baker_examples(examples):
# fields = ["phones", "tones", "num_phones", "num_frames", "feats"]
phones = [np.array(item["phones"], dtype=np.int64) for item in examples]
tones = [np.array(item["tones"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
num_phones = np.array([item["num_phones"] for item in examples])
num_frames = np.array([item["num_frames"] for item in examples])
phones = batch_sequences(phones)
tones = batch_sequences(tones)
feats = batch_sequences(feats)
batch = {
"phones": phones,
"tones": tones,
"num_phones": num_phones,
"num_frames": num_frames,
"feats": feats,
}
return batch

View File

@ -0,0 +1,110 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Calculate statistics of feature files."""
import argparse
import logging
import os
import numpy as np
import yaml
import json
import jsonlines
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from parakeet.datasets.data_table import DataTable
from parakeet.utils.h5_utils import read_hdf5
from parakeet.utils.h5_utils import write_hdf5
from config import get_cfg_default
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Compute mean and variance of dumped raw features.")
parser.add_argument(
"--metadata", type=str, help="json file with id and file paths ")
parser.add_argument(
"--field-name",
type=str,
help="name of the field to compute statistics for.")
parser.add_argument(
"--config", type=str, help="yaml format configuration file.")
parser.add_argument(
"--dumpdir",
type=str,
help="directory to save statistics. if not provided, "
"stats will be saved in the above root directory. (default=None)")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
config = get_cfg_default()
# load config
if args.config:
config.merge_from_file(args.config)
# check directory existence
if args.dumpdir is None:
args.dumpdir = os.path.dirname(args.metadata)
if not os.path.exists(args.dumpdir):
os.makedirs(args.dumpdir)
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata,
fields=[args.field_name],
converters={args.field_name: np.load}, )
logging.info(f"The number of files = {len(dataset)}.")
# calculate statistics
scaler = StandardScaler()
for datum in tqdm(dataset):
# StandardScalar supports (*, num_features) by default
scaler.partial_fit(datum[args.field_name])
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
np.save(
os.path.join(args.dumpdir, "stats.npy"),
stats.astype(np.float32),
allow_pickle=False)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,54 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
sr: 24000 # Sampling rate.
n_fft: 2048 # FFT size.
hop_length: 300 # Hop size.
win_length: 1200 # Window length.
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
n_mels: 80 # Number of mel basis.
fmin: 80 # Minimum freq in mel basis calculation.
fmax: 7600 # Maximum frequency in mel basis calculation.
# global_gain_scale: 1.0 # Will be multiplied to all of waveform.
trim_silence: false # Whether to trim the start and end of silence.
top_db: 60 # Need to tune carefully if the recording is not good.
trim_frame_length: 2048 # Frame size in trimming.(in samples)
trim_hop_length: 512 # Hop size in trimming.(in samples)
###########################################################
# DATA SETTING #
###########################################################
batch_size: 16
num_workers: 0
###########################################################
# MODEL SETTING #
###########################################################
model:
vocab_size: 68
tone_size: 6
encoder_hidden_size: 128
encoder_kernel_size: 3
encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
duration_predictor_hidden_size: 128
decoder_hidden_size: 128
decoder_output_size: 80
decoder_kernel_size: 3
decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1]
###########################################################
# OPTIMIZER SETTING #
###########################################################
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

View File

@ -0,0 +1,25 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from yacs.config import CfgNode as Configuration
with open("conf/default.yaml", 'rt') as f:
_C = yaml.safe_load(f)
_C = Configuration(_C)
def get_cfg_default():
config = _C.clone()
return config

View File

@ -0,0 +1,143 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalize feature files and dump them."""
import argparse
import logging
import os
from copy import copy
from operator import itemgetter
from pathlib import Path
import numpy as np
import yaml
import jsonlines
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from parakeet.frontend.vocab import Vocab
from parakeet.datasets.data_table import DataTable
from config import get_cfg_default
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump normalized feature files.")
parser.add_argument(
"--stats", type=str, required=True, help="statistics file.")
parser.add_argument(
"--config", type=str, help="yaml format configuration file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
# load config
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
# check directory existence
dumpdir = Path(args.dumpdir).resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(metadata, converters={'feats': np.load, })
logging.info(f"The number of files = {len(dataset)}.")
# restore scaler
scaler = StandardScaler()
scaler.mean_ = np.load(args.stats)[0]
scaler.scale_ = np.load(args.stats)[1]
# from version 0.23.0, this information is needed
scaler.n_features_in_ = scaler.mean_.shape[0]
with open("phones.txt", 'rt') as f:
phones = [line.strip() for line in f.readlines()]
with open("tones.txt", 'rt') as f:
tones = [line.strip() for line in f.readlines()]
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)
# process each file
output_metadata = []
for item in tqdm(dataset):
utt_id = item['utt_id']
mel = item['feats']
# normalize
mel = scaler.transform(mel)
# save
mel_path = dumpdir / f"{utt_id}-feats.npy"
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
phone_ids = [voc_phones.lookup(p) for p in item['phones']]
tone_ids = [voc_tones.lookup(t) for t in item['tones']]
output_metadata.append({
'utt_id': utt_id,
'phones': phone_ids,
'tones': tone_ids,
'num_phones': item['num_phones'],
'num_frames': item['num_frames'],
'durations': item['durations'],
'feats': str(mel_path),
})
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
writer.write(item)
logging.info(f"metadata dumped into {output_metadata_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,99 @@
b
p
m
f
d
t
n
l
g
k
h
zh
ch
sh
r
z
c
s
j
q
x
a
ar
ai
air
ao
aor
an
anr
ang
angr
e
er
ei
eir
en
enr
eng
engr
o
or
ou
our
ong
ongr
ii
iir
iii
iiir
i
ir
ia
iar
iao
iaor
ian
ianr
iang
iangr
ie
ier
io
ior
iou
iour
iong
iongr
in
inr
ing
ingr
u
ur
ua
uar
uai
uair
uan
uanr
uang
uangr
uei
ueir
uo
uor
uen
uenr
ueng
uengr
v
vr
ve
ver
van
vanr
vn
vnr
sil
sp

View File

@ -0,0 +1,311 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Dict, Any
import soundfile as sf
import librosa
import numpy as np
import argparse
import yaml
import json
import re
import jsonlines
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from pathlib import Path
import tqdm
from operator import itemgetter
from praatio import tgio
import logging
from config import get_cfg_default
from tg_utils import validate_textgrid
def logmelfilterbank(audio,
sr,
n_fft=1024,
hop_length=256,
win_length=None,
window="hann",
n_mels=80,
fmin=None,
fmax=None,
eps=1e-10):
"""Compute log-Mel filterbank feature.
Parameters
----------
audio : ndarray
Audio signal (T,).
sr : int
Sampling rate.
n_fft : int
FFT size. (Default value = 1024)
hop_length : int
Hop size. (Default value = 256)
win_length : int
Window length. If set to None, it will be the same as fft_size. (Default value = None)
window : str
Window function type. (Default value = "hann")
n_mels : int
Number of mel basis. (Default value = 80)
fmin : int
Minimum frequency in mel basis calculation. (Default value = None)
fmax : int
Maximum frequency in mel basis calculation. (Default value = None)
eps : float
Epsilon value to avoid inf in log calculation. (Default value = 1e-10)
Returns
-------
np.ndarray
Log Mel filterbank feature (#frames, num_mels).
"""
# get amplitude spectrogram
x_stft = librosa.stft(
audio,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window,
pad_mode="reflect")
spc = np.abs(x_stft) # (#bins, #frames,)
# get mel basis
fmin = 0 if fmin is None else fmin
fmax = sr / 2 if fmax is None else fmax
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
return np.log10(np.maximum(eps, np.dot(mel_basis, spc)))
def process_sentence(config: Dict[str, Any],
fp: Path,
alignment_fp: Path,
output_dir: Path):
utt_id = fp.stem
# reading
y, sr = librosa.load(fp, sr=config.sr) # resampling may occur
assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio."
assert np.abs(y).max(
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
duration = librosa.get_duration(y, sr=sr)
# intervals with empty lables are ignored
alignment = tgio.openTextgrid(alignment_fp)
# validate text grid against audio file
num_samples = y.shape[0]
validate_textgrid(alignment, num_samples, sr)
# only with baker's annotation
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList
first, last = intervals[0], intervals[-1]
if not (first.label == "sil" and first.end < duration):
logging.warning(
f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}"
)
if not (last.label == "sil" and last.start < duration):
logging.warning(
f" There is something wrong with the last interval {last} in utterance: {utt_id}"
)
logmel = logmelfilterbank(
y,
sr=sr,
n_fft=config.n_fft,
window=config.window,
win_length=config.win_length,
hop_length=config.hop_length,
n_mels=config.n_mels,
fmin=config.fmin,
fmax=config.fmax)
# extract phone and duration
phones = []
tones = []
ends = []
durations_sec = []
for interval in intervals:
label = interval.label
label = label.replace("sp1", "sp") # Baker has sp1 rather than sp
# split tone from finals
match = re.match(r'^(\w+)([012345])$', label)
if match:
phones.append(match.group(1))
tones.append(match.group(2))
else:
phones.append(label)
tones.append('0')
end = min(duration, interval.end)
ends.append(end)
durations_sec.append(end - interval.start) # duration in seconds
frame_pos = librosa.time_to_frames(
ends, sr=sr, hop_length=config.hop_length)
durations_frame = np.diff(frame_pos, prepend=0)
num_frames = logmel.shape[-1] # number of frames of the spectrogram
extra = np.sum(durations_frame) - num_frames
assert extra <= 0, (
f"Number of frames inferred from alignemnt is "
f"larger than number of frames of the spectrogram by {extra} frames")
durations_frame[-1] += (-extra)
assert np.sum(durations_frame) == num_frames
durations_frame = durations_frame.tolist()
mel_path = output_dir / (utt_id + "_feats.npy")
np.save(mel_path, logmel.T) # (num_frames, n_mels)
record = {
"utt_id": utt_id,
"phones": phones,
"tones": tones,
"num_phones": len(phones),
"num_frames": num_frames,
"durations": durations_frame,
"feats": str(mel_path.resolve()), # use absolute path
}
return record
def process_sentences(config,
fps: List[Path],
alignment_fps: List[Path],
output_dir: Path,
nprocs: int=1):
if nprocs == 1:
results = []
for fp, alignment_fp in tqdm.tqdm(
zip(fps, alignment_fps), total=len(fps)):
results.append(
process_sentence(config, fp, alignment_fp, output_dir))
else:
with ThreadPoolExecutor(nprocs) as pool:
futures = []
with tqdm.tqdm(total=len(fps)) as progress:
for fp, alignment_fp in zip(fps, alignment_fps):
future = pool.submit(process_sentence, config, fp,
alignment_fp, output_dir)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
results = []
for ft in futures:
results.append(ft.result())
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results:
writer.write(item)
print("Done")
def main():
# parse config and args
parser = argparse.ArgumentParser(
description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)."
)
parser.add_argument(
"--rootdir",
default=None,
type=str,
help="directory including wav files. you need to specify either scp or rootdir."
)
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump feature files.")
parser.add_argument(
"--config", type=str, help="yaml format configuration file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument(
"--num_cpu", type=int, default=1, help="number of process.")
args = parser.parse_args()
C = get_cfg_default()
if args.config:
C.merge_from_file(args.config)
C.freeze()
if args.verbose > 1:
print(vars(args))
print(C)
root_dir = Path(args.rootdir).expanduser()
dumpdir = Path(args.dumpdir).expanduser()
dumpdir.mkdir(parents=True, exist_ok=True)
wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
alignment_files = sorted(
list((root_dir / "PhoneLabeling").rglob("*.interval")))
# filter out several files that have errors in annotation
exclude = {'000611', '000662', '002365', '005107'}
wav_files = [f for f in wav_files if f.stem not in exclude]
alignment_files = [f for f in alignment_files if f.stem not in exclude]
# split data into 3 sections
num_train = 9800
num_dev = 100
train_wav_files = wav_files[:num_train]
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
train_alignment_files = alignment_files[:num_train]
dev_alignment_files = alignment_files[num_train:num_train + num_dev]
test_alignment_files = alignment_files[num_train + num_dev:]
train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" / "raw"
dev_dump_dir.mkdir(parents=True, exist_ok=True)
test_dump_dir = dumpdir / "test" / "raw"
test_dump_dir.mkdir(parents=True, exist_ok=True)
# process for the 3 sections
process_sentences(
C,
train_wav_files,
train_alignment_files,
train_dump_dir,
nprocs=args.num_cpu)
process_sentences(
C,
dev_wav_files,
dev_alignment_files,
dev_dump_dir,
nprocs=args.num_cpu)
process_sentences(
C,
test_wav_files,
test_alignment_files,
test_dump_dir,
nprocs=args.num_cpu)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,65 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.nn import functional as F
from paddle.fluid.layers import huber_loss
from parakeet.modules.ssim import ssim
from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.models.speedyspeech import SpeedySpeech
class SpeedySpeechUpdater(StandardUpdater):
def update_core(self, batch):
decoded, predicted_durations = self.model(
text=batch["phonemes"],
tones=batch["tones"],
plens=batch["phoneme_lenghts"],
durations=batch["phoneme_durations"])
target_mel = batch["mel"]
spec_mask = F.sequence_mask(
batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1)
text_mask = F.sequence_mask(
batch["phoneme_lenghts"], dtype=predicted_durations.dtype)
# spec loss
l1_loss = masked_l1_loss(decoded, target_mel, spec_mask)
# duration loss
target_durations = batch["phoneme_durations"]
target_durations = paddle.clip(target_durations, min=1.0)
duration_loss = weighted_mean(
huber_loss(
predicted_durations, paddle.log(target_durations), delta=1.0),
text_mask, )
# ssim loss
ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1),
(target_mel * spec_mask).unsqueeze(1))
loss = l1_loss + duration_loss + ssim_loss
optimizer = self.optimizer
optimizer.clear_grad()
loss.backward()
optimizer.step()
report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss))
report("train/ssim_loss", float(ssim_loss))

View File

@ -0,0 +1,27 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
from praatio import tgio
def validate_textgrid(text_grid, num_samples, sr):
"""Validate Text Grid to make sure that the time interval annotated
by the tex grid file does not go beyond the audio file.
"""
start = text_grid.minTimestamp
end = text_grid.maxTimestamp
end_audio = librosa.samples_to_time(num_samples, sr)
return start == 0.0 and end <= end_audio

View File

@ -0,0 +1,6 @@
0
1
2
3
4
5

View File

@ -0,0 +1,155 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
import dataclasses
from pathlib import Path
import yaml
import jsonlines
import paddle
import numpy as np
from paddle import nn
from paddle.nn import functional as F
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.optimizer import Adam # No RAdaom
from paddle.optimizer.lr import StepDecay
from paddle import DataParallel
from visualdl import LogWriter
from parakeet.datasets.data_table import DataTable
from parakeet.models.speedyspeech import SpeedySpeech
from parakeet.training.updater import UpdaterBase
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training import extension
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.seeding import seed_everything
from batch_fn import collate_baker_examples
from config import get_cfg_default
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if not paddle.is_compiled_with_cuda:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
converters={"feats": np.load, }, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=["phones", "tones", "num_phones", "num_frames", "feats"],
converters={"feats": np.load, }, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=False)
print("samplers done!")
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_baker_examples,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=collate_baker_examples,
num_workers=config.num_workers)
print("dataloaders done!")
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
# batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
# import pdb; pdb.set_trace()
model = SpeedySpeech(**config["model"])
print(model)
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN "
"model with Baker Mandrin TTS dataset.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config")
parser.add_argument("--train-metadata", type=str, help="training data")
parser.add_argument("--dev-metadata", type=str, help="dev data")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument(
"--nprocs", type=int, default=1, help="number of processes")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
if args.device == "cpu" and args.nprocs > 1:
raise RuntimeError("Multiprocess training on CPU is not supported.")
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.nprocs > 1:
dist.spawn(train_sp, (args, config), nprocs=args.nprocs)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

View File

@ -161,3 +161,27 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
mode='constant',
constant_values=pad_value))
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
def batch_sequences(sequences, axis=0, pad_value=0):
# import pdb; pdb.set_trace()
seq = sequences[0]
ndim = seq.ndim
if axis < 0:
axis += ndim
dtype = seq.dtype
pad_value = dtype.type(pad_value)
seq_lengths = [seq.shape[axis] for seq in sequences]
max_length = np.max(seq_lengths)
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (
ndim - axis - 1)
padded_seq = np.pad(seq,
padding,
mode='constant',
constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
return batch

View File

@ -0,0 +1,214 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import Tensor
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from parakeet.modules.positional_encoding import sinusoid_position_encoding
from parakeet.modules.expansion import expand
class ResidualBlock(nn.Layer):
def __init__(self, channels, kernel_size, dilation, n=2):
super().__init__()
blocks = [
nn.Sequential(
nn.Conv1D(
channels,
channels,
kernel_size,
dilation=dilation,
padding="same",
data_format="NLC"),
nn.ReLU(),
nn.BatchNorm1D(
channels, data_format="NLC"), ) for _ in range(n)
]
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
return x + self.blocks(x)
class TextEmbedding(nn.Layer):
def __init__(self,
vocab_size: int,
embedding_size: int,
tone_vocab_size: int=None,
tone_embedding_size: int=None,
padding_idx: int=None,
tone_padding_idx: int=None,
concat: bool=False):
super().__init__()
self.text_embedding = nn.Embedding(vocab_size, embedding_size,
padding_idx)
if tone_vocab_size:
tone_embedding_size = tone_embedding_size or embedding_size
if tone_embedding_size != embedding_size and not concat:
raise ValueError(
"embedding size != tone_embedding size, only conat is avaiable."
)
self.tone_embedding = nn.Embedding(
tone_vocab_size, tone_embedding_size, tone_padding_idx)
self.concat = concat
def forward(self, text, tone=None):
text_embed = self.text_embedding(text)
if tone is None:
return text_embed
tone_embed = self.tone_embedding(tone)
if self.concat:
embed = paddle.concat([text_embed, tone_embed], -1)
else:
embed = text_embed + tone_embed
return embed
class SpeedySpeechEncoder(nn.Layer):
def __init__(self, vocab_size, tone_size, hidden_size, kernel_size,
dilations):
super().__init__()
self.embedding = TextEmbedding(
vocab_size,
hidden_size,
tone_size,
padding_idx=0,
tone_padding_idx=0)
self.prenet = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(), )
res_blocks = [
ResidualBlock(
hidden_size, kernel_size, d, n=2) for d in dilations
]
self.res_blocks = nn.Sequential(*res_blocks)
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
self.postnet2 = nn.Sequential(
nn.ReLU(),
nn.BatchNorm1D(
hidden_size, data_format="NLC"),
nn.Linear(hidden_size, hidden_size), )
def forward(self, text, tones):
embedding = self.embedding(text, tones)
embedding = self.prenet(embedding)
x = self.res_blocks(embedding)
x = embedding + self.postnet1(x)
x = self.postnet2(x)
return x
class DurationPredictor(nn.Layer):
def __init__(self, hidden_size):
super().__init__()
self.layers = nn.Sequential(
ResidualBlock(
hidden_size, 4, 1, n=1),
ResidualBlock(
hidden_size, 3, 1, n=1),
ResidualBlock(
hidden_size, 1, 1, n=1),
nn.Linear(hidden_size, 1))
def forward(self, x):
return paddle.squeeze(self.layers(x), -1)
class SpeedySpeechDecoder(nn.Layer):
def __init__(self, hidden_size, output_size, kernel_size, dilations):
super().__init__()
res_blocks = [
ResidualBlock(
hidden_size, kernel_size, d, n=2) for d in dilations
]
self.res_blocks = nn.Sequential(*res_blocks)
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
self.postnet2 = nn.Sequential(
ResidualBlock(
hidden_size, kernel_size, 1, n=2),
nn.Linear(hidden_size, output_size))
def forward(self, x):
xx = self.res_blocks(x)
x = x + self.postnet1(xx)
x = self.postnet2(x)
return x
class SpeedySpeech(nn.Layer):
def __init__(
self,
vocab_size,
encoder_hidden_size,
encoder_kernel_size,
encoder_dilations,
duration_predictor_hidden_size,
decoder_hidden_size,
decoder_output_size,
decoder_kernel_size,
decoder_dilations,
tone_size=None, ):
super().__init__()
encoder = SpeedySpeechEncoder(vocab_size, tone_size,
encoder_hidden_size, encoder_kernel_size,
encoder_dilations)
duration_predictor = DurationPredictor(duration_predictor_hidden_size)
decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size,
decoder_kernel_size, decoder_dilations)
self.encoder = encoder
self.duration_predictor = duration_predictor
self.decoder = decoder
def forward(self, text, tones, plens, durations):
encodings = self.encoder(text, tones)
pred_durations = self.duration_predictor(encodings.detach()) # (B, T)
# expand encodings
durations_to_expand = durations
encodings = expand(encodings, durations_to_expand)
# decode
# remove positional encoding here
_, t_dec, feature_size = encodings.shpae
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings)
return decoded, pred_durations
def inference(self, text, tones):
# text: [T]
# tones: [T]
text = text.unsqueeze(0)
if tones is not None:
tones = tones.unsqueeze(0)
encodings = self.encoder(text, tones)
pred_durations = self.duration_predictor(encodings) # (1, T)
durations_to_expand = paddle.round(pred_durations.exp())
durations_to_expand = (durations_to_expand).astype(paddle.int64)
encodings = expand(encodings, durations_to_expand)
shape = paddle.shape(encodings)
t_dec, feature_size = shape[1], shape[2]
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings)
return decoded, pred_durations

View File

@ -403,7 +403,7 @@ class TransformerTTS(nn.Layer):
else:
self.toned = False
# position encoding matrix may be extended later
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
self.encoder_pe = pe.sinusoid_positional_encoding(1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
@ -411,7 +411,7 @@ class TransformerTTS(nn.Layer):
# decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder)
self.decoder_pe = pe.sinusoid_positional_encoding(1000, d_decoder)
self.decoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.decoder = TransformerDecoder(
@ -488,7 +488,7 @@ class TransformerTTS(nn.Layer):
# twice its length if needed
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
self.decoder_pe = pe.sinusoid_positional_encoding(0, new_T,
self.decoder_pe = pe.sinusoid_positional_encoding(new_T,
self.d_decoder)
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
x = x.scale(math.sqrt(

View File

@ -0,0 +1,39 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import Tensor
def expand(encodings: Tensor, durations: Tensor) -> Tensor:
"""
encodings: (B, T, C)
durations: (B, T)
"""
batch_size, t_enc = durations.shape
durations = durations.numpy()
slens = np.sum(durations, -1)
t_dec = np.max(slens)
M = np.zeros([batch_size, t_dec, t_enc])
for i in range(batch_size):
k = 0
for j in range(t_enc):
d = durations[i, j]
M[i, k:k + d, j] = 1
k += d
M = paddle.to_tensor(M, dtype=encodings.dtype)
encodings = paddle.matmul(M, encodings)
return encodings

View File

@ -14,47 +14,56 @@
import math
import numpy as np
import paddle
from paddle import Tensor
from paddle.nn import functional as F
__all__ = ["sinusoid_positional_encoding"]
__all__ = ["sinusoid_position_encoding", "scaled_position_encoding"]
def sinusoid_positional_encoding(start_index, length, size, dtype=None):
r"""Generate standard positional encoding matrix.
.. math::
pe(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{size}}}) \\
pe(pos, 2i+1) = cos(\frac{pos}{10000^{\frac{2i}{size}}})
Parameters
----------
start_index : int
The start index.
length : int
The timesteps of the positional encoding to generate.
size : int
Feature size of positional encoding.
Returns
-------
Tensor [shape=(length, size)]
The positional encoding.
Raises
------
ValueError
If ``size`` is not divisible by 2.
"""
if (size % 2 != 0):
def sinusoid_position_encoding(num_positions: int,
feature_size: int,
omega: float=1.0,
start_pos: int=0,
dtype=None) -> Tensor:
# return tensor shape (num_positions, feature_size)
if (feature_size % 2 != 0):
raise ValueError("size should be divisible by 2")
dtype = dtype or paddle.get_default_dtype()
channel = np.arange(0, size, 2)
index = np.arange(start_index, start_index + length, 1)
p = np.expand_dims(index, -1) / (10000**(channel / float(size)))
encodings = np.zeros([length, size])
encodings[:, 0::2] = np.sin(p)
encodings[:, 1::2] = np.cos(p)
encodings = paddle.to_tensor(encodings)
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype)
p = (paddle.unsqueeze(index, -1) *
omega) / (10000.0**(channel / float(feature_size)))
encodings = paddle.zeros([num_positions, feature_size], dtype=dtype)
encodings[:, 0::2] = paddle.sin(p)
encodings[:, 1::2] = paddle.cos(p)
return encodings
def scaled_position_encoding(num_positions: int,
feature_size: int,
omega: Tensor,
start_pos: int=0,
dtype=None) -> Tensor:
# omega: Tensor (batch_size, )
# return tensor shape (batch_size, num_positions, feature_size)
# consider renaming this as batched positioning encoding
if (feature_size % 2 != 0):
raise ValueError("size should be divisible by 2")
dtype = dtype or paddle.get_default_dtype()
channel = paddle.arange(0, feature_size, 2, dtype=dtype)
index = paddle.arange(
start_pos, start_pos + num_positions, 1, dtype=omega.dtype)
batch_size = omega.shape[0]
omega = paddle.unsqueeze(omega, [1, 2])
p = (paddle.unsqueeze(index, -1) *
omega) / (10000.0**(channel / float(feature_size)))
encodings = paddle.zeros(
[batch_size, num_positions, feature_size], dtype=dtype)
# it is nice to have fancy indexing and inplace operations
encodings[:, :, 0::2] = paddle.sin(p)
encodings[:, :, 1::2] = paddle.cos(p)
return encodings

84
parakeet/modules/ssim.py Normal file
View File

@ -0,0 +1,84 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import exp
import numpy as np
import paddle
from paddle import nn
import paddle.nn.functional as F
def gaussian(window_size, sigma):
gauss = paddle.to_tensor([
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = paddle.matmul(_1D_window,
paddle.transpose(_1D_window,
[1, 0])).unsqueeze([0, 1])
window = paddle.expand(_2D_window, [channel, 1, window_size, window_size])
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \
/ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(nn.Layer):
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
return _ssim(img1, img2, self.window, self.window_size, self.channel,
self.size_average)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.shape
window = create_window(window_size, channel)
return _ssim(img1, img2, window, window_size, channel, size_average)

29
tests/test_expansion.py Normal file
View File

@ -0,0 +1,29 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from parakeet.modules import expansion
def test_expand():
x = paddle.randn([2, 4, 3]) # (B, T, C)
lengths = paddle.to_tensor([[1, 2, 2, 1], [3, 1, 4, 0]])
y = expansion.expand(x, lengths)
assert y.shape == [2, 8, 3]
print("the first sequence")
print(y[0])
print("the second sequence")
print(y[1])

34
tests/test_to_static.py Normal file
View File

@ -0,0 +1,34 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle.jit import to_static
from paddle.static import InputSpec
def test_applicative_evaluation():
def m_sqrt2(x):
return paddle.scale(x, math.sqrt(2))
subgraph = to_static(m_sqrt2, input_spec=[InputSpec([-1])])
paddle.jit.save(subgraph, './temp_test_to_static')
fn = paddle.jit.load('./temp_test_to_static')
x = paddle.arange(10, dtype=paddle.float32)
y = fn(x)
print(x)
print(y)