change interface for io.py
This commit is contained in:
parent
64790853e5
commit
c845fbd51d
|
@ -20,6 +20,11 @@ import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
local_rank = dg.parallel.Env().local_rank
|
||||||
|
return local_rank == 0
|
||||||
|
|
||||||
|
|
||||||
def add_yaml_config_to_args(config):
|
def add_yaml_config_to_args(config):
|
||||||
""" Add args in yaml config to the args parsed by argparse. The argument in
|
""" Add args in yaml config to the args parsed by argparse. The argument in
|
||||||
yaml config will be overwritten by the same argument in argparse if they
|
yaml config will be overwritten by the same argument in argparse if they
|
||||||
|
@ -41,7 +46,7 @@ def add_yaml_config_to_args(config):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
def _load_latest_checkpoint(checkpoint_dir):
|
||||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -52,26 +57,20 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||||
Returns:
|
Returns:
|
||||||
int: the latest iteration number.
|
int: the latest iteration number.
|
||||||
"""
|
"""
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
# Create checkpoint index file if not exist.
|
# Create checkpoint index file if not exist.
|
||||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
if (not os.path.isfile(checkpoint_record)):
|
||||||
with open(checkpoint_path, "w") as handle:
|
return 0
|
||||||
handle.write("model_checkpoint_path: step-0")
|
|
||||||
|
|
||||||
# Make sure that other process waits until checkpoint file is created
|
|
||||||
# by process 0.
|
|
||||||
while not os.path.isfile(checkpoint_path):
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# Fetch the latest checkpoint index.
|
# Fetch the latest checkpoint index.
|
||||||
with open(checkpoint_path, "r") as handle:
|
with open(checkpoint_record, "r") as handle:
|
||||||
latest_checkpoint = handle.readline().split()[-1]
|
latest_checkpoint = handle.readline().split()[-1]
|
||||||
iteration = int(latest_checkpoint.split("-")[-1])
|
iteration = int(latest_checkpoint.split("-")[-1])
|
||||||
|
|
||||||
return iteration
|
return iteration
|
||||||
|
|
||||||
|
|
||||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
def _save_checkpoint(checkpoint_dir, iteration):
|
||||||
"""Save the iteration number of the latest model to be checkpointed.
|
"""Save the iteration number of the latest model to be checkpointed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -81,60 +80,76 @@ def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
# Update the latest checkpoint index.
|
# Update the latest checkpoint index.
|
||||||
with open(checkpoint_path, "w") as handle:
|
with open(checkpoint_record, "w") as handle:
|
||||||
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||||
|
|
||||||
|
|
||||||
def load_parameters(checkpoint_dir,
|
def load_parameters(model,
|
||||||
rank,
|
|
||||||
model,
|
|
||||||
optimizer=None,
|
optimizer=None,
|
||||||
|
checkpoint_dir=None,
|
||||||
iteration=None,
|
iteration=None,
|
||||||
file_path=None,
|
checkpoint_path=None,
|
||||||
dtype="float32"):
|
dtype="float32"):
|
||||||
"""Load a specific model checkpoint from disk.
|
"""Load a specific model checkpoint from disk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
||||||
rank (int): the rank of the process in multi-process setting.
|
|
||||||
model (obj): model to load parameters.
|
model (obj): model to load parameters.
|
||||||
optimizer (obj, optional): optimizer to load states if needed.
|
optimizer (obj, optional): optimizer to load states if needed.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||||
iteration (int, optional): if specified, load the specific checkpoint,
|
iteration (int, optional): if specified, load the specific checkpoint,
|
||||||
if not specified, load the latest one. Defaults to None.
|
if not specified, load the latest one. Defaults to None.
|
||||||
file_path (str, optional): if specified, load the checkpoint
|
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||||
stored in the file_path. Defaults to None.
|
stored in the checkpoint_path. Defaults to None.
|
||||||
dtype (str, optional): precision of the model parameters.
|
dtype (str, optional): precision of the model parameters.
|
||||||
Defaults to float32.
|
Defaults to float32.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
iteration (int): number of iterations that the loaded checkpoint has
|
||||||
|
been trained.
|
||||||
"""
|
"""
|
||||||
if file_path is None:
|
if checkpoint_dir is not None and checkpoint_path is not None:
|
||||||
if iteration is None:
|
raise ValueError(
|
||||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
"Load from either from (checkpoint_dir and iteration) \n"
|
||||||
if iteration == 0:
|
"or checkpoint_path. Do not pass both.")
|
||||||
return
|
if iteration is not None and checkpoint_dir is None:
|
||||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
raise ValueError(
|
||||||
|
"When iteration is specified, checkpoint_dir should not be None")
|
||||||
|
|
||||||
|
if checkpoint_dir is not None:
|
||||||
|
if iteration is None:
|
||||||
|
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||||
|
checkpoint_path = os.path.join(checkpoint_dir,
|
||||||
|
"step-{}".format(iteration))
|
||||||
|
if iteration == 0 and not os.path.exists(checkpoint_path):
|
||||||
|
# if step-0 exist, it is also loaded
|
||||||
|
return iteration
|
||||||
|
else:
|
||||||
|
# checkpoint is not None
|
||||||
|
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
|
||||||
|
|
||||||
|
local_rank = dg.parallel.Env().local_rank
|
||||||
|
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
|
||||||
|
|
||||||
|
# cast to desired data type
|
||||||
|
for k, v in model_dict.items():
|
||||||
|
model_dict[k] = v.astype(dtype)
|
||||||
|
|
||||||
model_dict, optimizer_dict = dg.load_dygraph(file_path)
|
|
||||||
if dtype == "float16":
|
|
||||||
for k, v in model_dict.items():
|
|
||||||
if "conv2d_transpose" in k:
|
|
||||||
model_dict[k] = v.astype("float32")
|
|
||||||
else:
|
|
||||||
model_dict[k] = v.astype(dtype)
|
|
||||||
model.set_dict(model_dict)
|
model.set_dict(model_dict)
|
||||||
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
|
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
|
||||||
|
local_rank, checkpoint_path))
|
||||||
|
|
||||||
if optimizer and optimizer_dict:
|
if optimizer and optimizer_dict:
|
||||||
optimizer.set_dict(optimizer_dict)
|
optimizer.set_dict(optimizer_dict)
|
||||||
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
|
print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
|
||||||
rank, file_path))
|
format(local_rank, checkpoint_path))
|
||||||
|
|
||||||
|
return iteration
|
||||||
|
|
||||||
|
|
||||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||||
"""Checkpoint the latest trained model parameters.
|
"""Checkpoint the latest trained model parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -147,12 +162,15 @@ def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
dg.save_dygraph(model_dict, file_path)
|
dg.save_dygraph(model_dict, checkpoint_path)
|
||||||
print("[checkpoint] Saved model to {}".format(file_path))
|
print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))
|
||||||
|
|
||||||
if optimizer:
|
if optimizer:
|
||||||
opt_dict = optimizer.state_dict()
|
opt_dict = optimizer.state_dict()
|
||||||
dg.save_dygraph(opt_dict, file_path)
|
dg.save_dygraph(opt_dict, checkpoint_path)
|
||||||
print("[checkpoint] Saved optimzier state to {}".format(file_path))
|
print("[checkpoint] Saved optimzier state to {}.pdopt".format(
|
||||||
|
checkpoint_path))
|
||||||
|
|
||||||
|
_save_checkpoint(checkpoint_dir, iteration)
|
||||||
|
|
Loading…
Reference in New Issue