diff --git a/tools/train.py b/tools/train.py index 2091ff48..05d295aa 100755 --- a/tools/train.py +++ b/tools/train.py @@ -97,8 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - #pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) - pre_best_model_dict = {} + pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format(