add docstring for LocationSensitiveAttention
This commit is contained in:
parent
dd2c5cc6c6
commit
1af9127ee6
|
@ -32,16 +32,16 @@ class DecoderPreNet(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
d_input: int
|
d_input: int
|
||||||
input feature size
|
The input feature size.
|
||||||
|
|
||||||
d_hidden: int
|
d_hidden: int
|
||||||
hidden size
|
The hidden size.
|
||||||
|
|
||||||
d_output: int
|
d_output: int
|
||||||
output feature size
|
The output feature size.
|
||||||
|
|
||||||
dropout_rate: float
|
dropout_rate: float
|
||||||
droput probability
|
The droput probability.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class DecoderPreNet(nn.Layer):
|
||||||
d_input: int,
|
d_input: int,
|
||||||
d_hidden: int,
|
d_hidden: int,
|
||||||
d_output: int,
|
d_output: int,
|
||||||
dropout_rate: float=0.2):
|
dropout_rate: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
|
@ -62,12 +62,12 @@ class DecoderPreNet(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x: Tensor [shape=(B, T_mel, C)]
|
x: Tensor [shape=(B, T_mel, C)]
|
||||||
batch of the sequences of padded mel spectrogram
|
Batch of the sequences of padded mel spectrogram.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
output: Tensor [shape=(B, T_mel, C)]
|
output: Tensor [shape=(B, T_mel, C)]
|
||||||
batch of the sequences of padded hidden state
|
Batch of the sequences of padded hidden state.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -82,28 +82,28 @@ class DecoderPostNet(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
d_mels: int
|
d_mels: int
|
||||||
number of mel bands
|
The number of mel bands.
|
||||||
|
|
||||||
d_hidden: int
|
d_hidden: int
|
||||||
hidden size of postnet
|
The hidden size of postnet.
|
||||||
|
|
||||||
kernel_size: int
|
kernel_size: int
|
||||||
kernel size of the conv layer in postnet
|
The kernel size of the conv layer in postnet.
|
||||||
|
|
||||||
num_layers: int
|
num_layers: int
|
||||||
number of conv layers in postnet
|
The number of conv layers in postnet.
|
||||||
|
|
||||||
dropout: float
|
dropout: float
|
||||||
droput probability
|
The droput probability.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
d_mels: int=80,
|
d_mels: int,
|
||||||
d_hidden: int=512,
|
d_hidden: int,
|
||||||
kernel_size: int=5,
|
kernel_size: int,
|
||||||
num_layers: int=5,
|
num_layers: int,
|
||||||
dropout: float=0.1):
|
dropout: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
@ -150,12 +150,12 @@ class DecoderPostNet(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input: Tensor [shape=(B, T_mel, C)]
|
input: Tensor [shape=(B, T_mel, C)]
|
||||||
output sequence of features from decoder
|
Output sequence of features from decoder.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
output: Tensor [shape=(B, T_mel, C)]
|
output: Tensor [shape=(B, T_mel, C)]
|
||||||
output sequence of features after postnet
|
Output sequence of features after postnet.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -173,16 +173,16 @@ class Tacotron2Encoder(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
d_hidden: int
|
d_hidden: int
|
||||||
hidden size in encoder module
|
The hidden size in encoder module.
|
||||||
|
|
||||||
conv_layers: int
|
conv_layers: int
|
||||||
number of conv layers
|
The number of conv layers.
|
||||||
|
|
||||||
kernel_size: int
|
kernel_size: int
|
||||||
kernel size of conv layers
|
The kernel size of conv layers.
|
||||||
|
|
||||||
p_dropout: float
|
p_dropout: float
|
||||||
droput probability
|
The droput probability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -216,15 +216,15 @@ class Tacotron2Encoder(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x: Tensor [shape=(B, T)]
|
x: Tensor [shape=(B, T)]
|
||||||
batch of the sequencees of padded character ids
|
Batch of the sequencees of padded character ids.
|
||||||
|
|
||||||
text_lens: Tensor [shape=(B,)]
|
text_lens: Tensor [shape=(B,)], optional
|
||||||
batch of lengths of each text input batch.
|
Batch of lengths of each text input batch. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
output : Tensor [shape=(B, T, C)]
|
output : Tensor [shape=(B, T, C)]
|
||||||
batch of the sequences of padded hidden states
|
Batch of the sequences of padded hidden states.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for conv_batchnorm in self.conv_batchnorms:
|
for conv_batchnorm in self.conv_batchnorms:
|
||||||
|
@ -241,40 +241,40 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
d_mels: int
|
d_mels: int
|
||||||
number of mel bands
|
The number of mel bands.
|
||||||
|
|
||||||
reduction_factor: int
|
reduction_factor: int
|
||||||
reduction factor of tacotron
|
The reduction factor of tacotron.
|
||||||
|
|
||||||
d_encoder: int
|
d_encoder: int
|
||||||
hidden size of encoder
|
The hidden size of encoder.
|
||||||
|
|
||||||
d_prenet: int
|
d_prenet: int
|
||||||
hidden size in decoder prenet
|
The hidden size in decoder prenet.
|
||||||
|
|
||||||
d_attention_rnn: int
|
d_attention_rnn: int
|
||||||
attention rnn layer hidden size
|
The attention rnn layer hidden size.
|
||||||
|
|
||||||
d_decoder_rnn: int
|
d_decoder_rnn: int
|
||||||
decoder rnn layer hidden size
|
The decoder rnn layer hidden size.
|
||||||
|
|
||||||
d_attention: int
|
d_attention: int
|
||||||
hidden size of the linear layer in location sensitive attention
|
The hidden size of the linear layer in location sensitive attention.
|
||||||
|
|
||||||
attention_filters: int
|
attention_filters: int
|
||||||
filter size of the conv layer in location sensitive attention
|
The filter size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
attention_kernel_size: int
|
attention_kernel_size: int
|
||||||
kernel size of the conv layer in location sensitive attention
|
The kernel size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
p_prenet_dropout: float
|
p_prenet_dropout: float
|
||||||
droput probability in decoder prenet
|
The droput probability in decoder prenet.
|
||||||
|
|
||||||
p_attention_dropout: float
|
p_attention_dropout: float
|
||||||
droput probability in location sensitive attention
|
The droput probability in location sensitive attention.
|
||||||
|
|
||||||
p_decoder_dropout: float
|
p_decoder_dropout: float
|
||||||
droput probability in decoder
|
The droput probability in decoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -382,25 +382,25 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
keys: Tensor[shape=(B, T_text, C)]
|
keys: Tensor[shape=(B, T_key, C)]
|
||||||
batch of the sequences of padded output from encoder
|
Batch of the sequences of padded output from encoder.
|
||||||
|
|
||||||
querys: Tensor[shape(B, T_mel, C)]
|
querys: Tensor[shape(B, T_query, C)]
|
||||||
batch of the sequences of padded mel spectrogram
|
Batch of the sequences of padded mel spectrogram.
|
||||||
|
|
||||||
mask: Tensor[shape=(B, T_text, 1)]
|
mask: Tensor
|
||||||
mask generated with text length
|
Mask generated with text length. Shape should be (B, T_key, T_query) or broadcastable shape.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
mel_output: Tensor [shape=(B, T_mel, C)]
|
mel_output: Tensor [shape=(B, T_query, C)]
|
||||||
output sequence of features
|
Output sequence of features.
|
||||||
|
|
||||||
stop_logits: Tensor [shape=(B, T_mel)]
|
stop_logits: Tensor [shape=(B, T_query)]
|
||||||
output sequence of stop logits
|
Output sequence of stop logits.
|
||||||
|
|
||||||
alignments: Tensor [shape=(B, T_mel, T_text)]
|
alignments: Tensor [shape=(B, T_query, T_key)]
|
||||||
attention weights
|
Attention weights.
|
||||||
"""
|
"""
|
||||||
querys = paddle.reshape(
|
querys = paddle.reshape(
|
||||||
querys,
|
querys,
|
||||||
|
@ -437,25 +437,25 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
keys: Tensor [shape=(B, T_text, C)]
|
keys: Tensor [shape=(B, T_key, C)]
|
||||||
batch of the sequences of padded output from encoder
|
Batch of the sequences of padded output from encoder.
|
||||||
|
|
||||||
stop_threshold: float
|
stop_threshold: float, optional
|
||||||
stop synthesize when stop logit is greater than this stop threshold
|
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||||
|
|
||||||
max_decoder_steps: int
|
max_decoder_steps: int, optional
|
||||||
number of max step when synthesize
|
Number of max step when synthesize. Defaults to 1000.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
mel_output: Tensor [shape=(B, T_mel, C)]
|
mel_output: Tensor [shape=(B, T_mel, C)]
|
||||||
output sequence of features
|
Output sequence of features.
|
||||||
|
|
||||||
stop_logits: Tensor [shape=(B, T_mel)]
|
stop_logits: Tensor [shape=(B, T_mel)]
|
||||||
output sequence of stop logits
|
Output sequence of stop logits.
|
||||||
|
|
||||||
alignments: Tensor [shape=(B, T_mel, T_text)]
|
alignments: Tensor [shape=(B, T_mel, T_key)]
|
||||||
attention weights
|
Attention weights.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
query = paddle.zeros(
|
query = paddle.zeros(
|
||||||
|
@ -493,75 +493,72 @@ class Tacotron2(nn.Layer):
|
||||||
"""Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
"""Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
||||||
|
|
||||||
This is a model of Spectrogram prediction network in Tacotron2 described
|
This is a model of Spectrogram prediction network in Tacotron2 described
|
||||||
in ``Natural TTS Synthesis
|
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
|
||||||
by Conditioning WaveNet on Mel Spectrogram Predictions``,
|
<https://arxiv.org/abs/1712.05884>`_,
|
||||||
which converts the sequence of characters
|
which converts the sequence of characters
|
||||||
into the sequence of mel spectrogram.
|
into the sequence of mel spectrogram.
|
||||||
|
|
||||||
`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
|
|
||||||
<https://arxiv.org/abs/1712.05884>`_.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
frontend : parakeet.frontend.Phonetics
|
frontend : parakeet.frontend.Phonetics
|
||||||
frontend used to preprocess text
|
Frontend used to preprocess text.
|
||||||
|
|
||||||
d_mels: int
|
d_mels: int
|
||||||
number of mel bands
|
Number of mel bands.
|
||||||
|
|
||||||
d_encoder: int
|
d_encoder: int
|
||||||
hidden size in encoder module
|
Hidden size in encoder module.
|
||||||
|
|
||||||
encoder_conv_layers: int
|
encoder_conv_layers: int
|
||||||
number of conv layers in encoder
|
Number of conv layers in encoder.
|
||||||
|
|
||||||
encoder_kernel_size: int
|
encoder_kernel_size: int
|
||||||
kernel size of conv layers in encoder
|
Kernel size of conv layers in encoder.
|
||||||
|
|
||||||
d_prenet: int
|
d_prenet: int
|
||||||
hidden size in decoder prenet
|
Hidden size in decoder prenet.
|
||||||
|
|
||||||
d_attention_rnn: int
|
d_attention_rnn: int
|
||||||
attention rnn layer hidden size in decoder
|
Attention rnn layer hidden size in decoder.
|
||||||
|
|
||||||
d_decoder_rnn: int
|
d_decoder_rnn: int
|
||||||
decoder rnn layer hidden size in decoder
|
Decoder rnn layer hidden size in decoder.
|
||||||
|
|
||||||
attention_filters: int
|
attention_filters: int
|
||||||
filter size of the conv layer in location sensitive attention
|
Filter size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
attention_kernel_size: int
|
attention_kernel_size: int
|
||||||
kernel size of the conv layer in location sensitive attention
|
Kernel size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
d_attention: int
|
d_attention: int
|
||||||
hidden size of the linear layer in location sensitive attention
|
Hidden size of the linear layer in location sensitive attention.
|
||||||
|
|
||||||
d_postnet: int
|
d_postnet: int
|
||||||
hidden size of postnet
|
Hidden size of postnet.
|
||||||
|
|
||||||
postnet_kernel_size: int
|
postnet_kernel_size: int
|
||||||
kernel size of the conv layer in postnet
|
Kernel size of the conv layer in postnet.
|
||||||
|
|
||||||
postnet_conv_layers: int
|
postnet_conv_layers: int
|
||||||
number of conv layers in postnet
|
Number of conv layers in postnet.
|
||||||
|
|
||||||
reduction_factor: int
|
reduction_factor: int
|
||||||
reduction factor of tacotron
|
Reduction factor of tacotron2.
|
||||||
|
|
||||||
p_encoder_dropout: float
|
p_encoder_dropout: float
|
||||||
droput probability in encoder
|
Droput probability in encoder.
|
||||||
|
|
||||||
p_prenet_dropout: float
|
p_prenet_dropout: float
|
||||||
droput probability in decoder prenet
|
Droput probability in decoder prenet.
|
||||||
|
|
||||||
p_attention_dropout: float
|
p_attention_dropout: float
|
||||||
droput probability in location sensitive attention
|
Droput probability in location sensitive attention.
|
||||||
|
|
||||||
p_decoder_dropout: float
|
p_decoder_dropout: float
|
||||||
droput probability in decoder
|
Droput probability in decoder.
|
||||||
|
|
||||||
p_postnet_dropout: float
|
p_postnet_dropout: float
|
||||||
droput probability in postnet
|
Droput probability in postnet.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -616,28 +613,28 @@ class Tacotron2(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
text_inputs: Tensor [shape=(B, T_text)]
|
text_inputs: Tensor [shape=(B, T_text)]
|
||||||
batch of the sequencees of padded character ids
|
Batch of the sequencees of padded character ids.
|
||||||
|
|
||||||
mels: Tensor [shape(B, T_mel, C)]
|
mels: Tensor [shape(B, T_mel, C)]
|
||||||
batch of the sequences of padded mel spectrogram
|
Batch of the sequences of padded mel spectrogram.
|
||||||
|
|
||||||
text_lens: Tensor [shape=(B,)]
|
text_lens: Tensor [shape=(B,)]
|
||||||
batch of lengths of each text input batch.
|
Batch of lengths of each text input batch.
|
||||||
|
|
||||||
output_lens: Tensor [shape=(B,)]
|
output_lens: Tensor [shape=(B,)], optional
|
||||||
batch of lengths of each mels batch.
|
Batch of lengths of each mels batch. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
outputs : Dict[str, Tensor]
|
outputs : Dict[str, Tensor]
|
||||||
|
|
||||||
mel_output: output sequence of features (B, T_mel, C)
|
mel_output: output sequence of features (B, T_mel, C);
|
||||||
|
|
||||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C)
|
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
|
||||||
|
|
||||||
stop_logits: output sequence of stop logits (B, T_mel)
|
stop_logits: output sequence of stop logits (B, T_mel);
|
||||||
|
|
||||||
alignments: attention weights (B, T_mel, T_text)
|
alignments: attention weights (B, T_mel, T_text).
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text_inputs)
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||||
|
@ -675,25 +672,25 @@ class Tacotron2(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
text_inputs: Tensor [shape=(B, T_text)]
|
text_inputs: Tensor [shape=(B, T_text)]
|
||||||
batch of the sequencees of padded character ids
|
Batch of the sequencees of padded character ids.
|
||||||
|
|
||||||
stop_threshold: float
|
stop_threshold: float, optional
|
||||||
stop synthesize when stop logit is greater than this stop threshold
|
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||||
|
|
||||||
max_decoder_steps: int
|
max_decoder_steps: int, optional
|
||||||
number of max step when synthesize
|
Number of max step when synthesize. Defaults to 1000.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
outputs : Dict[str, Tensor]
|
outputs : Dict[str, Tensor]
|
||||||
|
|
||||||
mel_output: output sequence of sepctrogram (B, T_mel, C)
|
mel_output: output sequence of sepctrogram (B, T_mel, C);
|
||||||
|
|
||||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C)
|
mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C);
|
||||||
|
|
||||||
stop_logits: output sequence of stop logits (B, T_mel)
|
stop_logits: output sequence of stop logits (B, T_mel);
|
||||||
|
|
||||||
alignments: attention weights (B, T_mel, T_text)
|
alignments: attention weights (B, T_mel, T_text).
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text_inputs)
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
encoder_outputs = self.encoder(embedded_inputs)
|
encoder_outputs = self.encoder(embedded_inputs)
|
||||||
|
@ -721,21 +718,21 @@ class Tacotron2(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
text: str
|
text: str
|
||||||
sequence of characters
|
Sequence of characters.
|
||||||
|
|
||||||
stop_threshold: float
|
stop_threshold: float, optional
|
||||||
stop synthesize when stop logit is greater than this stop threshold
|
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||||
|
|
||||||
max_decoder_steps: int
|
max_decoder_steps: int, optional
|
||||||
number of max step when synthesize
|
Number of max step when synthesize. Defaults to 1000.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
outputs : Dict[str, Tensor]
|
outputs : Dict[str, Tensor]
|
||||||
|
|
||||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C)
|
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C);
|
||||||
|
|
||||||
alignments: attention weights (T_mel, T_text)
|
alignments: attention weights (T_mel, T_text).
|
||||||
"""
|
"""
|
||||||
ids = np.asarray(self.frontend(text))
|
ids = np.asarray(self.frontend(text))
|
||||||
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
||||||
|
@ -750,21 +747,21 @@ class Tacotron2(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
frontend: parakeet.frontend.Phonetics
|
frontend: parakeet.frontend.Phonetics
|
||||||
frontend used to preprocess text
|
Frontend used to preprocess text.
|
||||||
|
|
||||||
config: yacs.config.CfgNode
|
config: yacs.config.CfgNode
|
||||||
model configs
|
Model configs.
|
||||||
|
|
||||||
checkpoint_path: Path
|
checkpoint_path: Path
|
||||||
the path of pretrained model checkpoint
|
The path of pretrained model checkpoint.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
mel_outputs_postnet: Tensor [shape=(T_mel, C)]
|
mel_outputs_postnet: Tensor [shape=(T_mel, C)]
|
||||||
output sequence of sepctrogram after postnet
|
Output sequence of sepctrogram after postnet.
|
||||||
|
|
||||||
alignments: Tensor [shape=(T_mel, T_text)]
|
alignments: Tensor [shape=(T_mel, T_text)]
|
||||||
attention weights
|
Attention weights.
|
||||||
"""
|
"""
|
||||||
model = cls(frontend,
|
model = cls(frontend,
|
||||||
d_mels=config.data.d_mels,
|
d_mels=config.data.d_mels,
|
||||||
|
@ -805,31 +802,31 @@ class Tacotron2Loss(nn.Layer):
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mel_outputs: Tensor [shape=(B, T_mel, C)]
|
mel_outputs: Tensor [shape=(B, T_mel, C)]
|
||||||
output mel spectrogram sequence
|
Output mel spectrogram sequence.
|
||||||
|
|
||||||
mel_outputs_postnet: Tensor [shape(B, T_mel, C)]
|
mel_outputs_postnet: Tensor [shape(B, T_mel, C)]
|
||||||
output mel spectrogram sequence after postnet
|
Output mel spectrogram sequence after postnet.
|
||||||
|
|
||||||
stop_logits: Tensor [shape=(B, T_mel)]
|
stop_logits: Tensor [shape=(B, T_mel)]
|
||||||
output sequence of stop logits befor sigmoid
|
Output sequence of stop logits befor sigmoid.
|
||||||
|
|
||||||
mel_targets: Tensor [shape=(B, T_mel, C)]
|
mel_targets: Tensor [shape=(B, T_mel, C)]
|
||||||
target mel spectrogram sequence
|
Target mel spectrogram sequence.
|
||||||
|
|
||||||
stop_tokens: Tensor [shape=(B,)]
|
stop_tokens: Tensor [shape=(B,)]
|
||||||
target stop token
|
Target stop token.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
losses : Dict[str, Tensor]
|
losses : Dict[str, Tensor]
|
||||||
|
|
||||||
loss: the sum of the other three losses
|
loss: the sum of the other three losses;
|
||||||
|
|
||||||
mel_loss: MSE loss compute by mel_targets and mel_outputs
|
mel_loss: MSE loss compute by mel_targets and mel_outputs;
|
||||||
|
|
||||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet
|
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
|
||||||
|
|
||||||
stop_loss: stop loss computed by stop_logits and stop token
|
stop_loss: stop loss computed by stop_logits and stop token.
|
||||||
"""
|
"""
|
||||||
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
||||||
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(q,
|
def scaled_dot_product_attention(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
|
@ -139,10 +140,11 @@ class MonoheadAttention(nn.Layer):
|
||||||
Feature size of the key of each scaled dot product attention. If not
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
model_dim: int,
|
def __init__(self,
|
||||||
dropout: float=0.0,
|
model_dim: int,
|
||||||
k_dim: int=None,
|
dropout: float=0.0,
|
||||||
|
k_dim: int=None,
|
||||||
v_dim: int=None):
|
v_dim: int=None):
|
||||||
super(MonoheadAttention, self).__init__()
|
super(MonoheadAttention, self).__init__()
|
||||||
k_dim = k_dim or model_dim
|
k_dim = k_dim or model_dim
|
||||||
|
@ -219,6 +221,7 @@ class MultiheadAttention(nn.Layer):
|
||||||
ValueError
|
ValueError
|
||||||
If ``model_dim`` is not divisible by ``num_heads``.
|
If ``model_dim`` is not divisible by ``num_heads``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dim: int,
|
model_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
@ -279,6 +282,28 @@ class MultiheadAttention(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class LocationSensitiveAttention(nn.Layer):
|
class LocationSensitiveAttention(nn.Layer):
|
||||||
|
"""Location Sensitive Attention module.
|
||||||
|
|
||||||
|
Reference: `Attention-Based Models for Speech Recognition <https://arxiv.org/pdf/1506.07503.pdf>`_
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
d_query: int
|
||||||
|
The feature size of query.
|
||||||
|
|
||||||
|
d_key : int
|
||||||
|
The feature size of key.
|
||||||
|
|
||||||
|
d_attention : int
|
||||||
|
The feature size of dimension.
|
||||||
|
|
||||||
|
location_filters : int
|
||||||
|
Filter size of attention convolution.
|
||||||
|
|
||||||
|
location_kernel_size : int
|
||||||
|
Kernel size of attention convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
d_query: int,
|
d_query: int,
|
||||||
d_key: int,
|
d_key: int,
|
||||||
|
@ -310,6 +335,34 @@ class LocationSensitiveAttention(nn.Layer):
|
||||||
value,
|
value,
|
||||||
attention_weights_cat,
|
attention_weights_cat,
|
||||||
mask=None):
|
mask=None):
|
||||||
|
"""Compute context vector and attention weights.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
query : Tensor [shape=(batch_size, d_query)]
|
||||||
|
The queries.
|
||||||
|
|
||||||
|
processed_key : Tensor [shape=(batch_size, time_steps_k, d_attention)]
|
||||||
|
The keys after linear layer.
|
||||||
|
|
||||||
|
value : Tensor [shape=(batch_size, time_steps_k, d_key)]
|
||||||
|
The values.
|
||||||
|
|
||||||
|
attention_weights_cat : Tensor [shape=(batch_size, time_step_k, 2)]
|
||||||
|
Attention weights concat.
|
||||||
|
|
||||||
|
mask : Tensor, optional
|
||||||
|
The mask. Shape should be (batch_size, times_steps_q, time_steps_k) or broadcastable shape.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
attention_context : Tensor [shape=(batch_size, time_steps_q, d_attention)]
|
||||||
|
The context vector.
|
||||||
|
|
||||||
|
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||||
|
The attention weights.
|
||||||
|
"""
|
||||||
|
|
||||||
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||||
processed_attention_weights = self.location_layer(
|
processed_attention_weights = self.location_layer(
|
||||||
|
|
Loading…
Reference in New Issue