Merge branch 'develop' of github.com:iclementine/Parakeet into doc

This commit is contained in:
iclementine 2021-01-13 11:09:05 +08:00
commit 641be1bc92
10 changed files with 253 additions and 20 deletions

View File

@ -0,0 +1,77 @@
# Tacotron2
PaddlePaddle dynamic graph implementation of Tacotron2, a neural network architecture for speech synthesis directly from text. The implementation is based on [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/abs/1712.05884).
## Project Structure
```text
├── config.py # default configuration file
├── ljspeech.py # dataset and dataloader settings for LJSpeech
├── preprocess.py # script to preprocess LJSpeech dataset
├── synthesis.py # script to synthesize spectrogram from text
├── train.py # script for tacotron2 model training
```
## Dataset
We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
Then you need to preprocess the data by running ``preprocess.py``, the preprocessed data will be placed in ``--output`` directory.
```bash
python preprocess.py \
--input=${DATAPATH} \
--output=${PREPROCESSEDDATAPATH} \
-v \
```
For more help on arguments
``python preprocess.py --help``.
## Train the model
Tacotron2 model can be trained by running ``train.py``.
```bash
python train.py \
--data=${PREPROCESSEDDATAPATH} \
--output=${OUTPUTPATH} \
--device=gpu \
```
If you want to train on CPU, just set ``--device=cpu``.
If you want to train on multiple GPUs, just set ``--nprocs`` as num of GPU.
By default, training will be resumed from the latest checkpoint in ``--output``, if you want to start a new training, please use a new ``${OUTPUTPATH}`` with no checkpoint. And if you want to resume from an other existing model, you should set ``checkpoint_path`` to be the checkpoint path you want to load.
**Note: The checkpoint path cannot contain the file extension.**
For more help on arguments
``python train_transformer.py --help``.
## Synthesis
After training the Tacotron2, spectrogram can be synthesized by running ``synthesis.py``.
```bash
python synthesis.py \
--config=${CONFIGPATH} \
--checkpoint_path=${CHECKPOINTPATH} \
--input=${TEXTPATH} \
--output=${OUTPUTPATH}
--device=gpu
```
The ``${CONFIGPATH}`` needs to be matched with ``${CHECKPOINTPATH}``.
For more help on arguments
``python synthesis.py --help``.
Then you can find the spectrogram files in ``${OUTPUTPATH}``, and then they can be the input of vocoder like [waveflow](../waveflow/README.md#Synthesis) to get audio files.

View File

@ -0,0 +1,48 @@
# TransformerTTS with LJSpeech
## Dataset
### Download the datasaet.
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
```
### Extract the dataset.
```bash
tar xjvf LJSpeech-1.1.tar.bz2
```
### Preprocess the dataset.
Assume the path to save the preprocessed dataset is `ljspeech_transformer_tts`. Run the command below to preprocess the dataset.
```bash
python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_transformer_tts
```
## Train the model
The training script requires 4 command line arguments.
`--data` is the path of the training dataset, `--output` is the path of the output direcctory (we recommend to use a subdirectory in `runs` to manage different experiments.)
`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel.
```bash
python train.py --data=ljspeech_transformer_tts/ --output=runs/test --device="gpu" --nprocs=1
```
If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet.
## Synthesize
Synthesize waveform. We assume the `--input` is a text file, one sentence per line, and `--output` is a directory to save the synthesized mel spectrogram(log magnitude) in `.npy` format. The mel spectrograms can be used with `Waveflow` to generate waveforms.
`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here.
`--device` specifies to device to run synthesis on.
```bash
python synthesize.py --input=sentence.txt --output=mels/ --checkpoint_path='step-310000' --device="gpu" --verbose
```

View File

@ -0,0 +1,48 @@
# WaveFlow with LJSpeech
## Dataset
### Download the datasaet.
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
```
### Extract the dataset.
```bash
tar xjvf LJSpeech-1.1.tar.bz2
```
### Preprocess the dataset.
Assume the path to save the preprocessed dataset is `ljspeech_waveflow`. Run the command below to preprocess the dataset.
```bash
python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_waveflow
```
## Train the model
The training script requires 4 command line arguments.
`--data` is the path of the training dataset, `--output` is the path of the output directory (we recommend to use a subdirectory in `runs` to manage different experiments.)
`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel.
```bash
python train.py --data=ljspeech_waveflow/ --output=runs/test --device="gpu" --nprocs=1
```
If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet.
## Synthesize
Synthesize waveform. We assume the `--input` is a directory containing several mel spectrograms(log magnitude) in `.npy` format. The output would be saved in `--output` directory, containing several `.wav` files, each with the same name as the mel spectrogram does.
`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here.
`--device` specifies to device to run synthesis on.
```bash
python synthesize.py --input=mels/ --output=wavs/ --checkpoint_path='step-2000000' --device="gpu" --verbose
```

View File

@ -46,7 +46,7 @@ class Experiment(ExperimentBase):
n_mels=config.data.n_mels,
kernel_size=config.model.kernel_size)
if self.parallel > 1:
if self.parallel:
model = paddle.DataParallel(model)
optimizer = paddle.optimizer.Adam(
config.training.lr, parameters=model.parameters())

View File

@ -0,0 +1,48 @@
# WaveNet with LJSpeech
## Dataset
### Download the datasaet.
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
```
### Extract the dataset.
```bash
tar xjvf LJSpeech-1.1.tar.bz2
```
### Preprocess the dataset.
Assume the path to save the preprocessed dataset is `ljspeech_wavenet`. Run the command below to preprocess the dataset.
```bash
python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_wavenet
```
## Train the model
The training script requires 4 command line arguments.
`--data` is the path of the training dataset, `--output` is the path of the output directory (we recommend to use a subdirectory in `runs` to manage different experiments.)
`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel.
```bash
python train.py --data=ljspeech_wavenet/ --output=runs/test --device="gpu" --nprocs=1
```
If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet.
## Synthesize
Synthesize waveform. We assume the `--input` is a directory containing several mel spectrograms(normalized into range[0, 1)) in `.npy` format. The output would be saved in `--output` directory, containing several `.wav` files, each with the same name as the mel spectrogram does.
`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here.
`--device` specifies to device to run synthesis on. Due to the autoregressiveness of wavenet, using cpu may be faster.
```bash
python synthesize.py --input=mels/ --output=wavs/ --checkpoint_path='step-2450000' --device="cpu" --verbose
```

View File

@ -49,7 +49,7 @@ class Experiment(ExperimentBase):
loss_type=config.model.loss_type,
log_scale_min=config.model.log_scale_min)
if self.parallel > 1:
if self.parallel:
model = paddle.DataParallel(model)
lr_scheduler = paddle.optimizer.lr.StepDecay(
@ -62,7 +62,7 @@ class Experiment(ExperimentBase):
config.training.gradient_max_norm))
self.model = model
self.model_core = model._layer if self.parallel else model
self.model_core = model._layers if self.parallel else model
self.optimizer = optimizer
def setup_dataloader(self):
@ -119,7 +119,7 @@ class Experiment(ExperimentBase):
mel, wav, audio_starts = batch
y = self.model(wav, mel, audio_starts)
loss = self.model.loss(y, wav)
loss = self.model_core.loss(y, wav)
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
@ -141,7 +141,7 @@ class Experiment(ExperimentBase):
valid_losses = []
mel, wav, audio_starts = next(valid_iterator)
y = self.model(wav, mel, audio_starts)
loss = self.model.loss(y, wav)
loss = self.model_core.loss(y, wav)
valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses)
self.visualizer.add_scalar(

View File

@ -25,7 +25,7 @@ class LJSpeechMetaData(Dataset):
csv_path = self.root / "metadata.csv"
records = []
speaker_name = "ljspeech"
with open(str(csv_path), 'rt') as f:
with open(str(csv_path), 'rt', encoding='utf-8') as f:
for line in f:
filename, _, normalized_text = line.strip().split("|")
filename = str(wav_dir / (filename + ".wav"))

View File

@ -71,8 +71,10 @@ class DecoderPreNet(nn.Layer):
"""
x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate)
output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate)
x = F.dropout(
F.relu(self.linear1(x)), self.dropout_rate, training=True)
output = F.dropout(
F.relu(self.linear2(x)), self.dropout_rate, training=True)
return output
@ -161,9 +163,13 @@ class DecoderPostNet(nn.Layer):
for i in range(len(self.conv_batchnorms) - 1):
input = F.dropout(
F.tanh(self.conv_batchnorms[i](input), self.dropout))
output = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
self.dropout)
F.tanh(self.conv_batchnorms[i](input)),
self.dropout,
training=self.training)
output = F.dropout(
self.conv_batchnorms[self.num_layers - 1](input),
self.dropout,
training=self.training)
return output
@ -228,8 +234,10 @@ class Tacotron2Encoder(nn.Layer):
"""
for conv_batchnorm in self.conv_batchnorms:
x = F.dropout(F.relu(conv_batchnorm(x)),
self.p_dropout) #(B, T, C)
x = F.dropout(
F.relu(conv_batchnorm(x)),
self.p_dropout,
training=self.training)
output, _ = self.lstm(inputs=x, sequence_length=input_lens)
return output
@ -350,8 +358,10 @@ class Tacotron2Decoder(nn.Layer):
# The first lstm layer
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden,
self.p_attention_dropout)
self.attention_hidden = F.dropout(
self.attention_hidden,
self.p_attention_dropout,
training=self.training)
# Loaction sensitive attention
attention_weights_cat = paddle.stack(
@ -367,7 +377,9 @@ class Tacotron2Decoder(nn.Layer):
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, p=self.p_decoder_dropout)
self.decoder_hidden,
p=self.p_decoder_dropout,
training=self.training)
# decode output one step
decoder_hidden_attention_context = paddle.concat(

View File

@ -391,7 +391,7 @@ class TransformerTTS(nn.Layer):
padding_idx=frontend.vocab.padding_index,
weight_attr=I.Uniform(-0.05, 0.05))
# position encoding matrix may be extended later
self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder)
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
self.encoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn,
@ -399,7 +399,7 @@ class TransformerTTS(nn.Layer):
# decoder
self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout)
self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder)
self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder)
self.decoder_pe_scalar = self.create_parameter(
[1], attr=I.Constant(1.))
self.decoder = TransformerDecoder(

View File

@ -17,10 +17,10 @@ import numpy as np
import paddle
from paddle.nn import functional as F
__all__ = ["positional_encoding"]
__all__ = ["sinusoid_positional_encoding"]
def positional_encoding(start_index, length, size, dtype=None):
def sinusoid_positional_encoding(start_index, length, size, dtype=None):
r"""Generate standard positional encoding matrix.
.. math::