add error message when image shape are not divisible by 4
This commit is contained in:
parent
a5f1efdfc9
commit
e33e38d1d0
|
@ -59,28 +59,23 @@ class DetModel(object):
|
|||
return: (image, corresponding label, dataloader)
|
||||
"""
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
if image_shape[1] % 4 != 0 or image_shape[2] % 4 != 0:
|
||||
raise Exception("The size of the image must be divisible by 4, "
|
||||
"received image shape is {}, please reset the "
|
||||
"Global.image_shape in the yml file".format(
|
||||
image_shape))
|
||||
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=image_shape, dtype='float32')
|
||||
if mode == "train":
|
||||
if self.algorithm == "EAST":
|
||||
h, w = int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
score = fluid.layers.data(
|
||||
name='score',
|
||||
shape=[
|
||||
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
],
|
||||
dtype='float32')
|
||||
name='score', shape=[1, h, w], dtype='float32')
|
||||
geo = fluid.layers.data(
|
||||
name='geo',
|
||||
shape=[
|
||||
9, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
],
|
||||
dtype='float32')
|
||||
name='geo', shape=[9, h, w], dtype='float32')
|
||||
mask = fluid.layers.data(
|
||||
name='mask',
|
||||
shape=[
|
||||
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||
],
|
||||
dtype='float32')
|
||||
name='mask', shape=[1, h, w], dtype='float32')
|
||||
feed_list = [image, score, geo, mask]
|
||||
labels = {'score': score, 'geo': geo, 'mask': mask}
|
||||
elif self.algorithm == "DB":
|
||||
|
|
Loading…
Reference in New Issue