update cnn

This commit is contained in:
leo 2019-12-03 22:42:15 +08:00
parent 5631f997b5
commit 03a00c202e
4 changed files with 11 additions and 11 deletions

View File

@ -1,10 +1,10 @@
model_name: cnn
#in_channels: 100 # 使用 embedding 输出的结果,不需要指定
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
out_channels: 100
kernel_sizes: [3, 5, 7] # 必须为奇数为了保证cnn的输出不改变句子长度
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy: 'max' # [max, avg, cls]
kernel_sizes: [3, 5, 7] # 必须为奇数为了保证cnn的输出不改变句子长度
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy: 'max' # [max, avg, cls]
dropout: 0.3
# pcnn

View File

@ -11,6 +11,10 @@ class PCNN(BasicModule):
super(PCNN, self).__init__()
self.use_pcnn = cfg.use_pcnn
if cfg.dim_strategy == 'cat':
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
else:
cfg.in_channels = cfg.word_dim
self.embedding = Embedding(cfg)
self.cnn = CNN(cfg)

View File

@ -30,12 +30,7 @@ class CNN(nn.Module):
super(CNN, self).__init__()
# self.xxx = config.xxx
# self.in_channels = config.in_channels
if config.dim_strategy == 'cat':
self.in_channels = config.word_dim + 2 * config.pos_dim
else:
self.in_channels = config.word_dim
self.in_channels = config.in_channels
self.out_channels = config.out_channels
self.kernel_sizes = config.kernel_sizes
self.activation = config.activation

View File

@ -3,4 +3,5 @@ from .CNN import CNN
from .RNN import RNN
from .Attention import DotAttention, MultiHeadAttention
from .Transformer import Transformer
from .Capsule import Capsule
from .Capsule import Capsule
from .GCN import GCN