fix issue #206
This commit is contained in:
parent
1abb0e4afb
commit
f4094b9700
|
@ -32,6 +32,7 @@ class DetModel(object):
|
|||
params (dict): the super parameters for detection module.
|
||||
"""
|
||||
global_params = params['Global']
|
||||
self.global_params = global_params
|
||||
self.algorithm = global_params['algorithm']
|
||||
|
||||
backbone_params = deepcopy(params["Backbone"])
|
||||
|
@ -64,11 +65,23 @@ class DetModel(object):
|
|||
if mode == "train":
|
||||
if self.algorithm == "EAST":
|
||||
score = fluid.layers.data(
|
||||
name='score', shape=[1, 128, 128], dtype='float32')
|
||||
name='score',
|
||||
shape=[
|
||||
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
],
|
||||
dtype='float32')
|
||||
geo = fluid.layers.data(
|
||||
name='geo', shape=[9, 128, 128], dtype='float32')
|
||||
name='geo',
|
||||
shape=[
|
||||
9, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
],
|
||||
dtype='float32')
|
||||
mask = fluid.layers.data(
|
||||
name='mask', shape=[1, 128, 128], dtype='float32')
|
||||
name='mask',
|
||||
shape=[
|
||||
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
],
|
||||
dtype='float32')
|
||||
feed_list = [image, score, geo, mask]
|
||||
labels = {'score': score, 'geo': geo, 'mask': mask}
|
||||
elif self.algorithm == "DB":
|
||||
|
|
Loading…
Reference in New Issue