This commit is contained in:
leo 2019-09-16 19:48:24 +08:00
parent b75312d53f
commit 567673e7ac
7 changed files with 65 additions and 62 deletions

1
.gitignore vendored
View File

@ -11,6 +11,7 @@ checkpoints/*.pth
bert_pretrained/*
!bert_pretrained/readme.md
data/out
data/origin/statics.py
demo.py

View File

@ -25,11 +25,11 @@ class ModelConfig(object):
class CNNConfig(object):
use_pcnn = True
out_channels = 100
kernel_size = [3, 5]
kernel_size = [3, 5, 7]
class RNNConfig(object):
lstm_layers = 2
lstm_layers = 3
last_hn = False
@ -39,7 +39,7 @@ class GCNConfig(object):
class TransformerConfig(object):
transformer_layers = 2
transformer_layers = 3
class CapsuleConfig(object):
@ -48,7 +48,7 @@ class CapsuleConfig(object):
primary_channels = 1
primary_unit_size = 768
output_unit_size = 128
num_iterations = 5
num_iterations = 3
class LMConfig(object):
@ -62,6 +62,8 @@ class Config(object):
# 预处理后存放文件的位置
out_path = 'data/out'
# 是否将句子中实体替换为实体类型
replace_entity_by_type = True
# 是否为中文数据
is_chinese = True
# 是否需要分词操作
@ -77,7 +79,7 @@ class Config(object):
pos_limit = 50 # [-50, 50]
# (CNN, RNN, GCN, Transformer, Capsule, LM)
model_name = 'CNN'
model_name = 'Capsule'
training = TrainingConfig()
model = ModelConfig()
@ -88,22 +90,5 @@ class Config(object):
capsule = CapsuleConfig()
lm = LMConfig()
def parse(self, kwargs, verbose=False):
'''
user can update the default hyper parameters
'''
for k, v in kwargs.items():
if not hasattr(self, k):
raise Exception('opt has No key: {}'.format(k))
setattr(self, k, v)
if verbose:
print('*************************************************')
print('user config:')
for k, v in kwargs.items():
if not k.startswith('__'):
print("{} => {}".format(k, getattr(self, k)))
print('*************************************************')
config = Config()

View File

@ -44,8 +44,8 @@ def collate_fn(batch):
tail_pos.append(_padding(data['tail_pos'], max_len))
mask_pos.append(_padding(data['mask_pos'], max_len))
y.append(data['target'])
return torch.Tensor(sent), torch.Tensor(head_pos), torch.Tensor(
tail_pos), torch.Tensor(mask_pos), torch.Tensor(y)
return torch.tensor(sent), torch.tensor(head_pos), torch.tensor(
tail_pos), torch.tensor(mask_pos), torch.tensor(y)
if __name__ == '__main__':

View File

@ -28,7 +28,7 @@ class CNN(BasicModule):
self.pos_dim)
# PCNN embedding
self.mask_embed = nn.Embedding(4, 3)
masks = torch.Tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0,
masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0,
100]])
self.mask_embed.weight.data.copy_(masks)
self.mask_embed.weight.requires_grad = False

View File

@ -40,29 +40,20 @@ def _pos_feature(sent_len: int, entity_idx: int, entity_len: int,
def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]:
if vocab.name == 'LM':
for d in data:
d['seq_len'] = len(d['lm_idx'])
d['target'] = relations[d['relation']]
return data
for d in data:
word2idx = [vocab.word2idx.get(w, 1) for w in d['sentence']]
seq_len = len(word2idx)
head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset'])
if vocab.name == 'word':
word2idx = [vocab.word2idx.get(w, 1) for w in d['words']]
seq_len = len(word2idx)
head_idx, tail_idx = d['head_idx'], d['tail_idx']
head_len, tail_len = 1, 1
elif vocab.name == 'char':
word2idx = [
vocab.word2idx.get(w, 1) for w in d['sentence'].strip()
]
seq_len = len(word2idx)
head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset'])
head_len, tail_len = len(d['head']), len(d['tail'])
else:
head_len, tail_len = len(d['head_type']), len(d['tail_type'])
entities_idx = [head_idx, tail_idx
] if tail_idx > head_idx else [tail_idx, head_idx]
head_pos = _pos_feature(seq_len, head_idx, head_len, config.pos_limit)
@ -83,12 +74,11 @@ def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]:
def _build_vocab(data: List[Dict], out_path: Path) -> Vocab:
if config.word_segment:
vocab = Vocab('word')
for d in data:
vocab.add_sent(d['words'])
else:
vocab = Vocab('char')
for d in data:
vocab.add_sent(d['sentence'].strip())
for d in data:
vocab.add_sent(d['sentence'])
vocab.trim(config.min_freq)
ensure_dir(out_path)
@ -108,29 +98,44 @@ def _split_sent(data: List[Dict], verbose: bool = True) -> List[Dict]:
jieba.add_word('TAIL')
for d in data:
sent = d['sentence'].strip()
sent = sent.replace(d['head'], 'HEAD', 1)
sent = sent.replace(d['tail'], 'TAIL', 1)
sent = d['sentence']
sent = sent.replace(d['head_type'], 'HEAD', 1)
sent = sent.replace(d['tail_type'], 'TAIL', 1)
sent = jieba.lcut(sent)
head_idx, tail_idx = sent.index('HEAD'), sent.index('TAIL')
sent[head_idx], sent[tail_idx] = d['head'], d['tail']
d['words'] = sent
d['head_idx'] = head_idx
d['tail_idx'] = tail_idx
sent[head_idx], sent[tail_idx] = d['head_type'], d['tail_type']
d['sentence'] = sent
d['head_offset'] = head_idx
d['tail_offset'] = tail_idx
return data
def _add_lm_data(data: List[Dict]) -> List[Dict]:
'使用语言模型的词表,序列化输入的句子'
tokenizer = BertTokenizer.from_pretrained('../bert_pretrained')
tokenizer = BertTokenizer.from_pretrained(config.lm.lm_file)
for d in data:
sent = d['sentence']
sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']
d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True)
d['seq_len'] = len(d['lm_idx'])
return data
def _replace_entity_by_type(data: List[Dict]) -> List[Dict]:
for d in data:
sent = d['sentence'].strip()
d['seq_len'] = len(sent)
sent = sent.replace(d['head'], d['head_type'], 1)
sent = sent.replace(d['tail'], d['tail_type'], 1)
sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']
d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True)
head_offset = sent.index(d['head_type'])
tail_offset = sent.index(d['tail_type'])
d['sentence'] = sent
d['head_offset'] = head_offset
d['tail_offset'] = tail_offset
return data
@ -163,6 +168,12 @@ def process(data_path: Path, out_path: Path) -> None:
test_raw_data = load_csv(test_fp)
relations = _load_relations(relation_fp)
# 使用 entity type 替换句子中的 entity
# 这样训练效果会提升很多
if config.replace_entity_by_type:
train_raw_data = _replace_entity_by_type(train_raw_data)
test_raw_data = _replace_entity_by_type(test_raw_data)
# 使用预训练语言模型时
if config.model_name == 'LM':
print('\nuse pretrained language model serialize sentence...')

View File

@ -28,14 +28,14 @@ def train(epoch, device, dataloader, model, optimizer, criterion, config):
# logging
data_cal = len(dataloader.dataset) if batch_idx == len(
dataloader) else batch_idx * len(y)
if (config.train_log and batch_idx %
config.log_interval == 0) or batch_idx == len(dataloader):
if (config.training.train_log and batch_idx %
config.training.log_interval == 0) or batch_idx == len(dataloader):
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, data_cal, len(dataloader.dataset),
100. * batch_idx / len(dataloader), loss.item()))
# plot
if config.show_plot:
if config.training.show_plot:
plt.plot(total_loss)
plt.title('loss')
plt.show()
@ -65,7 +65,7 @@ def validate(dataloader, model, device, config):
total_y_pred = np.append(total_y_pred, y_pred)
total_f1 = []
for average in config.f1_norm:
for average in config.training.f1_norm:
p, r, f1, _ = precision_recall_fscore_support(total_y_true,
total_y_pred,
average=average)

14
main.py
View File

@ -1,5 +1,6 @@
import os
import argparse
import warnings
import torch
import torch.nn as nn
import torch.optim as optim
@ -11,6 +12,8 @@ from deepke.trainer import train, validate
from deepke.preprocess import process
from deepke.dataset import CustomDataset, collate_fn
warnings.filterwarnings("ignore")
__Models__ = {
"CNN": model.CNN,
"RNN": model.BiLSTM,
@ -21,14 +24,14 @@ __Models__ = {
}
parser = argparse.ArgumentParser(description='choose your model')
parser.add_argument('--model_name', type=str, help='model name')
parser.add_argument('--model', type=str, help='model name')
args = parser.parse_args()
model_name = args.model_name if args.model_name else config.model_name
model_name = args.model if args.model else config.model_name
make_seed(config.training.seed)
if config.training.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', config.gpu_id)
device = torch.device('cuda', config.training.gpu_id)
else:
device = torch.device('cpu')
@ -64,7 +67,10 @@ model.to(device)
optimizer = optim.Adam(model.parameters(), lr=config.training.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'max', factor=config.training.decay_rate, patience=config.training.decay_patience)
optimizer,
'max',
factor=config.training.decay_rate,
patience=config.training.decay_patience)
criterion = nn.CrossEntropyLoss()
best_macro_f1, best_macro_epoch = 0, 1