diff --git a/parakeet/models/fastspeech/config/fastapeech.yaml b/parakeet/models/fastspeech/config/fastapeech.yaml deleted file mode 100644 index 3e62846..0000000 --- a/parakeet/models/fastspeech/config/fastapeech.yaml +++ /dev/null @@ -1,41 +0,0 @@ -audio: - num_mels: 80 - n_fft: 2048 - sr: 22050 - preemphasis: 0.97 - hop_length: 275 - win_length: 1102 - power: 1.2 - min_level_db: -100 - ref_level_db: 20 - outputs_per_step: 1 - -encoder_n_layer: 6 -encoder_head: 2 -encoder_conv1d_filter_size: 1536 -max_sep_len: 2048 -encoder_output_size: 384 -word_vec_dim: 384 -decoder_n_layer: 6 -decoder_head: 2 -decoder_conv1d_filter_size: 1536 -decoder_output_size: 384 -d_model: 384 -duration_predictor_output_size: 256 -duration_predictor_filter_size: 3 -fft_conv1d_filter: 3 -fft_conv1d_padding: 1 - - -batch_size: 32 -epochs: 10000 -lr: 0.001 -save_step: 500 -image_step: 2000 -use_gpu: False -use_data_parallel: False - -data_path: ../../../dataset/LJSpeech-1.1 -transtts_path: ./checkpoint -transformer_step: 70000 -log_dir: ./log \ No newline at end of file diff --git a/parakeet/models/transformerTTS/config/train_postnet.yaml b/parakeet/models/transformerTTS/config/train_postnet.yaml index 74e1b5a..091758f 100644 --- a/parakeet/models/transformerTTS/config/train_postnet.yaml +++ b/parakeet/models/transformerTTS/config/train_postnet.yaml @@ -24,4 +24,6 @@ use_data_parallel: True data_path: ../../../dataset/LJSpeech-1.1 save_path: ./checkpoint -log_dir: ./log \ No newline at end of file +log_dir: ./log +#checkpoint_path: ./checkpoint +#transformer_step: 27000 \ No newline at end of file diff --git a/parakeet/models/transformerTTS/config/train_transformer.yaml b/parakeet/models/transformerTTS/config/train_transformer.yaml index 0fbde62..8847a6e 100644 --- a/parakeet/models/transformerTTS/config/train_transformer.yaml +++ b/parakeet/models/transformerTTS/config/train_transformer.yaml @@ -28,6 +28,7 @@ use_data_parallel: False data_path: ../../../dataset/LJSpeech-1.1 save_path: ./checkpoint log_dir: ./log -#checkpoint_path: ./checkpoint/transformer/1 +#checkpoint_path: ./checkpoint +#transformer_step: 70000 \ No newline at end of file diff --git a/parakeet/models/transformerTTS/train_postnet.py b/parakeet/models/transformerTTS/train_postnet.py index fe0f379..d45a4c6 100644 --- a/parakeet/models/transformerTTS/train_postnet.py +++ b/parakeet/models/transformerTTS/train_postnet.py @@ -25,6 +25,9 @@ class MyDataParallel(dg.parallel.DataParallel): return getattr( object.__getattribute__(self, "_sub_layers")["_layers"], key) +def load_checkpoint(step, model_path): + model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step)) + return model_dict, opti_dict def main(cfg): @@ -55,9 +58,10 @@ def main(cfg): if cfg.checkpoint_path is not None: - model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) + model_dict, opti_dict = load_checkpoint(str(cfg.postnet_step), os.path.join(cfg.checkpoint_path, "postnet")) model.set_dict(model_dict) optimizer.set_dict(opti_dict) + global_step = cfg.postnet_step print("load checkpoint!!!") if cfg.use_data_parallel: diff --git a/parakeet/models/transformerTTS/train_transformer.py b/parakeet/models/transformerTTS/train_transformer.py index 8b177cd..844c56c 100644 --- a/parakeet/models/transformerTTS/train_transformer.py +++ b/parakeet/models/transformerTTS/train_transformer.py @@ -29,6 +29,10 @@ class MyDataParallel(dg.parallel.DataParallel): return getattr( object.__getattribute__(self, "_sub_layers")["_layers"], key) +def load_checkpoint(step, model_path): + model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step)) + return model_dict, opti_dict + def main(cfg): local_rank = dg.parallel.Env().local_rank if cfg.use_data_parallel else 0 @@ -62,9 +66,10 @@ def main(cfg): reader = LJSpeechLoader(cfg, nranks, local_rank).reader() if cfg.checkpoint_path is not None: - model_dict, opti_dict = fluid.dygraph.load_dygraph(cfg.checkpoint_path) + model_dict, opti_dict = load_checkpoint(str(cfg.transformer_step), os.path.join(cfg.checkpoint_path, "transformer")) model.set_dict(model_dict) optimizer.set_dict(opti_dict) + global_step = cfg.transformer_step print("load checkpoint!!!") if cfg.use_data_parallel: