63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
|
|
def residual_connection(input, layer):
|
|
"""residual connection, only used for single input-single output layer.
|
|
y = x + F(x) where F corresponds to the layer.
|
|
|
|
Args:
|
|
x (Tensor): the input tensor.
|
|
layer (callable): a callable that preserve tensor shape.
|
|
"""
|
|
return input + layer(input)
|
|
|
|
class ResidualWrapper(nn.Layer):
|
|
def __init__(self, layer):
|
|
super(ResidualWrapper, self).__init__()
|
|
self.layer = layer
|
|
|
|
def forward(self, x):
|
|
return residual_connection(x, self.layer)
|
|
|
|
|
|
class PreLayerNormWrapper(nn.Layer):
|
|
def __init__(self, layer, d_model):
|
|
super(PreLayerNormWrapper, self).__init__()
|
|
self.layer = layer
|
|
self.layer_norm = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
def forward(self, x):
|
|
return x + self.layer(self.layer_norm(x))
|
|
|
|
|
|
class PostLayerNormWrapper(nn.Layer):
|
|
def __init__(self, layer, d_model):
|
|
super(PostLayerNormWrapper, self).__init__()
|
|
self.layer = layer
|
|
self.layer_norm = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
|
|
def forward(self, x):
|
|
return self.layer_norm(x + self.layer(x))
|
|
|
|
|
|
def context_gate(input, axis):
|
|
"""sigmoid gate the content by gate.
|
|
|
|
Args:
|
|
input (Tensor): shape(*, d_axis, *), the input, treated as content & gate.
|
|
axis (int): the axis to chunk content and gate.
|
|
|
|
Raises:
|
|
ValueError: if input.shape[axis] is not even.
|
|
|
|
Returns:
|
|
Tensor: shape(*, d_axis / 2 , *), the gated content.
|
|
"""
|
|
size = input.shape[axis]
|
|
if size % 2 != 0:
|
|
raise ValueError("the size of the {}-th dimension of input should "
|
|
"be even, but received {}".format(axis, size))
|
|
content, gate = paddle.chunk(input, 2, axis)
|
|
return F.sigmoid(gate) * content
|