update gcn
This commit is contained in:
parent
c629185459
commit
0786b2b30e
|
@ -37,7 +37,11 @@ def collate_fn(cfg):
|
||||||
x['tail_pos'] = torch.tensor(tail_pos)
|
x['tail_pos'] = torch.tensor(tail_pos)
|
||||||
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
||||||
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
||||||
|
if cfg.model == 'gcn':
|
||||||
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||||
|
B, L = len(batch), max_len
|
||||||
|
adj = torch.empty(B, L, L).random_(2)
|
||||||
|
x['adj'] = adj
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
return collate_fn_intra
|
return collate_fn_intra
|
||||||
|
|
|
@ -20,12 +20,8 @@ class GCN(BasicModule):
|
||||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
|
word, lens, head_pos, tail_pos, adj = x['word'], x['lens'], x['head_pos'], x['tail_pos'], x['adj']
|
||||||
|
|
||||||
# adj = x['adj']
|
|
||||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
|
||||||
B, L = len(x['lens']), x['lens'][0]
|
|
||||||
adj = torch.empty(B, L, L).random_(2).to(device=x['lens'].device)
|
|
||||||
|
|
||||||
inputs = self.embedding(word, head_pos, tail_pos)
|
inputs = self.embedding(word, head_pos, tail_pos)
|
||||||
output = self.gcn(inputs, adj)
|
output = self.gcn(inputs, adj)
|
||||||
|
|
|
@ -23,11 +23,14 @@ class GCN(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, adj):
|
def forward(self, x, adj):
|
||||||
|
L = x.size(1)
|
||||||
AxW = self.fc1(torch.bmm(adj, x)) + self.fc1(x)
|
AxW = self.fc1(torch.bmm(adj, x)) + self.fc1(x)
|
||||||
|
AxW = AxW / L
|
||||||
AxW = F.leaky_relu(AxW)
|
AxW = F.leaky_relu(AxW)
|
||||||
AxW = self.dropout(AxW)
|
AxW = self.dropout(AxW)
|
||||||
for fc in self.fcs:
|
for fc in self.fcs:
|
||||||
AxW = fc(torch.bmm(adj, AxW)) + fc(AxW)
|
AxW = fc(torch.bmm(adj, AxW)) + fc(AxW)
|
||||||
|
AxW = AxW / L
|
||||||
AxW = F.leaky_relu(AxW)
|
AxW = F.leaky_relu(AxW)
|
||||||
AxW = self.dropout(AxW)
|
AxW = self.dropout(AxW)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue