change interface for io.py
This commit is contained in:
parent
64790853e5
commit
c845fbd51d
|
@ -20,6 +20,11 @@ import numpy as np
|
|||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
def is_main_process():
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
return local_rank == 0
|
||||
|
||||
|
||||
def add_yaml_config_to_args(config):
|
||||
""" Add args in yaml config to the args parsed by argparse. The argument in
|
||||
yaml config will be overwritten by the same argument in argparse if they
|
||||
|
@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
|
|||
return config
|
||||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||
def _load_latest_checkpoint(checkpoint_dir):
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||
|
||||
Args:
|
||||
|
@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
|
|||
Returns:
|
||||
int: the latest iteration number.
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Create checkpoint index file if not exist.
|
||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
handle.write("model_checkpoint_path: step-0")
|
||||
|
||||
# Make sure that other process waits until checkpoint file is created
|
||||
# by process 0.
|
||||
while not os.path.isfile(checkpoint_path):
|
||||
time.sleep(1)
|
||||
if (not os.path.isfile(checkpoint_record)):
|
||||
return 0
|
||||
|
||||
# Fetch the latest checkpoint index.
|
||||
with open(checkpoint_path, "r") as handle:
|
||||
with open(checkpoint_record, "r") as handle:
|
||||
latest_checkpoint = handle.readline().split()[-1]
|
||||
iteration = int(latest_checkpoint.split("-")[-1])
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||
def _save_checkpoint(checkpoint_dir, iteration):
|
||||
"""Save the iteration number of the latest model to be checkpointed.
|
||||
|
||||
Args:
|
||||
|
@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Update the latest checkpoint index.
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
with open(checkpoint_record, "w") as handle:
|
||||
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||
|
||||
|
||||
def load_parameters(checkpoint_dir,
|
||||
rank,
|
||||
model,
|
||||
def load_parameters(model,
|
||||
optimizer=None,
|
||||
checkpoint_dir=None,
|
||||
iteration=None,
|
||||
file_path=None,
|
||||
checkpoint_path=None,
|
||||
dtype="float32"):
|
||||
"""Load a specific model checkpoint from disk.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int): the rank of the process in multi-process setting.
|
||||
model (obj): model to load parameters.
|
||||
optimizer (obj, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||
iteration (int, optional): if specified, load the specific checkpoint,
|
||||
if not specified, load the latest one. Defaults to None.
|
||||
file_path (str, optional): if specified, load the checkpoint
|
||||
stored in the file_path. Defaults to None.
|
||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||
stored in the checkpoint_path. Defaults to None.
|
||||
dtype (str, optional): precision of the model parameters.
|
||||
Defaults to float32.
|
||||
|
||||
Returns:
|
||||
None
|
||||
iteration (int): number of iterations that the loaded checkpoint has
|
||||
been trained.
|
||||
"""
|
||||
if file_path is None:
|
||||
if iteration is None:
|
||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||
if iteration == 0:
|
||||
return
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
if checkpoint_dir is not None and checkpoint_path is not None:
|
||||
raise ValueError(
|
||||
"Load from either from (checkpoint_dir and iteration) \n"
|
||||
"or checkpoint_path. Do not pass both.")
|
||||
if iteration is not None and checkpoint_dir is None:
|
||||
raise ValueError(
|
||||
"When iteration is specified, checkpoint_dir should not be None")
|
||||
|
||||
if checkpoint_dir is not None:
|
||||
if iteration is None:
|
||||
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"step-{}".format(iteration))
|
||||
if iteration == 0 and not os.path.exists(checkpoint_path):
|
||||
# if step-0 exist, it is also loaded
|
||||
return iteration
|
||||
else:
|
||||
# checkpoint is not None
|
||||
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
|
||||
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
|
||||
|
||||
# cast to desired data type
|
||||
for k, v in model_dict.items():
|
||||
model_dict[k] = v.astype(dtype)
|
||||
|
||||
model_dict, optimizer_dict = dg.load_dygraph(file_path)
|
||||
if dtype == "float16":
|
||||
for k, v in model_dict.items():
|
||||
if "conv2d_transpose" in k:
|
||||
model_dict[k] = v.astype("float32")
|
||||
else:
|
||||
model_dict[k] = v.astype(dtype)
|
||||
model.set_dict(model_dict)
|
||||
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
|
||||
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
|
||||
local_rank, checkpoint_path))
|
||||
|
||||
if optimizer and optimizer_dict:
|
||||
optimizer.set_dict(optimizer_dict)
|
||||
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
|
||||
rank, file_path))
|
||||
print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
|
||||
format(local_rank, checkpoint_path))
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
|
||||
Args:
|
||||
|
@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
|
||||
model_dict = model.state_dict()
|
||||
dg.save_dygraph(model_dict, file_path)
|
||||
print("[checkpoint] Saved model to {}".format(file_path))
|
||||
dg.save_dygraph(model_dict, checkpoint_path)
|
||||
print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))
|
||||
|
||||
if optimizer:
|
||||
opt_dict = optimizer.state_dict()
|
||||
dg.save_dygraph(opt_dict, file_path)
|
||||
print("[checkpoint] Saved optimzier state to {}".format(file_path))
|
||||
dg.save_dygraph(opt_dict, checkpoint_path)
|
||||
print("[checkpoint] Saved optimzier state to {}.pdopt".format(
|
||||
checkpoint_path))
|
||||
|
||||
_save_checkpoint(checkpoint_dir, iteration)
|
||||
|
|
Loading…
Reference in New Issue