diff --git a/examples/transformer_tts/config.py b/examples/transformer_tts/config.py new file mode 100644 index 0000000..fef9ed8 --- /dev/null +++ b/examples/transformer_tts/config.py @@ -0,0 +1,55 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.data = CN( + dict( + batch_size=16, # batch size + valid_size=64, # the first N examples are reserved for validation + sample_rate=22050, # Hz, sample rate + n_fft=1024, # fft frame size + win_length=1024, # window size + hop_length=256, # hop size between ajacent frame + f_max=8000, # Hz, max frequency when converting to mel + d_mel=80, # mel bands + padding_idx=0, # text embedding's padding index + mel_start_value=0.5, # value for starting frame + mel_end_value=-0.5, # # value for ending frame + ) +) + +_C.model = CN( + dict( + d_encoder=512, # embedding & encoder's internal size + d_decoder=256, # decoder's internal size + n_heads=4, # actually it can differ at each layer + d_ffn=1024, # encoder_d_ffn & decoder_d_ffn + encoder_layers=4, # number of transformer encoder layer + decoder_layers=4, # number of transformer decoder layer + d_prenet=256, # decprenet's hidden size (d_mel=>d_prenet=>d_decoder) + d_postnet=256, # decoder postnet(cnn)'s internal channel + postnet_layers=5, # decoder postnet(cnn)'s layer + postnet_kernel_size=5, # decoder postnet(cnn)'s kernel size + max_reduction_factor=10, # max_reduction factor + dropout=0.1, # global droput probability + stop_loss_scale=8.0, # scaler for stop _loss + decoder_prenet_dropout=0.5, # decoder prenet dropout probability + ) +) + +_C.training = CN( + dict( + lr=1e-4, # learning rate + drop_n_heads=[[0, 0], [15000, 1]], + reduction_factor=[[0, 10], [80000, 4], [200000, 2]], + plot_interval=1000, # plot attention and spectrogram + valid_interval=1000, # validation + save_interval=10000, # checkpoint + max_iteration=900000, # max iteration to train + ) +) + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() diff --git a/examples/transformer_tts/ljspeech.py b/examples/transformer_tts/ljspeech.py new file mode 100644 index 0000000..245b475 --- /dev/null +++ b/examples/transformer_tts/ljspeech.py @@ -0,0 +1,88 @@ +import os +from pathlib import Path +import pickle +import numpy as np +from paddle.io import Dataset, DataLoader + +from parakeet.data.batch import batch_spec, batch_text_id +from parakeet.data import dataset + +class LJSpeech(Dataset): + """A simple dataset adaptor for the processed ljspeech dataset.""" + def __init__(self, root): + self.root = Path(root).expanduser() + records = [] + with open(self.root / "metadata.pkl", 'rb') as f: + metadata = pickle.load(f) + for mel_name, text, phonemes, ids in metadata: + mel_name = self.root / "mel" / (mel_name + ".npy") + records.append((mel_name, text, phonemes, ids)) + self.records = records + + def __getitem__(self, i): + mel_name, _, _, ids = self.records[i] + mel = np.load(mel_name) + return ids, mel + + def __len__(self): + return len(self.records) + + +# decorate mel & create stop probability +class Transform(object): + def __init__(self, start_value, end_value): + self.start_value = start_value + self.end_value = end_value + + def __call__(self, example): + ids, mel = example # ids already have and + ids = np.array(ids, dtype=np.int64) + # add start and end frame + mel = np.pad(mel, + [(0, 0), (1, 1)], + mode='constant', + constant_values=[(0, 0), (self.start_value, self.end_value)]) + stop_labels = np.ones([mel.shape[1]], dtype=np.int64) + stop_labels[-1] = 2 + # actually this thing can also be done within the model + return ids, mel, stop_labels + + +class LJSpeechCollector(object): + """A simple callable to batch LJSpeech examples.""" + def __init__(self, padding_idx=0, padding_value=0.): + self.padding_idx = padding_idx + self.padding_value = padding_value + + def __call__(self, examples): + ids = [example[0] for example in examples] + mels = [example[1] for example in examples] + stop_probs = [example[2] for example in examples] + + ids = batch_text_id(ids, pad_id=self.padding_idx) + mels = batch_spec(mels, pad_value=self.padding_value) + stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx) + return ids, np.transpose(mels, [0, 2, 1]), stop_probs + + +def create_dataloader(config, source_path): + lj = LJSpeech(source_path) + transform = Transform(config.data.mel_start_value, config.data.mel_end_value) + lj = dataset.TransformDataset(lj, transform) + + valid_set, train_set = dataset.split(lj, config.data.valid_size) + data_collator = LJSpeechCollector(padding_idx=config.data.padding_idx) + train_loader = DataLoader( + train_set, + batch_size=config.data.batch_size, + shuffle=True, + drop_last=True, + collate_fn=data_collator) + valid_loader = DataLoader( + valid_set, + batch_size=config.data.batch_size, + shuffle=False, + drop_last=False, + collate_fn=data_collator) + return train_loader, valid_loader + diff --git a/examples/transformer_tts/preprocess.py b/examples/transformer_tts/preprocess.py new file mode 100644 index 0000000..001f04c --- /dev/null +++ b/examples/transformer_tts/preprocess.py @@ -0,0 +1,82 @@ +import os +import tqdm +import pickle +import argparse +import numpy as np +from pathlib import Path + +from parakeet.datasets import LJSpeechMetaData +from parakeet.audio import AudioProcessor, LogMagnitude +from parakeet.frontend import English + +from config import get_cfg_defaults + +def create_dataset(config, source_path, target_path, verbose=False): + # create output dir + target_path = Path(target_path).expanduser() + mel_path = target_path / "mel" + os.makedirs(mel_path, exist_ok=True) + + meta_data = LJSpeechMetaData(source_path) + frontend = English() + processor = AudioProcessor( + sample_rate=config.data.sample_rate, + n_fft=config.data.n_fft, + n_mels=config.data.d_mel, + win_length=config.data.win_length, + hop_length=config.data.hop_length, + f_max=config.data.f_max) + normalizer = LogMagnitude() + + records = [] + for (fname, text, _) in tqdm.tqdm(meta_data): + wav = processor.read_wav(fname) + mel = processor.mel_spectrogram(wav) + mel = normalizer.transform(mel) + phonemes = frontend.phoneticize(text) + ids = frontend.numericalize(phonemes) + mel_name = os.path.splitext(os.path.basename(fname))[0] + + # save mel spectrogram + records.append((mel_name, text, phonemes, ids)) + np.save(mel_path / mel_name, mel) + if verbose: + print("save mel spectrograms into {}".format(mel_path)) + + # save meta data as pickle archive + with open(target_path / "metadata.pkl", 'wb') as f: + pickle.dump(records, f) + if verbose: + print("saved metadata into {}".format(target_path / "metadata.pkl")) + + # also save meta data into text format for inspection + with open(target_path / "metadata.txt", 'wt') as f: + for mel_name, text, phonemes, _ in records: + phoneme_str = "|".join(phonemes) + f.write("{}\t{}\t{}\n".format(mel_name, text, phoneme_str)) + if verbose: + print("saved metadata into {}".format(target_path / "metadata.txt")) + + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="create dataset") + parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config") + parser.add_argument("--input", type=str, help="path of the ljspeech dataset") + parser.add_argument("--output", type=str, help="path to save output dataset") + parser.add_argument("--opts", nargs=argparse.REMAINDER, + help="options to overwrite --config file and the default config, passing in KEY VALUE pairs" + ) + parser.add_argument("-v", "--verbose", action="store_true", help="print msg") + + config = get_cfg_defaults() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config.data) + + create_dataset(config, args.input, args.output, args.verbose) diff --git a/examples/transformer_tts/synthesize.py b/examples/transformer_tts/synthesize.py new file mode 100644 index 0000000..b8f352f --- /dev/null +++ b/examples/transformer_tts/synthesize.py @@ -0,0 +1,64 @@ +import argparse +import time +from pathlib import Path +import numpy as np +import paddle + +import parakeet +from parakeet.frontend import English +from parakeet.models.transformer_tts import TransformerTTS +from parakeet.utils import scheduler +from parakeet.training.cli import default_argument_parser +from parakeet.utils.display import add_attention_plots + +from config import get_cfg_defaults + +@paddle.fluid.dygraph.no_grad +def main(config, args): + paddle.set_device(args.device) + + # model + frontend = English() + model = TransformerTTS.from_pretrained( + frontend, config, args.checkpoint_path) + model.eval() + + # inputs + input_path = Path(args.input).expanduser() + with open(input_path, "rt") as f: + sentences = f.readlines() + + output_dir = Path(args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + + for i, sentence in enumerate(sentences): + outputs = model.predict(sentence, verbose=args.verbose) + mel_output = outputs["mel_output"] + # cross_attention_weights = outputs["cross_attention_weights"] + mel_output = mel_output.T #(C, T) + np.save(str(output_dir / f"sentence_{i}"), mel_output) + if args.verbose: + print("spectrogram saved at {}".format(output_dir / f"sentence_{i}.npy")) + +if __name__ == "__main__": + config = get_cfg_defaults() + + parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.") + parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config") + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.") + parser.add_argument("--input", type=str, help="path of the text sentences") + parser.add_argument("--output", type=str, help="path to save outputs") + parser.add_argument("--device", type=str, default="cpu", help="device type to use.") + parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + parser.add_argument("-v", "--verbose", action="store_true", help="print msg") + + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) diff --git a/examples/transformer_tts/train.py b/examples/transformer_tts/train.py new file mode 100644 index 0000000..59ec7aa --- /dev/null +++ b/examples/transformer_tts/train.py @@ -0,0 +1,202 @@ +import time +import logging +from pathlib import Path +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from tensorboardX import SummaryWriter +from collections import defaultdict + +import parakeet +from parakeet.data import dataset +from parakeet.frontend import English +from parakeet.models.transformer_tts import TransformerTTS, TransformerTTSLoss +from parakeet.utils import scheduler, checkpoint, mp_tools, display +from parakeet.training.cli import default_argument_parser +from parakeet.training.experiment import ExperimentBase + +from config import get_cfg_defaults +from ljspeech import LJSpeech, LJSpeechCollector, Transform + +class Experiment(ExperimentBase): + def setup_model(self): + config = self.config + frontend = English() + model = TransformerTTS( + frontend, + d_encoder=config.model.d_encoder, + d_decoder=config.model.d_decoder, + d_mel=config.data.d_mel, + n_heads=config.model.n_heads, + d_ffn=config.model.d_ffn, + encoder_layers=config.model.encoder_layers, + decoder_layers=config.model.decoder_layers, + d_prenet=config.model.d_prenet, + d_postnet=config.model.d_postnet, + postnet_layers=config.model.postnet_layers, + postnet_kernel_size=config.model.postnet_kernel_size, + max_reduction_factor=config.model.max_reduction_factor, + decoder_prenet_dropout=config.model.decoder_prenet_dropout, + dropout=config.model.dropout) + if self.parallel: + model = paddle.DataParallel(model) + optimizer = paddle.optimizer.Adam( + learning_rate=config.training.lr, + beta1=0.9, + beta2=0.98, + epsilon=1e-9, + parameters=model.parameters() + ) + criterion = TransformerTTSLoss(config.model.stop_loss_scale) + drop_n_heads = scheduler.StepWise(config.training.drop_n_heads) + reduction_factor = scheduler.StepWise(config.training.reduction_factor) + + self.model = model + self.optimizer = optimizer + self.criterion = criterion + self.drop_n_heads = drop_n_heads + self.reduction_factor = reduction_factor + + def setup_dataloader(self): + args = self.args + config = self.config + + ljspeech_dataset = LJSpeech(args.data) + transform = Transform(config.data.mel_start_value, config.data.mel_end_value) + ljspeech_dataset = dataset.TransformDataset(ljspeech_dataset, transform) + valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size) + batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx) + + if not self.parallel: + train_loader = DataLoader( + train_set, + batch_size=config.data.batch_size, + shuffle=True, + drop_last=True, + collate_fn=batch_fn) + else: + sampler = DistributedBatchSampler( + train_set, + batch_size=config.data.batch_size, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True, + drop_last=True) + train_loader = DataLoader( + train_set, batch_sampler=sampler, collate_fn=batch_fn) + + valid_loader = DataLoader( + valid_set, batch_size=config.data.batch_size, collate_fn=batch_fn) + + self.train_loader = train_loader + self.valid_loader = valid_loader + + def compute_outputs(self, text, mel, stop_label): + model_core = self.model._layers if self.parallel else self.model + model_core.set_constants( + self.reduction_factor(self.iteration), + self.drop_n_heads(self.iteration)) + + # TODO(chenfeiyu): we can combine these 2 slices + mel_input = mel[:,:-1, :] + reduced_mel_input = mel_input[:, ::model_core.r, :] + outputs = self.model(text, reduced_mel_input) + return outputs + + def compute_losses(self, inputs, outputs): + _, mel, stop_label = inputs + mel_target = mel[:, 1:, :] + stop_label_target = stop_label[:, 1:] + + mel_output = outputs["mel_output"] + mel_intermediate = outputs["mel_intermediate"] + stop_logits = outputs["stop_logits"] + + time_steps = mel_target.shape[1] + losses = self.criterion( + mel_output[:,:time_steps, :], + mel_intermediate[:,:time_steps, :], + mel_target, + stop_logits[:,:time_steps, :], + stop_label_target) + return losses + + def train_batch(self): + start = time.time() + batch = self.read_batch() + data_loader_time = time.time() - start + + self.optimizer.clear_grad() + self.model.train() + text, mel, stop_label = batch + outputs = self.compute_outputs(text, mel, stop_label) + losses = self.compute_losses(batch, outputs) + loss = losses["loss"] + loss.backward() + self.optimizer.step() + iteration_time = time.time() - start + + losses_np = {k: float(v) for k, v in losses.items()} + # logging + msg = "Rank: {}, ".format(dist.get_rank()) + msg += "step: {}, ".format(self.iteration) + msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time) + msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) + self.logger.info(msg) + + if dist.get_rank() == 0: + for k, v in losses_np.items(): + self.visualizer.add_scalar(f"train_loss/{k}", v, self.iteration) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + valid_losses = defaultdict(list) + for i, batch in enumerate(self.valid_loader): + text, mel, stop_label = batch + outputs = self.compute_outputs(text, mel, stop_label) + losses = self.compute_losses(batch, outputs) + for k, v in losses.items(): + valid_losses[k].append(float(v)) + + if i < 2: + attention_weights = outputs["cross_attention_weights"] + display.add_multi_attention_plots( + self.visualizer, + f"valid_sentence_{i}_cross_attention_weights", + attention_weights, + self.iteration) + + # write visual log + valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} + for k, v in valid_losses.items(): + self.visualizer.add_scalar(f"valid/{k}", v, self.iteration) + + +def main_sp(config, args): + exp = Experiment(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.nprocs > 1 and args.device == "gpu": + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = default_argument_parser() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) diff --git a/examples/waveflow/config.py b/examples/waveflow/config.py new file mode 100644 index 0000000..97a877a --- /dev/null +++ b/examples/waveflow/config.py @@ -0,0 +1,43 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.data = CN( + dict( + batch_size=8, # batch size + valid_size=16, # the first N examples are reserved for validation + sample_rate=22050, # Hz, sample rate + n_fft=1024, # fft frame size + win_length=1024, # window size + hop_length=256, # hop size between ajacent frame + f_max=8000, # Hz, max frequency when converting to mel + n_mels=80, # mel bands + clip_frames=65, # mel clip frames + ) +) + +_C.model = CN( + dict( + upsample_factors=[16, 16], + n_flows=8, # number of flows in WaveFlow + n_layers=8, # number of conv block in each flow + n_group=16, # folding factor of audio and spectrogram + channels=128, # resiaudal channel in each flow + kernel_size=[3, 3], # kernel size in each conv block + sigma=1.0, # stddev of the random noise + ) +) + +_C.training = CN( + dict( + lr=2e-4, # learning rates + valid_interval=1000, # validation + save_interval=10000, # checkpoint + max_iteration=3000000, # max iteration to train + ) +) + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() diff --git a/examples/waveflow/ljspeech.py b/examples/waveflow/ljspeech.py new file mode 100644 index 0000000..d7f5425 --- /dev/null +++ b/examples/waveflow/ljspeech.py @@ -0,0 +1,78 @@ +import os +from pathlib import Path +import pickle +import numpy as np +import pandas +from paddle.io import Dataset, DataLoader + +from parakeet.data.batch import batch_spec, batch_wav +from parakeet.data import dataset +from parakeet.audio import AudioProcessor + +class LJSpeech(Dataset): + """A simple dataset adaptor for the processed ljspeech dataset.""" + def __init__(self, root): + self.root = Path(root).expanduser() + meta_data = pandas.read_csv( + str(self.root / "metadata.csv"), + sep="\t", + header=None, + names=["fname", "frames", "samples"] + ) + + records = [] + for row in meta_data.itertuples() : + mel_path = str(self.root / "mel" / (row.fname + ".npy")) + wav_path = str(self.root / "wav" / (row.fname + ".npy")) + records.append((mel_path, wav_path)) + self.records = records + + def __getitem__(self, i): + mel_name, wav_name = self.records[i] + mel = np.load(mel_name) + wav = np.load(wav_name) + return mel, wav + + def __len__(self): + return len(self.records) + + +class LJSpeechCollector(object): + """A simple callable to batch LJSpeech examples.""" + def __init__(self, padding_value=0.): + self.padding_value = padding_value + + def __call__(self, examples): + mels = [example[0] for example in examples] + wavs = [example[1] for example in examples] + mels = batch_spec(mels, pad_value=self.padding_value) + wavs = batch_wav(wavs, pad_value=self.padding_value) + return mels, wavs + + +class LJSpeechClipCollector(object): + def __init__(self, clip_frames=65, hop_length=256): + self.clip_frames = clip_frames + self.hop_length = hop_length + + def __call__(self, examples): + mels = [] + wavs = [] + for example in examples: + mel_clip, wav_clip = self.clip(example) + mels.append(mel_clip) + wavs.append(wav_clip) + mels = np.stack(mels) + wavs = np.stack(wavs) + return mels, wavs + + def clip(self, example): + mel, wav = example + frames = mel.shape[-1] + start = np.random.randint(0, frames - self.clip_frames) + mel_clip = mel[:, start: start + self.clip_frames] + wav_clip = wav[start * self.hop_length: (start + self.clip_frames) * self.hop_length] + return mel_clip, wav_clip + + + diff --git a/examples/waveflow/preprocess.py b/examples/waveflow/preprocess.py new file mode 100644 index 0000000..d4bdc8e --- /dev/null +++ b/examples/waveflow/preprocess.py @@ -0,0 +1,138 @@ +import os +import tqdm +import csv +import argparse +import numpy as np +import librosa +from pathlib import Path +import pandas as pd + +from paddle.io import Dataset +from parakeet.data import batch_spec, batch_wav +from parakeet.datasets import LJSpeechMetaData +from parakeet.audio import AudioProcessor, LogMagnitude + +from config import get_cfg_defaults + + +class Transform(object): + def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels): + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.n_mels = n_mels + + self.spec_normalizer = LogMagnitude(min=1e-5) + + def __call__(self, example): + wav_path, _, _ = example + + sr = self.sample_rate + n_fft = self.n_fft + win_length = self.win_length + hop_length = self.hop_length + n_mels = self.n_mels + + wav, loaded_sr = librosa.load(wav_path, sr=None) + assert loaded_sr == sr, "sample rate does not match, resampling applied" + + # Pad audio to the right size. + frames = int(np.ceil(float(wav.size) / hop_length)) + fft_padding = (n_fft - hop_length) // 2 # sound + desired_length = frames * hop_length + fft_padding * 2 + pad_amount = (desired_length - wav.size) // 2 + + if wav.size % 2 == 0: + wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect') + else: + wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect') + + # Normalize audio. + wav = wav / np.abs(wav).max() * 0.999 + + # Compute mel-spectrogram. + # Turn center to False to prevent internal padding. + spectrogram = librosa.core.stft( + wav, + hop_length=hop_length, + win_length=win_length, + n_fft=n_fft, + center=False) + spectrogram_magnitude = np.abs(spectrogram) + + # Compute mel-spectrograms. + mel_filter_bank = librosa.filters.mel(sr=sr, + n_fft=n_fft, + n_mels=n_mels) + mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude) + mel_spectrogram = mel_spectrogram + + # log scale mel_spectrogram. + mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram) + + # Extract the center of audio that corresponds to mel spectrograms. + audio = wav[fft_padding:-fft_padding] + assert mel_spectrogram.shape[1] * hop_length == audio.size + + # there is no clipping here + return audio, mel_spectrogram + + +def create_dataset(config, input_dir, output_dir, verbose=True): + input_dir = Path(input_dir).expanduser() + dataset = LJSpeechMetaData(input_dir) + + output_dir = Path(output_dir).expanduser() + output_dir.mkdir(exist_ok=True) + + transform = Transform( + config.sample_rate, + config.n_fft, + config.win_length, + config.hop_length, + config.n_mels) + file_names = [] + + for example in tqdm.tqdm(dataset): + fname, _, _ = example + base_name = os.path.splitext(os.path.basename(fname))[0] + wav_dir = output_dir / "wav" + mel_dir = output_dir / "mel" + wav_dir.mkdir(exist_ok=True) + mel_dir.mkdir(exist_ok=True) + + audio, mel = transform(example) + np.save(str(wav_dir / base_name), audio) + np.save(str(mel_dir / base_name), mel) + + file_names.append((base_name, mel.shape[-1], audio.shape[-1])) + + meta_data = pd.DataFrame.from_records(file_names) + meta_data.to_csv(str(output_dir / "metadata.csv"), sep="\t", index=None, header=None) + print("saved meta data in to {}".format(os.path.join(output_dir, "metadata.csv"))) + + print("Done!") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="create dataset") + parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config") + parser.add_argument("--input", type=str, help="path of the ljspeech dataset") + parser.add_argument("--output", type=str, help="path to save output dataset") + parser.add_argument("--opts", nargs=argparse.REMAINDER, + help="options to overwrite --config file and the default config, passing in KEY VALUE pairs" + ) + parser.add_argument("-v", "--verbose", action="store_true", help="print msg") + + config = get_cfg_defaults() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + if args.verbose: + print(config.data) + print(args) + + create_dataset(config.data, args.input, args.output, args.verbose) diff --git a/examples/waveflow/synthesize.py b/examples/waveflow/synthesize.py new file mode 100644 index 0000000..1856eb2 --- /dev/null +++ b/examples/waveflow/synthesize.py @@ -0,0 +1,52 @@ +import argparse +import numpy as np +import soundfile as sf +import os +from pathlib import Path +import paddle +import parakeet +from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow +from parakeet.utils import layer_tools, checkpoint + + +from config import get_cfg_defaults + +def main(config, args): + paddle.set_device(args.device) + model = ConditionalWaveFlow.from_pretrained(config, args.checkpoint_path) + layer_tools.recursively_remove_weight_norm(model) + model.eval() + + mel_dir = Path(args.input).expanduser() + output_dir = Path(args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + for file_path in mel_dir.iterdir(): + mel = np.load(str(file_path)) + audio = model.predict(mel) + audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav") + sf.write(audio_path, audio, config.data.sample_rate) + print("[synthesize] {} -> {}".format(file_path, audio_path)) + + +if __name__ == "__main__": + config = get_cfg_defaults() + + parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.") + parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config") + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.") + parser.add_argument("--input", type=str, help="path of directory containing mel spectrogram (in .npy format)") + parser.add_argument("--output", type=str, help="path to save outputs") + parser.add_argument("--device", type=str, default="cpu", help="device type to use.") + parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + parser.add_argument("-v", "--verbose", action="store_true", help="print msg") + + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) \ No newline at end of file diff --git a/examples/waveflow/train.py b/examples/waveflow/train.py new file mode 100644 index 0000000..ae19994 --- /dev/null +++ b/examples/waveflow/train.py @@ -0,0 +1,147 @@ +import time +from pathlib import Path +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from tensorboardX import SummaryWriter +from collections import defaultdict + +import parakeet +from parakeet.data import dataset +from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow, WaveFlowLoss +from parakeet.audio import AudioProcessor +from parakeet.utils import scheduler, mp_tools +from parakeet.training.cli import default_argument_parser +from parakeet.training.experiment import ExperimentBase +from parakeet.utils.mp_tools import rank_zero_only + +from config import get_cfg_defaults +from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector + + +class Experiment(ExperimentBase): + def setup_model(self): + config = self.config + model = ConditionalWaveFlow( + upsample_factors=config.model.upsample_factors, + n_flows=config.model.n_flows, + n_layers=config.model.n_layers, + n_group=config.model.n_group, + channels=config.model.channels, + n_mels=config.data.n_mels, + kernel_size=config.model.kernel_size) + + if self.parallel > 1: + model = paddle.DataParallel(model) + optimizer = paddle.optimizer.Adam(config.training.lr, parameters=model.parameters()) + criterion = WaveFlowLoss(sigma=config.model.sigma) + + self.model = model + self.optimizer = optimizer + self.criterion = criterion + + def setup_dataloader(self): + config = self.config + args = self.args + + ljspeech_dataset = LJSpeech(args.data) + valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size) + + batch_fn = LJSpeechClipCollector(config.data.clip_frames, config.data.hop_length) + + if not self.parallel: + train_loader = DataLoader( + train_set, + batch_size=config.data.batch_size, + shuffle=True, + drop_last=True, + collate_fn=batch_fn) + else: + sampler = DistributedBatchSampler( + train_set, + batch_size=config.data.batch_size, + num_replicas=dist.get_world_size(), + rank=dist.get_rank(), + shuffle=True, + drop_last=True) + train_loader = DataLoader( + train_set, batch_sampler=sampler, collate_fn=batch_fn) + + valid_batch_fn = LJSpeechCollector() + valid_loader = DataLoader( + valid_set, batch_size=1, collate_fn=valid_batch_fn) + + self.train_loader = train_loader + self.valid_loader = valid_loader + + def compute_outputs(self, mel, wav): + # model_core = model._layers if isinstance(model, paddle.DataParallel) else model + z, log_det_jocobian = self.model(wav, mel) + return z, log_det_jocobian + + def compute_losses(self, outputs): + loss = self.criterion(outputs) + return loss + + def train_batch(self): + start = time.time() + batch = self.read_batch() + data_loader_time = time.time() - start + + self.model.train() + self.optimizer.clear_grad() + mel, wav = batch + outputs = self.compute_outputs(mel, wav) + loss = self.compute_losses(outputs) + loss.backward() + self.optimizer.step() + iteration_time = time.time() - start + + loss_value = float(loss) + msg = "Rank: {}, ".format(dist.get_rank()) + msg += "step: {}, ".format(self.iteration) + msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time) + msg += "loss: {:>.6f}".format(loss_value) + self.logger.info(msg) + self.visualizer.add_scalar("train/loss", loss_value, global_step=self.iteration) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + valid_iterator = iter(self.valid_loader) + valid_losses = [] + mel, wav = next(valid_iterator) + outputs = self.compute_outputs(mel, wav) + loss = self.compute_losses(outputs) + valid_losses.append(float(loss)) + valid_loss = np.mean(valid_losses) + self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration) + + +def main_sp(config, args): + exp = Experiment(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.nprocs > 1 and args.device == "gpu": + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = default_argument_parser() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) diff --git a/examples/wavenet/config.py b/examples/wavenet/config.py new file mode 100644 index 0000000..58f9beb --- /dev/null +++ b/examples/wavenet/config.py @@ -0,0 +1,47 @@ +from yacs.config import CfgNode as CN + +_C = CN() +_C.data = CN( + dict( + batch_size=8, # batch size + valid_size=16, # the first N examples are reserved for validation + sample_rate=22050, # Hz, sample rate + n_fft=2048, # fft frame size + win_length=1024, # window size + hop_length=256, # hop size between ajacent frame + # f_max=8000, # Hz, max frequency when converting to mel + n_mels=80, # mel bands + train_clip_seconds=0.5, # audio clip length(in seconds) + ) +) + +_C.model = CN( + dict( + upsample_factors=[16, 16], + n_stack=3, + n_loop=10, + filter_size=2, + residual_channels=128, # resiaudal channel in each flow + loss_type="mog", + output_dim=3, # single gaussian + log_scale_min=-9.0, + ) +) + +_C.training = CN( + dict( + lr=1e-3, # learning rates + anneal_rate=0.5, # learning rate decay rate + anneal_interval=200000, # decrese lr by annel_rate every anneal_interval steps + valid_interval=1000, # validation + save_interval=10000, # checkpoint + max_iteration=3000000, # max iteration to train + gradient_max_norm=100.0 # global norm of gradients + ) +) + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() diff --git a/examples/wavenet/ljspeech.py b/examples/wavenet/ljspeech.py new file mode 100644 index 0000000..18dc388 --- /dev/null +++ b/examples/wavenet/ljspeech.py @@ -0,0 +1,138 @@ +import os +from pathlib import Path +import pickle +import numpy as np +import pandas +from paddle.io import Dataset, DataLoader + +from parakeet.data.batch import batch_spec, batch_wav +from parakeet.data import dataset +from parakeet.audio import AudioProcessor + +class LJSpeech(Dataset): + """A simple dataset adaptor for the processed ljspeech dataset.""" + def __init__(self, root): + self.root = Path(root).expanduser() + meta_data = pandas.read_csv( + str(self.root / "metadata.csv"), + sep="\t", + header=None, + names=["fname", "frames", "samples"] + ) + + records = [] + for row in meta_data.itertuples() : + mel_path = str(self.root / "mel" / (row.fname + ".npy")) + wav_path = str(self.root / "wav" / (row.fname + ".npy")) + records.append((mel_path, wav_path)) + self.records = records + + def __getitem__(self, i): + mel_name, wav_name = self.records[i] + mel = np.load(mel_name) + wav = np.load(wav_name) + return mel, wav + + def __len__(self): + return len(self.records) + + +class LJSpeechCollector(object): + """A simple callable to batch LJSpeech examples.""" + def __init__(self, padding_value=0.): + self.padding_value = padding_value + + def __call__(self, examples): + batch_size = len(examples) + mels = [example[0] for example in examples] + wavs = [example[1] for example in examples] + mels = batch_spec(mels, pad_value=self.padding_value) + wavs = batch_wav(wavs, pad_value=self.padding_value) + audio_starts = np.zeros((batch_size,), dtype=np.int64) + return mels, wavs, audio_starts + + +class LJSpeechClipCollector(object): + def __init__(self, clip_frames=65, hop_length=256): + self.clip_frames = clip_frames + self.hop_length = hop_length + + def __call__(self, examples): + mels = [] + wavs = [] + starts = [] + for example in examples: + mel, wav_clip, start = self.clip(example) + mels.append(mel) + wavs.append(wav_clip) + starts.append(start) + mels = batch_spec(mels) + wavs = np.stack(wavs) + starts = np.array(starts, dtype=np.int64) + return mels, wavs, starts + + def clip(self, example): + mel, wav = example + frames = mel.shape[-1] + start = np.random.randint(0, frames - self.clip_frames) + wav_clip = wav[start * self.hop_length: (start + self.clip_frames) * self.hop_length] + return mel, wav_clip, start + + +class DataCollector(object): + def __init__(self, + context_size, + sample_rate, + hop_length, + train_clip_seconds, + valid=False): + frames_per_second = sample_rate // hop_length + train_clip_frames = int( + np.ceil(train_clip_seconds * frames_per_second)) + context_frames = context_size // hop_length + self.num_frames = train_clip_frames + context_frames + + self.sample_rate = sample_rate + self.hop_length = hop_length + self.valid = valid + + def random_crop(self, sample): + audio, mel_spectrogram = sample + audio_frames = int(audio.size) // self.hop_length + max_start_frame = audio_frames - self.num_frames + assert max_start_frame >= 0, "audio is too short to be cropped" + + frame_start = np.random.randint(0, max_start_frame) + # frame_start = 0 # norandom + frame_end = frame_start + self.num_frames + + audio_start = frame_start * self.hop_length + audio_end = frame_end * self.hop_length + + audio = audio[audio_start:audio_end] + return audio, mel_spectrogram, audio_start + + def __call__(self, samples): + # transform them first + if self.valid: + samples = [(audio, mel_spectrogram, 0) + for audio, mel_spectrogram in samples] + else: + samples = [self.random_crop(sample) for sample in samples] + # batch them + audios = [sample[0] for sample in samples] + audio_starts = [sample[2] for sample in samples] + mels = [sample[1] for sample in samples] + + mels = batch_spec(mels) + + if self.valid: + audios = batch_wav(audios, dtype=np.float32) + else: + audios = np.array(audios, dtype=np.float32) + audio_starts = np.array(audio_starts, dtype=np.int64) + return audios, mels, audio_starts + + + + diff --git a/examples/wavenet/preprocess.py b/examples/wavenet/preprocess.py new file mode 100644 index 0000000..29b140c --- /dev/null +++ b/examples/wavenet/preprocess.py @@ -0,0 +1,139 @@ +import os +import tqdm +import csv +import argparse +import numpy as np +import librosa +from pathlib import Path +import pandas as pd + +from paddle.io import Dataset +from parakeet.data import batch_spec, batch_wav +from parakeet.datasets import LJSpeechMetaData +from parakeet.audio import AudioProcessor +from parakeet.audio.spec_normalizer import UnitMagnitude + +from config import get_cfg_defaults + + +class Transform(object): + def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels): + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length + self.hop_length = hop_length + self.n_mels = n_mels + + self.spec_normalizer = UnitMagnitude(min=1e-5) + + def __call__(self, example): + wav_path, _, _ = example + + sr = self.sample_rate + n_fft = self.n_fft + win_length = self.win_length + hop_length = self.hop_length + n_mels = self.n_mels + + wav, loaded_sr = librosa.load(wav_path, sr=None) + assert loaded_sr == sr, "sample rate does not match, resampling applied" + + # Pad audio to the right size. + frames = int(np.ceil(float(wav.size) / hop_length)) + fft_padding = (n_fft - hop_length) // 2 # sound + desired_length = frames * hop_length + fft_padding * 2 + pad_amount = (desired_length - wav.size) // 2 + + if wav.size % 2 == 0: + wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect') + else: + wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect') + + # Normalize audio. + wav = wav / np.abs(wav).max() * 0.999 + + # Compute mel-spectrogram. + # Turn center to False to prevent internal padding. + spectrogram = librosa.core.stft( + wav, + hop_length=hop_length, + win_length=win_length, + n_fft=n_fft, + center=False) + spectrogram_magnitude = np.abs(spectrogram) + + # Compute mel-spectrograms. + mel_filter_bank = librosa.filters.mel(sr=sr, + n_fft=n_fft, + n_mels=n_mels) + mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude) + mel_spectrogram = mel_spectrogram + + # log scale mel_spectrogram. + mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram) + + # Extract the center of audio that corresponds to mel spectrograms. + audio = wav[fft_padding:-fft_padding] + assert mel_spectrogram.shape[1] * hop_length == audio.size + + # there is no clipping here + return audio, mel_spectrogram + + +def create_dataset(config, input_dir, output_dir, verbose=True): + input_dir = Path(input_dir).expanduser() + dataset = LJSpeechMetaData(input_dir) + + output_dir = Path(output_dir).expanduser() + output_dir.mkdir(exist_ok=True) + + transform = Transform( + config.sample_rate, + config.n_fft, + config.win_length, + config.hop_length, + config.n_mels) + file_names = [] + + for example in tqdm.tqdm(dataset): + fname, _, _ = example + base_name = os.path.splitext(os.path.basename(fname))[0] + wav_dir = output_dir / "wav" + mel_dir = output_dir / "mel" + wav_dir.mkdir(exist_ok=True) + mel_dir.mkdir(exist_ok=True) + + audio, mel = transform(example) + np.save(str(wav_dir / base_name), audio) + np.save(str(mel_dir / base_name), mel) + + file_names.append((base_name, mel.shape[-1], audio.shape[-1])) + + meta_data = pd.DataFrame.from_records(file_names) + meta_data.to_csv(str(output_dir / "metadata.csv"), sep="\t", index=None, header=None) + print("saved meta data in to {}".format(os.path.join(output_dir, "metadata.csv"))) + + print("Done!") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="create dataset") + parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config") + parser.add_argument("--input", type=str, help="path of the ljspeech dataset") + parser.add_argument("--output", type=str, help="path to save output dataset") + parser.add_argument("--opts", nargs=argparse.REMAINDER, + help="options to overwrite --config file and the default config, passing in KEY VALUE pairs" + ) + parser.add_argument("-v", "--verbose", action="store_true", help="print msg") + + config = get_cfg_defaults() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + if args.verbose: + print(config.data) + print(args) + + create_dataset(config.data, args.input, args.output, args.verbose) diff --git a/examples/wavenet/synthesize.py b/examples/wavenet/synthesize.py new file mode 100644 index 0000000..80b96a2 --- /dev/null +++ b/examples/wavenet/synthesize.py @@ -0,0 +1,51 @@ +import argparse +import numpy as np +import soundfile as sf +import os +from pathlib import Path +import paddle +import parakeet +from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet +from parakeet.utils import layer_tools, checkpoint + +from config import get_cfg_defaults + +def main(config, args): + paddle.set_device(args.device) + model = ConditionalWaveNet.from_pretrained(config, args.checkpoint_path) + layer_tools.recursively_remove_weight_norm(model) + model.eval() + + mel_dir = Path(args.input).expanduser() + output_dir = Path(args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + for file_path in mel_dir.iterdir(): + mel = np.load(str(file_path)) + audio = model.predict(mel) + audio_path = output_dir / (os.path.splitext(file_path.name)[0] + ".wav") + sf.write(audio_path, audio, config.data.sample_rate) + print("[synthesize] {} -> {}".format(file_path, audio_path)) + + +if __name__ == "__main__": + config = get_cfg_defaults() + + parser = argparse.ArgumentParser(description="generate mel spectrogram with TransformerTTS.") + parser.add_argument("--config", type=str, metavar="FILE", help="extra config to overwrite the default config") + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load.") + parser.add_argument("--input", type=str, help="path of directory containing mel spectrogram (in .npy format)") + parser.add_argument("--output", type=str, help="path to save outputs") + parser.add_argument("--device", type=str, default="cpu", help="device type to use.") + parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + parser.add_argument("-v", "--verbose", action="store_true", help="print msg") + + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args) \ No newline at end of file diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py new file mode 100644 index 0000000..77c54e3 --- /dev/null +++ b/examples/wavenet/train.py @@ -0,0 +1,157 @@ +import time +from pathlib import Path +import math +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from tensorboardX import SummaryWriter +from collections import defaultdict + +import parakeet +from parakeet.data import dataset +from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet +from parakeet.audio import AudioProcessor +from parakeet.utils import scheduler, mp_tools +from parakeet.training.cli import default_argument_parser +from parakeet.training.experiment import ExperimentBase +from parakeet.utils.mp_tools import rank_zero_only + +from config import get_cfg_defaults +from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector + + +class Experiment(ExperimentBase): + def setup_model(self): + config = self.config + model = ConditionalWaveNet( + upsample_factors=config.model.upsample_factors, + n_stack=config.model.n_stack, + n_loop=config.model.n_loop, + residual_channels=config.model.residual_channels, + output_dim=config.model.output_dim, + n_mels=config.data.n_mels, + filter_size=config.model.filter_size, + loss_type=config.model.loss_type, + log_scale_min=config.model.log_scale_min) + + if self.parallel > 1: + model = paddle.DataParallel(model) + + lr_scheduler = paddle.optimizer.lr.StepDecay( + config.training.lr, + config.training.anneal_interval, + config.training.anneal_rate) + optimizer = paddle.optimizer.Adam( + lr_scheduler, + parameters=model.parameters(), + grad_clip=paddle.nn.ClipGradByGlobalNorm(config.training.gradient_max_norm)) + + self.model = model + self.model_core = model._layer if self.parallel else model + self.optimizer = optimizer + + def setup_dataloader(self): + config = self.config + args = self.args + + ljspeech_dataset = LJSpeech(args.data) + valid_set, train_set = dataset.split(ljspeech_dataset, config.data.valid_size) + + # convolutional net's causal padding size + context_size = config.model.n_stack \ + * sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \ + + 1 + context_frames = context_size // config.data.hop_length + + # frames used to compute loss + frames_per_second = config.data.sample_rate // config.data.hop_length + train_clip_frames = math.ceil(config.data.train_clip_seconds * frames_per_second) + + num_frames = train_clip_frames + context_frames + batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length) + if not self.parallel: + train_loader = DataLoader( + train_set, + batch_size=config.data.batch_size, + shuffle=True, + drop_last=True, + collate_fn=batch_fn) + else: + sampler = DistributedBatchSampler( + train_set, + batch_size=config.data.batch_size, + shuffle=True, + drop_last=True) + train_loader = DataLoader( + train_set, batch_sampler=sampler, collate_fn=batch_fn) + + valid_batch_fn = LJSpeechCollector() + valid_loader = DataLoader( + valid_set, batch_size=1, collate_fn=valid_batch_fn) + + self.train_loader = train_loader + self.valid_loader = valid_loader + + def train_batch(self): + start = time.time() + batch = self.read_batch() + data_loader_time = time.time() - start + + self.model.train() + self.optimizer.clear_grad() + mel, wav, audio_starts = batch + + y = self.model(wav, mel, audio_starts) + loss = self.model.loss(y, wav) + loss.backward() + self.optimizer.step() + iteration_time = time.time() - start + + loss_value = float(loss) + msg = "Rank: {}, ".format(dist.get_rank()) + msg += "step: {}, ".format(self.iteration) + msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, iteration_time) + msg += "loss: {:>.6f}".format(loss_value) + self.logger.info(msg) + self.visualizer.add_scalar("train/loss", loss_value, global_step=self.iteration) + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + valid_iterator = iter(self.valid_loader) + valid_losses = [] + mel, wav, audio_starts = next(valid_iterator) + y = self.model(wav, mel, audio_starts) + loss = self.model.loss(y, wav) + valid_losses.append(float(loss)) + valid_loss = np.mean(valid_losses) + self.visualizer.add_scalar("valid/loss", valid_loss, global_step=self.iteration) + + +def main_sp(config, args): + exp = Experiment(config, args) + exp.setup() + exp.run() + + +def main(config, args): + if args.nprocs > 1 and args.device == "gpu": + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + else: + main_sp(config, args) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = default_argument_parser() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + print(args) + + main(config, args)