Merge pull request #97 from LDOUBLEV/fixocr

set mode of DB head as 'export' when export model
This commit is contained in:
Double_V 2020-05-27 09:45:22 +08:00 committed by GitHub
commit e9b7b195c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 5 deletions

View File

@ -109,6 +109,9 @@ class DetModel(object):
"""
image, labels, loader = self.create_feed(mode)
conv_feas = self.backbone(image)
if self.algorithm == "DB":
predicts = self.head(conv_feas, mode)
else:
predicts = self.head(conv_feas)
if mode == "train":
losses = self.loss(predicts, labels)

View File

@ -196,7 +196,7 @@ class DBHead(object):
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse)
if mode != "train":
return {"maps", shrink_maps}
return {"maps": shrink_maps}
threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat(

View File

@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config)
image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs])
fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name]
target_vars = fetches_var
@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path)
save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path)
save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)