diff --git a/conf/model/lm.yaml b/conf/model/lm.yaml index bc823dc..3cf6cf4 100644 --- a/conf/model/lm.yaml +++ b/conf/model/lm.yaml @@ -1,12 +1,20 @@ model_name: lm -# lm_name = 'bert-base-chinese' # download usage -# cache file usage -#lm_file: 'pretrained' # 当使用预训练语言模型时,该预训练的模型存放位置 -lm_file: '/Users/leo/transformers/bert-base-chinese' - +# lm_name = 'bert-base-chinese' # download usage +#lm_file: 'pretrained' +lm_file: '/home/yhy/transformers/bert-base-chinese' # transformer 层数,初始 base bert 为12层 # 但是数据量较小时调低些反而收敛更快效果更好 -num_hidden_layers: 2 \ No newline at end of file +num_hidden_layers: 1 + + +# 后面所接 bilstm 的参数 +type_rnn: 'LSTM' # [RNN, GRU, LSTM] +input_size: 768 # 这个值由bert得到 +hidden_size: 100 # 必须为偶数 +num_layers: 1 +dropout: 0.3 +bidirectional: True +last_layer_hn: True \ No newline at end of file diff --git a/models/LM.py b/models/LM.py new file mode 100644 index 0000000..44eb60b --- /dev/null +++ b/models/LM.py @@ -0,0 +1,24 @@ +from torch import nn +from . import BasicModule +from module import RNN +from transformers import BertModel +from utils import seq_len_to_mask + + +class LM(BasicModule): + def __init__(self, cfg): + super(LM, self).__init__() + self.bert = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.num_hidden_layers) + self.bilstm = RNN(cfg) + self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations) + self.dropout = nn.Dropout(cfg.dropout) + + def forward(self, x): + word, lens = x['word'], x['lens'] + mask = seq_len_to_mask(lens, mask_pos_to_true=False) + last_hidden_state, pooler_output = self.bert(word, attention_mask=mask) + out, out_pool = self.bilstm(last_hidden_state, lens) + out_pool = self.dropout(out_pool) + output = self.fc(out_pool) + + return output