update cnn
This commit is contained in:
parent
5631f997b5
commit
03a00c202e
|
@ -1,10 +1,10 @@
|
|||
model_name: cnn
|
||||
|
||||
#in_channels: 100 # 使用 embedding 输出的结果,不需要指定
|
||||
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
out_channels: 100
|
||||
kernel_sizes: [3, 5, 7] # 必须为奇数,为了保证cnn的输出不改变句子长度
|
||||
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy: 'max' # [max, avg, cls]
|
||||
kernel_sizes: [3, 5, 7] # 必须为奇数,为了保证cnn的输出不改变句子长度
|
||||
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy: 'max' # [max, avg, cls]
|
||||
dropout: 0.3
|
||||
|
||||
# pcnn
|
||||
|
|
|
@ -11,6 +11,10 @@ class PCNN(BasicModule):
|
|||
super(PCNN, self).__init__()
|
||||
|
||||
self.use_pcnn = cfg.use_pcnn
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.in_channels = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.cnn = CNN(cfg)
|
||||
|
|
|
@ -30,12 +30,7 @@ class CNN(nn.Module):
|
|||
super(CNN, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
# self.in_channels = config.in_channels
|
||||
if config.dim_strategy == 'cat':
|
||||
self.in_channels = config.word_dim + 2 * config.pos_dim
|
||||
else:
|
||||
self.in_channels = config.word_dim
|
||||
|
||||
self.in_channels = config.in_channels
|
||||
self.out_channels = config.out_channels
|
||||
self.kernel_sizes = config.kernel_sizes
|
||||
self.activation = config.activation
|
||||
|
|
|
@ -3,4 +3,5 @@ from .CNN import CNN
|
|||
from .RNN import RNN
|
||||
from .Attention import DotAttention, MultiHeadAttention
|
||||
from .Transformer import Transformer
|
||||
from .Capsule import Capsule
|
||||
from .Capsule import Capsule
|
||||
from .GCN import GCN
|
Loading…
Reference in New Issue