Merge pull request #97 from LDOUBLEV/fixocr

set mode of DB head as 'export' when export model
This commit is contained in:
Double_V 2020-05-27 09:45:22 +08:00 committed by GitHub
commit e9b7b195c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 5 deletions

View File

@ -109,6 +109,9 @@ class DetModel(object):
""" """
image, labels, loader = self.create_feed(mode) image, labels, loader = self.create_feed(mode)
conv_feas = self.backbone(image) conv_feas = self.backbone(image)
if self.algorithm == "DB":
predicts = self.head(conv_feas, mode)
else:
predicts = self.head(conv_feas) predicts = self.head(conv_feas)
if mode == "train": if mode == "train":
losses = self.loss(predicts, labels) losses = self.loss(predicts, labels)

View File

@ -196,7 +196,7 @@ class DBHead(object):
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1) fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse) shrink_maps = self.binarize(fuse)
if mode != "train": if mode != "train":
return {"maps", shrink_maps} return {"maps": shrink_maps}
threshold_maps = self.thresh(fuse) threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps) binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat( y = fluid.layers.concat(

View File

@ -191,7 +191,7 @@ def build_export(config, main_prog, startup_prog):
func_infor = config['Architecture']['function'] func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config) model = create_module(func_infor)(params=config)
image, outputs = model(mode='export') image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs]) fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name] fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name] feeded_var_names = [image.name]
target_vars = fetches_var target_vars = fetches_var