update config
This commit is contained in:
parent
a2e5a91f19
commit
b8b01398fc
|
@ -11,6 +11,6 @@ defaults:
|
|||
- preprocess
|
||||
- train
|
||||
- embedding
|
||||
- model: cnn
|
||||
- model: capsule
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# populated at runtime
|
||||
vocab_size: ???
|
||||
word_dim: 50
|
||||
word_dim: 60
|
||||
pos_size: ??? # 2 * pos_limit + 2
|
||||
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ seed: 1
|
|||
use_gpu: True
|
||||
gpu_id: 0
|
||||
|
||||
epoch: 50
|
||||
epoch: 5
|
||||
batch_size: 32
|
||||
learning_rate: 3e-4
|
||||
lr_factor: 0.7 # 学习率的衰减率
|
||||
|
|
|
@ -24,8 +24,9 @@ def collate_fn(cfg):
|
|||
if cfg.model_name != 'lm':
|
||||
head_pos.append(_padding(data['head_pos'], max_len))
|
||||
tail_pos.append(_padding(data['tail_pos'], max_len))
|
||||
if cfg.use_pcnn:
|
||||
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
||||
|
||||
x['word'] = torch.tensor(word)
|
||||
x['lens'] = torch.tensor(word_len)
|
||||
|
|
5
main.py
5
main.py
|
@ -27,6 +27,11 @@ def main(cfg):
|
|||
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# device
|
||||
|
|
|
@ -1,2 +1,7 @@
|
|||
from .BasicModule import BasicModule
|
||||
from .PCNN import PCNN
|
||||
from .BiLSTM import BiLSTM
|
||||
from .Transformer import Transformer
|
||||
from .Capsule import Capsule
|
||||
from .GCN import GCN
|
||||
from .LM import LM
|
|
@ -30,11 +30,12 @@ def _add_pos_seq(train_data: List[Dict], cfg):
|
|||
d['tail_pos'] = list(map(lambda i: i - d['tail_idx'], list(range(d['seq_len']))))
|
||||
d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))
|
||||
|
||||
if cfg.use_pcnn:
|
||||
# 当句子无法分隔成三段时,无法使用PCNN
|
||||
# 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段
|
||||
d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\
|
||||
[3] * (d['seq_len'] - entities_idx[1])
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
# 当句子无法分隔成三段时,无法使用PCNN
|
||||
# 比如: [head, ... tail] or [... head, tail, ...] 无法使用统一方式 mask 分段
|
||||
d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\
|
||||
[3] * (d['seq_len'] - entities_idx[1])
|
||||
|
||||
|
||||
def _convert_tokens_into_index(data: List[Dict], vocab):
|
||||
|
|
Loading…
Reference in New Issue