Merge branch 'add_TranTTS' into 'master'
modified load checkpoint in train See merge request !2
This commit is contained in:
commit
df8e44bbde
|
@ -1,41 +0,0 @@
|
||||||
audio:
|
|
||||||
num_mels: 80
|
|
||||||
n_fft: 2048
|
|
||||||
sr: 22050
|
|
||||||
preemphasis: 0.97
|
|
||||||
hop_length: 275
|
|
||||||
win_length: 1102
|
|
||||||
power: 1.2
|
|
||||||
min_level_db: -100
|
|
||||||
ref_level_db: 20
|
|
||||||
outputs_per_step: 1
|
|
||||||
|
|
||||||
encoder_n_layer: 6
|
|
||||||
encoder_head: 2
|
|
||||||
encoder_conv1d_filter_size: 1536
|
|
||||||
max_sep_len: 2048
|
|
||||||
encoder_output_size: 384
|
|
||||||
word_vec_dim: 384
|
|
||||||
decoder_n_layer: 6
|
|
||||||
decoder_head: 2
|
|
||||||
decoder_conv1d_filter_size: 1536
|
|
||||||
decoder_output_size: 384
|
|
||||||
d_model: 384
|
|
||||||
duration_predictor_output_size: 256
|
|
||||||
duration_predictor_filter_size: 3
|
|
||||||
fft_conv1d_filter: 3
|
|
||||||
fft_conv1d_padding: 1
|
|
||||||
|
|
||||||
|
|
||||||
batch_size: 32
|
|
||||||
epochs: 10000
|
|
||||||
lr: 0.001
|
|
||||||
save_step: 500
|
|
||||||
image_step: 2000
|
|
||||||
use_gpu: False
|
|
||||||
use_data_parallel: False
|
|
||||||
|
|
||||||
data_path: ../../../dataset/LJSpeech-1.1
|
|
||||||
transtts_path: ./checkpoint
|
|
||||||
transformer_step: 70000
|
|
||||||
log_dir: ./log
|
|
|
@ -24,4 +24,6 @@ use_data_parallel: True
|
||||||
|
|
||||||
data_path: ../../../dataset/LJSpeech-1.1
|
data_path: ../../../dataset/LJSpeech-1.1
|
||||||
save_path: ./checkpoint
|
save_path: ./checkpoint
|
||||||
log_dir: ./log
|
log_dir: ./log
|
||||||
|
#checkpoint_path: ./checkpoint
|
||||||
|
#transformer_step: 27000
|
|
@ -28,6 +28,7 @@ use_data_parallel: False
|
||||||
data_path: ../../../dataset/LJSpeech-1.1
|
data_path: ../../../dataset/LJSpeech-1.1
|
||||||
save_path: ./checkpoint
|
save_path: ./checkpoint
|
||||||
log_dir: ./log
|
log_dir: ./log
|
||||||
#checkpoint_path: ./checkpoint/transformer/1
|
#checkpoint_path: ./checkpoint
|
||||||
|
#transformer_step: 70000
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,9 @@ class MyDataParallel(dg.parallel.DataParallel):
|
||||||
return getattr(
|
return getattr(
|
||||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
||||||
|
|
||||||
|
def load_checkpoint(step, model_path):
|
||||||
|
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
|
return model_dict, opti_dict
|
||||||
|
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
|
|
||||||
|
@ -55,9 +58,10 @@ def main(cfg):
|
||||||
|
|
||||||
|
|
||||||
if cfg.checkpoint_path is not None:
|
if cfg.checkpoint_path is not None:
|
||||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
|
model_dict, opti_dict = load_checkpoint(str(cfg.postnet_step), os.path.join(cfg.checkpoint_path, "postnet"))
|
||||||
model.set_dict(model_dict)
|
model.set_dict(model_dict)
|
||||||
optimizer.set_dict(opti_dict)
|
optimizer.set_dict(opti_dict)
|
||||||
|
global_step = cfg.postnet_step
|
||||||
print("load checkpoint!!!")
|
print("load checkpoint!!!")
|
||||||
|
|
||||||
if cfg.use_data_parallel:
|
if cfg.use_data_parallel:
|
||||||
|
|
|
@ -29,6 +29,10 @@ class MyDataParallel(dg.parallel.DataParallel):
|
||||||
return getattr(
|
return getattr(
|
||||||
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
object.__getattribute__(self, "_sub_layers")["_layers"], key)
|
||||||
|
|
||||||
|
def load_checkpoint(step, model_path):
|
||||||
|
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
|
||||||
|
return model_dict, opti_dict
|
||||||
|
|
||||||
|
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0
|
||||||
|
@ -62,9 +66,10 @@ def main(cfg):
|
||||||
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
|
reader = LJSpeechLoader(cfg, nranks, local_rank).reader()
|
||||||
|
|
||||||
if cfg.checkpoint_path is not None:
|
if cfg.checkpoint_path is not None:
|
||||||
model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path)
|
model_dict, opti_dict = load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.checkpoint_path, "transformer"))
|
||||||
model.set_dict(model_dict)
|
model.set_dict(model_dict)
|
||||||
optimizer.set_dict(opti_dict)
|
optimizer.set_dict(opti_dict)
|
||||||
|
global_step = cfg.transformer_step
|
||||||
print("load checkpoint!!!")
|
print("load checkpoint!!!")
|
||||||
|
|
||||||
if cfg.use_data_parallel:
|
if cfg.use_data_parallel:
|
||||||
|
|
Loading…
Reference in New Issue