Merge pull request #97 from LDOUBLEV/fixocr

set mode of DB head as 'export' when export model
This commit is contained in:
Double_V 2020-05-27 09:45:22 +08:00 committed by GitHub
commit e9b7b195c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 5 deletions

View File

@ -109,7 +109,10 @@ class DetModel(object):
""" """
image, labels, loader = self.create_feed(mode) image, labels, loader = self.create_feed(mode)
conv_feas = self.backbone(image) conv_feas = self.backbone(image)
predicts = self.head(conv_feas) if self.algorithm == "DB":
predicts = self.head(conv_feas, mode)
else:
predicts = self.head(conv_feas)
if mode == "train": if mode == "train":
losses = self.loss(predicts, labels) losses = self.loss(predicts, labels)
return loader, losses return loader, losses

View File

@ -196,7 +196,7 @@ class DBHead(object):
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1) fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse) shrink_maps = self.binarize(fuse)
if mode != "train": if mode != "train":
return {"maps", shrink_maps} return {"maps": shrink_maps}
threshold_maps = self.thresh(fuse) threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps) binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat( y = fluid.layers.concat(

View File

@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var
@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path) save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)
@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
train_loader.reset() train_loader.reset()
if epoch == 0 and save_epoch_step == 1: if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0" save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'],save_path) save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0: if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path) save_model(train_info_dict['train_program'], save_path)