diff --git a/setup.py b/setup.py index 71b0860..5fbb814 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( name='deepke', # 打包后的包文件名 - version='0.2.91', #版本号 + version='0.2.92', #版本号 keywords=["pip", "RE","NER","AE"], # 关键字 description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明 long_description="client", #详细说明 diff --git a/src/deepke/relation_extraction/standard/tools/loss.py b/src/deepke/relation_extraction/standard/tools/loss.py index c874ac0..90f5689 100644 --- a/src/deepke/relation_extraction/standard/tools/loss.py +++ b/src/deepke/relation_extraction/standard/tools/loss.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.autograd import Variable class LabelSmoothSoftmaxCEV1(nn.Module): def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100): @@ -68,3 +69,38 @@ class TaylorCrossEntropyLossV1(nn.Module): loss = F.nll_loss(log_probs, labels, reduction=self.reduction, ignore_index=self.ignore_index) return loss + +class FocalLoss(nn.Module): + def __init__(self, gamma=0, alpha=None, size_average=True): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.alpha = alpha + if isinstance(alpha,(float,int)): + self.alpha = torch.Tensor([alpha,1-alpha]) + if isinstance(alpha,list): + self.alpha = torch.Tensor(alpha) + self.size_average = size_average + + def forward(self, input, target): + if input.dim()>2: + input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W + input = input.transpose(1,2) # N,C,H*W => N,H*W,C + input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C + target = target.view(-1,1) + + logpt = F.log_softmax(input,dim=1) + logpt = logpt.gather(1,target) + logpt = logpt.view(-1) + pt = Variable(logpt.data.exp()) + + if self.alpha is not None: + if self.alpha.type()!=input.data.type(): + self.alpha = self.alpha.type_as(input.data) + at = self.alpha.gather(0,target.data.view(-1)) + logpt = logpt * Variable(at) + + loss = -1 * (1-pt)**self.gamma * logpt + if self.size_average: + return loss.mean() + else: + return loss.sum() \ No newline at end of file