add error message when image shape are not divisible by 4

This commit is contained in:
LDOUBLEV 2020-06-16 15:22:19 +08:00
parent a5f1efdfc9
commit e33e38d1d0
1 changed files with 10 additions and 15 deletions

View File

@ -59,28 +59,23 @@ class DetModel(object):
return: (image, corresponding label, dataloader)
"""
image_shape = deepcopy(self.image_shape)
if image_shape[1] % 4 != 0 or image_shape[2] % 4 != 0:
raise Exception("The size of the image must be divisible by 4, "
"received image shape is {}, please reset the "
"Global.image_shape in the yml file".format(
image_shape))
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
if mode == "train":
if self.algorithm == "EAST":
h, w = int(image_shape[1] // 4), int(image_shape[2] // 4)
score = fluid.layers.data(
name='score',
shape=[
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
name='score', shape=[1, h, w], dtype='float32')
geo = fluid.layers.data(
name='geo',
shape=[
9, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
name='geo', shape=[9, h, w], dtype='float32')
mask = fluid.layers.data(
name='mask',
shape=[
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
],
dtype='float32')
name='mask', shape=[1, h, w], dtype='float32')
feed_list = [image, score, geo, mask]
labels = {'score': score, 'geo': geo, 'mask': mask}
elif self.algorithm == "DB":