Unify save & load interfaces

This commit is contained in:
liuyibing01 2020-03-22 08:05:05 +00:00
parent be70b41fd1
commit 64790853e5
6 changed files with 173 additions and 145 deletions

View File

@ -22,6 +22,7 @@ import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from parakeet.utils import io
from parakeet.models.waveflow import WaveFlow
@ -98,5 +99,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
config = io.add_yaml_config_to_args(config)
benchmark(config)

View File

@ -23,6 +23,7 @@ from paddle import fluid
import utils
from parakeet.models.waveflow import WaveFlow
from parakeet.utils import io
def add_options_to_parser(parser):
@ -96,7 +97,7 @@ def synthesize(config):
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir)
iteration = io.load_latest_checkpoint(checkpoint_dir)
else:
iteration = config.iteration
else:
@ -117,5 +118,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
config = io.add_yaml_config_to_args(config)
synthesize(config)

View File

@ -25,6 +25,7 @@ from paddle import fluid
from tensorboardX import SummaryWriter
import utils
from parakeet.utils import io
from parakeet.models.waveflow import WaveFlow
@ -104,7 +105,7 @@ def train(config):
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir, rank)
iteration = io.load_latest_checkpoint(checkpoint_dir, rank)
else:
iteration = config.iteration
else:
@ -140,7 +141,7 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
config = io.add_yaml_config_to_args(config)
# Force to use fp32 in model training
vars(config)["use_fp16"] = False
train(config)

View File

@ -12,14 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import time
import argparse
import ruamel.yaml
import numpy as np
import paddle.fluid.dygraph as dg
def str2bool(v):
@ -95,131 +88,3 @@ def add_config_options_to_parser(parser):
'--kernel_w', type=int, help="width of the kernel in the conv2d layer")
parser.add_argument('--config', type=str, help="Path to the config file.")
def add_yaml_config(config):
with open(config.config, 'rt') as f:
yaml_cfg = ruamel.yaml.safe_load(f)
cfg_vars = vars(config)
for k, v in yaml_cfg.items():
if k in cfg_vars and cfg_vars[k] is not None:
continue
cfg_vars[k] = v
return config
def load_latest_checkpoint(checkpoint_dir, rank=0):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir,
rank,
model,
optimizer=None,
iteration=None,
file_path=None,
dtype="float32"):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
if dtype == "float16":
for k, v in model_dict.items():
if "conv2d_transpose" in k:
model_dict[k] = v.astype("float32")
else:
model_dict[k] = v.astype(dtype)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))

View File

@ -22,6 +22,7 @@ from paddle import fluid
from scipy.io.wavfile import write
import utils
from parakeet.utils import io
from parakeet.modules import weight_norm
from .data import LJSpeech
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
@ -47,6 +48,7 @@ class WaveFlow():
Returns:
WaveFlow
"""
def __init__(self,
config,
checkpoint_dir,
@ -91,7 +93,7 @@ class WaveFlow():
parameter_list=waveflow.parameters())
# Load parameters.
utils.load_parameters(
io.load_parameters(
self.checkpoint_dir,
self.rank,
waveflow,
@ -111,7 +113,7 @@ class WaveFlow():
else:
# Load parameters.
utils.load_parameters(
io.load_parameters(
self.checkpoint_dir,
self.rank,
waveflow,
@ -291,6 +293,6 @@ class WaveFlow():
Returns:
None
"""
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.waveflow, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
io.save_latest_parameters(self.checkpoint_dir, iteration,
self.waveflow, self.optimizer)
io.save_latest_checkpoint(self.checkpoint_dir, iteration)

158
parakeet/utils/io.py Normal file
View File

@ -0,0 +1,158 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import ruamel.yaml
import numpy as np
import paddle.fluid.dygraph as dg
def add_yaml_config_to_args(config):
""" Add args in yaml config to the args parsed by argparse. The argument in
yaml config will be overwritten by the same argument in argparse if they
are both valid.
Args:
config (args): the args returned by `argparse.ArgumentParser().parse_args()`
Returns:
config: the args added yaml config.
"""
with open(config.config, 'rt') as f:
yaml_cfg = ruamel.yaml.safe_load(f)
cfg_vars = vars(config)
for k, v in yaml_cfg.items():
if k in cfg_vars and cfg_vars[k] is not None:
continue
cfg_vars[k] = v
return config
def load_latest_checkpoint(checkpoint_dir, rank=0):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir,
rank,
model,
optimizer=None,
iteration=None,
file_path=None,
dtype="float32"):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
if dtype == "float16":
for k, v in model_dict.items():
if "conv2d_transpose" in k:
model_dict[k] = v.astype("float32")
else:
model_dict[k] = v.astype(dtype)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))