update lm
This commit is contained in:
parent
67d4f0f6c9
commit
de97e387d9
|
@ -1,12 +1,20 @@
|
|||
model_name: lm
|
||||
|
||||
# lm_name = 'bert-base-chinese' # download usage
|
||||
# cache file usage
|
||||
#lm_file: 'pretrained'
|
||||
# 当使用预训练语言模型时,该预训练的模型存放位置
|
||||
lm_file: '/Users/leo/transformers/bert-base-chinese'
|
||||
|
||||
# lm_name = 'bert-base-chinese' # download usage
|
||||
#lm_file: 'pretrained'
|
||||
lm_file: '/home/yhy/transformers/bert-base-chinese'
|
||||
|
||||
# transformer 层数,初始 base bert 为12层
|
||||
# 但是数据量较小时调低些反而收敛更快效果更好
|
||||
num_hidden_layers: 2
|
||||
num_hidden_layers: 1
|
||||
|
||||
|
||||
# 后面所接 bilstm 的参数
|
||||
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
|
||||
input_size: 768 # 这个值由bert得到
|
||||
hidden_size: 100 # 必须为偶数
|
||||
num_layers: 1
|
||||
dropout: 0.3
|
||||
bidirectional: True
|
||||
last_layer_hn: True
|
|
@ -0,0 +1,24 @@
|
|||
from torch import nn
|
||||
from . import BasicModule
|
||||
from module import RNN
|
||||
from transformers import BertModel
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class LM(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(LM, self).__init__()
|
||||
self.bert = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.num_hidden_layers)
|
||||
self.bilstm = RNN(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens = x['word'], x['lens']
|
||||
mask = seq_len_to_mask(lens, mask_pos_to_true=False)
|
||||
last_hidden_state, pooler_output = self.bert(word, attention_mask=mask)
|
||||
out, out_pool = self.bilstm(last_hidden_state, lens)
|
||||
out_pool = self.dropout(out_pool)
|
||||
output = self.fc(out_pool)
|
||||
|
||||
return output
|
Loading…
Reference in New Issue