use eps instald of 0.001

This commit is contained in:
WenmuZhou 2021-09-01 12:43:11 +08:00
parent a4b0241a90
commit f8f7153526
1 changed files with 4 additions and 4 deletions

View File

@ -25,7 +25,7 @@ class PSELoss(nn.Layer):
ohem_ratio=3, ohem_ratio=3,
kernel_sample_mask='pred', kernel_sample_mask='pred',
reduction='sum', reduction='sum',
**kwargs): eps=1e-6**kwargs):
"""Implement PSE Loss. """Implement PSE Loss.
""" """
super(PSELoss, self).__init__() super(PSELoss, self).__init__()
@ -34,6 +34,7 @@ class PSELoss(nn.Layer):
self.ohem_ratio = ohem_ratio self.ohem_ratio = ohem_ratio
self.kernel_sample_mask = kernel_sample_mask self.kernel_sample_mask = kernel_sample_mask
self.reduction = reduction self.reduction = reduction
self.eps = eps
def forward(self, outputs, labels): def forward(self, outputs, labels):
predicts = outputs['maps'] predicts = outputs['maps']
@ -92,8 +93,8 @@ class PSELoss(nn.Layer):
target = target * mask target = target * mask
a = paddle.sum(input * target, 1) a = paddle.sum(input * target, 1)
b = paddle.sum(input * input, 1) + 0.001 b = paddle.sum(input * input, 1) + self.eps
c = paddle.sum(target * target, 1) + 0.001 c = paddle.sum(target * target, 1) + self.eps
d = (2 * a) / (b + c) d = (2 * a) / (b + c)
return 1 - d return 1 - d
@ -104,7 +105,6 @@ class PSELoss(nn.Layer):
.astype('float32'))) .astype('float32')))
if pos_num == 0: if pos_num == 0:
# selected_mask = gt_text.copy() * 0 # may be not good
selected_mask = training_mask selected_mask = training_mask
selected_mask = selected_mask.reshape( selected_mask = selected_mask.reshape(
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype( [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(