commit
cf43f2cf03
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
_C.data = CN(
|
||||
dict(
|
||||
batch_size=32, # batch size
|
||||
valid_size=64, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
f_min=0, # Hz, min frequency when converting to mel
|
||||
d_mels=80, # mel bands
|
||||
padding_idx=0, # text embedding's padding index
|
||||
))
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
reduction_factor=1, # reduction factor
|
||||
d_encoder=512, # embedding & encoder's internal size
|
||||
encoder_conv_layers=3, # number of conv layer in tacotron2 encoder
|
||||
encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder
|
||||
d_prenet=256, # hidden size of decoder prenet
|
||||
d_attention_rnn=1024, # hidden size of the first rnn layer in tacotron2 decoder
|
||||
d_decoder_rnn=1024, #hidden size of the second rnn layer in tacotron2 decoder
|
||||
d_attention=128, # hidden size of decoder location linear layer
|
||||
attention_filters=32, # number of filter in decoder location conv layer
|
||||
attention_kernel_size=31, # kernel size of decoder location conv layer
|
||||
d_postnet=512, # hidden size of decoder postnet
|
||||
postnet_kernel_size=5, # kernel size of conv layers in postnet
|
||||
postnet_conv_layers=5, # number of conv layer in decoder postnet
|
||||
p_encoder_dropout=0.5, # droput probability in encoder
|
||||
p_prenet_dropout=0.5, # droput probability in decoder prenet
|
||||
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
||||
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
||||
p_postnet_dropout=0.5, #droput probability in decoder postnet
|
||||
))
|
||||
|
||||
_C.training = CN(
|
||||
dict(
|
||||
lr=1e-3, # learning rate
|
||||
weight_decay=1e-6, # the coeff of weight decay
|
||||
grad_clip_thresh=1.0, # the clip norm of grad clip.
|
||||
plot_interval=1000, # plot attention and spectrogram
|
||||
valid_interval=1000, # validation
|
||||
save_interval=1000, # checkpoint
|
||||
max_iteration=500000, # max iteration to train
|
||||
))
|
||||
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _C.clone()
|
|
@ -0,0 +1,106 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_text_id
|
||||
from parakeet.data import dataset
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
"""A simple dataset adaptor for the processed ljspeech dataset."""
|
||||
|
||||
def __init__(self, root):
|
||||
self.root = Path(root).expanduser()
|
||||
records = []
|
||||
with open(self.root / "metadata.pkl", 'rb') as f:
|
||||
metadata = pickle.load(f)
|
||||
for mel_name, text, ids in metadata:
|
||||
mel_name = self.root / "mel" / (mel_name + ".npy")
|
||||
records.append((mel_name, text, ids))
|
||||
self.records = records
|
||||
|
||||
def __getitem__(self, i):
|
||||
mel_name, _, ids = self.records[i]
|
||||
mel = np.load(mel_name)
|
||||
return ids, mel
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
|
||||
class LJSpeechCollector(object):
|
||||
"""A simple callable to batch LJSpeech examples."""
|
||||
|
||||
def __init__(self, padding_idx=0, padding_value=0.,
|
||||
padding_stop_token=1.0):
|
||||
self.padding_idx = padding_idx
|
||||
self.padding_value = padding_value
|
||||
self.padding_stop_token = padding_stop_token
|
||||
|
||||
def __call__(self, examples):
|
||||
texts = []
|
||||
mels = []
|
||||
text_lens = []
|
||||
mel_lens = []
|
||||
stop_tokens = []
|
||||
for data in examples:
|
||||
text, mel = data
|
||||
text = np.array(text, dtype=np.int64)
|
||||
text_lens.append(len(text))
|
||||
mels.append(mel)
|
||||
texts.append(text)
|
||||
mel_lens.append(mel.shape[1])
|
||||
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
|
||||
stop_tokens.append(np.append(stop_token, 1.0))
|
||||
|
||||
# Sort by text_len in descending order
|
||||
texts = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(texts, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
mels = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(mels, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
mel_lens = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
stop_tokens = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(stop_tokens, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
text_lens = sorted(text_lens, reverse=True)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts = batch_text_id(texts, pad_id=self.padding_idx)
|
||||
mels = np.transpose(
|
||||
batch_spec(
|
||||
mels, pad_value=self.padding_value), axes=(0, 2, 1))
|
||||
stop_tokens = batch_text_id(
|
||||
stop_tokens, pad_id=self.padding_stop_token, dtype=mels[0].dtype)
|
||||
|
||||
return (texts, mels, text_lens, mel_lens, stop_tokens)
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tqdm
|
||||
import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor, LogMagnitude
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
def create_dataset(config, source_path, target_path, verbose=False):
|
||||
# create output dir
|
||||
target_path = Path(target_path).expanduser()
|
||||
mel_path = target_path / "mel"
|
||||
os.makedirs(mel_path, exist_ok=True)
|
||||
|
||||
meta_data = LJSpeechMetaData(source_path)
|
||||
frontend = EnglishCharacter()
|
||||
processor = AudioProcessor(
|
||||
sample_rate=config.data.sample_rate,
|
||||
n_fft=config.data.n_fft,
|
||||
n_mels=config.data.d_mels,
|
||||
win_length=config.data.win_length,
|
||||
hop_length=config.data.hop_length,
|
||||
f_max=config.data.f_max,
|
||||
f_min=config.data.f_min)
|
||||
normalizer = LogMagnitude()
|
||||
|
||||
records = []
|
||||
for (fname, text, _) in tqdm.tqdm(meta_data):
|
||||
wav = processor.read_wav(fname)
|
||||
mel = processor.mel_spectrogram(wav)
|
||||
mel = normalizer.transform(mel)
|
||||
ids = frontend(text)
|
||||
mel_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
|
||||
# save mel spectrogram
|
||||
records.append((mel_name, text, ids))
|
||||
np.save(mel_path / mel_name, mel)
|
||||
if verbose:
|
||||
print("save mel spectrograms into {}".format(mel_path))
|
||||
|
||||
# save meta data as pickle archive
|
||||
with open(target_path / "metadata.pkl", 'wb') as f:
|
||||
pickle.dump(records, f)
|
||||
if verbose:
|
||||
print("saved metadata into {}".format(target_path /
|
||||
"metadata.pkl"))
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="create dataset")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
metavar="FILE",
|
||||
help="extra config to overwrite the default config")
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="path of the ljspeech dataset")
|
||||
parser.add_argument(
|
||||
"--output", type=str, help="path to save output dataset")
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
config = get_cfg_defaults()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config.data)
|
||||
|
||||
create_dataset(config, args.input, args.output, args.verbose)
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
import parakeet
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
from parakeet.models.tacotron2 import Tacotron2
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
|
||||
# model
|
||||
frontend = EnglishCharacter()
|
||||
model = Tacotron2.from_pretrained(frontend, config, args.checkpoint_path)
|
||||
model.eval()
|
||||
|
||||
# inputs
|
||||
input_path = Path(args.input).expanduser()
|
||||
with open(input_path, "rt") as f:
|
||||
sentences = f.readlines()
|
||||
|
||||
if args.output is None:
|
||||
output_dir = input_path.parent / "synthesis"
|
||||
else:
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
mel_output, _ = model.predict(sentence)
|
||||
mel_output = mel_output.T
|
||||
|
||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||
if args.verbose:
|
||||
print("spectrogram saved at {}".format(output_dir /
|
||||
f"sentence_{i}.npy"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="generate mel spectrogram with TransformerTTS.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
metavar="FILE",
|
||||
help="extra config to overwrite the default config")
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="path of the checkpoint to load.")
|
||||
parser.add_argument("--input", type=str, help="path of the text sentences")
|
||||
parser.add_argument("--output", type=str, help="path to save outputs")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cpu", help="device type to use.")
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,211 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils import display, mp_tools
|
||||
from parakeet.models.tacotron2 import Tacotron2, Tacotron2Loss
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechCollector
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
def compute_losses(self, inputs, outputs):
|
||||
_, mel_targets, _, _, stop_tokens = inputs
|
||||
|
||||
mel_outputs = outputs["mel_output"]
|
||||
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
||||
stop_logits = outputs["stop_logits"]
|
||||
|
||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, stop_logits,
|
||||
mel_targets, stop_tokens)
|
||||
return losses
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
texts, mels, text_lens, output_lens, stop_tokens = batch
|
||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
loss = losses["loss"]
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
losses_np = {k: float(v) for k, v in losses.items()}
|
||||
# logging
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
||||
iteration_time)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_np.items())
|
||||
self.logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
for k, v in losses_np.items():
|
||||
self.visualizer.add_scalar(f"train_loss/{k}", v,
|
||||
self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_losses = defaultdict(list)
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
texts, mels, text_lens, output_lens, stop_tokens = batch
|
||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
for k, v in losses.items():
|
||||
valid_losses[k].append(float(v))
|
||||
|
||||
attention_weights = outputs["alignments"]
|
||||
display.add_attention_plots(self.visualizer,
|
||||
f"valid_sentence_{i}_alignments",
|
||||
attention_weights[0], self.iteration)
|
||||
display.add_spectrogram_plots(
|
||||
self.visualizer, f"valid_sentence_{i}_target_spectrogram",
|
||||
mels[0], self.iteration)
|
||||
display.add_spectrogram_plots(
|
||||
self.visualizer, f"valid_sentence_{i}_predicted_spectrogram",
|
||||
outputs['mel_outputs_postnet'][0], self.iteration)
|
||||
|
||||
# write visual log
|
||||
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
|
||||
# logging
|
||||
msg = "Valid: "
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in valid_losses.items())
|
||||
self.logger.info(msg)
|
||||
|
||||
for k, v in valid_losses.items():
|
||||
self.visualizer.add_scalar(f"valid/{k}", v, self.iteration)
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
frontend = EnglishCharacter()
|
||||
model = Tacotron2(
|
||||
frontend,
|
||||
d_mels=config.data.d_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||
d_prenet=config.model.d_prenet,
|
||||
d_attention_rnn=config.model.d_attention_rnn,
|
||||
d_decoder_rnn=config.model.d_decoder_rnn,
|
||||
attention_filters=config.model.attention_filters,
|
||||
attention_kernel_size=config.model.attention_kernel_size,
|
||||
d_attention=config.model.d_attention,
|
||||
d_postnet=config.model.d_postnet,
|
||||
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||
postnet_conv_layers=config.model.postnet_conv_layers,
|
||||
reduction_factor=config.model.reduction_factor,
|
||||
p_encoder_dropout=config.model.p_encoder_dropout,
|
||||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout)
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
grad_clip = paddle.nn.ClipGradByGlobalNorm(
|
||||
config.training.grad_clip_thresh)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=config.training.lr,
|
||||
parameters=model.parameters(),
|
||||
weight_decay=paddle.regularizer.L2Decay(
|
||||
config.training.weight_decay),
|
||||
grad_clip=grad_clip)
|
||||
criterion = Tacotron2Loss()
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
||||
def setup_dataloader(self):
|
||||
args = self.args
|
||||
config = self.config
|
||||
ljspeech_dataset = LJSpeech(args.data)
|
||||
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset,
|
||||
config.data.valid_size)
|
||||
batch_fn = LJSpeechCollector(padding_idx=config.data.padding_idx)
|
||||
|
||||
if not self.parallel:
|
||||
self.train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
self.train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
|
||||
self.valid_loader = DataLoader(
|
||||
valid_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=batch_fn)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
Loading…
Reference in New Issue