update docstrings for tacotron

This commit is contained in:
chenfeiyu 2021-05-07 16:08:31 +08:00
parent f197e4d04f
commit b9aa61b5eb
1 changed files with 61 additions and 8 deletions

View File

@ -266,6 +266,10 @@ class Tacotron2Decoder(nn.Layer):
p_decoder_dropout: float
The droput probability in decoder.
use_stop_token: bool
Whether to use a binary classifier for stop token prediction.
Defaults to False
"""
def __init__(self,
d_mels: int,
@ -530,8 +534,12 @@ class Tacotron2(nn.Layer):
Parameters
----------
frontend : parakeet.frontend.Phonetics
Frontend used to preprocess text.
vocab_size : int
Vocabulary size of phons of the model.
n_tones: int
Vocabulary size of tones of the model. Defaults to None. If provided,
the model has an extra tone embedding.
d_mels: int
Number of mel bands.
@ -590,6 +598,11 @@ class Tacotron2(nn.Layer):
p_postnet_dropout: float
Droput probability in postnet.
d_global_condition: int
Feature size of global condition. Defaults to None. If provided, The
model assumes a global condition that is concatenated to the encoder
outputs.
"""
def __init__(self,
vocab_size,
@ -669,15 +682,27 @@ class Tacotron2(nn.Layer):
text_inputs: Tensor [shape=(B, T_text)]
Batch of the sequencees of padded character ids.
mels: Tensor [shape(B, T_mel, C)]
Batch of the sequences of padded mel spectrogram.
text_lens: Tensor [shape=(B,)]
Batch of lengths of each text input batch.
mels: Tensor [shape(B, T_mel, C)]
Batch of the sequences of padded mel spectrogram.
output_lens: Tensor [shape=(B,)], optional
Batch of lengths of each mels batch. Defaults to None.
tones: Tensor [shape=(B, T_text)]
Batch of sequences of padded tone ids.
global_condition: Tensor [shape(B, C)]
Batch of global conditions. Defaults to None. If the
`d_global_condition` of the model is not None, this input should be
provided.
use_stop_token: bool
Whether to include a binary classifier to predict the stop token.
Defaults to False.
Returns
-------
outputs : Dict[str, Tensor]
@ -686,9 +711,9 @@ class Tacotron2(nn.Layer):
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
stop_logits: output sequence of stop logits (B, T_mel);
alignments: attention weights (B, T_mel, T_text);
alignments: attention weights (B, T_mel, T_text).
stop_logits: output sequence of stop logits (B, T_mel)
"""
embedded_inputs = self.embedding(text_inputs)
if self.toned:
@ -757,7 +782,8 @@ class Tacotron2(nn.Layer):
stop_logits: output sequence of stop logits (B, T_mel);
alignments: attention weights (B, T_mel, T_text).
alignments: attention weights (B, T_mel, T_text). This key is only
present when `use_stop_token` is True.
"""
embedded_inputs = self.embedding(text_inputs)
if self.toned:
@ -841,6 +867,13 @@ class Tacotron2Loss(nn.Layer):
use_stop_token_loss=True,
use_guided_attention_loss=False,
sigma=0.2):
"""Tacotron 2 Criterion.
Args:
use_stop_token_loss (bool, optional): Whether to use a loss for stop token prediction. Defaults to True.
use_guided_attention_loss (bool, optional): Whether to use a loss for attention weights. Defaults to False.
sigma (float, optional): Hyper-parameter sigma for guided attention loss. Defaults to 0.2.
"""
super().__init__()
self.spec_criterion = nn.MSELoss()
self.use_stop_token_loss = use_stop_token_loss
@ -870,6 +903,22 @@ class Tacotron2Loss(nn.Layer):
mel_targets: Tensor [shape=(B, T_mel, C)]
Target mel spectrogram sequence.
attention_weights: Tensor [shape=(B, T_mel, T_enc)]
Attention weights. This should be provided when
`use_guided_attention_loss` is True.
slens: Tensor [shape=(B,)]
Number of frames of mel spectrograms. This should be provided when
`use_guided_attention_loss` is True.
plens: Tensor [shape=(B, )]
Number of text or phone ids of each utterance. This should be
provided when `use_guided_attention_loss` is True.
stop_logits: Tensor [shape=(B, T_mel)]
Stop logits of each mel spectrogram frame. This should be provided
when `use_stop_token_loss` is True.
Returns
-------
losses : Dict[str, Tensor]
@ -879,6 +928,10 @@ class Tacotron2Loss(nn.Layer):
mel_loss: MSE loss compute by mel_targets and mel_outputs;
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
guided_attn_loss: Guided attention loss for attention weights;
stop_loss: Binary cross entropy loss for stop token prediction.
"""
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)