Merge pull request #60 from lfchener/doc
add docstring for LocationSensitiveAttention
This commit is contained in:
commit
d08eb72791
|
@ -19,6 +19,8 @@ from parakeet.frontend.normalizer.numbers import normalize_numbers
|
|||
|
||||
|
||||
def normalize(sentence):
|
||||
""" Normalize English text.
|
||||
"""
|
||||
# preprocessing
|
||||
sentence = unicode(sentence)
|
||||
sentence = normalize_numbers(sentence)
|
||||
|
|
|
@ -75,6 +75,8 @@ def _expand_number(m):
|
|||
|
||||
|
||||
def normalize_numbers(text):
|
||||
""" Normalize numbers in English text.
|
||||
"""
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r'\1 pounds', text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
|
|
|
@ -39,6 +39,9 @@ class Phonetics(ABC):
|
|||
|
||||
|
||||
class English(Phonetics):
|
||||
""" Normalize the input text sequence and convert into pronunciation id sequence.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.backend = G2p()
|
||||
self.phonemes = list(self.backend.phonemes)
|
||||
|
@ -46,6 +49,18 @@ class English(Phonetics):
|
|||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||
|
||||
def phoneticize(self, sentence):
|
||||
""" Normalize the input text sequence and convert it into pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
start = self.vocab.start_symbol
|
||||
end = self.vocab.end_symbol
|
||||
phonemes = ([] if start is None else [start]) \
|
||||
|
@ -54,6 +69,18 @@ class English(Phonetics):
|
|||
return phonemes
|
||||
|
||||
def numericalize(self, phonemes):
|
||||
""" Convert pronunciation sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
phonemes: List[str]
|
||||
The list of pronunciation sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
ids = [
|
||||
self.vocab.lookup(item) for item in phonemes
|
||||
if item in self.vocab.stoi
|
||||
|
@ -61,17 +88,46 @@ class English(Phonetics):
|
|||
return ids
|
||||
|
||||
def reverse(self, ids):
|
||||
""" Reverse the list of pronunciation id sequence to a list of pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
ids: List[int]
|
||||
The list of pronunciation id sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
return [self.vocab.reverse(i) for i in ids]
|
||||
|
||||
def __call__(self, sentence):
|
||||
""" Convert the input text sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
return self.numericalize(self.phoneticize(sentence))
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
return len(self.vocab)
|
||||
|
||||
|
||||
class EnglishCharacter(Phonetics):
|
||||
""" Normalize the input text sequence and convert it into character id sequence.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.backend = G2p()
|
||||
self.graphemes = list(self.backend.graphemes)
|
||||
|
@ -79,10 +135,34 @@ class EnglishCharacter(Phonetics):
|
|||
self.vocab = Vocab(self.graphemes + self.punctuations)
|
||||
|
||||
def phoneticize(self, sentence):
|
||||
""" Normalize the input text sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
str
|
||||
A text sequence after normalize.
|
||||
"""
|
||||
words = normalize(sentence)
|
||||
return words
|
||||
|
||||
def numericalize(self, sentence):
|
||||
""" Convert a text sequence into ids.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
List of a character id sequence.
|
||||
"""
|
||||
ids = [
|
||||
self.vocab.lookup(item) for item in sentence
|
||||
if item in self.vocab.stoi
|
||||
|
@ -90,17 +170,46 @@ class EnglishCharacter(Phonetics):
|
|||
return ids
|
||||
|
||||
def reverse(self, ids):
|
||||
""" Convert a character id sequence into text.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
ids: List[int]
|
||||
List of a character id sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
str
|
||||
The input text sequence.
|
||||
|
||||
"""
|
||||
return [self.vocab.reverse(i) for i in ids]
|
||||
|
||||
def __call__(self, sentence):
|
||||
""" Normalize the input text sequence and convert it into character id sequence.
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
List of a character id sequence.
|
||||
"""
|
||||
return self.numericalize(self.phoneticize(sentence))
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
return len(self.vocab)
|
||||
|
||||
|
||||
class Chinese(Phonetics):
|
||||
"""Normalize Chinese text sequence and convert it into ids.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.opencc_backend = OpenCC('t2s.json')
|
||||
self.backend = G2pM()
|
||||
|
@ -115,6 +224,18 @@ class Chinese(Phonetics):
|
|||
return list(all_syllables)
|
||||
|
||||
def phoneticize(self, sentence):
|
||||
""" Normalize the input text sequence and convert it into pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
simplified = self.opencc_backend.convert(sentence)
|
||||
phonemes = self.backend(simplified)
|
||||
start = self.vocab.start_symbol
|
||||
|
@ -136,15 +257,53 @@ class Chinese(Phonetics):
|
|||
return cleaned_phonemes
|
||||
|
||||
def numericalize(self, phonemes):
|
||||
""" Convert pronunciation sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
phonemes: List[str]
|
||||
The list of pronunciation sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
ids = [self.vocab.lookup(item) for item in phonemes]
|
||||
return ids
|
||||
|
||||
def __call__(self, sentence):
|
||||
""" Convert the input text sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
return self.numericalize(self.phoneticize(sentence))
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
return len(self.vocab)
|
||||
|
||||
def reverse(self, ids):
|
||||
""" Reverse the list of pronunciation id sequence to a list of pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
ids: List[int]
|
||||
The list of pronunciation id sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
return [self.vocab.reverse(i) for i in ids]
|
||||
|
|
|
@ -1,32 +1,64 @@
|
|||
from typing import Dict, Iterable, List
|
||||
from ruamel import yaml
|
||||
from collections import OrderedDict
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Iterable, List
|
||||
from collections import OrderedDict
|
||||
|
||||
__all__ = ["Vocab"]
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
def __init__(self, symbols: Iterable[str],
|
||||
padding_symbol="<pad>",
|
||||
unk_symbol="<unk>",
|
||||
start_symbol="<s>",
|
||||
end_symbol="</s>"):
|
||||
""" Vocabulary.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
symbols: Iterable[str]
|
||||
Common symbols.
|
||||
|
||||
padding_symbol: str, optional
|
||||
Symbol for pad. Defaults to "<pad>".
|
||||
|
||||
unk_symbol: str, optional
|
||||
Symbol for unknow. Defaults to "<unk>"
|
||||
|
||||
start_symbol: str, optional
|
||||
Symbol for start. Defaults to "<s>"
|
||||
|
||||
end_symbol: str, optional
|
||||
Symbol for end. Defaults to "</s>"
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
symbols: Iterable[str],
|
||||
padding_symbol="<pad>",
|
||||
unk_symbol="<unk>",
|
||||
start_symbol="<s>",
|
||||
end_symbol="</s>"):
|
||||
self.special_symbols = OrderedDict()
|
||||
for i, item in enumerate(
|
||||
[padding_symbol, unk_symbol, start_symbol, end_symbol]):
|
||||
if item:
|
||||
self.special_symbols[item] = len(self.special_symbols)
|
||||
|
||||
|
||||
self.padding_symbol = padding_symbol
|
||||
self.unk_symbol = unk_symbol
|
||||
self.start_symbol = start_symbol
|
||||
self.end_symbol = end_symbol
|
||||
|
||||
|
||||
|
||||
self.stoi = OrderedDict()
|
||||
self.stoi.update(self.special_symbols)
|
||||
|
||||
|
||||
for i, s in enumerate(symbols):
|
||||
if s not in self.stoi:
|
||||
self.stoi[s] = len(self.stoi)
|
||||
|
@ -34,49 +66,66 @@ class Vocab(object):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.stoi)
|
||||
|
||||
|
||||
@property
|
||||
def num_specials(self):
|
||||
""" The number of special symbols.
|
||||
"""
|
||||
return len(self.special_symbols)
|
||||
|
||||
# special tokens
|
||||
@property
|
||||
def padding_index(self):
|
||||
""" The index of padding symbol
|
||||
"""
|
||||
return self.stoi.get(self.padding_symbol, -1)
|
||||
|
||||
@property
|
||||
def unk_index(self):
|
||||
"""The index of unknow symbol.
|
||||
"""
|
||||
return self.stoi.get(self.unk_symbol, -1)
|
||||
|
||||
@property
|
||||
def start_index(self):
|
||||
"""The index of start symbol.
|
||||
"""
|
||||
return self.stoi.get(self.start_symbol, -1)
|
||||
|
||||
@property
|
||||
def end_index(self):
|
||||
""" The index of end symbol.
|
||||
"""
|
||||
return self.stoi.get(self.end_symbol, -1)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
fmt = "Vocab(size: {},\nstoi:\n{})"
|
||||
return fmt.format(len(self), self.stoi)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
def lookup(self, symbol):
|
||||
""" The index that symbol correspond.
|
||||
"""
|
||||
return self.stoi[symbol]
|
||||
|
||||
|
||||
def reverse(self, index):
|
||||
""" The symbol thar index cottespond.
|
||||
"""
|
||||
return self.itos[index]
|
||||
|
||||
|
||||
def add_symbol(self, symbol):
|
||||
""" Add a new symbol in vocab.
|
||||
"""
|
||||
if symbol in self.stoi:
|
||||
return
|
||||
return
|
||||
N = len(self.stoi)
|
||||
self.stoi[symbol] = N
|
||||
self.itos[N] = symbol
|
||||
|
||||
|
||||
def add_symbols(self, symbols):
|
||||
""" Add multiple symbols in vocab.
|
||||
"""
|
||||
for symbol in symbols:
|
||||
self.add_symbol(symbol)
|
||||
|
||||
|
|
|
@ -32,16 +32,16 @@ class DecoderPreNet(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
d_input: int
|
||||
input feature size
|
||||
The input feature size.
|
||||
|
||||
d_hidden: int
|
||||
hidden size
|
||||
The hidden size.
|
||||
|
||||
d_output: int
|
||||
output feature size
|
||||
The output feature size.
|
||||
|
||||
dropout_rate: float
|
||||
droput probability
|
||||
The droput probability.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -49,7 +49,7 @@ class DecoderPreNet(nn.Layer):
|
|||
d_input: int,
|
||||
d_hidden: int,
|
||||
d_output: int,
|
||||
dropout_rate: float=0.2):
|
||||
dropout_rate: float):
|
||||
super().__init__()
|
||||
|
||||
self.dropout_rate = dropout_rate
|
||||
|
@ -62,12 +62,12 @@ class DecoderPreNet(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
x: Tensor [shape=(B, T_mel, C)]
|
||||
batch of the sequences of padded mel spectrogram
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: Tensor [shape=(B, T_mel, C)]
|
||||
batch of the sequences of padded hidden state
|
||||
Batch of the sequences of padded hidden state.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -82,28 +82,28 @@ class DecoderPostNet(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
d_mels: int
|
||||
number of mel bands
|
||||
The number of mel bands.
|
||||
|
||||
d_hidden: int
|
||||
hidden size of postnet
|
||||
The hidden size of postnet.
|
||||
|
||||
kernel_size: int
|
||||
kernel size of the conv layer in postnet
|
||||
The kernel size of the conv layer in postnet.
|
||||
|
||||
num_layers: int
|
||||
number of conv layers in postnet
|
||||
The number of conv layers in postnet.
|
||||
|
||||
dropout: float
|
||||
droput probability
|
||||
The droput probability.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_mels: int=80,
|
||||
d_hidden: int=512,
|
||||
kernel_size: int=5,
|
||||
num_layers: int=5,
|
||||
dropout: float=0.1):
|
||||
d_mels: int,
|
||||
d_hidden: int,
|
||||
kernel_size: int,
|
||||
num_layers: int,
|
||||
dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
self.num_layers = num_layers
|
||||
|
@ -150,12 +150,12 @@ class DecoderPostNet(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
input: Tensor [shape=(B, T_mel, C)]
|
||||
output sequence of features from decoder
|
||||
Output sequence of features from decoder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: Tensor [shape=(B, T_mel, C)]
|
||||
output sequence of features after postnet
|
||||
Output sequence of features after postnet.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -173,16 +173,16 @@ class Tacotron2Encoder(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
d_hidden: int
|
||||
hidden size in encoder module
|
||||
The hidden size in encoder module.
|
||||
|
||||
conv_layers: int
|
||||
number of conv layers
|
||||
The number of conv layers.
|
||||
|
||||
kernel_size: int
|
||||
kernel size of conv layers
|
||||
The kernel size of conv layers.
|
||||
|
||||
p_dropout: float
|
||||
droput probability
|
||||
The droput probability.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -216,15 +216,15 @@ class Tacotron2Encoder(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
x: Tensor [shape=(B, T)]
|
||||
batch of the sequencees of padded character ids
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
text_lens: Tensor [shape=(B,)]
|
||||
batch of lengths of each text input batch.
|
||||
text_lens: Tensor [shape=(B,)], optional
|
||||
Batch of lengths of each text input batch. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : Tensor [shape=(B, T, C)]
|
||||
batch of the sequences of padded hidden states
|
||||
Batch of the sequences of padded hidden states.
|
||||
|
||||
"""
|
||||
for conv_batchnorm in self.conv_batchnorms:
|
||||
|
@ -241,40 +241,40 @@ class Tacotron2Decoder(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
d_mels: int
|
||||
number of mel bands
|
||||
The number of mel bands.
|
||||
|
||||
reduction_factor: int
|
||||
reduction factor of tacotron
|
||||
The reduction factor of tacotron.
|
||||
|
||||
d_encoder: int
|
||||
hidden size of encoder
|
||||
The hidden size of encoder.
|
||||
|
||||
d_prenet: int
|
||||
hidden size in decoder prenet
|
||||
The hidden size in decoder prenet.
|
||||
|
||||
d_attention_rnn: int
|
||||
attention rnn layer hidden size
|
||||
The attention rnn layer hidden size.
|
||||
|
||||
d_decoder_rnn: int
|
||||
decoder rnn layer hidden size
|
||||
The decoder rnn layer hidden size.
|
||||
|
||||
d_attention: int
|
||||
hidden size of the linear layer in location sensitive attention
|
||||
The hidden size of the linear layer in location sensitive attention.
|
||||
|
||||
attention_filters: int
|
||||
filter size of the conv layer in location sensitive attention
|
||||
The filter size of the conv layer in location sensitive attention.
|
||||
|
||||
attention_kernel_size: int
|
||||
kernel size of the conv layer in location sensitive attention
|
||||
The kernel size of the conv layer in location sensitive attention.
|
||||
|
||||
p_prenet_dropout: float
|
||||
droput probability in decoder prenet
|
||||
The droput probability in decoder prenet.
|
||||
|
||||
p_attention_dropout: float
|
||||
droput probability in location sensitive attention
|
||||
The droput probability in location sensitive attention.
|
||||
|
||||
p_decoder_dropout: float
|
||||
droput probability in decoder
|
||||
The droput probability in decoder.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -382,25 +382,25 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
keys: Tensor[shape=(B, T_text, C)]
|
||||
batch of the sequences of padded output from encoder
|
||||
keys: Tensor[shape=(B, T_key, C)]
|
||||
Batch of the sequences of padded output from encoder.
|
||||
|
||||
querys: Tensor[shape(B, T_mel, C)]
|
||||
batch of the sequences of padded mel spectrogram
|
||||
querys: Tensor[shape(B, T_query, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
mask: Tensor[shape=(B, T_text, 1)]
|
||||
mask generated with text length
|
||||
mask: Tensor
|
||||
Mask generated with text length. Shape should be (B, T_key, T_query) or broadcastable shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_output: Tensor [shape=(B, T_mel, C)]
|
||||
output sequence of features
|
||||
mel_output: Tensor [shape=(B, T_query, C)]
|
||||
Output sequence of features.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
output sequence of stop logits
|
||||
stop_logits: Tensor [shape=(B, T_query)]
|
||||
Output sequence of stop logits.
|
||||
|
||||
alignments: Tensor [shape=(B, T_mel, T_text)]
|
||||
attention weights
|
||||
alignments: Tensor [shape=(B, T_query, T_key)]
|
||||
Attention weights.
|
||||
"""
|
||||
querys = paddle.reshape(
|
||||
querys,
|
||||
|
@ -437,25 +437,25 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
keys: Tensor [shape=(B, T_text, C)]
|
||||
batch of the sequences of padded output from encoder
|
||||
keys: Tensor [shape=(B, T_key, C)]
|
||||
Batch of the sequences of padded output from encoder.
|
||||
|
||||
stop_threshold: float
|
||||
stop synthesize when stop logit is greater than this stop threshold
|
||||
stop_threshold: float, optional
|
||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||
|
||||
max_decoder_steps: int
|
||||
number of max step when synthesize
|
||||
max_decoder_steps: int, optional
|
||||
Number of max step when synthesize. Defaults to 1000.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_output: Tensor [shape=(B, T_mel, C)]
|
||||
output sequence of features
|
||||
Output sequence of features.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
output sequence of stop logits
|
||||
Output sequence of stop logits.
|
||||
|
||||
alignments: Tensor [shape=(B, T_mel, T_text)]
|
||||
attention weights
|
||||
alignments: Tensor [shape=(B, T_mel, T_key)]
|
||||
Attention weights.
|
||||
|
||||
"""
|
||||
query = paddle.zeros(
|
||||
|
@ -493,75 +493,72 @@ class Tacotron2(nn.Layer):
|
|||
"""Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
||||
|
||||
This is a model of Spectrogram prediction network in Tacotron2 described
|
||||
in ``Natural TTS Synthesis
|
||||
by Conditioning WaveNet on Mel Spectrogram Predictions``,
|
||||
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
|
||||
<https://arxiv.org/abs/1712.05884>`_,
|
||||
which converts the sequence of characters
|
||||
into the sequence of mel spectrogram.
|
||||
|
||||
`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
|
||||
<https://arxiv.org/abs/1712.05884>`_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frontend : parakeet.frontend.Phonetics
|
||||
frontend used to preprocess text
|
||||
Frontend used to preprocess text.
|
||||
|
||||
d_mels: int
|
||||
number of mel bands
|
||||
Number of mel bands.
|
||||
|
||||
d_encoder: int
|
||||
hidden size in encoder module
|
||||
Hidden size in encoder module.
|
||||
|
||||
encoder_conv_layers: int
|
||||
number of conv layers in encoder
|
||||
Number of conv layers in encoder.
|
||||
|
||||
encoder_kernel_size: int
|
||||
kernel size of conv layers in encoder
|
||||
Kernel size of conv layers in encoder.
|
||||
|
||||
d_prenet: int
|
||||
hidden size in decoder prenet
|
||||
Hidden size in decoder prenet.
|
||||
|
||||
d_attention_rnn: int
|
||||
attention rnn layer hidden size in decoder
|
||||
Attention rnn layer hidden size in decoder.
|
||||
|
||||
d_decoder_rnn: int
|
||||
decoder rnn layer hidden size in decoder
|
||||
Decoder rnn layer hidden size in decoder.
|
||||
|
||||
attention_filters: int
|
||||
filter size of the conv layer in location sensitive attention
|
||||
Filter size of the conv layer in location sensitive attention.
|
||||
|
||||
attention_kernel_size: int
|
||||
kernel size of the conv layer in location sensitive attention
|
||||
Kernel size of the conv layer in location sensitive attention.
|
||||
|
||||
d_attention: int
|
||||
hidden size of the linear layer in location sensitive attention
|
||||
Hidden size of the linear layer in location sensitive attention.
|
||||
|
||||
d_postnet: int
|
||||
hidden size of postnet
|
||||
Hidden size of postnet.
|
||||
|
||||
postnet_kernel_size: int
|
||||
kernel size of the conv layer in postnet
|
||||
Kernel size of the conv layer in postnet.
|
||||
|
||||
postnet_conv_layers: int
|
||||
number of conv layers in postnet
|
||||
Number of conv layers in postnet.
|
||||
|
||||
reduction_factor: int
|
||||
reduction factor of tacotron
|
||||
Reduction factor of tacotron2.
|
||||
|
||||
p_encoder_dropout: float
|
||||
droput probability in encoder
|
||||
Droput probability in encoder.
|
||||
|
||||
p_prenet_dropout: float
|
||||
droput probability in decoder prenet
|
||||
Droput probability in decoder prenet.
|
||||
|
||||
p_attention_dropout: float
|
||||
droput probability in location sensitive attention
|
||||
Droput probability in location sensitive attention.
|
||||
|
||||
p_decoder_dropout: float
|
||||
droput probability in decoder
|
||||
Droput probability in decoder.
|
||||
|
||||
p_postnet_dropout: float
|
||||
droput probability in postnet
|
||||
Droput probability in postnet.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -616,28 +613,28 @@ class Tacotron2(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
text_inputs: Tensor [shape=(B, T_text)]
|
||||
batch of the sequencees of padded character ids
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
mels: Tensor [shape(B, T_mel, C)]
|
||||
batch of the sequences of padded mel spectrogram
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
text_lens: Tensor [shape=(B,)]
|
||||
batch of lengths of each text input batch.
|
||||
Batch of lengths of each text input batch.
|
||||
|
||||
output_lens: Tensor [shape=(B,)]
|
||||
batch of lengths of each mels batch.
|
||||
output_lens: Tensor [shape=(B,)], optional
|
||||
Batch of lengths of each mels batch. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_output: output sequence of features (B, T_mel, C)
|
||||
mel_output: output sequence of features (B, T_mel, C);
|
||||
|
||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C)
|
||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
|
||||
|
||||
stop_logits: output sequence of stop logits (B, T_mel)
|
||||
stop_logits: output sequence of stop logits (B, T_mel);
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text)
|
||||
alignments: attention weights (B, T_mel, T_text).
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||
|
@ -675,25 +672,25 @@ class Tacotron2(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
text_inputs: Tensor [shape=(B, T_text)]
|
||||
batch of the sequencees of padded character ids
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
stop_threshold: float
|
||||
stop synthesize when stop logit is greater than this stop threshold
|
||||
stop_threshold: float, optional
|
||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||
|
||||
max_decoder_steps: int
|
||||
number of max step when synthesize
|
||||
max_decoder_steps: int, optional
|
||||
Number of max step when synthesize. Defaults to 1000.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_output: output sequence of sepctrogram (B, T_mel, C)
|
||||
mel_output: output sequence of sepctrogram (B, T_mel, C);
|
||||
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C)
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C);
|
||||
|
||||
stop_logits: output sequence of stop logits (B, T_mel)
|
||||
stop_logits: output sequence of stop logits (B, T_mel);
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text)
|
||||
alignments: attention weights (B, T_mel, T_text).
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
encoder_outputs = self.encoder(embedded_inputs)
|
||||
|
@ -721,21 +718,21 @@ class Tacotron2(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
text: str
|
||||
sequence of characters
|
||||
Sequence of characters.
|
||||
|
||||
stop_threshold: float
|
||||
stop synthesize when stop logit is greater than this stop threshold
|
||||
stop_threshold: float, optional
|
||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||
|
||||
max_decoder_steps: int
|
||||
number of max step when synthesize
|
||||
max_decoder_steps: int, optional
|
||||
Number of max step when synthesize. Defaults to 1000.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C)
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C);
|
||||
|
||||
alignments: attention weights (T_mel, T_text)
|
||||
alignments: attention weights (T_mel, T_text).
|
||||
"""
|
||||
ids = np.asarray(self.frontend(text))
|
||||
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
||||
|
@ -750,21 +747,18 @@ class Tacotron2(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
frontend: parakeet.frontend.Phonetics
|
||||
frontend used to preprocess text
|
||||
Frontend used to preprocess text.
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
Model configs.
|
||||
|
||||
checkpoint_path: Path
|
||||
the path of pretrained model checkpoint
|
||||
checkpoint_path: Path or str
|
||||
The path of pretrained model checkpoint, without extension name.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_outputs_postnet: Tensor [shape=(T_mel, C)]
|
||||
output sequence of sepctrogram after postnet
|
||||
|
||||
alignments: Tensor [shape=(T_mel, T_text)]
|
||||
attention weights
|
||||
Tacotron2
|
||||
The model build from pretrined result.
|
||||
"""
|
||||
model = cls(frontend,
|
||||
d_mels=config.data.d_mels,
|
||||
|
@ -805,31 +799,31 @@ class Tacotron2Loss(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
mel_outputs: Tensor [shape=(B, T_mel, C)]
|
||||
output mel spectrogram sequence
|
||||
Output mel spectrogram sequence.
|
||||
|
||||
mel_outputs_postnet: Tensor [shape(B, T_mel, C)]
|
||||
output mel spectrogram sequence after postnet
|
||||
Output mel spectrogram sequence after postnet.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
output sequence of stop logits befor sigmoid
|
||||
Output sequence of stop logits befor sigmoid.
|
||||
|
||||
mel_targets: Tensor [shape=(B, T_mel, C)]
|
||||
target mel spectrogram sequence
|
||||
Target mel spectrogram sequence.
|
||||
|
||||
stop_tokens: Tensor [shape=(B,)]
|
||||
target stop token
|
||||
Target stop token.
|
||||
|
||||
Returns
|
||||
-------
|
||||
losses : Dict[str, Tensor]
|
||||
|
||||
loss: the sum of the other three losses
|
||||
loss: the sum of the other three losses;
|
||||
|
||||
mel_loss: MSE loss compute by mel_targets and mel_outputs
|
||||
mel_loss: MSE loss compute by mel_targets and mel_outputs;
|
||||
|
||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet
|
||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
|
||||
|
||||
stop_loss: stop loss computed by stop_logits and stop token
|
||||
stop_loss: stop loss computed by stop_logits and stop token.
|
||||
"""
|
||||
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
||||
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
||||
|
|
|
@ -18,6 +18,7 @@ import paddle
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def scaled_dot_product_attention(q,
|
||||
k,
|
||||
v,
|
||||
|
@ -139,10 +140,11 @@ class MonoheadAttention(nn.Layer):
|
|||
Feature size of the key of each scaled dot product attention. If not
|
||||
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||
"""
|
||||
def __init__(self,
|
||||
model_dim: int,
|
||||
dropout: float=0.0,
|
||||
k_dim: int=None,
|
||||
|
||||
def __init__(self,
|
||||
model_dim: int,
|
||||
dropout: float=0.0,
|
||||
k_dim: int=None,
|
||||
v_dim: int=None):
|
||||
super(MonoheadAttention, self).__init__()
|
||||
k_dim = k_dim or model_dim
|
||||
|
@ -219,6 +221,7 @@ class MultiheadAttention(nn.Layer):
|
|||
ValueError
|
||||
If ``model_dim`` is not divisible by ``num_heads``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dim: int,
|
||||
num_heads: int,
|
||||
|
@ -279,6 +282,28 @@ class MultiheadAttention(nn.Layer):
|
|||
|
||||
|
||||
class LocationSensitiveAttention(nn.Layer):
|
||||
"""Location Sensitive Attention module.
|
||||
|
||||
Reference: `Attention-Based Models for Speech Recognition <https://arxiv.org/pdf/1506.07503.pdf>`_
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
d_query: int
|
||||
The feature size of query.
|
||||
|
||||
d_key : int
|
||||
The feature size of key.
|
||||
|
||||
d_attention : int
|
||||
The feature size of dimension.
|
||||
|
||||
location_filters : int
|
||||
Filter size of attention convolution.
|
||||
|
||||
location_kernel_size : int
|
||||
Kernel size of attention convolution.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_query: int,
|
||||
d_key: int,
|
||||
|
@ -310,6 +335,34 @@ class LocationSensitiveAttention(nn.Layer):
|
|||
value,
|
||||
attention_weights_cat,
|
||||
mask=None):
|
||||
"""Compute context vector and attention weights.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
query : Tensor [shape=(batch_size, d_query)]
|
||||
The queries.
|
||||
|
||||
processed_key : Tensor [shape=(batch_size, time_steps_k, d_attention)]
|
||||
The keys after linear layer.
|
||||
|
||||
value : Tensor [shape=(batch_size, time_steps_k, d_key)]
|
||||
The values.
|
||||
|
||||
attention_weights_cat : Tensor [shape=(batch_size, time_step_k, 2)]
|
||||
Attention weights concat.
|
||||
|
||||
mask : Tensor, optional
|
||||
The mask. Shape should be (batch_size, times_steps_q, time_steps_k) or broadcastable shape.
|
||||
Defaults to None.
|
||||
|
||||
Returns
|
||||
----------
|
||||
attention_context : Tensor [shape=(batch_size, time_steps_q, d_attention)]
|
||||
The context vector.
|
||||
|
||||
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||
The attention weights.
|
||||
"""
|
||||
|
||||
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||
processed_attention_weights = self.location_layer(
|
||||
|
|
Loading…
Reference in New Issue