update config

This commit is contained in:
leo 2019-12-03 22:36:59 +08:00
parent a2e5a91f19
commit b8b01398fc
7 changed files with 22 additions and 10 deletions

View File

@ -11,6 +11,6 @@ defaults:
- preprocess - preprocess
- train - train
- embedding - embedding
- model: cnn - model: capsule

View File

@ -1,6 +1,6 @@
# populated at runtime # populated at runtime
vocab_size: ??? vocab_size: ???
word_dim: 50 word_dim: 60
pos_size: ??? # 2 * pos_limit + 2 pos_size: ??? # 2 * pos_limit + 2
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同 pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同

View File

@ -3,7 +3,7 @@ seed: 1
use_gpu: True use_gpu: True
gpu_id: 0 gpu_id: 0
epoch: 50 epoch: 5
batch_size: 32 batch_size: 32
learning_rate: 3e-4 learning_rate: 3e-4
lr_factor: 0.7 # 学习率的衰减率 lr_factor: 0.7 # 学习率的衰减率

View File

@ -24,6 +24,7 @@ def collate_fn(cfg):
if cfg.model_name != 'lm': if cfg.model_name != 'lm':
head_pos.append(_padding(data['head_pos'], max_len)) head_pos.append(_padding(data['head_pos'], max_len))
tail_pos.append(_padding(data['tail_pos'], max_len)) tail_pos.append(_padding(data['tail_pos'], max_len))
if cfg.model_name == 'cnn':
if cfg.use_pcnn: if cfg.use_pcnn:
pcnn_mask.append(_padding(data['entities_pos'], max_len)) pcnn_mask.append(_padding(data['entities_pos'], max_len))

View File

@ -27,6 +27,11 @@ def main(cfg):
__Model__ = { __Model__ = {
'cnn': models.PCNN, 'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
} }
# device # device

View File

@ -1,2 +1,7 @@
from .BasicModule import BasicModule from .BasicModule import BasicModule
from .PCNN import PCNN from .PCNN import PCNN
from .BiLSTM import BiLSTM
from .Transformer import Transformer
from .Capsule import Capsule
from .GCN import GCN
from .LM import LM

View File

@ -30,6 +30,7 @@ def _add_pos_seq(train_data: List[Dict], cfg):
d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len'])))) d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len']))))
d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit)) d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))
if cfg.model_name == 'cnn':
if cfg.use_pcnn: if cfg.use_pcnn:
# 当句子无法分隔成三段时无法使用PCNN # 当句子无法分隔成三段时无法使用PCNN
# 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段 # 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段