Merge branch 'add_TranTTS' into 'master'
modified load checkpoint in train See merge request !2
This commit is contained in:
commit
df8e44bbde
|
@ -1,41 +0,0 @@
|
|||
audio:
|
||||
num_mels: 80
|
||||
n_fft: 2048
|
||||
sr: 22050
|
||||
preemphasis: 0.97
|
||||
hop_length: 275
|
||||
win_length: 1102
|
||||
power: 1.2
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
outputs_per_step: 1
|
||||
|
||||
encoder_n_layer: 6
|
||||
encoder_head: 2
|
||||
encoder_conv1d_filter_size: 1536
|
||||
max_sep_len: 2048
|
||||
encoder_output_size: 384
|
||||
word_vec_dim: 384
|
||||
decoder_n_layer: 6
|
||||
decoder_head: 2
|
||||
decoder_conv1d_filter_size: 1536
|
||||
decoder_output_size: 384
|
||||
d_model: 384
|
||||
duration_predictor_output_size: 256
|
||||
duration_predictor_filter_size: 3
|
||||
fft_conv1d_filter: 3
|
||||
fft_conv1d_padding: 1
|
||||
|
||||
|
||||
batch_size: 32
|
||||
epochs: 10000
|
||||
lr: 0.001
|
||||
save_step: 500
|
||||
image_step: 2000
|
||||
use_gpu: False
|
||||
use_data_parallel: False
|
||||
|
||||
data_path: ../../../dataset/LJSpeech-1.1
|
||||
transtts_path: ./checkpoint
|
||||
transformer_step: 70000
|
||||
log_dir: ./log
|
|
@ -25,3 +25,5 @@ use_data_parallel: True
|
|||
data_path: ../../../dataset/LJSpeech-1.1
|
||||
save_path: ./checkpoint
|
||||
log_dir: ./log
|
||||
#checkpoint_path: ./checkpoint
|
||||
#transformer_step: 27000
|
|
@ -28,6 +28,7 @@ use_data_parallel: False
|
|||
data_path: ../../../dataset/LJSpeech-1.1
|
||||
save_path: ./checkpoint
|
||||
log_dir: ./log
|
||||
#checkpoint_path: ./checkpoint/transformer/1
|
||||
#checkpoint_path: ./checkpoint
|
||||
#transformer_step: 70000
|
||||
|
||||
|
|
@ -25,6 +25,9 @@ class MyDataParallel(dg.parallel.DataParallel):
|
|||
return getattr(
|
||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
return model_dict, opti_dict
|
||||
|
||||
def main(cfg):
|
||||
|
||||
|
@ -55,9 +58,10 @@ def main(cfg):
|
|||
|
||||
|
||||
if cfg.checkpoint_path is not None:
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
|
||||
model_dict, opti_dict = load_checkpoint(str(cfg.postnet_step), os.path.join(cfg.checkpoint_path, "postnet"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = cfg.postnet_step
|
||||
print("load checkpoint!!!")
|
||||
|
||||
if cfg.use_data_parallel:
|
||||
|
|
|
@ -29,6 +29,10 @@ class MyDataParallel(dg.parallel.DataParallel):
|
|||
return getattr(
|
||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
||||
|
||||
def load_checkpoint(step, model_path):
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||
return model_dict, opti_dict
|
||||
|
||||
|
||||
def main(cfg):
|
||||
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
||||
|
@ -62,9 +66,10 @@ def main(cfg):
|
|||
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
|
||||
|
||||
if cfg.checkpoint_path is not None:
|
||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
|
||||
model_dict, opti_dict = load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.checkpoint_path, "transformer"))
|
||||
model.set_dict(model_dict)
|
||||
optimizer.set_dict(opti_dict)
|
||||
global_step = cfg.transformer_step
|
||||
print("load checkpoint!!!")
|
||||
|
||||
if cfg.use_data_parallel:
|
||||
|
|
Loading…
Reference in New Issue