update gcn
This commit is contained in:
parent
c629185459
commit
0786b2b30e
|
@ -37,7 +37,11 @@ def collate_fn(cfg):
|
|||
x['tail_pos'] = torch.tensor(tail_pos)
|
||||
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
||||
|
||||
if cfg.model == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
B, L = len(batch), max_len
|
||||
adj = torch.empty(B, L, L).random_(2)
|
||||
x['adj'] = adj
|
||||
return x, y
|
||||
|
||||
return collate_fn_intra
|
||||
|
|
|
@ -20,12 +20,8 @@ class GCN(BasicModule):
|
|||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
|
||||
word, lens, head_pos, tail_pos, adj = x['word'], x['lens'], x['head_pos'], x['tail_pos'], x['adj']
|
||||
|
||||
# adj = x['adj']
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
B, L = len(x['lens']), x['lens'][0]
|
||||
adj = torch.empty(B, L, L).random_(2).to(device=x['lens'].device)
|
||||
|
||||
inputs = self.embedding(word, head_pos, tail_pos)
|
||||
output = self.gcn(inputs, adj)
|
||||
|
|
|
@ -23,11 +23,14 @@ class GCN(nn.Module):
|
|||
|
||||
|
||||
def forward(self, x, adj):
|
||||
L = x.size(1)
|
||||
AxW = self.fc1(torch.bmm(adj, x)) + self.fc1(x)
|
||||
AxW = AxW / L
|
||||
AxW = F.leaky_relu(AxW)
|
||||
AxW = self.dropout(AxW)
|
||||
for fc in self.fcs:
|
||||
AxW = fc(torch.bmm(adj, AxW)) + fc(AxW)
|
||||
AxW = AxW / L
|
||||
AxW = F.leaky_relu(AxW)
|
||||
AxW = self.dropout(AxW)
|
||||
|
||||
|
|
Loading…
Reference in New Issue