update lm

This commit is contained in:
leo 2019-12-03 22:42:51 +08:00
parent 67d4f0f6c9
commit de97e387d9
2 changed files with 38 additions and 6 deletions

View File

@ -1,12 +1,20 @@
model_name: lm
# lm_name = 'bert-base-chinese' # download usage
# cache file usage
#lm_file: 'pretrained'
# 当使用预训练语言模型时,该预训练的模型存放位置
lm_file: '/Users/leo/transformers/bert-base-chinese'
# lm_name = 'bert-base-chinese' # download usage
#lm_file: 'pretrained'
lm_file: '/home/yhy/transformers/bert-base-chinese'
# transformer 层数,初始 base bert 为12层
# 但是数据量较小时调低些反而收敛更快效果更好
num_hidden_layers: 2
num_hidden_layers: 1
# 后面所接 bilstm 的参数
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
input_size: 768 # 这个值由bert得到
hidden_size: 100 # 必须为偶数
num_layers: 1
dropout: 0.3
bidirectional: True
last_layer_hn: True

24
models/LM.py Normal file
View File

@ -0,0 +1,24 @@
from torch import nn
from . import BasicModule
from module import RNN
from transformers import BertModel
from utils import seq_len_to_mask
class LM(BasicModule):
def __init__(self, cfg):
super(LM, self).__init__()
self.bert = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.num_hidden_layers)
self.bilstm = RNN(cfg)
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x):
word, lens = x['word'], x['lens']
mask = seq_len_to_mask(lens, mask_pos_to_true=False)
last_hidden_state, pooler_output = self.bert(word, attention_mask=mask)
out, out_pool = self.bilstm(last_hidden_state, lens)
out_pool = self.dropout(out_pool)
output = self.fc(out_pool)
return output