update gcn
This commit is contained in:
parent
fabcef1825
commit
c629185459
|
@ -1 +1,7 @@
|
|||
num_layers: 3
|
||||
model_name: gcn
|
||||
|
||||
num_layers: 3
|
||||
|
||||
input_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
hidden_size: 100
|
||||
dropout: 0.3
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from . import BasicModule
|
||||
from module import Embedding
|
||||
from module import GCN as GCNBlock
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class GCN(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(GCN, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.input_size = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.gcn = GCNBlock(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
|
||||
|
||||
# adj = x['adj']
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
B, L = len(x['lens']), x['lens'][0]
|
||||
adj = torch.empty(B, L, L).random_(2).to(device=x['lens'].device)
|
||||
|
||||
inputs = self.embedding(word, head_pos, tail_pos)
|
||||
output = self.gcn(inputs, adj)
|
||||
output = output.max(dim=1)[0]
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,145 @@
|
|||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(GCN, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.num_layers = cfg.num_layers
|
||||
self.input_size = cfg.input_size
|
||||
self.hidden_size = cfg.hidden_size
|
||||
self.dropout = cfg.dropout
|
||||
|
||||
self.fc1 = nn.Linear(self.input_size, self.hidden_size)
|
||||
self.fcs = nn.ModuleList([nn.Linear(self.hidden_size, self.hidden_size) for i in range(self.num_layers - 1)])
|
||||
self.dropout = nn.Dropout(self.dropout)
|
||||
|
||||
|
||||
def forward(self, x, adj):
|
||||
AxW = self.fc1(torch.bmm(adj, x)) + self.fc1(x)
|
||||
AxW = F.leaky_relu(AxW)
|
||||
AxW = self.dropout(AxW)
|
||||
for fc in self.fcs:
|
||||
AxW = fc(torch.bmm(adj, AxW)) + fc(AxW)
|
||||
AxW = F.leaky_relu(AxW)
|
||||
AxW = self.dropout(AxW)
|
||||
|
||||
return AxW
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Tree(object):
|
||||
def __init__(self):
|
||||
self.parent = None
|
||||
self.num_children = 0
|
||||
self.children = list()
|
||||
|
||||
def add_child(self, child):
|
||||
child.parent = self
|
||||
self.num_children += 1
|
||||
self.children.append(child)
|
||||
|
||||
def size(self):
|
||||
s = getattr(self, '_size', -1)
|
||||
if s != -1:
|
||||
return self._size
|
||||
else:
|
||||
count = 1
|
||||
for i in range(self.num_children):
|
||||
count += self.children[i].size()
|
||||
self._size = count
|
||||
return self._size
|
||||
|
||||
def __iter__(self):
|
||||
yield self
|
||||
for c in self.children:
|
||||
for x in c:
|
||||
yield x
|
||||
|
||||
def depth(self):
|
||||
d = getattr(self, '_depth', -1)
|
||||
if d != -1:
|
||||
return self._depth
|
||||
else:
|
||||
count = 0
|
||||
if self.num_children > 0:
|
||||
for i in range(self.num_children):
|
||||
child_depth = self.children[i].depth()
|
||||
if child_depth > count:
|
||||
count = child_depth
|
||||
count += 1
|
||||
self._depth = count
|
||||
return self._depth
|
||||
|
||||
|
||||
def head_to_adj(head, directed=True, self_loop=False):
|
||||
"""
|
||||
Convert a sequence of head indexes to an (numpy) adjacency matrix.
|
||||
"""
|
||||
seq_len = len(head)
|
||||
head = head[:seq_len]
|
||||
root = None
|
||||
nodes = [Tree() for _ in head]
|
||||
|
||||
for i in range(seq_len):
|
||||
h = head[i]
|
||||
setattr(nodes[i], 'idx', i)
|
||||
if h == 0:
|
||||
root = nodes[i]
|
||||
else:
|
||||
nodes[h - 1].add_child(nodes[i])
|
||||
|
||||
assert root is not None
|
||||
|
||||
ret = np.zeros((seq_len, seq_len), dtype=np.float32)
|
||||
queue = [root]
|
||||
idx = []
|
||||
while len(queue) > 0:
|
||||
t, queue = queue[0], queue[1:]
|
||||
idx += [t.idx]
|
||||
for c in t.children:
|
||||
ret[t.idx, c.idx] = 1
|
||||
queue += t.children
|
||||
|
||||
if not directed:
|
||||
ret = ret + ret.T
|
||||
|
||||
if self_loop:
|
||||
for i in idx:
|
||||
ret[i, i] = 1
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def pad_adj(adj, max_len):
|
||||
pad_len = max_len - adj.shape[0]
|
||||
for i in range(pad_len):
|
||||
adj = np.insert(adj, adj.shape[-1], 0, axis=1)
|
||||
for i in range(len):
|
||||
adj = np.insert(adj, adj.shape[0], 0, axis=0)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class Config():
|
||||
num_layers = 3
|
||||
input_size = 50
|
||||
hidden_size = 100
|
||||
dropout = 0.3
|
||||
cfg = Config()
|
||||
x = torch.randn(1, 10, 50)
|
||||
adj = torch.empty(1, 10, 10).random_(2)
|
||||
m = GCN(cfg)
|
||||
print(m)
|
||||
out = m(x, adj)
|
||||
print(out.shape)
|
||||
print(out)
|
Loading…
Reference in New Issue