commit
dd2c5cc6c6
|
@ -18,14 +18,13 @@ import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(q,
|
def scaled_dot_product_attention(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
mask=None,
|
mask=None,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
training=True):
|
training=True):
|
||||||
"""Scaled dot product attention with masking.
|
r"""Scaled dot product attention with masking.
|
||||||
|
|
||||||
Assume that q, k, v all have the same leading dimensions (denoted as * in
|
Assume that q, k, v all have the same leading dimensions (denoted as * in
|
||||||
descriptions below). Dropout is applied to attention weights before
|
descriptions below). Dropout is applied to attention weights before
|
||||||
|
@ -34,24 +33,24 @@ def scaled_dot_product_attention(q,
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
|
|
||||||
q: Tensor [shape=(*, T_q, d)]
|
q : Tensor [shape=(\*, T_q, d)]
|
||||||
the query tensor.
|
the query tensor.
|
||||||
|
|
||||||
k: Tensor [shape=(*, T_k, d)]
|
k : Tensor [shape=(\*, T_k, d)]
|
||||||
the key tensor.
|
the key tensor.
|
||||||
|
|
||||||
v: Tensor [shape=(*, T_k, d_v)]
|
v : Tensor [shape=(\*, T_k, d_v)]
|
||||||
the value tensor.
|
the value tensor.
|
||||||
|
|
||||||
mask: Tensor, [shape=(*, T_q, T_k) or broadcastable shape], optional
|
mask : Tensor, [shape=(\*, T_q, T_k) or broadcastable shape], optional
|
||||||
the mask tensor, zeros correspond to paddings. Defaults to None.
|
the mask tensor, zeros correspond to paddings. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
out: Tensor [shape(*, T_q, d_v)]
|
out : Tensor [shape=(\*, T_q, d_v)]
|
||||||
the context vector.
|
the context vector.
|
||||||
|
|
||||||
attn_weights [Tensor shape(*, T_q, T_k)]
|
attn_weights : Tensor [shape=(\*, T_q, T_k)]
|
||||||
the attention weights.
|
the attention weights.
|
||||||
"""
|
"""
|
||||||
d = q.shape[-1] # we only support imperative execution
|
d = q.shape[-1] # we only support imperative execution
|
||||||
|
@ -67,17 +66,25 @@ def scaled_dot_product_attention(q,
|
||||||
return out, attn_weights
|
return out, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def drop_head(x, drop_n_heads, training):
|
def drop_head(x, drop_n_heads, training=True):
|
||||||
"""
|
"""Drop n context vectors from multiple ones.
|
||||||
Drop n heads from multiple context vectors.
|
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(batch_size, num_heads, time_steps, channels), the input.
|
----------
|
||||||
drop_n_heads (int): [description]
|
x : Tensor [shape=(batch_size, num_heads, time_steps, channels)]
|
||||||
training ([type]): [description]
|
The input, multiple context vectors.
|
||||||
|
|
||||||
|
drop_n_heads : int [0<= drop_n_heads <= num_heads]
|
||||||
|
Number of vectors to drop.
|
||||||
|
|
||||||
|
training : bool
|
||||||
|
A flag indicating whether it is in training. If `False`, no dropout is
|
||||||
|
applied.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
[type]: [description]
|
-------
|
||||||
|
Tensor
|
||||||
|
The output.
|
||||||
"""
|
"""
|
||||||
if not training or (drop_n_heads == 0):
|
if not training or (drop_n_heads == 0):
|
||||||
return x
|
return x
|
||||||
|
@ -113,21 +120,30 @@ def _concat_heads(x):
|
||||||
|
|
||||||
# Standard implementations of Monohead Attention & Multihead Attention
|
# Standard implementations of Monohead Attention & Multihead Attention
|
||||||
class MonoheadAttention(nn.Layer):
|
class MonoheadAttention(nn.Layer):
|
||||||
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
|
"""Monohead Attention module.
|
||||||
"""
|
|
||||||
Monohead Attention module.
|
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
model_dim (int): the feature size of query.
|
----------
|
||||||
dropout (float, optional): dropout probability of scaled dot product
|
model_dim : int
|
||||||
attention and final context vector. Defaults to 0.0.
|
Feature size of the query.
|
||||||
k_dim (int, optional): feature size of the key of each scaled dot
|
|
||||||
product attention. If not provided, it is set to
|
dropout : float, optional
|
||||||
model_dim / num_heads. Defaults to None.
|
Dropout probability of scaled dot product attention and final context
|
||||||
v_dim (int, optional): feature size of the key of each scaled dot
|
vector. Defaults to 0.0.
|
||||||
product attention. If not provided, it is set to
|
|
||||||
model_dim / num_heads. Defaults to None.
|
k_dim : int, optional
|
||||||
"""
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||||
|
|
||||||
|
v_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
model_dim: int,
|
||||||
|
dropout: float=0.0,
|
||||||
|
k_dim: int=None,
|
||||||
|
v_dim: int=None):
|
||||||
super(MonoheadAttention, self).__init__()
|
super(MonoheadAttention, self).__init__()
|
||||||
k_dim = k_dim or model_dim
|
k_dim = k_dim or model_dim
|
||||||
v_dim = v_dim or model_dim
|
v_dim = v_dim or model_dim
|
||||||
|
@ -140,20 +156,29 @@ class MonoheadAttention(nn.Layer):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, q, k, v, mask):
|
def forward(self, q, k, v, mask):
|
||||||
"""
|
"""Compute context vector and attention weights.
|
||||||
Compute context vector and attention weights.
|
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
|
-----------
|
||||||
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
|
q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
|
The queries.
|
||||||
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
|
|
||||||
broadcastable shape, dtype: float32 or float64, the mask.
|
k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The keys.
|
||||||
|
|
||||||
|
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The values.
|
||||||
|
|
||||||
|
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
|
||||||
|
The mask.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
(out, attention_weights)
|
----------
|
||||||
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
The context vector.
|
||||||
|
|
||||||
|
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||||
|
The attention weights.
|
||||||
"""
|
"""
|
||||||
q = self.affine_q(q) # (B, T, C)
|
q = self.affine_q(q) # (B, T, C)
|
||||||
k = self.affine_k(k)
|
k = self.affine_k(k)
|
||||||
|
@ -167,34 +192,39 @@ class MonoheadAttention(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(nn.Layer):
|
class MultiheadAttention(nn.Layer):
|
||||||
"""
|
"""Multihead Attention module.
|
||||||
Multihead scaled dot product attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
model_dim: int
|
||||||
|
The feature size of query.
|
||||||
|
|
||||||
|
num_heads : int
|
||||||
|
The number of attention heads.
|
||||||
|
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout probability of scaled dot product attention and final context
|
||||||
|
vector. Defaults to 0.0.
|
||||||
|
|
||||||
|
k_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to ``model_dim / num_heads``. Defaults to None.
|
||||||
|
|
||||||
|
v_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to ``model_dim / num_heads``. Defaults to None.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
---------
|
||||||
|
ValueError
|
||||||
|
If ``model_dim`` is not divisible by ``num_heads``.
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dim,
|
model_dim: int,
|
||||||
num_heads,
|
num_heads: int,
|
||||||
dropout=0.0,
|
dropout: float=0.0,
|
||||||
k_dim=None,
|
k_dim: int=None,
|
||||||
v_dim=None):
|
v_dim: int=None):
|
||||||
"""
|
|
||||||
Multihead Attention module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_dim (int): the feature size of query.
|
|
||||||
num_heads (int): the number of attention heads.
|
|
||||||
dropout (float, optional): dropout probability of scaled dot product
|
|
||||||
attention and final context vector. Defaults to 0.0.
|
|
||||||
k_dim (int, optional): feature size of the key of each scaled dot
|
|
||||||
product attention. If not provided, it is set to
|
|
||||||
model_dim / num_heads. Defaults to None.
|
|
||||||
v_dim (int, optional): feature size of the key of each scaled dot
|
|
||||||
product attention. If not provided, it is set to
|
|
||||||
model_dim / num_heads. Defaults to None.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if model_dim is not divisible by num_heads
|
|
||||||
"""
|
|
||||||
super(MultiheadAttention, self).__init__()
|
super(MultiheadAttention, self).__init__()
|
||||||
if model_dim % num_heads != 0:
|
if model_dim % num_heads != 0:
|
||||||
raise ValueError("model_dim must be divisible by num_heads")
|
raise ValueError("model_dim must be divisible by num_heads")
|
||||||
|
@ -211,20 +241,29 @@ class MultiheadAttention(nn.Layer):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, q, k, v, mask):
|
def forward(self, q, k, v, mask):
|
||||||
"""
|
"""Compute context vector and attention weights.
|
||||||
Compute context vector and attention weights.
|
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
|
-----------
|
||||||
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
|
q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
|
The queries.
|
||||||
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
|
|
||||||
broadcastable shape, dtype: float32 or float64, the mask.
|
k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The keys.
|
||||||
|
|
||||||
|
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The values.
|
||||||
|
|
||||||
|
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
|
||||||
|
The mask.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
(out, attention_weights)
|
----------
|
||||||
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
|
out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
|
The context vector.
|
||||||
|
|
||||||
|
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||||
|
The attention weights.
|
||||||
"""
|
"""
|
||||||
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
||||||
k = _split_heads(self.affine_k(k), self.num_heads)
|
k = _split_heads(self.affine_k(k), self.num_heads)
|
||||||
|
|
|
@ -8,28 +8,48 @@ __all__ = ["quantize", "dequantize", "STFT"]
|
||||||
|
|
||||||
|
|
||||||
def quantize(values, n_bands):
|
def quantize(values, n_bands):
|
||||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in
|
||||||
|
[0, n_bands).
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
values (Tensor): dtype: flaot32 or float64. the floating point value.
|
-----------
|
||||||
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
|
values : Tensor [dtype: flaot32 or float64]
|
||||||
|
The floating point value.
|
||||||
|
|
||||||
|
n_bands : int
|
||||||
|
The number of bands. The output integer Tensor's value is in the range
|
||||||
|
[0, n_bans).
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: the quantized tensor, dtype: int64.
|
----------
|
||||||
|
Tensor [dtype: int 64]
|
||||||
|
The quantized tensor.
|
||||||
"""
|
"""
|
||||||
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
||||||
return quantized
|
return quantized
|
||||||
|
|
||||||
|
|
||||||
def dequantize(quantized, n_bands, dtype=None):
|
def dequantize(quantized, n_bands, dtype=None):
|
||||||
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
|
"""Linearlly dequantize an integer Tensor into a float Tensor in the range
|
||||||
|
[-1, 1).
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
quantized (Tensor): dtype: int64. The quantized value in the range [0, n_bands).
|
-----------
|
||||||
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
|
quantized : Tensor [dtype: int]
|
||||||
dtype (str, optional): data type of the output.
|
The quantized value in the range [0, n_bands).
|
||||||
Returns:
|
|
||||||
Tensor: the dequantized tensor, dtype is specified by dtype.
|
n_bands : int
|
||||||
|
Number of bands. The input integer Tensor's value is in the range
|
||||||
|
[0, n_bans).
|
||||||
|
|
||||||
|
dtype : str, optional
|
||||||
|
Data type of the output.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-----------
|
||||||
|
Tensor
|
||||||
|
The dequantized tensor, dtype is specified by `dtype`. If `dtype` is
|
||||||
|
not specified, the default float data type is used.
|
||||||
"""
|
"""
|
||||||
dtype = dtype or paddle.get_default_dtype()
|
dtype = dtype or paddle.get_default_dtype()
|
||||||
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
||||||
|
@ -37,15 +57,36 @@ def dequantize(quantized, n_bands, dtype=None):
|
||||||
|
|
||||||
|
|
||||||
class STFT(nn.Layer):
|
class STFT(nn.Layer):
|
||||||
|
"""A module for computing stft transformation in a differentiable way.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
------------
|
||||||
|
n_fft : int
|
||||||
|
Number of samples in a frame.
|
||||||
|
|
||||||
|
hop_length : int
|
||||||
|
Number of samples shifted between adjacent frames.
|
||||||
|
|
||||||
|
win_length : int
|
||||||
|
Length of the window.
|
||||||
|
|
||||||
|
window : str, optional
|
||||||
|
Name of window function, see `scipy.signal.get_window` for more
|
||||||
|
details. Defaults to "hanning".
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----------
|
||||||
|
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
|
||||||
|
details.
|
||||||
|
|
||||||
|
Given a audio which ``T`` samples, it the STFT transformation outputs a
|
||||||
|
spectrum with (C, frames) and complex dtype, where ``C = 1 + n_fft / 2``
|
||||||
|
and ``frames = 1 + T // hop_lenghth``.
|
||||||
|
|
||||||
|
Ony ``center`` and ``reflect`` padding is supported now.
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||||
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n_fft (int): number of samples in a frame.
|
|
||||||
hop_length (int): number of samples shifted between adjacent frames.
|
|
||||||
win_length (int): length of the window function.
|
|
||||||
window (str, optional): name of window function, see `scipy.signal.get_window` for more details. Defaults to "hanning".
|
|
||||||
"""
|
|
||||||
super(STFT, self).__init__()
|
super(STFT, self).__init__()
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.n_bin = 1 + n_fft // 2
|
self.n_bin = 1 + n_fft // 2
|
||||||
|
@ -73,13 +114,18 @@ class STFT(nn.Layer):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Compute the stft transform.
|
"""Compute the stft transform.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Variable): shape(B, T), dtype flaot32, the input waveform.
|
------------
|
||||||
|
x : Tensor [shape=(B, T)]
|
||||||
|
The input waveform.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
(real, imag)
|
------------
|
||||||
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram. (C = 1 + n_fft // 2)
|
real : Tensor [shape=(B, C, 1, frames)]
|
||||||
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram. (C = 1 + n_fft // 2)
|
The real part of the spectrogram.
|
||||||
|
|
||||||
|
imag : Tensor [shape=(B, C, 1, frames)]
|
||||||
|
The image part of the spectrogram.
|
||||||
"""
|
"""
|
||||||
# x(batch_size, time_steps)
|
# x(batch_size, time_steps)
|
||||||
# pad it first with reflect mode
|
# pad it first with reflect mode
|
||||||
|
@ -95,30 +141,34 @@ class STFT(nn.Layer):
|
||||||
return real, imag
|
return real, imag
|
||||||
|
|
||||||
def power(self, x):
|
def power(self, x):
|
||||||
"""Compute the power spectrogram.
|
"""Compute the power spectrum.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
(real, imag)
|
------------
|
||||||
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram.
|
x : Tensor [shape=(B, T)]
|
||||||
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram.
|
The input waveform.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Variable: shape(B, C, 1, T), dtype flaot32, the power spectrogram.
|
------------
|
||||||
|
Tensor [shape=(B, C, 1, T)]
|
||||||
|
The power spectrum.
|
||||||
"""
|
"""
|
||||||
real, imag = self(x)
|
real, imag = self(x)
|
||||||
power = real**2 + imag**2
|
power = real**2 + imag**2
|
||||||
return power
|
return power
|
||||||
|
|
||||||
def magnitude(self, x):
|
def magnitude(self, x):
|
||||||
"""Compute the magnitude spectrogram.
|
"""Compute the magnitude of the spectrum.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
(real, imag)
|
------------
|
||||||
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram.
|
x : Tensor [shape=(B, T)]
|
||||||
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram.
|
The input waveform.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Variable: shape(B, C, 1, T), dtype flaot32, the magnitude spectrogram. It is the square root of the power spectrogram.
|
------------
|
||||||
|
Tensor [shape=(B, C, 1, T)]
|
||||||
|
The magnitude of the spectrum.
|
||||||
"""
|
"""
|
||||||
power = self.power(x)
|
power = self.power(x)
|
||||||
magnitude = paddle.sqrt(power)
|
magnitude = paddle.sqrt(power)
|
||||||
|
|
|
@ -4,16 +4,25 @@ import paddle
|
||||||
def shuffle_dim(x, axis, perm=None):
|
def shuffle_dim(x, axis, perm=None):
|
||||||
"""Permute input tensor along aixs given the permutation or randomly.
|
"""Permute input tensor along aixs given the permutation or randomly.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(*, d_{axis}, *), the input tensor.
|
----------
|
||||||
axis (int): the axis to shuffle.
|
x : Tensor
|
||||||
perm (list[int], ndarray, optional): a permutation of [0, d_{axis}),
|
The input tensor.
|
||||||
the order to reorder the tensor along the `axis`-th dimension, if
|
|
||||||
not provided, randomly shuffle the `axis`-th dimension. Defaults to
|
axis : int
|
||||||
None.
|
The axis to shuffle.
|
||||||
|
|
||||||
|
perm : List[int], ndarray, optional
|
||||||
|
The order to reorder the tensor along the ``axis``-th dimension.
|
||||||
|
|
||||||
|
It is a permutation of ``[0, d)``, where d is the size of the
|
||||||
|
``axis``-th dimension of the input tensor. If not provided,
|
||||||
|
a random permutation is used. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: the shuffled tensor, it has the same shape as x does.
|
---------
|
||||||
|
Tensor
|
||||||
|
The shuffled tensor, which has the same shape as x does.
|
||||||
"""
|
"""
|
||||||
size = x.shape[axis]
|
size = x.shape[axis]
|
||||||
if perm is not None and len(perm) != size:
|
if perm is not None and len(perm) != size:
|
||||||
|
|
|
@ -4,29 +4,128 @@ import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"weighted_mean",
|
||||||
|
"masked_l1_loss",
|
||||||
|
"masked_softmax_with_cross_entropy",
|
||||||
|
"diagonal_loss",
|
||||||
|
]
|
||||||
|
|
||||||
def weighted_mean(input, weight):
|
def weighted_mean(input, weight):
|
||||||
"""weighted mean.(It can also be used as masked mean.)
|
"""Weighted mean. It can also be used as masked mean.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
input (Tensor): input tensor, floating point dtype.
|
-----------
|
||||||
weight (Tensor): weight tensor with broadcastable shape.
|
input : Tensor
|
||||||
|
The input tensor.
|
||||||
|
weight : Tensor
|
||||||
|
The weight tensor with broadcastable shape with the input.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
----------
|
||||||
|
Tensor [shape=(1,)]
|
||||||
|
Weighted mean tensor with the same dtype as input.
|
||||||
|
|
||||||
|
Warnings
|
||||||
|
---------
|
||||||
|
This is not a mathematical weighted mean. It performs weighted sum and
|
||||||
|
simple average.
|
||||||
"""
|
"""
|
||||||
weight = paddle.cast(weight, input.dtype)
|
weight = paddle.cast(weight, input.dtype)
|
||||||
return paddle.mean(input * weight)
|
return paddle.mean(input * weight)
|
||||||
|
|
||||||
|
|
||||||
def masked_l1_loss(prediction, target, mask):
|
def masked_l1_loss(prediction, target, mask):
|
||||||
|
"""Compute maksed L1 loss.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prediction : Tensor
|
||||||
|
The prediction.
|
||||||
|
|
||||||
|
target : Tensor
|
||||||
|
The target. The shape should be broadcastable to ``prediction``.
|
||||||
|
|
||||||
|
mask : Tensor
|
||||||
|
The mask. The shape should be broadcatable to the broadcasted shape of
|
||||||
|
``prediction`` and ``target``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor [shape=(1,)]
|
||||||
|
The masked L1 loss.
|
||||||
|
"""
|
||||||
abs_error = F.l1_loss(prediction, target, reduction='none')
|
abs_error = F.l1_loss(prediction, target, reduction='none')
|
||||||
return weighted_mean(abs_error, mask)
|
loss = weighted_mean(abs_error, mask)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
||||||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
"""Compute masked softmax with cross entropy loss.
|
||||||
return weighted_mean(ce, mask)
|
|
||||||
|
|
||||||
def diagonal_loss(attentions, input_lengths, target_lengths, g=0.2, multihead=False):
|
Parameters
|
||||||
"""A metric to evaluate how diagonal a attention distribution is."""
|
----------
|
||||||
|
logits : Tensor
|
||||||
|
The logits. The ``axis``-th axis is the class dimension.
|
||||||
|
|
||||||
|
label : Tensor [dtype: int]
|
||||||
|
The label. The size of the ``axis``-th axis should be 1.
|
||||||
|
|
||||||
|
mask : Tensor
|
||||||
|
The mask. The shape should be broadcastable to ``label``.
|
||||||
|
|
||||||
|
axis : int, optional
|
||||||
|
The index of the class dimension in the shape of ``logits``, by default
|
||||||
|
-1.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor [shape=(1,)]
|
||||||
|
The masked softmax with cross entropy loss.
|
||||||
|
"""
|
||||||
|
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||||
|
loss = weighted_mean(ce, mask)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def diagonal_loss(
|
||||||
|
attentions,
|
||||||
|
input_lengths,
|
||||||
|
target_lengths,
|
||||||
|
g=0.2,
|
||||||
|
multihead=False):
|
||||||
|
"""A metric to evaluate how diagonal a attention distribution is.
|
||||||
|
|
||||||
|
It is computed for batch attention distributions. For each attention
|
||||||
|
distribution, the valid decoder time steps and encoder time steps may
|
||||||
|
differ.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
|
||||||
|
The attention weights from an encoder-decoder structure.
|
||||||
|
|
||||||
|
input_lengths : Tensor [shape=(B,)]
|
||||||
|
The valid length for each encoder output.
|
||||||
|
|
||||||
|
target_lengths : Tensor [shape=(B,)]
|
||||||
|
The valid length for each decoder output.
|
||||||
|
|
||||||
|
g : float, optional
|
||||||
|
[description], by default 0.2.
|
||||||
|
|
||||||
|
multihead : bool, optional
|
||||||
|
A flag indicating whether ``attentions`` is a multihead attention's
|
||||||
|
attention distribution.
|
||||||
|
|
||||||
|
If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by
|
||||||
|
default False.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor [shape=(1,)]
|
||||||
|
The diagonal loss.
|
||||||
|
"""
|
||||||
W = guided_attentions(input_lengths, target_lengths, g)
|
W = guided_attentions(input_lengths, target_lengths, g)
|
||||||
W_tensor = paddle.to_tensor(W)
|
W_tensor = paddle.to_tensor(W)
|
||||||
if not multihead:
|
if not multihead:
|
||||||
|
|
|
@ -1,32 +1,114 @@
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.fluid.layers import sequence_mask
|
from paddle.fluid.layers import sequence_mask
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"id_mask",
|
||||||
|
"feature_mask",
|
||||||
|
"combine_mask",
|
||||||
|
"future_mask",
|
||||||
|
]
|
||||||
|
|
||||||
def id_mask(input, padding_index=0, dtype="bool"):
|
def id_mask(input, padding_index=0, dtype="bool"):
|
||||||
|
"""Generate mask with input ids.
|
||||||
|
|
||||||
|
Those positions where the value equals ``padding_index`` correspond to 0 or
|
||||||
|
``False``, otherwise, 1 or ``True``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input : Tensor [dtype: int]
|
||||||
|
The input tensor. It represents the ids.
|
||||||
|
|
||||||
|
padding_index : int, optional
|
||||||
|
The id which represents padding, by default 0.
|
||||||
|
|
||||||
|
dtype : str, optional
|
||||||
|
Data type of the returned mask, by default "bool".
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
The generate mask. It has the same shape as ``input`` does.
|
||||||
|
"""
|
||||||
return paddle.cast(input != padding_index, dtype)
|
return paddle.cast(input != padding_index, dtype)
|
||||||
|
|
||||||
|
|
||||||
def feature_mask(input, axis, dtype="bool"):
|
def feature_mask(input, axis, dtype="bool"):
|
||||||
|
"""Compute mask from input features.
|
||||||
|
|
||||||
|
For a input features, represented as batched feature vectors, those vectors
|
||||||
|
which all zeros are considerd padding vectors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input : Tensor [dtype: float]
|
||||||
|
The input tensor which represents featues.
|
||||||
|
|
||||||
|
axis : int
|
||||||
|
The index of the feature dimension in ``input``. Other dimensions are
|
||||||
|
considered ``spatial`` dimensions.
|
||||||
|
|
||||||
|
dtype : str, optional
|
||||||
|
Data type of the generated mask, by default "bool"
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
The geenrated mask with ``spatial`` shape as mentioned above.
|
||||||
|
|
||||||
|
It has one less dimension than ``input`` does.
|
||||||
|
"""
|
||||||
feature_sum = paddle.sum(paddle.abs(input), axis)
|
feature_sum = paddle.sum(paddle.abs(input), axis)
|
||||||
return paddle.cast(feature_sum != 0, dtype)
|
return paddle.cast(feature_sum != 0, dtype)
|
||||||
|
|
||||||
def combine_mask(padding_mask, no_future_mask):
|
|
||||||
"""
|
|
||||||
Combine the padding mask and no future mask for transformer decoder.
|
|
||||||
Padding mask is used to mask padding positions and no future mask is used
|
|
||||||
to prevent the decoder to see future information.
|
|
||||||
|
|
||||||
Args:
|
def combine_mask(mask1, mask2):
|
||||||
padding_mask (Tensor): shape(batch_size, time_steps), dtype: float32 or float64, decoder padding mask.
|
"""Combine two mask with multiplication or logical and.
|
||||||
no_future_mask (Tensor): shape(time_steps, time_steps), dtype: float32 or float64, no future mask.
|
|
||||||
|
|
||||||
Returns:
|
Parameters
|
||||||
Tensor: shape(batch_size, time_steps, time_steps), combined mask.
|
-----------
|
||||||
|
mask1 : Tensor
|
||||||
|
The first mask.
|
||||||
|
|
||||||
|
mask2 : Tensor
|
||||||
|
The second mask with broadcastable shape with ``mask1``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
Tensor
|
||||||
|
Combined mask.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
It is mainly used to combine the padding mask and no future mask for
|
||||||
|
transformer decoder.
|
||||||
|
|
||||||
|
Padding mask is used to mask padding positions of the decoder inputs and
|
||||||
|
no future mask is used to prevent the decoder to see future information.
|
||||||
"""
|
"""
|
||||||
# TODO: to support boolean mask by using logical_and?
|
if mask1.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
|
||||||
if padding_mask.dtype == paddle.fluid.core.VarDesc.VarType.BOOL:
|
return paddle.logical_and(mask1, mask2)
|
||||||
return paddle.logical_and(padding_mask, no_future_mask)
|
|
||||||
else:
|
else:
|
||||||
return padding_mask * no_future_mask
|
return mask1 * mask2
|
||||||
|
|
||||||
|
|
||||||
def future_mask(time_steps, dtype="bool"):
|
def future_mask(time_steps, dtype="bool"):
|
||||||
|
"""Generate lower triangular mask.
|
||||||
|
|
||||||
|
It is used at transformer decoder to prevent the decoder to see future
|
||||||
|
information.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
time_steps : int
|
||||||
|
Decoder time steps.
|
||||||
|
dtype : str, optional
|
||||||
|
The data type of the generate mask, by default "bool".
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
The generated mask.
|
||||||
|
"""
|
||||||
mask = paddle.tril(paddle.ones([time_steps, time_steps]))
|
mask = paddle.tril(paddle.ones([time_steps, time_steps]))
|
||||||
return paddle.cast(mask, dtype)
|
return paddle.cast(mask, dtype)
|
||||||
|
|
|
@ -3,21 +3,34 @@ import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
__all__ = ["positional_encoding"]
|
||||||
|
|
||||||
def positional_encoding(start_index, length, size, dtype=None):
|
def positional_encoding(start_index, length, size, dtype=None):
|
||||||
"""
|
r"""Generate standard positional encoding matrix.
|
||||||
Generate standard positional encoding.
|
|
||||||
|
|
||||||
pe(pos, 2i) = sin(pos / 10000 ** (2i / size))
|
.. math::
|
||||||
pe(pos, 2i+1) = cos(pos / 10000 ** (2i / size))
|
|
||||||
|
|
||||||
Args:
|
pe(pos, 2i) = sin(\frac{pos}{10000^{\frac{2i}{size}}}) \\
|
||||||
start_index (int): the start index.
|
pe(pos, 2i+1) = cos(\frac{pos}{10000^{\frac{2i}{size}}})
|
||||||
length (int): the length of the positional encoding.
|
|
||||||
size (int): positional encoding dimension.
|
|
||||||
|
|
||||||
Returns:
|
Parameters
|
||||||
encodings (Tensor): shape(length, size), the positional encoding.
|
----------
|
||||||
|
start_index : int
|
||||||
|
The start index.
|
||||||
|
length : int
|
||||||
|
The timesteps of the positional encoding to generate.
|
||||||
|
size : int
|
||||||
|
Feature size of positional encoding.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor [shape=(length, size)]
|
||||||
|
The positional encoding.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If ``size`` is not divisible by 2.
|
||||||
"""
|
"""
|
||||||
if (size % 2 != 0):
|
if (size % 2 != 0):
|
||||||
raise ValueError("size should be divisible by 2")
|
raise ValueError("size should be divisible by 2")
|
||||||
|
|
|
@ -5,23 +5,35 @@ from paddle.nn import functional as F
|
||||||
|
|
||||||
from parakeet.modules import attention as attn
|
from parakeet.modules import attention as attn
|
||||||
from parakeet.modules.masking import combine_mask
|
from parakeet.modules.masking import combine_mask
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PositionwiseFFN",
|
||||||
|
"TransformerEncoderLayer",
|
||||||
|
"TransformerDecoderLayer",
|
||||||
|
]
|
||||||
|
|
||||||
class PositionwiseFFN(nn.Layer):
|
class PositionwiseFFN(nn.Layer):
|
||||||
"""
|
"""A faithful implementation of Position-wise Feed-Forward Network
|
||||||
A faithful implementation of Position-wise Feed-Forward Network
|
|
||||||
in `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
in `Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||||
It is basically a 3-layer MLP, with relu actication and dropout in between.
|
It is basically a 2-layer MLP, with relu actication and dropout in between.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_size: int
|
||||||
|
The feature size of the intput. It is also the feature size of the
|
||||||
|
output.
|
||||||
|
|
||||||
|
hidden_size: int
|
||||||
|
The hidden size.
|
||||||
|
|
||||||
|
dropout: float
|
||||||
|
The probability of the Dropout applied to the output of the first
|
||||||
|
layer, by default 0.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_size: int,
|
input_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dropout=0.0):
|
dropout=0.0):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_size (int): the input feature size.
|
|
||||||
hidden_size (int): the hidden layer's feature size.
|
|
||||||
dropout (float, optional): probability of dropout applied to the
|
|
||||||
output of the first fully connected layer. Defaults to 0.0.
|
|
||||||
"""
|
|
||||||
super(PositionwiseFFN, self).__init__()
|
super(PositionwiseFFN, self).__init__()
|
||||||
self.linear1 = nn.Linear(input_size, hidden_size)
|
self.linear1 = nn.Linear(input_size, hidden_size)
|
||||||
self.linear2 = nn.Linear(hidden_size, input_size)
|
self.linear2 = nn.Linear(hidden_size, input_size)
|
||||||
|
@ -31,13 +43,17 @@ class PositionwiseFFN(nn.Layer):
|
||||||
self.hidden_szie = hidden_size
|
self.hidden_szie = hidden_size
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""positionwise feed forward network.
|
r"""Forward pass of positionwise feed forward network.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(*, input_size), the input tensor.
|
----------
|
||||||
|
x : Tensor [shape=(\*, input_size)]
|
||||||
|
The input tensor, where ``\*`` means arbitary shape.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(*, input_size), the output tensor.
|
-------
|
||||||
|
Tensor [shape=(\*, input_size)]
|
||||||
|
The output tensor.
|
||||||
"""
|
"""
|
||||||
l1 = self.dropout(F.relu(self.linear1(x)))
|
l1 = self.dropout(F.relu(self.linear1(x)))
|
||||||
l2 = self.linear2(l1)
|
l2 = self.linear2(l1)
|
||||||
|
@ -45,56 +61,101 @@ class PositionwiseFFN(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Layer):
|
class TransformerEncoderLayer(nn.Layer):
|
||||||
"""
|
"""A faithful implementation of Transformer encoder layer in
|
||||||
Transformer encoder layer.
|
`Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
d_model :int
|
||||||
|
The feature size of the input. It is also the feature size of the
|
||||||
|
output.
|
||||||
|
|
||||||
|
n_heads : int
|
||||||
|
The number of heads of self attention (a ``MultiheadAttention``
|
||||||
|
layer).
|
||||||
|
|
||||||
|
d_ffn : int
|
||||||
|
The hidden size of the positional feed forward network (a
|
||||||
|
``PositionwiseFFN`` layer).
|
||||||
|
|
||||||
|
dropout : float, optional
|
||||||
|
The probability of the dropout in MultiHeadAttention and
|
||||||
|
PositionwiseFFN, by default 0.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
It uses the PostLN (post layer norm) scheme.
|
||||||
"""
|
"""
|
||||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
d_model (int): the feature size of the input, and the output.
|
|
||||||
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
|
|
||||||
d_ffn (int): the hidden size of the internal PositionwiseFFN.
|
|
||||||
dropout (float, optional): the probability of the dropout in
|
|
||||||
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
|
|
||||||
"""
|
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
|
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
|
||||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||||
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x, mask):
|
def forward(self, x, mask):
|
||||||
"""
|
"""Forward pass of TransformerEncoderLayer.
|
||||||
Args:
|
|
||||||
x (Tensor): shape(batch_size, time_steps, d_model), the decoder input.
|
|
||||||
mask (Tensor): shape(batch_size, time_steps), the padding mask.
|
|
||||||
|
|
||||||
Returns:
|
Parameters
|
||||||
(x, attn_weights)
|
----------
|
||||||
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
|
x : Tensor [shape=(batch_size, time_steps, d_model)]
|
||||||
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
|
The input.
|
||||||
"""
|
|
||||||
context_vector, attn_weights = self.self_mha(x, x, x, paddle.unsqueeze(mask, 1))
|
mask : Tensor
|
||||||
x = self.layer_norm1(x + context_vector)
|
The padding mask. The shape is (batch_size, time_steps,
|
||||||
|
time_steps) or broadcastable shape.
|
||||||
|
|
||||||
x = self.layer_norm2(x + self.ffn(x))
|
Returns
|
||||||
|
-------
|
||||||
|
x :Tensor [shape=(batch_size, time_steps, d_model)]
|
||||||
|
The encoded output.
|
||||||
|
|
||||||
|
attn_weights : Tensor [shape=(batch_size, n_heads, time_steps, time_steps)]
|
||||||
|
The attention weights of the self attention.
|
||||||
|
"""
|
||||||
|
context_vector, attn_weights = self.self_mha(x, x, x, mask)
|
||||||
|
x = self.layer_norm1(
|
||||||
|
F.dropout(x + context_vector,
|
||||||
|
self.dropout,
|
||||||
|
training=self.training))
|
||||||
|
|
||||||
|
x = self.layer_norm2(
|
||||||
|
F.dropout(x + self.ffn(x),
|
||||||
|
self.dropout,
|
||||||
|
training=self.training))
|
||||||
return x, attn_weights
|
return x, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Layer):
|
class TransformerDecoderLayer(nn.Layer):
|
||||||
"""
|
"""A faithful implementation of Transformer decoder layer in
|
||||||
Transformer decoder layer.
|
`Attention is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
d_model :int
|
||||||
|
The feature size of the input. It is also the feature size of the
|
||||||
|
output.
|
||||||
|
|
||||||
|
n_heads : int
|
||||||
|
The number of heads of attentions (``MultiheadAttention``
|
||||||
|
layers).
|
||||||
|
|
||||||
|
d_ffn : int
|
||||||
|
The hidden size of the positional feed forward network (a
|
||||||
|
``PositionwiseFFN`` layer).
|
||||||
|
|
||||||
|
dropout : float, optional
|
||||||
|
The probability of the dropout in MultiHeadAttention and
|
||||||
|
PositionwiseFFN, by default 0.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
It uses the PostLN (post layer norm) scheme.
|
||||||
"""
|
"""
|
||||||
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
def __init__(self, d_model, n_heads, d_ffn, dropout=0.):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
d_model (int): the feature size of the input, and the output.
|
|
||||||
n_heads (int): the number of heads in the internal MultiHeadAttention layer.
|
|
||||||
d_ffn (int): the hidden size of the internal PositionwiseFFN.
|
|
||||||
dropout (float, optional): the probability of the dropout in
|
|
||||||
MultiHeadAttention and PositionwiseFFN. Defaults to 0.
|
|
||||||
"""
|
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
super(TransformerDecoderLayer, self).__init__()
|
||||||
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
|
self.self_mha = attn.MultiheadAttention(d_model, n_heads, dropout)
|
||||||
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
@ -104,30 +165,52 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
|
|
||||||
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
self.ffn = PositionwiseFFN(d_model, d_ffn, dropout)
|
||||||
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6)
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, q, k, v, encoder_mask, decoder_mask):
|
def forward(self, q, k, v, encoder_mask, decoder_mask):
|
||||||
|
"""Forward pass of TransformerEncoderLayer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
q : Tensor [shape=(batch_size, time_steps_q, d_model)]
|
||||||
|
The decoder input.
|
||||||
|
k : Tensor [shape=(batch_size, time_steps_k, d_model)]
|
||||||
|
The keys.
|
||||||
|
v : Tensor [shape=(batch_size, time_steps_k, d_model)]
|
||||||
|
The values
|
||||||
|
encoder_mask : Tensor
|
||||||
|
Encoder padding mask, shape is ``(batch_size, time_steps_k,
|
||||||
|
time_steps_k)`` or broadcastable shape.
|
||||||
|
decoder_mask : Tensor
|
||||||
|
Decoder mask, shape is ``(batch_size, time_steps_q, time_steps_k)``
|
||||||
|
or broadcastable shape.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
q : Tensor [shape=(batch_size, time_steps_q, d_model)]
|
||||||
|
The decoder output.
|
||||||
|
|
||||||
|
self_attn_weights : Tensor [shape=(batch_size, n_heads, time_steps_q, time_steps_q)]
|
||||||
|
Decoder self attention.
|
||||||
|
|
||||||
|
cross_attn_weights : Tensor [shape=(batch_size, n_heads, time_steps_q, time_steps_k)]
|
||||||
|
Decoder-encoder cross attention.
|
||||||
"""
|
"""
|
||||||
Args:
|
context_vector, self_attn_weights = self.self_mha(q, q, q, decoder_mask)
|
||||||
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoder input.
|
q = self.layer_norm1(
|
||||||
k (Tensor): shape(batch_size, time_steps_k, d_model), keys.
|
F.dropout(q + context_vector,
|
||||||
v (Tensor): shape(batch_size, time_steps_k, d_model), values
|
self.dropout,
|
||||||
encoder_mask (Tensor): shape(batch_size, time_steps_k) encoder padding mask.
|
training=self.training))
|
||||||
decoder_mask (Tensor): shape(batch_size, time_steps_q) decoder padding mask.
|
|
||||||
|
|
||||||
Returns:
|
context_vector, cross_attn_weights = self.cross_mha(q, k, v, encoder_mask)
|
||||||
(q, self_attn_weights, cross_attn_weights)
|
q = self.layer_norm2(
|
||||||
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
|
F.dropout(q + context_vector,
|
||||||
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
|
self.dropout,
|
||||||
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
|
training=self.training))
|
||||||
"""
|
|
||||||
tq = q.shape[1]
|
|
||||||
no_future_mask = paddle.tril(paddle.ones([tq, tq])) #(tq, tq)
|
|
||||||
combined_mask = combine_mask(decoder_mask.unsqueeze(1), no_future_mask)
|
|
||||||
context_vector, self_attn_weights = self.self_mha(q, q, q, combined_mask)
|
|
||||||
q = self.layer_norm1(q + context_vector)
|
|
||||||
|
|
||||||
context_vector, cross_attn_weights = self.cross_mha(q, k, v, paddle.unsqueeze(encoder_mask, 1))
|
q = self.layer_norm3(
|
||||||
q = self.layer_norm2(q + context_vector)
|
F.dropout(q + self.ffn(q),
|
||||||
|
self.dropout,
|
||||||
q = self.layer_norm3(q + self.ffn(q))
|
training=self.training))
|
||||||
return q, self_attn_weights, cross_attn_weights
|
return q, self_attn_weights, cross_attn_weights
|
||||||
|
|
Loading…
Reference in New Issue