update
This commit is contained in:
parent
b75312d53f
commit
567673e7ac
|
@ -11,6 +11,7 @@ checkpoints/*.pth
|
|||
bert_pretrained/*
|
||||
!bert_pretrained/readme.md
|
||||
|
||||
data/out
|
||||
data/origin/statics.py
|
||||
|
||||
demo.py
|
||||
|
|
|
@ -25,11 +25,11 @@ class ModelConfig(object):
|
|||
class CNNConfig(object):
|
||||
use_pcnn = True
|
||||
out_channels = 100
|
||||
kernel_size = [3, 5]
|
||||
kernel_size = [3, 5, 7]
|
||||
|
||||
|
||||
class RNNConfig(object):
|
||||
lstm_layers = 2
|
||||
lstm_layers = 3
|
||||
last_hn = False
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ class GCNConfig(object):
|
|||
|
||||
|
||||
class TransformerConfig(object):
|
||||
transformer_layers = 2
|
||||
transformer_layers = 3
|
||||
|
||||
|
||||
class CapsuleConfig(object):
|
||||
|
@ -48,7 +48,7 @@ class CapsuleConfig(object):
|
|||
primary_channels = 1
|
||||
primary_unit_size = 768
|
||||
output_unit_size = 128
|
||||
num_iterations = 5
|
||||
num_iterations = 3
|
||||
|
||||
|
||||
class LMConfig(object):
|
||||
|
@ -62,6 +62,8 @@ class Config(object):
|
|||
# 预处理后存放文件的位置
|
||||
out_path = 'data/out'
|
||||
|
||||
# 是否将句子中实体替换为实体类型
|
||||
replace_entity_by_type = True
|
||||
# 是否为中文数据
|
||||
is_chinese = True
|
||||
# 是否需要分词操作
|
||||
|
@ -77,7 +79,7 @@ class Config(object):
|
|||
pos_limit = 50 # [-50, 50]
|
||||
|
||||
# (CNN, RNN, GCN, Transformer, Capsule, LM)
|
||||
model_name = 'CNN'
|
||||
model_name = 'Capsule'
|
||||
|
||||
training = TrainingConfig()
|
||||
model = ModelConfig()
|
||||
|
@ -88,22 +90,5 @@ class Config(object):
|
|||
capsule = CapsuleConfig()
|
||||
lm = LMConfig()
|
||||
|
||||
def parse(self, kwargs, verbose=False):
|
||||
'''
|
||||
user can update the default hyper parameters
|
||||
'''
|
||||
for k, v in kwargs.items():
|
||||
if not hasattr(self, k):
|
||||
raise Exception('opt has No key: {}'.format(k))
|
||||
setattr(self, k, v)
|
||||
|
||||
if verbose:
|
||||
print('*************************************************')
|
||||
print('user config:')
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith('__'):
|
||||
print("{} => {}".format(k, getattr(self, k)))
|
||||
print('*************************************************')
|
||||
|
||||
|
||||
config = Config()
|
||||
|
|
|
@ -44,8 +44,8 @@ def collate_fn(batch):
|
|||
tail_pos.append(_padding(data['tail_pos'], max_len))
|
||||
mask_pos.append(_padding(data['mask_pos'], max_len))
|
||||
y.append(data['target'])
|
||||
return torch.Tensor(sent), torch.Tensor(head_pos), torch.Tensor(
|
||||
tail_pos), torch.Tensor(mask_pos), torch.Tensor(y)
|
||||
return torch.tensor(sent), torch.tensor(head_pos), torch.tensor(
|
||||
tail_pos), torch.tensor(mask_pos), torch.tensor(y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -28,7 +28,7 @@ class CNN(BasicModule):
|
|||
self.pos_dim)
|
||||
# PCNN embedding
|
||||
self.mask_embed = nn.Embedding(4, 3)
|
||||
masks = torch.Tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0,
|
||||
masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0,
|
||||
100]])
|
||||
self.mask_embed.weight.data.copy_(masks)
|
||||
self.mask_embed.weight.requires_grad = False
|
||||
|
|
|
@ -40,29 +40,20 @@ def _pos_feature(sent_len: int, entity_idx: int, entity_len: int,
|
|||
|
||||
|
||||
def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]:
|
||||
|
||||
if vocab.name == 'LM':
|
||||
for d in data:
|
||||
d['seq_len'] = len(d['lm_idx'])
|
||||
d['target'] = relations[d['relation']]
|
||||
|
||||
return data
|
||||
|
||||
for d in data:
|
||||
word2idx = [vocab.word2idx.get(w, 1) for w in d['sentence']]
|
||||
seq_len = len(word2idx)
|
||||
head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset'])
|
||||
if vocab.name == 'word':
|
||||
word2idx = [vocab.word2idx.get(w, 1) for w in d['words']]
|
||||
seq_len = len(word2idx)
|
||||
head_idx, tail_idx = d['head_idx'], d['tail_idx']
|
||||
head_len, tail_len = 1, 1
|
||||
|
||||
elif vocab.name == 'char':
|
||||
word2idx = [
|
||||
vocab.word2idx.get(w, 1) for w in d['sentence'].strip()
|
||||
]
|
||||
seq_len = len(word2idx)
|
||||
head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset'])
|
||||
head_len, tail_len = len(d['head']), len(d['tail'])
|
||||
|
||||
else:
|
||||
head_len, tail_len = len(d['head_type']), len(d['tail_type'])
|
||||
entities_idx = [head_idx, tail_idx
|
||||
] if tail_idx > head_idx else [tail_idx, head_idx]
|
||||
head_pos = _pos_feature(seq_len, head_idx, head_len, config.pos_limit)
|
||||
|
@ -83,12 +74,11 @@ def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]:
|
|||
def _build_vocab(data: List[Dict], out_path: Path) -> Vocab:
|
||||
if config.word_segment:
|
||||
vocab = Vocab('word')
|
||||
for d in data:
|
||||
vocab.add_sent(d['words'])
|
||||
else:
|
||||
vocab = Vocab('char')
|
||||
for d in data:
|
||||
vocab.add_sent(d['sentence'].strip())
|
||||
|
||||
for d in data:
|
||||
vocab.add_sent(d['sentence'])
|
||||
vocab.trim(config.min_freq)
|
||||
|
||||
ensure_dir(out_path)
|
||||
|
@ -108,29 +98,44 @@ def _split_sent(data: List[Dict], verbose: bool = True) -> List[Dict]:
|
|||
jieba.add_word('TAIL')
|
||||
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
sent = sent.replace(d['head'], 'HEAD', 1)
|
||||
sent = sent.replace(d['tail'], 'TAIL', 1)
|
||||
sent = d['sentence']
|
||||
sent = sent.replace(d['head_type'], 'HEAD', 1)
|
||||
sent = sent.replace(d['tail_type'], 'TAIL', 1)
|
||||
sent = jieba.lcut(sent)
|
||||
head_idx, tail_idx = sent.index('HEAD'), sent.index('TAIL')
|
||||
sent[head_idx], sent[tail_idx] = d['head'], d['tail']
|
||||
d['words'] = sent
|
||||
d['head_idx'] = head_idx
|
||||
d['tail_idx'] = tail_idx
|
||||
sent[head_idx], sent[tail_idx] = d['head_type'], d['tail_type']
|
||||
d['sentence'] = sent
|
||||
d['head_offset'] = head_idx
|
||||
d['tail_offset'] = tail_idx
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _add_lm_data(data: List[Dict]) -> List[Dict]:
|
||||
'使用语言模型的词表,序列化输入的句子'
|
||||
tokenizer = BertTokenizer.from_pretrained('../bert_pretrained')
|
||||
tokenizer = BertTokenizer.from_pretrained(config.lm.lm_file)
|
||||
|
||||
for d in data:
|
||||
sent = d['sentence']
|
||||
sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']
|
||||
|
||||
d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True)
|
||||
d['seq_len'] = len(d['lm_idx'])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _replace_entity_by_type(data: List[Dict]) -> List[Dict]:
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
d['seq_len'] = len(sent)
|
||||
sent = sent.replace(d['head'], d['head_type'], 1)
|
||||
sent = sent.replace(d['tail'], d['tail_type'], 1)
|
||||
sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']
|
||||
d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True)
|
||||
head_offset = sent.index(d['head_type'])
|
||||
tail_offset = sent.index(d['tail_type'])
|
||||
|
||||
d['sentence'] = sent
|
||||
d['head_offset'] = head_offset
|
||||
d['tail_offset'] = tail_offset
|
||||
|
||||
return data
|
||||
|
||||
|
@ -163,6 +168,12 @@ def process(data_path: Path, out_path: Path) -> None:
|
|||
test_raw_data = load_csv(test_fp)
|
||||
relations = _load_relations(relation_fp)
|
||||
|
||||
# 使用 entity type 替换句子中的 entity
|
||||
# 这样训练效果会提升很多
|
||||
if config.replace_entity_by_type:
|
||||
train_raw_data = _replace_entity_by_type(train_raw_data)
|
||||
test_raw_data = _replace_entity_by_type(test_raw_data)
|
||||
|
||||
# 使用预训练语言模型时
|
||||
if config.model_name == 'LM':
|
||||
print('\nuse pretrained language model serialize sentence...')
|
||||
|
|
|
@ -28,14 +28,14 @@ def train(epoch, device, dataloader, model, optimizer, criterion, config):
|
|||
# logging
|
||||
data_cal = len(dataloader.dataset) if batch_idx == len(
|
||||
dataloader) else batch_idx * len(y)
|
||||
if (config.train_log and batch_idx %
|
||||
config.log_interval == 0) or batch_idx == len(dataloader):
|
||||
if (config.training.train_log and batch_idx %
|
||||
config.training.log_interval == 0) or batch_idx == len(dataloader):
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, data_cal, len(dataloader.dataset),
|
||||
100. * batch_idx / len(dataloader), loss.item()))
|
||||
|
||||
# plot
|
||||
if config.show_plot:
|
||||
if config.training.show_plot:
|
||||
plt.plot(total_loss)
|
||||
plt.title('loss')
|
||||
plt.show()
|
||||
|
@ -65,7 +65,7 @@ def validate(dataloader, model, device, config):
|
|||
total_y_pred = np.append(total_y_pred, y_pred)
|
||||
|
||||
total_f1 = []
|
||||
for average in config.f1_norm:
|
||||
for average in config.training.f1_norm:
|
||||
p, r, f1, _ = precision_recall_fscore_support(total_y_true,
|
||||
total_y_pred,
|
||||
average=average)
|
||||
|
|
14
main.py
14
main.py
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import argparse
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
@ -11,6 +12,8 @@ from deepke.trainer import train, validate
|
|||
from deepke.preprocess import process
|
||||
from deepke.dataset import CustomDataset, collate_fn
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
__Models__ = {
|
||||
"CNN": model.CNN,
|
||||
"RNN": model.BiLSTM,
|
||||
|
@ -21,14 +24,14 @@ __Models__ = {
|
|||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='choose your model')
|
||||
parser.add_argument('--model_name', type=str, help='model name')
|
||||
parser.add_argument('--model', type=str, help='model name')
|
||||
args = parser.parse_args()
|
||||
model_name = args.model_name if args.model_name else config.model_name
|
||||
model_name = args.model if args.model else config.model_name
|
||||
|
||||
make_seed(config.training.seed)
|
||||
|
||||
if config.training.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', config.gpu_id)
|
||||
device = torch.device('cuda', config.training.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
|
@ -64,7 +67,10 @@ model.to(device)
|
|||
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.training.learning_rate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, 'max', factor=config.training.decay_rate, patience=config.training.decay_patience)
|
||||
optimizer,
|
||||
'max',
|
||||
factor=config.training.decay_rate,
|
||||
patience=config.training.decay_patience)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
best_macro_f1, best_macro_epoch = 0, 1
|
||||
|
|
Loading…
Reference in New Issue