update cnn

add keep_length param
This commit is contained in:
leo 2019-12-05 20:47:18 +08:00
parent c0f508b24b
commit 355a5883a6
2 changed files with 4 additions and 1 deletions

View File

@ -5,6 +5,7 @@ out_channels: 100
kernel_sizes: [3, 5, 7] # 必须为奇数为了保证cnn的输出不改变句子长度
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy: 'max' # [max, avg, cls]
keep_length: True
dropout: 0.3
# pcnn

View File

@ -15,6 +15,7 @@ class GELU(nn.Module):
class CNN(nn.Module):
"""
nlp 里为了保证输出的句长 = 输入的句长一般使用奇数 kernel_size [3, 5, 7, 9]
当然也可以不等长输出keep_length 设为 False
此时padding = k // 2
stride 一般为 1
"""
@ -36,6 +37,7 @@ class CNN(nn.Module):
self.activation = config.activation
self.pooling_strategy = config.pooling_strategy
self.dropout = config.dropout
self.keep_length = config.keep_length
for kernel_size in self.kernel_sizes:
assert kernel_size % 2 == 1, "kernel size has to be odd numbers."
@ -45,7 +47,7 @@ class CNN(nn.Module):
out_channels=self.out_channels,
kernel_size=k,
stride=1,
padding=k // 2,
padding=k // 2 if self.keep_length else 0,
dilation=1,
groups=1,
bias=False) for k in self.kernel_sizes