parent
c0f508b24b
commit
355a5883a6
|
@ -5,6 +5,7 @@ out_channels: 100
|
|||
kernel_sizes: [3, 5, 7] # 必须为奇数,为了保证cnn的输出不改变句子长度
|
||||
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy: 'max' # [max, avg, cls]
|
||||
keep_length: True
|
||||
dropout: 0.3
|
||||
|
||||
# pcnn
|
||||
|
|
|
@ -15,6 +15,7 @@ class GELU(nn.Module):
|
|||
class CNN(nn.Module):
|
||||
"""
|
||||
nlp 里为了保证输出的句长 = 输入的句长,一般使用奇数 kernel_size,如 [3, 5, 7, 9]
|
||||
当然也可以不等长输出,keep_length 设为 False
|
||||
此时,padding = k // 2
|
||||
stride 一般为 1
|
||||
"""
|
||||
|
@ -36,6 +37,7 @@ class CNN(nn.Module):
|
|||
self.activation = config.activation
|
||||
self.pooling_strategy = config.pooling_strategy
|
||||
self.dropout = config.dropout
|
||||
self.keep_length = config.keep_length
|
||||
for kernel_size in self.kernel_sizes:
|
||||
assert kernel_size % 2 == 1, "kernel size has to be odd numbers."
|
||||
|
||||
|
@ -45,7 +47,7 @@ class CNN(nn.Module):
|
|||
out_channels=self.out_channels,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2,
|
||||
padding=k // 2 if self.keep_length else 0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False) for k in self.kernel_sizes
|
||||
|
|
Loading…
Reference in New Issue