update rnn
This commit is contained in:
parent
03a00c202e
commit
67d4f0f6c9
|
@ -1,9 +1,9 @@
|
|||
model_name: rnn
|
||||
|
||||
type_rnn: 'RNN' # [RNN, GRU, LSTM]
|
||||
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
|
||||
|
||||
#input_size: 100 # 使用 embedding 输出的结果,不需要指定
|
||||
hidden_size: 150 # 必须为偶数
|
||||
input_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
hidden_size: 150 # 必须为偶数
|
||||
num_layers: 2
|
||||
dropout: 0.3
|
||||
bidirectional: True
|
||||
|
|
|
@ -1,24 +1,26 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from . import BasicModule
|
||||
from module import Embedding, RNN
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class BiLSTM(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(BiLSTM, self).__init__()
|
||||
|
||||
self.use_pcnn = cfg.use_pcnn
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.input_size = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.bilsm = RNN(cfg)
|
||||
self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)
|
||||
self.fc2 = nn.Linear(cfg.intermediate, cfg.num_relations)
|
||||
self.bilstm = RNN(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
|
||||
inputs = self.embedding(word, head_pos, tail_pos)
|
||||
out, out_pool = self.rnn(inputs)
|
||||
out, out_pool = self.bilstm(inputs, lens)
|
||||
output = self.fc(out_pool)
|
||||
|
||||
return output
|
||||
|
|
|
@ -19,7 +19,6 @@ class RNN(nn.Module):
|
|||
self.last_layer_hn = config.last_layer_hn
|
||||
self.type_rnn = config.type_rnn
|
||||
|
||||
self.h0 = self._init_h0()
|
||||
rnn = eval(f'nn.{self.type_rnn}')
|
||||
self.rnn = rnn(input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
|
@ -29,11 +28,6 @@ class RNN(nn.Module):
|
|||
bias=True,
|
||||
batch_first=True)
|
||||
|
||||
def _init_h0(self):
|
||||
pass
|
||||
# h0 = torch.empty(1,B,H)
|
||||
# h0 = nn.init.orthogonal_(h0)
|
||||
|
||||
def forward(self, x, x_len):
|
||||
"""
|
||||
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
|
||||
|
@ -46,8 +40,10 @@ class RNN(nn.Module):
|
|||
H, N = self.hidden_size, self.num_layers
|
||||
|
||||
h0 = torch.zeros([2 * N, B, H]) if self.bidirectional else torch.zeros([N, B, H])
|
||||
h0 = h0.to(device=x_len.device)
|
||||
nn.init.orthogonal_(h0)
|
||||
c0 = torch.zeros([2 * N, B, H]) if self.bidirectional else torch.zeros([N, B, H])
|
||||
c0 = c0.to(device=x_len.device)
|
||||
nn.init.orthogonal_(c0)
|
||||
|
||||
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)
|
||||
|
|
Loading…
Reference in New Issue