commit
dd2c5cc6c6
|
@ -18,14 +18,13 @@ import paddle
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
mask=None,
|
||||
dropout=0.0,
|
||||
training=True):
|
||||
"""Scaled dot product attention with masking.
|
||||
r"""Scaled dot product attention with masking.
|
||||
|
||||
Assume that q, k, v all have the same leading dimensions (denoted as * in
|
||||
descriptions below). Dropout is applied to attention weights before
|
||||
|
@ -34,24 +33,24 @@ def scaled_dot_product_attention(q,
|
|||
Parameters
|
||||
-----------
|
||||
|
||||
q: Tensor [shape=(*, T_q, d)]
|
||||
q : Tensor [shape=(\*, T_q, d)]
|
||||
the query tensor.
|
||||
|
||||
k: Tensor [shape=(*, T_k, d)]
|
||||
k : Tensor [shape=(\*, T_k, d)]
|
||||
the key tensor.
|
||||
|
||||
v: Tensor [shape=(*, T_k, d_v)]
|
||||
v : Tensor [shape=(\*, T_k, d_v)]
|
||||
the value tensor.
|
||||
|
||||
mask: Tensor, [shape=(*, T_q, T_k) or broadcastable shape], optional
|
||||
mask : Tensor, [shape=(\*, T_q, T_k) or broadcastable shape], optional
|
||||
the mask tensor, zeros correspond to paddings. Defaults to None.
|
||||
|
||||
Returns
|
||||
----------
|
||||
out: Tensor [shape(*, T_q, d_v)]
|
||||
out : Tensor [shape=(\*, T_q, d_v)]
|
||||
the context vector.
|
||||
|
||||
attn_weights [Tensor shape(*, T_q, T_k)]
|
||||
attn_weights : Tensor [shape=(\*, T_q, T_k)]
|
||||
the attention weights.
|
||||
"""
|
||||
d = q.shape[-1] # we only support imperative execution
|
||||
|
@ -67,17 +66,25 @@ def scaled_dot_product_attention(q,
|
|||
return out, attn_weights
|
||||
|
||||
|
||||
def drop_head(x, drop_n_heads, training):
|
||||
"""
|
||||
Drop n heads from multiple context vectors.
|
||||
def drop_head(x, drop_n_heads, training=True):
|
||||
"""Drop n context vectors from multiple ones.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, num_heads, time_steps, channels), the input.
|
||||
drop_n_heads (int): [description]
|
||||
training ([type]): [description]
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(batch_size, num_heads, time_steps, channels)]
|
||||
The input, multiple context vectors.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
drop_n_heads : int [0<= drop_n_heads <= num_heads]
|
||||
Number of vectors to drop.
|
||||
|
||||
training : bool
|
||||
A flag indicating whether it is in training. If `False`, no dropout is
|
||||
applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
The output.
|
||||
"""
|
||||
if not training or (drop_n_heads == 0):
|
||||
return x
|
||||
|
@ -113,21 +120,30 @@ def _concat_heads(x):
|
|||
|
||||
# Standard implementations of Monohead Attention & Multihead Attention
|
||||
class MonoheadAttention(nn.Layer):
|
||||
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
|
||||
"""
|
||||
Monohead Attention module.
|
||||
"""Monohead Attention module.
|
||||
|
||||
Args:
|
||||
model_dim (int): the feature size of query.
|
||||
dropout (float, optional): dropout probability of scaled dot product
|
||||
attention and final context vector. Defaults to 0.0.
|
||||
k_dim (int, optional): feature size of the key of each scaled dot
|
||||
product attention. If not provided, it is set to
|
||||
model_dim / num_heads. Defaults to None.
|
||||
v_dim (int, optional): feature size of the key of each scaled dot
|
||||
product attention. If not provided, it is set to
|
||||
model_dim / num_heads. Defaults to None.
|
||||
Parameters
|
||||
----------
|
||||
model_dim : int
|
||||
Feature size of the query.
|
||||
|
||||
dropout : float, optional
|
||||
Dropout probability of scaled dot product attention and final context
|
||||
vector. Defaults to 0.0.
|
||||
|
||||
k_dim : int, optional
|
||||
Feature size of the key of each scaled dot product attention. If not
|
||||
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||
|
||||
v_dim : int, optional
|
||||
Feature size of the key of each scaled dot product attention. If not
|
||||
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||
"""
|
||||
def __init__(self,
|
||||
model_dim: int,
|
||||
dropout: float=0.0,
|
||||
k_dim: int=None,
|
||||
v_dim: int=None):
|
||||
super(MonoheadAttention, self).__init__()
|
||||
k_dim = k_dim or model_dim
|
||||
v_dim = v_dim or model_dim
|
||||
|
@ -140,20 +156,29 @@ class MonoheadAttention(nn.Layer):
|
|||
self.dropout = dropout
|
||||
|
||||
def forward(self, q, k, v, mask):
|
||||
"""
|
||||
Compute context vector and attention weights.
|
||||
"""Compute context vector and attention weights.
|
||||
|
||||
Args:
|
||||
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
|
||||
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
|
||||
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
|
||||
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
|
||||
broadcastable shape, dtype: float32 or float64, the mask.
|
||||
Parameters
|
||||
-----------
|
||||
q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||
The queries.
|
||||
|
||||
Returns:
|
||||
(out, attention_weights)
|
||||
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
||||
k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||
The keys.
|
||||
|
||||
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||
The values.
|
||||
|
||||
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
|
||||
The mask.
|
||||
|
||||
Returns
|
||||
----------
|
||||
out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||
The context vector.
|
||||
|
||||
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||
The attention weights.
|
||||
"""
|
||||
q = self.affine_q(q) # (B, T, C)
|
||||
k = self.affine_k(k)
|
||||
|
@ -167,34 +192,39 @@ class MonoheadAttention(nn.Layer):
|
|||
|
||||
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""
|
||||
Multihead scaled dot product attention.
|
||||
"""
|
||||
"""Multihead Attention module.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
model_dim: int
|
||||
The feature size of query.
|
||||
|
||||
num_heads : int
|
||||
The number of attention heads.
|
||||
|
||||
dropout : float, optional
|
||||
Dropout probability of scaled dot product attention and final context
|
||||
vector. Defaults to 0.0.
|
||||
|
||||
k_dim : int, optional
|
||||
Feature size of the key of each scaled dot product attention. If not
|
||||
provided, it is set to ``model_dim / num_heads``. Defaults to None.
|
||||
|
||||
v_dim : int, optional
|
||||
Feature size of the key of each scaled dot product attention. If not
|
||||
provided, it is set to ``model_dim / num_heads``. Defaults to None.
|
||||
|
||||
Raises
|
||||
---------
|
||||
ValueError
|
||||
If ``model_dim`` is not divisible by ``num_heads``.
|
||||
"""
|
||||
def __init__(self,
|
||||
model_dim,
|
||||
num_heads,
|
||||
dropout=0.0,
|
||||
k_dim=None,
|
||||
v_dim=None):
|
||||
"""
|
||||
Multihead Attention module.
|
||||
|
||||
Args:
|
||||
model_dim (int): the feature size of query.
|
||||
num_heads (int): the number of attention heads.
|
||||
dropout (float, optional): dropout probability of scaled dot product
|
||||
attention and final context vector. Defaults to 0.0.
|
||||
k_dim (int, optional): feature size of the key of each scaled dot
|
||||
product attention. If not provided, it is set to
|
||||
model_dim / num_heads. Defaults to None.
|
||||
v_dim (int, optional): feature size of the key of each scaled dot
|
||||
product attention. If not provided, it is set to
|
||||
model_dim / num_heads. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: if model_dim is not divisible by num_heads
|
||||
"""
|
||||
model_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float=0.0,
|
||||
k_dim: int=None,
|
||||
v_dim: int=None):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
if model_dim % num_heads != 0:
|
||||
raise ValueError("model_dim must be divisible by num_heads")
|
||||
|
@ -211,20 +241,29 @@ class MultiheadAttention(nn.Layer):
|
|||
self.dropout = dropout
|
||||
|
||||
def forward(self, q, k, v, mask):
|
||||
"""
|
||||
Compute context vector and attention weights.
|
||||
"""Compute context vector and attention weights.
|
||||
|
||||
Args:
|
||||
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
|
||||
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
|
||||
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
|
||||
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
|
||||
broadcastable shape, dtype: float32 or float64, the mask.
|
||||
Parameters
|
||||
-----------
|
||||
q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||
The queries.
|
||||
|
||||
Returns:
|
||||
(out, attention_weights)
|
||||
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
||||
k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||
The keys.
|
||||
|
||||
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||
The values.
|
||||
|
||||
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
|
||||
The mask.
|
||||
|
||||
Returns
|
||||
----------
|
||||
out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||
The context vector.
|
||||
|
||||
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||
The attention weights.
|
||||
"""
|
||||
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
||||
k = _split_heads(self.affine_k(k), self.num_heads)
|
||||
|
|
|
@ -8,28 +8,48 @@ __all__ = ["quantize", "dequantize", "STFT"]
|
|||
|
||||
|
||||
def quantize(values, n_bands):
|
||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in
|
||||
[0, n_bands).
|
||||
|
||||
Args:
|
||||
values (Tensor): dtype: flaot32 or float64. the floating point value.
|
||||
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
|
||||
Parameters
|
||||
-----------
|
||||
values : Tensor [dtype: flaot32 or float64]
|
||||
The floating point value.
|
||||
|
||||
Returns:
|
||||
Tensor: the quantized tensor, dtype: int64.
|
||||
n_bands : int
|
||||
The number of bands. The output integer Tensor's value is in the range
|
||||
[0, n_bans).
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor [dtype: int 64]
|
||||
The quantized tensor.
|
||||
"""
|
||||
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize(quantized, n_bands, dtype=None):
|
||||
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
|
||||
"""Linearlly dequantize an integer Tensor into a float Tensor in the range
|
||||
[-1, 1).
|
||||
|
||||
Args:
|
||||
quantized (Tensor): dtype: int64. The quantized value in the range [0, n_bands).
|
||||
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
|
||||
dtype (str, optional): data type of the output.
|
||||
Returns:
|
||||
Tensor: the dequantized tensor, dtype is specified by dtype.
|
||||
Parameters
|
||||
-----------
|
||||
quantized : Tensor [dtype: int]
|
||||
The quantized value in the range [0, n_bands).
|
||||
|
||||
n_bands : int
|
||||
Number of bands. The input integer Tensor's value is in the range
|
||||
[0, n_bans).
|
||||
|
||||
dtype : str, optional
|
||||
Data type of the output.
|
||||
|
||||
Returns
|
||||
-----------
|
||||
Tensor
|
||||
The dequantized tensor, dtype is specified by `dtype`. If `dtype` is
|
||||
not specified, the default float data type is used.
|
||||
"""
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
||||
|
@ -37,15 +57,36 @@ def dequantize(quantized, n_bands, dtype=None):
|
|||
|
||||
|
||||
class STFT(nn.Layer):
|
||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
|
||||
"""A module for computing stft transformation in a differentiable way.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
n_fft : int
|
||||
Number of samples in a frame.
|
||||
|
||||
hop_length : int
|
||||
Number of samples shifted between adjacent frames.
|
||||
|
||||
win_length : int
|
||||
Length of the window.
|
||||
|
||||
window : str, optional
|
||||
Name of window function, see `scipy.signal.get_window` for more
|
||||
details. Defaults to "hanning".
|
||||
|
||||
Notes
|
||||
-----------
|
||||
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
|
||||
details.
|
||||
|
||||
Given a audio which ``T`` samples, it the STFT transformation outputs a
|
||||
spectrum with (C, frames) and complex dtype, where ``C = 1 + n_fft / 2``
|
||||
and ``frames = 1 + T // hop_lenghth``.
|
||||
|
||||
Ony ``center`` and ``reflect`` padding is supported now.
|
||||
|
||||
Args:
|
||||
n_fft (int): number of samples in a frame.
|
||||
hop_length (int): number of samples shifted between adjacent frames.
|
||||
win_length (int): length of the window function.
|
||||
window (str, optional): name of window function, see `scipy.signal.get_window` for more details. Defaults to "hanning".
|
||||
"""
|
||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||
super(STFT, self).__init__()
|
||||
self.hop_length = hop_length
|
||||
self.n_bin = 1 + n_fft // 2
|
||||
|
@ -73,13 +114,18 @@ class STFT(nn.Layer):
|
|||
def forward(self, x):
|
||||
"""Compute the stft transform.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, T), dtype flaot32, the input waveform.
|
||||
Parameters
|
||||
------------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
|
||||
Returns:
|
||||
(real, imag)
|
||||
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram. (C = 1 + n_fft // 2)
|
||||
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram. (C = 1 + n_fft // 2)
|
||||
Returns
|
||||
------------
|
||||
real : Tensor [shape=(B, C, 1, frames)]
|
||||
The real part of the spectrogram.
|
||||
|
||||
imag : Tensor [shape=(B, C, 1, frames)]
|
||||
The image part of the spectrogram.
|
||||
"""
|
||||
# x(batch_size, time_steps)
|
||||
# pad it first with reflect mode
|
||||
|
@ -95,30 +141,34 @@ class STFT(nn.Layer):
|
|||
return real, imag
|
||||
|
||||
def power(self, x):
|
||||
"""Compute the power spectrogram.
|
||||
"""Compute the power spectrum.
|
||||
|
||||
Args:
|
||||
(real, imag)
|
||||
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram.
|
||||
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram.
|
||||
Parameters
|
||||
------------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C, 1, T), dtype flaot32, the power spectrogram.
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, 1, T)]
|
||||
The power spectrum.
|
||||
"""
|
||||
real, imag = self(x)
|
||||
power = real**2 + imag**2
|
||||
return power
|
||||
|
||||
def magnitude(self, x):
|
||||
"""Compute the magnitude spectrogram.
|
||||
"""Compute the magnitude of the spectrum.
|
||||
|
||||
Args:
|
||||
(real, imag)
|
||||
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram.
|
||||
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram.
|
||||
Parameters
|
||||
------------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C, 1, T), dtype flaot32, the magnitude spectrogram. It is the square root of the power spectrogram.
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, 1, T)]
|
||||
The magnitude of the spectrum.
|
||||
"""
|
||||
power = self.power(x)
|
||||
magnitude = paddle.sqrt(power)
|
||||
|
|
|
@ -4,16 +4,25 @@ import paddle
|
|||
def shuffle_dim(x, axis, perm=None):
|
||||
"""Permute input tensor along aixs given the permutation or randomly.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(*, d_{axis}, *), the input tensor.
|
||||
axis (int): the axis to shuffle.
|
||||
perm (list[int], ndarray, optional): a permutation of [0, d_{axis}),
|
||||
the order to reorder the tensor along the `axis`-th dimension, if
|
||||
not provided, randomly shuffle the `axis`-th dimension. Defaults to
|
||||
None.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
The input tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: the shuffled tensor, it has the same shape as x does.
|
||||
axis : int
|
||||
The axis to shuffle.
|
||||
|
||||
perm : List[int], ndarray, optional
|
||||
The order to reorder the tensor along the ``axis``-th dimension.
|
||||
|
||||
It is a permutation of ``[0, d)``, where d is the size of the
|
||||
``axis``-th dimension of the input tensor. If not provided,
|
||||
a random permutation is used. Defaults to None.
|
||||
|
||||
Returns
|
||||
---------
|
||||
Tensor
|
||||
The shuffled tensor, which has the same shape as x does.
|
||||
"""
|
||||
size = x.shape[axis]
|
||||
if perm is not None and len(perm) != size:
|
||||
|
|
|
@ -4,29 +4,128 @@ import paddle
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
__all__ = [
|
||||
"weighted_mean",
|
||||
"masked_l1_loss",
|
||||
"masked_softmax_with_cross_entropy",
|
||||
"diagonal_loss",
|
||||
]
|
||||
|
||||
def weighted_mean(input, weight):
|
||||
"""weighted mean.(It can also be used as masked mean.)
|
||||
"""Weighted mean. It can also be used as masked mean.
|
||||
|
||||
Args:
|
||||
input (Tensor): input tensor, floating point dtype.
|
||||
weight (Tensor): weight tensor with broadcastable shape.
|
||||
Parameters
|
||||
-----------
|
||||
input : Tensor
|
||||
The input tensor.
|
||||
weight : Tensor
|
||||
The weight tensor with broadcastable shape with the input.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
||||
Returns
|
||||
----------
|
||||
Tensor [shape=(1,)]
|
||||
Weighted mean tensor with the same dtype as input.
|
||||
|
||||
Warnings
|
||||
---------
|
||||
This is not a mathematical weighted mean. It performs weighted sum and
|
||||
simple average.
|
||||
"""
|
||||
weight = paddle.cast(weight, input.dtype)
|
||||
return paddle.mean(input * weight)
|
||||
|
||||
|
||||
def masked_l1_loss(prediction, target, mask):
|
||||
"""Compute maksed L1 loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prediction : Tensor
|
||||
The prediction.
|
||||
|
||||
target : Tensor
|
||||
The target. The shape should be broadcastable to ``prediction``.
|
||||
|
||||
mask : Tensor
|
||||
The mask. The shape should be broadcatable to the broadcasted shape of
|
||||
``prediction`` and ``target``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(1,)]
|
||||
The masked L1 loss.
|
||||
"""
|
||||
abs_error = F.l1_loss(prediction, target, reduction='none')
|
||||
return weighted_mean(abs_error, mask)
|
||||
loss = weighted_mean(abs_error, mask)
|
||||
return loss
|
||||
|
||||
|
||||
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
||||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||
return weighted_mean(ce, mask)
|
||||
"""Compute masked softmax with cross entropy loss.
|
||||
|
||||
def diagonal_loss(attentions, input_lengths, target_lengths, g=0.2, multihead=False):
|
||||
"""A metric to evaluate how diagonal a attention distribution is."""
|
||||
Parameters
|
||||
----------
|
||||
logits : Tensor
|
||||
The logits. The ``axis``-th axis is the class dimension.
|
||||
|
||||
label : Tensor [dtype: int]
|
||||
The label. The size of the ``axis``-th axis should be 1.
|
||||
|
||||
mask : Tensor
|
||||
The mask. The shape should be broadcastable to ``label``.
|
||||
|
||||
axis : int, optional
|
||||
The index of the class dimension in the shape of ``logits``, by default
|
||||
-1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(1,)]
|
||||
The masked softmax with cross entropy loss.
|
||||
"""
|
||||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||
loss = weighted_mean(ce, mask)
|
||||
return loss
|
||||
|
||||
|
||||
def diagonal_loss(
|
||||
attentions,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
g=0.2,
|
||||
multihead=False):
|
||||
"""A metric to evaluate how diagonal a attention distribution is.
|
||||
|
||||
It is computed for batch attention distributions. For each attention
|
||||
distribution, the valid decoder time steps and encoder time steps may
|
||||
differ.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
|
||||
The attention weights from an encoder-decoder structure.
|
||||
|
||||
input_lengths : Tensor [shape=(B,)]
|
||||
The valid length for each encoder output.
|
||||
|
||||
target_lengths : Tensor [shape=(B,)]
|
||||
The valid length for each decoder output.
|
||||
|
||||
g : float, optional
|
||||
[description], by default 0.2.
|
||||
|
||||
multihead : bool, optional
|
||||
A flag indicating whether ``attentions`` is a multihead attention's
|
||||
attention distribution.
|
||||
|
||||
If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by
|
||||
default False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(1,)]
|
||||
The diagonal loss.
|
||||
"""
|
||||
W = guided_attentions(input_lengths, target_lengths, g)
|
||||
W_tensor = paddle.to_tensor(W)
|
||||
if not multihead:
|
||||
|
|
|
@ -1,32 +1,114 @@
|
|||
import paddle
|
||||
from paddle.fluid.layers import sequence_mask
|
||||
|
||||
__all__ = [
|
||||
"id_mask",
|
||||
"feature_mask",
|
||||
"combine_mask",
|
||||
"future_mask",
|
||||
]
|
||||
|
||||
def id_mask(input, padding_index=0, dtype="bool"):
|
||||
"""Generate mask with input ids.
|
||||
|
||||
Those positions where the value equals ``padding_index`` correspond to 0 or
|
||||
``False``, otherwise, 1 or ``True``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input : Tensor [dtype: int]
|
||||
The input tensor. It represents the ids.
|
||||
|
||||
padding_index : int, optional
|
||||
The id which represents padding, by default 0.
|
||||
|
||||
dtype : str, optional
|
||||
Data type of the returned mask, by default "bool".
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
The generate mask. It has the same shape as ``input`` does.
|
||||
"""
|
||||
return paddle.cast(input != padding_index, dtype)
|
||||
|
||||
|
||||
def feature_mask(input, axis, dtype="bool"):
|
||||
"""Compute mask from input features.
|
||||
|
||||
For a input features, represented as batched feature vectors, those vectors
|
||||
which all zeros are considerd padding vectors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input : Tensor [dtype: float]
|
||||
The input tensor which represents featues.
|
||||
|
||||
axis : int
|
||||
The index of the feature dimension in ``input``. Other dimensions are
|
||||
considered ``spatial`` dimensions.
|
||||
|
||||
dtype : str, optional
|
||||
Data type of the generated mask, by default "bool"
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
The geenrated mask with ``spatial`` shape as mentioned above.
|
||||
|
||||
It has one less dimension than ``input`` does.
|
||||
"""
|
||||
feature_sum = paddle.sum(paddle.abs(input), axis)
|
||||
return paddle.cast(feature_sum != 0, dtype)
|
||||
|
||||
def combine_mask(padding_mask, no_future_mask):
|
||||
"""
|
||||
Combine the padding mask and no future mask for transformer decoder.
|
||||
Padding mask is used to mask padding positions and no future mask is used
|
||||
to prevent the decoder to see future information.
|
||||
|
||||
Args:
|
||||
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask.
|
||||
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
|
||||
def combine_mask(mask1, mask2):
|
||||
"""Combine two mask with multiplication or logical and.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(batch_size, time_steps, time_steps), combined mask.
|
||||
Parameters
|
||||
-----------
|
||||
mask1 : Tensor
|
||||
The first mask.
|
||||
|
||||
mask2 : Tensor
|
||||
The second mask with broadcastable shape with ``mask1``.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor
|
||||
Combined mask.
|
||||
|
||||
Notes
|
||||
------
|
||||
It is mainly used to combine the padding mask and no future mask for
|
||||
transformer decoder.
|
||||
|
||||
Padding mask is used to mask padding positions of the decoder inputs and
|
||||
no future mask is used to prevent the decoder to see future information.
|
||||
"""
|
||||
# TODO: to support boolean mask by using logical_and?
|
||||
if padding_mask.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
|
||||
return paddle.logical_and(padding_mask, no_future_mask)
|
||||
if mask1.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
|
||||
return paddle.logical_and(mask1, mask2)
|
||||
else:
|
||||
return padding_mask * no_future_mask
|
||||
return mask1 * mask2
|
||||
|
||||
|
||||
def future_mask(time_steps, dtype="bool"):
|
||||
"""Generate lower triangular mask.
|
||||
|
||||
It is used at transformer decoder to prevent the decoder to see future
|
||||
information.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_steps : int
|
||||
Decoder time steps.
|
||||
dtype : str, optional
|
||||
The data type of the generate mask, by default "bool".
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
The generated mask.
|
||||
"""
|
||||
mask = paddle.tril(paddle.ones([time_steps, time_steps]))
|
||||
return paddle.cast(mask, dtype)
|
||||
|
|
|
@ -3,21 +3,34 @@ import numpy as np
|
|||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
__all__ = ["positional_encoding"]
|
||||
|
||||
def positional_encoding(start_index, length, size, dtype=None):
|
||||
"""
|
||||
Generate standard positional encoding.
|
||||
r"""Generate standard positional encoding matrix.
|
||||
|
||||
pe(pos, 2i) = sin(pos / 10000 ** (2i / size))
|
||||
pe(pos, 2i+1) = cos(pos / 10000 ** (2i / size))
|
||||
.. math::
|
||||
|
||||
Args:
|
||||
start_index (int): the start index.
|
||||
length (int): the length of the positional encoding.
|
||||
size (int): positional encoding dimension.
|
||||
pe(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{size}}}) \\
|
||||
pe(pos, 2i+1) = cos(\frac{pos}{10000^{\frac{2i}{size}}})
|
||||
|
||||
Returns:
|
||||
encodings (Tensor): shape(length, size), the positional encoding.
|
||||
Parameters
|
||||
----------
|
||||
start_index : int
|
||||
The start index.
|
||||
length : int
|
||||
The timesteps of the positional encoding to generate.
|
||||
size : int
|
||||
Feature size of positional encoding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(length, size)]
|
||||
The positional encoding.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If ``size`` is not divisible by 2.
|
||||
"""
|
||||
if (size % 2 != 0):
|
||||
raise ValueError("size should be divisible by 2")
|
||||
|
|
|
@ -5,23 +5,35 @@ from paddle.nn import functional as F
|
|||
|
||||
from parakeet.modules import attention as attn
|
||||
from parakeet.modules.masking import combine_mask
|
||||
|
||||
__all__ = [
|
||||
"PositionwiseFFN",
|
||||
"TransformerEncoderLayer",
|
||||
"TransformerDecoderLayer",
|
||||
]
|
||||
|
||||
class PositionwiseFFN(nn.Layer):
|
||||
"""
|
||||
A faithful implementation of Position-wise Feed-Forward Network
|
||||
"""A faithful implementation of Position-wise Feed-Forward Network
|
||||
in `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
It is basically a 3-layer MLP, with relu actication and dropout in between.
|
||||
It is basically a 2-layer MLP, with relu actication and dropout in between.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_size: int
|
||||
The feature size of the intput. It is also the feature size of the
|
||||
output.
|
||||
|
||||
hidden_size: int
|
||||
The hidden size.
|
||||
|
||||
dropout: float
|
||||
The probability of the Dropout applied to the output of the first
|
||||
layer, by default 0.
|
||||
"""
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
dropout=0.0):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): the input feature size.
|
||||
hidden_size (int): the hidden layer's feature size.
|
||||
dropout (float, optional): probability of dropout applied to the
|
||||
output of the first fully connected layer. Defaults to 0.0.
|
||||
"""
|
||||
super(PositionwiseFFN, self).__init__()
|
||||
self.linear1 = nn.Linear(input_size, hidden_size)
|
||||
self.linear2 = nn.Linear(hidden_size, input_size)
|
||||
|
@ -31,13 +43,17 @@ class PositionwiseFFN(nn.Layer):
|
|||
self.hidden_szie = hidden_size
|
||||
|
||||
def forward(self, x):
|
||||
"""positionwise feed forward network.
|
||||
r"""Forward pass of positionwise feed forward network.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(*, input_size), the input tensor.
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(\*, input_size)]
|
||||
The input tensor, where ``\*`` means arbitary shape.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(*, input_size), the output tensor.
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(\*, input_size)]
|
||||
The output tensor.
|
||||
"""
|
||||
l1 = self.dropout(F.relu(self.linear1(x)))
|
||||
l2 = self.linear2(l1)
|
||||
|
@ -45,18 +61,32 @@ class PositionwiseFFN(nn.Layer):
|
|||
|
||||
|
||||
class TransformerEncoderLayer(nn.Layer):
|
||||
"""
|
||||
Transformer encoder layer.
|
||||
"""A faithful implementation of Transformer encoder layer in
|
||||
`Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_model :int
|
||||
The feature size of the input. It is also the feature size of the
|
||||
output.
|
||||
|
||||
n_heads : int
|
||||
The number of heads of self attention (a ``MultiheadAttention``
|
||||
layer).
|
||||
|
||||
d_ffn : int
|
||||
The hidden size of the positional feed forward network (a
|
||||
``PositionwiseFFN`` layer).
|
||||
|
||||
dropout : float, optional
|
||||
The probability of the dropout in MultiHeadAttention and
|
||||
PositionwiseFFN, by default 0.
|
||||
|
||||
Notes
|
||||
------
|
||||
It uses the PostLN (post layer norm) scheme.
|
||||
"""
|
||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
||||
"""
|
||||
Args:
|
||||
d_model (int): the feature size of the input, and the output.
|
||||
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
|
||||
d_ffn (int): the hidden size of the internal PositionwiseFFN.
|
||||
dropout (float, optional): the probability of the dropout in
|
||||
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
|
||||
"""
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
|
||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||
|
@ -64,37 +94,68 @@ class TransformerEncoderLayer(nn.Layer):
|
|||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, mask):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
|
||||
mask (Tensor): shape(batch_size, time_steps), the padding mask.
|
||||
"""Forward pass of TransformerEncoderLayer.
|
||||
|
||||
Returns:
|
||||
(x, attn_weights)
|
||||
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
|
||||
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
|
||||
"""
|
||||
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1))
|
||||
x = self.layer_norm1(x + context_vector)
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(batch_size, time_steps, d_model)]
|
||||
The input.
|
||||
|
||||
x = self.layer_norm2(x + self.ffn(x))
|
||||
mask : Tensor
|
||||
The padding mask. The shape is (batch_size, time_steps,
|
||||
time_steps) or broadcastable shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x :Tensor [shape=(batch_size, time_steps, d_model)]
|
||||
The encoded output.
|
||||
|
||||
attn_weights : Tensor [shape=(batch_size, n_heads, time_steps, time_steps)]
|
||||
The attention weights of the self attention.
|
||||
"""
|
||||
context_vector, attn_weights = self.self_mha(x, x, x, mask)
|
||||
x = self.layer_norm1(
|
||||
F.dropout(x + context_vector,
|
||||
self.dropout,
|
||||
training=self.training))
|
||||
|
||||
x = self.layer_norm2(
|
||||
F.dropout(x + self.ffn(x),
|
||||
self.dropout,
|
||||
training=self.training))
|
||||
return x, attn_weights
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Layer):
|
||||
"""
|
||||
Transformer decoder layer.
|
||||
"""A faithful implementation of Transformer decoder layer in
|
||||
`Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_model :int
|
||||
The feature size of the input. It is also the feature size of the
|
||||
output.
|
||||
|
||||
n_heads : int
|
||||
The number of heads of attentions (``MultiheadAttention``
|
||||
layers).
|
||||
|
||||
d_ffn : int
|
||||
The hidden size of the positional feed forward network (a
|
||||
``PositionwiseFFN`` layer).
|
||||
|
||||
dropout : float, optional
|
||||
The probability of the dropout in MultiHeadAttention and
|
||||
PositionwiseFFN, by default 0.
|
||||
|
||||
Notes
|
||||
------
|
||||
It uses the PostLN (post layer norm) scheme.
|
||||
"""
|
||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
||||
"""
|
||||
Args:
|
||||
d_model (int): the feature size of the input, and the output.
|
||||
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
|
||||
d_ffn (int): the hidden size of the internal PositionwiseFFN.
|
||||
dropout (float, optional): the probability of the dropout in
|
||||
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
|
||||
"""
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
|
||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||
|
@ -105,29 +166,51 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, q, k, v, encoder_mask, decoder_mask):
|
||||
"""Forward pass of TransformerEncoderLayer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
q : Tensor [shape=(batch_size, time_steps_q, d_model)]
|
||||
The decoder input.
|
||||
k : Tensor [shape=(batch_size, time_steps_k, d_model)]
|
||||
The keys.
|
||||
v : Tensor [shape=(batch_size, time_steps_k, d_model)]
|
||||
The values
|
||||
encoder_mask : Tensor
|
||||
Encoder padding mask, shape is ``(batch_size, time_steps_k,
|
||||
time_steps_k)`` or broadcastable shape.
|
||||
decoder_mask : Tensor
|
||||
Decoder mask, shape is ``(batch_size, time_steps_q, time_steps_k)``
|
||||
or broadcastable shape.
|
||||
|
||||
Returns
|
||||
--------
|
||||
q : Tensor [shape=(batch_size, time_steps_q, d_model)]
|
||||
The decoder output.
|
||||
|
||||
self_attn_weights : Tensor [shape=(batch_size, n_heads, time_steps_q, time_steps_q)]
|
||||
Decoder self attention.
|
||||
|
||||
cross_attn_weights : Tensor [shape=(batch_size, n_heads, time_steps_q, time_steps_k)]
|
||||
Decoder-encoder cross attention.
|
||||
"""
|
||||
Args:
|
||||
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
|
||||
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
|
||||
v (Tensor): shape(batch_size, time_steps_k, d_model), values
|
||||
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask.
|
||||
decoder_mask (Tensor): shape(batch_size, time_steps_q) decoder padding mask.
|
||||
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
|
||||
q = self.layer_norm1(
|
||||
F.dropout(q + context_vector,
|
||||
self.dropout,
|
||||
training=self.training))
|
||||
|
||||
Returns:
|
||||
(q, self_attn_weights, cross_attn_weights)
|
||||
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
|
||||
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
||||
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
||||
"""
|
||||
tq = q.shape[1]
|
||||
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
|
||||
combined_mask = combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
|
||||
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
|
||||
q = self.layer_norm1(q + context_vector)
|
||||
context_vector, cross_attn_weights = self.cross_mha(q, k, v, encoder_mask)
|
||||
q = self.layer_norm2(
|
||||
F.dropout(q + context_vector,
|
||||
self.dropout,
|
||||
training=self.training))
|
||||
|
||||
context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1))
|
||||
q = self.layer_norm2(q + context_vector)
|
||||
|
||||
q = self.layer_norm3(q + self.ffn(q))
|
||||
q = self.layer_norm3(
|
||||
F.dropout(q + self.ffn(q),
|
||||
self.dropout,
|
||||
training=self.training))
|
||||
return q, self_attn_weights, cross_attn_weights
|
||||
|
|
Loading…
Reference in New Issue