add error message when image shape are not divisible by 4
This commit is contained in:
parent
a5f1efdfc9
commit
e33e38d1d0
|
@ -59,28 +59,23 @@ class DetModel(object):
|
||||||
return: (image, corresponding label, dataloader)
|
return: (image, corresponding label, dataloader)
|
||||||
"""
|
"""
|
||||||
image_shape = deepcopy(self.image_shape)
|
image_shape = deepcopy(self.image_shape)
|
||||||
|
if image_shape[1] % 4 != 0 or image_shape[2] % 4 != 0:
|
||||||
|
raise Exception("The size of the image must be divisible by 4, "
|
||||||
|
"received image shape is {}, please reset the "
|
||||||
|
"Global.image_shape in the yml file".format(
|
||||||
|
image_shape))
|
||||||
|
|
||||||
image = fluid.layers.data(
|
image = fluid.layers.data(
|
||||||
name='image', shape=image_shape, dtype='float32')
|
name='image', shape=image_shape, dtype='float32')
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
if self.algorithm == "EAST":
|
if self.algorithm == "EAST":
|
||||||
|
h, w = int(image_shape[1] // 4), int(image_shape[2] // 4)
|
||||||
score = fluid.layers.data(
|
score = fluid.layers.data(
|
||||||
name='score',
|
name='score', shape=[1, h, w], dtype='float32')
|
||||||
shape=[
|
|
||||||
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
|
||||||
],
|
|
||||||
dtype='float32')
|
|
||||||
geo = fluid.layers.data(
|
geo = fluid.layers.data(
|
||||||
name='geo',
|
name='geo', shape=[9, h, w], dtype='float32')
|
||||||
shape=[
|
|
||||||
9, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
|
||||||
],
|
|
||||||
dtype='float32')
|
|
||||||
mask = fluid.layers.data(
|
mask = fluid.layers.data(
|
||||||
name='mask',
|
name='mask', shape=[1, h, w], dtype='float32')
|
||||||
shape=[
|
|
||||||
1, int(image_shape[1] // 4), int(image_shape[2] // 4)
|
|
||||||
],
|
|
||||||
dtype='float32')
|
|
||||||
feed_list = [image, score, geo, mask]
|
feed_list = [image, score, geo, mask]
|
||||||
labels = {'score': score, 'geo': geo, 'mask': mask}
|
labels = {'score': score, 'geo': geo, 'mask': mask}
|
||||||
elif self.algorithm == "DB":
|
elif self.algorithm == "DB":
|
||||||
|
|
Loading…
Reference in New Issue