update config

This commit is contained in:
leo 2019-12-03 22:36:59 +08:00
parent a2e5a91f19
commit b8b01398fc
7 changed files with 22 additions and 10 deletions

View File

@ -11,6 +11,6 @@ defaults:
- preprocess
- train
- embedding
- model: cnn
- model: capsule

View File

@ -1,6 +1,6 @@
# populated at runtime
vocab_size: ???
word_dim: 50
word_dim: 60
pos_size: ??? # 2 * pos_limit + 2
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同

View File

@ -3,7 +3,7 @@ seed: 1
use_gpu: True
gpu_id: 0
epoch: 50
epoch: 5
batch_size: 32
learning_rate: 3e-4
lr_factor: 0.7 # 学习率的衰减率

View File

@ -24,6 +24,7 @@ def collate_fn(cfg):
if cfg.model_name != 'lm':
head_pos.append(_padding(data['head_pos'], max_len))
tail_pos.append(_padding(data['tail_pos'], max_len))
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
pcnn_mask.append(_padding(data['entities_pos'], max_len))

View File

@ -27,6 +27,11 @@ def main(cfg):
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# device

View File

@ -1,2 +1,7 @@
from .BasicModule import BasicModule
from .PCNN import PCNN
from .BiLSTM import BiLSTM
from .Transformer import Transformer
from .Capsule import Capsule
from .GCN import GCN
from .LM import LM

View File

@ -30,6 +30,7 @@ def _add_pos_seq(train_data: List[Dict], cfg):
d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len']))))
d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
# 当句子无法分隔成三段时无法使用PCNN
# 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段