update gcn

This commit is contained in:
leo 2019-12-06 10:08:57 +08:00
parent c629185459
commit 0786b2b30e
3 changed files with 9 additions and 6 deletions

View File

@ -37,7 +37,11 @@ def collate_fn(cfg):
x['tail_pos'] = torch.tensor(tail_pos)
if cfg.model_name == 'cnn' and cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor(pcnn_mask)
if cfg.model == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
B, L = len(batch), max_len
adj = torch.empty(B, L, L).random_(2)
x['adj'] = adj
return x, y
return collate_fn_intra

View File

@ -20,12 +20,8 @@ class GCN(BasicModule):
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
def forward(self, x):
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
word, lens, head_pos, tail_pos, adj = x['word'], x['lens'], x['head_pos'], x['tail_pos'], x['adj']
# adj = x['adj']
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
B, L = len(x['lens']), x['lens'][0]
adj = torch.empty(B, L, L).random_(2).to(device=x['lens'].device)
inputs = self.embedding(word, head_pos, tail_pos)
output = self.gcn(inputs, adj)

View File

@ -23,11 +23,14 @@ class GCN(nn.Module):
def forward(self, x, adj):
L = x.size(1)
AxW = self.fc1(torch.bmm(adj, x)) + self.fc1(x)
AxW = AxW / L
AxW = F.leaky_relu(AxW)
AxW = self.dropout(AxW)
for fc in self.fcs:
AxW = fc(torch.bmm(adj, AxW)) + fc(AxW)
AxW = AxW / L
AxW = F.leaky_relu(AxW)
AxW = self.dropout(AxW)