From 567673e7acd2f7f41851ac1cf6cd77eff498183b Mon Sep 17 00:00:00 2001 From: leo Date: Mon, 16 Sep 2019 19:48:24 +0800 Subject: [PATCH] update --- .gitignore | 1 + deepke/config.py | 29 +++++-------------- deepke/dataset.py | 4 +-- deepke/model/CNN.py | 2 +- deepke/preprocess.py | 69 +++++++++++++++++++++++++------------------- deepke/trainer.py | 8 ++--- main.py | 14 ++++++--- 7 files changed, 65 insertions(+), 62 deletions(-) diff --git a/.gitignore b/.gitignore index aa5e8d5..5eb72ac 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ checkpoints/*.pth bert_pretrained/* !bert_pretrained/readme.md +data/out data/origin/statics.py demo.py diff --git a/deepke/config.py b/deepke/config.py index dcfa108..e8bfb35 100644 --- a/deepke/config.py +++ b/deepke/config.py @@ -25,11 +25,11 @@ class ModelConfig(object): class CNNConfig(object): use_pcnn = True out_channels = 100 - kernel_size = [3, 5] + kernel_size = [3, 5, 7] class RNNConfig(object): - lstm_layers = 2 + lstm_layers = 3 last_hn = False @@ -39,7 +39,7 @@ class GCNConfig(object): class TransformerConfig(object): - transformer_layers = 2 + transformer_layers = 3 class CapsuleConfig(object): @@ -48,7 +48,7 @@ class CapsuleConfig(object): primary_channels = 1 primary_unit_size = 768 output_unit_size = 128 - num_iterations = 5 + num_iterations = 3 class LMConfig(object): @@ -62,6 +62,8 @@ class Config(object): # 预处理后存放文件的位置 out_path = 'data/out' + # 是否将句子中实体替换为实体类型 + replace_entity_by_type = True # 是否为中文数据 is_chinese = True # 是否需要分词操作 @@ -77,7 +79,7 @@ class Config(object): pos_limit = 50 # [-50, 50] # (CNN, RNN, GCN, Transformer, Capsule, LM) - model_name = 'CNN' + model_name = 'Capsule' training = TrainingConfig() model = ModelConfig() @@ -88,22 +90,5 @@ class Config(object): capsule = CapsuleConfig() lm = LMConfig() - def parse(self, kwargs, verbose=False): - ''' - user can update the default hyper parameters - ''' - for k, v in kwargs.items(): - if not hasattr(self, k): - raise Exception('opt has No key: {}'.format(k)) - setattr(self, k, v) - - if verbose: - print('*************************************************') - print('user config:') - for k, v in kwargs.items(): - if not k.startswith('__'): - print("{} => {}".format(k, getattr(self, k))) - print('*************************************************') - config = Config() diff --git a/deepke/dataset.py b/deepke/dataset.py index 3faee5a..f0e32b1 100644 --- a/deepke/dataset.py +++ b/deepke/dataset.py @@ -44,8 +44,8 @@ def collate_fn(batch): tail_pos.append(_padding(data['tail_pos'], max_len)) mask_pos.append(_padding(data['mask_pos'], max_len)) y.append(data['target']) - return torch.Tensor(sent), torch.Tensor(head_pos), torch.Tensor( - tail_pos), torch.Tensor(mask_pos), torch.Tensor(y) + return torch.tensor(sent), torch.tensor(head_pos), torch.tensor( + tail_pos), torch.tensor(mask_pos), torch.tensor(y) if __name__ == '__main__': diff --git a/deepke/model/CNN.py b/deepke/model/CNN.py index e0f3896..416f301 100644 --- a/deepke/model/CNN.py +++ b/deepke/model/CNN.py @@ -28,7 +28,7 @@ class CNN(BasicModule): self.pos_dim) # PCNN embedding self.mask_embed = nn.Embedding(4, 3) - masks = torch.Tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, + masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, 100]]) self.mask_embed.weight.data.copy_(masks) self.mask_embed.weight.requires_grad = False diff --git a/deepke/preprocess.py b/deepke/preprocess.py index 600f399..4aa26db 100644 --- a/deepke/preprocess.py +++ b/deepke/preprocess.py @@ -40,29 +40,20 @@ def _pos_feature(sent_len: int, entity_idx: int, entity_len: int, def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]: - if vocab.name == 'LM': for d in data: - d['seq_len'] = len(d['lm_idx']) d['target'] = relations[d['relation']] return data for d in data: + word2idx = [vocab.word2idx.get(w, 1) for w in d['sentence']] + seq_len = len(word2idx) + head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset']) if vocab.name == 'word': - word2idx = [vocab.word2idx.get(w, 1) for w in d['words']] - seq_len = len(word2idx) - head_idx, tail_idx = d['head_idx'], d['tail_idx'] head_len, tail_len = 1, 1 - - elif vocab.name == 'char': - word2idx = [ - vocab.word2idx.get(w, 1) for w in d['sentence'].strip() - ] - seq_len = len(word2idx) - head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset']) - head_len, tail_len = len(d['head']), len(d['tail']) - + else: + head_len, tail_len = len(d['head_type']), len(d['tail_type']) entities_idx = [head_idx, tail_idx ] if tail_idx > head_idx else [tail_idx, head_idx] head_pos = _pos_feature(seq_len, head_idx, head_len, config.pos_limit) @@ -83,12 +74,11 @@ def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]: def _build_vocab(data: List[Dict], out_path: Path) -> Vocab: if config.word_segment: vocab = Vocab('word') - for d in data: - vocab.add_sent(d['words']) else: vocab = Vocab('char') - for d in data: - vocab.add_sent(d['sentence'].strip()) + + for d in data: + vocab.add_sent(d['sentence']) vocab.trim(config.min_freq) ensure_dir(out_path) @@ -108,29 +98,44 @@ def _split_sent(data: List[Dict], verbose: bool = True) -> List[Dict]: jieba.add_word('TAIL') for d in data: - sent = d['sentence'].strip() - sent = sent.replace(d['head'], 'HEAD', 1) - sent = sent.replace(d['tail'], 'TAIL', 1) + sent = d['sentence'] + sent = sent.replace(d['head_type'], 'HEAD', 1) + sent = sent.replace(d['tail_type'], 'TAIL', 1) sent = jieba.lcut(sent) head_idx, tail_idx = sent.index('HEAD'), sent.index('TAIL') - sent[head_idx], sent[tail_idx] = d['head'], d['tail'] - d['words'] = sent - d['head_idx'] = head_idx - d['tail_idx'] = tail_idx + sent[head_idx], sent[tail_idx] = d['head_type'], d['tail_type'] + d['sentence'] = sent + d['head_offset'] = head_idx + d['tail_offset'] = tail_idx + return data def _add_lm_data(data: List[Dict]) -> List[Dict]: '使用语言模型的词表,序列化输入的句子' - tokenizer = BertTokenizer.from_pretrained('../bert_pretrained') + tokenizer = BertTokenizer.from_pretrained(config.lm.lm_file) + for d in data: + sent = d['sentence'] + sent += '[SEP]' + d['head'] + '[SEP]' + d['tail'] + + d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True) + d['seq_len'] = len(d['lm_idx']) + + return data + + +def _replace_entity_by_type(data: List[Dict]) -> List[Dict]: for d in data: sent = d['sentence'].strip() - d['seq_len'] = len(sent) sent = sent.replace(d['head'], d['head_type'], 1) sent = sent.replace(d['tail'], d['tail_type'], 1) - sent += '[SEP]' + d['head'] + '[SEP]' + d['tail'] - d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True) + head_offset = sent.index(d['head_type']) + tail_offset = sent.index(d['tail_type']) + + d['sentence'] = sent + d['head_offset'] = head_offset + d['tail_offset'] = tail_offset return data @@ -163,6 +168,12 @@ def process(data_path: Path, out_path: Path) -> None: test_raw_data = load_csv(test_fp) relations = _load_relations(relation_fp) + # 使用 entity type 替换句子中的 entity + # 这样训练效果会提升很多 + if config.replace_entity_by_type: + train_raw_data = _replace_entity_by_type(train_raw_data) + test_raw_data = _replace_entity_by_type(test_raw_data) + # 使用预训练语言模型时 if config.model_name == 'LM': print('\nuse pretrained language model serialize sentence...') diff --git a/deepke/trainer.py b/deepke/trainer.py index 6241132..0c7f651 100644 --- a/deepke/trainer.py +++ b/deepke/trainer.py @@ -28,14 +28,14 @@ def train(epoch, device, dataloader, model, optimizer, criterion, config): # logging data_cal = len(dataloader.dataset) if batch_idx == len( dataloader) else batch_idx * len(y) - if (config.train_log and batch_idx % - config.log_interval == 0) or batch_idx == len(dataloader): + if (config.training.train_log and batch_idx % + config.training.log_interval == 0) or batch_idx == len(dataloader): print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, data_cal, len(dataloader.dataset), 100. * batch_idx / len(dataloader), loss.item())) # plot - if config.show_plot: + if config.training.show_plot: plt.plot(total_loss) plt.title('loss') plt.show() @@ -65,7 +65,7 @@ def validate(dataloader, model, device, config): total_y_pred = np.append(total_y_pred, y_pred) total_f1 = [] - for average in config.f1_norm: + for average in config.training.f1_norm: p, r, f1, _ = precision_recall_fscore_support(total_y_true, total_y_pred, average=average) diff --git a/main.py b/main.py index b9f69c7..3d02ffc 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import os import argparse +import warnings import torch import torch.nn as nn import torch.optim as optim @@ -11,6 +12,8 @@ from deepke.trainer import train, validate from deepke.preprocess import process from deepke.dataset import CustomDataset, collate_fn +warnings.filterwarnings("ignore") + __Models__ = { "CNN": model.CNN, "RNN": model.BiLSTM, @@ -21,14 +24,14 @@ __Models__ = { } parser = argparse.ArgumentParser(description='choose your model') -parser.add_argument('--model_name', type=str, help='model name') +parser.add_argument('--model', type=str, help='model name') args = parser.parse_args() -model_name = args.model_name if args.model_name else config.model_name +model_name = args.model if args.model else config.model_name make_seed(config.training.seed) if config.training.use_gpu and torch.cuda.is_available(): - device = torch.device('cuda', config.gpu_id) + device = torch.device('cuda', config.training.gpu_id) else: device = torch.device('cpu') @@ -64,7 +67,10 @@ model.to(device) optimizer = optim.Adam(model.parameters(), lr=config.training.learning_rate) scheduler = optim.lr_scheduler.ReduceLROnPlateau( - optimizer, 'max', factor=config.training.decay_rate, patience=config.training.decay_patience) + optimizer, + 'max', + factor=config.training.decay_rate, + patience=config.training.decay_patience) criterion = nn.CrossEntropyLoss() best_macro_f1, best_macro_epoch = 0, 1