update docstrings for tacotron
This commit is contained in:
parent
f197e4d04f
commit
b9aa61b5eb
|
@ -266,6 +266,10 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
p_decoder_dropout: float
|
||||
The droput probability in decoder.
|
||||
|
||||
use_stop_token: bool
|
||||
Whether to use a binary classifier for stop token prediction.
|
||||
Defaults to False
|
||||
"""
|
||||
def __init__(self,
|
||||
d_mels: int,
|
||||
|
@ -530,8 +534,12 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
frontend : parakeet.frontend.Phonetics
|
||||
Frontend used to preprocess text.
|
||||
vocab_size : int
|
||||
Vocabulary size of phons of the model.
|
||||
|
||||
n_tones: int
|
||||
Vocabulary size of tones of the model. Defaults to None. If provided,
|
||||
the model has an extra tone embedding.
|
||||
|
||||
d_mels: int
|
||||
Number of mel bands.
|
||||
|
@ -590,6 +598,11 @@ class Tacotron2(nn.Layer):
|
|||
p_postnet_dropout: float
|
||||
Droput probability in postnet.
|
||||
|
||||
d_global_condition: int
|
||||
Feature size of global condition. Defaults to None. If provided, The
|
||||
model assumes a global condition that is concatenated to the encoder
|
||||
outputs.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
|
@ -669,15 +682,27 @@ class Tacotron2(nn.Layer):
|
|||
text_inputs: Tensor [shape=(B, T_text)]
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
mels: Tensor [shape(B, T_mel, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
text_lens: Tensor [shape=(B,)]
|
||||
Batch of lengths of each text input batch.
|
||||
|
||||
mels: Tensor [shape(B, T_mel, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
output_lens: Tensor [shape=(B,)], optional
|
||||
Batch of lengths of each mels batch. Defaults to None.
|
||||
|
||||
tones: Tensor [shape=(B, T_text)]
|
||||
Batch of sequences of padded tone ids.
|
||||
|
||||
global_condition: Tensor [shape(B, C)]
|
||||
Batch of global conditions. Defaults to None. If the
|
||||
`d_global_condition` of the model is not None, this input should be
|
||||
provided.
|
||||
|
||||
use_stop_token: bool
|
||||
Whether to include a binary classifier to predict the stop token.
|
||||
Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
@ -686,9 +711,9 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
|
||||
|
||||
stop_logits: output sequence of stop logits (B, T_mel);
|
||||
alignments: attention weights (B, T_mel, T_text);
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text).
|
||||
stop_logits: output sequence of stop logits (B, T_mel)
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
if self.toned:
|
||||
|
@ -757,7 +782,8 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
stop_logits: output sequence of stop logits (B, T_mel);
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text).
|
||||
alignments: attention weights (B, T_mel, T_text). This key is only
|
||||
present when `use_stop_token` is True.
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
if self.toned:
|
||||
|
@ -841,6 +867,13 @@ class Tacotron2Loss(nn.Layer):
|
|||
use_stop_token_loss=True,
|
||||
use_guided_attention_loss=False,
|
||||
sigma=0.2):
|
||||
"""Tacotron 2 Criterion.
|
||||
|
||||
Args:
|
||||
use_stop_token_loss (bool, optional): Whether to use a loss for stop token prediction. Defaults to True.
|
||||
use_guided_attention_loss (bool, optional): Whether to use a loss for attention weights. Defaults to False.
|
||||
sigma (float, optional): Hyper-parameter sigma for guided attention loss. Defaults to 0.2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.spec_criterion = nn.MSELoss()
|
||||
self.use_stop_token_loss = use_stop_token_loss
|
||||
|
@ -870,6 +903,22 @@ class Tacotron2Loss(nn.Layer):
|
|||
mel_targets: Tensor [shape=(B, T_mel, C)]
|
||||
Target mel spectrogram sequence.
|
||||
|
||||
attention_weights: Tensor [shape=(B, T_mel, T_enc)]
|
||||
Attention weights. This should be provided when
|
||||
`use_guided_attention_loss` is True.
|
||||
|
||||
slens: Tensor [shape=(B,)]
|
||||
Number of frames of mel spectrograms. This should be provided when
|
||||
`use_guided_attention_loss` is True.
|
||||
|
||||
plens: Tensor [shape=(B, )]
|
||||
Number of text or phone ids of each utterance. This should be
|
||||
provided when `use_guided_attention_loss` is True.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
Stop logits of each mel spectrogram frame. This should be provided
|
||||
when `use_stop_token_loss` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
losses : Dict[str, Tensor]
|
||||
|
@ -879,6 +928,10 @@ class Tacotron2Loss(nn.Layer):
|
|||
mel_loss: MSE loss compute by mel_targets and mel_outputs;
|
||||
|
||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
|
||||
|
||||
guided_attn_loss: Guided attention loss for attention weights;
|
||||
|
||||
stop_loss: Binary cross entropy loss for stop token prediction.
|
||||
"""
|
||||
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
|
||||
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
|
||||
|
|
Loading…
Reference in New Issue