Merge pull request #59 from iclementine/doc

update docstrings
This commit is contained in:
Feiyu Chan 2020-12-18 16:12:56 +08:00 committed by GitHub
commit dd2c5cc6c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 610 additions and 235 deletions

View File

@ -18,14 +18,13 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
def scaled_dot_product_attention(q, def scaled_dot_product_attention(q,
k, k,
v, v,
mask=None, mask=None,
dropout=0.0, dropout=0.0,
training=True): training=True):
"""Scaled dot product attention with masking. r"""Scaled dot product attention with masking.
Assume that q, k, v all have the same leading dimensions (denoted as * in Assume that q, k, v all have the same leading dimensions (denoted as * in
descriptions below). Dropout is applied to attention weights before descriptions below). Dropout is applied to attention weights before
@ -34,24 +33,24 @@ def scaled_dot_product_attention(q,
Parameters Parameters
----------- -----------
q: Tensor [shape=(*, T_q, d)] q : Tensor [shape=(\*, T_q, d)]
the query tensor. the query tensor.
k: Tensor [shape=(*, T_k, d)] k : Tensor [shape=(\*, T_k, d)]
the key tensor. the key tensor.
v: Tensor [shape=(*, T_k, d_v)] v : Tensor [shape=(\*, T_k, d_v)]
the value tensor. the value tensor.
mask: Tensor, [shape=(*, T_q, T_k) or broadcastable shape], optional mask : Tensor, [shape=(\*, T_q, T_k) or broadcastable shape], optional
the mask tensor, zeros correspond to paddings. Defaults to None. the mask tensor, zeros correspond to paddings. Defaults to None.
Returns Returns
---------- ----------
out: Tensor [shape(*, T_q, d_v)] out : Tensor [shape=(\*, T_q, d_v)]
the context vector. the context vector.
attn_weights [Tensor shape(*, T_q, T_k)] attn_weights : Tensor [shape=(\*, T_q, T_k)]
the attention weights. the attention weights.
""" """
d = q.shape[-1] # we only support imperative execution d = q.shape[-1] # we only support imperative execution
@ -67,17 +66,25 @@ def scaled_dot_product_attention(q,
return out, attn_weights return out, attn_weights
def drop_head(x, drop_n_heads, training): def drop_head(x, drop_n_heads, training=True):
""" """Drop n context vectors from multiple ones.
Drop n heads from multiple context vectors.
Args: Parameters
x (Tensor): shape(batch_size, num_heads, time_steps, channels), the input. ----------
drop_n_heads (int): [description] x : Tensor [shape=(batch_size, num_heads, time_steps, channels)]
training ([type]): [description] The input, multiple context vectors.
drop_n_heads : int [0<= drop_n_heads <= num_heads]
Number of vectors to drop.
training : bool
A flag indicating whether it is in training. If `False`, no dropout is
applied.
Returns: Returns
[type]: [description] -------
Tensor
The output.
""" """
if not training or (drop_n_heads == 0): if not training or (drop_n_heads == 0):
return x return x
@ -113,21 +120,30 @@ def _concat_heads(x):
# Standard implementations of Monohead Attention & Multihead Attention # Standard implementations of Monohead Attention & Multihead Attention
class MonoheadAttention(nn.Layer): class MonoheadAttention(nn.Layer):
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None): """Monohead Attention module.
"""
Monohead Attention module.
Args: Parameters
model_dim (int): the feature size of query. ----------
dropout (float, optional): dropout probability of scaled dot product model_dim : int
attention and final context vector. Defaults to 0.0. Feature size of the query.
k_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to dropout : float, optional
model_dim / num_heads. Defaults to None. Dropout probability of scaled dot product attention and final context
v_dim (int, optional): feature size of the key of each scaled dot vector. Defaults to 0.0.
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None. k_dim : int, optional
""" Feature size of the key of each scaled dot product attention. If not
provided, it is set to `model_dim / num_heads`. Defaults to None.
v_dim : int, optional
Feature size of the key of each scaled dot product attention. If not
provided, it is set to `model_dim / num_heads`. Defaults to None.
"""
def __init__(self,
model_dim: int,
dropout: float=0.0,
k_dim: int=None,
v_dim: int=None):
super(MonoheadAttention, self).__init__() super(MonoheadAttention, self).__init__()
k_dim = k_dim or model_dim k_dim = k_dim or model_dim
v_dim = v_dim or model_dim v_dim = v_dim or model_dim
@ -140,20 +156,29 @@ class MonoheadAttention(nn.Layer):
self.dropout = dropout self.dropout = dropout
def forward(self, q, k, v, mask): def forward(self, q, k, v, mask):
""" """Compute context vector and attention weights.
Compute context vector and attention weights.
Args: Parameters
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries. -----------
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys. q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values. The queries.
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
broadcastable shape, dtype: float32 or float64, the mask. k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
The keys.
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
The values.
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
The mask.
Returns: Returns
(out, attention_weights) ----------
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. The context vector.
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
The attention weights.
""" """
q = self.affine_q(q) # (B, T, C) q = self.affine_q(q) # (B, T, C)
k = self.affine_k(k) k = self.affine_k(k)
@ -167,34 +192,39 @@ class MonoheadAttention(nn.Layer):
class MultiheadAttention(nn.Layer): class MultiheadAttention(nn.Layer):
""" """Multihead Attention module.
Multihead scaled dot product attention.
"""
Parameters
-----------
model_dim: int
The feature size of query.
num_heads : int
The number of attention heads.
dropout : float, optional
Dropout probability of scaled dot product attention and final context
vector. Defaults to 0.0.
k_dim : int, optional
Feature size of the key of each scaled dot product attention. If not
provided, it is set to ``model_dim / num_heads``. Defaults to None.
v_dim : int, optional
Feature size of the key of each scaled dot product attention. If not
provided, it is set to ``model_dim / num_heads``. Defaults to None.
Raises
---------
ValueError
If ``model_dim`` is not divisible by ``num_heads``.
"""
def __init__(self, def __init__(self,
model_dim, model_dim: int,
num_heads, num_heads: int,
dropout=0.0, dropout: float=0.0,
k_dim=None, k_dim: int=None,
v_dim=None): v_dim: int=None):
"""
Multihead Attention module.
Args:
model_dim (int): the feature size of query.
num_heads (int): the number of attention heads.
dropout (float, optional): dropout probability of scaled dot product
attention and final context vector. Defaults to 0.0.
k_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
v_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
Raises:
ValueError: if model_dim is not divisible by num_heads
"""
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__()
if model_dim % num_heads != 0: if model_dim % num_heads != 0:
raise ValueError("model_dim must be divisible by num_heads") raise ValueError("model_dim must be divisible by num_heads")
@ -211,20 +241,29 @@ class MultiheadAttention(nn.Layer):
self.dropout = dropout self.dropout = dropout
def forward(self, q, k, v, mask): def forward(self, q, k, v, mask):
""" """Compute context vector and attention weights.
Compute context vector and attention weights.
Args: Parameters
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries. -----------
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys. q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values. The queries.
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
broadcastable shape, dtype: float32 or float64, the mask. k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
The keys.
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
The values.
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
The mask.
Returns: Returns
(out, attention_weights) ----------
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. The context vector.
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
The attention weights.
""" """
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
k = _split_heads(self.affine_k(k), self.num_heads) k = _split_heads(self.affine_k(k), self.num_heads)

View File

@ -8,28 +8,48 @@ __all__ = ["quantize", "dequantize", "STFT"]
def quantize(values, n_bands): def quantize(values, n_bands):
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands). """Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in
[0, n_bands).
Args: Parameters
values (Tensor): dtype: flaot32 or float64. the floating point value. -----------
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans). values : Tensor [dtype: flaot32 or float64]
The floating point value.
n_bands : int
The number of bands. The output integer Tensor's value is in the range
[0, n_bans).
Returns: Returns
Tensor: the quantized tensor, dtype: int64. ----------
Tensor [dtype: int 64]
The quantized tensor.
""" """
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64") quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
return quantized return quantized
def dequantize(quantized, n_bands, dtype=None): def dequantize(quantized, n_bands, dtype=None):
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1). """Linearlly dequantize an integer Tensor into a float Tensor in the range
[-1, 1).
Args: Parameters
quantized (Tensor): dtype: int64. The quantized value in the range [0, n_bands). -----------
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans). quantized : Tensor [dtype: int]
dtype (str, optional): data type of the output. The quantized value in the range [0, n_bands).
Returns:
Tensor: the dequantized tensor, dtype is specified by dtype. n_bands : int
Number of bands. The input integer Tensor's value is in the range
[0, n_bans).
dtype : str, optional
Data type of the output.
Returns
-----------
Tensor
The dequantized tensor, dtype is specified by `dtype`. If `dtype` is
not specified, the default float data type is used.
""" """
dtype = dtype or paddle.get_default_dtype() dtype = dtype or paddle.get_default_dtype()
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0 value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
@ -37,15 +57,36 @@ def dequantize(quantized, n_bands, dtype=None):
class STFT(nn.Layer): class STFT(nn.Layer):
"""A module for computing stft transformation in a differentiable way.
Parameters
------------
n_fft : int
Number of samples in a frame.
hop_length : int
Number of samples shifted between adjacent frames.
win_length : int
Length of the window.
window : str, optional
Name of window function, see `scipy.signal.get_window` for more
details. Defaults to "hanning".
Notes
-----------
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
details.
Given a audio which ``T`` samples, it the STFT transformation outputs a
spectrum with (C, frames) and complex dtype, where ``C = 1 + n_fft / 2``
and ``frames = 1 + T // hop_lenghth``.
Ony ``center`` and ``reflect`` padding is supported now.
"""
def __init__(self, n_fft, hop_length, win_length, window="hanning"): def __init__(self, n_fft, hop_length, win_length, window="hanning"):
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
Args:
n_fft (int): number of samples in a frame.
hop_length (int): number of samples shifted between adjacent frames.
win_length (int): length of the window function.
window (str, optional): name of window function, see `scipy.signal.get_window` for more details. Defaults to "hanning".
"""
super(STFT, self).__init__() super(STFT, self).__init__()
self.hop_length = hop_length self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2 self.n_bin = 1 + n_fft // 2
@ -73,13 +114,18 @@ class STFT(nn.Layer):
def forward(self, x): def forward(self, x):
"""Compute the stft transform. """Compute the stft transform.
Args: Parameters
x (Variable): shape(B, T), dtype flaot32, the input waveform. ------------
x : Tensor [shape=(B, T)]
The input waveform.
Returns: Returns
(real, imag) ------------
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram. (C = 1 + n_fft // 2) real : Tensor [shape=(B, C, 1, frames)]
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram. (C = 1 + n_fft // 2) The real part of the spectrogram.
imag : Tensor [shape=(B, C, 1, frames)]
The image part of the spectrogram.
""" """
# x(batch_size, time_steps) # x(batch_size, time_steps)
# pad it first with reflect mode # pad it first with reflect mode
@ -95,30 +141,34 @@ class STFT(nn.Layer):
return real, imag return real, imag
def power(self, x): def power(self, x):
"""Compute the power spectrogram. """Compute the power spectrum.
Args: Parameters
(real, imag) ------------
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram. x : Tensor [shape=(B, T)]
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram. The input waveform.
Returns: Returns
Variable: shape(B, C, 1, T), dtype flaot32, the power spectrogram. ------------
Tensor [shape=(B, C, 1, T)]
The power spectrum.
""" """
real, imag = self(x) real, imag = self(x)
power = real**2 + imag**2 power = real**2 + imag**2
return power return power
def magnitude(self, x): def magnitude(self, x):
"""Compute the magnitude spectrogram. """Compute the magnitude of the spectrum.
Args: Parameters
(real, imag) ------------
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram. x : Tensor [shape=(B, T)]
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram. The input waveform.
Returns: Returns
Variable: shape(B, C, 1, T), dtype flaot32, the magnitude spectrogram. It is the square root of the power spectrogram. ------------
Tensor [shape=(B, C, 1, T)]
The magnitude of the spectrum.
""" """
power = self.power(x) power = self.power(x)
magnitude = paddle.sqrt(power) magnitude = paddle.sqrt(power)

View File

@ -4,16 +4,25 @@ import paddle
def shuffle_dim(x, axis, perm=None): def shuffle_dim(x, axis, perm=None):
"""Permute input tensor along aixs given the permutation or randomly. """Permute input tensor along aixs given the permutation or randomly.
Args: Parameters
x (Tensor): shape(*, d_{axis}, *), the input tensor. ----------
axis (int): the axis to shuffle. x : Tensor
perm (list[int], ndarray, optional): a permutation of [0, d_{axis}), The input tensor.
the order to reorder the tensor along the `axis`-th dimension, if
not provided, randomly shuffle the `axis`-th dimension. Defaults to axis : int
None. The axis to shuffle.
perm : List[int], ndarray, optional
The order to reorder the tensor along the ``axis``-th dimension.
It is a permutation of ``[0, d)``, where d is the size of the
``axis``-th dimension of the input tensor. If not provided,
a random permutation is used. Defaults to None.
Returns: Returns
Tensor: the shuffled tensor, it has the same shape as x does. ---------
Tensor
The shuffled tensor, which has the same shape as x does.
""" """
size = x.shape[axis] size = x.shape[axis]
if perm is not None and len(perm) != size: if perm is not None and len(perm) != size:

View File

@ -4,29 +4,128 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
__all__ = [
"weighted_mean",
"masked_l1_loss",
"masked_softmax_with_cross_entropy",
"diagonal_loss",
]
def weighted_mean(input, weight): def weighted_mean(input, weight):
"""weighted mean.(It can also be used as masked mean.) """Weighted mean. It can also be used as masked mean.
Args: Parameters
input (Tensor): input tensor, floating point dtype. -----------
weight (Tensor): weight tensor with broadcastable shape. input : Tensor
The input tensor.
weight : Tensor
The weight tensor with broadcastable shape with the input.
Returns: Returns
Tensor: shape(1,), weighted mean tensor with the same dtype as input. ----------
Tensor [shape=(1,)]
Weighted mean tensor with the same dtype as input.
Warnings
---------
This is not a mathematical weighted mean. It performs weighted sum and
simple average.
""" """
weight = paddle.cast(weight, input.dtype) weight = paddle.cast(weight, input.dtype)
return paddle.mean(input * weight) return paddle.mean(input * weight)
def masked_l1_loss(prediction, target, mask): def masked_l1_loss(prediction, target, mask):
"""Compute maksed L1 loss.
Parameters
----------
prediction : Tensor
The prediction.
target : Tensor
The target. The shape should be broadcastable to ``prediction``.
mask : Tensor
The mask. The shape should be broadcatable to the broadcasted shape of
``prediction`` and ``target``.
Returns
-------
Tensor [shape=(1,)]
The masked L1 loss.
"""
abs_error = F.l1_loss(prediction, target, reduction='none') abs_error = F.l1_loss(prediction, target, reduction='none')
return weighted_mean(abs_error, mask) loss = weighted_mean(abs_error, mask)
return loss
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1): def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
ce = F.softmax_with_cross_entropy(logits, label, axis=axis) """Compute masked softmax with cross entropy loss.
return weighted_mean(ce, mask)
def diagonal_loss(attentions, input_lengths, target_lengths, g=0.2, multihead=False): Parameters
"""A metric to evaluate how diagonal a attention distribution is.""" ----------
logits : Tensor
The logits. The ``axis``-th axis is the class dimension.
label : Tensor [dtype: int]
The label. The size of the ``axis``-th axis should be 1.
mask : Tensor
The mask. The shape should be broadcastable to ``label``.
axis : int, optional
The index of the class dimension in the shape of ``logits``, by default
-1.
Returns
-------
Tensor [shape=(1,)]
The masked softmax with cross entropy loss.
"""
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
loss = weighted_mean(ce, mask)
return loss
def diagonal_loss(
attentions,
input_lengths,
target_lengths,
g=0.2,
multihead=False):
"""A metric to evaluate how diagonal a attention distribution is.
It is computed for batch attention distributions. For each attention
distribution, the valid decoder time steps and encoder time steps may
differ.
Parameters
----------
attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
The attention weights from an encoder-decoder structure.
input_lengths : Tensor [shape=(B,)]
The valid length for each encoder output.
target_lengths : Tensor [shape=(B,)]
The valid length for each decoder output.
g : float, optional
[description], by default 0.2.
multihead : bool, optional
A flag indicating whether ``attentions`` is a multihead attention's
attention distribution.
If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by
default False.
Returns
-------
Tensor [shape=(1,)]
The diagonal loss.
"""
W = guided_attentions(input_lengths, target_lengths, g) W = guided_attentions(input_lengths, target_lengths, g)
W_tensor = paddle.to_tensor(W) W_tensor = paddle.to_tensor(W)
if not multihead: if not multihead:

View File

@ -1,32 +1,114 @@
import paddle import paddle
from paddle.fluid.layers import sequence_mask from paddle.fluid.layers import sequence_mask
__all__ = [
"id_mask",
"feature_mask",
"combine_mask",
"future_mask",
]
def id_mask(input, padding_index=0, dtype="bool"): def id_mask(input, padding_index=0, dtype="bool"):
"""Generate mask with input ids.
Those positions where the value equals ``padding_index`` correspond to 0 or
``False``, otherwise, 1 or ``True``.
Parameters
----------
input : Tensor [dtype: int]
The input tensor. It represents the ids.
padding_index : int, optional
The id which represents padding, by default 0.
dtype : str, optional
Data type of the returned mask, by default "bool".
Returns
-------
Tensor
The generate mask. It has the same shape as ``input`` does.
"""
return paddle.cast(input != padding_index, dtype) return paddle.cast(input != padding_index, dtype)
def feature_mask(input, axis, dtype="bool"): def feature_mask(input, axis, dtype="bool"):
"""Compute mask from input features.
For a input features, represented as batched feature vectors, those vectors
which all zeros are considerd padding vectors.
Parameters
----------
input : Tensor [dtype: float]
The input tensor which represents featues.
axis : int
The index of the feature dimension in ``input``. Other dimensions are
considered ``spatial`` dimensions.
dtype : str, optional
Data type of the generated mask, by default "bool"
Returns
-------
Tensor
The geenrated mask with ``spatial`` shape as mentioned above.
It has one less dimension than ``input`` does.
"""
feature_sum = paddle.sum(paddle.abs(input), axis) feature_sum = paddle.sum(paddle.abs(input), axis)
return paddle.cast(feature_sum != 0, dtype) return paddle.cast(feature_sum != 0, dtype)
def combine_mask(padding_mask, no_future_mask):
"""
Combine the padding mask and no future mask for transformer decoder.
Padding mask is used to mask padding positions and no future mask is used
to prevent the decoder to see future information.
Args: def combine_mask(mask1, mask2):
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask. """Combine two mask with multiplication or logical and.
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
Returns: Parameters
Tensor: shape(batch_size, time_steps, time_steps), combined mask. -----------
mask1 : Tensor
The first mask.
mask2 : Tensor
The second mask with broadcastable shape with ``mask1``.
Returns
--------
Tensor
Combined mask.
Notes
------
It is mainly used to combine the padding mask and no future mask for
transformer decoder.
Padding mask is used to mask padding positions of the decoder inputs and
no future mask is used to prevent the decoder to see future information.
""" """
# TODO: to support boolean mask by using logical_and? if mask1.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
if padding_mask.dtype == paddle.fluid.core.VarDesc.VarType.BOOL: return paddle.logical_and(mask1, mask2)
return paddle.logical_and(padding_mask, no_future_mask)
else: else:
return padding_mask * no_future_mask return mask1 * mask2
def future_mask(time_steps, dtype="bool"): def future_mask(time_steps, dtype="bool"):
"""Generate lower triangular mask.
It is used at transformer decoder to prevent the decoder to see future
information.
Parameters
----------
time_steps : int
Decoder time steps.
dtype : str, optional
The data type of the generate mask, by default "bool".
Returns
-------
Tensor
The generated mask.
"""
mask = paddle.tril(paddle.ones([time_steps, time_steps])) mask = paddle.tril(paddle.ones([time_steps, time_steps]))
return paddle.cast(mask, dtype) return paddle.cast(mask, dtype)

View File

@ -3,21 +3,34 @@ import numpy as np
import paddle import paddle
from paddle.nn import functional as F from paddle.nn import functional as F
__all__ = ["positional_encoding"]
def positional_encoding(start_index, length, size, dtype=None): def positional_encoding(start_index, length, size, dtype=None):
""" r"""Generate standard positional encoding matrix.
Generate standard positional encoding.
pe(pos, 2i) = sin(pos / 10000 ** (2i / size)) .. math::
pe(pos, 2i+1) = cos(pos / 10000 ** (2i / size))
Args: pe(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{size}}}) \\
start_index (int): the start index. pe(pos, 2i+1) = cos(\frac{pos}{10000^{\frac{2i}{size}}})
length (int): the length of the positional encoding.
size (int): positional encoding dimension.
Returns: Parameters
encodings (Tensor): shape(length, size), the positional encoding. ----------
start_index : int
The start index.
length : int
The timesteps of the positional encoding to generate.
size : int
Feature size of positional encoding.
Returns
-------
Tensor [shape=(length, size)]
The positional encoding.
Raises
------
ValueError
If ``size`` is not divisible by 2.
""" """
if (size % 2 != 0): if (size % 2 != 0):
raise ValueError("size should be divisible by 2") raise ValueError("size should be divisible by 2")

View File

@ -5,23 +5,35 @@ from paddle.nn import functional as F
from parakeet.modules import attention as attn from parakeet.modules import attention as attn
from parakeet.modules.masking import combine_mask from parakeet.modules.masking import combine_mask
__all__ = [
"PositionwiseFFN",
"TransformerEncoderLayer",
"TransformerDecoderLayer",
]
class PositionwiseFFN(nn.Layer): class PositionwiseFFN(nn.Layer):
""" """A faithful implementation of Position-wise Feed-Forward Network
A faithful implementation of Position-wise Feed-Forward Network
in `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_. in `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
It is basically a 3-layer MLP, with relu actication and dropout in between. It is basically a 2-layer MLP, with relu actication and dropout in between.
Parameters
----------
input_size: int
The feature size of the intput. It is also the feature size of the
output.
hidden_size: int
The hidden size.
dropout: float
The probability of the Dropout applied to the output of the first
layer, by default 0.
""" """
def __init__(self, def __init__(self,
input_size: int, input_size: int,
hidden_size: int, hidden_size: int,
dropout=0.0): dropout=0.0):
"""
Args:
input_size (int): the input feature size.
hidden_size (int): the hidden layer's feature size.
dropout (float, optional): probability of dropout applied to the
output of the first fully connected layer. Defaults to 0.0.
"""
super(PositionwiseFFN, self).__init__() super(PositionwiseFFN, self).__init__()
self.linear1 = nn.Linear(input_size, hidden_size) self.linear1 = nn.Linear(input_size, hidden_size)
self.linear2 = nn.Linear(hidden_size, input_size) self.linear2 = nn.Linear(hidden_size, input_size)
@ -31,13 +43,17 @@ class PositionwiseFFN(nn.Layer):
self.hidden_szie = hidden_size self.hidden_szie = hidden_size
def forward(self, x): def forward(self, x):
"""positionwise feed forward network. r"""Forward pass of positionwise feed forward network.
Args: Parameters
x (Tensor): shape(*, input_size), the input tensor. ----------
x : Tensor [shape=(\*, input_size)]
The input tensor, where ``\*`` means arbitary shape.
Returns: Returns
Tensor: shape(*, input_size), the output tensor. -------
Tensor [shape=(\*, input_size)]
The output tensor.
""" """
l1 = self.dropout(F.relu(self.linear1(x))) l1 = self.dropout(F.relu(self.linear1(x)))
l2 = self.linear2(l1) l2 = self.linear2(l1)
@ -45,56 +61,101 @@ class PositionwiseFFN(nn.Layer):
class TransformerEncoderLayer(nn.Layer): class TransformerEncoderLayer(nn.Layer):
""" """A faithful implementation of Transformer encoder layer in
Transformer encoder layer. `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
Parameters
----------
d_model :int
The feature size of the input. It is also the feature size of the
output.
n_heads : int
The number of heads of self attention (a ``MultiheadAttention``
layer).
d_ffn : int
The hidden size of the positional feed forward network (a
``PositionwiseFFN`` layer).
dropout : float, optional
The probability of the dropout in MultiHeadAttention and
PositionwiseFFN, by default 0.
Notes
------
It uses the PostLN (post layer norm) scheme.
""" """
def __init__(self, d_model, n_heads, d_ffn, dropout=0.): def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
"""
Args:
d_model (int): the feature size of the input, and the output.
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
d_ffn (int): the hidden size of the internal PositionwiseFFN.
dropout (float, optional): the probability of the dropout in
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
"""
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout) self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout) self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
self.dropout = dropout
def forward(self, x, mask): def forward(self, x, mask):
""" """Forward pass of TransformerEncoderLayer.
Args:
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
mask (Tensor): shape(batch_size, time_steps), the padding mask.
Returns: Parameters
(x, attn_weights) ----------
x (Tensor): shape(batch_size, time_steps, d_model), the decoded. x : Tensor [shape=(batch_size, time_steps, d_model)]
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention. The input.
"""
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1)) mask : Tensor
x = self.layer_norm1(x + context_vector) The padding mask. The shape is (batch_size, time_steps,
time_steps) or broadcastable shape.
x = self.layer_norm2(x + self.ffn(x)) Returns
-------
x :Tensor [shape=(batch_size, time_steps, d_model)]
The encoded output.
attn_weights : Tensor [shape=(batch_size, n_heads, time_steps, time_steps)]
The attention weights of the self attention.
"""
context_vector, attn_weights = self.self_mha(x, x, x, mask)
x = self.layer_norm1(
F.dropout(x + context_vector,
self.dropout,
training=self.training))
x = self.layer_norm2(
F.dropout(x + self.ffn(x),
self.dropout,
training=self.training))
return x, attn_weights return x, attn_weights
class TransformerDecoderLayer(nn.Layer): class TransformerDecoderLayer(nn.Layer):
""" """A faithful implementation of Transformer decoder layer in
Transformer decoder layer. `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
Parameters
----------
d_model :int
The feature size of the input. It is also the feature size of the
output.
n_heads : int
The number of heads of attentions (``MultiheadAttention``
layers).
d_ffn : int
The hidden size of the positional feed forward network (a
``PositionwiseFFN`` layer).
dropout : float, optional
The probability of the dropout in MultiHeadAttention and
PositionwiseFFN, by default 0.
Notes
------
It uses the PostLN (post layer norm) scheme.
""" """
def __init__(self, d_model, n_heads, d_ffn, dropout=0.): def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
"""
Args:
d_model (int): the feature size of the input, and the output.
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
d_ffn (int): the hidden size of the internal PositionwiseFFN.
dropout (float, optional): the probability of the dropout in
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
"""
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayer, self).__init__()
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout) self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
@ -104,30 +165,52 @@ class TransformerDecoderLayer(nn.Layer):
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout) self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6) self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
self.dropout = dropout
def forward(self, q, k, v, encoder_mask, decoder_mask): def forward(self, q, k, v, encoder_mask, decoder_mask):
"""Forward pass of TransformerEncoderLayer.
Parameters
----------
q : Tensor [shape=(batch_size, time_steps_q, d_model)]
The decoder input.
k : Tensor [shape=(batch_size, time_steps_k, d_model)]
The keys.
v : Tensor [shape=(batch_size, time_steps_k, d_model)]
The values
encoder_mask : Tensor
Encoder padding mask, shape is ``(batch_size, time_steps_k,
time_steps_k)`` or broadcastable shape.
decoder_mask : Tensor
Decoder mask, shape is ``(batch_size, time_steps_q, time_steps_k)``
or broadcastable shape.
Returns
--------
q : Tensor [shape=(batch_size, time_steps_q, d_model)]
The decoder output.
self_attn_weights : Tensor [shape=(batch_size, n_heads, time_steps_q, time_steps_q)]
Decoder self attention.
cross_attn_weights : Tensor [shape=(batch_size, n_heads, time_steps_q, time_steps_k)]
Decoder-encoder cross attention.
""" """
Args: context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input. q = self.layer_norm1(
k (Tensor): shape(batch_size, time_steps_k, d_model), keys. F.dropout(q + context_vector,
v (Tensor): shape(batch_size, time_steps_k, d_model), values self.dropout,
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask. training=self.training))
decoder_mask (Tensor): shape(batch_size, time_steps_q) decoder padding mask.
Returns: context_vector, cross_attn_weights = self.cross_mha(q, k, v, encoder_mask)
(q, self_attn_weights, cross_attn_weights) q = self.layer_norm2(
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded. F.dropout(q + context_vector,
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. self.dropout,
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. training=self.training))
"""
tq = q.shape[1]
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
combined_mask = combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
q = self.layer_norm1(q + context_vector)
context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1)) q = self.layer_norm3(
q = self.layer_norm2(q + context_vector) F.dropout(q + self.ffn(q),
self.dropout,
q = self.layer_norm3(q + self.ffn(q)) training=self.training))
return q, self_attn_weights, cross_attn_weights return q, self_attn_weights, cross_attn_weights