Merge pull request #97 from LDOUBLEV/fixocr
set mode of DB head as 'export' when export model
This commit is contained in:
commit
e9b7b195c7
|
@ -109,7 +109,10 @@ class DetModel(object):
|
||||||
"""
|
"""
|
||||||
image, labels, loader = self.create_feed(mode)
|
image, labels, loader = self.create_feed(mode)
|
||||||
conv_feas = self.backbone(image)
|
conv_feas = self.backbone(image)
|
||||||
predicts = self.head(conv_feas)
|
if self.algorithm == "DB":
|
||||||
|
predicts = self.head(conv_feas, mode)
|
||||||
|
else:
|
||||||
|
predicts = self.head(conv_feas)
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
losses = self.loss(predicts, labels)
|
losses = self.loss(predicts, labels)
|
||||||
return loader, losses
|
return loader, losses
|
||||||
|
|
|
@ -196,7 +196,7 @@ class DBHead(object):
|
||||||
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
|
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
|
||||||
shrink_maps = self.binarize(fuse)
|
shrink_maps = self.binarize(fuse)
|
||||||
if mode != "train":
|
if mode != "train":
|
||||||
return {"maps", shrink_maps}
|
return {"maps": shrink_maps}
|
||||||
threshold_maps = self.thresh(fuse)
|
threshold_maps = self.thresh(fuse)
|
||||||
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
||||||
y = fluid.layers.concat(
|
y = fluid.layers.concat(
|
||||||
|
|
|
@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog):
|
||||||
func_infor = config['Architecture']['function']
|
func_infor = config['Architecture']['function']
|
||||||
model = create_module(func_infor)(params=config)
|
model = create_module(func_infor)(params=config)
|
||||||
image, outputs = model(mode='export')
|
image, outputs = model(mode='export')
|
||||||
fetches_var_name = sorted([name for name in outputs])
|
fetches_var_name = sorted([name for name in outputs.keys()])
|
||||||
fetches_var = [outputs[name] for name in fetches_var_name]
|
fetches_var = [outputs[name] for name in fetches_var_name]
|
||||||
feeded_var_names = [image.name]
|
feeded_var_names = [image.name]
|
||||||
target_vars = fetches_var
|
target_vars = fetches_var
|
||||||
|
@ -271,7 +271,7 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
||||||
train_loader.reset()
|
train_loader.reset()
|
||||||
if epoch == 0 and save_epoch_step == 1:
|
if epoch == 0 and save_epoch_step == 1:
|
||||||
save_path = save_model_dir + "/iter_epoch_0"
|
save_path = save_model_dir + "/iter_epoch_0"
|
||||||
save_model(train_info_dict['train_program'],save_path)
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||||
save_model(train_info_dict['train_program'], save_path)
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
|
@ -350,7 +350,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
||||||
train_loader.reset()
|
train_loader.reset()
|
||||||
if epoch == 0 and save_epoch_step == 1:
|
if epoch == 0 and save_epoch_step == 1:
|
||||||
save_path = save_model_dir + "/iter_epoch_0"
|
save_path = save_model_dir + "/iter_epoch_0"
|
||||||
save_model(train_info_dict['train_program'],save_path)
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||||
save_model(train_info_dict['train_program'], save_path)
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
|
|
Loading…
Reference in New Issue