diff --git a/examples/tacotron2/README.md b/examples/tacotron2/README.md new file mode 100644 index 0000000..12e28da --- /dev/null +++ b/examples/tacotron2/README.md @@ -0,0 +1,77 @@ +# Tacotron2 + +PaddlePaddle dynamic graph implementation of Tacotron2, a neural network architecture for speech synthesis directly from text. The implementation is based on [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/abs/1712.05884). + +## Project Structure + +```text +├── config.py # default configuration file +├── ljspeech.py # dataset and dataloader settings for LJSpeech +├── preprocess.py # script to preprocess LJSpeech dataset +├── synthesis.py # script to synthesize spectrogram from text +├── train.py # script for tacotron2 model training +``` + +## Dataset + +We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +Then you need to preprocess the data by running ``preprocess.py``, the preprocessed data will be placed in ``--output`` directory. + +```bash +python preprocess.py \ +--input=${DATAPATH} \ +--output=${PREPROCESSEDDATAPATH} \ +-v \ +``` + +For more help on arguments + +``python preprocess.py --help``. + +## Train the model + +Tacotron2 model can be trained by running ``train.py``. + +```bash +python train.py \ +--data=${PREPROCESSEDDATAPATH} \ +--output=${OUTPUTPATH} \ +--device=gpu \ +``` + +If you want to train on CPU, just set ``--device=cpu``. +If you want to train on multiple GPUs, just set ``--nprocs`` as num of GPU. +By default, training will be resumed from the latest checkpoint in ``--output``, if you want to start a new training, please use a new ``${OUTPUTPATH}`` with no checkpoint. And if you want to resume from an other existing model, you should set ``checkpoint_path`` to be the checkpoint path you want to load. + +**Note: The checkpoint path cannot contain the file extension.** + +For more help on arguments + +``python train_transformer.py --help``. + +## Synthesis + +After training the Tacotron2, spectrogram can be synthesized by running ``synthesis.py``. + +```bash +python synthesis.py \ +--config=${CONFIGPATH} \ +--checkpoint_path=${CHECKPOINTPATH} \ +--input=${TEXTPATH} \ +--output=${OUTPUTPATH} +--device=gpu +``` + +The ``${CONFIGPATH}`` needs to be matched with ``${CHECKPOINTPATH}``. + +For more help on arguments + +``python synthesis.py --help``. + +Then you can find the spectrogram files in ``${OUTPUTPATH}``, and then they can be the input of vocoder like [waveflow](../waveflow/README.md#Synthesis) to get audio files. diff --git a/examples/transformer_tts/README.md b/examples/transformer_tts/README.md new file mode 100644 index 0000000..2924afb --- /dev/null +++ b/examples/transformer_tts/README.md @@ -0,0 +1,48 @@ +# TransformerTTS with LJSpeech + +## Dataset + +### Download the datasaet. + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +``` + +### Extract the dataset. + +```bash +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +### Preprocess the dataset. + +Assume the path to save the preprocessed dataset is `ljspeech_transformer_tts`. Run the command below to preprocess the dataset. + +```bash +python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_transformer_tts +``` + +## Train the model + +The training script requires 4 command line arguments. +`--data` is the path of the training dataset, `--output` is the path of the output direcctory (we recommend to use a subdirectory in `runs` to manage different experiments.) + +`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel. + +```bash +python train.py --data=ljspeech_transformer_tts/ --output=runs/test --device="gpu" --nprocs=1 +``` + +If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet. + +## Synthesize + +Synthesize waveform. We assume the `--input` is a text file, one sentence per line, and `--output` is a directory to save the synthesized mel spectrogram(log magnitude) in `.npy` format. The mel spectrograms can be used with `Waveflow` to generate waveforms. + +`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here. + +`--device` specifies to device to run synthesis on. + +```bash +python synthesize.py --input=sentence.txt --output=mels/ --checkpoint_path='step-310000' --device="gpu" --verbose +``` \ No newline at end of file diff --git a/examples/waveflow/README.md b/examples/waveflow/README.md new file mode 100644 index 0000000..8931b88 --- /dev/null +++ b/examples/waveflow/README.md @@ -0,0 +1,48 @@ +# WaveFlow with LJSpeech + +## Dataset + +### Download the datasaet. + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +``` + +### Extract the dataset. + +```bash +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +### Preprocess the dataset. + +Assume the path to save the preprocessed dataset is `ljspeech_waveflow`. Run the command below to preprocess the dataset. + +```bash +python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_waveflow +``` + +## Train the model + +The training script requires 4 command line arguments. +`--data` is the path of the training dataset, `--output` is the path of the output directory (we recommend to use a subdirectory in `runs` to manage different experiments.) + +`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel. + +```bash +python train.py --data=ljspeech_waveflow/ --output=runs/test --device="gpu" --nprocs=1 +``` + +If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet. + +## Synthesize + +Synthesize waveform. We assume the `--input` is a directory containing several mel spectrograms(log magnitude) in `.npy` format. The output would be saved in `--output` directory, containing several `.wav` files, each with the same name as the mel spectrogram does. + +`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here. + +`--device` specifies to device to run synthesis on. + +```bash +python synthesize.py --input=mels/ --output=wavs/ --checkpoint_path='step-2000000' --device="gpu" --verbose +``` \ No newline at end of file diff --git a/examples/waveflow/train.py b/examples/waveflow/train.py index 443cc8b..c64ace6 100644 --- a/examples/waveflow/train.py +++ b/examples/waveflow/train.py @@ -46,7 +46,7 @@ class Experiment(ExperimentBase): n_mels=config.data.n_mels, kernel_size=config.model.kernel_size) - if self.parallel > 1: + if self.parallel: model = paddle.DataParallel(model) optimizer = paddle.optimizer.Adam( config.training.lr, parameters=model.parameters()) diff --git a/examples/wavenet/README.md b/examples/wavenet/README.md new file mode 100644 index 0000000..ef61a9f --- /dev/null +++ b/examples/wavenet/README.md @@ -0,0 +1,48 @@ +# WaveNet with LJSpeech + +## Dataset + +### Download the datasaet. + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +``` + +### Extract the dataset. + +```bash +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +### Preprocess the dataset. + +Assume the path to save the preprocessed dataset is `ljspeech_wavenet`. Run the command below to preprocess the dataset. + +```bash +python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_wavenet +``` + +## Train the model + +The training script requires 4 command line arguments. +`--data` is the path of the training dataset, `--output` is the path of the output directory (we recommend to use a subdirectory in `runs` to manage different experiments.) + +`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel. + +```bash +python train.py --data=ljspeech_wavenet/ --output=runs/test --device="gpu" --nprocs=1 +``` + +If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet. + +## Synthesize + +Synthesize waveform. We assume the `--input` is a directory containing several mel spectrograms(normalized into range[0, 1)) in `.npy` format. The output would be saved in `--output` directory, containing several `.wav` files, each with the same name as the mel spectrogram does. + +`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here. + +`--device` specifies to device to run synthesis on. Due to the autoregressiveness of wavenet, using cpu may be faster. + +```bash +python synthesize.py --input=mels/ --output=wavs/ --checkpoint_path='step-2450000' --device="cpu" --verbose +``` \ No newline at end of file diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index 8e9bc0e..b62e4a3 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -49,7 +49,7 @@ class Experiment(ExperimentBase): loss_type=config.model.loss_type, log_scale_min=config.model.log_scale_min) - if self.parallel > 1: + if self.parallel: model = paddle.DataParallel(model) lr_scheduler = paddle.optimizer.lr.StepDecay( @@ -62,7 +62,7 @@ class Experiment(ExperimentBase): config.training.gradient_max_norm)) self.model = model - self.model_core = model._layer if self.parallel else model + self.model_core = model._layers if self.parallel else model self.optimizer = optimizer def setup_dataloader(self): @@ -119,7 +119,7 @@ class Experiment(ExperimentBase): mel, wav, audio_starts = batch y = self.model(wav, mel, audio_starts) - loss = self.model.loss(y, wav) + loss = self.model_core.loss(y, wav) loss.backward() self.optimizer.step() iteration_time = time.time() - start @@ -141,7 +141,7 @@ class Experiment(ExperimentBase): valid_losses = [] mel, wav, audio_starts = next(valid_iterator) y = self.model(wav, mel, audio_starts) - loss = self.model.loss(y, wav) + loss = self.model_core.loss(y, wav) valid_losses.append(float(loss)) valid_loss = np.mean(valid_losses) self.visualizer.add_scalar( diff --git a/parakeet/datasets/ljspeech.py b/parakeet/datasets/ljspeech.py index a37863f..c34f52b 100644 --- a/parakeet/datasets/ljspeech.py +++ b/parakeet/datasets/ljspeech.py @@ -25,7 +25,7 @@ class LJSpeechMetaData(Dataset): csv_path = self.root / "metadata.csv" records = [] speaker_name = "ljspeech" - with open(str(csv_path), 'rt') as f: + with open(str(csv_path), 'rt', encoding='utf-8') as f: for line in f: filename, _, normalized_text = line.strip().split("|") filename = str(wav_dir / (filename + ".wav")) diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index d67f40e..1587108 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -71,8 +71,10 @@ class DecoderPreNet(nn.Layer): """ - x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate) - output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate) + x = F.dropout( + F.relu(self.linear1(x)), self.dropout_rate, training=True) + output = F.dropout( + F.relu(self.linear2(x)), self.dropout_rate, training=True) return output @@ -161,9 +163,13 @@ class DecoderPostNet(nn.Layer): for i in range(len(self.conv_batchnorms) - 1): input = F.dropout( - F.tanh(self.conv_batchnorms[i](input), self.dropout)) - output = F.dropout(self.conv_batchnorms[self.num_layers - 1](input), - self.dropout) + F.tanh(self.conv_batchnorms[i](input)), + self.dropout, + training=self.training) + output = F.dropout( + self.conv_batchnorms[self.num_layers - 1](input), + self.dropout, + training=self.training) return output @@ -228,8 +234,10 @@ class Tacotron2Encoder(nn.Layer): """ for conv_batchnorm in self.conv_batchnorms: - x = F.dropout(F.relu(conv_batchnorm(x)), - self.p_dropout) #(B, T, C) + x = F.dropout( + F.relu(conv_batchnorm(x)), + self.p_dropout, + training=self.training) output, _ = self.lstm(inputs=x, sequence_length=input_lens) return output @@ -350,8 +358,10 @@ class Tacotron2Decoder(nn.Layer): # The first lstm layer _, (self.attention_hidden, self.attention_cell) = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) - self.attention_hidden = F.dropout(self.attention_hidden, - self.p_attention_dropout) + self.attention_hidden = F.dropout( + self.attention_hidden, + self.p_attention_dropout, + training=self.training) # Loaction sensitive attention attention_weights_cat = paddle.stack( @@ -367,7 +377,9 @@ class Tacotron2Decoder(nn.Layer): _, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout( - self.decoder_hidden, p=self.p_decoder_dropout) + self.decoder_hidden, + p=self.p_decoder_dropout, + training=self.training) # decode output one step decoder_hidden_attention_context = paddle.concat( diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index c7f0ccd..05ce008 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -391,7 +391,7 @@ class TransformerTTS(nn.Layer): padding_idx=frontend.vocab.padding_index, weight_attr=I.Uniform(-0.05, 0.05)) # position encoding matrix may be extended later - self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) + self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder) self.encoder_pe_scalar = self.create_parameter( [1], attr=I.Constant(1.)) self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, @@ -399,7 +399,7 @@ class TransformerTTS(nn.Layer): # decoder self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) - self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) + self.decoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_decoder) self.decoder_pe_scalar = self.create_parameter( [1], attr=I.Constant(1.)) self.decoder = TransformerDecoder( diff --git a/parakeet/modules/positional_encoding.py b/parakeet/modules/positional_encoding.py index 07a86c9..cec168c 100644 --- a/parakeet/modules/positional_encoding.py +++ b/parakeet/modules/positional_encoding.py @@ -17,10 +17,10 @@ import numpy as np import paddle from paddle.nn import functional as F -__all__ = ["positional_encoding"] +__all__ = ["sinusoid_positional_encoding"] -def positional_encoding(start_index, length, size, dtype=None): +def sinusoid_positional_encoding(start_index, length, size, dtype=None): r"""Generate standard positional encoding matrix. .. math::