deepke/models/Capsule.py

51 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from . import BasicModule
from module import Embedding, CNN
from module import Capsule as CapsuleLayer
from utils import seq_len_to_mask, to_one_hot
class Capsule(BasicModule):
def __init__(self, cfg):
super(Capsule, self).__init__()
if cfg.dim_strategy == 'cat':
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
else:
cfg.in_channels = cfg.word_dim
# capsule config
cfg.input_dim_capsule = cfg.out_channels
cfg.num_capsule = cfg.num_relations
self.num_relations = cfg.num_relations
self.embedding = Embedding(cfg)
self.cnn = CNN(cfg)
self.capsule = CapsuleLayer(cfg)
def forward(self, x):
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
mask = seq_len_to_mask(lens)
inputs = self.embedding(word, head_pos, tail_pos)
primary, _ = self.cnn(inputs) # 由于长度改变无法定向mask不mask可可以毕竟primary capsule 就是粗粒度的信息
output = self.capsule(primary)
output = output.norm(p=2, dim=-1) # 求得模长再返回值
return output # [B, N]
def loss(self, predict, target, reduction='mean'):
m_plus, m_minus, loss_lambda = 0.9, 0.1, 0.5
target = to_one_hot(target, self.num_relations)
max_l = (torch.relu(m_plus - predict))**2
max_r = (torch.relu(predict - m_minus))**2
loss = target * max_l + loss_lambda * (1 - target) * max_r
loss = torch.sum(loss, dim=-1)
if reduction == 'sum':
return loss.sum()
else:
# 默认情况为求平均
return loss.mean()