add more losses
This commit is contained in:
parent
ba77f57f07
commit
5879785cb8
2
setup.py
2
setup.py
|
@ -1,7 +1,7 @@
|
|||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name='deepke', # 打包后的包文件名
|
||||
version='0.2.91', #版本号
|
||||
version='0.2.92', #版本号
|
||||
keywords=["pip", "RE","NER","AE"], # 关键字
|
||||
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
|
||||
long_description="client", #详细说明
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
class LabelSmoothSoftmaxCEV1(nn.Module):
|
||||
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-100):
|
||||
|
@ -68,3 +69,38 @@ class TaylorCrossEntropyLossV1(nn.Module):
|
|||
loss = F.nll_loss(log_probs, labels, reduction=self.reduction,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
def __init__(self, gamma=0, alpha=None, size_average=True):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
if isinstance(alpha,(float,int)):
|
||||
self.alpha = torch.Tensor([alpha,1-alpha])
|
||||
if isinstance(alpha,list):
|
||||
self.alpha = torch.Tensor(alpha)
|
||||
self.size_average = size_average
|
||||
|
||||
def forward(self, input, target):
|
||||
if input.dim()>2:
|
||||
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
|
||||
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
|
||||
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
|
||||
target = target.view(-1,1)
|
||||
|
||||
logpt = F.log_softmax(input,dim=1)
|
||||
logpt = logpt.gather(1,target)
|
||||
logpt = logpt.view(-1)
|
||||
pt = Variable(logpt.data.exp())
|
||||
|
||||
if self.alpha is not None:
|
||||
if self.alpha.type()!=input.data.type():
|
||||
self.alpha = self.alpha.type_as(input.data)
|
||||
at = self.alpha.gather(0,target.data.view(-1))
|
||||
logpt = logpt * Variable(at)
|
||||
|
||||
loss = -1 * (1-pt)**self.gamma * logpt
|
||||
if self.size_average:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum()
|
Loading…
Reference in New Issue