update capsule
This commit is contained in:
parent
355a5883a6
commit
1f92922c2e
|
@ -1,6 +1,20 @@
|
|||
num_primary_units: 8
|
||||
num_output_units: 10 # relation_type
|
||||
primary_channels: 1
|
||||
primary_unit_size: 768
|
||||
output_unit_size: 128
|
||||
num_iterations: 3
|
||||
model_name: capsule
|
||||
|
||||
share_weights: True
|
||||
num_iterations: 5 # 迭代次数
|
||||
dropout: 0.3
|
||||
|
||||
input_dim_capsule: ??? # 由上层卷积结果得到,一般是卷积输出的 hidden_size
|
||||
dim_capsule: 50 # 输出 capsule 的维度
|
||||
num_capsule: ??? # 输出 capsule 的数目,和分类结果相同, == num_relations
|
||||
|
||||
|
||||
# primary capsule 组成
|
||||
# 可以 embedding / cnn / rnn
|
||||
# 暂时先用 cnn
|
||||
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
out_channels: 100 # == input_dim_capsule
|
||||
kernel_sizes: [9] # 必须为奇数,而且要比较大
|
||||
activation: 'lrelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
keep_length: False # 不需要padding,太多无用信息
|
||||
pooling_strategy: cls # 无关紧要,根本用不到
|
|
@ -1,6 +1,50 @@
|
|||
# coding=utf-8
|
||||
# Version: Python 3.7.3
|
||||
# Tools: Pycharm 2019.02
|
||||
import torch
|
||||
from . import BasicModule
|
||||
from module import Embedding, CNN
|
||||
from module import Capsule as CapsuleLayer
|
||||
from utils import seq_len_to_mask, to_one_hot
|
||||
|
||||
__date__ = '2019/12/1 12:00 上午'
|
||||
__author__ = 'Haiyang Yu'
|
||||
|
||||
class Capsule(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(Capsule, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.in_channels = cfg.word_dim
|
||||
|
||||
# capsule config
|
||||
cfg.input_dim_capsule = cfg.out_channels
|
||||
cfg.num_capsule = cfg.num_relations
|
||||
|
||||
self.num_relations = cfg.num_relations
|
||||
self.embedding = Embedding(cfg)
|
||||
self.cnn = CNN(cfg)
|
||||
self.capsule = CapsuleLayer(cfg)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
|
||||
mask = seq_len_to_mask(lens)
|
||||
inputs = self.embedding(word, head_pos, tail_pos)
|
||||
|
||||
primary, _ = self.cnn(inputs) # 由于长度改变,无法定向mask,不mask可可以,毕竟primary capsule 就是粗粒度的信息
|
||||
output = self.capsule(primary)
|
||||
output = output.norm(p=2, dim=-1) # 求得模长再返回值
|
||||
|
||||
return output # [B, N]
|
||||
|
||||
def loss(self, predict, target, reduction='mean'):
|
||||
m_plus, m_minus, loss_lambda = 0.9, 0.1, 0.5
|
||||
|
||||
target = to_one_hot(target, self.num_relations)
|
||||
max_l = (torch.relu(m_plus - predict))**2
|
||||
max_r = (torch.relu(predict - m_minus))**2
|
||||
loss = target * max_l + loss_lambda * (1 - target) * max_r
|
||||
loss = torch.sum(loss, dim=-1)
|
||||
|
||||
if reduction == 'sum':
|
||||
return loss.sum()
|
||||
else:
|
||||
# 默认情况为求平均
|
||||
return loss.mean()
|
||||
|
|
|
@ -1,13 +1,54 @@
|
|||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Capsule(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, cfg):
|
||||
super(Capsule, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
# self.xxx = cfg.xxx
|
||||
self.input_dim_capsule = cfg.input_dim_capsule
|
||||
self.dim_capsule = cfg.dim_capsule
|
||||
self.num_capsule = cfg.num_capsule
|
||||
self.batch_size = cfg.batch_size
|
||||
self.share_weights = cfg.share_weights
|
||||
self.num_iterations = cfg.num_iterations
|
||||
|
||||
if self.share_weights:
|
||||
W = torch.zeros(1, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
|
||||
else:
|
||||
W = torch.zeros(self.batch_size, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
|
||||
|
||||
W = nn.init.xavier_normal_(W)
|
||||
self.W = nn.Parameter(W)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||
L 作为 input_num_capsules, H 作为 input_dim_capsule
|
||||
"""
|
||||
B, I, _ = x.size() # I 是 input_num_capsules
|
||||
O, F = self.num_capsule, self.dim_capsule
|
||||
|
||||
u = torch.matmul(x, self.W)
|
||||
u = u.view(B, I, O, F).transpose(1, 2) # [B, O, I, F]
|
||||
|
||||
b = torch.zeros_like(u[:, :, :, 0]).to(device=u.device) # [B, O, I]
|
||||
for i in range(self.num_iterations):
|
||||
c = torch.softmax(b, dim=1) # [B, O_s, I]
|
||||
v = torch.einsum('boi,boif->bof', [c, u]) # [B, O, F]
|
||||
v = self.squash(v)
|
||||
b = torch.einsum('bof,boif->boi', [v, u]) # [B, O, I]
|
||||
|
||||
return v # [B, O, F] [B, num_capsule, dim_capsule]
|
||||
|
||||
@staticmethod
|
||||
def squash(x: torch.Tensor):
|
||||
x_norm = x.norm(p=2, dim=-1, keepdim=True)
|
||||
mag = x_norm**2
|
||||
out = x / x_norm * mag / (1 + mag)
|
||||
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue