update config
This commit is contained in:
parent
a2e5a91f19
commit
b8b01398fc
|
@ -11,6 +11,6 @@ defaults:
|
||||||
- preprocess
|
- preprocess
|
||||||
- train
|
- train
|
||||||
- embedding
|
- embedding
|
||||||
- model: cnn
|
- model: capsule
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# populated at runtime
|
# populated at runtime
|
||||||
vocab_size: ???
|
vocab_size: ???
|
||||||
word_dim: 50
|
word_dim: 60
|
||||||
pos_size: ??? # 2 * pos_limit + 2
|
pos_size: ??? # 2 * pos_limit + 2
|
||||||
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同
|
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ seed: 1
|
||||||
use_gpu: True
|
use_gpu: True
|
||||||
gpu_id: 0
|
gpu_id: 0
|
||||||
|
|
||||||
epoch: 50
|
epoch: 5
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
learning_rate: 3e-4
|
learning_rate: 3e-4
|
||||||
lr_factor: 0.7 # 学习率的衰减率
|
lr_factor: 0.7 # 学习率的衰减率
|
||||||
|
|
|
@ -24,6 +24,7 @@ def collate_fn(cfg):
|
||||||
if cfg.model_name != 'lm':
|
if cfg.model_name != 'lm':
|
||||||
head_pos.append(_padding(data['head_pos'], max_len))
|
head_pos.append(_padding(data['head_pos'], max_len))
|
||||||
tail_pos.append(_padding(data['tail_pos'], max_len))
|
tail_pos.append(_padding(data['tail_pos'], max_len))
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
if cfg.use_pcnn:
|
if cfg.use_pcnn:
|
||||||
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
||||||
|
|
||||||
|
|
5
main.py
5
main.py
|
@ -27,6 +27,11 @@ def main(cfg):
|
||||||
|
|
||||||
__Model__ = {
|
__Model__ = {
|
||||||
'cnn': models.PCNN,
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
}
|
}
|
||||||
|
|
||||||
# device
|
# device
|
||||||
|
|
|
@ -1,2 +1,7 @@
|
||||||
from .BasicModule import BasicModule
|
from .BasicModule import BasicModule
|
||||||
from .PCNN import PCNN
|
from .PCNN import PCNN
|
||||||
|
from .BiLSTM import BiLSTM
|
||||||
|
from .Transformer import Transformer
|
||||||
|
from .Capsule import Capsule
|
||||||
|
from .GCN import GCN
|
||||||
|
from .LM import LM
|
|
@ -30,6 +30,7 @@ def _add_pos_seq(train_data: List[Dict], cfg):
|
||||||
d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len']))))
|
d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len']))))
|
||||||
d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))
|
d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))
|
||||||
|
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
if cfg.use_pcnn:
|
if cfg.use_pcnn:
|
||||||
# 当句子无法分隔成三段时,无法使用PCNN
|
# 当句子无法分隔成三段时,无法使用PCNN
|
||||||
# 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段
|
# 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段
|
||||||
|
|
Loading…
Reference in New Issue