From 6c21d800259b24929c885394775c91d43e8c6cad Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 8 Jul 2021 16:47:08 +0800 Subject: [PATCH] add WIP: speedyspeech model and example with baker dataset. --- examples/parallelwave_gan/baker/preprocess.py | 3 +- examples/speedyspeech/baker/batch_fn.py | 37 +++ .../speedyspeech/baker/compute_statistics.py | 110 +++++++ examples/speedyspeech/baker/conf/default.yaml | 54 +++ examples/speedyspeech/baker/config.py | 25 ++ examples/speedyspeech/baker/normalize.py | 143 ++++++++ examples/speedyspeech/baker/phones.txt | 99 ++++++ examples/speedyspeech/baker/preprocess.py | 311 ++++++++++++++++++ .../baker/speedyspeech_updater.py | 65 ++++ examples/speedyspeech/baker/tg_utils.py | 27 ++ examples/speedyspeech/baker/tones.txt | 6 + examples/speedyspeech/baker/train.py | 155 +++++++++ parakeet/data/batch.py | 24 ++ parakeet/models/speedyspeech.py | 214 ++++++++++++ parakeet/models/transformer_tts.py | 6 +- parakeet/modules/expansion.py | 39 +++ parakeet/modules/positional_encoding.py | 81 +++-- parakeet/modules/ssim.py | 84 +++++ tests/test_expansion.py | 29 ++ tests/test_to_static.py | 34 ++ 20 files changed, 1505 insertions(+), 41 deletions(-) create mode 100644 examples/speedyspeech/baker/batch_fn.py create mode 100644 examples/speedyspeech/baker/compute_statistics.py create mode 100644 examples/speedyspeech/baker/conf/default.yaml create mode 100644 examples/speedyspeech/baker/config.py create mode 100644 examples/speedyspeech/baker/normalize.py create mode 100644 examples/speedyspeech/baker/phones.txt create mode 100644 examples/speedyspeech/baker/preprocess.py create mode 100644 examples/speedyspeech/baker/speedyspeech_updater.py create mode 100644 examples/speedyspeech/baker/tg_utils.py create mode 100644 examples/speedyspeech/baker/tones.txt create mode 100644 examples/speedyspeech/baker/train.py create mode 100644 parakeet/models/speedyspeech.py create mode 100644 parakeet/modules/expansion.py create mode 100644 parakeet/modules/ssim.py create mode 100644 tests/test_expansion.py create mode 100644 tests/test_to_static.py diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py index 6144c34..52b03de 100644 --- a/examples/parallelwave_gan/baker/preprocess.py +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -208,8 +208,7 @@ def main(): "--rootdir", default=None, type=str, - help="directory including wav files. you need to specify either scp or rootdir." - ) + help="directory to baker dataset.") parser.add_argument( "--dumpdir", type=str, diff --git a/examples/speedyspeech/baker/batch_fn.py b/examples/speedyspeech/baker/batch_fn.py new file mode 100644 index 0000000..873145a --- /dev/null +++ b/examples/speedyspeech/baker/batch_fn.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from parakeet.data.batch import batch_sequences + + +def collate_baker_examples(examples): + # fields = ["phones", "tones", "num_phones", "num_frames", "feats"] + phones = [np.array(item["phones"], dtype=np.int64) for item in examples] + tones = [np.array(item["tones"], dtype=np.int64) for item in examples] + feats = [np.array(item["feats"], dtype=np.float32) for item in examples] + num_phones = np.array([item["num_phones"] for item in examples]) + num_frames = np.array([item["num_frames"] for item in examples]) + + phones = batch_sequences(phones) + tones = batch_sequences(tones) + feats = batch_sequences(feats) + batch = { + "phones": phones, + "tones": tones, + "num_phones": num_phones, + "num_frames": num_frames, + "feats": feats, + } + return batch diff --git a/examples/speedyspeech/baker/compute_statistics.py b/examples/speedyspeech/baker/compute_statistics.py new file mode 100644 index 0000000..06b9b65 --- /dev/null +++ b/examples/speedyspeech/baker/compute_statistics.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Calculate statistics of feature files.""" + +import argparse +import logging +import os + +import numpy as np +import yaml +import json +import jsonlines + +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from parakeet.datasets.data_table import DataTable +from parakeet.utils.h5_utils import read_hdf5 +from parakeet.utils.h5_utils import write_hdf5 + +from config import get_cfg_default + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Compute mean and variance of dumped raw features.") + parser.add_argument( + "--metadata", type=str, help="json file with id and file paths ") + parser.add_argument( + "--field-name", + type=str, + help="name of the field to compute statistics for.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--dumpdir", + type=str, + help="directory to save statistics. if not provided, " + "stats will be saved in the above root directory. (default=None)") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + config = get_cfg_default() + # load config + if args.config: + config.merge_from_file(args.config) + + # check directory existence + if args.dumpdir is None: + args.dumpdir = os.path.dirname(args.metadata) + if not os.path.exists(args.dumpdir): + os.makedirs(args.dumpdir) + + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + fields=[args.field_name], + converters={args.field_name: np.load}, ) + logging.info(f"The number of files = {len(dataset)}.") + + # calculate statistics + scaler = StandardScaler() + for datum in tqdm(dataset): + # StandardScalar supports (*, num_features) by default + scaler.partial_fit(datum[args.field_name]) + + stats = np.stack([scaler.mean_, scaler.scale_], axis=0) + np.save( + os.path.join(args.dumpdir, "stats.npy"), + stats.astype(np.float32), + allow_pickle=False) + + +if __name__ == "__main__": + main() diff --git a/examples/speedyspeech/baker/conf/default.yaml b/examples/speedyspeech/baker/conf/default.yaml new file mode 100644 index 0000000..bbfeba5 --- /dev/null +++ b/examples/speedyspeech/baker/conf/default.yaml @@ -0,0 +1,54 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sr: 24000 # Sampling rate. +n_fft: 2048 # FFT size. +hop_length: 300 # Hop size. +win_length: 1200 # Window length. + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. +fmax: 7600 # Maximum frequency in mel basis calculation. +# global_gain_scale: 1.0 # Will be multiplied to all of waveform. +trim_silence: false # Whether to trim the start and end of silence. +top_db: 60 # Need to tune carefully if the recording is not good. +trim_frame_length: 2048 # Frame size in trimming.(in samples) +trim_hop_length: 512 # Hop size in trimming.(in samples) + + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 16 +num_workers: 0 + + + + +########################################################### +# MODEL SETTING # +########################################################### +model: + vocab_size: 68 + tone_size: 6 + encoder_hidden_size: 128 + encoder_kernel_size: 3 + encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1] + duration_predictor_hidden_size: 128 + decoder_hidden_size: 128 + decoder_output_size: 80 + decoder_kernel_size: 3 + decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1] + + +########################################################### +# OPTIMIZER SETTING # +########################################################### + + + +########################################################### +# OTHER SETTING # +########################################################### +seed: 10086 \ No newline at end of file diff --git a/examples/speedyspeech/baker/config.py b/examples/speedyspeech/baker/config.py new file mode 100644 index 0000000..f555791 --- /dev/null +++ b/examples/speedyspeech/baker/config.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml +from yacs.config import CfgNode as Configuration + +with open("conf/default.yaml", 'rt') as f: + _C = yaml.safe_load(f) + _C = Configuration(_C) + + +def get_cfg_default(): + config = _C.clone() + return config diff --git a/examples/speedyspeech/baker/normalize.py b/examples/speedyspeech/baker/normalize.py new file mode 100644 index 0000000..74661f8 --- /dev/null +++ b/examples/speedyspeech/baker/normalize.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" + +import argparse +import logging +import os +from copy import copy +from operator import itemgetter +from pathlib import Path + +import numpy as np +import yaml +import jsonlines +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from parakeet.frontend.vocab import Vocab +from parakeet.datasets.data_table import DataTable + +from config import get_cfg_default + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--stats", type=str, required=True, help="statistics file.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + # load config + config = get_cfg_default() + if args.config: + config.merge_from_file(args.config) + + # check directory existence + dumpdir = Path(args.dumpdir).resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable(metadata, converters={'feats': np.load, }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + scaler = StandardScaler() + scaler.mean_ = np.load(args.stats)[0] + scaler.scale_ = np.load(args.stats)[1] + + # from version 0.23.0, this information is needed + scaler.n_features_in_ = scaler.mean_.shape[0] + + with open("phones.txt", 'rt') as f: + phones = [line.strip() for line in f.readlines()] + + with open("tones.txt", 'rt') as f: + tones = [line.strip() for line in f.readlines()] + voc_phones = Vocab(phones, start_symbol=None, end_symbol=None) + voc_tones = Vocab(tones, start_symbol=None, end_symbol=None) + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + mel = item['feats'] + # normalize + mel = scaler.transform(mel) + + # save + mel_path = dumpdir / f"{utt_id}-feats.npy" + np.save(mel_path, mel.astype(np.float32), allow_pickle=False) + phone_ids = [voc_phones.lookup(p) for p in item['phones']] + tone_ids = [voc_tones.lookup(t) for t in item['tones']] + output_metadata.append({ + 'utt_id': utt_id, + 'phones': phone_ids, + 'tones': tone_ids, + 'num_phones': item['num_phones'], + 'num_frames': item['num_frames'], + 'durations': item['durations'], + 'feats': str(mel_path), + }) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/speedyspeech/baker/phones.txt b/examples/speedyspeech/baker/phones.txt new file mode 100644 index 0000000..d91a4a0 --- /dev/null +++ b/examples/speedyspeech/baker/phones.txt @@ -0,0 +1,99 @@ +b +p +m +f +d +t +n +l +g +k +h +zh +ch +sh +r +z +c +s +j +q +x +a +ar +ai +air +ao +aor +an +anr +ang +angr +e +er +ei +eir +en +enr +eng +engr +o +or +ou +our +ong +ongr +ii +iir +iii +iiir +i +ir +ia +iar +iao +iaor +ian +ianr +iang +iangr +ie +ier +io +ior +iou +iour +iong +iongr +in +inr +ing +ingr +u +ur +ua +uar +uai +uair +uan +uanr +uang +uangr +uei +ueir +uo +uor +uen +uenr +ueng +uengr +v +vr +ve +ver +van +vanr +vn +vnr +sil +sp diff --git a/examples/speedyspeech/baker/preprocess.py b/examples/speedyspeech/baker/preprocess.py new file mode 100644 index 0000000..f228a76 --- /dev/null +++ b/examples/speedyspeech/baker/preprocess.py @@ -0,0 +1,311 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Dict, Any +import soundfile as sf +import librosa +import numpy as np +import argparse +import yaml +import json +import re +import jsonlines +import concurrent.futures +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from pathlib import Path +import tqdm +from operator import itemgetter +from praatio import tgio +import logging + +from config import get_cfg_default +from tg_utils import validate_textgrid + + +def logmelfilterbank(audio, + sr, + n_fft=1024, + hop_length=256, + win_length=None, + window="hann", + n_mels=80, + fmin=None, + fmax=None, + eps=1e-10): + """Compute log-Mel filterbank feature. + + Parameters + ---------- + audio : ndarray + Audio signal (T,). + sr : int + Sampling rate. + n_fft : int + FFT size. (Default value = 1024) + hop_length : int + Hop size. (Default value = 256) + win_length : int + Window length. If set to None, it will be the same as fft_size. (Default value = None) + window : str + Window function type. (Default value = "hann") + n_mels : int + Number of mel basis. (Default value = 80) + fmin : int + Minimum frequency in mel basis calculation. (Default value = None) + fmax : int + Maximum frequency in mel basis calculation. (Default value = None) + eps : float + Epsilon value to avoid inf in log calculation. (Default value = 1e-10) + + Returns + ------- + np.ndarray + Log Mel filterbank feature (#frames, num_mels). + + """ + # get amplitude spectrogram + x_stft = librosa.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + pad_mode="reflect") + spc = np.abs(x_stft) # (#bins, #frames,) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sr / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax) + + return np.log10(np.maximum(eps, np.dot(mel_basis, spc))) + + +def process_sentence(config: Dict[str, Any], + fp: Path, + alignment_fp: Path, + output_dir: Path): + utt_id = fp.stem + + # reading + y, sr = librosa.load(fp, sr=config.sr) # resampling may occur + assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(y).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + duration = librosa.get_duration(y, sr=sr) + + # intervals with empty lables are ignored + alignment = tgio.openTextgrid(alignment_fp) + + # validate text grid against audio file + num_samples = y.shape[0] + validate_textgrid(alignment, num_samples, sr) + + # only with baker's annotation + intervals = alignment.tierDict[alignment.tierNameList[0]].entryList + + first, last = intervals[0], intervals[-1] + if not (first.label == "sil" and first.end < duration): + logging.warning( + f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}" + ) + if not (last.label == "sil" and last.start < duration): + logging.warning( + f" There is something wrong with the last interval {last} in utterance: {utt_id}" + ) + + logmel = logmelfilterbank( + y, + sr=sr, + n_fft=config.n_fft, + window=config.window, + win_length=config.win_length, + hop_length=config.hop_length, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + + # extract phone and duration + phones = [] + tones = [] + ends = [] + durations_sec = [] + + for interval in intervals: + label = interval.label + label = label.replace("sp1", "sp") # Baker has sp1 rather than sp + + # split tone from finals + match = re.match(r'^(\w+)([012345])$', label) + if match: + phones.append(match.group(1)) + tones.append(match.group(2)) + else: + phones.append(label) + tones.append('0') + end = min(duration, interval.end) + ends.append(end) + durations_sec.append(end - interval.start) # duration in seconds + + frame_pos = librosa.time_to_frames( + ends, sr=sr, hop_length=config.hop_length) + durations_frame = np.diff(frame_pos, prepend=0) + + num_frames = logmel.shape[-1] # number of frames of the spectrogram + extra = np.sum(durations_frame) - num_frames + assert extra <= 0, ( + f"Number of frames inferred from alignemnt is " + f"larger than number of frames of the spectrogram by {extra} frames") + durations_frame[-1] += (-extra) + + assert np.sum(durations_frame) == num_frames + durations_frame = durations_frame.tolist() + + mel_path = output_dir / (utt_id + "_feats.npy") + np.save(mel_path, logmel.T) # (num_frames, n_mels) + record = { + "utt_id": utt_id, + "phones": phones, + "tones": tones, + "num_phones": len(phones), + "num_frames": num_frames, + "durations": durations_frame, + "feats": str(mel_path.resolve()), # use absolute path + } + return record + + +def process_sentences(config, + fps: List[Path], + alignment_fps: List[Path], + output_dir: Path, + nprocs: int=1): + if nprocs == 1: + results = [] + for fp, alignment_fp in tqdm.tqdm( + zip(fps, alignment_fps), total=len(fps)): + results.append( + process_sentence(config, fp, alignment_fp, output_dir)) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp, alignment_fp in zip(fps, alignment_fps): + future = pool.submit(process_sentence, config, fp, + alignment_fp, output_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + results.append(ft.result()) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)." + ) + parser.add_argument( + "--rootdir", + default=None, + type=str, + help="directory including wav files. you need to specify either scp or rootdir." + ) + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + parser.add_argument( + "--num_cpu", type=int, default=1, help="number of process.") + args = parser.parse_args() + + C = get_cfg_default() + if args.config: + C.merge_from_file(args.config) + C.freeze() + + if args.verbose > 1: + print(vars(args)) + print(C) + + root_dir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + dumpdir.mkdir(parents=True, exist_ok=True) + + wav_files = sorted(list((root_dir / "Wave").rglob("*.wav"))) + alignment_files = sorted( + list((root_dir / "PhoneLabeling").rglob("*.interval"))) + + # filter out several files that have errors in annotation + exclude = {'000611', '000662', '002365', '005107'} + wav_files = [f for f in wav_files if f.stem not in exclude] + alignment_files = [f for f in alignment_files if f.stem not in exclude] + + # split data into 3 sections + num_train = 9800 + num_dev = 100 + + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + + train_alignment_files = alignment_files[:num_train] + dev_alignment_files = alignment_files[num_train:num_train + num_dev] + test_alignment_files = alignment_files[num_train + num_dev:] + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # process for the 3 sections + process_sentences( + C, + train_wav_files, + train_alignment_files, + train_dump_dir, + nprocs=args.num_cpu) + process_sentences( + C, + dev_wav_files, + dev_alignment_files, + dev_dump_dir, + nprocs=args.num_cpu) + process_sentences( + C, + test_wav_files, + test_alignment_files, + test_dump_dir, + nprocs=args.num_cpu) + + +if __name__ == "__main__": + main() diff --git a/examples/speedyspeech/baker/speedyspeech_updater.py b/examples/speedyspeech/baker/speedyspeech_updater.py new file mode 100644 index 0000000..3e3a32b --- /dev/null +++ b/examples/speedyspeech/baker/speedyspeech_updater.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.nn import functional as F +from paddle.fluid.layers import huber_loss + +from parakeet.modules.ssim import ssim +from parakeet.modules.modules.losses import masked_l1_loss, weighted_mean +from parakeet.training.reporter import report +from parakeet.training.updaters.standard_updater import StandardUpdater +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.models.speedyspeech import SpeedySpeech + + +class SpeedySpeechUpdater(StandardUpdater): + def update_core(self, batch): + decoded, predicted_durations = self.model( + text=batch["phonemes"], + tones=batch["tones"], + plens=batch["phoneme_lenghts"], + durations=batch["phoneme_durations"]) + + target_mel = batch["mel"] + spec_mask = F.sequence_mask( + batch["num_frames"], dtype=target_mel.dtype).unsqueeze(-1) + text_mask = F.sequence_mask( + batch["phoneme_lenghts"], dtype=predicted_durations.dtype) + + # spec loss + l1_loss = masked_l1_loss(decoded, target_mel, spec_mask) + + # duration loss + target_durations = batch["phoneme_durations"] + target_durations = paddle.clip(target_durations, min=1.0) + duration_loss = weighted_mean( + huber_loss( + predicted_durations, paddle.log(target_durations), delta=1.0), + text_mask, ) + + # ssim loss + ssim_loss = 1.0 - ssim((decoded * spec_mask).unsqueeze(1), + (target_mel * spec_mask).unsqueeze(1)) + + loss = l1_loss + duration_loss + ssim_loss + + optimizer = self.optimizer + optimizer.clear_grad() + loss.backward() + optimizer.step() + + report("train/l1_loss", float(l1_loss)) + report("train/duration_loss", float(duration_loss)) + report("train/ssim_loss", float(ssim_loss)) diff --git a/examples/speedyspeech/baker/tg_utils.py b/examples/speedyspeech/baker/tg_utils.py new file mode 100644 index 0000000..18c0385 --- /dev/null +++ b/examples/speedyspeech/baker/tg_utils.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import librosa +from praatio import tgio + + +def validate_textgrid(text_grid, num_samples, sr): + """Validate Text Grid to make sure that the time interval annotated + by the tex grid file does not go beyond the audio file. + """ + start = text_grid.minTimestamp + end = text_grid.maxTimestamp + + end_audio = librosa.samples_to_time(num_samples, sr) + return start == 0.0 and end <= end_audio diff --git a/examples/speedyspeech/baker/tones.txt b/examples/speedyspeech/baker/tones.txt new file mode 100644 index 0000000..e8371f0 --- /dev/null +++ b/examples/speedyspeech/baker/tones.txt @@ -0,0 +1,6 @@ +0 +1 +2 +3 +4 +5 diff --git a/examples/speedyspeech/baker/train.py b/examples/speedyspeech/baker/train.py new file mode 100644 index 0000000..d0d1df4 --- /dev/null +++ b/examples/speedyspeech/baker/train.py @@ -0,0 +1,155 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import logging +import argparse +import dataclasses +from pathlib import Path + +import yaml +import jsonlines +import paddle +import numpy as np +from paddle import nn +from paddle.nn import functional as F +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from paddle.optimizer import Adam # No RAdaom +from paddle.optimizer.lr import StepDecay +from paddle import DataParallel +from visualdl import LogWriter + +from parakeet.datasets.data_table import DataTable +from parakeet.models.speedyspeech import SpeedySpeech + +from parakeet.training.updater import UpdaterBase +from parakeet.training.trainer import Trainer +from parakeet.training.reporter import report +from parakeet.training import extension +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.extensions.visualizer import VisualDL +from parakeet.training.seeding import seed_everything + +from batch_fn import collate_baker_examples +from config import get_cfg_default + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if not paddle.is_compiled_with_cuda: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=["phones", "tones", "num_phones", "num_frames", "feats"], + converters={"feats": np.load, }, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=["phones", "tones", "num_phones", "num_frames", "feats"], + converters={"feats": np.load, }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_baker_examples, + num_workers=config.num_workers) + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=collate_baker_examples, + num_workers=config.num_workers) + print("dataloaders done!") + + # batch = collate_baker_examples([train_dataset[i] for i in range(10)]) + # batch = collate_baker_examples([dev_dataset[i] for i in range(10)]) + # import pdb; pdb.set_trace() + model = SpeedySpeech(**config["model"]) + print(model) + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN " + "model with Baker Mandrin TTS dataset.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config") + parser.add_argument("--train-metadata", type=str, help="training data") + parser.add_argument("--dev-metadata", type=str, help="dev data") + parser.add_argument("--output-dir", type=str, help="output dir") + parser.add_argument( + "--device", type=str, default="gpu", help="device type to use") + parser.add_argument( + "--nprocs", type=int, default=1, help="number of processes") + parser.add_argument("--verbose", type=int, default=1, help="verbose") + + args = parser.parse_args() + if args.device == "cpu" and args.nprocs > 1: + raise RuntimeError("Multiprocess training on CPU is not supported.") + config = get_cfg_default() + if args.config: + config.merge_from_file(args.config) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.nprocs > 1: + dist.spawn(train_sp, (args, config), nprocs=args.nprocs) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/parakeet/data/batch.py b/parakeet/data/batch.py index 1397f55..d5f5e91 100644 --- a/parakeet/data/batch.py +++ b/parakeet/data/batch.py @@ -161,3 +161,27 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32): mode='constant', constant_values=pad_value)) return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64) + + +def batch_sequences(sequences, axis=0, pad_value=0): + # import pdb; pdb.set_trace() + seq = sequences[0] + ndim = seq.ndim + if axis < 0: + axis += ndim + dtype = seq.dtype + pad_value = dtype.type(pad_value) + seq_lengths = [seq.shape[axis] for seq in sequences] + max_length = np.max(seq_lengths) + + padded_sequences = [] + for seq, length in zip(sequences, seq_lengths): + padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * ( + ndim - axis - 1) + padded_seq = np.pad(seq, + padding, + mode='constant', + constant_values=pad_value) + padded_sequences.append(padded_seq) + batch = np.stack(padded_sequences) + return batch diff --git a/parakeet/models/speedyspeech.py b/parakeet/models/speedyspeech.py new file mode 100644 index 0000000..e33ff1f --- /dev/null +++ b/parakeet/models/speedyspeech.py @@ -0,0 +1,214 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import paddle +from paddle import Tensor +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I + +from parakeet.modules.positional_encoding import sinusoid_position_encoding +from parakeet.modules.expansion import expand + + +class ResidualBlock(nn.Layer): + def __init__(self, channels, kernel_size, dilation, n=2): + super().__init__() + blocks = [ + nn.Sequential( + nn.Conv1D( + channels, + channels, + kernel_size, + dilation=dilation, + padding="same", + data_format="NLC"), + nn.ReLU(), + nn.BatchNorm1D( + channels, data_format="NLC"), ) for _ in range(n) + ] + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + return x + self.blocks(x) + + +class TextEmbedding(nn.Layer): + def __init__(self, + vocab_size: int, + embedding_size: int, + tone_vocab_size: int=None, + tone_embedding_size: int=None, + padding_idx: int=None, + tone_padding_idx: int=None, + concat: bool=False): + super().__init__() + self.text_embedding = nn.Embedding(vocab_size, embedding_size, + padding_idx) + if tone_vocab_size: + tone_embedding_size = tone_embedding_size or embedding_size + if tone_embedding_size != embedding_size and not concat: + raise ValueError( + "embedding size != tone_embedding size, only conat is avaiable." + ) + self.tone_embedding = nn.Embedding( + tone_vocab_size, tone_embedding_size, tone_padding_idx) + self.concat = concat + + def forward(self, text, tone=None): + text_embed = self.text_embedding(text) + if tone is None: + return text_embed + tone_embed = self.tone_embedding(tone) + if self.concat: + embed = paddle.concat([text_embed, tone_embed], -1) + else: + embed = text_embed + tone_embed + return embed + + +class SpeedySpeechEncoder(nn.Layer): + def __init__(self, vocab_size, tone_size, hidden_size, kernel_size, + dilations): + super().__init__() + self.embedding = TextEmbedding( + vocab_size, + hidden_size, + tone_size, + padding_idx=0, + tone_padding_idx=0) + self.prenet = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), ) + res_blocks = [ + ResidualBlock( + hidden_size, kernel_size, d, n=2) for d in dilations + ] + self.res_blocks = nn.Sequential(*res_blocks) + + self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size)) + self.postnet2 = nn.Sequential( + nn.ReLU(), + nn.BatchNorm1D( + hidden_size, data_format="NLC"), + nn.Linear(hidden_size, hidden_size), ) + + def forward(self, text, tones): + embedding = self.embedding(text, tones) + embedding = self.prenet(embedding) + x = self.res_blocks(embedding) + x = embedding + self.postnet1(x) + x = self.postnet2(x) + return x + + +class DurationPredictor(nn.Layer): + def __init__(self, hidden_size): + super().__init__() + self.layers = nn.Sequential( + ResidualBlock( + hidden_size, 4, 1, n=1), + ResidualBlock( + hidden_size, 3, 1, n=1), + ResidualBlock( + hidden_size, 1, 1, n=1), + nn.Linear(hidden_size, 1)) + + def forward(self, x): + return paddle.squeeze(self.layers(x), -1) + + +class SpeedySpeechDecoder(nn.Layer): + def __init__(self, hidden_size, output_size, kernel_size, dilations): + super().__init__() + res_blocks = [ + ResidualBlock( + hidden_size, kernel_size, d, n=2) for d in dilations + ] + self.res_blocks = nn.Sequential(*res_blocks) + + self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size)) + self.postnet2 = nn.Sequential( + ResidualBlock( + hidden_size, kernel_size, 1, n=2), + nn.Linear(hidden_size, output_size)) + + def forward(self, x): + xx = self.res_blocks(x) + x = x + self.postnet1(xx) + x = self.postnet2(x) + return x + + +class SpeedySpeech(nn.Layer): + def __init__( + self, + vocab_size, + encoder_hidden_size, + encoder_kernel_size, + encoder_dilations, + duration_predictor_hidden_size, + decoder_hidden_size, + decoder_output_size, + decoder_kernel_size, + decoder_dilations, + tone_size=None, ): + super().__init__() + encoder = SpeedySpeechEncoder(vocab_size, tone_size, + encoder_hidden_size, encoder_kernel_size, + encoder_dilations) + duration_predictor = DurationPredictor(duration_predictor_hidden_size) + decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size, + decoder_kernel_size, decoder_dilations) + + self.encoder = encoder + self.duration_predictor = duration_predictor + self.decoder = decoder + + def forward(self, text, tones, plens, durations): + encodings = self.encoder(text, tones) + pred_durations = self.duration_predictor(encodings.detach()) # (B, T) + + # expand encodings + durations_to_expand = durations + encodings = expand(encodings, durations_to_expand) + + # decode + # remove positional encoding here + _, t_dec, feature_size = encodings.shpae + encodings += sinusoid_position_encoding(t_dec, feature_size) + decoded = self.decoder(encodings) + return decoded, pred_durations + + def inference(self, text, tones): + # text: [T] + # tones: [T] + text = text.unsqueeze(0) + if tones is not None: + tones = tones.unsqueeze(0) + + encodings = self.encoder(text, tones) + pred_durations = self.duration_predictor(encodings) # (1, T) + durations_to_expand = paddle.round(pred_durations.exp()) + durations_to_expand = (durations_to_expand).astype(paddle.int64) + encodings = expand(encodings, durations_to_expand) + + shape = paddle.shape(encodings) + t_dec, feature_size = shape[1], shape[2] + encodings += sinusoid_position_encoding(t_dec, feature_size) + decoded = self.decoder(encodings) + return decoded, pred_durations diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index aa97395..db8708a 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -403,7 +403,7 @@ class TransformerTTS(nn.Layer): else: self.toned = False # position encoding matrix may be extended later - self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder) + self.encoder_pe = pe.sinusoid_positional_encoding(1000, d_encoder) self.encoder_pe_scalar = self.create_parameter( [1], attr=I.Constant(1.)) self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, @@ -411,7 +411,7 @@ class TransformerTTS(nn.Layer): # decoder self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) - self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder) + self.decoder_pe = pe.sinusoid_positional_encoding(1000, d_decoder) self.decoder_pe_scalar = self.create_parameter( [1], attr=I.Constant(1.)) self.decoder = TransformerDecoder( @@ -488,7 +488,7 @@ class TransformerTTS(nn.Layer): # twice its length if needed if x.shape[1] * self.r > self.decoder_pe.shape[0]: new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2) - self.decoder_pe = pe.sinusoid_positional_encoding(0, new_T, + self.decoder_pe = pe.sinusoid_positional_encoding(new_T, self.d_decoder) pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :] x = x.scale(math.sqrt( diff --git a/parakeet/modules/expansion.py b/parakeet/modules/expansion.py new file mode 100644 index 0000000..d136ada --- /dev/null +++ b/parakeet/modules/expansion.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +from paddle import Tensor + + +def expand(encodings: Tensor, durations: Tensor) -> Tensor: + """ + encodings: (B, T, C) + durations: (B, T) + """ + batch_size, t_enc = durations.shape + durations = durations.numpy() + slens = np.sum(durations, -1) + t_dec = np.max(slens) + M = np.zeros([batch_size, t_dec, t_enc]) + for i in range(batch_size): + k = 0 + for j in range(t_enc): + d = durations[i, j] + M[i, k:k + d, j] = 1 + k += d + M = paddle.to_tensor(M, dtype=encodings.dtype) + encodings = paddle.matmul(M, encodings) + return encodings diff --git a/parakeet/modules/positional_encoding.py b/parakeet/modules/positional_encoding.py index cec168c..d4bdea5 100644 --- a/parakeet/modules/positional_encoding.py +++ b/parakeet/modules/positional_encoding.py @@ -14,47 +14,56 @@ import math import numpy as np + import paddle +from paddle import Tensor from paddle.nn import functional as F -__all__ = ["sinusoid_positional_encoding"] +__all__ = ["sinusoid_position_encoding", "scaled_position_encoding"] -def sinusoid_positional_encoding(start_index, length, size, dtype=None): - r"""Generate standard positional encoding matrix. - - .. math:: - - pe(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{size}}}) \\ - pe(pos, 2i+1) = cos(\frac{pos}{10000^{\frac{2i}{size}}}) - - Parameters - ---------- - start_index : int - The start index. - length : int - The timesteps of the positional encoding to generate. - size : int - Feature size of positional encoding. - - Returns - ------- - Tensor [shape=(length, size)] - The positional encoding. - - Raises - ------ - ValueError - If ``size`` is not divisible by 2. - """ - if (size % 2 != 0): +def sinusoid_position_encoding(num_positions: int, + feature_size: int, + omega: float=1.0, + start_pos: int=0, + dtype=None) -> Tensor: + # return tensor shape (num_positions, feature_size) + if (feature_size % 2 != 0): raise ValueError("size should be divisible by 2") dtype = dtype or paddle.get_default_dtype() - channel = np.arange(0, size, 2) - index = np.arange(start_index, start_index + length, 1) - p = np.expand_dims(index, -1) / (10000**(channel / float(size))) - encodings = np.zeros([length, size]) - encodings[:, 0::2] = np.sin(p) - encodings[:, 1::2] = np.cos(p) - encodings = paddle.to_tensor(encodings) + + channel = paddle.arange(0, feature_size, 2, dtype=dtype) + index = paddle.arange(start_pos, start_pos + num_positions, 1, dtype=dtype) + p = (paddle.unsqueeze(index, -1) * + omega) / (10000.0**(channel / float(feature_size))) + encodings = paddle.zeros([num_positions, feature_size], dtype=dtype) + encodings[:, 0::2] = paddle.sin(p) + encodings[:, 1::2] = paddle.cos(p) + return encodings + + +def scaled_position_encoding(num_positions: int, + feature_size: int, + omega: Tensor, + start_pos: int=0, + dtype=None) -> Tensor: + # omega: Tensor (batch_size, ) + # return tensor shape (batch_size, num_positions, feature_size) + # consider renaming this as batched positioning encoding + if (feature_size % 2 != 0): + raise ValueError("size should be divisible by 2") + dtype = dtype or paddle.get_default_dtype() + + channel = paddle.arange(0, feature_size, 2, dtype=dtype) + index = paddle.arange( + start_pos, start_pos + num_positions, 1, dtype=omega.dtype) + batch_size = omega.shape[0] + omega = paddle.unsqueeze(omega, [1, 2]) + p = (paddle.unsqueeze(index, -1) * + omega) / (10000.0**(channel / float(feature_size))) + encodings = paddle.zeros( + [batch_size, num_positions, feature_size], dtype=dtype) + # it is nice to have fancy indexing and inplace operations + encodings[:, :, 0::2] = paddle.sin(p) + encodings[:, :, 1::2] = paddle.cos(p) return encodings diff --git a/parakeet/modules/ssim.py b/parakeet/modules/ssim.py new file mode 100644 index 0000000..3e4b20d --- /dev/null +++ b/parakeet/modules/ssim.py @@ -0,0 +1,84 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import exp + +import numpy as np +import paddle +from paddle import nn +import paddle.nn.functional as F + + +def gaussian(window_size, sigma): + gauss = paddle.to_tensor([ + exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = paddle.matmul(_1D_window, + paddle.transpose(_1D_window, + [1, 0])).unsqueeze([0, 1]) + window = paddle.expand(_2D_window, [channel, 1, window_size, window_size]) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=window_size // 2, + groups=channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \ + / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(nn.Layer): + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + return _ssim(img1, img2, self.window, self.window_size, self.channel, + self.size_average) + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.shape + window = create_window(window_size, channel) + return _ssim(img1, img2, window, window_size, channel, size_average) \ No newline at end of file diff --git a/tests/test_expansion.py b/tests/test_expansion.py new file mode 100644 index 0000000..d548993 --- /dev/null +++ b/tests/test_expansion.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from parakeet.modules import expansion + + +def test_expand(): + x = paddle.randn([2, 4, 3]) # (B, T, C) + lengths = paddle.to_tensor([[1, 2, 2, 1], [3, 1, 4, 0]]) + y = expansion.expand(x, lengths) + + assert y.shape == [2, 8, 3] + print("the first sequence") + print(y[0]) + + print("the second sequence") + print(y[1]) diff --git a/tests/test_to_static.py b/tests/test_to_static.py new file mode 100644 index 0000000..7695eca --- /dev/null +++ b/tests/test_to_static.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +from paddle.jit import to_static +from paddle.static import InputSpec + + +def test_applicative_evaluation(): + def m_sqrt2(x): + return paddle.scale(x, math.sqrt(2)) + + subgraph = to_static(m_sqrt2, input_spec=[InputSpec([-1])]) + paddle.jit.save(subgraph, './temp_test_to_static') + + fn = paddle.jit.load('./temp_test_to_static') + x = paddle.arange(10, dtype=paddle.float32) + y = fn(x) + + print(x) + print(y)