update capsule

This commit is contained in:
leo 2019-12-05 21:40:35 +08:00
parent 355a5883a6
commit 1f92922c2e
3 changed files with 113 additions and 14 deletions

View File

@ -1,6 +1,20 @@
num_primary_units: 8
num_output_units: 10 # relation_type
primary_channels: 1
primary_unit_size: 768
output_unit_size: 128
num_iterations: 3
model_name: capsule
share_weights: True
num_iterations: 5 # 迭代次数
dropout: 0.3
input_dim_capsule: ??? # 由上层卷积结果得到,一般是卷积输出的 hidden_size
dim_capsule: 50 # 输出 capsule 的维度
num_capsule: ??? # 输出 capsule 的数目,和分类结果相同, == num_relations
# primary capsule 组成
# 可以 embedding / cnn / rnn
# 暂时先用 cnn
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
out_channels: 100 # == input_dim_capsule
kernel_sizes: [9] # 必须为奇数,而且要比较大
activation: 'lrelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
keep_length: False # 不需要padding太多无用信息
pooling_strategy: cls # 无关紧要,根本用不到

View File

@ -1,6 +1,50 @@
# coding=utf-8
# Version: Python 3.7.3
# Tools: Pycharm 2019.02
import torch
from . import BasicModule
from module import Embedding, CNN
from module import Capsule as CapsuleLayer
from utils import seq_len_to_mask, to_one_hot
__date__ = '2019/12/1 12:00 上午'
__author__ = 'Haiyang Yu'
class Capsule(BasicModule):
def __init__(self, cfg):
super(Capsule, self).__init__()
if cfg.dim_strategy == 'cat':
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
else:
cfg.in_channels = cfg.word_dim
# capsule config
cfg.input_dim_capsule = cfg.out_channels
cfg.num_capsule = cfg.num_relations
self.num_relations = cfg.num_relations
self.embedding = Embedding(cfg)
self.cnn = CNN(cfg)
self.capsule = CapsuleLayer(cfg)
def forward(self, x):
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
mask = seq_len_to_mask(lens)
inputs = self.embedding(word, head_pos, tail_pos)
primary, _ = self.cnn(inputs) # 由于长度改变无法定向mask不mask可可以毕竟primary capsule 就是粗粒度的信息
output = self.capsule(primary)
output = output.norm(p=2, dim=-1) # 求得模长再返回值
return output # [B, N]
def loss(self, predict, target, reduction='mean'):
m_plus, m_minus, loss_lambda = 0.9, 0.1, 0.5
target = to_one_hot(target, self.num_relations)
max_l = (torch.relu(m_plus - predict))**2
max_r = (torch.relu(predict - m_minus))**2
loss = target * max_l + loss_lambda * (1 - target) * max_r
loss = torch.sum(loss, dim=-1)
if reduction == 'sum':
return loss.sum()
else:
# 默认情况为求平均
return loss.mean()

View File

@ -1,13 +1,54 @@
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
class Capsule(nn.Module):
def __init__(self, config):
def __init__(self, cfg):
super(Capsule, self).__init__()
# self.xxx = config.xxx
# self.xxx = cfg.xxx
self.input_dim_capsule = cfg.input_dim_capsule
self.dim_capsule = cfg.dim_capsule
self.num_capsule = cfg.num_capsule
self.batch_size = cfg.batch_size
self.share_weights = cfg.share_weights
self.num_iterations = cfg.num_iterations
if self.share_weights:
W = torch.zeros(1, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
else:
W = torch.zeros(self.batch_size, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
W = nn.init.xavier_normal_(W)
self.W = nn.Parameter(W)
def forward(self, x):
"""
x: [B, L, H] # 从 CNN / RNN 得到的结果
L 作为 input_num_capsules, H 作为 input_dim_capsule
"""
B, I, _ = x.size() # I 是 input_num_capsules
O, F = self.num_capsule, self.dim_capsule
u = torch.matmul(x, self.W)
u = u.view(B, I, O, F).transpose(1, 2) # [B, O, I, F]
b = torch.zeros_like(u[:, :, :, 0]).to(device=u.device) # [B, O, I]
for i in range(self.num_iterations):
c = torch.softmax(b, dim=1) # [B, O_s, I]
v = torch.einsum('boi,boif->bof', [c, u]) # [B, O, F]
v = self.squash(v)
b = torch.einsum('bof,boif->boi', [v, u]) # [B, O, I]
return v # [B, O, F] [B, num_capsule, dim_capsule]
@staticmethod
def squash(x: torch.Tensor):
x_norm = x.norm(p=2, dim=-1, keepdim=True)
mag = x_norm**2
out = x / x_norm * mag / (1 + mag)
return out