update rnn

This commit is contained in:
leo 2019-12-03 22:42:34 +08:00
parent 03a00c202e
commit 67d4f0f6c9
3 changed files with 15 additions and 17 deletions

View File

@ -1,9 +1,9 @@
model_name: rnn
type_rnn: 'RNN' # [RNN, GRU, LSTM]
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
#input_size: 100 # 使用 embedding 输出的结果,不需要指定
hidden_size: 150 # 必须为偶数
input_size: ??? # 使用 embedding 输出的结果,不需要指定
hidden_size: 150 # 必须为偶数
num_layers: 2
dropout: 0.3
bidirectional: True

View File

@ -1,24 +1,26 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import BasicModule
from module import Embedding, RNN
from utils import seq_len_to_mask
class BiLSTM(BasicModule):
def __init__(self, cfg):
super(BiLSTM, self).__init__()
self.use_pcnn = cfg.use_pcnn
if cfg.dim_strategy == 'cat':
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
else:
cfg.input_size = cfg.word_dim
self.embedding = Embedding(cfg)
self.bilsm = RNN(cfg)
self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)
self.fc2 = nn.Linear(cfg.intermediate, cfg.num_relations)
self.bilstm = RNN(cfg)
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x):
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
inputs = self.embedding(word, head_pos, tail_pos)
out, out_pool = self.rnn(inputs)
out, out_pool = self.bilstm(inputs, lens)
output = self.fc(out_pool)
return output

View File

@ -19,7 +19,6 @@ class RNN(nn.Module):
self.last_layer_hn = config.last_layer_hn
self.type_rnn = config.type_rnn
self.h0 = self._init_h0()
rnn = eval(f'nn.{self.type_rnn}')
self.rnn = rnn(input_size=self.input_size,
hidden_size=self.hidden_size,
@ -29,11 +28,6 @@ class RNN(nn.Module):
bias=True,
batch_first=True)
def _init_h0(self):
pass
# h0 = torch.empty(1,B,H)
# h0 = nn.init.orthogonal_(h0)
def forward(self, x, x_len):
"""
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
@ -46,8 +40,10 @@ class RNN(nn.Module):
H, N = self.hidden_size, self.num_layers
h0 = torch.zeros([2 * N, B, H]) if self.bidirectional else torch.zeros([N, B, H])
h0 = h0.to(device=x_len.device)
nn.init.orthogonal_(h0)
c0 = torch.zeros([2 * N, B, H]) if self.bidirectional else torch.zeros([N, B, H])
c0 = c0.to(device=x_len.device)
nn.init.orthogonal_(c0)
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)