diff --git a/conf/config.yaml b/conf/config.yaml index edf9c0e..bf6dce9 100644 --- a/conf/config.yaml +++ b/conf/config.yaml @@ -11,6 +11,6 @@ defaults: - preprocess - train - embedding - - model: cnn + - model: capsule diff --git a/conf/embedding.yaml b/conf/embedding.yaml index 82da364..8a91bbe 100644 --- a/conf/embedding.yaml +++ b/conf/embedding.yaml @@ -1,6 +1,6 @@ # populated at runtime vocab_size: ??? -word_dim: 50 +word_dim: 60 pos_size: ??? # 2 * pos_limit + 2 pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同 diff --git a/conf/train.yaml b/conf/train.yaml index 7bf4624..d7b25f4 100644 --- a/conf/train.yaml +++ b/conf/train.yaml @@ -3,7 +3,7 @@ seed: 1 use_gpu: True gpu_id: 0 -epoch: 50 +epoch: 5 batch_size: 32 learning_rate: 3e-4 lr_factor: 0.7 # 学习率的衰减率 diff --git a/dataset.py b/dataset.py index c59f728..c55bdb1 100644 --- a/dataset.py +++ b/dataset.py @@ -24,8 +24,9 @@ def collate_fn(cfg): if cfg.model_name != 'lm': head_pos.append(_padding(data['head_pos'], max_len)) tail_pos.append(_padding(data['tail_pos'], max_len)) - if cfg.use_pcnn: - pcnn_mask.append(_padding(data['entities_pos'], max_len)) + if cfg.model_name == 'cnn': + if cfg.use_pcnn: + pcnn_mask.append(_padding(data['entities_pos'], max_len)) x['word'] = torch.tensor(word) x['lens'] = torch.tensor(word_len) diff --git a/main.py b/main.py index 6d736c6..3df28aa 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,11 @@ def main(cfg): __Model__ = { 'cnn': models.PCNN, + 'rnn': models.BiLSTM, + 'transformer': models.Transformer, + 'gcn': models.GCN, + 'capsule': models.Capsule, + 'lm': models.LM, } # device diff --git a/models/__init__.py b/models/__init__.py index 90cd98b..7d049fc 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,7 @@ from .BasicModule import BasicModule from .PCNN import PCNN +from .BiLSTM import BiLSTM +from .Transformer import Transformer +from .Capsule import Capsule +from .GCN import GCN +from .LM import LM \ No newline at end of file diff --git a/preprocess.py b/preprocess.py index 081dfca..866039c 100644 --- a/preprocess.py +++ b/preprocess.py @@ -30,11 +30,12 @@ def _add_pos_seq(train_data: List[Dict], cfg): d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len'])))) d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit)) - if cfg.use_pcnn: - # 当句子无法分隔成三段时,无法使用PCNN - # 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段 - d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\ - [3] * (d['seq_len'] - entities_idx[1]) + if cfg.model_name == 'cnn': + if cfg.use_pcnn: + # 当句子无法分隔成三段时,无法使用PCNN + # 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段 + d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\ + [3] * (d['seq_len'] - entities_idx[1]) def _convert_tokens_into_index(data: List[Dict], vocab):