ParakeetRebeccaRosario/parakeet/modules/connections.py

63 lines
1.8 KiB
Python

import paddle
from paddle import nn
from paddle.nn import functional as F
def residual_connection(input, layer):
"""residual connection, only used for single input-single output layer.
y = x + F(x) where F corresponds to the layer.
Args:
x (Tensor): the input tensor.
layer (callable): a callable that preserve tensor shape.
"""
return input + layer(input)
class ResidualWrapper(nn.Layer):
def __init__(self, layer):
super(ResidualWrapper, self).__init__()
self.layer = layer
def forward(self, x):
return residual_connection(x, self.layer)
class PreLayerNormWrapper(nn.Layer):
def __init__(self, layer, d_model):
super(PreLayerNormWrapper, self).__init__()
self.layer = layer
self.layer_norm = nn.LayerNorm([d_model], epsilon=1e-6)
def forward(self, x):
return x + self.layer(self.layer_norm(x))
class PostLayerNormWrapper(nn.Layer):
def __init__(self, layer, d_model):
super(PostLayerNormWrapper, self).__init__()
self.layer = layer
self.layer_norm = nn.LayerNorm([d_model], epsilon=1e-6)
def forward(self, x):
return self.layer_norm(x + self.layer(x))
def context_gate(input, axis):
"""sigmoid gate the content by gate.
Args:
input (Tensor): shape(*, d_axis, *), the input, treated as content & gate.
axis (int): the axis to chunk content and gate.
Raises:
ValueError: if input.shape[axis] is not even.
Returns:
Tensor: shape(*, d_axis / 2 , *), the gated content.
"""
size = input.shape[axis]
if size % 2 != 0:
raise ValueError("the size of the {}-th dimension of input should "
"be even, but received {}".format(axis, size))
content, gate = paddle.chunk(input, 2, axis)
return F.sigmoid(gate) * content