This commit is contained in:
LDOUBLEV 2020-06-16 14:55:52 +08:00
parent 1abb0e4afb
commit f4094b9700
1 changed files with 16 additions and 3 deletions

View File

@ -32,6 +32,7 @@ class DetModel(object):
params (dict): the super parameters for detection module.
"""
global_params = params['Global']
self.global_params = global_params
self.algorithm = global_params['algorithm']
backbone_params = deepcopy(params["Backbone"])
@ -64,11 +65,23 @@ class DetModel(object):
if mode == "train":
if self.algorithm == "EAST":
score = fluid.layers.data(
name='score', shape=[1, 128, 128], dtype='float32')
name='score',
shape=[
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
geo = fluid.layers.data(
name='geo', shape=[9, 128, 128], dtype='float32')
name='geo',
shape=[
9, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
mask = fluid.layers.data(
name='mask', shape=[1, 128, 128], dtype='float32')
name='mask',
shape=[
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
feed_list = [image, score, geo, mask]
labels = {'score': score, 'geo': geo, 'mask': mask}
elif self.algorithm == "DB":