Unify save & load interfaces
This commit is contained in:
parent
be70b41fd1
commit
64790853e5
|
@ -22,6 +22,7 @@ import paddle.fluid.dygraph as dg
|
|||
from paddle import fluid
|
||||
|
||||
import utils
|
||||
from parakeet.utils import io
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
|
@ -98,5 +99,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
config = io.add_yaml_config_to_args(config)
|
||||
benchmark(config)
|
||||
|
|
|
@ -23,6 +23,7 @@ from paddle import fluid
|
|||
|
||||
import utils
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
|
@ -96,7 +97,7 @@ def synthesize(config):
|
|||
# Obtain the current iteration.
|
||||
if config.checkpoint is None:
|
||||
if config.iteration is None:
|
||||
iteration = utils.load_latest_checkpoint(checkpoint_dir)
|
||||
iteration = io.load_latest_checkpoint(checkpoint_dir)
|
||||
else:
|
||||
iteration = config.iteration
|
||||
else:
|
||||
|
@ -117,5 +118,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
config = io.add_yaml_config_to_args(config)
|
||||
synthesize(config)
|
||||
|
|
|
@ -25,6 +25,7 @@ from paddle import fluid
|
|||
from tensorboardX import SummaryWriter
|
||||
|
||||
import utils
|
||||
from parakeet.utils import io
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
|
@ -104,7 +105,7 @@ def train(config):
|
|||
# Obtain the current iteration.
|
||||
if config.checkpoint is None:
|
||||
if config.iteration is None:
|
||||
iteration = utils.load_latest_checkpoint(checkpoint_dir, rank)
|
||||
iteration = io.load_latest_checkpoint(checkpoint_dir, rank)
|
||||
else:
|
||||
iteration = config.iteration
|
||||
else:
|
||||
|
@ -140,7 +141,7 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
config = io.add_yaml_config_to_args(config)
|
||||
# Force to use fp32 in model training
|
||||
vars(config)["use_fp16"] = False
|
||||
train(config)
|
||||
|
|
|
@ -12,14 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
|
||||
import argparse
|
||||
import ruamel.yaml
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
|
@ -95,131 +88,3 @@ def add_config_options_to_parser(parser):
|
|||
'--kernel_w', type=int, help="width of the kernel in the conv2d layer")
|
||||
|
||||
parser.add_argument('--config', type=str, help="Path to the config file.")
|
||||
|
||||
|
||||
def add_yaml_config(config):
|
||||
with open(config.config, 'rt') as f:
|
||||
yaml_cfg = ruamel.yaml.safe_load(f)
|
||||
cfg_vars = vars(config)
|
||||
for k, v in yaml_cfg.items():
|
||||
if k in cfg_vars and cfg_vars[k] is not None:
|
||||
continue
|
||||
cfg_vars[k] = v
|
||||
return config
|
||||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int, optional): the rank of the process in multi-process setting.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
int: the latest iteration number.
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Create checkpoint index file if not exist.
|
||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
handle.write("model_checkpoint_path: step-0")
|
||||
|
||||
# Make sure that other process waits until checkpoint file is created
|
||||
# by process 0.
|
||||
while not os.path.isfile(checkpoint_path):
|
||||
time.sleep(1)
|
||||
|
||||
# Fetch the latest checkpoint index.
|
||||
with open(checkpoint_path, "r") as handle:
|
||||
latest_checkpoint = handle.readline().split()[-1]
|
||||
iteration = int(latest_checkpoint.split("-")[-1])
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||
"""Save the iteration number of the latest model to be checkpointed.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Update the latest checkpoint index.
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||
|
||||
|
||||
def load_parameters(checkpoint_dir,
|
||||
rank,
|
||||
model,
|
||||
optimizer=None,
|
||||
iteration=None,
|
||||
file_path=None,
|
||||
dtype="float32"):
|
||||
"""Load a specific model checkpoint from disk.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int): the rank of the process in multi-process setting.
|
||||
model (obj): model to load parameters.
|
||||
optimizer (obj, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
iteration (int, optional): if specified, load the specific checkpoint,
|
||||
if not specified, load the latest one. Defaults to None.
|
||||
file_path (str, optional): if specified, load the checkpoint
|
||||
stored in the file_path. Defaults to None.
|
||||
dtype (str, optional): precision of the model parameters.
|
||||
Defaults to float32.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if file_path is None:
|
||||
if iteration is None:
|
||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||
if iteration == 0:
|
||||
return
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
|
||||
model_dict, optimizer_dict = dg.load_dygraph(file_path)
|
||||
if dtype == "float16":
|
||||
for k, v in model_dict.items():
|
||||
if "conv2d_transpose" in k:
|
||||
model_dict[k] = v.astype("float32")
|
||||
else:
|
||||
model_dict[k] = v.astype(dtype)
|
||||
model.set_dict(model_dict)
|
||||
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
|
||||
if optimizer and optimizer_dict:
|
||||
optimizer.set_dict(optimizer_dict)
|
||||
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
|
||||
rank, file_path))
|
||||
|
||||
|
||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
model (obj): model to be checkpointed.
|
||||
optimizer (obj, optional): optimizer to be checkpointed.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
model_dict = model.state_dict()
|
||||
dg.save_dygraph(model_dict, file_path)
|
||||
print("[checkpoint] Saved model to {}".format(file_path))
|
||||
|
||||
if optimizer:
|
||||
opt_dict = optimizer.state_dict()
|
||||
dg.save_dygraph(opt_dict, file_path)
|
||||
print("[checkpoint] Saved optimzier state to {}".format(file_path))
|
||||
|
|
|
@ -22,6 +22,7 @@ from paddle import fluid
|
|||
from scipy.io.wavfile import write
|
||||
|
||||
import utils
|
||||
from parakeet.utils import io
|
||||
from parakeet.modules import weight_norm
|
||||
from .data import LJSpeech
|
||||
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
||||
|
@ -47,6 +48,7 @@ class WaveFlow():
|
|||
Returns:
|
||||
WaveFlow
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
|
@ -91,7 +93,7 @@ class WaveFlow():
|
|||
parameter_list=waveflow.parameters())
|
||||
|
||||
# Load parameters.
|
||||
utils.load_parameters(
|
||||
io.load_parameters(
|
||||
self.checkpoint_dir,
|
||||
self.rank,
|
||||
waveflow,
|
||||
|
@ -111,7 +113,7 @@ class WaveFlow():
|
|||
|
||||
else:
|
||||
# Load parameters.
|
||||
utils.load_parameters(
|
||||
io.load_parameters(
|
||||
self.checkpoint_dir,
|
||||
self.rank,
|
||||
waveflow,
|
||||
|
@ -291,6 +293,6 @@ class WaveFlow():
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||
io.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||
self.waveflow, self.optimizer)
|
||||
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|
||||
io.save_latest_checkpoint(self.checkpoint_dir, iteration)
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import ruamel.yaml
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
def add_yaml_config_to_args(config):
|
||||
""" Add args in yaml config to the args parsed by argparse. The argument in
|
||||
yaml config will be overwritten by the same argument in argparse if they
|
||||
are both valid.
|
||||
|
||||
Args:
|
||||
config (args): the args returned by `argparse.ArgumentParser().parse_args()`
|
||||
|
||||
Returns:
|
||||
config: the args added yaml config.
|
||||
"""
|
||||
with open(config.config, 'rt') as f:
|
||||
yaml_cfg = ruamel.yaml.safe_load(f)
|
||||
cfg_vars = vars(config)
|
||||
for k, v in yaml_cfg.items():
|
||||
if k in cfg_vars and cfg_vars[k] is not None:
|
||||
continue
|
||||
cfg_vars[k] = v
|
||||
return config
|
||||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int, optional): the rank of the process in multi-process setting.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
int: the latest iteration number.
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Create checkpoint index file if not exist.
|
||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
handle.write("model_checkpoint_path: step-0")
|
||||
|
||||
# Make sure that other process waits until checkpoint file is created
|
||||
# by process 0.
|
||||
while not os.path.isfile(checkpoint_path):
|
||||
time.sleep(1)
|
||||
|
||||
# Fetch the latest checkpoint index.
|
||||
with open(checkpoint_path, "r") as handle:
|
||||
latest_checkpoint = handle.readline().split()[-1]
|
||||
iteration = int(latest_checkpoint.split("-")[-1])
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||
"""Save the iteration number of the latest model to be checkpointed.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Update the latest checkpoint index.
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||
|
||||
|
||||
def load_parameters(checkpoint_dir,
|
||||
rank,
|
||||
model,
|
||||
optimizer=None,
|
||||
iteration=None,
|
||||
file_path=None,
|
||||
dtype="float32"):
|
||||
"""Load a specific model checkpoint from disk.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int): the rank of the process in multi-process setting.
|
||||
model (obj): model to load parameters.
|
||||
optimizer (obj, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
iteration (int, optional): if specified, load the specific checkpoint,
|
||||
if not specified, load the latest one. Defaults to None.
|
||||
file_path (str, optional): if specified, load the checkpoint
|
||||
stored in the file_path. Defaults to None.
|
||||
dtype (str, optional): precision of the model parameters.
|
||||
Defaults to float32.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if file_path is None:
|
||||
if iteration is None:
|
||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||
if iteration == 0:
|
||||
return
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
|
||||
model_dict, optimizer_dict = dg.load_dygraph(file_path)
|
||||
if dtype == "float16":
|
||||
for k, v in model_dict.items():
|
||||
if "conv2d_transpose" in k:
|
||||
model_dict[k] = v.astype("float32")
|
||||
else:
|
||||
model_dict[k] = v.astype(dtype)
|
||||
model.set_dict(model_dict)
|
||||
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
|
||||
if optimizer and optimizer_dict:
|
||||
optimizer.set_dict(optimizer_dict)
|
||||
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
|
||||
rank, file_path))
|
||||
|
||||
|
||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
model (obj): model to be checkpointed.
|
||||
optimizer (obj, optional): optimizer to be checkpointed.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
model_dict = model.state_dict()
|
||||
dg.save_dygraph(model_dict, file_path)
|
||||
print("[checkpoint] Saved model to {}".format(file_path))
|
||||
|
||||
if optimizer:
|
||||
opt_dict = optimizer.state_dict()
|
||||
dg.save_dygraph(opt_dict, file_path)
|
||||
print("[checkpoint] Saved optimzier state to {}".format(file_path))
|
Loading…
Reference in New Issue