diff --git a/conf/model/cnn.yaml b/conf/model/cnn.yaml index 3d8c72f..ef15d2b 100644 --- a/conf/model/cnn.yaml +++ b/conf/model/cnn.yaml @@ -5,6 +5,7 @@ out_channels: 100 kernel_sizes: [3, 5, 7] # 必须为奇数,为了保证cnn的输出不改变句子长度 activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh] pooling_strategy: 'max' # [max, avg, cls] +keep_length: True dropout: 0.3 # pcnn diff --git a/module/CNN.py b/module/CNN.py index cb3448a..8d43366 100644 --- a/module/CNN.py +++ b/module/CNN.py @@ -15,6 +15,7 @@ class GELU(nn.Module): class CNN(nn.Module): """ nlp 里为了保证输出的句长 = 输入的句长,一般使用奇数 kernel_size,如 [3, 5, 7, 9] + 当然也可以不等长输出,keep_length 设为 False 此时,padding = k // 2 stride 一般为 1 """ @@ -36,6 +37,7 @@ class CNN(nn.Module): self.activation = config.activation self.pooling_strategy = config.pooling_strategy self.dropout = config.dropout + self.keep_length = config.keep_length for kernel_size in self.kernel_sizes: assert kernel_size % 2 == 1, "kernel size has to be odd numbers." @@ -45,7 +47,7 @@ class CNN(nn.Module): out_channels=self.out_channels, kernel_size=k, stride=1, - padding=k // 2, + padding=k // 2 if self.keep_length else 0, dilation=1, groups=1, bias=False) for k in self.kernel_sizes