add more losses

This commit is contained in:
tlk-dsg 2021-12-07 18:53:20 +08:00
parent ba77f57f07
commit 5879785cb8
2 changed files with 37 additions and 1 deletions

View File

@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name='deepke', # 打包后的包文件名
version='0.2.91', #版本号
version='0.2.92', #版本号
keywords=["pip", "RE","NER","AE"], # 关键字
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
long_description="client", #详细说明

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class LabelSmoothSoftmaxCEV1(nn.Module):
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
@ -68,3 +69,38 @@ class TaylorCrossEntropyLossV1(nn.Module):
loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
ignore_index=self.ignore_index)
return loss
class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int)):
self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list):
self.alpha = torch.Tensor(alpha)
self.size_average = size_average
def forward(self, input, target):
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1,1)
logpt = F.log_softmax(input,dim=1)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp())
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average:
return loss.mean()
else:
return loss.sum()