Merge branch 'add_TranTTS' into 'master'

modified load checkpoint in train

See merge request !2
This commit is contained in:
lifuchen 2020-01-08 14:09:19 +08:00
commit df8e44bbde
5 changed files with 16 additions and 45 deletions

View File

@ -1,41 +0,0 @@
audio:
num_mels: 80
n_fft: 2048
sr: 22050
preemphasis: 0.97
hop_length: 275
win_length: 1102
power: 1.2
min_level_db: -100
ref_level_db: 20
outputs_per_step: 1
encoder_n_layer: 6
encoder_head: 2
encoder_conv1d_filter_size: 1536
max_sep_len: 2048
encoder_output_size: 384
word_vec_dim: 384
decoder_n_layer: 6
decoder_head: 2
decoder_conv1d_filter_size: 1536
decoder_output_size: 384
d_model: 384
duration_predictor_output_size: 256
duration_predictor_filter_size: 3
fft_conv1d_filter: 3
fft_conv1d_padding: 1
batch_size: 32
epochs: 10000
lr: 0.001
save_step: 500
image_step: 2000
use_gpu: False
use_data_parallel: False
data_path: ../../../dataset/LJSpeech-1.1
transtts_path: ./checkpoint
transformer_step: 70000
log_dir: ./log

View File

@ -24,4 +24,6 @@ use_data_parallel: True
data_path: ../../../dataset/LJSpeech-1.1 data_path: ../../../dataset/LJSpeech-1.1
save_path: ./checkpoint save_path: ./checkpoint
log_dir: ./log log_dir: ./log
#checkpoint_path: ./checkpoint
#transformer_step: 27000

View File

@ -28,6 +28,7 @@ use_data_parallel: False
data_path: ../../../dataset/LJSpeech-1.1 data_path: ../../../dataset/LJSpeech-1.1
save_path: ./checkpoint save_path: ./checkpoint
log_dir: ./log log_dir: ./log
#checkpoint_path: ./checkpoint/transformer/1 #checkpoint_path: ./checkpoint
#transformer_step: 70000

View File

@ -25,6 +25,9 @@ class MyDataParallel(dg.parallel.DataParallel):
return getattr( return getattr(
object.__getattribute__(self, "_sub_layers")["_layers"], key) object.__getattribute__(self, "_sub_layers")["_layers"], key)
def load_checkpoint(step, model_path):
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
return model_dict, opti_dict
def main(cfg): def main(cfg):
@ -55,9 +58,10 @@ def main(cfg):
if cfg.checkpoint_path is not None: if cfg.checkpoint_path is not None:
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) model_dict, opti_dict = load_checkpoint(str(cfg.postnet_step), os.path.join(cfg.checkpoint_path, "postnet"))
model.set_dict(model_dict) model.set_dict(model_dict)
optimizer.set_dict(opti_dict) optimizer.set_dict(opti_dict)
global_step = cfg.postnet_step
print("load checkpoint!!!") print("load checkpoint!!!")
if cfg.use_data_parallel: if cfg.use_data_parallel:

View File

@ -29,6 +29,10 @@ class MyDataParallel(dg.parallel.DataParallel):
return getattr( return getattr(
object.__getattribute__(self, "_sub_layers")["_layers"], key) object.__getattribute__(self, "_sub_layers")["_layers"], key)
def load_checkpoint(step, model_path):
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
return model_dict, opti_dict
def main(cfg): def main(cfg):
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0 local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
@ -62,9 +66,10 @@ def main(cfg):
reader = LJSpeechLoader(cfg, nranks, local_rank).reader() reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
if cfg.checkpoint_path is not None: if cfg.checkpoint_path is not None:
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) model_dict, opti_dict = load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.checkpoint_path, "transformer"))
model.set_dict(model_dict) model.set_dict(model_dict)
optimizer.set_dict(opti_dict) optimizer.set_dict(opti_dict)
global_step = cfg.transformer_step
print("load checkpoint!!!") print("load checkpoint!!!")
if cfg.use_data_parallel: if cfg.use_data_parallel: