93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
|
|
|
|
class RNN(nn.Module):
|
|
def __init__(self, config):
|
|
"""
|
|
type_rnn: RNN, GRU, LSTM 可选
|
|
"""
|
|
super(RNN, self).__init__()
|
|
|
|
# self.xxx = config.xxx
|
|
self.input_size = config.input_size
|
|
self.hidden_size = config.hidden_size // 2 if config.bidirectional else config.hidden_size
|
|
self.num_layers = config.num_layers
|
|
self.dropout = config.dropout
|
|
self.bidirectional = config.bidirectional
|
|
self.last_layer_hn = config.last_layer_hn
|
|
self.type_rnn = config.type_rnn
|
|
|
|
rnn = eval(f'nn.{self.type_rnn}')
|
|
self.rnn = rnn(input_size=self.input_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.num_layers,
|
|
dropout=self.dropout,
|
|
bidirectional=self.bidirectional,
|
|
bias=True,
|
|
batch_first=True)
|
|
|
|
def forward(self, x, x_len):
|
|
"""
|
|
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
|
|
:param x_len: torch.Tensor [L] 已经排好序的句长值
|
|
:return:
|
|
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
|
|
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果
|
|
"""
|
|
B, L, _ = x.size()
|
|
H, N = self.hidden_size, self.num_layers
|
|
|
|
h0 = torch.zeros([2 * N, B, H]) if self.bidirectional else torch.zeros([N, B, H])
|
|
h0 = h0.to(device=x_len.device)
|
|
nn.init.orthogonal_(h0)
|
|
c0 = torch.zeros([2 * N, B, H]) if self.bidirectional else torch.zeros([N, B, H])
|
|
c0 = c0.to(device=x_len.device)
|
|
nn.init.orthogonal_(c0)
|
|
|
|
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)
|
|
if self.type_rnn == 'LSTM':
|
|
output, hn = self.rnn(x, (h0, c0))
|
|
else:
|
|
output, hn = self.rnn(x, h0)
|
|
|
|
output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)
|
|
|
|
if self.type_rnn == 'LSTM':
|
|
hn = hn[0]
|
|
if self.bidirectional:
|
|
hn = hn.view(N, 2, B, H).transpose(1, 2).contiguous().view(N, B, 2 * H).transpose(0, 1)
|
|
else:
|
|
hn = hn.transpose(0, 1)
|
|
if self.last_layer_hn:
|
|
hn = hn[:, -1, :]
|
|
|
|
return output, hn
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
class Config(object):
|
|
type_rnn = 'LSTM'
|
|
input_size = 5
|
|
hidden_size = 4
|
|
num_layers = 3
|
|
dropout = 0.0
|
|
last_layer_hn = False
|
|
bidirectional = True
|
|
|
|
config = Config()
|
|
model = RNN(config)
|
|
print(model)
|
|
|
|
torch.manual_seed(1)
|
|
x = torch.tensor([[4, 3, 2, 1], [5, 6, 7, 0], [8, 10, 0, 0]])
|
|
x = torch.nn.Embedding(11, 5, padding_idx=0)(x) # B,L,H = 3,4,5
|
|
x_len = torch.tensor([4, 3, 2])
|
|
|
|
o, h = model(x, x_len)
|
|
|
|
print(o.shape, h.shape, sep='\n\n')
|
|
print(o[-1].data, h[-1].data, sep='\n\n')
|