add ge2e and tacotron2_aishell3 example (#107)
* hacky thing, add tone support for acoustic model * fix experiments for waveflow and wavenet, only write visual log in rank-0 * use emb add in tacotron2 * 1. remove space from numericalized representation; 2. fix decoder paddign mask's unsqueeze dim. * remove bn in postnet * refactoring code * add an option to normalize volume when loading audio. * add an embedding layer. * 1. change the default min value of LogMagnitude to 1e-5; 2. remove stop logit prediction from tacotron2 model. * WIP: baker * add ge2e * fix lstm speaker encoder * fix lstm speaker encoder * fix speaker encoder and add support for 2 more datasets * simplify visualization code * add a simple strategy to support multispeaker for tacotron. * add vctk example for refactored tacotron * fix indentation * fix class name * fix visualizer * fix root path * fix root path * fix root path * fix typos * fix bugs * fix text log extention name * add example for baker and aishell3 * update experiment and display * format code for tacotron_vctk, add plot_waveform to display * add new trainer * minor fix * add global condition support for tacotron2 * add gst layer * add 2 frontend * fix fmax for example/waveflow * update collate function, data loader not does not convert nested list into numpy array. * WIP: add hifigan * WIP:update hifigan * change stft to use conv1d * add audio datasets * change batch_text_id, batch_spec, batch_wav to include valid lengths in the returned value * change wavenet to use on-the-fly prepeocessing * fix typos * resolve conflict * remove imports that are removed * remove files not included in this release * remove imports to deleted modules * move tacotron2_msp * clean code * fix argument order * fix argument name * clean code for data processing * WIP: add README * add more details to thr README, fix some preprocess scripts * add voice cloning notebook * add an optional to alter the loss and model structure of tacotron2, add an alternative config * add plot_multiple_attentions and update visualization code in transformer_tts * format code * remove tacotron2_msp * update tacotron2 from_pretrained, update setup.py * update tacotron2 * update tacotron_aishell3's README * add images for exampels/tacotron2_aishell3's README * update README for examples/ge2e * add STFT back * add extra_config keys into the default config of tacotron * fix typos and docs * update README and doc * update docstrings for tacotron * update doc * update README * add links to downlaod pretrained models * refine READMEs and clean code * add praatio into requirements for running the experiments * format code with pre-commit * simplify text processing code and update notebook
This commit is contained in:
parent
0aa7088d36
commit
4f288a6d4f
28
README.md
28
README.md
|
@ -18,14 +18,14 @@ In order to facilitate exploiting the existing TTS models directly and developin
|
|||
|
||||
- Vocoders
|
||||
- [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219)
|
||||
- [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499)
|
||||
|
||||
- TTS models
|
||||
- [Neural Speech Synthesis with Transformer Network (Transformer TTS)](https://arxiv.org/abs/1809.08895)
|
||||
- [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](arxiv.org/abs/1712.05884)
|
||||
|
||||
## Updates
|
||||
|
||||
And more will be added in the future.
|
||||
May-07-2021, Add an example for voice cloning in Chinese. Check [examples/tacotron2_aishell3](./examples/tacotron2_aishell3).
|
||||
|
||||
|
||||
## Setup
|
||||
|
@ -45,7 +45,7 @@ See [install](https://www.paddlepaddle.org.cn/install/quick) for more details. T
|
|||
pip install -U paddle-parakeet
|
||||
```
|
||||
|
||||
or
|
||||
or
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/Parakeet
|
||||
cd Parakeet
|
||||
|
@ -59,9 +59,10 @@ See [install](https://paddle-parakeet.readthedocs.io/en/latest/install.html) for
|
|||
Entries to the introduction, and the launch of training and synthsis for different example models:
|
||||
|
||||
- [>>> WaveFlow](./examples/waveflow)
|
||||
- [>>> WaveNet](./examples/wavenet)
|
||||
- [>>> Transformer TTS](./examples/transformer_tts)
|
||||
- [>>> Tacotron2](./examples/tacotron2)
|
||||
- [>>> Tacotron2_AISHELL3](./examples/tacotron2_aishell3)
|
||||
- [>>> GE2E](./examples/ge2e)
|
||||
|
||||
|
||||
## Audio samples
|
||||
|
@ -70,6 +71,25 @@ Entries to the introduction, and the launch of training and synthsis for differe
|
|||
|
||||
Check our [website](https://paddle-parakeet.readthedocs.io/en/latest/demo.html) for audio sampels.
|
||||
|
||||
|
||||
## Checkpoints
|
||||
|
||||
### Tacotron2
|
||||
1. [tacotron2_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_ljspeech_ckpt_0.3.zip)
|
||||
2. [tacotron2_ljspeech_ckpt_0.3_alternative.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_ljspeech_ckpt_0.3_alternative.zip)
|
||||
|
||||
### Tacotron2_AISHELL3
|
||||
1. [tacotron2_aishell3_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_aishell3_ckpt_0.3.zip)
|
||||
|
||||
### TransformerTTS
|
||||
1. [transformer_tts_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/transformer_tts_ljspeech_ckpt_0.3.zip)
|
||||
|
||||
### WaveFlow
|
||||
1. [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_ljspeech_ckpt_0.3.zip)
|
||||
|
||||
### GE2E
|
||||
1. [ge2e_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/ge2e_ckpt_0.3.zip)
|
||||
|
||||
## Copyright and License
|
||||
|
||||
Parakeet is provided under the [Apache-2.0 license](LICENSE).
|
||||
|
|
|
@ -68,7 +68,6 @@ exclude_patterns = []
|
|||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
|
|
|
@ -140,4 +140,48 @@ Vocoder audio samples
|
|||
|
||||
Audio samples generated from ground-truth spectrograms with a vocoder.
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<embed>
|
||||
<table>
|
||||
<tr>
|
||||
<th align="left"> WaveFlow res 128</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<audio controls="controls">
|
||||
<source
|
||||
src="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_samples_1.0/step_2000k_sentence_0.wav"
|
||||
type="audio/wav">
|
||||
Your browser does not support the <code>audio</code> element.
|
||||
</audio>
|
||||
<audio controls="controls">
|
||||
<source
|
||||
src="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_samples_1.0/step_2000k_sentence_1.wav"
|
||||
type="audio/wav">
|
||||
Your browser does not support the <code>audio</code> element.
|
||||
</audio>
|
||||
<audio controls="controls">
|
||||
<source
|
||||
src="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_samples_1.0/step_2000k_sentence_2.wav"
|
||||
type="audio/wav">
|
||||
Your browser does not support the <code>audio</code> element.
|
||||
</audio>
|
||||
<audio controls="controls">
|
||||
<source
|
||||
src="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_samples_1.0/step_2000k_sentence_3.wav"
|
||||
type="audio/wav">
|
||||
Your browser does not support the <code>audio</code> element.
|
||||
</audio>
|
||||
<audio controls="controls">
|
||||
<source
|
||||
src="https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_res128_ljspeech_samples_1.0/step_2000k_sentence_4.wav"
|
||||
type="audio/wav">
|
||||
Your browser does not support the <code>audio</code> element.
|
||||
</audio>
|
||||
</td>
|
||||
</tr>
|
||||
</tabel>
|
||||
</table>
|
||||
</embed>
|
||||
|
||||
|
|
|
@ -28,13 +28,6 @@ parakeet.models.waveflow module
|
|||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
parakeet.models.wavenet module
|
||||
------------------------------
|
||||
|
||||
.. automodule:: parakeet.models.wavenet
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
|
|
@ -0,0 +1,129 @@
|
|||
# Speaker Encoder
|
||||
|
||||
This experiment trains a speaker encoder with speaker verification as its task. It is done as a part of the experiment of transfer learning from speaker verification to multispeaker text-to-speech synthesis, which can be found at [tacotron2_aishell3](../tacotron2_shell3). The trained speaker encoder is used to extract utterance embeddings from utterances.
|
||||
|
||||
## Model
|
||||
|
||||
The model used in this experiment is the speaker encoder with text independent speaker verification task in [GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION](https://arxiv.org/pdf/1710.10467.pdf). GE2E-softmax loss is used.
|
||||
|
||||
## File Structure
|
||||
|
||||
```text
|
||||
ge2e
|
||||
├── README.md
|
||||
├── README_cn.md
|
||||
├── audio_processor.py
|
||||
├── config.py
|
||||
├── dataset_processors.py
|
||||
├── inference.py
|
||||
├── preprocess.py
|
||||
├── random_cycle.py
|
||||
├── speaker_verification_dataset.py
|
||||
└── train.py
|
||||
```
|
||||
|
||||
## Download Datasets
|
||||
|
||||
Currently supported datasets are Librispeech-other-500, VoxCeleb, VoxCeleb2,ai-datatang-200zh, magicdata, which can be downloaded from corresponding webpage.
|
||||
|
||||
1. Librispeech/train-other-500
|
||||
|
||||
An English multispeaker dataset,[URL](https://www.openslr.org/resources/12/train-other-500.tar.gz),only the `train-other-500` subset is used.
|
||||
|
||||
2. VoxCeleb1
|
||||
|
||||
An English multispeaker dataset,[URL](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) , Audio Files from Dev A to Dev D should be downloaded, combined and extracted.
|
||||
|
||||
3. VoxCeleb2
|
||||
|
||||
An English multispeaker dataset,[URL](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) , Audio Files from Dev A to Dev H should be downloaded, combined and extracted.
|
||||
|
||||
4. Aidatatang-200zh
|
||||
|
||||
A Mandarin Chinese multispeaker dataset ,[URL](https://www.openslr.org/62/) .
|
||||
|
||||
5. magicdata
|
||||
|
||||
A Mandarin Chinese multispeaker dataset ,[URL](https://www.openslr.org/68/) .
|
||||
|
||||
If you want to use other datasets, you can also download and preprocess it as long as it meets the requirements described below.
|
||||
|
||||
## Preprocess Datasets
|
||||
|
||||
Multispeaker datasets are used as training data, though the transcriptions are not used. To enlarge the amount of data used for training, several multispeaker datasets are combined. The preporcessed datasets are organized in a file structure described below. The mel spectrogram of each utterance is save in `.npy` format. The dataset is 2-stratified (speaker-utterance). Since multiple datasets are combined, to avoid conflict in speaker id, dataset name is prepended to the speake ids.
|
||||
|
||||
```text
|
||||
dataset_root
|
||||
├── dataset01_speaker01/
|
||||
│ ├── utterance01.npy
|
||||
│ ├── utterance02.npy
|
||||
│ └── utterance03.npy
|
||||
├── dataset01_speaker02/
|
||||
│ ├── utterance01.npy
|
||||
│ ├── utterance02.npy
|
||||
│ └── utterance03.npy
|
||||
├── dataset02_speaker01/
|
||||
│ ├── utterance01.npy
|
||||
│ ├── utterance02.npy
|
||||
│ └── utterance03.npy
|
||||
└── dataset02_speaker02/
|
||||
├── utterance01.npy
|
||||
├── utterance02.npy
|
||||
└── utterance03.npy
|
||||
```
|
||||
|
||||
Run the command to preprocess datasets.
|
||||
|
||||
```bash
|
||||
python preprocess.py --datasets_root=<datasets_root> --output_dir=<output_dir> --dataset_names=<dataset_names>
|
||||
```
|
||||
|
||||
Here `--datasets_root` is the directory that contains several extracted dataset; `--output_dir` is the directory to save the preprocessed dataset; `--dataset_names` is the dataset to preprocess. If there are multiple datasets in `--datasets_root` to preprocess, the names can be joined with comma. Currently supported dataset names are librispeech_other, voxceleb1, voxceleb2, aidatatang_200zh and magicdata.
|
||||
|
||||
## Training
|
||||
|
||||
When preprocessing is done, run the command below to train the mdoel.
|
||||
|
||||
```bash
|
||||
python train.py --data=<data_path> --output=<output> --device="gpu" --nprocs=1
|
||||
```
|
||||
|
||||
- `--data` is the path to the preprocessed dataset.
|
||||
- `--output` is the directory to save results,usually a subdirectory of `runs`.It contains visualdl log files, text log files, config file and a `checkpoints` directory, which contains parameter file and optimizer state file. If `--output` already has some training results in it, the most recent parameter file and optimizer state file is loaded before training.
|
||||
- `--device` is the device type to run the training, 'cpu' and 'gpu' are supported.
|
||||
- `--nprocs` is the number of replicas to run in multiprocessing based parallel training。Currently multiprocessing based parallel training is only enabled when using 'gpu' as the devicde. `CUDA_VISIBLE_DEVICES` can be used to specify visible devices with cuda.
|
||||
|
||||
Other options are described below.
|
||||
|
||||
- `--config` is a `.yaml` config file used to override the default config(which is coded in `config.py`).
|
||||
- `--opts` is command line options to further override config files. It should be the last comman line options passed with multiple key-value pairs separated by spaces.
|
||||
- `--checkpoint_path` specifies the checkpoiont to load before training, extension is not included. A parameter file ( `.pdparams`) and an optimizer state file ( `.pdopt`) with the same name is used. This option has a higher priority than auto-resuming from the `--output` directory.
|
||||
|
||||
## Pretrained Model
|
||||
|
||||
The pretrained model is first trained to 1560k steps at Librispeech-other-500 and voxceleb1. Then trained at aidatatang_200h and magic_data to 3000k steps.
|
||||
|
||||
Download URL [ge2e_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/ge2e_ckpt_0.3.zip).
|
||||
|
||||
## Inference
|
||||
|
||||
When training is done, run the command below to generate utterance embedding for each utterance in a dataset.
|
||||
|
||||
```bash
|
||||
python inference.py --input=<input> --output=<output> --checkpoint_path=<checkpoint_path> --device="gpu"
|
||||
```
|
||||
|
||||
`--input` is the path of the dataset used for inference.
|
||||
|
||||
`--output` is the directory to save the processed results. It has the same file structure as the input dataset. Each utterance in the dataset has a corrsponding utterance embedding file in `*.npy` format.
|
||||
|
||||
`--checkpoint_path` is the path of the checkpoint to use, extension not included.
|
||||
|
||||
`--pattern` is the wildcard pattern to filter audio files for inference, defaults to `*.wav`.
|
||||
|
||||
`--device` and `--opts` have the same meaning as in the training script.
|
||||
|
||||
## References
|
||||
|
||||
1. [Generalized End-to-end Loss for Speaker Verification](https://arxiv.org/pdf/1710.10467.pdf)
|
||||
2. [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf)
|
|
@ -0,0 +1,124 @@
|
|||
# Speaker Encoder
|
||||
|
||||
本实验是的在多说话人数据集上以 Speaker Verification 为任务训练一个 speaker encoder, 这是作为 transfer learning from speaker verification to multispeaker text-to-speech synthesis 实验的一部分, 可以在 [tacotron2_aishell3](../tacotron2_aishell3) 中找到。用训练好的模型来提取音频的 utterance embedding.
|
||||
|
||||
## 模型
|
||||
|
||||
本实验使用的模型是 [GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION](https://arxiv.org/pdf/1710.10467.pdf) 中的 speaker encoder text independent 模型。使用的是 GE2E softmax 损失函数。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```text
|
||||
ge2e
|
||||
├── README_cn.md
|
||||
├── audio_processor.py
|
||||
├── config.py
|
||||
├── dataset_processors.py
|
||||
├── inference.py
|
||||
├── preprocess.py
|
||||
├── random_cycle.py
|
||||
├── speaker_verification_dataset.py
|
||||
└── train.py
|
||||
```
|
||||
|
||||
## 数据集下载
|
||||
|
||||
本实验支持了 Librispeech-other-500, VoxCeleb, VoxCeleb2,ai-datatang-200zh, magicdata 数据集。可以在对应的页面下载。
|
||||
|
||||
1. Librispeech/train-other-500
|
||||
|
||||
英文多说话人数据集,[下载链接](https://www.openslr.org/resources/12/train-other-500.tar.gz),我们的实验中仅用到了 train-other-500 这个子集。
|
||||
|
||||
2. VoxCeleb1
|
||||
|
||||
英文多说话人数据集,[下载链接](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html),需要下载其中的 Audio Files 中的 Dev A 到 Dev D 四个压缩文件并合并解压。
|
||||
|
||||
3. VoxCeleb2
|
||||
|
||||
英文多说话人数据集,[下载链接](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html),需要下载其中的 Audio Files 中的 Dev A 到 Dev H 八个压缩文件并合并解压。
|
||||
|
||||
4. Aidatatang-200zh
|
||||
|
||||
中文多说话人数据集,[下载链接](https://www.openslr.org/62/)。
|
||||
|
||||
5. magicdata
|
||||
|
||||
中文多说话人数据集,[下载链接](https://www.openslr.org/68/)。
|
||||
|
||||
如果用户需要使用其他的数据集,也可以自行下载并进行数据处理,只要符合如下的要求。
|
||||
|
||||
## 数据集预处理
|
||||
|
||||
训练中使用的数据集是多说话人数据集,transcription 并不会被使用。为了扩大数据的量,训练过程可以将多个数据集合并为一个。处理后的文件结果组织方式如下,每个句子的频谱存储为 `.npy` 格式。以 speaker-utterance 的两层目录结构存储。因为合并数据集的原因,为了避免 speaker id 冲突,dataset 名会被添加到 speaker id 前面。
|
||||
|
||||
```text
|
||||
dataset_root
|
||||
├── dataset01_speaker01/
|
||||
│ ├── utterance01.npy
|
||||
│ ├── utterance02.npy
|
||||
│ └── utterance03.npy
|
||||
├── dataset01_speaker02/
|
||||
│ ├── utterance01.npy
|
||||
│ ├── utterance02.npy
|
||||
│ └── utterance03.npy
|
||||
├── dataset02_speaker01/
|
||||
│ ├── utterance01.npy
|
||||
│ ├── utterance02.npy
|
||||
│ └── utterance03.npy
|
||||
└── dataset02_speaker02/
|
||||
├── utterance01.npy
|
||||
├── utterance02.npy
|
||||
└── utterance03.npy
|
||||
```
|
||||
|
||||
运行数据处理脚本
|
||||
|
||||
```bash
|
||||
python preprocess.py --datasets_root=<datasets_root> --output_dir=<output_dir> --dataset_names=<dataset_names>
|
||||
```
|
||||
|
||||
其中 datasets_root 是包含多个原始数据集的路径,--output_dir 是多个数据集合并后输出的路径,dataset_names 是数据集的名称,多个数据集可以用逗号分割,比如 'librispeech_other, voxceleb1'. 目前支持的数据集有 librispeech_other, voxceleb1, voxceleb2, aidatatang_200zh, magicdata.
|
||||
|
||||
## 训练
|
||||
|
||||
数据处理完成后,使用如下的脚本训练。
|
||||
|
||||
```bash
|
||||
python train.py --data=<data_path> --output=<output> --device="gpu" --nprocs=1
|
||||
```
|
||||
|
||||
- `--data` 是处理后的数据集路径。
|
||||
- `--output` 是训练结果的保存路径,一般使用 runs 下的一个子目录。保存结果包含 visualdl 的 log 文件,文本 log 记录,运行 config 备份,以及 checkpoints 目录,里面包含参数文件和优化器状态文件。如果指定的 output 路径包含此前的训练结果,训练前会自动加载最近的参数文件和优化器状态文件。
|
||||
- `--device` 是运行设备,目前支持 'cpu' 和 'gpu'.
|
||||
- `--nprocs` 是指定运行进程数。目前仅在使用 'gpu' 是支持多进程训练。可以配合 `CUDA_VISIBLE_DEVICES` 环境变量指定可见卡号。
|
||||
|
||||
另外还有几个选项。
|
||||
|
||||
- `--config` 是用于覆盖默认配置(默认配置可以查看 `config.py`) 的配置文件,为 `.yaml` 文件。
|
||||
- `--opts` 是用命令行参数进一步覆盖配置。这是最后一个传入的命令行选项,用多组空格分隔的 KEY VALUE 对的方式传入。
|
||||
- `--checkpoint_path` 指定从中恢复的 checkpoint, 不需要包含扩展名。同名的参数文件( `.pdparams`) 和优化器文件( `.pdopt`)会被加载以恢复训练。这个参数指定的恢复训练优先级高于自动从 `output` 文件夹中恢复训练。
|
||||
|
||||
## 预训练模型
|
||||
|
||||
预训练模型是在 Librispeech-other-500 和 voxceleb1 上训练到 1560k steps 后用 aidatatang_200h 和 magic_data 训练到 3000k 的结果。
|
||||
|
||||
下载链接 [ge2e_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/ge2e_ckpt_0.3.zip)
|
||||
|
||||
## 预测
|
||||
|
||||
使用训练好的模型进行预测,对一个数据集中的所有 utterance 生成一个 embedding.
|
||||
|
||||
```bash
|
||||
python inference.py --input=<input> --output=<output> --checkpoint_path=<checkpoint_path> --device="gpu"
|
||||
```
|
||||
|
||||
- `--input` 是需要处理的数据集的路径。
|
||||
- `--output` 是处理的结果,它会保持和 `--input` 相同的文件夹结构,对应 input 中的每一个音频文件会有一个同名的 `*.npy` 文件,是从这个音频文件中提取到的 utterance embedding.
|
||||
- `--checkpoint_path` 为用于预测的参数文件路径,不包含扩展名。
|
||||
- `--pattern` 是用于筛选数据集中需要处理的音频文件的通配符模式,默认为 `*.wav`.
|
||||
- `--device` 和 `--opts` 的语义和训练脚本一致。
|
||||
|
||||
## 参考文献
|
||||
|
||||
1. [GENERALIZED END-TO-END LOSS FOR SPEAKER VERIFICATION](https://arxiv.org/pdf/1710.10467.pdf)
|
||||
2. [Transfer Learning from Speaker Verification toMultispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf)
|
|
@ -0,0 +1,237 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
import struct
|
||||
|
||||
from scipy.ndimage.morphology import binary_dilation
|
||||
import numpy as np
|
||||
import librosa
|
||||
|
||||
try:
|
||||
import webrtcvad
|
||||
except ModuleNotFoundError:
|
||||
warn("Unable to import 'webrtcvad'."
|
||||
"This package enables noise removal and is recommended.")
|
||||
webrtcvad = None
|
||||
|
||||
INT16_MAX = (2**15) - 1
|
||||
|
||||
|
||||
def normalize_volume(wav,
|
||||
target_dBFS,
|
||||
increase_only=False,
|
||||
decrease_only=False):
|
||||
# this function implements Loudness normalization, instead of peak
|
||||
# normalization, See https://en.wikipedia.org/wiki/Audio_normalization
|
||||
# dBFS: Decibels relative to full scale
|
||||
# See https://en.wikipedia.org/wiki/DBFS for more details
|
||||
# for 16Bit PCM audio, minimal level is -96dB
|
||||
# compute the mean dBFS and adjust to target dBFS, with by increasing
|
||||
# or decreasing
|
||||
if increase_only and decrease_only:
|
||||
raise ValueError("Both increase only and decrease only are set")
|
||||
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
|
||||
if ((dBFS_change < 0 and increase_only) or
|
||||
(dBFS_change > 0 and decrease_only)):
|
||||
return wav
|
||||
gain = 10**(dBFS_change / 20)
|
||||
return wav * gain
|
||||
|
||||
|
||||
def trim_long_silences(wav,
|
||||
vad_window_length: int,
|
||||
vad_moving_average_width: int,
|
||||
vad_max_silence_length: int,
|
||||
sampling_rate: int):
|
||||
"""
|
||||
Ensures that segments without voice in the waveform remain no longer than a
|
||||
threshold determined by the VAD parameters in params.py.
|
||||
|
||||
:param wav: the raw waveform as a numpy array of floats
|
||||
:return: the same waveform with silences trimmed away (length <= original wav length)
|
||||
"""
|
||||
# Compute the voice detection window size
|
||||
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
||||
|
||||
# Trim the end of the audio to have a multiple of the window size
|
||||
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
||||
|
||||
# Convert the float waveform to 16-bit mono PCM
|
||||
pcm_wave = struct.pack("%dh" % len(wav),
|
||||
*(np.round(wav * INT16_MAX)).astype(np.int16))
|
||||
|
||||
# Perform voice activation detection
|
||||
voice_flags = []
|
||||
vad = webrtcvad.Vad(mode=3)
|
||||
for window_start in range(0, len(wav), samples_per_window):
|
||||
window_end = window_start + samples_per_window
|
||||
voice_flags.append(
|
||||
vad.is_speech(
|
||||
pcm_wave[window_start * 2:window_end * 2],
|
||||
sample_rate=sampling_rate))
|
||||
voice_flags = np.array(voice_flags)
|
||||
|
||||
# Smooth the voice detection with a moving average
|
||||
def moving_average(array, width):
|
||||
array_padded = np.concatenate((np.zeros((width - 1) // 2), array,
|
||||
np.zeros(width // 2)))
|
||||
ret = np.cumsum(array_padded, dtype=float)
|
||||
ret[width:] = ret[width:] - ret[:-width]
|
||||
return ret[width - 1:] / width
|
||||
|
||||
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
||||
audio_mask = np.round(audio_mask).astype(np.bool)
|
||||
|
||||
# Dilate the voiced regions
|
||||
audio_mask = binary_dilation(audio_mask,
|
||||
np.ones(vad_max_silence_length + 1))
|
||||
audio_mask = np.repeat(audio_mask, samples_per_window)
|
||||
|
||||
return wav[audio_mask]
|
||||
|
||||
|
||||
def compute_partial_slices(n_samples: int,
|
||||
partial_utterance_n_frames: int,
|
||||
hop_length: int,
|
||||
min_pad_coverage: float=0.75,
|
||||
overlap: float=0.5):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
||||
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
||||
defined in params_data.py.
|
||||
|
||||
The returned ranges may be indexing further than the length of the waveform. It is
|
||||
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
||||
|
||||
:param n_samples: the number of samples in the waveform
|
||||
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
||||
utterance
|
||||
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
||||
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
||||
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
||||
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
||||
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
||||
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
||||
utterances are entirely disjoint.
|
||||
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
||||
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
||||
utterances.
|
||||
"""
|
||||
assert 0 <= overlap < 1
|
||||
assert 0 < min_pad_coverage <= 1
|
||||
|
||||
# librosa's function to compute num_frames from num_samples
|
||||
n_frames = int(np.ceil((n_samples + 1) / hop_length))
|
||||
# frame shift between ajacent partials
|
||||
frame_step = max(
|
||||
1, int(np.round(partial_utterance_n_frames * (1 - overlap))))
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
||||
for i in range(0, steps, frame_step):
|
||||
mel_range = np.array([i, i + partial_utterance_n_frames])
|
||||
wav_range = mel_range * hop_length
|
||||
mel_slices.append(slice(*mel_range))
|
||||
wav_slices.append(slice(*wav_range))
|
||||
|
||||
# Evaluate whether extra padding is warranted or not
|
||||
last_wav_range = wav_slices[-1]
|
||||
coverage = (n_samples - last_wav_range.start) / (
|
||||
last_wav_range.stop - last_wav_range.start)
|
||||
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
||||
mel_slices = mel_slices[:-1]
|
||||
wav_slices = wav_slices[:-1]
|
||||
|
||||
return wav_slices, mel_slices
|
||||
|
||||
|
||||
class SpeakerVerificationPreprocessor(object):
|
||||
def __init__(self,
|
||||
sampling_rate: int,
|
||||
audio_norm_target_dBFS: float,
|
||||
vad_window_length,
|
||||
vad_moving_average_width,
|
||||
vad_max_silence_length,
|
||||
mel_window_length,
|
||||
mel_window_step,
|
||||
n_mels,
|
||||
partial_n_frames: int,
|
||||
min_pad_coverage: float=0.75,
|
||||
partial_overlap_ratio: float=0.5):
|
||||
self.sampling_rate = sampling_rate
|
||||
self.audio_norm_target_dBFS = audio_norm_target_dBFS
|
||||
|
||||
self.vad_window_length = vad_window_length
|
||||
self.vad_moving_average_width = vad_moving_average_width
|
||||
self.vad_max_silence_length = vad_max_silence_length
|
||||
|
||||
self.n_fft = int(mel_window_length * sampling_rate / 1000)
|
||||
self.hop_length = int(mel_window_step * sampling_rate / 1000)
|
||||
self.n_mels = n_mels
|
||||
|
||||
self.partial_n_frames = partial_n_frames
|
||||
self.min_pad_coverage = min_pad_coverage
|
||||
self.partial_overlap_ratio = partial_overlap_ratio
|
||||
|
||||
def preprocess_wav(self, fpath_or_wav, source_sr=None):
|
||||
# Load the wav from disk if needed
|
||||
if isinstance(fpath_or_wav, (str, Path)):
|
||||
wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
|
||||
else:
|
||||
wav = fpath_or_wav
|
||||
|
||||
# Resample if numpy.array is passed and sr does not match
|
||||
if source_sr is not None and source_sr != self.sampling_rate:
|
||||
wav = librosa.resample(wav, source_sr, self.sampling_rate)
|
||||
|
||||
# loudness normalization
|
||||
wav = normalize_volume(
|
||||
wav, self.audio_norm_target_dBFS, increase_only=True)
|
||||
|
||||
# trim long silence
|
||||
if webrtcvad:
|
||||
wav = trim_long_silences(
|
||||
wav, self.vad_window_length, self.vad_moving_average_width,
|
||||
self.vad_max_silence_length, self.sampling_rate)
|
||||
return wav
|
||||
|
||||
def melspectrogram(self, wav):
|
||||
mel = librosa.feature.melspectrogram(
|
||||
wav,
|
||||
sr=self.sampling_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels)
|
||||
mel = mel.astype(np.float32).T
|
||||
return mel
|
||||
|
||||
def extract_mel_partials(self, wav):
|
||||
wav_slices, mel_slices = compute_partial_slices(
|
||||
len(wav), self.partial_n_frames, self.hop_length,
|
||||
self.min_pad_coverage, self.partial_overlap_ratio)
|
||||
|
||||
# pad audio if needed
|
||||
max_wave_length = wav_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = self.melspectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
return frames_batch # [B, T, C]
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from yacs.config import CfgNode
|
||||
|
||||
_C = CfgNode()
|
||||
|
||||
data_config = _C.data = CfgNode()
|
||||
|
||||
## Audio volume normalization
|
||||
data_config.audio_norm_target_dBFS = -30
|
||||
|
||||
## Audio sample rate
|
||||
data_config.sampling_rate = 16000 # Hz
|
||||
|
||||
## Voice Activation Detection
|
||||
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
||||
# This sets the granularity of the VAD. Should not need to be changed.
|
||||
data_config.vad_window_length = 30 # In milliseconds
|
||||
# Number of frames to average together when performing the moving average smoothing.
|
||||
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
||||
data_config.vad_moving_average_width = 8
|
||||
# Maximum number of consecutive silent frames a segment can have.
|
||||
data_config.vad_max_silence_length = 6
|
||||
|
||||
## Mel-filterbank
|
||||
data_config.mel_window_length = 25 # In milliseconds
|
||||
data_config.mel_window_step = 10 # In milliseconds
|
||||
data_config.n_mels = 40 # mel bands
|
||||
|
||||
# Number of spectrogram frames in a partial utterance
|
||||
data_config.partial_n_frames = 160 # 1600 ms
|
||||
data_config.min_pad_coverage = 0.75 # at least 75% of the audio is valid in a partial
|
||||
data_config.partial_overlap_ratio = 0.5 # overlap ratio between ajancent partials
|
||||
|
||||
model_config = _C.model = CfgNode()
|
||||
model_config.num_layers = 3
|
||||
model_config.hidden_size = 256
|
||||
model_config.embedding_size = 256 # output size
|
||||
|
||||
training_config = _C.training = CfgNode()
|
||||
training_config.learning_rate_init = 1e-4
|
||||
training_config.speakers_per_batch = 64
|
||||
training_config.utterances_per_speaker = 10
|
||||
training_config.max_iteration = 1560000
|
||||
training_config.save_interval = 10000
|
||||
training_config.valid_interval = 10000
|
||||
|
||||
|
||||
def get_cfg_defaults():
|
||||
return _C.clone()
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
import multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from audio_processor import SpeakerVerificationPreprocessor
|
||||
|
||||
|
||||
def _process_utterance(path_pair, processor: SpeakerVerificationPreprocessor):
|
||||
# Load and preprocess the waveform
|
||||
input_path, output_path = path_pair
|
||||
wav = processor.preprocess_wav(input_path)
|
||||
if len(wav) == 0:
|
||||
return
|
||||
|
||||
# Create the mel spectrogram, discard those that are too short
|
||||
frames = processor.melspectrogram(wav)
|
||||
if len(frames) < processor.partial_n_frames:
|
||||
return
|
||||
|
||||
np.save(output_path, frames)
|
||||
|
||||
|
||||
def _process_speaker(speaker_dir: Path,
|
||||
processor: SpeakerVerificationPreprocessor,
|
||||
datasets_root: Path,
|
||||
output_dir: Path,
|
||||
pattern: str,
|
||||
skip_existing: bool=False):
|
||||
# datastes root: a reference path to compute speaker_name
|
||||
# we prepand dataset name to speaker_id becase we are mixing serveal
|
||||
# multispeaker datasets together
|
||||
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
||||
speaker_output_dir = output_dir / speaker_name
|
||||
speaker_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load exsiting file set
|
||||
sources_fpath = speaker_output_dir / "_sources.txt"
|
||||
if sources_fpath.exists():
|
||||
try:
|
||||
with sources_fpath.open("rt") as sources_file:
|
||||
existing_names = {line.split(",")[0] for line in sources_file}
|
||||
except:
|
||||
existing_names = {}
|
||||
else:
|
||||
existing_names = {}
|
||||
|
||||
sources_file = sources_fpath.open("at" if skip_existing else "wt")
|
||||
for in_fpath in speaker_dir.rglob(pattern):
|
||||
out_name = "_".join(
|
||||
in_fpath.relative_to(speaker_dir).with_suffix(".npy").parts)
|
||||
if skip_existing and out_name in existing_names:
|
||||
continue
|
||||
out_fpath = speaker_output_dir / out_name
|
||||
_process_utterance((in_fpath, out_fpath), processor)
|
||||
sources_file.write(f"{out_name},{in_fpath}\n")
|
||||
|
||||
sources_file.close()
|
||||
|
||||
|
||||
def _process_dataset(processor: SpeakerVerificationPreprocessor,
|
||||
datasets_root: Path,
|
||||
speaker_dirs: List[Path],
|
||||
dataset_name: str,
|
||||
output_dir: Path,
|
||||
pattern: str,
|
||||
skip_existing: bool=False):
|
||||
print(
|
||||
f"{dataset_name}: Preprocessing data for {len(speaker_dirs)} speakers.")
|
||||
|
||||
_func = partial(
|
||||
_process_speaker,
|
||||
processor=processor,
|
||||
datasets_root=datasets_root,
|
||||
output_dir=output_dir,
|
||||
pattern=pattern,
|
||||
skip_existing=skip_existing)
|
||||
|
||||
with mp.Pool(16) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap(_func, speaker_dirs),
|
||||
dataset_name,
|
||||
len(speaker_dirs),
|
||||
unit="speakers"))
|
||||
print(f"Done preprocessing {dataset_name}.")
|
||||
|
||||
|
||||
def process_librispeech(processor,
|
||||
datasets_root,
|
||||
output_dir,
|
||||
skip_existing=False):
|
||||
dataset_name = "LibriSpeech/train-other-500"
|
||||
dataset_root = datasets_root / dataset_name
|
||||
speaker_dirs = list(dataset_root.glob("*"))
|
||||
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
|
||||
output_dir, "*.flac", skip_existing)
|
||||
|
||||
|
||||
def process_voxceleb1(processor,
|
||||
datasets_root,
|
||||
output_dir,
|
||||
skip_existing=False):
|
||||
dataset_name = "VoxCeleb1"
|
||||
dataset_root = datasets_root / dataset_name
|
||||
|
||||
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
||||
with dataset_root.joinpath("vox1_meta.csv").open("rt") as metafile:
|
||||
metadata = [line.strip().split("\t") for line in metafile][1:]
|
||||
|
||||
# speaker id -> nationality
|
||||
nationalities = {
|
||||
line[0]: line[3]
|
||||
for line in metadata if line[-1] == "dev"
|
||||
}
|
||||
keep_speaker_ids = [
|
||||
speaker_id for speaker_id, nationality in nationalities.items()
|
||||
if nationality.lower() in anglophone_nationalites
|
||||
]
|
||||
print(
|
||||
"VoxCeleb1: using samples from {} (presumed anglophone) speakers out of {}."
|
||||
.format(len(keep_speaker_ids), len(nationalities)))
|
||||
|
||||
speaker_dirs = list((dataset_root / "wav").glob("*"))
|
||||
speaker_dirs = [
|
||||
speaker_dir for speaker_dir in speaker_dirs
|
||||
if speaker_dir.name in keep_speaker_ids
|
||||
]
|
||||
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
|
||||
output_dir, "*.wav", skip_existing)
|
||||
|
||||
|
||||
def process_voxceleb2(processor,
|
||||
datasets_root,
|
||||
output_dir,
|
||||
skip_existing=False):
|
||||
dataset_name = "VoxCeleb2"
|
||||
dataset_root = datasets_root / dataset_name
|
||||
# There is no nationality in meta data for VoxCeleb2
|
||||
speaker_dirs = list((dataset_root / "wav").glob("*"))
|
||||
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
|
||||
output_dir, "*.wav", skip_existing)
|
||||
|
||||
|
||||
def process_aidatatang_200zh(processor,
|
||||
datasets_root,
|
||||
output_dir,
|
||||
skip_existing=False):
|
||||
dataset_name = "aidatatang_200zh/train"
|
||||
dataset_root = datasets_root / dataset_name
|
||||
|
||||
speaker_dirs = list((dataset_root).glob("*"))
|
||||
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
|
||||
output_dir, "*.wav", skip_existing)
|
||||
|
||||
|
||||
def process_magicdata(processor,
|
||||
datasets_root,
|
||||
output_dir,
|
||||
skip_existing=False):
|
||||
dataset_name = "magicdata/train"
|
||||
dataset_root = datasets_root / dataset_name
|
||||
|
||||
speaker_dirs = list((dataset_root).glob("*"))
|
||||
_process_dataset(processor, datasets_root, speaker_dirs, dataset_name,
|
||||
output_dir, "*.wav", skip_existing)
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
from parakeet.models.lstm_speaker_encoder import LSTMSpeakerEncoder
|
||||
|
||||
from audio_processor import SpeakerVerificationPreprocessor
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
def embed_utterance(processor, model, fpath_or_wav):
|
||||
# audio processor
|
||||
wav = processor.preprocess_wav(fpath_or_wav)
|
||||
mel_partials = processor.extract_mel_partials(wav)
|
||||
|
||||
model.eval()
|
||||
# speaker encoder
|
||||
with paddle.no_grad():
|
||||
mel_partials = paddle.to_tensor(mel_partials)
|
||||
with paddle.no_grad():
|
||||
embed = model.embed_utterance(mel_partials)
|
||||
embed = embed.numpy()
|
||||
return embed
|
||||
|
||||
|
||||
def _process_utterance(ifpath: Path,
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
processor: SpeakerVerificationPreprocessor,
|
||||
model: LSTMSpeakerEncoder):
|
||||
rel_path = ifpath.relative_to(input_dir)
|
||||
ofpath = (output_dir / rel_path).with_suffix(".npy")
|
||||
ofpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
embed = embed_utterance(processor, model, ifpath)
|
||||
np.save(ofpath, embed)
|
||||
|
||||
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
|
||||
# load model
|
||||
model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers,
|
||||
config.model.hidden_size,
|
||||
config.model.embedding_size)
|
||||
weights_fpath = str(Path(args.checkpoint_path).expanduser())
|
||||
model_state_dict = paddle.load(weights_fpath + ".pdparams")
|
||||
model.set_state_dict(model_state_dict)
|
||||
model.eval()
|
||||
print(f"Loaded encoder {weights_fpath}")
|
||||
|
||||
# create audio processor
|
||||
c = config.data
|
||||
processor = SpeakerVerificationPreprocessor(
|
||||
sampling_rate=c.sampling_rate,
|
||||
audio_norm_target_dBFS=c.audio_norm_target_dBFS,
|
||||
vad_window_length=c.vad_window_length,
|
||||
vad_moving_average_width=c.vad_moving_average_width,
|
||||
vad_max_silence_length=c.vad_max_silence_length,
|
||||
mel_window_length=c.mel_window_length,
|
||||
mel_window_step=c.mel_window_step,
|
||||
n_mels=c.n_mels,
|
||||
partial_n_frames=c.partial_n_frames,
|
||||
min_pad_coverage=c.min_pad_coverage,
|
||||
partial_overlap_ratio=c.min_pad_coverage, )
|
||||
|
||||
# input output preparation
|
||||
input_dir = Path(args.input).expanduser()
|
||||
ifpaths = list(input_dir.rglob(args.pattern))
|
||||
print(f"{len(ifpaths)} utterances in total")
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for ifpath in tqdm.tqdm(ifpaths, unit="utterance"):
|
||||
_process_utterance(ifpath, input_dir, output_dir, processor, model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = argparse.ArgumentParser(description="compute utterance embed.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
metavar="FILE",
|
||||
help="path of the config file to overwrite to default config with.")
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="path of the audio_file folder.")
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
default="*.wav",
|
||||
help="pattern to filter audio files.")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
metavar="OUTPUT_DIR",
|
||||
help="path to save checkpoint and logs.")
|
||||
|
||||
# load from saved checkpoint
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="path of the checkpoint to load")
|
||||
|
||||
# running
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
choices=["cpu", "gpu"],
|
||||
help="device type to use, cpu and gpu are supported.")
|
||||
|
||||
# overwrite extra config and default config
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from config import get_cfg_defaults
|
||||
from audio_processor import SpeakerVerificationPreprocessor
|
||||
from dataset_processors import (process_librispeech, process_voxceleb1,
|
||||
process_voxceleb2, process_aidatatang_200zh,
|
||||
process_magicdata)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="preprocess dataset for speaker verification task")
|
||||
parser.add_argument(
|
||||
"--datasets_root",
|
||||
type=Path,
|
||||
help="Path to the directory containing your LibriSpeech, LibriTTS and VoxCeleb datasets."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=Path, help="Path to save processed dataset.")
|
||||
parser.add_argument(
|
||||
"--dataset_names",
|
||||
type=str,
|
||||
default="librispeech_other,voxceleb1,voxceleb2",
|
||||
help="comma-separated list of names of the datasets you want to preprocess. only "
|
||||
"the train set of these datastes will be used. Possible names: librispeech_other, "
|
||||
"voxceleb1, voxceleb2, aidatatang_200zh, magicdata.")
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="Whether to skip ouput files with the same name. Useful if this script was interrupted."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_trim",
|
||||
action="store_true",
|
||||
help="Preprocess audio without trimming silences (not recommended).")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.no_trim:
|
||||
try:
|
||||
import webrtcvad
|
||||
except:
|
||||
raise ModuleNotFoundError(
|
||||
"Package 'webrtcvad' not found. This package enables "
|
||||
"noise removal and is recommended. Please install and "
|
||||
"try again. If installation fails, "
|
||||
"use --no_trim to disable this error message.")
|
||||
del args.no_trim
|
||||
|
||||
args.datasets = [item.strip() for item in args.dataset_names.split(",")]
|
||||
if not hasattr(args, "output_dir"):
|
||||
args.output_dir = args.dataset_root / "SV2TTS" / "encoder"
|
||||
|
||||
args.output_dir = args.output_dir.expanduser()
|
||||
args.datasets_root = args.datasets_root.expanduser()
|
||||
assert args.datasets_root.exists()
|
||||
args.output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
config = get_cfg_defaults()
|
||||
print(args)
|
||||
|
||||
c = config.data
|
||||
processor = SpeakerVerificationPreprocessor(
|
||||
sampling_rate=c.sampling_rate,
|
||||
audio_norm_target_dBFS=c.audio_norm_target_dBFS,
|
||||
vad_window_length=c.vad_window_length,
|
||||
vad_moving_average_width=c.vad_moving_average_width,
|
||||
vad_max_silence_length=c.vad_max_silence_length,
|
||||
mel_window_length=c.mel_window_length,
|
||||
mel_window_step=c.mel_window_step,
|
||||
n_mels=c.n_mels,
|
||||
partial_n_frames=c.partial_n_frames,
|
||||
min_pad_coverage=c.min_pad_coverage,
|
||||
partial_overlap_ratio=c.min_pad_coverage, )
|
||||
|
||||
preprocess_func = {
|
||||
"librispeech_other": process_librispeech,
|
||||
"voxceleb1": process_voxceleb1,
|
||||
"voxceleb2": process_voxceleb2,
|
||||
"aidatatang_200zh": process_aidatatang_200zh,
|
||||
"magicdata": process_magicdata,
|
||||
}
|
||||
|
||||
for dataset in args.datasets:
|
||||
print("Preprocessing %s" % dataset)
|
||||
preprocess_func[dataset](processor, args.datasets_root,
|
||||
args.output_dir, args.skip_existing)
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
# cycle('ABCD') --> A B C D A B C D A B C D ...
|
||||
saved = []
|
||||
for element in iterable:
|
||||
yield element
|
||||
saved.append(element)
|
||||
while saved:
|
||||
for element in saved:
|
||||
yield element
|
||||
|
||||
|
||||
def random_cycle(iterable):
|
||||
# cycle('ABCD') --> A B C D B C D A A D B C ...
|
||||
saved = []
|
||||
for element in iterable:
|
||||
yield element
|
||||
saved.append(element)
|
||||
random.shuffle(saved)
|
||||
while saved:
|
||||
for element in saved:
|
||||
yield element
|
||||
random.shuffle(saved)
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, BatchSampler
|
||||
|
||||
from random_cycle import random_cycle
|
||||
|
||||
|
||||
class MultiSpeakerMelDataset(Dataset):
|
||||
"""A 2 layer directory thatn contains mel spectrograms in *.npy format.
|
||||
An Example file structure tree is shown below. We prefer to preprocess
|
||||
raw datasets and organized them like this.
|
||||
|
||||
dataset_root/
|
||||
speaker1/
|
||||
utterance1.npy
|
||||
utterance2.npy
|
||||
utterance3.npy
|
||||
speaker2/
|
||||
utterance1.npy
|
||||
utterance2.npy
|
||||
utterance3.npy
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_root: Path):
|
||||
self.root = Path(dataset_root).expanduser()
|
||||
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
||||
|
||||
speaker_utterances = {
|
||||
speaker_dir: list(speaker_dir.glob("*.npy"))
|
||||
for speaker_dir in speaker_dirs
|
||||
}
|
||||
|
||||
self.speaker_dirs = speaker_dirs
|
||||
self.speaker_to_utterances = speaker_utterances
|
||||
|
||||
# meta data
|
||||
self.num_speakers = len(self.speaker_dirs)
|
||||
self.num_utterances = np.sum(
|
||||
len(utterances)
|
||||
for speaker, utterances in self.speaker_to_utterances.items())
|
||||
|
||||
def get_example_by_index(self, speaker_index, utterance_index):
|
||||
speaker_dir = self.speaker_dirs[speaker_index]
|
||||
fpath = self.speaker_to_utterances[speaker_dir][utterance_index]
|
||||
return self[fpath]
|
||||
|
||||
def __getitem__(self, fpath):
|
||||
return np.load(fpath)
|
||||
|
||||
def __len__(self):
|
||||
return int(self.num_utterances)
|
||||
|
||||
|
||||
class MultiSpeakerSampler(BatchSampler):
|
||||
"""A multi-stratal sampler designed for speaker verification task.
|
||||
First, N speakers from all speakers are sampled randomly. Then, for each
|
||||
speaker, randomly sample M utterances from their corresponding utterances.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset: MultiSpeakerMelDataset,
|
||||
speakers_per_batch: int,
|
||||
utterances_per_speaker: int):
|
||||
self._speakers = list(dataset.speaker_dirs)
|
||||
self._speaker_to_utterances = dataset.speaker_to_utterances
|
||||
|
||||
self.speakers_per_batch = speakers_per_batch
|
||||
self.utterances_per_speaker = utterances_per_speaker
|
||||
|
||||
def __iter__(self):
|
||||
# yield list of Paths
|
||||
speaker_generator = iter(random_cycle(self._speakers))
|
||||
speaker_utterances_generator = {
|
||||
s: iter(random_cycle(us))
|
||||
for s, us in self._speaker_to_utterances.items()
|
||||
}
|
||||
|
||||
while True:
|
||||
speakers = []
|
||||
for _ in range(self.speakers_per_batch):
|
||||
speakers.append(next(speaker_generator))
|
||||
|
||||
utterances = []
|
||||
for s in speakers:
|
||||
us = speaker_utterances_generator[s]
|
||||
for _ in range(self.utterances_per_speaker):
|
||||
utterances.append(next(us))
|
||||
yield utterances
|
||||
|
||||
|
||||
class RandomClip(object):
|
||||
def __init__(self, frames):
|
||||
self.frames = frames
|
||||
|
||||
def __call__(self, spec):
|
||||
# spec [T, C]
|
||||
T = spec.shape[0]
|
||||
start = random.randint(0, T - self.frames)
|
||||
return spec[start:start + self.frames, :]
|
||||
|
||||
|
||||
class Collate(object):
|
||||
def __init__(self, num_frames):
|
||||
self.random_crop = RandomClip(num_frames)
|
||||
|
||||
def __call__(self, examples):
|
||||
frame_clips = [self.random_crop(mel) for mel in examples]
|
||||
batced_clips = np.stack(frame_clips)
|
||||
return batced_clips
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mydataset = MultiSpeakerMelDataset(
|
||||
Path("/home/chenfeiyu/datasets/SV2TTS/encoder"))
|
||||
print(mydataset.get_example_by_index(0, 10))
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
from paddle import distributed as dist
|
||||
from paddle.optimizer import Adam
|
||||
from paddle import DataParallel
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn.clip import ClipGradByGlobalNorm
|
||||
|
||||
from parakeet.models.lstm_speaker_encoder import LSTMSpeakerEncoder
|
||||
from parakeet.training import ExperimentBase
|
||||
from parakeet.training import default_argument_parser
|
||||
|
||||
from speaker_verification_dataset import MultiSpeakerMelDataset
|
||||
from speaker_verification_dataset import MultiSpeakerSampler
|
||||
from speaker_verification_dataset import Collate
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
class Ge2eExperiment(ExperimentBase):
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers,
|
||||
config.model.hidden_size,
|
||||
config.model.embedding_size)
|
||||
optimizer = Adam(
|
||||
config.training.learning_rate_init,
|
||||
parameters=model.parameters(),
|
||||
grad_clip=ClipGradByGlobalNorm(3))
|
||||
self.model = DataParallel(model) if self.parallel else model
|
||||
self.model_core = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config
|
||||
train_dataset = MultiSpeakerMelDataset(self.args.data)
|
||||
sampler = MultiSpeakerSampler(train_dataset,
|
||||
config.training.speakers_per_batch,
|
||||
config.training.utterances_per_speaker)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=Collate(config.data.partial_n_frames),
|
||||
num_workers=16)
|
||||
|
||||
self.train_dataset = train_dataset
|
||||
self.train_loader = train_loader
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
specs = batch
|
||||
loss, eer = self.model(specs, self.config.training.speakers_per_batch)
|
||||
loss.backward()
|
||||
self.model_core.do_gradient_ops()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
# logging
|
||||
loss_value = float(loss)
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
||||
iteration_time)
|
||||
msg += 'loss: {:>.6f} err: {:>.6f}'.format(loss_value, eer)
|
||||
self.logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
self.visualizer.add_scalar("train/loss", loss_value,
|
||||
self.iteration)
|
||||
self.visualizer.add_scalar("train/eer", eer, self.iteration)
|
||||
self.visualizer.add_scalar(
|
||||
"param/w",
|
||||
float(self.model_core.similarity_weight), self.iteration)
|
||||
self.visualizer.add_scalar("param/b",
|
||||
float(self.model_core.similarity_bias),
|
||||
self.iteration)
|
||||
|
||||
def valid(self):
|
||||
pass
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Ge2eExperiment(config, args)
|
||||
exp.setup()
|
||||
exp.resume_or_load()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -8,8 +8,9 @@ PaddlePaddle dynamic graph implementation of Tacotron2, a neural network archite
|
|||
├── config.py # default configuration file
|
||||
├── ljspeech.py # dataset and dataloader settings for LJSpeech
|
||||
├── preprocess.py # script to preprocess LJSpeech dataset
|
||||
├── synthesis.py # script to synthesize spectrogram from text
|
||||
├── synthesize.py # script to synthesize spectrogram from text
|
||||
├── train.py # script for tacotron2 model training
|
||||
├── synthesize.ipynb # notebook example for end-to-end TTS
|
||||
```
|
||||
|
||||
## Dataset
|
||||
|
@ -75,3 +76,17 @@ For more help on arguments
|
|||
``python synthesis.py --help``.
|
||||
|
||||
Then you can find the spectrogram files in ``${OUTPUTPATH}``, and then they can be the input of vocoder like [waveflow](../waveflow/README.md#Synthesis) to get audio files.
|
||||
|
||||
|
||||
## Pretrained Models
|
||||
|
||||
Pretrained Models can be downloaded from links below. We provide 2 models with different configurations.
|
||||
|
||||
1. This model use a binary classifier to predict the stop token. [tacotron2_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_ljspeech_ckpt_0.3.zip)
|
||||
|
||||
2. This model does not have a stop token predictor. It uses the attention peak position to decided whether all the contents have been uttered. Also guided attention loss is used to speed up training. This model is trained with `configs/alternative.yaml`.[tacotron2_ljspeech_ckpt_0.3_alternative.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_ljspeech_ckpt_0.3_alternative.zip)
|
||||
|
||||
|
||||
## Notebook: End-to-end TTS
|
||||
|
||||
See [synthesize.ipynb](./synthesize.ipynb) for details about end-to-end TTS with tacotron2 and waveflow.
|
||||
|
|
|
@ -23,23 +23,25 @@ _C.data = CN(
|
|||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
f_min=0, # Hz, min frequency when converting to mel
|
||||
d_mels=80, # mel bands
|
||||
fmax=8000, # Hz, max frequency when converting to mel
|
||||
fmin=0, # Hz, min frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
padding_idx=0, # text embedding's padding index
|
||||
))
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
vocab_size=37, # set this according to the frontend's vocab size
|
||||
n_tones=None,
|
||||
reduction_factor=1, # reduction factor
|
||||
d_encoder=512, # embedding & encoder's internal size
|
||||
encoder_conv_layers=3, # number of conv layer in tacotron2 encoder
|
||||
encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder
|
||||
d_prenet=256, # hidden size of decoder prenet
|
||||
d_attention_rnn=1024, # hidden size of the first rnn layer in tacotron2 decoder
|
||||
d_decoder_rnn=1024, #hidden size of the second rnn layer in tacotron2 decoder
|
||||
d_decoder_rnn=1024, # hidden size of the second rnn layer in tacotron2 decoder
|
||||
d_attention=128, # hidden size of decoder location linear layer
|
||||
attention_filters=32, # number of filter in decoder location conv layer
|
||||
attention_filters=32, # number of filter in decoder location conv layer
|
||||
attention_kernel_size=31, # kernel size of decoder location conv layer
|
||||
d_postnet=512, # hidden size of decoder postnet
|
||||
postnet_kernel_size=5, # kernel size of conv layers in postnet
|
||||
|
@ -48,7 +50,11 @@ _C.model = CN(
|
|||
p_prenet_dropout=0.5, # droput probability in decoder prenet
|
||||
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
||||
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
||||
p_postnet_dropout=0.5, #droput probability in decoder postnet
|
||||
p_postnet_dropout=0.5, # droput probability in decoder postnet
|
||||
d_global_condition=None,
|
||||
use_stop_token=True, # wherther to use binary classifier to predict when to stop
|
||||
use_guided_attention_loss=False, # whether to use guided attention loss
|
||||
guided_attention_loss_sigma=0.2 # sigma in guided attention loss
|
||||
))
|
||||
|
||||
_C.training = CN(
|
||||
|
|
|
@ -12,14 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, DataLoader
|
||||
from paddle.io import Dataset
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_text_id
|
||||
from parakeet.data import dataset
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
|
@ -58,7 +57,7 @@ class LJSpeechCollector(object):
|
|||
mels = []
|
||||
text_lens = []
|
||||
mel_lens = []
|
||||
stop_tokens = []
|
||||
|
||||
for data in examples:
|
||||
text, mel = data
|
||||
text = np.array(text, dtype=np.int64)
|
||||
|
@ -66,8 +65,6 @@ class LJSpeechCollector(object):
|
|||
mels.append(mel)
|
||||
texts.append(text)
|
||||
mel_lens.append(mel.shape[1])
|
||||
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
|
||||
stop_tokens.append(np.append(stop_token, 1.0))
|
||||
|
||||
# Sort by text_len in descending order
|
||||
texts = [
|
||||
|
@ -87,20 +84,12 @@ class LJSpeechCollector(object):
|
|||
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
stop_tokens = [
|
||||
i
|
||||
for i, _ in sorted(
|
||||
zip(stop_tokens, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
text_lens = sorted(text_lens, reverse=True)
|
||||
mel_lens = np.array(mel_lens, dtype=np.int64)
|
||||
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts = batch_text_id(texts, pad_id=self.padding_idx)
|
||||
mels = np.transpose(
|
||||
batch_spec(
|
||||
mels, pad_value=self.padding_value), axes=(0, 2, 1))
|
||||
stop_tokens = batch_text_id(
|
||||
stop_tokens, pad_id=self.padding_stop_token, dtype=mels[0].dtype)
|
||||
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
|
||||
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||
mels = np.transpose(mels, axes=(0, 2, 1))
|
||||
|
||||
return (texts, mels, text_lens, mel_lens, stop_tokens)
|
||||
return texts, mels, text_lens, mel_lens
|
||||
|
|
|
@ -13,12 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tqdm
|
||||
import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
import numpy as np
|
||||
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor, LogMagnitude
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
|
@ -37,11 +38,11 @@ def create_dataset(config, source_path, target_path, verbose=False):
|
|||
processor = AudioProcessor(
|
||||
sample_rate=config.data.sample_rate,
|
||||
n_fft=config.data.n_fft,
|
||||
n_mels=config.data.d_mels,
|
||||
n_mels=config.data.n_mels,
|
||||
win_length=config.data.win_length,
|
||||
hop_length=config.data.hop_length,
|
||||
f_max=config.data.f_max,
|
||||
f_min=config.data.f_min)
|
||||
fmax=config.data.fmax,
|
||||
fmin=config.data.fmin)
|
||||
normalizer = LogMagnitude()
|
||||
|
||||
records = []
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -14,12 +14,14 @@
|
|||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
import parakeet
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
from parakeet.models.tacotron2 import Tacotron2
|
||||
from parakeet.utils import display
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
@ -29,7 +31,7 @@ def main(config, args):
|
|||
|
||||
# model
|
||||
frontend = EnglishCharacter()
|
||||
model = Tacotron2.from_pretrained(frontend, config, args.checkpoint_path)
|
||||
model = Tacotron2.from_pretrained(config, args.checkpoint_path)
|
||||
model.eval()
|
||||
|
||||
# inputs
|
||||
|
@ -44,10 +46,15 @@ def main(config, args):
|
|||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
mel_output, _ = model.predict(sentence)
|
||||
mel_output = mel_output.T
|
||||
sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0)
|
||||
|
||||
outputs = model.infer(sentence)
|
||||
mel_output = outputs["mel_outputs_postnet"][0].numpy().T
|
||||
alignment = outputs["alignments"][0].numpy().T
|
||||
|
||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||
display.plot_alignment(alignment)
|
||||
plt.savefig(str(output_dir / f"sentence_{i}.png"))
|
||||
if args.verbose:
|
||||
print("spectrogram saved at {}".format(output_dir /
|
||||
f"sentence_{i}.npy"))
|
||||
|
|
|
@ -20,9 +20,8 @@ import paddle
|
|||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
from parakeet.frontend import EnglishCharacter # pylint: disable=unused-import
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils import display, mp_tools
|
||||
|
@ -34,14 +33,18 @@ from ljspeech import LJSpeech, LJSpeechCollector
|
|||
|
||||
class Experiment(ExperimentBase):
|
||||
def compute_losses(self, inputs, outputs):
|
||||
_, mel_targets, _, _, stop_tokens = inputs
|
||||
texts, mel_targets, plens, slens = inputs
|
||||
|
||||
mel_outputs = outputs["mel_output"]
|
||||
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
||||
stop_logits = outputs["stop_logits"]
|
||||
attention_weight = outputs["alignments"]
|
||||
if self.config.model.use_stop_token:
|
||||
stop_logits = outputs["stop_logits"]
|
||||
else:
|
||||
stop_logits = None
|
||||
|
||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, stop_logits,
|
||||
mel_targets, stop_tokens)
|
||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
|
||||
attention_weight, slens, plens, stop_logits)
|
||||
return losses
|
||||
|
||||
def train_batch(self):
|
||||
|
@ -51,8 +54,8 @@ class Experiment(ExperimentBase):
|
|||
|
||||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
texts, mels, text_lens, output_lens, stop_tokens = batch
|
||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
||||
texts, mels, text_lens, output_lens = batch
|
||||
outputs = self.model(texts, text_lens, mels, output_lens)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
loss = losses["loss"]
|
||||
loss.backward()
|
||||
|
@ -79,22 +82,24 @@ class Experiment(ExperimentBase):
|
|||
def valid(self):
|
||||
valid_losses = defaultdict(list)
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
texts, mels, text_lens, output_lens, stop_tokens = batch
|
||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
||||
texts, mels, text_lens, output_lens = batch
|
||||
outputs = self.model(texts, text_lens, mels, output_lens)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
for k, v in losses.items():
|
||||
valid_losses[k].append(float(v))
|
||||
|
||||
attention_weights = outputs["alignments"]
|
||||
display.add_attention_plots(self.visualizer,
|
||||
f"valid_sentence_{i}_alignments",
|
||||
attention_weights[0], self.iteration)
|
||||
display.add_spectrogram_plots(
|
||||
self.visualizer, f"valid_sentence_{i}_target_spectrogram",
|
||||
mels[0], self.iteration)
|
||||
display.add_spectrogram_plots(
|
||||
self.visualizer, f"valid_sentence_{i}_predicted_spectrogram",
|
||||
outputs['mel_outputs_postnet'][0], self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_alignments",
|
||||
display.plot_alignment(attention_weights[0].numpy().T),
|
||||
self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_target_spectrogram",
|
||||
display.plot_spectrogram(mels[0].numpy().T), self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_predicted_spectrogram",
|
||||
display.plot_spectrogram(outputs['mel_outputs_postnet'][0]
|
||||
.numpy().T), self.iteration)
|
||||
|
||||
# write visual log
|
||||
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
|
@ -111,10 +116,9 @@ class Experiment(ExperimentBase):
|
|||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
frontend = EnglishCharacter()
|
||||
model = Tacotron2(
|
||||
frontend,
|
||||
d_mels=config.data.d_mels,
|
||||
vocab_size=config.model.vocab_size,
|
||||
d_mels=config.data.n_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||
|
@ -132,7 +136,8 @@ class Experiment(ExperimentBase):
|
|||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout)
|
||||
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||
use_stop_token=config.model.use_stop_token)
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
@ -145,7 +150,10 @@ class Experiment(ExperimentBase):
|
|||
weight_decay=paddle.regularizer.L2Decay(
|
||||
config.training.weight_decay),
|
||||
grad_clip=grad_clip)
|
||||
criterion = Tacotron2Loss()
|
||||
criterion = Tacotron2Loss(
|
||||
use_stop_token_loss=config.model.use_stop_token,
|
||||
use_guided_attention_loss=config.model.use_guided_attention_loss,
|
||||
sigma=config.model.guided_attention_loss_sigma)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
@ -186,6 +194,7 @@ class Experiment(ExperimentBase):
|
|||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.resume_or_load()
|
||||
exp.run()
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
## Tacotron2 + AISHELL-3 数据集训练语音克隆模型
|
||||
|
||||
本实验的内容是利用 AISHELL-3 数据集和 Tacotron 2 模型进行语音克隆任务,使用的模型大体结构和论文 [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) 相同。大致步骤如下:
|
||||
|
||||
1. Speaker Encoder: 我们使用了一个 Speaker Verification 任务训练一个 speaker encoder。这部分任务所用的数据集和训练 Tacotron 2 的数据集不同,因为不需要 transcription 的缘故,我们使用了较多的训练数据,可以参考实现 [ge2e](../ge2e)。
|
||||
2. Synthesizer: 然后使用训练好的 speaker encoder 为 AISHELL-3 数据集中的每个句子生成对应的 utterance embedding. 这个 Embedding 作为 Tacotron 模型中的一个额外输入和 encoder outputs 拼接在一起。
|
||||
3. Vocoder: 我们使用的声码器是 WaveFlow,参考实验 [waveflow](../waveflow).
|
||||
|
||||
## 数据处理
|
||||
|
||||
### utterance embedding 的生成
|
||||
|
||||
使用训练好的 speaker encoder 为 AISHELL-3 数据集中的每个句子生成对应的 utterance embedding. 以和音频文件夹同构的方式存储。存储格式是 `.npy` 文件。
|
||||
|
||||
首先 cd 到 [ge2e](../ge2e) 文件夹。下载训练好的 [模型](https://paddlespeech.bj.bcebos.com/Parakeet/ge2e_ckpt_0.3.zip),然后运行脚本生成每个句子的 utterance embedding.
|
||||
|
||||
```bash
|
||||
python inference.py --input=<intput> --output=<output> --device="gpu" --checkpoint_path=<pretrained checkpoint>
|
||||
```
|
||||
|
||||
其中 input 是只包含音频文件夹的文件。这里可以用 `~/datasets/aishell3/train/wav`,然后 output 是用于存储 utterance embed 的文件夹,这里可以用 `~/datasets/aishell3/train/embed`。Utterance embedding 会以和音频文件夹相同的文件结构存储,格式为 `.npy`.
|
||||
|
||||
utterance embedding 的计算可能会用几个小时的时间,请耐心等待。
|
||||
|
||||
### 音频处理
|
||||
|
||||
因为 AISHELL-3 数据集前后有一些空白,静音片段,而且语音幅值很小,所以我们需要进行空白移除和音量规范化。空白移除可以简单的使用基于音量或者能量的方法,但是效果不是很好,对于不同的句子很难取到一个一致的阈值。我们使用的是先利用 Force Aligner 进行文本和语音的对齐。然后根据对齐结果截除空白。
|
||||
|
||||
我们使用的工具是 Montreal Force Aligner 1.0. 因为 aishell 的标注包含拼音标注,所以我们提供给 Montreal Force Aligner 的是拼音 transcription 而不是汉字 transcription. 而且需要把其中的韵律标记(`$` 和 `%`)去除,并且处理成 Montreal Force Alinger 所需要的文件形式。和音频同名的文本文件,扩展名为 `.lab`.
|
||||
|
||||
此外还需要准备词典文件。其中包含把拼音序列转换为 phone 序列的映射关系。在这里我们只做声母和韵母的切分,而声调则归为韵母的一部分。我们使用的[词典文件](./lexicon.txt)可以下载。
|
||||
|
||||
准备好之后运行训练和对齐。首先下载 [Montreal Force Aligner 1.0](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/tag/v1.0.1).下载之后解压即可运行。cd 到其中的 bin 文件夹运行命令,即可进行训练和对齐。前三个命令行参数分别是音频文件夹的路径,词典路径和对齐文件输出路径。可以通过`-o` 传入训练得到的模型保存路径。
|
||||
|
||||
```bash
|
||||
./mfa_train_and_align \
|
||||
~/datasets/aishell3/train/wav \
|
||||
lexicon.txt \
|
||||
~/datasets/aishell3/train/alignment \
|
||||
-o aishell3_model \
|
||||
-v
|
||||
```
|
||||
|
||||
因为训练和对齐的时间比较长。我们提供了对齐后的 [alignment 文件](https://paddlespeech.bj.bcebos.com/Parakeet/alignment_aishell3.tar.gz),其中每个句子对应的文件为 `.TextGrid` 格式的文本。
|
||||
|
||||
得到了对齐文件之后,可以运行 `process_wav.py` 脚本来处理音频。
|
||||
|
||||
```bash
|
||||
python process_wav.py --input=<input> --output=<output> --alignment=<alignment>
|
||||
```
|
||||
|
||||
默认 input, output, alignment 分别是 `~/datasets/aishell3/train/wav`, `~/datasets/aishell3/train/normalized_wav`, `~/datasets/aishell3/train/alignment`.
|
||||
|
||||
处理结束后,会将处理好的音频保存在 `<output>` 文件夹中。
|
||||
|
||||
### 转录文本处理
|
||||
|
||||
把文本转换成为 phone 和 tone 的形式,并存储起来。值得注意的是,这里我们的处理和用于 montreal force aligner 的不一样。我们把声调分了出来。这是一个处理方式,当然也可以只做声母和韵母的切分。
|
||||
|
||||
运行脚本处理转录文本。
|
||||
|
||||
```bash
|
||||
python preprocess_transcription.py --input=<input> --output=<output>
|
||||
```
|
||||
|
||||
默认的 input 是 `~/datasets/aishell3/train`,其中会包含 `label_train-set.txt` 文件,处理后的结果会 `metadata.yaml` 和 `metadata.pickle`. 前者是文本格式,方便查看,后者是二进制格式,方便直接读取。
|
||||
|
||||
### mel 频谱提取
|
||||
|
||||
对处理后的音频进行 mel 频谱的提取,并且以和音频文件夹同构的方式存储,存储格式是 `.npy` 文件。
|
||||
|
||||
```python
|
||||
python extract_mel.py --input=<intput> --output=<output>
|
||||
```
|
||||
|
||||
input 是处理后的音频所在的文件夹,output 是输出频谱的文件夹。
|
||||
|
||||
## 训练
|
||||
|
||||
运行脚本训练。
|
||||
|
||||
```python
|
||||
python train.py --data=<data> --output=<output> --device="gpu"
|
||||
```
|
||||
|
||||
我们的模型去掉了 tacotron2 模型中的 stop token prediction。因为实践中由于 stop token prediction 是一个正负样例比例极不平衡的问题,每个句子可能有几百帧对应负样例,只有一帧正样例,而且这个 stop token prediction 对音频静音的裁切十分敏感。我们转用 attention 的最高点到达 encoder 侧的最后一个符号为终止条件。
|
||||
|
||||
另外,为了加速模型的收敛,我们加上了 guided attention loss, 诱导 encoder-decoder 之间的 alignment 更快地呈现对角线。
|
||||
|
||||
可以使用 visualdl 查看训练过程的 log。
|
||||
|
||||
```bash
|
||||
visualdl --logdir=<output> --host=$HOSTNAME
|
||||
```
|
||||
|
||||
示例 training loss / validation loss 曲线如下。
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
<img src="images/alignment-step2000.png" alt="alignment-step2000" style="zoom:50%;" />
|
||||
|
||||
大约从训练 2000 步左右就从 validation 过程中产出的 alignement 中可以观察到模糊的对角线。随着训练步数增加,对角线会更加清晰。但因为 validation 也是以 teacher forcing 的方式进行的,所以要在真正的 auto regressive 合成中产出的 alignment 中观察到对角线,需要更长的时间。
|
||||
|
||||
## 预训练模型
|
||||
|
||||
预训练模型下载链接。[tacotron2_aishell3_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_aishell3_ckpt_0.3.zip).
|
||||
|
||||
## 使用
|
||||
|
||||
本实验包含了一个简单的使用示例,用户可以替换作为参考的声音以及文本,用训练好的模型来合成语音。使用方式参考 [notebook](./voice_cloning.ipynb) 上的使用说明。
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset
|
||||
from parakeet.frontend import Vocab
|
||||
from parakeet.data import batch_text_id, batch_spec
|
||||
|
||||
from preprocess_transcription import _phones, _tones
|
||||
|
||||
voc_phones = Vocab(sorted(list(_phones)))
|
||||
print("vocab_phones:\n", voc_phones)
|
||||
voc_tones = Vocab(sorted(list(_tones)))
|
||||
print("vocab_tones:\n", voc_tones)
|
||||
|
||||
|
||||
class AiShell3(Dataset):
|
||||
"""Processed AiShell3 dataset."""
|
||||
|
||||
def __init__(self, root):
|
||||
super().__init__()
|
||||
self.root = Path(root).expanduser()
|
||||
self.embed_dir = self.root / "embed"
|
||||
self.mel_dir = self.root / "mel"
|
||||
|
||||
with open(self.root / "metadata.pickle", 'rb') as f:
|
||||
self.records = pickle.load(f)
|
||||
|
||||
def __getitem__(self, index):
|
||||
metadatum = self.records[index]
|
||||
sentence_id = metadatum["sentence_id"]
|
||||
speaker_id = sentence_id[:7]
|
||||
phones = metadatum["phones"]
|
||||
tones = metadatum["tones"]
|
||||
phones = np.array(
|
||||
[voc_phones.lookup(item) for item in phones], dtype=np.int64)
|
||||
tones = np.array(
|
||||
[voc_tones.lookup(item) for item in tones], dtype=np.int64)
|
||||
mel = np.load(str(self.mel_dir / speaker_id / (sentence_id + ".npy")))
|
||||
embed = np.load(
|
||||
str(self.embed_dir / speaker_id / (sentence_id + ".npy")))
|
||||
return phones, tones, mel, embed
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
|
||||
def collate_aishell3_examples(examples):
|
||||
phones, tones, mel, embed = list(zip(*examples))
|
||||
|
||||
text_lengths = np.array([item.shape[0] for item in phones], dtype=np.int64)
|
||||
spec_lengths = np.array([item.shape[1] for item in mel], dtype=np.int64)
|
||||
T_dec = np.max(spec_lengths)
|
||||
stop_tokens = (np.arange(T_dec) >= np.expand_dims(spec_lengths, -1)
|
||||
).astype(np.float32)
|
||||
phones, _ = batch_text_id(phones)
|
||||
tones, _ = batch_text_id(tones)
|
||||
mel, _ = batch_spec(mel)
|
||||
mel = np.transpose(mel, (0, 2, 1))
|
||||
embed = np.stack(embed)
|
||||
# 7 fields
|
||||
# (B, T), (B, T), (B, T, C), (B, C), (B,), (B,), (B, T)
|
||||
return phones, tones, mel, embed, text_lengths, spec_lengths, stop_tokens
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = AiShell3("~/datasets/aishell3/train")
|
||||
example = dataset[0]
|
||||
|
||||
examples = [dataset[i] for i in range(10)]
|
||||
batch = collate_aishell3_examples(examples)
|
||||
|
||||
for field in batch:
|
||||
print(field.shape, field.dtype)
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
from preprocess_transcription import split_syllable
|
||||
|
||||
|
||||
def convert_to_pinyin(text: str) -> List[str]:
|
||||
"""convert text into list of syllables, other characters that are not chinese, thus
|
||||
cannot be converted to pinyin are splited.
|
||||
"""
|
||||
syllables = lazy_pinyin(
|
||||
text, style=Style.TONE3, neutral_tone_with_five=True)
|
||||
return syllables
|
||||
|
||||
|
||||
def convert_sentence(text: str) -> List[Tuple[str]]:
|
||||
"""convert a sentence into two list: phones and tones"""
|
||||
syllables = convert_to_pinyin(text)
|
||||
phones = []
|
||||
tones = []
|
||||
for syllable in syllables:
|
||||
p, t = split_syllable(syllable)
|
||||
phones.extend(p)
|
||||
tones.extend(t)
|
||||
|
||||
return phones, tones
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
_C.data = CN(
|
||||
dict(
|
||||
batch_size=32, # batch size
|
||||
valid_size=64, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
fmax=8000, # Hz, max frequency when converting to mel
|
||||
fmin=0, # Hz, min frequency when converting to mel
|
||||
d_mels=80, # mel bands
|
||||
padding_idx=0, # text embedding's padding index
|
||||
))
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
vocab_size=70,
|
||||
n_tones=10,
|
||||
reduction_factor=1, # reduction factor
|
||||
d_encoder=512, # embedding & encoder's internal size
|
||||
encoder_conv_layers=3, # number of conv layer in tacotron2 encoder
|
||||
encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder
|
||||
d_prenet=256, # hidden size of decoder prenet
|
||||
# hidden size of the first rnn layer in tacotron2 decoder
|
||||
d_attention_rnn=1024,
|
||||
# hidden size of the second rnn layer in tacotron2 decoder
|
||||
d_decoder_rnn=1024,
|
||||
d_attention=128, # hidden size of decoder location linear layer
|
||||
attention_filters=32, # number of filter in decoder location conv layer
|
||||
attention_kernel_size=31, # kernel size of decoder location conv layer
|
||||
d_postnet=512, # hidden size of decoder postnet
|
||||
postnet_kernel_size=5, # kernel size of conv layers in postnet
|
||||
postnet_conv_layers=5, # number of conv layer in decoder postnet
|
||||
p_encoder_dropout=0.5, # droput probability in encoder
|
||||
p_prenet_dropout=0.5, # droput probability in decoder prenet
|
||||
|
||||
# droput probability of first rnn layer in decoder
|
||||
p_attention_dropout=0.1,
|
||||
# droput probability of second rnn layer in decoder
|
||||
p_decoder_dropout=0.1,
|
||||
p_postnet_dropout=0.5, # droput probability in decoder postnet
|
||||
guided_attention_loss_sigma=0.2,
|
||||
d_global_condition=256,
|
||||
|
||||
# whether to use a classifier to predict stop probability
|
||||
use_stop_token=False,
|
||||
# whether to use guided attention loss in training
|
||||
use_guided_attention_loss=True, ))
|
||||
|
||||
_C.training = CN(
|
||||
dict(
|
||||
lr=1e-3, # learning rate
|
||||
weight_decay=1e-6, # the coeff of weight decay
|
||||
grad_clip_thresh=1.0, # the clip norm of grad clip.
|
||||
valid_interval=1000, # validation
|
||||
save_interval=1000, # checkpoint
|
||||
max_iteration=500000, # max iteration to train
|
||||
))
|
||||
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _C.clone()
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.audio.spec_normalizer import NormalizerBase, LogMagnitude
|
||||
|
||||
import tqdm
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
def extract_mel(fname: Path,
|
||||
input_dir: Path,
|
||||
output_dir: Path,
|
||||
p: AudioProcessor,
|
||||
n: NormalizerBase):
|
||||
relative_path = fname.relative_to(input_dir)
|
||||
out_path = (output_dir / relative_path).with_suffix(".npy")
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
wav = p.read_wav(fname)
|
||||
mel = p.mel_spectrogram(wav)
|
||||
mel = n.transform(mel)
|
||||
np.save(out_path, mel)
|
||||
|
||||
|
||||
def extract_mel_multispeaker(config, input_dir, output_dir, extension=".wav"):
|
||||
input_dir = Path(input_dir).expanduser()
|
||||
fnames = list(input_dir.rglob(f"*{extension}"))
|
||||
output_dir = Path(output_dir).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
p = AudioProcessor(config.sample_rate, config.n_fft, config.win_length,
|
||||
config.hop_length, config.n_mels, config.fmin,
|
||||
config.fmax)
|
||||
n = LogMagnitude(1e-5)
|
||||
|
||||
func = partial(
|
||||
extract_mel, input_dir=input_dir, output_dir=output_dir, p=p, n=n)
|
||||
|
||||
with mp.Pool(16) as pool:
|
||||
list(
|
||||
tqdm.tqdm(
|
||||
pool.imap(func, fnames), total=len(fnames), unit="utterance"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extract mel spectrogram from processed wav in AiShell3 training dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="yaml config file to overwrite the default config")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="~/datasets/aishell3/train/normalized_wav",
|
||||
help="path of the processed wav folder")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="~/datasets/aishell3/train/mel",
|
||||
help="path of the folder to save mel spectrograms")
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
default_config = get_cfg_defaults()
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
default_config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
default_config.merge_from_list(args.opts)
|
||||
default_config.freeze()
|
||||
audio_config = default_config.data
|
||||
|
||||
extract_mel_multispeaker(audio_config, args.input, args.output)
|
Binary file not shown.
After Width: | Height: | Size: 221 KiB |
Binary file not shown.
After Width: | Height: | Size: 550 KiB |
Binary file not shown.
After Width: | Height: | Size: 514 KiB |
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,258 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import re
|
||||
import pickle
|
||||
|
||||
import yaml
|
||||
import tqdm
|
||||
|
||||
zh_pattern = re.compile("[\u4e00-\u9fa5]")
|
||||
|
||||
_tones = {'<pad>', '<s>', '</s>', '0', '1', '2', '3', '4', '5'}
|
||||
|
||||
_pauses = {'%', '$'}
|
||||
|
||||
_initials = {
|
||||
'b',
|
||||
'p',
|
||||
'm',
|
||||
'f',
|
||||
'd',
|
||||
't',
|
||||
'n',
|
||||
'l',
|
||||
'g',
|
||||
'k',
|
||||
'h',
|
||||
'j',
|
||||
'q',
|
||||
'x',
|
||||
'zh',
|
||||
'ch',
|
||||
'sh',
|
||||
'r',
|
||||
'z',
|
||||
'c',
|
||||
's',
|
||||
}
|
||||
|
||||
_finals = {
|
||||
'ii',
|
||||
'iii',
|
||||
'a',
|
||||
'o',
|
||||
'e',
|
||||
'ea',
|
||||
'ai',
|
||||
'ei',
|
||||
'ao',
|
||||
'ou',
|
||||
'an',
|
||||
'en',
|
||||
'ang',
|
||||
'eng',
|
||||
'er',
|
||||
'i',
|
||||
'ia',
|
||||
'io',
|
||||
'ie',
|
||||
'iai',
|
||||
'iao',
|
||||
'iou',
|
||||
'ian',
|
||||
'ien',
|
||||
'iang',
|
||||
'ieng',
|
||||
'u',
|
||||
'ua',
|
||||
'uo',
|
||||
'uai',
|
||||
'uei',
|
||||
'uan',
|
||||
'uen',
|
||||
'uang',
|
||||
'ueng',
|
||||
'v',
|
||||
've',
|
||||
'van',
|
||||
'ven',
|
||||
'veng',
|
||||
}
|
||||
|
||||
_ernized_symbol = {'&r'}
|
||||
|
||||
_specials = {'<pad>', '<unk>', '<s>', '</s>'}
|
||||
|
||||
_phones = _initials | _finals | _ernized_symbol | _specials | _pauses
|
||||
|
||||
|
||||
def is_zh(word):
|
||||
global zh_pattern
|
||||
match = zh_pattern.search(word)
|
||||
return match is not None
|
||||
|
||||
|
||||
def ernized(syllable):
|
||||
return syllable[:2] != "er" and syllable[-2] == 'r'
|
||||
|
||||
|
||||
def convert(syllable):
|
||||
# expansion of o -> uo
|
||||
syllable = re.sub(r"([bpmf])o$", r"\1uo", syllable)
|
||||
# syllable = syllable.replace("bo", "buo").replace("po", "puo").replace("mo", "muo").replace("fo", "fuo")
|
||||
# expansion for iong, ong
|
||||
syllable = syllable.replace("iong", "veng").replace("ong", "ueng")
|
||||
|
||||
# expansion for ing, in
|
||||
syllable = syllable.replace("ing", "ieng").replace("in", "ien")
|
||||
|
||||
# expansion for un, ui, iu
|
||||
syllable = syllable.replace("un", "uen").replace(
|
||||
"ui", "uei").replace("iu", "iou")
|
||||
|
||||
# rule for variants of i
|
||||
syllable = syllable.replace("zi", "zii").replace("ci", "cii").replace("si", "sii")\
|
||||
.replace("zhi", "zhiii").replace("chi", "chiii").replace("shi", "shiii")\
|
||||
.replace("ri", "riii")
|
||||
|
||||
# rule for y preceding i, u
|
||||
syllable = syllable.replace("yi", "i").replace("yu", "v").replace("y", "i")
|
||||
|
||||
# rule for w
|
||||
syllable = syllable.replace("wu", "u").replace("w", "u")
|
||||
|
||||
# rule for v following j, q, x
|
||||
syllable = syllable.replace("ju", "jv").replace("qu",
|
||||
"qv").replace("xu", "xv")
|
||||
|
||||
return syllable
|
||||
|
||||
|
||||
def split_syllable(syllable: str):
|
||||
"""Split a syllable in pinyin into a list of phones and a list of tones.
|
||||
Initials have no tone, represented by '0', while finals have tones from
|
||||
'1,2,3,4,5'.
|
||||
|
||||
e.g.
|
||||
|
||||
zhang -> ['zh', 'ang'], ['0', '1']
|
||||
"""
|
||||
if syllable in _pauses:
|
||||
# syllable, tone
|
||||
return [syllable], ['0']
|
||||
|
||||
tone = syllable[-1]
|
||||
syllable = convert(syllable[:-1])
|
||||
|
||||
phones = []
|
||||
tones = []
|
||||
|
||||
global _initials
|
||||
if syllable[:2] in _initials:
|
||||
phones.append(syllable[:2])
|
||||
tones.append('0')
|
||||
phones.append(syllable[2:])
|
||||
tones.append(tone)
|
||||
elif syllable[0] in _initials:
|
||||
phones.append(syllable[0])
|
||||
tones.append('0')
|
||||
phones.append(syllable[1:])
|
||||
tones.append(tone)
|
||||
else:
|
||||
phones.append(syllable)
|
||||
tones.append(tone)
|
||||
return phones, tones
|
||||
|
||||
|
||||
def load_aishell3_transcription(line: str):
|
||||
sentence_id, pinyin, text = line.strip().split("|")
|
||||
syllables = pinyin.strip().split()
|
||||
|
||||
results = []
|
||||
|
||||
for syllable in syllables:
|
||||
if syllable in _pauses:
|
||||
results.append(syllable)
|
||||
elif not ernized(syllable):
|
||||
results.append(syllable)
|
||||
else:
|
||||
results.append(syllable[:-2] + syllable[-1])
|
||||
results.append('&r5')
|
||||
|
||||
phones = []
|
||||
tones = []
|
||||
for syllable in results:
|
||||
p, t = split_syllable(syllable)
|
||||
phones.extend(p)
|
||||
tones.extend(t)
|
||||
for p in phones:
|
||||
assert p in _phones, p
|
||||
return {
|
||||
"sentence_id": sentence_id,
|
||||
"text": text,
|
||||
"syllables": results,
|
||||
"phones": phones,
|
||||
"tones": tones
|
||||
}
|
||||
|
||||
|
||||
def process_aishell3(dataset_root, output_dir):
|
||||
dataset_root = Path(dataset_root).expanduser()
|
||||
output_dir = Path(output_dir).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prosody_label_path = dataset_root / "label_train-set.txt"
|
||||
with open(prosody_label_path, 'rt') as f:
|
||||
lines = [line.strip() for line in f]
|
||||
|
||||
records = lines[5:]
|
||||
|
||||
processed_records = []
|
||||
for record in tqdm.tqdm(records):
|
||||
new_record = load_aishell3_transcription(record)
|
||||
processed_records.append(new_record)
|
||||
print(new_record)
|
||||
|
||||
with open(output_dir / "metadata.pickle", 'wb') as f:
|
||||
pickle.dump(processed_records, f)
|
||||
|
||||
with open(output_dir / "metadata.yaml", 'wt', encoding="utf-8") as f:
|
||||
yaml.safe_dump(
|
||||
processed_records, f, default_flow_style=None, allow_unicode=True)
|
||||
|
||||
print("metadata done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess transcription of AiShell3 and save them in a compact file(yaml and pickle)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="~/datasets/aishell3/train",
|
||||
help="path of the training dataset,(contains a label_train-set.txt).")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="the directory to save the processed transcription."
|
||||
"If not provided, it would be the same as the input.")
|
||||
args = parser.parse_args()
|
||||
if args.output is None:
|
||||
args.output = args.input
|
||||
|
||||
process_aishell3(args.input, args.output)
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from tqdm import tqdm
|
||||
from praatio import tgio
|
||||
|
||||
|
||||
def get_valid_part(fpath):
|
||||
f = tgio.openTextgrid(fpath)
|
||||
|
||||
start = 0
|
||||
phone_entry_list = f.tierDict['phones'].entryList
|
||||
first_entry = phone_entry_list[0]
|
||||
if first_entry.label == "sil":
|
||||
start = first_entry.end
|
||||
|
||||
last_entry = phone_entry_list[-1]
|
||||
if last_entry.label == "sp":
|
||||
end = last_entry.start
|
||||
else:
|
||||
end = last_entry.end
|
||||
return start, end
|
||||
|
||||
|
||||
def process_utterance(fpath, source_dir, target_dir, alignment_dir):
|
||||
rel_path = fpath.relative_to(source_dir)
|
||||
opath = target_dir / rel_path
|
||||
apath = (alignment_dir / rel_path).with_suffix(".TextGrid")
|
||||
opath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
start, end = get_valid_part(apath)
|
||||
wav, _ = librosa.load(fpath, sr=22050, offset=start, duration=end - start)
|
||||
normalized_wav = wav / np.max(wav) * 0.999
|
||||
sf.write(opath, normalized_wav, samplerate=22050, subtype='PCM_16')
|
||||
# print(f"{fpath} => {opath}")
|
||||
|
||||
|
||||
def preprocess_aishell3(source_dir, target_dir, alignment_dir):
|
||||
source_dir = Path(source_dir).expanduser()
|
||||
target_dir = Path(target_dir).expanduser()
|
||||
alignment_dir = Path(alignment_dir).expanduser()
|
||||
|
||||
wav_paths = list(source_dir.rglob("*.wav"))
|
||||
print(f"there are {len(wav_paths)} audio files in total")
|
||||
fx = partial(
|
||||
process_utterance,
|
||||
source_dir=source_dir,
|
||||
target_dir=target_dir,
|
||||
alignment_dir=alignment_dir)
|
||||
with Pool(16) as p:
|
||||
list(
|
||||
tqdm(
|
||||
p.imap(fx, wav_paths), total=len(wav_paths), unit="utterance"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Process audio in AiShell3, trim silence according to the alignment "
|
||||
"files generated by MFA, and normalize volume by peak.")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="~/datasets/aishell3/train/wav",
|
||||
help="path of the original audio folder in aishell3.")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="~/datasets/aishell3/train/normalized_wav",
|
||||
help="path of the folder to save the processed audio files.")
|
||||
parser.add_argument(
|
||||
"--alignment",
|
||||
type=str,
|
||||
default="~/datasets/aishell3/train/alignment",
|
||||
help="path of the alignment files.")
|
||||
args = parser.parse_args()
|
||||
|
||||
preprocess_aishell3(args.input, args.output, args.alignment)
|
|
@ -0,0 +1,263 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
|
||||
from parakeet.data import dataset
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils import display, mp_tools
|
||||
from parakeet.models.tacotron2 import Tacotron2, Tacotron2Loss
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from aishell3 import AiShell3, collate_aishell3_examples
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
def compute_losses(self, inputs, outputs):
|
||||
texts, tones, mel_targets, utterance_embeds, text_lens, output_lens, stop_tokens = inputs
|
||||
|
||||
mel_outputs = outputs["mel_output"]
|
||||
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
||||
alignments = outputs["alignments"]
|
||||
|
||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
|
||||
alignments, output_lens, text_lens)
|
||||
return losses
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
texts, tones, mels, utterance_embeds, text_lens, output_lens, stop_tokens = batch
|
||||
outputs = self.model(
|
||||
texts,
|
||||
text_lens,
|
||||
mels,
|
||||
output_lens,
|
||||
tones=tones,
|
||||
global_condition=utterance_embeds)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
loss = losses["loss"]
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
losses_np = {k: float(v) for k, v in losses.items()}
|
||||
# logging
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
||||
iteration_time)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_np.items())
|
||||
self.logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
for key, value in losses_np.items():
|
||||
self.visualizer.add_scalar(f"train_loss/{key}", value,
|
||||
self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_losses = defaultdict(list)
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
texts, tones, mels, utterance_embeds, text_lens, output_lens, stop_tokens = batch
|
||||
outputs = self.model(
|
||||
texts,
|
||||
text_lens,
|
||||
mels,
|
||||
output_lens,
|
||||
tones=tones,
|
||||
global_condition=utterance_embeds)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
for key, value in losses.items():
|
||||
valid_losses[key].append(float(value))
|
||||
|
||||
attention_weights = outputs["alignments"]
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_alignments",
|
||||
display.plot_alignment(attention_weights[0].numpy().T),
|
||||
self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_target_spectrogram",
|
||||
display.plot_spectrogram(mels[0].numpy().T), self.iteration)
|
||||
mel_pred = outputs['mel_outputs_postnet']
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_predicted_spectrogram",
|
||||
display.plot_spectrogram(mel_pred[0].numpy().T),
|
||||
self.iteration)
|
||||
|
||||
# write visual log
|
||||
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
|
||||
# logging
|
||||
msg = "Valid: "
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in valid_losses.items())
|
||||
self.logger.info(msg)
|
||||
|
||||
for key, value in valid_losses.items():
|
||||
self.visualizer.add_scalar(f"valid/{key}", value, self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def eval(self):
|
||||
"""Evaluation of Tacotron2 in autoregressive manner."""
|
||||
self.model.eval()
|
||||
mel_dir = Path(self.output_dir / ("eval_{}".format(self.iteration)))
|
||||
mel_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, batch in enumerate(self.test_loader):
|
||||
texts, tones, mels, utterance_embeds, *_ = batch
|
||||
outputs = self.model.infer(
|
||||
texts, tones=tones, global_condition=utterance_embeds)
|
||||
|
||||
display.plot_alignment(outputs["alignments"][0].numpy().T)
|
||||
plt.savefig(mel_dir / f"sentence_{i}.png")
|
||||
plt.close()
|
||||
np.save(mel_dir / f"sentence_{i}",
|
||||
outputs["mel_outputs_postnet"][0].numpy().T)
|
||||
print(f"sentence_{i}")
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model = Tacotron2(
|
||||
vocab_size=config.model.vocab_size,
|
||||
n_tones=config.model.n_tones,
|
||||
d_mels=config.data.d_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||
d_prenet=config.model.d_prenet,
|
||||
d_attention_rnn=config.model.d_attention_rnn,
|
||||
d_decoder_rnn=config.model.d_decoder_rnn,
|
||||
attention_filters=config.model.attention_filters,
|
||||
attention_kernel_size=config.model.attention_kernel_size,
|
||||
d_attention=config.model.d_attention,
|
||||
d_postnet=config.model.d_postnet,
|
||||
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||
postnet_conv_layers=config.model.postnet_conv_layers,
|
||||
reduction_factor=config.model.reduction_factor,
|
||||
p_encoder_dropout=config.model.p_encoder_dropout,
|
||||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||
d_global_condition=config.model.d_global_condition,
|
||||
use_stop_token=config.model.use_stop_token, )
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
grad_clip = paddle.nn.ClipGradByGlobalNorm(
|
||||
config.training.grad_clip_thresh)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=config.training.lr,
|
||||
parameters=model.parameters(),
|
||||
weight_decay=paddle.regularizer.L2Decay(
|
||||
config.training.weight_decay),
|
||||
grad_clip=grad_clip)
|
||||
criterion = Tacotron2Loss(
|
||||
use_stop_token_loss=config.model.use_stop_token,
|
||||
use_guided_attention_loss=config.model.use_guided_attention_loss,
|
||||
sigma=config.model.guided_attention_loss_sigma)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
||||
def setup_dataloader(self):
|
||||
args = self.args
|
||||
config = self.config
|
||||
ljspeech_dataset = AiShell3(args.data)
|
||||
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset,
|
||||
config.data.valid_size)
|
||||
batch_fn = collate_aishell3_examples
|
||||
|
||||
if not self.parallel:
|
||||
self.train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
self.train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
|
||||
self.valid_loader = DataLoader(
|
||||
valid_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=batch_fn)
|
||||
|
||||
self.test_loader = DataLoader(
|
||||
valid_set,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=batch_fn)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.resume_or_load()
|
||||
if not args.test:
|
||||
exp.run()
|
||||
else:
|
||||
exp.eval()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument("--test", action="store_true")
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
File diff suppressed because one or more lines are too long
|
@ -14,7 +14,7 @@ wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
|
|||
tar xjvf LJSpeech-1.1.tar.bz2
|
||||
```
|
||||
|
||||
### Preprocess the dataset.
|
||||
### Preprocess the dataset.
|
||||
|
||||
Assume the path to save the preprocessed dataset is `ljspeech_transformer_tts`. Run the command below to preprocess the dataset.
|
||||
|
||||
|
@ -45,4 +45,8 @@ Synthesize waveform. We assume the `--input` is a text file, one sentence per li
|
|||
|
||||
```bash
|
||||
python synthesize.py --input=sentence.txt --output=mels/ --checkpoint_path='step-310000' --device="gpu" --verbose
|
||||
```
|
||||
```
|
||||
|
||||
## Pretrained Model
|
||||
|
||||
Pretrained model can be downloaded here. [transformer_tts_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/transformer_tts_ljspeech_ckpt_0.3.zip).
|
||||
|
|
|
@ -23,8 +23,9 @@ _C.data = CN(
|
|||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
d_mel=80, # mel bands
|
||||
fmin=0, # Hz, min frequency when converting to mel
|
||||
fmax=8000, # Hz, max frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
padding_idx=0, # text embedding's padding index
|
||||
mel_start_value=0.5, # value for starting frame
|
||||
mel_end_value=-0.5, # # value for ending frame
|
||||
|
@ -56,7 +57,7 @@ _C.training = CN(
|
|||
plot_interval=1000, # plot attention and spectrogram
|
||||
valid_interval=1000, # validation
|
||||
save_interval=10000, # checkpoint
|
||||
max_iteration=900000, # max iteration to train
|
||||
max_iteration=500000, # max iteration to train
|
||||
))
|
||||
|
||||
|
||||
|
|
|
@ -12,14 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from paddle.io import Dataset, DataLoader
|
||||
from paddle.io import Dataset
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_text_id
|
||||
from parakeet.data import dataset
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
|
@ -54,10 +53,10 @@ class Transform(object):
|
|||
ids, mel = example # ids already have <s> and </s>
|
||||
ids = np.array(ids, dtype=np.int64)
|
||||
# add start and end frame
|
||||
mel = np.pad(
|
||||
mel, [(0, 0), (1, 1)],
|
||||
mode='constant',
|
||||
constant_values=[(0, 0), (self.start_value, self.end_value)])
|
||||
mel = np.pad(mel, [(0, 0), (1, 1)],
|
||||
mode='constant',
|
||||
constant_values=[(0, 0),
|
||||
(self.start_value, self.end_value)])
|
||||
stop_labels = np.ones([mel.shape[1]], dtype=np.int64)
|
||||
stop_labels[-1] = 2
|
||||
# actually this thing can also be done within the model
|
||||
|
@ -76,30 +75,7 @@ class LJSpeechCollector(object):
|
|||
mels = [example[1] for example in examples]
|
||||
stop_probs = [example[2] for example in examples]
|
||||
|
||||
ids = batch_text_id(ids, pad_id=self.padding_idx)
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
stop_probs = batch_text_id(stop_probs, pad_id=self.padding_idx)
|
||||
ids, _ = batch_text_id(ids, pad_id=self.padding_idx)
|
||||
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||
stop_probs, _ = batch_text_id(stop_probs, pad_id=self.padding_idx)
|
||||
return ids, np.transpose(mels, [0, 2, 1]), stop_probs
|
||||
|
||||
|
||||
def create_dataloader(config, source_path):
|
||||
lj = LJSpeech(source_path)
|
||||
transform = Transform(config.data.mel_start_value,
|
||||
config.data.mel_end_value)
|
||||
lj = dataset.TransformDataset(lj, transform)
|
||||
|
||||
valid_set, train_set = dataset.split(lj, config.data.valid_size)
|
||||
data_collator = LJSpeechCollector(padding_idx=config.data.padding_idx)
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator)
|
||||
valid_loader = DataLoader(
|
||||
valid_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=data_collator)
|
||||
return train_loader, valid_loader
|
||||
|
|
|
@ -13,12 +13,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tqdm
|
||||
import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
import numpy as np
|
||||
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor, LogMagnitude
|
||||
from parakeet.frontend import English
|
||||
|
@ -40,7 +41,8 @@ def create_dataset(config, source_path, target_path, verbose=False):
|
|||
n_mels=config.data.d_mel,
|
||||
win_length=config.data.win_length,
|
||||
hop_length=config.data.hop_length,
|
||||
f_max=config.data.f_max)
|
||||
fmax=config.data.fmax,
|
||||
fmin=config.data.fmin)
|
||||
normalizer = LogMagnitude()
|
||||
|
||||
records = []
|
||||
|
|
|
@ -13,22 +13,19 @@
|
|||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import parakeet
|
||||
from parakeet.frontend import English
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import scheduler
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.utils.display import add_attention_plots
|
||||
from parakeet.utils import display
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
@paddle.fluid.dygraph.no_grad
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
|
||||
|
@ -47,9 +44,22 @@ def main(config, args):
|
|||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, sentence in enumerate(sentences):
|
||||
outputs = model.predict(sentence, verbose=args.verbose)
|
||||
mel_output = outputs["mel_output"]
|
||||
# cross_attention_weights = outputs["cross_attention_weights"]
|
||||
if args.verbose:
|
||||
print("text: ", sentence)
|
||||
print("phones: ", frontend.phoneticize(sentence))
|
||||
text_ids = paddle.to_tensor(frontend(sentence))
|
||||
text_ids = paddle.unsqueeze(text_ids, 0) # (1, T)
|
||||
|
||||
with paddle.no_grad():
|
||||
outputs = model.infer(text_ids, verbose=args.verbose)
|
||||
|
||||
mel_output = outputs["mel_output"][0].numpy()
|
||||
cross_attention_weights = outputs["cross_attention_weights"]
|
||||
attns = np.stack([attn[0].numpy() for attn in cross_attention_weights])
|
||||
attns = np.transpose(attns, [0, 1, 3, 2])
|
||||
display.plot_multilayer_multihead_alignments(attns)
|
||||
plt.savefig(str(output_dir / f"sentence_{i}.png"))
|
||||
|
||||
mel_output = mel_output.T #(C, T)
|
||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||
if args.verbose:
|
||||
|
|
|
@ -13,20 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.frontend import English
|
||||
from parakeet.models.transformer_tts import TransformerTTS, TransformerTTSLoss
|
||||
from parakeet.utils import scheduler, checkpoint, mp_tools, display
|
||||
from parakeet.utils import scheduler, mp_tools, display
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
|
||||
|
@ -34,7 +31,7 @@ from config import get_cfg_defaults
|
|||
from ljspeech import LJSpeech, LJSpeechCollector, Transform
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
class TransformerTTSExperiment(ExperimentBase):
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
frontend = English()
|
||||
|
@ -42,7 +39,7 @@ class Experiment(ExperimentBase):
|
|||
frontend,
|
||||
d_encoder=config.model.d_encoder,
|
||||
d_decoder=config.model.d_decoder,
|
||||
d_mel=config.data.d_mel,
|
||||
d_mel=config.data.n_mels,
|
||||
n_heads=config.model.n_heads,
|
||||
d_ffn=config.model.d_ffn,
|
||||
encoder_layers=config.model.encoder_layers,
|
||||
|
@ -109,13 +106,12 @@ class Experiment(ExperimentBase):
|
|||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def compute_outputs(self, text, mel, stop_label):
|
||||
def compute_outputs(self, text, mel):
|
||||
model_core = self.model._layers if self.parallel else self.model
|
||||
model_core.set_constants(
|
||||
self.reduction_factor(self.iteration),
|
||||
self.drop_n_heads(self.iteration))
|
||||
|
||||
# TODO(chenfeiyu): we can combine these 2 slices
|
||||
mel_input = mel[:, :-1, :]
|
||||
reduced_mel_input = mel_input[:, ::model_core.r, :]
|
||||
outputs = self.model(text, reduced_mel_input)
|
||||
|
@ -144,7 +140,7 @@ class Experiment(ExperimentBase):
|
|||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
text, mel, stop_label = batch
|
||||
outputs = self.compute_outputs(text, mel, stop_label)
|
||||
outputs = self.compute_outputs(text, mel)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
loss = losses["loss"]
|
||||
loss.backward()
|
||||
|
@ -169,20 +165,26 @@ class Experiment(ExperimentBase):
|
|||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
self.model.eval()
|
||||
valid_losses = defaultdict(list)
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
text, mel, stop_label = batch
|
||||
outputs = self.compute_outputs(text, mel, stop_label)
|
||||
outputs = self.compute_outputs(text, mel)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
for k, v in losses.items():
|
||||
valid_losses[k].append(float(v))
|
||||
|
||||
if i < 2:
|
||||
attention_weights = outputs["cross_attention_weights"]
|
||||
display.add_multi_attention_plots(
|
||||
self.visualizer,
|
||||
attention_weights = [
|
||||
np.transpose(item[0].numpy(), [0, 2, 1])
|
||||
for item in attention_weights
|
||||
]
|
||||
attention_weights = np.stack(attention_weights)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_cross_attention_weights",
|
||||
attention_weights, self.iteration)
|
||||
display.plot_multilayer_multihead_alignments(
|
||||
attention_weights), self.iteration)
|
||||
|
||||
# write visual log
|
||||
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
|
@ -191,8 +193,9 @@ class Experiment(ExperimentBase):
|
|||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp = TransformerTTSExperiment(config, args)
|
||||
exp.setup()
|
||||
exp.resume_or_load()
|
||||
exp.run()
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
|
|||
tar xjvf LJSpeech-1.1.tar.bz2
|
||||
```
|
||||
|
||||
### Preprocess the dataset.
|
||||
### Preprocess the dataset.
|
||||
|
||||
Assume the path to save the preprocessed dataset is `ljspeech_waveflow`. Run the command below to preprocess the dataset.
|
||||
|
||||
|
@ -45,4 +45,8 @@ Synthesize waveform. We assume the `--input` is a directory containing several m
|
|||
|
||||
```bash
|
||||
python synthesize.py --input=mels/ --output=wavs/ --checkpoint_path='step-2000000' --device="gpu" --verbose
|
||||
```
|
||||
```
|
||||
|
||||
## Pretrained Model
|
||||
|
||||
Pretrained Model with residual channel equals 128 can be downloaded here. [waveflow_ljspeech_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/waveflow_ljspeech_ckpt_0.3.zip).
|
||||
|
|
|
@ -23,7 +23,8 @@ _C.data = CN(
|
|||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
fmin=0,
|
||||
fmax=8000, # Hz, max frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
clip_frames=65, # mel clip frames
|
||||
))
|
||||
|
|
|
@ -12,16 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import pandas
|
||||
from paddle.io import Dataset, DataLoader
|
||||
from paddle.io import Dataset
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_wav
|
||||
from parakeet.data import dataset
|
||||
from parakeet.audio import AudioProcessor
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
|
@ -61,8 +58,8 @@ class LJSpeechCollector(object):
|
|||
def __call__(self, examples):
|
||||
mels = [example[0] for example in examples]
|
||||
wavs = [example[1] for example in examples]
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs = batch_wav(wavs, pad_value=self.padding_value)
|
||||
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs, _ = batch_wav(wavs, pad_value=self.padding_value)
|
||||
return mels, wavs
|
||||
|
||||
|
||||
|
|
|
@ -13,29 +13,30 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tqdm
|
||||
import csv
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from paddle.io import Dataset
|
||||
from parakeet.data import batch_spec, batch_wav
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor, LogMagnitude
|
||||
from parakeet.audio import LogMagnitude
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
class Transform(object):
|
||||
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
|
||||
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels,
|
||||
fmin, fmax):
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.n_mels = n_mels
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
|
||||
self.spec_normalizer = LogMagnitude(min=1e-5)
|
||||
|
||||
|
@ -47,6 +48,8 @@ class Transform(object):
|
|||
win_length = self.win_length
|
||||
hop_length = self.hop_length
|
||||
n_mels = self.n_mels
|
||||
fmin = self.fmin
|
||||
fmax = self.fmax
|
||||
|
||||
wav, loaded_sr = librosa.load(wav_path, sr=None)
|
||||
assert loaded_sr == sr, "sample rate does not match, resampling applied"
|
||||
|
@ -78,9 +81,10 @@ class Transform(object):
|
|||
# Compute mel-spectrograms.
|
||||
mel_filter_bank = librosa.filters.mel(sr=sr,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels)
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax)
|
||||
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
|
||||
mel_spectrogram = mel_spectrogram
|
||||
|
||||
# log scale mel_spectrogram.
|
||||
mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram)
|
||||
|
@ -93,7 +97,7 @@ class Transform(object):
|
|||
return audio, mel_spectrogram
|
||||
|
||||
|
||||
def create_dataset(config, input_dir, output_dir, verbose=True):
|
||||
def create_dataset(config, input_dir, output_dir):
|
||||
input_dir = Path(input_dir).expanduser()
|
||||
dataset = LJSpeechMetaData(input_dir)
|
||||
|
||||
|
@ -101,7 +105,8 @@ def create_dataset(config, input_dir, output_dir, verbose=True):
|
|||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
transform = Transform(config.sample_rate, config.n_fft, config.win_length,
|
||||
config.hop_length, config.n_mels)
|
||||
config.hop_length, config.n_mels, config.fmin,
|
||||
config.fmax)
|
||||
file_names = []
|
||||
|
||||
for example in tqdm.tqdm(dataset):
|
||||
|
@ -157,4 +162,4 @@ if __name__ == "__main__":
|
|||
print(config.data)
|
||||
print(args)
|
||||
|
||||
create_dataset(config.data, args.input, args.output, args.verbose)
|
||||
create_dataset(config.data, args.input, args.output)
|
||||
|
|
|
@ -12,15 +12,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import os
|
||||
from pathlib import Path
|
||||
import paddle
|
||||
import parakeet
|
||||
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow
|
||||
from parakeet.utils import layer_tools, checkpoint
|
||||
|
||||
from parakeet.models.waveflow import ConditionalWaveFlow
|
||||
from parakeet.utils import layer_tools
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
@ -34,9 +35,10 @@ def main(config, args):
|
|||
mel_dir = Path(args.input).expanduser()
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for file_path in mel_dir.iterdir():
|
||||
for file_path in mel_dir.glob("*.npy"):
|
||||
mel = np.load(str(file_path))
|
||||
audio = model.predict(mel)
|
||||
with paddle.amp.auto_cast():
|
||||
audio = model.predict(mel)
|
||||
audio_path = output_dir / (
|
||||
os.path.splitext(file_path.name)[0] + ".wav")
|
||||
sf.write(audio_path, audio, config.data.sample_rate)
|
||||
|
|
|
@ -13,22 +13,17 @@
|
|||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow, WaveFlowLoss
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.utils import scheduler, mp_tools
|
||||
from parakeet.models.waveflow import ConditionalWaveFlow, WaveFlowLoss
|
||||
from parakeet.utils import mp_tools
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
|
||||
|
@ -119,8 +114,8 @@ class Experiment(ExperimentBase):
|
|||
msg += "loss: {:>.6f}".format(loss_value)
|
||||
self.logger.info(msg)
|
||||
if dist.get_rank() == 0:
|
||||
self.visualizer.add_scalar(
|
||||
"train/loss", loss_value, global_step=self.iteration)
|
||||
self.visualizer.add_scalar("train/loss", loss_value,
|
||||
self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
|
@ -132,13 +127,13 @@ class Experiment(ExperimentBase):
|
|||
loss = self.criterion(z, log_det_jocobian)
|
||||
valid_losses.append(float(loss))
|
||||
valid_loss = np.mean(valid_losses)
|
||||
self.visualizer.add_scalar(
|
||||
"valid/loss", valid_loss, global_step=self.iteration)
|
||||
self.visualizer.add_scalar("valid/loss", valid_loss, self.iteration)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.resume_or_load()
|
||||
exp.run()
|
||||
|
||||
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
# WaveNet with LJSpeech
|
||||
|
||||
## Dataset
|
||||
|
||||
### Download the datasaet.
|
||||
|
||||
```bash
|
||||
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
|
||||
```
|
||||
|
||||
### Extract the dataset.
|
||||
|
||||
```bash
|
||||
tar xjvf LJSpeech-1.1.tar.bz2
|
||||
```
|
||||
|
||||
### Preprocess the dataset.
|
||||
|
||||
Assume the path to save the preprocessed dataset is `ljspeech_wavenet`. Run the command below to preprocess the dataset.
|
||||
|
||||
```bash
|
||||
python preprocess.py --input=LJSpeech-1.1/ --output=ljspeech_wavenet
|
||||
```
|
||||
|
||||
## Train the model
|
||||
|
||||
The training script requires 4 command line arguments.
|
||||
`--data` is the path of the training dataset, `--output` is the path of the output directory (we recommend to use a subdirectory in `runs` to manage different experiments.)
|
||||
|
||||
`--device` should be "cpu" or "gpu", `--nprocs` is the number of processes to train the model in parallel.
|
||||
|
||||
```bash
|
||||
python train.py --data=ljspeech_wavenet/ --output=runs/test --device="gpu" --nprocs=1
|
||||
```
|
||||
|
||||
If you want distributed training, set a larger `--nprocs` (e.g. 4). Note that distributed training with cpu is not supported yet.
|
||||
|
||||
## Synthesize
|
||||
|
||||
Synthesize waveform. We assume the `--input` is a directory containing several mel spectrograms(normalized into range[0, 1)) in `.npy` format. The output would be saved in `--output` directory, containing several `.wav` files, each with the same name as the mel spectrogram does.
|
||||
|
||||
`--checkpoint_path` should be the path of the parameter file (`.pdparams`) to load. Note that the extention name `.pdparmas` is not included here.
|
||||
|
||||
`--device` specifies to device to run synthesis on. Due to the autoregressiveness of wavenet, using cpu may be faster.
|
||||
|
||||
```bash
|
||||
python synthesize.py --input=mels/ --output=wavs/ --checkpoint_path='step-2450000' --device="cpu" --verbose
|
||||
```
|
|
@ -1,58 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_C = CN()
|
||||
_C.data = CN(
|
||||
dict(
|
||||
batch_size=8, # batch size
|
||||
valid_size=16, # the first N examples are reserved for validation
|
||||
sample_rate=22050, # Hz, sample rate
|
||||
n_fft=2048, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
# f_max=8000, # Hz, max frequency when converting to mel
|
||||
n_mels=80, # mel bands
|
||||
train_clip_seconds=0.5, # audio clip length(in seconds)
|
||||
))
|
||||
|
||||
_C.model = CN(
|
||||
dict(
|
||||
upsample_factors=[16, 16],
|
||||
n_stack=3,
|
||||
n_loop=10,
|
||||
filter_size=2,
|
||||
residual_channels=128, # resiaudal channel in each flow
|
||||
loss_type="mog",
|
||||
output_dim=3, # single gaussian
|
||||
log_scale_min=-9.0, ))
|
||||
|
||||
_C.training = CN(
|
||||
dict(
|
||||
lr=1e-3, # learning rates
|
||||
anneal_rate=0.5, # learning rate decay rate
|
||||
anneal_interval=200000, # decrese lr by annel_rate every anneal_interval steps
|
||||
valid_interval=1000, # validation
|
||||
save_interval=10000, # checkpoint
|
||||
max_iteration=3000000, # max iteration to train
|
||||
gradient_max_norm=100.0 # global norm of gradients
|
||||
))
|
||||
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _C.clone()
|
|
@ -1,151 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas
|
||||
from paddle.io import Dataset, DataLoader
|
||||
|
||||
from parakeet.data.batch import batch_spec, batch_wav
|
||||
from parakeet.data import dataset
|
||||
from parakeet.audio import AudioProcessor
|
||||
|
||||
|
||||
class LJSpeech(Dataset):
|
||||
"""A simple dataset adaptor for the processed ljspeech dataset."""
|
||||
|
||||
def __init__(self, root):
|
||||
self.root = Path(root).expanduser()
|
||||
meta_data = pandas.read_csv(
|
||||
str(self.root / "metadata.csv"),
|
||||
sep="\t",
|
||||
header=None,
|
||||
names=["fname", "frames", "samples"])
|
||||
|
||||
records = []
|
||||
for row in meta_data.itertuples():
|
||||
mel_path = str(self.root / "mel" / (row.fname + ".npy"))
|
||||
wav_path = str(self.root / "wav" / (row.fname + ".npy"))
|
||||
records.append((mel_path, wav_path))
|
||||
self.records = records
|
||||
|
||||
def __getitem__(self, i):
|
||||
mel_name, wav_name = self.records[i]
|
||||
mel = np.load(mel_name)
|
||||
wav = np.load(wav_name)
|
||||
return mel, wav
|
||||
|
||||
def __len__(self):
|
||||
return len(self.records)
|
||||
|
||||
|
||||
class LJSpeechCollector(object):
|
||||
"""A simple callable to batch LJSpeech examples."""
|
||||
|
||||
def __init__(self, padding_value=0.):
|
||||
self.padding_value = padding_value
|
||||
|
||||
def __call__(self, examples):
|
||||
batch_size = len(examples)
|
||||
mels = [example[0] for example in examples]
|
||||
wavs = [example[1] for example in examples]
|
||||
mels = batch_spec(mels, pad_value=self.padding_value)
|
||||
wavs = batch_wav(wavs, pad_value=self.padding_value)
|
||||
audio_starts = np.zeros((batch_size, ), dtype=np.int64)
|
||||
return mels, wavs, audio_starts
|
||||
|
||||
|
||||
class LJSpeechClipCollector(object):
|
||||
def __init__(self, clip_frames=65, hop_length=256):
|
||||
self.clip_frames = clip_frames
|
||||
self.hop_length = hop_length
|
||||
|
||||
def __call__(self, examples):
|
||||
mels = []
|
||||
wavs = []
|
||||
starts = []
|
||||
for example in examples:
|
||||
mel, wav_clip, start = self.clip(example)
|
||||
mels.append(mel)
|
||||
wavs.append(wav_clip)
|
||||
starts.append(start)
|
||||
mels = batch_spec(mels)
|
||||
wavs = np.stack(wavs)
|
||||
starts = np.array(starts, dtype=np.int64)
|
||||
return mels, wavs, starts
|
||||
|
||||
def clip(self, example):
|
||||
mel, wav = example
|
||||
frames = mel.shape[-1]
|
||||
start = np.random.randint(0, frames - self.clip_frames)
|
||||
wav_clip = wav[start * self.hop_length:(start + self.clip_frames) *
|
||||
self.hop_length]
|
||||
return mel, wav_clip, start
|
||||
|
||||
|
||||
class DataCollector(object):
|
||||
def __init__(self,
|
||||
context_size,
|
||||
sample_rate,
|
||||
hop_length,
|
||||
train_clip_seconds,
|
||||
valid=False):
|
||||
frames_per_second = sample_rate // hop_length
|
||||
train_clip_frames = int(
|
||||
np.ceil(train_clip_seconds * frames_per_second))
|
||||
context_frames = context_size // hop_length
|
||||
self.num_frames = train_clip_frames + context_frames
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.hop_length = hop_length
|
||||
self.valid = valid
|
||||
|
||||
def random_crop(self, sample):
|
||||
audio, mel_spectrogram = sample
|
||||
audio_frames = int(audio.size) // self.hop_length
|
||||
max_start_frame = audio_frames - self.num_frames
|
||||
assert max_start_frame >= 0, "audio is too short to be cropped"
|
||||
|
||||
frame_start = np.random.randint(0, max_start_frame)
|
||||
# frame_start = 0 # norandom
|
||||
frame_end = frame_start + self.num_frames
|
||||
|
||||
audio_start = frame_start * self.hop_length
|
||||
audio_end = frame_end * self.hop_length
|
||||
|
||||
audio = audio[audio_start:audio_end]
|
||||
return audio, mel_spectrogram, audio_start
|
||||
|
||||
def __call__(self, samples):
|
||||
# transform them first
|
||||
if self.valid:
|
||||
samples = [(audio, mel_spectrogram, 0)
|
||||
for audio, mel_spectrogram in samples]
|
||||
else:
|
||||
samples = [self.random_crop(sample) for sample in samples]
|
||||
# batch them
|
||||
audios = [sample[0] for sample in samples]
|
||||
audio_starts = [sample[2] for sample in samples]
|
||||
mels = [sample[1] for sample in samples]
|
||||
|
||||
mels = batch_spec(mels)
|
||||
|
||||
if self.valid:
|
||||
audios = batch_wav(audios, dtype=np.float32)
|
||||
else:
|
||||
audios = np.array(audios, dtype=np.float32)
|
||||
audio_starts = np.array(audio_starts, dtype=np.int64)
|
||||
return audios, mels, audio_starts
|
|
@ -1,161 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import tqdm
|
||||
import csv
|
||||
import argparse
|
||||
import numpy as np
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from paddle.io import Dataset
|
||||
from parakeet.data import batch_spec, batch_wav
|
||||
from parakeet.datasets import LJSpeechMetaData
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.audio.spec_normalizer import UnitMagnitude
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
class Transform(object):
|
||||
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
|
||||
self.sample_rate = sample_rate
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.n_mels = n_mels
|
||||
|
||||
self.spec_normalizer = UnitMagnitude(min=1e-5)
|
||||
|
||||
def __call__(self, example):
|
||||
wav_path, _, _ = example
|
||||
|
||||
sr = self.sample_rate
|
||||
n_fft = self.n_fft
|
||||
win_length = self.win_length
|
||||
hop_length = self.hop_length
|
||||
n_mels = self.n_mels
|
||||
|
||||
wav, loaded_sr = librosa.load(wav_path, sr=None)
|
||||
assert loaded_sr == sr, "sample rate does not match, resampling applied"
|
||||
|
||||
# Pad audio to the right size.
|
||||
frames = int(np.ceil(float(wav.size) / hop_length))
|
||||
fft_padding = (n_fft - hop_length) // 2 # sound
|
||||
desired_length = frames * hop_length + fft_padding * 2
|
||||
pad_amount = (desired_length - wav.size) // 2
|
||||
|
||||
if wav.size % 2 == 0:
|
||||
wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect')
|
||||
else:
|
||||
wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect')
|
||||
|
||||
# Normalize audio.
|
||||
wav = wav / np.abs(wav).max() * 0.999
|
||||
|
||||
# Compute mel-spectrogram.
|
||||
# Turn center to False to prevent internal padding.
|
||||
spectrogram = librosa.core.stft(
|
||||
wav,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_fft=n_fft,
|
||||
center=False)
|
||||
spectrogram_magnitude = np.abs(spectrogram)
|
||||
|
||||
# Compute mel-spectrograms.
|
||||
mel_filter_bank = librosa.filters.mel(sr=sr,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels)
|
||||
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
|
||||
mel_spectrogram = mel_spectrogram
|
||||
|
||||
# log scale mel_spectrogram.
|
||||
mel_spectrogram = self.spec_normalizer.transform(mel_spectrogram)
|
||||
|
||||
# Extract the center of audio that corresponds to mel spectrograms.
|
||||
audio = wav[fft_padding:-fft_padding]
|
||||
assert mel_spectrogram.shape[1] * hop_length == audio.size
|
||||
|
||||
# there is no clipping here
|
||||
return audio, mel_spectrogram
|
||||
|
||||
|
||||
def create_dataset(config, input_dir, output_dir, verbose=True):
|
||||
input_dir = Path(input_dir).expanduser()
|
||||
dataset = LJSpeechMetaData(input_dir)
|
||||
|
||||
output_dir = Path(output_dir).expanduser()
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
transform = Transform(config.sample_rate, config.n_fft, config.win_length,
|
||||
config.hop_length, config.n_mels)
|
||||
file_names = []
|
||||
|
||||
for example in tqdm.tqdm(dataset):
|
||||
fname, _, _ = example
|
||||
base_name = os.path.splitext(os.path.basename(fname))[0]
|
||||
wav_dir = output_dir / "wav"
|
||||
mel_dir = output_dir / "mel"
|
||||
wav_dir.mkdir(exist_ok=True)
|
||||
mel_dir.mkdir(exist_ok=True)
|
||||
|
||||
audio, mel = transform(example)
|
||||
np.save(str(wav_dir / base_name), audio)
|
||||
np.save(str(mel_dir / base_name), mel)
|
||||
|
||||
file_names.append((base_name, mel.shape[-1], audio.shape[-1]))
|
||||
|
||||
meta_data = pd.DataFrame.from_records(file_names)
|
||||
meta_data.to_csv(
|
||||
str(output_dir / "metadata.csv"), sep="\t", index=None, header=None)
|
||||
print("saved meta data in to {}".format(
|
||||
os.path.join(output_dir, "metadata.csv")))
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="create dataset")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
metavar="FILE",
|
||||
help="extra config to overwrite the default config")
|
||||
parser.add_argument(
|
||||
"--input", type=str, help="path of the ljspeech dataset")
|
||||
parser.add_argument(
|
||||
"--output", type=str, help="path to save output dataset")
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
config = get_cfg_defaults()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
if args.verbose:
|
||||
print(config.data)
|
||||
print(args)
|
||||
|
||||
create_dataset(config.data, args.input, args.output, args.verbose)
|
|
@ -1,82 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import os
|
||||
from pathlib import Path
|
||||
import paddle
|
||||
import parakeet
|
||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet
|
||||
from parakeet.utils import layer_tools, checkpoint
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
||||
def main(config, args):
|
||||
paddle.set_device(args.device)
|
||||
model = ConditionalWaveNet.from_pretrained(config, args.checkpoint_path)
|
||||
layer_tools.recursively_remove_weight_norm(model)
|
||||
model.eval()
|
||||
|
||||
mel_dir = Path(args.input).expanduser()
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for file_path in mel_dir.iterdir():
|
||||
mel = np.load(str(file_path))
|
||||
audio = model.predict(mel)
|
||||
audio_path = output_dir / (
|
||||
os.path.splitext(file_path.name)[0] + ".wav")
|
||||
sf.write(audio_path, audio, config.data.sample_rate)
|
||||
print("[synthesize] {} -> {}".format(file_path, audio_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="generate mel spectrogram with TransformerTTS.")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
metavar="FILE",
|
||||
help="extra config to overwrite the default config")
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="path of the checkpoint to load.")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
help="path of directory containing mel spectrogram (in .npy format)")
|
||||
parser.add_argument("--output", type=str, help="path to save outputs")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="cpu", help="device type to use.")
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_true", help="print msg")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -1,177 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.data import dataset
|
||||
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWaveNet
|
||||
from parakeet.audio import AudioProcessor
|
||||
from parakeet.utils import scheduler, mp_tools
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.training.experiment import ExperimentBase
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
|
||||
from config import get_cfg_defaults
|
||||
from ljspeech import LJSpeech, LJSpeechClipCollector, LJSpeechCollector
|
||||
|
||||
|
||||
class Experiment(ExperimentBase):
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model = ConditionalWaveNet(
|
||||
upsample_factors=config.model.upsample_factors,
|
||||
n_stack=config.model.n_stack,
|
||||
n_loop=config.model.n_loop,
|
||||
residual_channels=config.model.residual_channels,
|
||||
output_dim=config.model.output_dim,
|
||||
n_mels=config.data.n_mels,
|
||||
filter_size=config.model.filter_size,
|
||||
loss_type=config.model.loss_type,
|
||||
log_scale_min=config.model.log_scale_min)
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
lr_scheduler = paddle.optimizer.lr.StepDecay(
|
||||
config.training.lr, config.training.anneal_interval,
|
||||
config.training.anneal_rate)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
lr_scheduler,
|
||||
parameters=model.parameters(),
|
||||
grad_clip=paddle.nn.ClipGradByGlobalNorm(
|
||||
config.training.gradient_max_norm))
|
||||
|
||||
self.model = model
|
||||
self.model_core = model._layers if self.parallel else model
|
||||
self.optimizer = optimizer
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config
|
||||
args = self.args
|
||||
|
||||
ljspeech_dataset = LJSpeech(args.data)
|
||||
valid_set, train_set = dataset.split(ljspeech_dataset,
|
||||
config.data.valid_size)
|
||||
|
||||
# convolutional net's causal padding size
|
||||
context_size = config.model.n_stack \
|
||||
* sum([(config.model.filter_size - 1) * 2**i for i in range(config.model.n_loop)]) \
|
||||
+ 1
|
||||
context_frames = context_size // config.data.hop_length
|
||||
|
||||
# frames used to compute loss
|
||||
frames_per_second = config.data.sample_rate // config.data.hop_length
|
||||
train_clip_frames = math.ceil(config.data.train_clip_seconds *
|
||||
frames_per_second)
|
||||
|
||||
num_frames = train_clip_frames + context_frames
|
||||
batch_fn = LJSpeechClipCollector(num_frames, config.data.hop_length)
|
||||
if not self.parallel:
|
||||
train_loader = DataLoader(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=batch_fn)
|
||||
else:
|
||||
sampler = DistributedBatchSampler(
|
||||
train_set,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_loader = DataLoader(
|
||||
train_set, batch_sampler=sampler, collate_fn=batch_fn)
|
||||
|
||||
valid_batch_fn = LJSpeechCollector()
|
||||
valid_loader = DataLoader(
|
||||
valid_set, batch_size=1, collate_fn=valid_batch_fn)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loader = valid_loader
|
||||
|
||||
def train_batch(self):
|
||||
start = time.time()
|
||||
batch = self.read_batch()
|
||||
data_loader_time = time.time() - start
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.clear_grad()
|
||||
mel, wav, audio_starts = batch
|
||||
|
||||
y = self.model(wav, mel, audio_starts)
|
||||
loss = self.model_core.loss(y, wav)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
iteration_time = time.time() - start
|
||||
|
||||
loss_value = float(loss)
|
||||
msg = "Rank: {}, ".format(dist.get_rank())
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
||||
iteration_time)
|
||||
msg += "loss: {:>.6f}".format(loss_value)
|
||||
self.logger.info(msg)
|
||||
if dist.get_rank() == 0:
|
||||
self.visualizer.add_scalar(
|
||||
"train/loss", loss_value, global_step=self.iteration)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
valid_iterator = iter(self.valid_loader)
|
||||
valid_losses = []
|
||||
mel, wav, audio_starts = next(valid_iterator)
|
||||
y = self.model(wav, mel, audio_starts)
|
||||
loss = self.model_core.loss(y, wav)
|
||||
valid_losses.append(float(loss))
|
||||
valid_loss = np.mean(valid_losses)
|
||||
self.visualizer.add_scalar(
|
||||
"valid/loss", valid_loss, global_step=self.iteration)
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
|
@ -26,13 +26,15 @@ class AudioProcessor(object):
|
|||
win_length: int,
|
||||
hop_length: int,
|
||||
n_mels: int=80,
|
||||
f_min: int=0,
|
||||
f_max: int=None,
|
||||
fmin: int=0,
|
||||
fmax: int=None,
|
||||
window="hann",
|
||||
center=True,
|
||||
pad_mode="reflect"):
|
||||
pad_mode="reflect",
|
||||
normalize=True):
|
||||
# read & write
|
||||
self.sample_rate = sample_rate
|
||||
self.normalize = normalize
|
||||
|
||||
# stft
|
||||
self.n_fft = n_fft
|
||||
|
@ -44,8 +46,8 @@ class AudioProcessor(object):
|
|||
|
||||
# mel
|
||||
self.n_mels = n_mels
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
|
||||
self.mel_filter = self._create_mel_filter()
|
||||
self.inv_mel_filter = np.linalg.pinv(self.mel_filter)
|
||||
|
@ -54,13 +56,17 @@ class AudioProcessor(object):
|
|||
mel_filter = librosa.filters.mel(self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.f_min,
|
||||
fmax=self.f_max)
|
||||
fmin=self.fmin,
|
||||
fmax=self.fmax)
|
||||
return mel_filter
|
||||
|
||||
def read_wav(self, filename):
|
||||
# resampling may occur
|
||||
wav, _ = librosa.load(filename, sr=self.sample_rate)
|
||||
|
||||
# normalize the volume
|
||||
if self.normalize:
|
||||
wav = wav / np.max(np.abs(wav)) * 0.999
|
||||
return wav
|
||||
|
||||
def write_wav(self, path, wav):
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This modules contains normalizers for spectrogram magnitude.
|
||||
Normalizers are invertible transformations. They can be used to process
|
||||
|
@ -42,7 +41,7 @@ class LogMagnitude(NormalizerBase):
|
|||
This is a simple normalizer used in Waveglow, Waveflow, tacotron2...
|
||||
"""
|
||||
|
||||
def __init__(self, min=1e-7):
|
||||
def __init__(self, min=1e-5):
|
||||
self.min = min
|
||||
|
||||
def transform(self, x):
|
||||
|
|
|
@ -1,3 +1,16 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Parakeet's infrastructure for data processing.
|
||||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
|
|
@ -65,7 +65,7 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
|
|||
mode='constant',
|
||||
constant_values=pad_id))
|
||||
|
||||
return np.array(batch, dtype=dtype)
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
||||
|
||||
class WavBatcher(object):
|
||||
|
@ -106,7 +106,7 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
|
|||
np.pad(example, [(0, pad_len)],
|
||||
mode='constant',
|
||||
constant_values=pad_value))
|
||||
return np.array(batch, dtype=dtype)
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
||||
|
||||
class SpecBatcher(object):
|
||||
|
@ -160,4 +160,4 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32):
|
|||
np.pad(example, [(0, 0), (0, pad_len)],
|
||||
mode='constant',
|
||||
constant_values=pad_value))
|
||||
return np.array(batch, dtype=dtype)
|
||||
return np.array(batch, dtype=dtype), np.array(lengths, dtype=np.int64)
|
||||
|
|
|
@ -15,24 +15,79 @@
|
|||
from paddle.io import Dataset
|
||||
import os
|
||||
import librosa
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
__all__ = ["AudioFolderDataset"]
|
||||
__all__ = ["AudioSegmentDataset", "AudioDataset", "AudioFolderDataset"]
|
||||
|
||||
|
||||
class AudioFolderDataset(Dataset):
|
||||
def __init__(self, path, sample_rate, extension="wav"):
|
||||
self.root = os.path.expanduser(path)
|
||||
self.sample_rate = sample_rate
|
||||
self.extension = extension
|
||||
self.file_names = [
|
||||
os.path.join(self.root, x) for x in os.listdir(self.root) \
|
||||
if os.path.splitext(x)[-1] == self.extension]
|
||||
self.length = len(self.file_names)
|
||||
class AudioSegmentDataset(Dataset):
|
||||
"""A simple dataset adaptor for audio files to train vocoders.
|
||||
Read -> trim silence -> normalize -> extract a segment
|
||||
"""
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
def __init__(self,
|
||||
file_paths: List[Path],
|
||||
sample_rate: int,
|
||||
length: int,
|
||||
top_db: float):
|
||||
self.file_paths = file_paths
|
||||
self.sr = sample_rate
|
||||
self.top_db = top_db
|
||||
self.length = length # samples in the clip
|
||||
|
||||
def __getitem__(self, i):
|
||||
file_name = self.file_names[i]
|
||||
y, _ = librosa.load(file_name, sr=self.sample_rate) # pylint: disable=unused-variable
|
||||
fpath = self.file_paths[i]
|
||||
y, sr = librosa.load(fpath, self.sr)
|
||||
y, _ = librosa.effects.trim(y, top_db=self.top_db)
|
||||
y = librosa.util.normalize(y)
|
||||
y = y.astype(np.float32)
|
||||
|
||||
# pad or trim
|
||||
if y.size <= self.length:
|
||||
y = np.pad(y, [0, self.length - len(y)], mode='constant')
|
||||
else:
|
||||
start = np.random.randint(0, 1 + len(y) - self.length)
|
||||
y = y[start:start + self.length]
|
||||
return y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
"""A simple dataset adaptor for the audio files.
|
||||
Read -> trim silence -> normalize
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
file_paths: List[Path],
|
||||
sample_rate: int,
|
||||
top_db: float=60):
|
||||
self.file_paths = file_paths
|
||||
self.sr = sample_rate
|
||||
self.top_db = top_db
|
||||
|
||||
def __getitem__(self, i):
|
||||
fpath = self.file_paths[i]
|
||||
y, sr = librosa.load(fpath, self.sr)
|
||||
y, _ = librosa.effects.trim(y, top_db=self.top_db)
|
||||
y = librosa.util.normalize(y)
|
||||
y = y.astype(np.float32)
|
||||
return y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_paths)
|
||||
|
||||
|
||||
class AudioFolderDataset(AudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
sample_rate,
|
||||
top_db=60,
|
||||
extension=".wav", ):
|
||||
root = Path(root).expanduser()
|
||||
file_paths = sorted(list(root.rglob("*{}".format(extension))))
|
||||
super().__init__(file_paths, sample_rate, top_db)
|
||||
|
|
|
@ -0,0 +1,305 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from parakeet.frontend.phonectic import Phonetics
|
||||
"""
|
||||
A phonology system with ARPABET symbols and limited punctuations. The G2P
|
||||
conversion is done by g2p_en.
|
||||
|
||||
Note that g2p_en does not handle words with hypen well. So make sure the input
|
||||
sentence is first normalized.
|
||||
"""
|
||||
from parakeet.frontend.vocab import Vocab
|
||||
from g2p_en import G2p
|
||||
|
||||
|
||||
class ARPABET(Phonetics):
|
||||
"""A phonology for English that uses ARPABET as the phoneme vocabulary.
|
||||
See http://www.speech.cs.cmu.edu/cgi-bin/cmudict for more details.
|
||||
Phoneme Example Translation
|
||||
------- ------- -----------
|
||||
AA odd AA D
|
||||
AE at AE T
|
||||
AH hut HH AH T
|
||||
AO ought AO T
|
||||
AW cow K AW
|
||||
AY hide HH AY D
|
||||
B be B IY
|
||||
CH cheese CH IY Z
|
||||
D dee D IY
|
||||
DH thee DH IY
|
||||
EH Ed EH D
|
||||
ER hurt HH ER T
|
||||
EY ate EY T
|
||||
F fee F IY
|
||||
G green G R IY N
|
||||
HH he HH IY
|
||||
IH it IH T
|
||||
IY eat IY T
|
||||
JH gee JH IY
|
||||
K key K IY
|
||||
L lee L IY
|
||||
M me M IY
|
||||
N knee N IY
|
||||
NG ping P IH NG
|
||||
OW oat OW T
|
||||
OY toy T OY
|
||||
P pee P IY
|
||||
R read R IY D
|
||||
S sea S IY
|
||||
SH she SH IY
|
||||
T tea T IY
|
||||
TH theta TH EY T AH
|
||||
UH hood HH UH D
|
||||
UW two T UW
|
||||
V vee V IY
|
||||
W we W IY
|
||||
Y yield Y IY L D
|
||||
Z zee Z IY
|
||||
ZH seizure S IY ZH ER
|
||||
"""
|
||||
phonemes = [
|
||||
'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER',
|
||||
'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
|
||||
'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UW', 'UH', 'V', 'W', 'Y', 'Z',
|
||||
'ZH'
|
||||
]
|
||||
punctuations = [',', '.', '?', '!']
|
||||
symbols = phonemes + punctuations
|
||||
_stress_to_no_stress_ = {
|
||||
'AA0': 'AA',
|
||||
'AA1': 'AA',
|
||||
'AA2': 'AA',
|
||||
'AE0': 'AE',
|
||||
'AE1': 'AE',
|
||||
'AE2': 'AE',
|
||||
'AH0': 'AH',
|
||||
'AH1': 'AH',
|
||||
'AH2': 'AH',
|
||||
'AO0': 'AO',
|
||||
'AO1': 'AO',
|
||||
'AO2': 'AO',
|
||||
'AW0': 'AW',
|
||||
'AW1': 'AW',
|
||||
'AW2': 'AW',
|
||||
'AY0': 'AY',
|
||||
'AY1': 'AY',
|
||||
'AY2': 'AY',
|
||||
'EH0': 'EH',
|
||||
'EH1': 'EH',
|
||||
'EH2': 'EH',
|
||||
'ER0': 'ER',
|
||||
'ER1': 'ER',
|
||||
'ER2': 'ER',
|
||||
'EY0': 'EY',
|
||||
'EY1': 'EY',
|
||||
'EY2': 'EY',
|
||||
'IH0': 'IH',
|
||||
'IH1': 'IH',
|
||||
'IH2': 'IH',
|
||||
'IY0': 'IY',
|
||||
'IY1': 'IY',
|
||||
'IY2': 'IY',
|
||||
'OW0': 'OW',
|
||||
'OW1': 'OW',
|
||||
'OW2': 'OW',
|
||||
'OY0': 'OY',
|
||||
'OY1': 'OY',
|
||||
'OY2': 'OY',
|
||||
'UH0': 'UH',
|
||||
'UH1': 'UH',
|
||||
'UH2': 'UH',
|
||||
'UW0': 'UW',
|
||||
'UW1': 'UW',
|
||||
'UW2': 'UW'
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.backend = G2p()
|
||||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||
|
||||
def _remove_vowels(self, phone):
|
||||
return self._stress_to_no_stress_.get(phone, phone)
|
||||
|
||||
def phoneticize(self, sentence, add_start_end=False):
|
||||
""" Normalize the input text sequence and convert it into pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
phonemes = [
|
||||
self._remove_vowels(item) for item in self.backend(sentence)
|
||||
]
|
||||
if add_start_end:
|
||||
start = self.vocab.start_symbol
|
||||
end = self.vocab.end_symbol
|
||||
phonemes = [start] + phonemes + [end]
|
||||
phonemes = [item for item in phonemes if item in self.vocab.stoi]
|
||||
return phonemes
|
||||
|
||||
def numericalize(self, phonemes):
|
||||
""" Convert pronunciation sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
phonemes: List[str]
|
||||
The list of pronunciation sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
ids = [self.vocab.lookup(item) for item in phonemes]
|
||||
return ids
|
||||
|
||||
def reverse(self, ids):
|
||||
""" Reverse the list of pronunciation id sequence to a list of pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
ids: List[int]
|
||||
The list of pronunciation id sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
return [self.vocab.reverse(i) for i in ids]
|
||||
|
||||
def __call__(self, sentence, add_start_end=False):
|
||||
""" Convert the input text sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
return self.numericalize(
|
||||
self.phoneticize(
|
||||
sentence, add_start_end=add_start_end))
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
# 47 = 39 phones + 4 punctuations + 4 special tokens
|
||||
return len(self.vocab)
|
||||
|
||||
|
||||
class ARPABETWithStress(Phonetics):
|
||||
phonemes = [
|
||||
'AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
|
||||
'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D',
|
||||
'DH', 'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2',
|
||||
'F', 'G', 'HH', 'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K',
|
||||
'L', 'M', 'N', 'NG', 'OW0', 'OW1', 'OW2', 'OY0', 'OY1', 'OY2', 'P',
|
||||
'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW0', 'UW1', 'UW2',
|
||||
'V', 'W', 'Y', 'Z', 'ZH'
|
||||
]
|
||||
punctuations = [',', '.', '?', '!']
|
||||
symbols = phonemes + punctuations
|
||||
|
||||
def __init__(self):
|
||||
self.backend = G2p()
|
||||
self.vocab = Vocab(self.phonemes + self.punctuations)
|
||||
|
||||
def phoneticize(self, sentence, add_start_end=False):
|
||||
""" Normalize the input text sequence and convert it into pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
phonemes = self.backend(sentence)
|
||||
if add_start_end:
|
||||
start = self.vocab.start_symbol
|
||||
end = self.vocab.end_symbol
|
||||
phonemes = [start] + phonemes + [end]
|
||||
phonemes = [item for item in phonemes if item in self.vocab.stoi]
|
||||
return phonemes
|
||||
|
||||
def numericalize(self, phonemes):
|
||||
""" Convert pronunciation sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
phonemes: List[str]
|
||||
The list of pronunciation sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
ids = [self.vocab.lookup(item) for item in phonemes]
|
||||
return ids
|
||||
|
||||
def reverse(self, ids):
|
||||
""" Reverse the list of pronunciation id sequence to a list of pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
ids: List[int]
|
||||
The list of pronunciation id sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
return [self.vocab.reverse(i) for i in ids]
|
||||
|
||||
def __call__(self, sentence, add_start_end=False):
|
||||
""" Convert the input text sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
return self.numericalize(
|
||||
self.phoneticize(
|
||||
sentence, add_start_end=add_start_end))
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
# 77 = 69 phones + 4 punctuations + 4 special tokens
|
||||
return len(self.vocab)
|
|
@ -11,4 +11,3 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
|
|
@ -11,4 +11,3 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def full2half_width(ustr):
|
||||
half = []
|
||||
for u in ustr:
|
||||
|
|
|
@ -67,6 +67,7 @@ class English(Phonetics):
|
|||
phonemes = ([] if start is None else [start]) \
|
||||
+ self.backend(sentence) \
|
||||
+ ([] if end is None else [end])
|
||||
phonemes = [item for item in phonemes if item in self.vocab.stoi]
|
||||
return phonemes
|
||||
|
||||
def numericalize(self, phonemes):
|
||||
|
|
|
@ -0,0 +1,331 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A Simple Chinese Phonology using pinyin symbols.
|
||||
The G2P conversion converts pinyin string to symbols. Also it can handle string
|
||||
in Chinese chracters, but due to the complexity of chinese G2P, we can leave
|
||||
text -> pinyin to other part of a TTS system. Other NLP techniques may be used
|
||||
(e.g. tokenization, tagging, NER...)
|
||||
"""
|
||||
import re
|
||||
from parakeet.frontend.phonectic import Phonetics
|
||||
from parakeet.frontend.vocab import Vocab
|
||||
import pypinyin
|
||||
from pypinyin.core import Pinyin, Style
|
||||
from pypinyin.core import DefaultConverter
|
||||
from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
|
||||
from itertools import product
|
||||
|
||||
_punctuations = [',', '。', '?', '!']
|
||||
_initials = [
|
||||
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'j', 'q', 'x', 'zh',
|
||||
'ch', 'sh', 'r', 'z', 'c', 's'
|
||||
]
|
||||
_finals = [
|
||||
'ii', 'iii', 'a', 'o', 'e', 'ea', 'ai', 'ei', 'ao', 'ou', 'an', 'en',
|
||||
'ang', 'eng', 'er', 'i', 'ia', 'io', 'ie', 'iai', 'iao', 'iou', 'ian',
|
||||
'ien', 'iang', 'ieng', 'u', 'ua', 'uo', 'uai', 'uei', 'uan', 'uen', 'uang',
|
||||
'ueng', 'v', 've', 'van', 'ven', 'veng'
|
||||
]
|
||||
_ernized_symbol = ['&r']
|
||||
_phones = _initials + _finals + _ernized_symbol + _punctuations
|
||||
_tones = ['0', '1', '2', '3', '4', '5']
|
||||
|
||||
_toned_finals = [final + tone for final, tone in product(_finals, _tones[1:])]
|
||||
_toned_phonems = _initials + _toned_finals + _ernized_symbol + _punctuations
|
||||
|
||||
|
||||
class ParakeetConverter(NeutralToneWith5Mixin, DefaultConverter):
|
||||
pass
|
||||
|
||||
|
||||
class ParakeetPinyin(Phonetics):
|
||||
def __init__(self):
|
||||
self.vocab_phonemes = Vocab(_phones)
|
||||
self.vocab_tones = Vocab(_tones)
|
||||
self.pinyin_backend = Pinyin(ParakeetConverter())
|
||||
|
||||
def convert_pypinyin_tone3(self, syllables, add_start_end=False):
|
||||
phonemes, tones = _convert_to_parakeet_style_pinyin(syllables)
|
||||
|
||||
if add_start_end:
|
||||
start = self.vocab_phonemes.start_symbol
|
||||
end = self.vocab_phonemes.end_symbol
|
||||
phonemes = [start] + phonemes + [end]
|
||||
|
||||
start = self.vocab_tones.start_symbol
|
||||
end = self.vocab_tones.end_symbol
|
||||
phonemes = [start] + tones + [end]
|
||||
|
||||
phonemes = [
|
||||
item for item in phonemes if item in self.vocab_phonemes.stoi
|
||||
]
|
||||
tones = [item for item in tones if item in self.vocab_tones.stoi]
|
||||
return phonemes, tones
|
||||
|
||||
def phoneticize(self, sentence, add_start_end=False):
|
||||
""" Normalize the input text sequence and convert it into pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
syllables = self.pinyin_backend.lazy_pinyin(
|
||||
sentence, style=Style.TONE3, strict=True)
|
||||
phonemes, tones = self.convert_pypinyin_tone3(
|
||||
syllables, add_start_end=add_start_end)
|
||||
return phonemes, tones
|
||||
|
||||
def numericalize(self, phonemes, tones):
|
||||
""" Convert pronunciation sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
phonemes: List[str]
|
||||
The list of pronunciation sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
phoneme_ids = [self.vocab_phonemes.lookup(item) for item in phonemes]
|
||||
tone_ids = [self.vocab_tones.lookup(item) for item in tones]
|
||||
return phoneme_ids, tone_ids
|
||||
|
||||
def __call__(self, sentence, add_start_end=False):
|
||||
""" Convert the input text sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
phonemes, tones = self.phoneticize(
|
||||
sentence, add_start_end=add_start_end)
|
||||
phoneme_ids, tone_ids = self.numericalize(phonemes, tones)
|
||||
return phoneme_ids, tone_ids
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
# 70 = 62 phones + 4 punctuations + 4 special tokens
|
||||
return len(self.vocab_phonemes)
|
||||
|
||||
@property
|
||||
def tone_vocab_size(self):
|
||||
# 10 = 1 non tone + 5 tone + 4 special tokens
|
||||
return len(self.vocab_tones)
|
||||
|
||||
|
||||
class ParakeetPinyinWithTone(Phonetics):
|
||||
def __init__(self):
|
||||
self.vocab = Vocab(_toned_phonems)
|
||||
self.pinyin_backend = Pinyin(ParakeetConverter())
|
||||
|
||||
def convert_pypinyin_tone3(self, syllables, add_start_end=False):
|
||||
phonemes = _convert_to_parakeet_style_pinyin_with_tone(syllables)
|
||||
|
||||
if add_start_end:
|
||||
start = self.vocab_phonemes.start_symbol
|
||||
end = self.vocab_phonemes.end_symbol
|
||||
phonemes = [start] + phonemes + [end]
|
||||
|
||||
phonemes = [item for item in phonemes if item in self.vocab.stoi]
|
||||
return phonemes
|
||||
|
||||
def phoneticize(self, sentence, add_start_end=False):
|
||||
""" Normalize the input text sequence and convert it into pronunciation sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation sequence.
|
||||
"""
|
||||
syllables = self.pinyin_backend.lazy_pinyin(
|
||||
sentence, style=Style.TONE3, strict=True)
|
||||
phonemes = self.convert_pypinyin_tone3(
|
||||
syllables, add_start_end=add_start_end)
|
||||
return phonemes
|
||||
|
||||
def numericalize(self, phonemes):
|
||||
""" Convert pronunciation sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
phonemes: List[str]
|
||||
The list of pronunciation sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[int]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
phoneme_ids = [self.vocab.lookup(item) for item in phonemes]
|
||||
return phoneme_ids
|
||||
|
||||
def __call__(self, sentence, add_start_end=False):
|
||||
""" Convert the input text sequence into pronunciation id sequence.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
sentence: str
|
||||
The input text sequence.
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[str]
|
||||
The list of pronunciation id sequence.
|
||||
"""
|
||||
phonemes = self.phoneticize(sentence, add_start_end=add_start_end)
|
||||
phoneme_ids = self.numericalize(phonemes)
|
||||
return phoneme_ids
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
""" Vocab size.
|
||||
"""
|
||||
# 230 = 222 phones + 4 punctuations + 4 special tokens
|
||||
return len(self.vocab)
|
||||
|
||||
|
||||
def _convert_to_parakeet_convension(syllable):
|
||||
# from pypinyin.Style.TONE3 to parakeet convension
|
||||
tone = syllable[-1]
|
||||
syllable = syllable[:-1]
|
||||
|
||||
# expansion of o -> uo
|
||||
syllable = re.sub(r"([bpmf])o$", r"\1uo", syllable)
|
||||
|
||||
# expansion for iong, ong
|
||||
syllable = syllable.replace("iong", "veng").replace("ong", "ueng")
|
||||
|
||||
# expansion for ing, in
|
||||
syllable = syllable.replace("ing", "ieng").replace("in", "ien")
|
||||
|
||||
# expansion for un, ui, iu
|
||||
syllable = syllable.replace("un","uen")\
|
||||
.replace("ui", "uei")\
|
||||
.replace("iu", "iou")
|
||||
|
||||
# rule for variants of i
|
||||
syllable = syllable.replace("zi", "zii")\
|
||||
.replace("ci", "cii")\
|
||||
.replace("si", "sii")\
|
||||
.replace("zhi", "zhiii")\
|
||||
.replace("chi", "chiii")\
|
||||
.replace("shi", "shiii")\
|
||||
.replace("ri", "riii")
|
||||
|
||||
# rule for y preceding i, u
|
||||
syllable = syllable.replace("yi", "i").replace("yu", "v").replace("y", "i")
|
||||
|
||||
# rule for w
|
||||
syllable = syllable.replace("wu", "u").replace("w", "u")
|
||||
|
||||
# rule for v following j, q, x
|
||||
syllable = syllable.replace("ju", "jv")\
|
||||
.replace("qu", "qv")\
|
||||
.replace("xu", "xv")
|
||||
|
||||
return syllable + tone
|
||||
|
||||
|
||||
def _split_syllable(syllable: str):
|
||||
global _punctuations
|
||||
|
||||
if syllable in _punctuations:
|
||||
# syllables, tones
|
||||
return [syllable], ['0']
|
||||
|
||||
syllable = _convert_to_parakeet_convension(syllable)
|
||||
|
||||
tone = syllable[-1]
|
||||
syllable = syllable[:-1]
|
||||
|
||||
phones = []
|
||||
tones = []
|
||||
|
||||
global _initials
|
||||
if syllable[:2] in _initials:
|
||||
phones.append(syllable[:2])
|
||||
tones.append('0')
|
||||
phones.append(syllable[2:])
|
||||
tones.append(tone)
|
||||
elif syllable[0] in _initials:
|
||||
phones.append(syllable[0])
|
||||
tones.append('0')
|
||||
phones.append(syllable[1:])
|
||||
tones.append(tone)
|
||||
else:
|
||||
phones.append(syllable)
|
||||
tones.append(tone)
|
||||
return phones, tones
|
||||
|
||||
|
||||
def _convert_to_parakeet_style_pinyin(syllables):
|
||||
phones, tones = [], []
|
||||
for syllable in syllables:
|
||||
p, t = _split_syllable(syllable)
|
||||
phones.extend(p)
|
||||
tones.extend(t)
|
||||
return phones, tones
|
||||
|
||||
|
||||
def _split_syllable_with_tone(syllable: str):
|
||||
global _punctuations
|
||||
|
||||
if syllable in _punctuations:
|
||||
# syllables
|
||||
return [syllable]
|
||||
|
||||
syllable = _convert_to_parakeet_convension(syllable)
|
||||
|
||||
phones = []
|
||||
|
||||
global _initials
|
||||
if syllable[:2] in _initials:
|
||||
phones.append(syllable[:2])
|
||||
phones.append(syllable[2:])
|
||||
elif syllable[0] in _initials:
|
||||
phones.append(syllable[0])
|
||||
phones.append(syllable[1:])
|
||||
else:
|
||||
phones.append(syllable)
|
||||
return phones
|
||||
|
||||
|
||||
def _convert_to_parakeet_style_pinyin_with_tone(syllables):
|
||||
phones = []
|
||||
for syllable in syllables:
|
||||
p = _split_syllable_with_tone(syllable)
|
||||
phones.extend(p)
|
||||
return phones
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
#from parakeet.models.clarinet import *
|
||||
from parakeet.models.waveflow import *
|
||||
from parakeet.models.wavenet import *
|
||||
#from parakeet.models.wavenet import *
|
||||
|
||||
from parakeet.models.transformer_tts import *
|
||||
#from parakeet.models.deepvoice3 import *
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from scipy.interpolate import interp1d
|
||||
from sklearn.metrics import roc_curve
|
||||
from scipy.optimize import brentq
|
||||
|
||||
|
||||
class LSTMSpeakerEncoder(nn.Layer):
|
||||
def __init__(self, n_mels, num_layers, hidden_size, output_size):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(n_mels, hidden_size, num_layers)
|
||||
self.linear = nn.Linear(hidden_size, output_size)
|
||||
self.similarity_weight = self.create_parameter(
|
||||
[1], default_initializer=I.Constant(10.))
|
||||
self.similarity_bias = self.create_parameter(
|
||||
[1], default_initializer=I.Constant(-5.))
|
||||
|
||||
def forward(self, utterances, num_speakers, initial_states=None):
|
||||
normalized_embeds = self.embed_sequences(utterances, initial_states)
|
||||
embeds = normalized_embeds.reshape([num_speakers, -1, num_speakers])
|
||||
loss, eer = self.loss(embeds)
|
||||
return loss, eer
|
||||
|
||||
def embed_sequences(self, utterances, initial_states=None, reduce=False):
|
||||
out, (h, c) = self.lstm(utterances, initial_states)
|
||||
embeds = F.relu(self.linear(h[-1]))
|
||||
normalized_embeds = F.normalize(embeds)
|
||||
if reduce:
|
||||
embed = paddle.mean(normalized_embeds, 0)
|
||||
embed = F.normalize(embed, axis=0)
|
||||
return embed
|
||||
return normalized_embeds
|
||||
|
||||
def embed_utterance(self, utterances, initial_states=None):
|
||||
# utterances: [B, T, C] -> embed [C']
|
||||
embed = self.embed_sequences(utterances, initial_states, reduce=True)
|
||||
return embed
|
||||
|
||||
def similarity_matrix(self, embeds):
|
||||
# (N, M, C)
|
||||
speakers_per_batch, utterances_per_speaker, embed_dim = embeds.shape
|
||||
|
||||
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
||||
centroids_incl = paddle.mean(embeds, axis=1)
|
||||
centroids_incl_norm = paddle.norm(
|
||||
centroids_incl, p=2, axis=1, keepdim=True)
|
||||
normalized_centroids_incl = centroids_incl / centroids_incl_norm
|
||||
|
||||
# Exclusive centroids (1 per utterance)
|
||||
centroids_excl = paddle.broadcast_to(
|
||||
paddle.sum(embeds, axis=1, keepdim=True), embeds.shape) - embeds
|
||||
centroids_excl /= (utterances_per_speaker - 1)
|
||||
centroids_excl_norm = paddle.norm(
|
||||
centroids_excl, p=2, axis=2, keepdim=True)
|
||||
normalized_centroids_excl = centroids_excl / centroids_excl_norm
|
||||
|
||||
p1 = paddle.matmul(
|
||||
embeds.reshape([-1, embed_dim]),
|
||||
normalized_centroids_incl,
|
||||
transpose_y=True) # (NMN)
|
||||
p1 = p1.reshape([-1])
|
||||
# print("p1: ", p1.shape)
|
||||
p2 = paddle.bmm(
|
||||
embeds.reshape([-1, 1, embed_dim]),
|
||||
normalized_centroids_excl.reshape(
|
||||
[-1, embed_dim, 1])) # (NM, 1, 1)
|
||||
p2 = p2.reshape([-1]) # (NM)
|
||||
|
||||
# begin: alternative implementation for scatter
|
||||
with paddle.no_grad():
|
||||
index = paddle.arange(
|
||||
0, speakers_per_batch * utterances_per_speaker,
|
||||
dtype="int64").reshape(
|
||||
[speakers_per_batch, utterances_per_speaker])
|
||||
index = index * speakers_per_batch + paddle.arange(
|
||||
0, speakers_per_batch, dtype="int64").unsqueeze(-1)
|
||||
index = paddle.reshape(index, [-1])
|
||||
ones = paddle.ones([
|
||||
speakers_per_batch * utterances_per_speaker * speakers_per_batch
|
||||
])
|
||||
zeros = paddle.zeros_like(index, dtype=ones.dtype)
|
||||
mask_p1 = paddle.scatter(ones, index, zeros)
|
||||
p = p1 * mask_p1 + (1 - mask_p1) * paddle.scatter(ones, index, p2)
|
||||
# end: alternative implementation for scatter
|
||||
# p = paddle.scatter(p1, index, p2)
|
||||
|
||||
p = p * self.similarity_weight + self.similarity_bias # neg
|
||||
p = p.reshape(
|
||||
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
|
||||
return p, p1, p2
|
||||
|
||||
def do_gradient_ops(self):
|
||||
for p in [self.similarity_weight, self.similarity_bias]:
|
||||
g = p._grad_ivar()
|
||||
g[...] = g * 0.01
|
||||
|
||||
def loss(self, embeds):
|
||||
"""
|
||||
Computes the softmax loss according the section 2.1 of GE2E.
|
||||
|
||||
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
||||
utterances_per_speaker, embedding_size)
|
||||
:return: the loss and the EER for this batch of embeddings.
|
||||
"""
|
||||
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
||||
|
||||
# Loss
|
||||
sim_matrix, *_ = self.similarity_matrix(embeds)
|
||||
sim_matrix = sim_matrix.reshape(
|
||||
[speakers_per_batch * utterances_per_speaker, speakers_per_batch])
|
||||
target = paddle.arange(
|
||||
0, speakers_per_batch, dtype="int64").unsqueeze(-1)
|
||||
target = paddle.expand(target,
|
||||
[speakers_per_batch, utterances_per_speaker])
|
||||
target = paddle.reshape(target, [-1])
|
||||
|
||||
loss = nn.CrossEntropyLoss()(sim_matrix, target)
|
||||
|
||||
# EER (not backpropagated)
|
||||
with paddle.no_grad():
|
||||
ground_truth = target.numpy()
|
||||
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
||||
labels = np.array([inv_argmax(i) for i in ground_truth])
|
||||
preds = sim_matrix.numpy()
|
||||
|
||||
# Snippet from https://yangcha.github.io/EER-ROC/
|
||||
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
||||
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
||||
|
||||
return loss, eer
|
|
@ -13,15 +13,18 @@
|
|||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import parakeet
|
||||
from paddle.nn import initializer as I
|
||||
from paddle.fluid.layers import sequence_mask
|
||||
|
||||
from parakeet.modules.conv import Conv1dBatchNorm
|
||||
from parakeet.modules.attention import LocationSensitiveAttention
|
||||
from parakeet.modules import masking
|
||||
from parakeet.modules.losses import guided_attention_loss
|
||||
from parakeet.utils import checkpoint
|
||||
from tqdm import trange
|
||||
|
||||
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||
|
||||
|
@ -63,7 +66,7 @@ class DecoderPreNet(nn.Layer):
|
|||
----------
|
||||
x: Tensor [shape=(B, T_mel, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: Tensor [shape=(B, T_mel, C)]
|
||||
|
@ -110,7 +113,7 @@ class DecoderPostNet(nn.Layer):
|
|||
self.dropout = dropout
|
||||
self.num_layers = num_layers
|
||||
|
||||
padding = int((kernel_size - 1) / 2),
|
||||
padding = int((kernel_size - 1) / 2)
|
||||
|
||||
self.conv_batchnorms = nn.LayerList()
|
||||
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
||||
|
@ -120,8 +123,7 @@ class DecoderPostNet(nn.Layer):
|
|||
d_hidden,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||
low=-k, high=k)),
|
||||
bias_attr=I.Uniform(-k, k),
|
||||
data_format='NLC'))
|
||||
|
||||
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||
|
@ -131,8 +133,7 @@ class DecoderPostNet(nn.Layer):
|
|||
d_hidden,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||
low=-k, high=k)),
|
||||
bias_attr=I.Uniform(-k, k),
|
||||
data_format='NLC') for i in range(1, num_layers - 1)
|
||||
])
|
||||
|
||||
|
@ -142,18 +143,17 @@ class DecoderPostNet(nn.Layer):
|
|||
d_mels,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||
low=-k, high=k)),
|
||||
bias_attr=I.Uniform(-k, k),
|
||||
data_format='NLC'))
|
||||
|
||||
def forward(self, input):
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: Tensor [shape=(B, T_mel, C)]
|
||||
x: Tensor [shape=(B, T_mel, C)]
|
||||
Output sequence of features from decoder.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: Tensor [shape=(B, T_mel, C)]
|
||||
|
@ -162,12 +162,12 @@ class DecoderPostNet(nn.Layer):
|
|||
"""
|
||||
|
||||
for i in range(len(self.conv_batchnorms) - 1):
|
||||
input = F.dropout(
|
||||
F.tanh(self.conv_batchnorms[i](input)),
|
||||
x = F.dropout(
|
||||
F.tanh(self.conv_batchnorms[i](x)),
|
||||
self.dropout,
|
||||
training=self.training)
|
||||
output = F.dropout(
|
||||
self.conv_batchnorms[self.num_layers - 1](input),
|
||||
self.conv_batchnorms[self.num_layers - 1](x),
|
||||
self.dropout,
|
||||
training=self.training)
|
||||
return output
|
||||
|
@ -180,13 +180,13 @@ class Tacotron2Encoder(nn.Layer):
|
|||
----------
|
||||
d_hidden: int
|
||||
The hidden size in encoder module.
|
||||
|
||||
|
||||
conv_layers: int
|
||||
The number of conv layers.
|
||||
|
||||
kernel_size: int
|
||||
The kernel size of conv layers.
|
||||
|
||||
|
||||
p_dropout: float
|
||||
The droput probability.
|
||||
"""
|
||||
|
@ -206,8 +206,7 @@ class Tacotron2Encoder(nn.Layer):
|
|||
kernel_size,
|
||||
stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||
low=-k, high=k)),
|
||||
bias_attr=I.Uniform(-k, k),
|
||||
data_format='NLC') for i in range(conv_layers)
|
||||
])
|
||||
self.p_dropout = p_dropout
|
||||
|
@ -221,12 +220,12 @@ class Tacotron2Encoder(nn.Layer):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
x: Tensor [shape=(B, T)]
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
x: Tensor [shape=(B, T, C)]
|
||||
Input embeddings.
|
||||
|
||||
text_lens: Tensor [shape=(B,)], optional
|
||||
Batch of lengths of each text input batch. Defaults to None.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : Tensor [shape=(B, T, C)]
|
||||
|
@ -253,7 +252,7 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
reduction_factor: int
|
||||
The reduction factor of tacotron.
|
||||
|
||||
|
||||
d_encoder: int
|
||||
The hidden size of encoder.
|
||||
|
||||
|
@ -265,13 +264,13 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
d_decoder_rnn: int
|
||||
The decoder rnn layer hidden size.
|
||||
|
||||
|
||||
d_attention: int
|
||||
The hidden size of the linear layer in location sensitive attention.
|
||||
|
||||
attention_filters: int
|
||||
The filter size of the conv layer in location sensitive attention.
|
||||
|
||||
|
||||
attention_kernel_size: int
|
||||
The kernel size of the conv layer in location sensitive attention.
|
||||
|
||||
|
@ -283,6 +282,10 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
p_decoder_dropout: float
|
||||
The droput probability in decoder.
|
||||
|
||||
use_stop_token: bool
|
||||
Whether to use a binary classifier for stop token prediction.
|
||||
Defaults to False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -297,7 +300,8 @@ class Tacotron2Decoder(nn.Layer):
|
|||
attention_kernel_size: int,
|
||||
p_prenet_dropout: float,
|
||||
p_attention_dropout: float,
|
||||
p_decoder_dropout: float):
|
||||
p_decoder_dropout: float,
|
||||
use_stop_token: bool=False):
|
||||
super().__init__()
|
||||
self.d_mels = d_mels
|
||||
self.reduction_factor = reduction_factor
|
||||
|
@ -313,22 +317,44 @@ class Tacotron2Decoder(nn.Layer):
|
|||
d_prenet,
|
||||
dropout_rate=p_prenet_dropout)
|
||||
|
||||
# attention_rnn takes attention's context vector has an
|
||||
# auxiliary input
|
||||
self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn)
|
||||
|
||||
self.attention_layer = LocationSensitiveAttention(
|
||||
d_attention_rnn, d_encoder, d_attention, attention_filters,
|
||||
attention_kernel_size)
|
||||
|
||||
# decoder_rnn takes prenet's output and attention_rnn's input
|
||||
# as input
|
||||
self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder,
|
||||
d_decoder_rnn)
|
||||
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
||||
d_mels * reduction_factor)
|
||||
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||
|
||||
self.use_stop_token = use_stop_token
|
||||
if use_stop_token:
|
||||
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||
|
||||
# states - temporary attributes
|
||||
self.attention_hidden = None
|
||||
self.attention_cell = None
|
||||
|
||||
self.decoder_hidden = None
|
||||
self.decoder_cell = None
|
||||
|
||||
self.attention_weights = None
|
||||
self.attention_weights_cum = None
|
||||
self.attention_context = None
|
||||
|
||||
self.key = None
|
||||
self.mask = None
|
||||
self.processed_key = None
|
||||
|
||||
def _initialize_decoder_states(self, key):
|
||||
"""init states be used in decoder
|
||||
"""
|
||||
batch_size = key.shape[0]
|
||||
MAX_TIME = key.shape[1]
|
||||
batch_size, encoder_steps, _ = key.shape
|
||||
|
||||
self.attention_hidden = paddle.zeros(
|
||||
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
||||
|
@ -341,21 +367,22 @@ class Tacotron2Decoder(nn.Layer):
|
|||
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
||||
|
||||
self.attention_weights = paddle.zeros(
|
||||
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
||||
shape=[batch_size, encoder_steps], dtype=key.dtype)
|
||||
self.attention_weights_cum = paddle.zeros(
|
||||
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
||||
shape=[batch_size, encoder_steps], dtype=key.dtype)
|
||||
self.attention_context = paddle.zeros(
|
||||
shape=[batch_size, self.d_encoder], dtype=key.dtype)
|
||||
|
||||
self.key = key #[B, T, C]
|
||||
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
|
||||
self.key = key # [B, T, C]
|
||||
# pre-compute projected keys to improve efficiency
|
||||
self.processed_key = self.attention_layer.key_layer(key) # [B, T, C]
|
||||
|
||||
def _decode(self, query):
|
||||
"""decode one time step
|
||||
"""
|
||||
cell_input = paddle.concat([query, self.attention_context], axis=-1)
|
||||
|
||||
# The first lstm layer
|
||||
# The first lstm layer (or spec encoder lstm)
|
||||
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_hidden = F.dropout(
|
||||
|
@ -371,7 +398,7 @@ class Tacotron2Decoder(nn.Layer):
|
|||
attention_weights_cat, self.mask)
|
||||
self.attention_weights_cum += self.attention_weights
|
||||
|
||||
# The second lstm layer
|
||||
# The second lstm layer (or spec decoder lstm)
|
||||
decoder_input = paddle.concat(
|
||||
[self.attention_hidden, self.attention_context], axis=-1)
|
||||
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
||||
|
@ -386,8 +413,10 @@ class Tacotron2Decoder(nn.Layer):
|
|||
[self.decoder_hidden, self.attention_context], axis=-1)
|
||||
decoder_output = self.linear_projection(
|
||||
decoder_hidden_attention_context)
|
||||
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||
return decoder_output, stop_logit, self.attention_weights
|
||||
if self.use_stop_token:
|
||||
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||
return decoder_output, self.attention_weights, stop_logit
|
||||
return decoder_output, self.attention_weights
|
||||
|
||||
def forward(self, keys, querys, mask):
|
||||
"""Calculate forward propagation of tacotron2 decoder.
|
||||
|
@ -396,131 +425,148 @@ class Tacotron2Decoder(nn.Layer):
|
|||
----------
|
||||
keys: Tensor[shape=(B, T_key, C)]
|
||||
Batch of the sequences of padded output from encoder.
|
||||
|
||||
|
||||
querys: Tensor[shape(B, T_query, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
|
||||
mask: Tensor
|
||||
Mask generated with text length. Shape should be (B, T_key, T_query) or broadcastable shape.
|
||||
|
||||
Mask generated with text length. Shape should be (B, T_key, 1).
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_output: Tensor [shape=(B, T_query, C)]
|
||||
Output sequence of features.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_query)]
|
||||
Output sequence of stop logits.
|
||||
|
||||
alignments: Tensor [shape=(B, T_query, T_key)]
|
||||
Attention weights.
|
||||
"""
|
||||
querys = paddle.reshape(
|
||||
querys,
|
||||
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
||||
querys = paddle.concat(
|
||||
[
|
||||
paddle.zeros(
|
||||
shape=[querys.shape[0], 1, querys.shape[-1]],
|
||||
dtype=querys.dtype), querys
|
||||
],
|
||||
axis=1)
|
||||
querys = self.prenet(querys)
|
||||
|
||||
self._initialize_decoder_states(keys)
|
||||
self.mask = mask
|
||||
|
||||
mel_outputs, stop_logits, alignments = [], [], []
|
||||
while len(mel_outputs) < querys.shape[
|
||||
1] - 1: # Ignore the last time step
|
||||
querys = paddle.reshape(
|
||||
querys,
|
||||
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
||||
start_step = paddle.zeros(
|
||||
shape=[querys.shape[0], 1, querys.shape[-1]], dtype=querys.dtype)
|
||||
querys = paddle.concat([start_step, querys], axis=1)
|
||||
|
||||
querys = self.prenet(querys)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_logits = []
|
||||
# Ignore the last time step
|
||||
while len(mel_outputs) < querys.shape[1] - 1:
|
||||
query = querys[:, len(mel_outputs), :]
|
||||
mel_output, stop_logit, attention_weights = self._decode(query)
|
||||
mel_outputs += [mel_output]
|
||||
stop_logits += [stop_logit]
|
||||
alignments += [attention_weights]
|
||||
if self.use_stop_token:
|
||||
mel_output, attention_weights, stop_logit = self._decode(query)
|
||||
else:
|
||||
mel_output, attention_weights = self._decode(query)
|
||||
mel_outputs.append(mel_output)
|
||||
alignments.append(attention_weights)
|
||||
if self.use_stop_token:
|
||||
stop_logits.append(stop_logit)
|
||||
|
||||
alignments = paddle.stack(alignments, axis=1)
|
||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||
if self.use_stop_token:
|
||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||
return mel_outputs, alignments, stop_logits
|
||||
return mel_outputs, alignments
|
||||
|
||||
return mel_outputs, stop_logits, alignments
|
||||
|
||||
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
def infer(self, key, max_decoder_steps=1000):
|
||||
"""Calculate forward propagation of tacotron2 decoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
keys: Tensor [shape=(B, T_key, C)]
|
||||
Batch of the sequences of padded output from encoder.
|
||||
|
||||
stop_threshold: float, optional
|
||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||
|
||||
|
||||
max_decoder_steps: int, optional
|
||||
Number of max step when synthesize. Defaults to 1000.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_output: Tensor [shape=(B, T_mel, C)]
|
||||
Output sequence of features.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
Output sequence of stop logits.
|
||||
|
||||
alignments: Tensor [shape=(B, T_mel, T_key)]
|
||||
Attention weights.
|
||||
|
||||
"""
|
||||
query = paddle.zeros(
|
||||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||
dtype=key.dtype) #[B, C]
|
||||
|
||||
self._initialize_decoder_states(key)
|
||||
self.mask = None
|
||||
self.mask = None # mask is not needed for single instance inference
|
||||
encoder_steps = key.shape[1]
|
||||
|
||||
mel_outputs, stop_logits, alignments = [], [], []
|
||||
while True:
|
||||
# [B, C]
|
||||
start_step = paddle.zeros(
|
||||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||
dtype=key.dtype)
|
||||
query = start_step # [B, C]
|
||||
first_hit_end = None
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_logits = []
|
||||
for i in trange(max_decoder_steps):
|
||||
query = self.prenet(query)
|
||||
mel_output, stop_logit, alignment = self._decode(query)
|
||||
if self.use_stop_token:
|
||||
mel_output, alignment, stop_logit = self._decode(query)
|
||||
else:
|
||||
mel_output, alignment = self._decode(query)
|
||||
|
||||
mel_outputs += [mel_output]
|
||||
stop_logits += [stop_logit]
|
||||
alignments += [alignment]
|
||||
mel_outputs.append(mel_output)
|
||||
alignments.append(alignment) # (B=1, T)
|
||||
if self.use_stop_token:
|
||||
stop_logits.append(stop_logit)
|
||||
|
||||
if F.sigmoid(stop_logit) > stop_threshold:
|
||||
break
|
||||
elif len(mel_outputs) == max_decoder_steps:
|
||||
if self.use_stop_token:
|
||||
if F.sigmoid(stop_logit) > 0.5:
|
||||
print("hit stop condition!")
|
||||
break
|
||||
else:
|
||||
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
|
||||
if first_hit_end is None:
|
||||
first_hit_end = i
|
||||
elif i > (first_hit_end + 20):
|
||||
print("content exhausted!")
|
||||
break
|
||||
if len(mel_outputs) == max_decoder_steps:
|
||||
print("Warning! Reached max decoder steps!!!")
|
||||
break
|
||||
|
||||
query = mel_output
|
||||
|
||||
alignments = paddle.stack(alignments, axis=1)
|
||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||
|
||||
return mel_outputs, stop_logits, alignments
|
||||
if self.use_stop_token:
|
||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||
return mel_outputs, alignments, stop_logits
|
||||
return mel_outputs, alignments
|
||||
|
||||
|
||||
class Tacotron2(nn.Layer):
|
||||
"""Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
||||
|
||||
This is a model of Spectrogram prediction network in Tacotron2 described
|
||||
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
|
||||
<https://arxiv.org/abs/1712.05884>`_,
|
||||
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram
|
||||
Predictions <https://arxiv.org/abs/1712.05884>`_,
|
||||
which converts the sequence of characters
|
||||
into the sequence of mel spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frontend : parakeet.frontend.Phonetics
|
||||
Frontend used to preprocess text.
|
||||
vocab_size : int
|
||||
Vocabulary size of phons of the model.
|
||||
|
||||
n_tones: int
|
||||
Vocabulary size of tones of the model. Defaults to None. If provided,
|
||||
the model has an extra tone embedding.
|
||||
|
||||
d_mels: int
|
||||
Number of mel bands.
|
||||
|
||||
|
||||
d_encoder: int
|
||||
Hidden size in encoder module.
|
||||
|
||||
|
||||
encoder_conv_layers: int
|
||||
Number of conv layers in encoder.
|
||||
|
||||
|
@ -538,7 +584,7 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
attention_filters: int
|
||||
Filter size of the conv layer in location sensitive attention.
|
||||
|
||||
|
||||
attention_kernel_size: int
|
||||
Kernel size of the conv layer in location sensitive attention.
|
||||
|
||||
|
@ -572,10 +618,16 @@ class Tacotron2(nn.Layer):
|
|||
p_postnet_dropout: float
|
||||
Droput probability in postnet.
|
||||
|
||||
d_global_condition: int
|
||||
Feature size of global condition. Defaults to None. If provided, The
|
||||
model assumes a global condition that is concatenated to the encoder
|
||||
outputs.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
frontend: parakeet.frontend.Phonetics,
|
||||
vocab_size,
|
||||
n_tones=None,
|
||||
d_mels: int=80,
|
||||
d_encoder: int=512,
|
||||
encoder_conv_layers: int=3,
|
||||
|
@ -594,24 +646,43 @@ class Tacotron2(nn.Layer):
|
|||
p_prenet_dropout: float=0.5,
|
||||
p_attention_dropout: float=0.1,
|
||||
p_decoder_dropout: float=0.1,
|
||||
p_postnet_dropout: float=0.5):
|
||||
p_postnet_dropout: float=0.5,
|
||||
d_global_condition=None,
|
||||
use_stop_token=False):
|
||||
super().__init__()
|
||||
|
||||
self.frontend = frontend
|
||||
std = math.sqrt(2.0 / (self.frontend.vocab_size + d_encoder))
|
||||
std = math.sqrt(2.0 / (vocab_size + d_encoder))
|
||||
val = math.sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding = nn.Embedding(
|
||||
self.frontend.vocab_size,
|
||||
d_encoder,
|
||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||
low=-val, high=val)))
|
||||
vocab_size, d_encoder, weight_attr=I.Uniform(-val, val))
|
||||
if n_tones:
|
||||
self.embedding_tones = nn.Embedding(
|
||||
n_tones,
|
||||
d_encoder,
|
||||
padding_idx=0,
|
||||
weight_attr=I.Uniform(-0.1 * val, 0.1 * val))
|
||||
self.toned = n_tones is not None
|
||||
|
||||
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
||||
encoder_kernel_size, p_encoder_dropout)
|
||||
|
||||
# input augmentation scheme: concat global condition to the encoder output
|
||||
if d_global_condition is not None:
|
||||
d_encoder += d_global_condition
|
||||
self.decoder = Tacotron2Decoder(
|
||||
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
|
||||
d_decoder_rnn, d_attention, attention_filters,
|
||||
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||
p_decoder_dropout)
|
||||
d_mels,
|
||||
reduction_factor,
|
||||
d_encoder,
|
||||
d_prenet,
|
||||
d_attention_rnn,
|
||||
d_decoder_rnn,
|
||||
d_attention,
|
||||
attention_filters,
|
||||
attention_kernel_size,
|
||||
p_prenet_dropout,
|
||||
p_attention_dropout,
|
||||
p_decoder_dropout,
|
||||
use_stop_token=use_stop_token)
|
||||
self.postnet = DecoderPostNet(
|
||||
d_mels=d_mels * reduction_factor,
|
||||
d_hidden=d_postnet,
|
||||
|
@ -619,79 +690,109 @@ class Tacotron2(nn.Layer):
|
|||
num_layers=postnet_conv_layers,
|
||||
dropout=p_postnet_dropout)
|
||||
|
||||
def forward(self, text_inputs, mels, text_lens, output_lens=None):
|
||||
def forward(self,
|
||||
text_inputs,
|
||||
text_lens,
|
||||
mels,
|
||||
output_lens=None,
|
||||
tones=None,
|
||||
global_condition=None):
|
||||
"""Calculate forward propagation of tacotron2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text_inputs: Tensor [shape=(B, T_text)]
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
mels: Tensor [shape(B, T_mel, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
|
||||
text_lens: Tensor [shape=(B,)]
|
||||
Batch of lengths of each text input batch.
|
||||
|
||||
|
||||
mels: Tensor [shape(B, T_mel, C)]
|
||||
Batch of the sequences of padded mel spectrogram.
|
||||
|
||||
output_lens: Tensor [shape=(B,)], optional
|
||||
Batch of lengths of each mels batch. Defaults to None.
|
||||
|
||||
|
||||
tones: Tensor [shape=(B, T_text)]
|
||||
Batch of sequences of padded tone ids.
|
||||
|
||||
global_condition: Tensor [shape(B, C)]
|
||||
Batch of global conditions. Defaults to None. If the
|
||||
`d_global_condition` of the model is not None, this input should be
|
||||
provided.
|
||||
|
||||
use_stop_token: bool
|
||||
Whether to include a binary classifier to predict the stop token.
|
||||
Defaults to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
|
||||
mel_output: output sequence of features (B, T_mel, C);
|
||||
|
||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
|
||||
|
||||
stop_logits: output sequence of stop logits (B, T_mel);
|
||||
alignments: attention weights (B, T_mel, T_text);
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text).
|
||||
stop_logits: output sequence of stop logits (B, T_mel)
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
if self.toned:
|
||||
embedded_inputs += self.embedding_tones(tones)
|
||||
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||
|
||||
mask = paddle.tensor.unsqueeze(
|
||||
paddle.fluid.layers.sequence_mask(
|
||||
x=text_lens, dtype=encoder_outputs.dtype), [-1])
|
||||
mel_outputs, stop_logits, alignments = self.decoder(
|
||||
encoder_outputs, mels, mask=mask)
|
||||
if global_condition is not None:
|
||||
global_condition = global_condition.unsqueeze(1)
|
||||
global_condition = paddle.expand(
|
||||
global_condition, [-1, encoder_outputs.shape[1], -1])
|
||||
encoder_outputs = paddle.concat(
|
||||
[encoder_outputs, global_condition], -1)
|
||||
|
||||
# [B, T_enc, 1]
|
||||
mask = sequence_mask(
|
||||
text_lens, dtype=encoder_outputs.dtype).unsqueeze(-1)
|
||||
if self.decoder.use_stop_token:
|
||||
mel_outputs, alignments, stop_logits = self.decoder(
|
||||
encoder_outputs, mels, mask=mask)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder(
|
||||
encoder_outputs, mels, mask=mask)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
||||
if output_lens is not None:
|
||||
mask = paddle.tensor.unsqueeze(
|
||||
paddle.fluid.layers.sequence_mask(x=output_lens),
|
||||
[-1]) #[B, T, 1]
|
||||
mel_outputs = mel_outputs * mask #[B, T, C]
|
||||
mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C]
|
||||
stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0]
|
||||
) * 1e3 #[B, T]
|
||||
# [B, T_dec, 1]
|
||||
mask = sequence_mask(output_lens).unsqueeze(-1)
|
||||
mel_outputs = mel_outputs * mask # [B, T, C]
|
||||
mel_outputs_postnet = mel_outputs_postnet * mask # [B, T, C]
|
||||
outputs = {
|
||||
"mel_output": mel_outputs,
|
||||
"mel_outputs_postnet": mel_outputs_postnet,
|
||||
"stop_logits": stop_logits,
|
||||
"alignments": alignments
|
||||
}
|
||||
if self.decoder.use_stop_token:
|
||||
outputs["stop_logits"] = stop_logits
|
||||
|
||||
return outputs
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
def infer(self,
|
||||
text_inputs,
|
||||
max_decoder_steps=1000,
|
||||
tones=None,
|
||||
global_condition=None):
|
||||
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text_inputs: Tensor [shape=(B, T_text)]
|
||||
Batch of the sequencees of padded character ids.
|
||||
|
||||
stop_threshold: float, optional
|
||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||
|
||||
|
||||
max_decoder_steps: int, optional
|
||||
Number of max step when synthesize. Defaults to 1000.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
@ -702,14 +803,26 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
stop_logits: output sequence of stop logits (B, T_mel);
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text).
|
||||
alignments: attention weights (B, T_mel, T_text). This key is only
|
||||
present when `use_stop_token` is True.
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
if self.toned:
|
||||
embedded_inputs += self.embedding_tones(tones)
|
||||
encoder_outputs = self.encoder(embedded_inputs)
|
||||
mel_outputs, stop_logits, alignments = self.decoder.infer(
|
||||
encoder_outputs,
|
||||
stop_threshold=stop_threshold,
|
||||
max_decoder_steps=max_decoder_steps)
|
||||
|
||||
if global_condition is not None:
|
||||
global_condition = global_condition.unsqueeze(1)
|
||||
global_condition = paddle.expand(
|
||||
global_condition, [-1, encoder_outputs.shape[1], -1])
|
||||
encoder_outputs = paddle.concat(
|
||||
[encoder_outputs, global_condition], -1)
|
||||
if self.decoder.use_stop_token:
|
||||
mel_outputs, alignments, stop_logits = self.decoder.infer(
|
||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder.infer(
|
||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
@ -717,63 +830,33 @@ class Tacotron2(nn.Layer):
|
|||
outputs = {
|
||||
"mel_output": mel_outputs,
|
||||
"mel_outputs_postnet": mel_outputs_postnet,
|
||||
"stop_logits": stop_logits,
|
||||
"alignments": alignments
|
||||
}
|
||||
if self.decoder.use_stop_token:
|
||||
outputs["stop_logits"] = stop_logits
|
||||
|
||||
return outputs
|
||||
|
||||
@paddle.no_grad()
|
||||
def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
"""Generate the mel sepctrogram of features given the sequenc of characters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
Sequence of characters.
|
||||
|
||||
stop_threshold: float, optional
|
||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
||||
|
||||
max_decoder_steps: int, optional
|
||||
Number of max step when synthesize. Defaults to 1000.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C);
|
||||
|
||||
alignments: attention weights (T_mel, T_text).
|
||||
"""
|
||||
ids = np.asarray(self.frontend(text))
|
||||
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
||||
outputs = self.infer(ids, stop_threshold, max_decoder_steps)
|
||||
return outputs['mel_outputs_postnet'][0].numpy(), outputs[
|
||||
'alignments'][0].numpy()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, frontend, config, checkpoint_path):
|
||||
"""Build a tacotron2 model from a pretrained model.
|
||||
def from_pretrained(cls, config, checkpoint_path):
|
||||
"""Build a Tacotron2 model from a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frontend: parakeet.frontend.Phonetics
|
||||
Frontend used to preprocess text.
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
Model configs.
|
||||
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
The path of pretrained model checkpoint, without extension name.
|
||||
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tacotron2
|
||||
The model build from pretrined result.
|
||||
ConditionalWaveFlow
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(frontend,
|
||||
d_mels=config.data.d_mels,
|
||||
model = cls(vocab_size=config.model.vocab_size,
|
||||
n_tones=config.model.n_tones,
|
||||
d_mels=config.data.n_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||
|
@ -791,8 +874,9 @@ class Tacotron2(nn.Layer):
|
|||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout)
|
||||
|
||||
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||
d_global_condition=config.model.d_global_condition,
|
||||
use_stop_token=config.model.use_stop_token)
|
||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
return model
|
||||
|
||||
|
@ -801,49 +885,96 @@ class Tacotron2Loss(nn.Layer):
|
|||
""" Tacotron2 Loss module
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self,
|
||||
use_stop_token_loss=True,
|
||||
use_guided_attention_loss=False,
|
||||
sigma=0.2):
|
||||
"""Tacotron 2 Criterion.
|
||||
|
||||
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
|
||||
mel_targets, stop_tokens):
|
||||
Args:
|
||||
use_stop_token_loss (bool, optional): Whether to use a loss for stop token prediction. Defaults to True.
|
||||
use_guided_attention_loss (bool, optional): Whether to use a loss for attention weights. Defaults to False.
|
||||
sigma (float, optional): Hyper-parameter sigma for guided attention loss. Defaults to 0.2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.spec_criterion = nn.MSELoss()
|
||||
self.use_stop_token_loss = use_stop_token_loss
|
||||
self.use_guided_attention_loss = use_guided_attention_loss
|
||||
self.attn_criterion = guided_attention_loss
|
||||
self.stop_criterion = paddle.nn.BCEWithLogitsLoss()
|
||||
self.sigma = sigma
|
||||
|
||||
def forward(self,
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mel_targets,
|
||||
attention_weights=None,
|
||||
slens=None,
|
||||
plens=None,
|
||||
stop_logits=None):
|
||||
"""Calculate tacotron2 loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mel_outputs: Tensor [shape=(B, T_mel, C)]
|
||||
Output mel spectrogram sequence.
|
||||
|
||||
|
||||
mel_outputs_postnet: Tensor [shape(B, T_mel, C)]
|
||||
Output mel spectrogram sequence after postnet.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
Output sequence of stop logits befor sigmoid.
|
||||
|
||||
|
||||
mel_targets: Tensor [shape=(B, T_mel, C)]
|
||||
Target mel spectrogram sequence.
|
||||
|
||||
attention_weights: Tensor [shape=(B, T_mel, T_enc)]
|
||||
Attention weights. This should be provided when
|
||||
`use_guided_attention_loss` is True.
|
||||
|
||||
stop_tokens: Tensor [shape=(B,)]
|
||||
Target stop token.
|
||||
slens: Tensor [shape=(B,)]
|
||||
Number of frames of mel spectrograms. This should be provided when
|
||||
`use_guided_attention_loss` is True.
|
||||
|
||||
plens: Tensor [shape=(B, )]
|
||||
Number of text or phone ids of each utterance. This should be
|
||||
provided when `use_guided_attention_loss` is True.
|
||||
|
||||
stop_logits: Tensor [shape=(B, T_mel)]
|
||||
Stop logits of each mel spectrogram frame. This should be provided
|
||||
when `use_stop_token_loss` is True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
losses : Dict[str, Tensor]
|
||||
|
||||
|
||||
loss: the sum of the other three losses;
|
||||
|
||||
mel_loss: MSE loss compute by mel_targets and mel_outputs;
|
||||
|
||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
|
||||
|
||||
stop_loss: stop loss computed by stop_logits and stop token.
|
||||
guided_attn_loss: Guided attention loss for attention weights;
|
||||
|
||||
stop_loss: Binary cross entropy loss for stop token prediction.
|
||||
"""
|
||||
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
||||
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
||||
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)
|
||||
total_loss = mel_loss + post_mel_loss + stop_loss
|
||||
losses = dict(
|
||||
loss=total_loss,
|
||||
mel_loss=mel_loss,
|
||||
post_mel_loss=post_mel_loss,
|
||||
stop_loss=stop_loss)
|
||||
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
|
||||
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
|
||||
total_loss = mel_loss + post_mel_loss
|
||||
if self.use_guided_attention_loss:
|
||||
gal_loss = self.attn_criterion(attention_weights, slens, plens,
|
||||
self.sigma)
|
||||
total_loss += gal_loss
|
||||
if self.use_stop_token_loss:
|
||||
T_dec = mel_targets.shape[1]
|
||||
stop_labels = F.one_hot(slens - 1, num_classes=T_dec)
|
||||
stop_token_loss = self.stop_criterion(stop_logits, stop_labels)
|
||||
total_loss += stop_token_loss
|
||||
|
||||
losses = {
|
||||
"loss": total_loss,
|
||||
"mel_loss": mel_loss,
|
||||
"post_mel_loss": post_mel_loss
|
||||
}
|
||||
if self.use_guided_attention_loss:
|
||||
losses["guided_attn_loss"] = gal_loss
|
||||
if self.use_stop_token_loss:
|
||||
losses["stop_loss"] = stop_token_loss
|
||||
return losses
|
||||
|
|
|
@ -321,10 +321,8 @@ class MLPPreNet(nn.Layer):
|
|||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, dropout):
|
||||
l1 = F.dropout(
|
||||
F.relu(self.lin1(x)), self.dropout, training=True)
|
||||
l2 = F.dropout(
|
||||
F.relu(self.lin2(l1)), self.dropout, training=True)
|
||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=True)
|
||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=True)
|
||||
l3 = self.lin3(l2)
|
||||
return l3
|
||||
|
||||
|
@ -345,8 +343,10 @@ class CNNPostNet(nn.Layer):
|
|||
c_out,
|
||||
kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding=padding))
|
||||
self.last_bn = nn.BatchNorm1D(d_output)
|
||||
padding=padding,
|
||||
momentum=0.99,
|
||||
epsilon=1e-03))
|
||||
self.last_bn = nn.BatchNorm1D(d_output, momentum=0.99, epsilon=1e-3)
|
||||
# for a layer that ends with a normalization layer that is targeted to
|
||||
# output a non zero-central output, it may take a long time to
|
||||
# train the scale and bias
|
||||
|
@ -358,6 +358,8 @@ class CNNPostNet(nn.Layer):
|
|||
x = layer(x)
|
||||
if i != (len(self.convs) - 1):
|
||||
x = F.tanh(x)
|
||||
# TODO: check it
|
||||
# x = x_in + x
|
||||
x = self.last_bn(x_in + x)
|
||||
return x
|
||||
|
||||
|
@ -378,7 +380,8 @@ class TransformerTTS(nn.Layer):
|
|||
postnet_kernel_size: int,
|
||||
max_reduction_factor: int,
|
||||
decoder_prenet_dropout: float,
|
||||
dropout: float):
|
||||
dropout: float,
|
||||
n_tones=None):
|
||||
super(TransformerTTS, self).__init__()
|
||||
|
||||
# text frontend (text normalization and g2p)
|
||||
|
@ -390,6 +393,15 @@ class TransformerTTS(nn.Layer):
|
|||
d_encoder,
|
||||
padding_idx=frontend.vocab.padding_index,
|
||||
weight_attr=I.Uniform(-0.05, 0.05))
|
||||
if n_tones:
|
||||
self.toned = True
|
||||
self.tone_embed = nn.Embedding(
|
||||
n_tones,
|
||||
d_encoder,
|
||||
padding_idx=0,
|
||||
weight_attr=I.Uniform(-0.005, 0.005))
|
||||
else:
|
||||
self.toned = False
|
||||
# position encoding matrix may be extended later
|
||||
self.encoder_pe = pe.sinusoid_positional_encoding(0, 1000, d_encoder)
|
||||
self.encoder_pe_scalar = self.create_parameter(
|
||||
|
@ -434,8 +446,9 @@ class TransformerTTS(nn.Layer):
|
|||
self.r = max_reduction_factor # set it every call
|
||||
self.drop_n_heads = 0
|
||||
|
||||
def forward(self, text, mel):
|
||||
encoded, encoder_attention_weights, encoder_mask = self.encode(text)
|
||||
def forward(self, text, mel, tones=None):
|
||||
encoded, encoder_attention_weights, encoder_mask = self.encode(
|
||||
text, tones=tones)
|
||||
mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(
|
||||
encoded, mel, encoder_mask)
|
||||
outputs = {
|
||||
|
@ -447,9 +460,11 @@ class TransformerTTS(nn.Layer):
|
|||
}
|
||||
return outputs
|
||||
|
||||
def encode(self, text):
|
||||
def encode(self, text, tones=None):
|
||||
T_enc = text.shape[-1]
|
||||
embed = self.encoder_prenet(text)
|
||||
if self.toned:
|
||||
embed += self.tone_embed(tones)
|
||||
if embed.shape[1] > self.encoder_pe.shape[0]:
|
||||
new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2)
|
||||
self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder)
|
||||
|
@ -473,7 +488,8 @@ class TransformerTTS(nn.Layer):
|
|||
# twice its length if needed
|
||||
if x.shape[1] * self.r > self.decoder_pe.shape[0]:
|
||||
new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2)
|
||||
self.decoder_pe = pe.sinusoid_positional_encoding(0, new_T, self.d_decoder)
|
||||
self.decoder_pe = pe.sinusoid_positional_encoding(0, new_T,
|
||||
self.d_decoder)
|
||||
pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :]
|
||||
x = x.scale(math.sqrt(
|
||||
self.d_decoder)) + pos_enc * self.decoder_pe_scalar
|
||||
|
@ -483,7 +499,7 @@ class TransformerTTS(nn.Layer):
|
|||
decoder_padding_mask = masking.feature_mask(
|
||||
input, axis=-1, dtype=input.dtype)
|
||||
decoder_mask = masking.combine_mask(
|
||||
decoder_padding_mask.unsqueeze(-1), no_future_mask)
|
||||
decoder_padding_mask.unsqueeze(1), no_future_mask)
|
||||
decoder_output, _, cross_attention_weights = self.decoder(
|
||||
x, encoder_output, encoder_output, encoder_padding_mask,
|
||||
decoder_mask, self.drop_n_heads)
|
||||
|
@ -502,7 +518,7 @@ class TransformerTTS(nn.Layer):
|
|||
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, input, max_length=1000, verbose=True):
|
||||
def infer(self, input, max_length=1000, verbose=True, tones=None):
|
||||
"""Predict log scale magnitude mel spectrogram from text input.
|
||||
|
||||
Args:
|
||||
|
@ -515,7 +531,7 @@ class TransformerTTS(nn.Layer):
|
|||
|
||||
# encoder the text sequence
|
||||
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
|
||||
input)
|
||||
input, tones=tones)
|
||||
for _ in trange(int(max_length // self.r) + 1):
|
||||
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
||||
encoder_output, decoder_input, encoder_padding_mask)
|
||||
|
@ -528,6 +544,7 @@ class TransformerTTS(nn.Layer):
|
|||
[decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||
|
||||
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
||||
# import pdb; pdb.set_trace()
|
||||
if paddle.any(
|
||||
paddle.argmax(
|
||||
stop_logits[0, -self.r:, :], axis=-1) ==
|
||||
|
@ -544,14 +561,6 @@ class TransformerTTS(nn.Layer):
|
|||
}
|
||||
return outputs
|
||||
|
||||
@paddle.no_grad()
|
||||
def predict(self, input, max_length=1000, verbose=True):
|
||||
text_ids = paddle.to_tensor(self.frontend(input))
|
||||
input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
||||
outputs = self.infer(input, max_length=max_length, verbose=verbose)
|
||||
outputs = {k: v[0].numpy() for k, v in outputs.items()}
|
||||
return outputs
|
||||
|
||||
def set_constants(self, reduction_factor, drop_n_heads):
|
||||
self.r = reduction_factor
|
||||
self.drop_n_heads = drop_n_heads
|
||||
|
|
|
@ -12,9 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
@ -33,7 +35,7 @@ def fold(x, n_group):
|
|||
----------
|
||||
x : Tensor [shape=(\*, time_steps)
|
||||
The input tensor.
|
||||
|
||||
|
||||
n_group : int
|
||||
The size of a group.
|
||||
|
||||
|
@ -48,37 +50,37 @@ def fold(x, n_group):
|
|||
|
||||
|
||||
class UpsampleNet(nn.LayerList):
|
||||
"""Layer to upsample mel spectrogram to the same temporal resolution with
|
||||
the corresponding waveform.
|
||||
|
||||
It consists of several conv2dtranspose layers which perform deconvolution
|
||||
"""Layer to upsample mel spectrogram to the same temporal resolution with
|
||||
the corresponding waveform.
|
||||
|
||||
It consists of several conv2dtranspose layers which perform deconvolution
|
||||
on mel and time dimension.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
upscale_factors : List[int], optional
|
||||
Time upsampling factors for each Conv2DTranspose Layer.
|
||||
|
||||
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
|
||||
Layers. Each upscale_factor is used as the ``stride`` for the
|
||||
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
|
||||
Time upsampling factors for each Conv2DTranspose Layer.
|
||||
|
||||
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
|
||||
Layers. Each upscale_factor is used as the ``stride`` for the
|
||||
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
|
||||
upsampling factor is 256.
|
||||
|
||||
|
||||
Notes
|
||||
------
|
||||
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
|
||||
transformation used to extract spectrogram features from audio.
|
||||
|
||||
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
|
||||
transformation whose ``hop_length`` equals 256 is suitable.
|
||||
|
||||
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
|
||||
transformation used to extract spectrogram features from audio.
|
||||
|
||||
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
|
||||
transformation whose ``hop_length`` equals 256 is suitable.
|
||||
|
||||
See Also
|
||||
---------
|
||||
``librosa.core.stft``
|
||||
"""
|
||||
|
||||
def __init__(self, upsample_factors):
|
||||
super(UpsampleNet, self).__init__()
|
||||
super().__init__()
|
||||
for factor in upsample_factors:
|
||||
std = math.sqrt(1 / (3 * 2 * factor))
|
||||
init = I.Uniform(-std, std)
|
||||
|
@ -98,12 +100,12 @@ class UpsampleNet(nn.LayerList):
|
|||
|
||||
def forward(self, x, trim_conv_artifact=False):
|
||||
r"""Forward pass of the ``UpsampleNet``.
|
||||
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch_size, input_channels, time_steps)]
|
||||
The input spectrogram.
|
||||
|
||||
|
||||
trim_conv_artifact : bool, optional
|
||||
Trim deconvolution artifact at each layer. Defaults to False.
|
||||
|
||||
|
@ -111,10 +113,10 @@ class UpsampleNet(nn.LayerList):
|
|||
--------
|
||||
Tensor: [shape=(batch_size, input_channels, time_steps \* upsample_factor)]
|
||||
The upsampled spectrogram.
|
||||
|
||||
|
||||
Notes
|
||||
--------
|
||||
If trim_conv_artifact is ``True``, the output time steps is less
|
||||
If trim_conv_artifact is ``True``, the output time steps is less
|
||||
than ``time_steps \* upsample_factors``.
|
||||
"""
|
||||
x = paddle.unsqueeze(x, 1) #(B, C, T) -> (B, 1, C, T)
|
||||
|
@ -128,31 +130,30 @@ class UpsampleNet(nn.LayerList):
|
|||
return x
|
||||
|
||||
|
||||
#TODO write doc
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""ResidualBlock, the basic unit of ResidualNet used in WaveFlow.
|
||||
|
||||
It has a conv2d layer, which has causal padding in height dimension and
|
||||
same paddign in width dimension. It also has projection for the condition
|
||||
"""ResidualBlock, the basic unit of ResidualNet used in WaveFlow.
|
||||
|
||||
It has a conv2d layer, which has causal padding in height dimension and
|
||||
same paddign in width dimension. It also has projection for the condition
|
||||
and output.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
channels : int
|
||||
Feature size of the input.
|
||||
|
||||
|
||||
cond_channels : int
|
||||
Featuer size of the condition.
|
||||
|
||||
|
||||
kernel_size : Tuple[int]
|
||||
Kernel size of the Convolution2d applied to the input.
|
||||
|
||||
|
||||
dilations : int
|
||||
Dilations of the Convolution2d applied to the input.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, cond_channels, kernel_size, dilations):
|
||||
super(ResidualBlock, self).__init__()
|
||||
super().__init__()
|
||||
# input conv
|
||||
std = math.sqrt(1 / channels * np.prod(kernel_size))
|
||||
init = I.Uniform(-std, std)
|
||||
|
@ -193,12 +194,12 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
def forward(self, x, condition):
|
||||
"""Compute output for a whole folded sequence.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The input.
|
||||
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
|
@ -206,7 +207,7 @@ class ResidualBlock(nn.Layer):
|
|||
-------
|
||||
res : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The residual output.
|
||||
|
||||
|
||||
skip : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The skip output.
|
||||
"""
|
||||
|
@ -223,8 +224,8 @@ class ResidualBlock(nn.Layer):
|
|||
return res, skip
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the layer for incremental computation of causal
|
||||
convolution. Reset the buffer for causal convolution.
|
||||
"""Prepare the layer for incremental computation of causal
|
||||
convolution. Reset the buffer for causal convolution.
|
||||
|
||||
Raises:
|
||||
ValueError: If not in evaluation mode.
|
||||
|
@ -233,11 +234,11 @@ class ResidualBlock(nn.Layer):
|
|||
raise ValueError("Only use start sequence at evaluation mode.")
|
||||
self._conv_buffer = None
|
||||
|
||||
# NOTE: call self.conv's weight norm hook expliccitly since
|
||||
# its weight will be visited directly in `add_input` without
|
||||
# calling its `__call__` method. If we do not trigger the weight
|
||||
# norm hook, the weight may be outdated. e.g. after loading from
|
||||
# a saved checkpoint
|
||||
# NOTE: call self.conv's weight norm hook expliccitly since
|
||||
# its weight will be visited directly in `add_input` without
|
||||
# calling its `__call__` method. If we do not trigger the weight
|
||||
# norm hook, the weight may be outdated. e.g. after loading from
|
||||
# a saved checkpoint
|
||||
# see also: https://github.com/pytorch/pytorch/issues/47588
|
||||
for hook in self.conv._forward_pre_hooks.values():
|
||||
hook(self.conv, None)
|
||||
|
@ -249,7 +250,7 @@ class ResidualBlock(nn.Layer):
|
|||
----------
|
||||
x_row : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the input.
|
||||
|
||||
|
||||
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
|
||||
A row of the condition.
|
||||
|
||||
|
@ -257,7 +258,7 @@ class ResidualBlock(nn.Layer):
|
|||
-------
|
||||
res : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the the residual output.
|
||||
|
||||
|
||||
skip : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the skip output.
|
||||
"""
|
||||
|
@ -295,21 +296,21 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
class ResidualNet(nn.LayerList):
|
||||
"""A stack of several ResidualBlocks. It merges condition at each layer.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_layer : int
|
||||
Number of ResidualBlocks in the ResidualNet.
|
||||
|
||||
|
||||
residual_channels : int
|
||||
Feature size of each ResidualBlocks.
|
||||
|
||||
|
||||
condition_channels : int
|
||||
Feature size of the condition.
|
||||
|
||||
|
||||
kernel_size : Tuple[int]
|
||||
Kernel size of each ResidualBlock.
|
||||
|
||||
|
||||
dilations_h : List[int]
|
||||
Dilation in height dimension of every ResidualBlock.
|
||||
|
||||
|
@ -328,7 +329,7 @@ class ResidualNet(nn.LayerList):
|
|||
if len(dilations_h) != n_layer:
|
||||
raise ValueError(
|
||||
"number of dilations_h should equals num of layers")
|
||||
super(ResidualNet, self).__init__()
|
||||
super().__init__()
|
||||
for i in range(n_layer):
|
||||
dilation = (dilations_h[i], 2**i)
|
||||
layer = ResidualBlock(residual_channels, condition_channels,
|
||||
|
@ -342,8 +343,8 @@ class ResidualNet(nn.LayerList):
|
|||
-----------
|
||||
x : Tensor [shape=(batch_size, channel, height, width)]
|
||||
The input.
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
Returns
|
||||
|
@ -371,7 +372,7 @@ class ResidualNet(nn.LayerList):
|
|||
----------
|
||||
x_row : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the input.
|
||||
|
||||
|
||||
condition_row : Tensor [shape=(batch_size, condition_channel, 1, width)]
|
||||
A row of the condition.
|
||||
|
||||
|
@ -379,7 +380,7 @@ class ResidualNet(nn.LayerList):
|
|||
-------
|
||||
res : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the the residual output.
|
||||
|
||||
|
||||
skip : Tensor [shape=(batch_size, channel, 1, width)]
|
||||
A row of the skip output.
|
||||
"""
|
||||
|
@ -392,27 +393,27 @@ class ResidualNet(nn.LayerList):
|
|||
|
||||
|
||||
class Flow(nn.Layer):
|
||||
"""A bijection (Reversable layer) that transform a density of latent
|
||||
"""A bijection (Reversable layer) that transform a density of latent
|
||||
variables p(Z) into a complex data distribution p(X).
|
||||
|
||||
It's an auto regressive flow. The ``forward`` method implements the
|
||||
probability density estimation. The ``inverse`` method implements the
|
||||
It's an auto regressive flow. The ``forward`` method implements the
|
||||
probability density estimation. The ``inverse`` method implements the
|
||||
sampling.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_layers : int
|
||||
Number of ResidualBlocks in the Flow.
|
||||
|
||||
|
||||
channels : int
|
||||
Feature size of the ResidualBlocks.
|
||||
|
||||
|
||||
mel_bands : int
|
||||
Feature size of the mel spectrogram (mel bands).
|
||||
|
||||
|
||||
kernel_size : Tuple[int]
|
||||
Kernel size of each ResisualBlocks in the Flow.
|
||||
|
||||
|
||||
n_group : int
|
||||
Number of timesteps to the folded into a group.
|
||||
"""
|
||||
|
@ -425,7 +426,7 @@ class Flow(nn.Layer):
|
|||
}
|
||||
|
||||
def __init__(self, n_layers, channels, mel_bands, kernel_size, n_group):
|
||||
super(Flow, self).__init__()
|
||||
super().__init__()
|
||||
# input projection
|
||||
self.input_proj = nn.utils.weight_norm(
|
||||
nn.Conv2D(
|
||||
|
@ -462,28 +463,28 @@ class Flow(nn.Layer):
|
|||
return z_out
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Probability density estimation. It is done by inversely transform
|
||||
"""Probability density estimation. It is done by inversely transform
|
||||
a sample from p(X) into a sample from p(Z).
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch, 1, height, width)]
|
||||
A input sample of the distribution p(X).
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
Returns
|
||||
--------
|
||||
z (Tensor): shape(batch, 1, height, width), the transformed sample.
|
||||
|
||||
|
||||
Tuple[Tensor, Tensor]
|
||||
The parameter of the transformation.
|
||||
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
||||
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
||||
of the transformation from x to z.
|
||||
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
||||
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
||||
transformation from x to z.
|
||||
"""
|
||||
# (B, C, H-1, W)
|
||||
|
@ -512,14 +513,14 @@ class Flow(nn.Layer):
|
|||
self.resnet.start_sequence()
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X). It is done by sample form
|
||||
"""Sampling from the the distrition p(X). It is done by sample form
|
||||
p(Z) and transform the sample. It is a auto regressive transformation.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
z : Tensor [shape=(batch, 1, height, width)]
|
||||
A sample of the distribution p(Z).
|
||||
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, height, width)]
|
||||
The local condition.
|
||||
|
||||
|
@ -527,14 +528,14 @@ class Flow(nn.Layer):
|
|||
---------
|
||||
x : Tensor [shape=(batch, 1, height, width)]
|
||||
The transformed sample.
|
||||
|
||||
|
||||
Tuple[Tensor, Tensor]
|
||||
The parameter of the transformation.
|
||||
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
||||
|
||||
logs (Tensor): shape(batch, 1, height - 1, width), the log scale
|
||||
of the transformation from x to z.
|
||||
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
||||
|
||||
b (Tensor): shape(batch, 1, height - 1, width), the shift of the
|
||||
transformation from x to z.
|
||||
"""
|
||||
z_0 = z[:, :, :1, :]
|
||||
|
@ -562,26 +563,26 @@ class Flow(nn.Layer):
|
|||
|
||||
|
||||
class WaveFlow(nn.LayerList):
|
||||
"""An Deep Reversible layer that is composed of severel auto regressive
|
||||
"""An Deep Reversible layer that is composed of severel auto regressive
|
||||
flows.
|
||||
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
n_flows : int
|
||||
Number of flows in the WaveFlow model.
|
||||
|
||||
|
||||
n_layers : int
|
||||
Number of ResidualBlocks in each Flow.
|
||||
|
||||
|
||||
n_group : int
|
||||
Number of timesteps to fold as a group.
|
||||
|
||||
|
||||
channels : int
|
||||
Feature size of each ResidualBlock.
|
||||
|
||||
|
||||
mel_bands : int
|
||||
Feature size of mel spectrogram (mel bands).
|
||||
|
||||
|
||||
kernel_size : Union[int, List[int]]
|
||||
Kernel size of the convolution layer in each ResidualBlock.
|
||||
"""
|
||||
|
@ -592,7 +593,7 @@ class WaveFlow(nn.LayerList):
|
|||
raise ValueError(
|
||||
"number of flows and number of group must be even "
|
||||
"since a permutation along group among flows is used.")
|
||||
super(WaveFlow, self).__init__()
|
||||
super().__init__()
|
||||
for _ in range(n_flows):
|
||||
self.append(
|
||||
Flow(n_layers, channels, mel_bands, kernel_size, n_group))
|
||||
|
@ -628,14 +629,14 @@ class WaveFlow(nn.LayerList):
|
|||
return x, condition
|
||||
|
||||
def forward(self, x, condition):
|
||||
"""Probability density estimation of random variable x given the
|
||||
"""Probability density estimation of random variable x given the
|
||||
condition.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(batch_size, time_steps)]
|
||||
The audio.
|
||||
|
||||
|
||||
condition : Tensor [shape=(batch_size, condition channel, time_steps)]
|
||||
The local condition (mel spectrogram here).
|
||||
|
||||
|
@ -643,9 +644,9 @@ class WaveFlow(nn.LayerList):
|
|||
--------
|
||||
z : Tensor [shape=(batch_size, time_steps)]
|
||||
The transformed random variable.
|
||||
|
||||
|
||||
log_det_jacobian: Tensor [shape=(1,)]
|
||||
The log determinant of the jacobian of the transformation from x
|
||||
The log determinant of the jacobian of the transformation from x
|
||||
to z.
|
||||
"""
|
||||
# x: (B, T)
|
||||
|
@ -675,17 +676,17 @@ class WaveFlow(nn.LayerList):
|
|||
return z, log_det_jacobian
|
||||
|
||||
def inverse(self, z, condition):
|
||||
"""Sampling from the the distrition p(X).
|
||||
|
||||
It is done by sample a ``z`` form p(Z) and transform it into ``x``.
|
||||
Each Flow transform .. math:: `z_{i-1}` to .. math:: `z_{i}` in an
|
||||
"""Sampling from the the distrition p(X).
|
||||
|
||||
It is done by sample a ``z`` form p(Z) and transform it into ``x``.
|
||||
Each Flow transform .. math:: `z_{i-1}` to .. math:: `z_{i}` in an
|
||||
autoregressive manner.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor [shape=(batch, 1, time_steps]
|
||||
A sample of the distribution p(Z).
|
||||
|
||||
|
||||
condition : Tensor [shape=(batch, condition_channel, time_steps)]
|
||||
The local condition.
|
||||
|
||||
|
@ -721,22 +722,22 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
----------
|
||||
upsample_factors : List[int]
|
||||
Upsample factors for the upsample net.
|
||||
|
||||
|
||||
n_flows : int
|
||||
Number of flows in the WaveFlow model.
|
||||
|
||||
|
||||
n_layers : int
|
||||
Number of ResidualBlocks in each Flow.
|
||||
|
||||
|
||||
n_group : int
|
||||
Number of timesteps to fold as a group.
|
||||
|
||||
|
||||
channels : int
|
||||
Feature size of each ResidualBlock.
|
||||
|
||||
|
||||
n_mels : int
|
||||
Feature size of mel spectrogram (mel bands).
|
||||
|
||||
|
||||
kernel_size : Union[int, List[int]]
|
||||
Kernel size of the convolution layer in each ResidualBlock.
|
||||
"""
|
||||
|
@ -749,7 +750,7 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
channels: int,
|
||||
n_mels: int,
|
||||
kernel_size: Union[int, List[int]]):
|
||||
super(ConditionalWaveFlow, self).__init__()
|
||||
super().__init__()
|
||||
self.encoder = UpsampleNet(upsample_factors)
|
||||
self.decoder = WaveFlow(
|
||||
n_flows=n_flows,
|
||||
|
@ -760,14 +761,14 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
kernel_size=kernel_size)
|
||||
|
||||
def forward(self, audio, mel):
|
||||
"""Compute the transformed random variable z (x to z) and the log of
|
||||
"""Compute the transformed random variable z (x to z) and the log of
|
||||
the determinant of the jacobian of the transformation from x to z.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : Tensor [shape=(B, T)]
|
||||
The audio.
|
||||
|
||||
|
||||
mel : Tensor [shape=(B, C_mel, T_mel)]
|
||||
The mel spectrogram.
|
||||
|
||||
|
@ -775,9 +776,9 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
-------
|
||||
z : Tensor [shape=(B, T)]
|
||||
The inversely transformed random variable z (x to z)
|
||||
|
||||
|
||||
log_det_jacobian: Tensor [shape=(1,)]
|
||||
the log of the determinant of the jacobian of the transformation
|
||||
the log of the determinant of the jacobian of the transformation
|
||||
from x to z.
|
||||
"""
|
||||
condition = self.encoder(mel)
|
||||
|
@ -795,13 +796,16 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
|
||||
Returns
|
||||
-------
|
||||
Tensor : [shape=(B, T)]
|
||||
Tensor : [shape=(B, T)]
|
||||
The synthesized audio, where``T <= T_mel \* upsample_factors``.
|
||||
"""
|
||||
start = time.time()
|
||||
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
||||
x = self.decoder.inverse(z, condition)
|
||||
end = time.time()
|
||||
print("time: {}s".format(end - start))
|
||||
return x
|
||||
|
||||
@paddle.no_grad()
|
||||
|
@ -811,7 +815,7 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
Parameters
|
||||
----------
|
||||
mel : np.ndarray [shape=(C_mel, T_mel)]
|
||||
Mel spectrogram of an utterance(in log-magnitude).
|
||||
Mel spectrogram of an utterance(in log-magnitude).
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -829,13 +833,13 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
"""Build a ConditionalWaveFlow model from a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
----------
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
ConditionalWaveFlow
|
||||
|
@ -858,26 +862,26 @@ class WaveFlowLoss(nn.Layer):
|
|||
Parameters
|
||||
----------
|
||||
sigma : float
|
||||
The standard deviation of the gaussian noise used in WaveFlow, by
|
||||
The standard deviation of the gaussian noise used in WaveFlow, by
|
||||
default 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self, sigma=1.0):
|
||||
super(WaveFlowLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
self.const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
|
||||
|
||||
def forward(self, z, log_det_jacobian):
|
||||
"""Compute the loss given the transformed random variable z and the
|
||||
"""Compute the loss given the transformed random variable z and the
|
||||
log_det_jacobian of transformation from x to z.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor [shape=(B, T)]
|
||||
The transformed random variable (x to z).
|
||||
|
||||
|
||||
log_det_jacobian : Tensor [shape=(1,)]
|
||||
The log of the determinant of the jacobian matrix of the
|
||||
The log of the determinant of the jacobian matrix of the
|
||||
transformation from x to z.
|
||||
|
||||
Returns
|
||||
|
|
|
@ -1,977 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Union, Sequence, List
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid.initializer as I
|
||||
import paddle.fluid.layers.distributions as D
|
||||
|
||||
from parakeet.modules.conv import Conv1dCell
|
||||
from parakeet.modules.audio import quantize, dequantize, STFT
|
||||
from parakeet.utils import checkpoint, layer_tools
|
||||
|
||||
__all__ = ["WaveNet", "ConditionalWaveNet"]
|
||||
|
||||
|
||||
def crop(x, audio_start, audio_length):
|
||||
"""Crop the upsampled condition to match audio_length.
|
||||
|
||||
The upsampled condition has the same time steps as the whole audio does.
|
||||
But since audios are sliced to 0.5 seconds randomly while conditions are
|
||||
not, upsampled conditions should also be sliced to extactly match the time
|
||||
steps of the audio slice.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(B, C, T)]
|
||||
The upsampled condition.
|
||||
audio_start : Tensor [shape=(B,), dtype:int]
|
||||
The index of the starting point of the audio clips.
|
||||
audio_length : int
|
||||
The length of the audio clip(number of samples it contaions).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(B, C, audio_length)]
|
||||
Cropped condition.
|
||||
"""
|
||||
# crop audio
|
||||
slices = [] # for each example
|
||||
# paddle now supports Tensor of shape [1] in slice
|
||||
# starts = audio_start.numpy()
|
||||
for i in range(x.shape[0]):
|
||||
start = audio_start[i]
|
||||
end = start + audio_length
|
||||
slice = paddle.slice(x[i], axes=[1], starts=[start], ends=[end])
|
||||
slices.append(slice)
|
||||
out = paddle.stack(slices)
|
||||
return out
|
||||
|
||||
|
||||
class UpsampleNet(nn.LayerList):
|
||||
"""A network used to upsample mel spectrogram to match the time steps of
|
||||
audio.
|
||||
|
||||
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose
|
||||
layer upsamples the time dimension by its `stride` times.
|
||||
|
||||
Also, each Conv2DTranspose's filter_size at frequency dimension is 3.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
upscale_factors : List[int], optional
|
||||
Time upsampling factors for each Conv2DTranspose Layer.
|
||||
|
||||
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
|
||||
Layers. Each upscale_factor is used as the ``stride`` for the
|
||||
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
|
||||
upsampling factor is 256.
|
||||
|
||||
Notes
|
||||
------
|
||||
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
|
||||
transformation used to extract spectrogram features from audio.
|
||||
|
||||
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
|
||||
transformation whose ``hop_length`` equals 256 is suitable.
|
||||
|
||||
See Also
|
||||
---------
|
||||
``librosa.core.stft``
|
||||
"""
|
||||
|
||||
def __init__(self, upscale_factors=[16, 16]):
|
||||
super(UpsampleNet, self).__init__()
|
||||
self.upscale_factors = list(upscale_factors)
|
||||
self.upscale_factor = 1
|
||||
for item in upscale_factors:
|
||||
self.upscale_factor *= item
|
||||
|
||||
for factor in self.upscale_factors:
|
||||
self.append(
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv2DTranspose(
|
||||
1,
|
||||
1,
|
||||
kernel_size=(3, 2 * factor),
|
||||
stride=(1, factor),
|
||||
padding=(1, factor // 2))))
|
||||
|
||||
def forward(self, x):
|
||||
r"""Compute the upsampled condition.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(B, F, T)]
|
||||
The condition (mel spectrogram here). ``F`` means the frequency
|
||||
bands, which is the feature size of the input.
|
||||
|
||||
In the internal Conv2DTransposes, the frequency dimension
|
||||
is treated as ``height`` dimension instead of ``in_channels``.
|
||||
|
||||
Returns:
|
||||
Tensor [shape=(B, F, T \* upscale_factor)]
|
||||
The upsampled condition.
|
||||
"""
|
||||
x = paddle.unsqueeze(x, 1)
|
||||
for sublayer in self:
|
||||
x = F.leaky_relu(sublayer(x), 0.4)
|
||||
x = paddle.squeeze(x, 1)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""A Residual block used in wavenet. Conv1D-gated-tanh Block.
|
||||
|
||||
It consists of a Conv1DCell and an Conv1D(kernel_size = 1) to integrate
|
||||
information of the condition.
|
||||
|
||||
Notes
|
||||
--------
|
||||
It does not have parametric residual or skip connection.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
residual_channels : int
|
||||
The feature size of the input. It is also the feature size of the
|
||||
residual output and skip output.
|
||||
|
||||
condition_dim : int
|
||||
The feature size of the condition.
|
||||
|
||||
filter_size : int
|
||||
Kernel size of the internal convolution cells.
|
||||
|
||||
dilation :int
|
||||
Dilation of the internal convolution cells.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
residual_channels: int,
|
||||
condition_dim: int,
|
||||
filter_size: Union[int, Sequence[int]],
|
||||
dilation: int):
|
||||
|
||||
super(ResidualBlock, self).__init__()
|
||||
dilated_channels = 2 * residual_channels
|
||||
# following clarinet's implementation, we do not have parametric residual
|
||||
# & skip connection.
|
||||
|
||||
_filter_size = filter_size[0] if isinstance(filter_size, (
|
||||
list, tuple)) else filter_size
|
||||
std = math.sqrt(1 / (_filter_size * residual_channels))
|
||||
conv = Conv1dCell(
|
||||
residual_channels,
|
||||
dilated_channels,
|
||||
filter_size,
|
||||
dilation=dilation,
|
||||
weight_attr=I.Normal(scale=std))
|
||||
self.conv = nn.utils.weight_norm(conv)
|
||||
|
||||
std = math.sqrt(1 / condition_dim)
|
||||
condition_proj = Conv1dCell(
|
||||
condition_dim,
|
||||
dilated_channels, (1, ),
|
||||
weight_attr=I.Normal(scale=std))
|
||||
self.condition_proj = nn.utils.weight_norm(condition_proj)
|
||||
|
||||
self.filter_size = filter_size
|
||||
self.dilation = dilation
|
||||
self.dilated_channels = dilated_channels
|
||||
self.residual_channels = residual_channels
|
||||
self.condition_dim = condition_dim
|
||||
|
||||
def forward(self, x, condition=None):
|
||||
"""Forward pass of the ResidualBlock.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(B, C, T)]
|
||||
The input tensor.
|
||||
|
||||
condition : Tensor, optional [shape(B, C_cond, T)]
|
||||
The condition.
|
||||
|
||||
It has been upsampled in time steps, so it has the same time steps
|
||||
as the input does.(C_cond stands for the condition's channels).
|
||||
Defaults to None.
|
||||
|
||||
Returns
|
||||
-----------
|
||||
residual : Tensor [shape=(B, C, T)]
|
||||
The residual, which is used as the input to the next ResidualBlock.
|
||||
|
||||
skip_connection : Tensor [shape=(B, C, T)]
|
||||
Tthe skip connection. This output is accumulated with that of
|
||||
other ResidualBlocks.
|
||||
"""
|
||||
h = x
|
||||
|
||||
# dilated conv
|
||||
h = self.conv(h)
|
||||
|
||||
# condition
|
||||
if condition is not None:
|
||||
h += self.condition_proj(condition)
|
||||
|
||||
# gated tanh
|
||||
content, gate = paddle.split(h, 2, axis=1)
|
||||
z = F.sigmoid(gate) * paddle.tanh(content)
|
||||
|
||||
# projection
|
||||
residual = paddle.scale(z + x, math.sqrt(.5))
|
||||
skip_connection = z
|
||||
return residual, skip_connection
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the ResidualBlock to generate a new sequence.
|
||||
|
||||
Warnings
|
||||
---------
|
||||
This method should be called before calling ``add_input`` multiple times.
|
||||
"""
|
||||
self.conv.start_sequence()
|
||||
self.condition_proj.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""Take a step input and return a step output.
|
||||
|
||||
This method works similarily with ``forward`` but in a
|
||||
``step-in-step-out`` fashion.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(B, C)]
|
||||
Input for a step.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond)]
|
||||
Condition for a step. Defaults to None.
|
||||
|
||||
Returns
|
||||
----------
|
||||
residual : Tensor [shape=(B, C)]
|
||||
The residual for a step, which is used as the input to the next
|
||||
layer of ResidualBlock.
|
||||
|
||||
skip_connection : Tensor [shape=(B, C)]
|
||||
T he skip connection for a step. This output is accumulated with
|
||||
that of other ResidualBlocks.
|
||||
"""
|
||||
h = x
|
||||
|
||||
# dilated conv
|
||||
h = self.conv.add_input(h)
|
||||
|
||||
# condition
|
||||
if condition is not None:
|
||||
h += self.condition_proj.add_input(condition)
|
||||
|
||||
# gated tanh
|
||||
content, gate = paddle.split(h, 2, axis=1)
|
||||
z = F.sigmoid(gate) * paddle.tanh(content)
|
||||
|
||||
# projection
|
||||
residual = paddle.scale(z + x, math.sqrt(0.5))
|
||||
skip_connection = z
|
||||
return residual, skip_connection
|
||||
|
||||
|
||||
class ResidualNet(nn.LayerList):
|
||||
"""The residual network in wavenet.
|
||||
|
||||
It consists of ``n_stack`` stacks, each of which consists of ``n_loop``
|
||||
ResidualBlocks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_stack : int
|
||||
Number of stacks in the ``ResidualNet``.
|
||||
|
||||
n_loop : int
|
||||
Number of ResidualBlocks in a stack.
|
||||
|
||||
residual_channels : int
|
||||
Input feature size of each ``ResidualBlock``'s input.
|
||||
|
||||
condition_dim : int
|
||||
Feature size of the condition.
|
||||
|
||||
filter_size : int
|
||||
Kernel size of the internal ``Conv1dCell`` of each ``ResidualBlock``.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_stack: int,
|
||||
n_loop: int,
|
||||
residual_channels: int,
|
||||
condition_dim: int,
|
||||
filter_size: int):
|
||||
super(ResidualNet, self).__init__()
|
||||
# double the dilation at each layer in a stack
|
||||
dilations = [2**i for i in range(n_loop)] * n_stack
|
||||
self.context_size = 1 + sum(dilations)
|
||||
for dilation in dilations:
|
||||
self.append(
|
||||
ResidualBlock(residual_channels, condition_dim, filter_size,
|
||||
dilation))
|
||||
|
||||
def forward(self, x, condition=None):
|
||||
"""Forward pass of ``ResidualNet``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(B, C, T)]
|
||||
The input.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond, T)]
|
||||
The condition, it has been upsampled in time steps, so it has the
|
||||
same time steps as the input does. Defaults to None.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, C, T)]
|
||||
The output.
|
||||
"""
|
||||
for i, func in enumerate(self):
|
||||
x, skip = func(x, condition)
|
||||
if i == 0:
|
||||
skip_connections = skip
|
||||
else:
|
||||
skip_connections = paddle.scale(skip_connections + skip,
|
||||
math.sqrt(0.5))
|
||||
return skip_connections
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the ResidualNet to generate a new sequence. This method
|
||||
should be called before starting calling ``add_input`` multiple times.
|
||||
"""
|
||||
for block in self:
|
||||
block.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""Take a step input and return a step output.
|
||||
|
||||
This method works similarily with ``forward`` but in a
|
||||
``step-in-step-out`` fashion.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor [shape=(B, C)]
|
||||
Input for a step.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond)]
|
||||
Condition for a step. Defaults to None.
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor [shape=(B, C)]
|
||||
The skip connection for a step. This output is accumulated with
|
||||
that of other ResidualBlocks.
|
||||
"""
|
||||
for i, func in enumerate(self):
|
||||
x, skip = func.add_input(x, condition)
|
||||
if i == 0:
|
||||
skip_connections = skip
|
||||
else:
|
||||
skip_connections = paddle.scale(skip_connections + skip,
|
||||
math.sqrt(0.5))
|
||||
return skip_connections
|
||||
|
||||
|
||||
class WaveNet(nn.Layer):
|
||||
"""Wavenet that transform upsampled mel spectrogram into waveform.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
n_stack : int
|
||||
``n_stack`` for the internal ``ResidualNet``.
|
||||
|
||||
n_loop : int
|
||||
``n_loop`` for the internal ``ResidualNet``.
|
||||
|
||||
residual_channels : int
|
||||
Feature size of the input.
|
||||
|
||||
output_dim : int
|
||||
Feature size of the input.
|
||||
|
||||
condition_dim : int
|
||||
Feature size of the condition (mel spectrogram bands).
|
||||
|
||||
filter_size : int
|
||||
Kernel size of the internal ``ResidualNet``.
|
||||
|
||||
loss_type : str, optional ["mog" or "softmax"]
|
||||
The output type and loss type of the model, by default "mog".
|
||||
|
||||
If "softmax", the model input is first quantized audio and the model
|
||||
outputs a discret categorical distribution.
|
||||
|
||||
If "mog", the model input is audio in floating point format, and the
|
||||
model outputs parameters for a mixture of gaussian distributions.
|
||||
Namely, the weight, mean and log scale of each gaussian distribution.
|
||||
Thus, the ``output_size`` should be a multiple of 3.
|
||||
|
||||
log_scale_min : float, optional
|
||||
Minimum value of the log scale of gaussian distributions, by default
|
||||
-9.0.
|
||||
|
||||
This is only used for computing loss when ``loss_type`` is "mog", If
|
||||
the predicted log scale is less than -9.0, it is clipped at -9.0.
|
||||
"""
|
||||
|
||||
def __init__(self, n_stack, n_loop, residual_channels, output_dim,
|
||||
condition_dim, filter_size, loss_type, log_scale_min):
|
||||
|
||||
super(WaveNet, self).__init__()
|
||||
if loss_type not in ["softmax", "mog"]:
|
||||
raise ValueError("loss_type {} is not supported".format(loss_type))
|
||||
if loss_type == "softmax":
|
||||
self.embed = nn.Embedding(output_dim, residual_channels)
|
||||
else:
|
||||
if (output_dim % 3 != 0):
|
||||
raise ValueError(
|
||||
"with Mixture of Gaussians(mog) output, the output dim must be divisible by 3, but get {}".
|
||||
format(output_dim))
|
||||
self.embed = nn.utils.weight_norm(
|
||||
nn.Linear(1, residual_channels), dim=1)
|
||||
|
||||
self.resnet = ResidualNet(n_stack, n_loop, residual_channels,
|
||||
condition_dim, filter_size)
|
||||
self.context_size = self.resnet.context_size
|
||||
|
||||
skip_channels = residual_channels # assume the same channel
|
||||
self.proj1 = nn.utils.weight_norm(
|
||||
nn.Linear(skip_channels, skip_channels), dim=1)
|
||||
self.proj2 = nn.utils.weight_norm(
|
||||
nn.Linear(skip_channels, skip_channels), dim=1)
|
||||
# if loss_type is softmax, output_dim is n_vocab of waveform magnitude.
|
||||
# if loss_type is mog, output_dim is 3 * gaussian, (weight, mean and stddev)
|
||||
self.proj3 = nn.utils.weight_norm(
|
||||
nn.Linear(skip_channels, output_dim), dim=1)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.output_dim = output_dim
|
||||
self.input_dim = 1
|
||||
self.skip_channels = skip_channels
|
||||
self.log_scale_min = log_scale_min
|
||||
|
||||
def forward(self, x, condition=None):
|
||||
"""Forward pass of ``WaveNet``.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(B, T)]
|
||||
The input waveform.
|
||||
condition : Tensor, optional [shape=(B, C_cond, T)]
|
||||
the upsampled condition. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor: [shape=(B, T, C_output)]
|
||||
The parameters of the output distributions.
|
||||
"""
|
||||
|
||||
# Causal Conv
|
||||
if self.loss_type == "softmax":
|
||||
x = paddle.clip(x, min=-1., max=0.99999)
|
||||
x = quantize(x, self.output_dim)
|
||||
x = self.embed(x) # (B, T, C)
|
||||
else:
|
||||
x = paddle.unsqueeze(x, -1) # (B, T, 1)
|
||||
x = self.embed(x) # (B, T, C)
|
||||
x = paddle.transpose(x, perm=[0, 2, 1]) # (B, C, T)
|
||||
|
||||
# Residual & Skip-conenection & linears
|
||||
z = self.resnet(x, condition)
|
||||
|
||||
z = paddle.transpose(z, [0, 2, 1])
|
||||
z = F.relu(self.proj2(F.relu(self.proj1(z))))
|
||||
|
||||
y = self.proj3(z)
|
||||
return y
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the WaveNet to generate a new sequence. This method should
|
||||
be called before starting calling ``add_input`` multiple times.
|
||||
"""
|
||||
self.resnet.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""Compute the output distribution (represented by its parameters) for
|
||||
a step. It works similarily with the ``forward`` method but in a
|
||||
``step-in-step-out`` fashion.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
x : Tensor [shape=(B,)]
|
||||
A step of the input waveform.
|
||||
|
||||
condition : Tensor, optional [shape=(B, C_cond)]
|
||||
A step of the upsampled condition. Defaults to None.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(B, C_output)]
|
||||
A step of the parameters of the output distributions.
|
||||
"""
|
||||
# Causal Conv
|
||||
if self.loss_type == "softmax":
|
||||
x = paddle.clip(x, min=-1., max=0.99999)
|
||||
x = quantize(x, self.output_dim)
|
||||
x = self.embed(x) # (B, C)
|
||||
else:
|
||||
x = paddle.unsqueeze(x, -1) # (B, 1)
|
||||
x = self.embed(x) # (B, C)
|
||||
|
||||
# Residual & Skip-conenection & linears
|
||||
z = self.resnet.add_input(x, condition)
|
||||
z = F.relu(self.proj2(F.relu(self.proj1(z)))) # (B, C)
|
||||
|
||||
# Output
|
||||
y = self.proj3(z)
|
||||
return y
|
||||
|
||||
def compute_softmax_loss(self, y, t):
|
||||
"""Compute the loss when output distributions are categorial
|
||||
distributions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The logits of the output distributions.
|
||||
|
||||
t : Tensor [shape=(B, T)]
|
||||
The target audio. The audio is first quantized then used as the
|
||||
target.
|
||||
|
||||
Notes
|
||||
-------
|
||||
Output distributions whose input contains padding is neglected in
|
||||
loss computation. So the first ``context_size`` steps does not
|
||||
contribute to the loss.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
# context size is not taken into account
|
||||
y = y[:, self.context_size:, :]
|
||||
t = t[:, self.context_size:]
|
||||
t = paddle.clip(t, min=-1.0, max=0.99999)
|
||||
quantized = quantize(t, n_bands=self.output_dim)
|
||||
label = paddle.unsqueeze(quantized, -1)
|
||||
|
||||
loss = F.softmax_with_cross_entropy(y, label)
|
||||
reduced_loss = paddle.mean(loss)
|
||||
return reduced_loss
|
||||
|
||||
def sample_from_softmax(self, y):
|
||||
"""Sample from the output distribution when the output distributions
|
||||
are categorical distriobutions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The logits of the output distributions.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, T)]
|
||||
Waveform sampled from the output distribution.
|
||||
"""
|
||||
# dequantize
|
||||
batch_size, time_steps, output_dim, = y.shape
|
||||
y = paddle.reshape(y, (batch_size * time_steps, output_dim))
|
||||
prob = F.softmax(y)
|
||||
quantized = paddle.fluid.layers.sampling_id(prob)
|
||||
samples = dequantize(quantized, n_bands=self.output_dim)
|
||||
samples = paddle.reshape(samples, (batch_size, -1))
|
||||
return samples
|
||||
|
||||
def compute_mog_loss(self, y, t):
|
||||
"""Compute the loss where output distributions is a mixture of
|
||||
Gaussians distributions.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution. It is the concatenation
|
||||
of 3 parts, the logits of every distribution, the mean of each
|
||||
distribution and the log standard deviation of each distribution.
|
||||
|
||||
Each part's shape is (B, T, n_mixture), where ``n_mixture`` means
|
||||
the number of Gaussians in the mixture.
|
||||
|
||||
t : Tensor [shape=(B, T)]
|
||||
The target audio.
|
||||
|
||||
Notes
|
||||
-------
|
||||
Output distributions whose input contains padding is neglected in
|
||||
loss computation. So the first ``context_size`` steps does not
|
||||
contribute to the loss.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
n_mixture = self.output_dim // 3
|
||||
|
||||
# context size is not taken in to account
|
||||
y = y[:, self.context_size:, :]
|
||||
t = t[:, self.context_size:]
|
||||
|
||||
w, mu, log_std = paddle.split(y, 3, axis=2)
|
||||
# 100.0 is just a large float
|
||||
log_std = paddle.clip(log_std, min=self.log_scale_min, max=100.)
|
||||
inv_std = paddle.exp(-log_std)
|
||||
p_mixture = F.softmax(w, -1)
|
||||
|
||||
t = paddle.unsqueeze(t, -1)
|
||||
if n_mixture > 1:
|
||||
# t = F.expand_as(t, log_std)
|
||||
t = paddle.expand(t, [-1, -1, n_mixture])
|
||||
|
||||
x_std = inv_std * (t - mu)
|
||||
exponent = paddle.exp(-0.5 * x_std * x_std)
|
||||
pdf_x = 1.0 / math.sqrt(2.0 * math.pi) * inv_std * exponent
|
||||
|
||||
pdf_x = p_mixture * pdf_x
|
||||
# pdf_x: [bs, len]
|
||||
pdf_x = paddle.sum(pdf_x, -1)
|
||||
per_sample_loss = -paddle.log(pdf_x + 1e-9)
|
||||
|
||||
loss = paddle.mean(per_sample_loss)
|
||||
return loss
|
||||
|
||||
def sample_from_mog(self, y):
|
||||
"""Sample from the output distribution when the output distribution
|
||||
is a mixture of Gaussian distributions.
|
||||
|
||||
Parameters
|
||||
------------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution. It is the concatenation
|
||||
of 3 parts, the logits of every distribution, the mean of each
|
||||
distribution and the log standard deviation of each distribution.
|
||||
|
||||
Each part's shape is (B, T, n_mixture), where ``n_mixture`` means
|
||||
the number of Gaussians in the mixture.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(B, T)]
|
||||
Waveform sampled from the output distribution.
|
||||
"""
|
||||
batch_size, time_steps, output_dim = y.shape
|
||||
n_mixture = output_dim // 3
|
||||
|
||||
w, mu, log_std = paddle.split(y, 3, -1)
|
||||
|
||||
reshaped_w = paddle.reshape(w, (batch_size * time_steps, n_mixture))
|
||||
prob_ids = paddle.fluid.layers.sampling_id(F.softmax(reshaped_w))
|
||||
prob_ids = paddle.reshape(prob_ids, (batch_size, time_steps))
|
||||
prob_ids = prob_ids.numpy()
|
||||
|
||||
# do it
|
||||
index = np.array([[[b, t, prob_ids[b, t]] for t in range(time_steps)]
|
||||
for b in range(batch_size)]).astype("int32")
|
||||
index_var = paddle.to_tensor(index)
|
||||
|
||||
mu_ = paddle.gather_nd(mu, index_var)
|
||||
log_std_ = paddle.gather_nd(log_std, index_var)
|
||||
|
||||
dist = D.Normal(mu_, paddle.exp(log_std_))
|
||||
samples = dist.sample(shape=[])
|
||||
samples = paddle.clip(samples, min=-1., max=1.)
|
||||
return samples
|
||||
|
||||
def sample(self, y):
|
||||
"""Sample from the output distribution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameterd of the output distribution.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, T)]
|
||||
Waveform sampled from the output distribution.
|
||||
"""
|
||||
if self.loss_type == "softmax":
|
||||
return self.sample_from_softmax(y)
|
||||
else:
|
||||
return self.sample_from_mog(y)
|
||||
|
||||
def loss(self, y, t):
|
||||
"""Compute the loss given the output distribution and the target.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
The parameters of the output distribution.
|
||||
|
||||
t : Tensor [shape=(B, T)]
|
||||
The target audio.
|
||||
|
||||
Returns
|
||||
---------
|
||||
Tensor: [shape=(1,)]
|
||||
The loss.
|
||||
"""
|
||||
if self.loss_type == "softmax":
|
||||
return self.compute_softmax_loss(y, t)
|
||||
else:
|
||||
return self.compute_mog_loss(y, t)
|
||||
|
||||
|
||||
class ConditionalWaveNet(nn.Layer):
|
||||
r"""Conditional Wavenet. An implementation of
|
||||
`WaveNet: A Generative Model for Raw Audio <http://arxiv.org/abs/1609.03499>`_.
|
||||
|
||||
It contains an UpsampleNet as the encoder and a WaveNet as the decoder.
|
||||
It is an autoregressive model that generate raw audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
upsample_factors : List[int]
|
||||
The upsampling factors of the UpsampleNet.
|
||||
|
||||
n_stack : int
|
||||
Number of convolution stacks in the WaveNet.
|
||||
|
||||
n_loop : int
|
||||
Number of convolution layers in a convolution stack.
|
||||
|
||||
Convolution layers in a stack have exponentially growing dilations,
|
||||
from 1 to .. math:: `k^{n_{loop} - 1}`, where k is the kernel size.
|
||||
|
||||
residual_channels : int
|
||||
Feature size of each ResidualBlocks.
|
||||
|
||||
output_dim : int
|
||||
Feature size of the output. See ``loss_type`` for details.
|
||||
|
||||
n_mels : int
|
||||
The number of bands of mel spectrogram.
|
||||
|
||||
filter_size : int, optional
|
||||
Convolution kernel size of each ResidualBlock, by default 2.
|
||||
|
||||
loss_type : str, optional ["mog" or "softmax"]
|
||||
The output type and loss type of the model, by default "mog".
|
||||
|
||||
If "softmax", the model input should be quantized audio and the model
|
||||
outputs a discret distribution.
|
||||
|
||||
If "mog", the model input is audio in floating point format, and the
|
||||
model outputs parameters for a mixture of gaussian distributions.
|
||||
Namely, the weight, mean and logscale of each gaussian distribution.
|
||||
Thus, the ``output_size`` should be a multiple of 3.
|
||||
|
||||
log_scale_min : float, optional
|
||||
Minimum value of the log scale of gaussian distributions, by default
|
||||
-9.0.
|
||||
|
||||
This is only used for computing loss when ``loss_type`` is "mog", If
|
||||
the predicted log scale is less than -9.0, it is clipped at -9.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
upsample_factors: List[int],
|
||||
n_stack: int,
|
||||
n_loop: int,
|
||||
residual_channels: int,
|
||||
output_dim: int,
|
||||
n_mels: int,
|
||||
filter_size: int=2,
|
||||
loss_type: str="mog",
|
||||
log_scale_min: float=-9.0):
|
||||
super(ConditionalWaveNet, self).__init__()
|
||||
self.encoder = UpsampleNet(upsample_factors)
|
||||
self.decoder = WaveNet(
|
||||
n_stack=n_stack,
|
||||
n_loop=n_loop,
|
||||
residual_channels=residual_channels,
|
||||
output_dim=output_dim,
|
||||
condition_dim=n_mels,
|
||||
filter_size=filter_size,
|
||||
loss_type=loss_type,
|
||||
log_scale_min=log_scale_min)
|
||||
|
||||
def forward(self, audio, mel, audio_start):
|
||||
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
audio : Tensor [shape=(B, T_audio)]
|
||||
ground truth waveform, used for teacher force training.
|
||||
|
||||
mel : Tensor [shape(B, F, T_mel)]
|
||||
Mel spectrogram. Note that it is the spectrogram for the whole
|
||||
utterance.
|
||||
|
||||
audio_start : Tensor [shape=(B,), dtype: int]
|
||||
Audio slices' start positions for each utterance.
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor [shape(B, T_audio - 1, C_output)]
|
||||
Parameters for the output distribution, where ``C_output`` is the
|
||||
``output_dim`` of the decoder.)
|
||||
"""
|
||||
audio_length = audio.shape[1] # audio clip's length
|
||||
condition = self.encoder(mel)
|
||||
condition_slice = crop(condition, audio_start, audio_length)
|
||||
|
||||
# shifting 1 step
|
||||
audio = audio[:, :-1]
|
||||
condition_slice = condition_slice[:, :, 1:]
|
||||
|
||||
y = self.decoder(audio, condition_slice)
|
||||
return y
|
||||
|
||||
def loss(self, y, t):
|
||||
"""Compute loss with respect to the output distribution and the target
|
||||
audio.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
y : Tensor [shape=(B, T - 1, C_output)]
|
||||
Parameters of the output distribution.
|
||||
|
||||
t : Tensor [shape(B, T)]
|
||||
target waveform.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor: [shape=(1,)]
|
||||
the loss.
|
||||
"""
|
||||
t = t[:, 1:]
|
||||
loss = self.decoder.loss(y, t)
|
||||
return loss
|
||||
|
||||
def sample(self, y):
|
||||
"""Sample from the output distribution.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
y : Tensor [shape=(B, T, C_output)]
|
||||
Parameters of the output distribution.
|
||||
|
||||
Returns
|
||||
--------
|
||||
Tensor [shape=(B, T)]
|
||||
Sampled waveform from the output distribution.
|
||||
"""
|
||||
samples = self.decoder.sample(y)
|
||||
return samples
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, mel):
|
||||
r"""Synthesize waveform from mel spectrogram.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
mel : Tensor [shape=(B, F, T)]
|
||||
The ondition (mel spectrogram here).
|
||||
|
||||
Returns
|
||||
-----------
|
||||
Tensor [shape=(B, T \* upsacle_factor)]
|
||||
Synthesized waveform.
|
||||
|
||||
``upscale_factor`` is the ``upscale_factor`` of the encoder
|
||||
``UpsampleNet``.
|
||||
"""
|
||||
condition = self.encoder(mel)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
samples = []
|
||||
|
||||
self.decoder.start_sequence()
|
||||
x_t = paddle.zeros((batch_size, ), dtype=mel.dtype)
|
||||
for i in trange(time_steps):
|
||||
c_t = condition[:, :, i] # (B, C)
|
||||
y_t = self.decoder.add_input(x_t, c_t) #(B, C)
|
||||
y_t = paddle.unsqueeze(y_t, 1)
|
||||
x_t = self.sample(y_t) # (B, 1)
|
||||
x_t = paddle.squeeze(x_t, 1) #(B,)
|
||||
samples.append(x_t)
|
||||
samples = paddle.stack(samples, -1)
|
||||
return samples
|
||||
|
||||
@paddle.no_grad()
|
||||
def predict(self, mel):
|
||||
r"""Synthesize audio from mel spectrogram.
|
||||
|
||||
The output and input are numpy arrays without batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mel : np.ndarray [shape=(C, T)]
|
||||
Mel spectrogram of an utterance.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor : np.ndarray [shape=(C, T \* upsample_factor)]
|
||||
The synthesized waveform of an utterance.
|
||||
"""
|
||||
mel = paddle.to_tensor(mel)
|
||||
mel = paddle.unsqueeze(mel, 0)
|
||||
audio = self.infer(mel)
|
||||
audio = audio[0].numpy()
|
||||
return audio
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, checkpoint_path):
|
||||
"""Build a ConditionalWaveNet model from a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
ConditionalWaveNet
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(upsample_factors=config.model.upsample_factors,
|
||||
n_stack=config.model.n_stack,
|
||||
n_loop=config.model.n_loop,
|
||||
residual_channels=config.model.residual_channels,
|
||||
output_dim=config.model.output_dim,
|
||||
n_mels=config.data.n_mels,
|
||||
filter_size=config.model.filter_size,
|
||||
loss_type=config.model.loss_type,
|
||||
log_scale_min=config.model.log_scale_min)
|
||||
layer_tools.summary(model)
|
||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
return model
|
|
@ -13,10 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
from parakeet.modules.attention import *
|
||||
from parakeet.modules.audio import *
|
||||
from parakeet.modules.conv import *
|
||||
from parakeet.modules.geometry import *
|
||||
from parakeet.modules.losses import *
|
||||
from parakeet.modules.masking import *
|
||||
from parakeet.modules.positional_encoding import *
|
||||
from parakeet.modules.transformer import *
|
||||
from parakeet.modules.transformer import *
|
||||
|
|
|
@ -316,14 +316,12 @@ class LocationSensitiveAttention(nn.Layer):
|
|||
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
|
||||
self.value = nn.Linear(d_attention, 1, bias_attr=False)
|
||||
|
||||
#Location Layer
|
||||
# Location Layer
|
||||
self.location_conv = nn.Conv1D(
|
||||
2,
|
||||
location_filters,
|
||||
location_kernel_size,
|
||||
1,
|
||||
int((location_kernel_size - 1) / 2),
|
||||
1,
|
||||
kernel_size=location_kernel_size,
|
||||
padding=int((location_kernel_size - 1) / 2),
|
||||
bias_attr=False,
|
||||
data_format='NLC')
|
||||
self.location_layer = nn.Linear(
|
||||
|
@ -352,21 +350,22 @@ class LocationSensitiveAttention(nn.Layer):
|
|||
Attention weights concat.
|
||||
|
||||
mask : Tensor, optional
|
||||
The mask. Shape should be (batch_size, times_steps_q, time_steps_k) or broadcastable shape.
|
||||
The mask. Shape should be (batch_size, times_steps_k, 1).
|
||||
Defaults to None.
|
||||
|
||||
Returns
|
||||
----------
|
||||
attention_context : Tensor [shape=(batch_size, time_steps_q, d_attention)]
|
||||
attention_context : Tensor [shape=(batch_size, d_attention)]
|
||||
The context vector.
|
||||
|
||||
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||
attention_weights : Tensor [shape=(batch_size, time_steps_k)]
|
||||
The attention weights.
|
||||
"""
|
||||
|
||||
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||
processed_attention_weights = self.location_layer(
|
||||
self.location_conv(attention_weights_cat))
|
||||
# (B, T_enc, 1)
|
||||
alignment = self.value(
|
||||
paddle.tanh(processed_attention_weights + processed_key +
|
||||
processed_query))
|
||||
|
@ -378,7 +377,7 @@ class LocationSensitiveAttention(nn.Layer):
|
|||
attention_context = paddle.matmul(
|
||||
attention_weights, value, transpose_x=True)
|
||||
|
||||
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
|
||||
attention_context = paddle.squeeze(attention_context, axis=[1])
|
||||
attention_weights = paddle.squeeze(attention_weights, axis=-1)
|
||||
attention_context = paddle.squeeze(attention_context, axis=1)
|
||||
|
||||
return attention_context, attention_weights
|
||||
|
|
|
@ -16,6 +16,8 @@ import paddle
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from scipy import signal
|
||||
import librosa
|
||||
from librosa.util import pad_center
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["quantize", "dequantize", "STFT"]
|
||||
|
@ -88,6 +90,19 @@ class STFT(nn.Layer):
|
|||
Name of window function, see `scipy.signal.get_window` for more
|
||||
details. Defaults to "hanning".
|
||||
|
||||
center : bool
|
||||
If True, the signal y is padded so that frame D[:, t] is centered
|
||||
at y[t * hop_length]. If False, then D[:, t] begins at y[t * hop_length].
|
||||
Defaults to True.
|
||||
|
||||
pad_mode : string or function
|
||||
If center=True, this argument is passed to np.pad for padding the edges
|
||||
of the signal y. By default (pad_mode="reflect"), y is padded on both
|
||||
sides with its own reflection, mirrored around its first and last
|
||||
sample respectively. If center=False, this argument is ignored.
|
||||
|
||||
|
||||
|
||||
Notes
|
||||
-----------
|
||||
It behaves like ``librosa.core.stft``. See ``librosa.core.stft`` for more
|
||||
|
@ -101,29 +116,53 @@ class STFT(nn.Layer):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||
super(STFT, self).__init__()
|
||||
def __init__(self,
|
||||
n_fft,
|
||||
hop_length=None,
|
||||
win_length=None,
|
||||
window="hanning",
|
||||
center=True,
|
||||
pad_mode="reflect"):
|
||||
super().__init__()
|
||||
# By default, use the entire frame
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
# Set the default hop, if it's not already specified
|
||||
if hop_length is None:
|
||||
hop_length = int(win_length // 4)
|
||||
|
||||
self.hop_length = hop_length
|
||||
self.n_bin = 1 + n_fft // 2
|
||||
self.n_fft = n_fft
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
# calculate window
|
||||
window = signal.get_window(window, win_length)
|
||||
window = signal.get_window(window, win_length, fftbins=True)
|
||||
|
||||
# pad window to n_fft size
|
||||
if n_fft != win_length:
|
||||
pad = (n_fft - win_length) // 2
|
||||
window = np.pad(window, ((pad, pad), ), 'constant')
|
||||
window = pad_center(window, n_fft, mode="constant")
|
||||
#lpad = (n_fft - win_length) // 2
|
||||
#rpad = n_fft - win_length - lpad
|
||||
#window = np.pad(window, ((lpad, pad), ), 'constant')
|
||||
|
||||
# calculate weights
|
||||
r = np.arange(0, n_fft)
|
||||
M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
|
||||
w_real = np.reshape(window *
|
||||
np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
(self.n_bin, 1, 1, self.n_fft))
|
||||
w_imag = np.reshape(window *
|
||||
np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
(self.n_bin, 1, 1, self.n_fft))
|
||||
|
||||
#r = np.arange(0, n_fft)
|
||||
#M = np.expand_dims(r, -1) * np.expand_dims(r, 0)
|
||||
#w_real = np.reshape(window *
|
||||
#np.cos(2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
#(self.n_bin, 1, self.n_fft))
|
||||
#w_imag = np.reshape(window *
|
||||
#np.sin(-2 * np.pi * M / n_fft)[:self.n_bin],
|
||||
#(self.n_bin, 1, self.n_fft))
|
||||
weight = np.fft.fft(np.eye(n_fft))[:self.n_bin]
|
||||
w_real = weight.real
|
||||
w_imag = weight.imag
|
||||
w = np.concatenate([w_real, w_imag], axis=0)
|
||||
w = w * window
|
||||
w = np.expand_dims(w, 1)
|
||||
self.weight = paddle.cast(
|
||||
paddle.to_tensor(w), paddle.get_default_dtype())
|
||||
|
||||
|
@ -137,23 +176,21 @@ class STFT(nn.Layer):
|
|||
|
||||
Returns
|
||||
------------
|
||||
real : Tensor [shape=(B, C, 1, frames)]
|
||||
real : Tensor [shape=(B, C, frames)]
|
||||
The real part of the spectrogram.
|
||||
|
||||
imag : Tensor [shape=(B, C, 1, frames)]
|
||||
imag : Tensor [shape=(B, C, frames)]
|
||||
The image part of the spectrogram.
|
||||
"""
|
||||
# x(batch_size, time_steps)
|
||||
# pad it first with reflect mode
|
||||
# TODO(chenfeiyu): report an issue on paddle.flip
|
||||
pad_start = paddle.reverse(x[:, 1:1 + self.n_fft // 2], axis=[1])
|
||||
pad_stop = paddle.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=[1])
|
||||
x = paddle.concat([pad_start, x, pad_stop], axis=-1)
|
||||
x = paddle.unsqueeze(x, axis=1)
|
||||
if self.center:
|
||||
x = F.pad(x, [self.n_fft // 2, self.n_fft // 2],
|
||||
data_format='NCL',
|
||||
mode=self.pad_mode)
|
||||
|
||||
# to BC1T, C=1
|
||||
x = paddle.unsqueeze(x, axis=[1, 2])
|
||||
out = F.conv2d(x, self.weight, stride=(1, self.hop_length))
|
||||
real, imag = paddle.chunk(out, 2, axis=1) # BC1T
|
||||
# to BCT, C=1
|
||||
out = F.conv1d(x, self.weight, stride=self.hop_length)
|
||||
real, imag = paddle.chunk(out, 2, axis=1) # BCT
|
||||
return real, imag
|
||||
|
||||
def power(self, x):
|
||||
|
@ -166,7 +203,7 @@ class STFT(nn.Layer):
|
|||
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, 1, T)]
|
||||
Tensor [shape=(B, C, T)]
|
||||
The power spectrum.
|
||||
"""
|
||||
real, imag = self(x)
|
||||
|
@ -183,9 +220,21 @@ class STFT(nn.Layer):
|
|||
|
||||
Returns
|
||||
------------
|
||||
Tensor [shape=(B, C, 1, T)]
|
||||
Tensor [shape=(B, C, T)]
|
||||
The magnitude of the spectrum.
|
||||
"""
|
||||
power = self.power(x)
|
||||
magnitude = paddle.sqrt(power)
|
||||
return magnitude
|
||||
|
||||
|
||||
class MelScale(nn.Layer):
|
||||
def __init__(self, sr, n_fft, n_mels, fmin, fmax):
|
||||
super().__init__()
|
||||
mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax)
|
||||
self.weight = paddle.to_tensor(mel_basis)
|
||||
|
||||
def forward(self, spec):
|
||||
# (n_mels, n_freq) * (batch_size, n_freq, n_frames)
|
||||
mel = paddle.matmul(self.weight, spec)
|
||||
return mel
|
||||
|
|
|
@ -17,15 +17,50 @@ import numpy as np
|
|||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.fluid.layers import sequence_mask
|
||||
|
||||
__all__ = [
|
||||
"guided_attention_loss",
|
||||
"weighted_mean",
|
||||
"masked_l1_loss",
|
||||
"masked_softmax_with_cross_entropy",
|
||||
"diagonal_loss",
|
||||
]
|
||||
|
||||
|
||||
def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None):
|
||||
"""Build that W matrix. shape(B, T_dec, T_enc)
|
||||
W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2))
|
||||
|
||||
See also:
|
||||
Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969.
|
||||
"""
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
dec_pos = paddle.arange(0, N).astype(dtype) / dec_lens.unsqueeze(
|
||||
-1) # n/N # shape(B, T_dec)
|
||||
enc_pos = paddle.arange(0, T).astype(dtype) / enc_lens.unsqueeze(
|
||||
-1) # t/T # shape(B, T_enc)
|
||||
W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) - enc_pos.unsqueeze(1))**2 /
|
||||
(2 * g**2))
|
||||
|
||||
dec_mask = sequence_mask(dec_lens, maxlen=N)
|
||||
enc_mask = sequence_mask(enc_lens, maxlen=T)
|
||||
mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1)
|
||||
mask = paddle.cast(mask, W.dtype)
|
||||
|
||||
W *= mask
|
||||
return W
|
||||
|
||||
|
||||
def guided_attention_loss(attention_weight, dec_lens, enc_lens, g):
|
||||
"""Guided attention loss, masked to excluded padding parts."""
|
||||
_, N, T = attention_weight.shape
|
||||
W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype)
|
||||
|
||||
total_tokens = (dec_lens * enc_lens).astype(W.dtype)
|
||||
loss = paddle.mean(paddle.sum(W * attention_weight, [1, 2]) / total_tokens)
|
||||
return loss
|
||||
|
||||
|
||||
def weighted_mean(input, weight):
|
||||
"""Weighted mean. It can also be used as masked mean.
|
||||
|
||||
|
@ -40,14 +75,10 @@ def weighted_mean(input, weight):
|
|||
----------
|
||||
Tensor [shape=(1,)]
|
||||
Weighted mean tensor with the same dtype as input.
|
||||
|
||||
Warnings
|
||||
---------
|
||||
This is not a mathematical weighted mean. It performs weighted sum and
|
||||
simple average.
|
||||
"""
|
||||
weight = paddle.cast(weight, input.dtype)
|
||||
return paddle.mean(input * weight)
|
||||
broadcast_ratio = input.size / weight.size
|
||||
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio)
|
||||
|
||||
|
||||
def masked_l1_loss(prediction, target, mask):
|
||||
|
@ -101,70 +132,3 @@ def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
|||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||
loss = weighted_mean(ce, mask)
|
||||
return loss
|
||||
|
||||
|
||||
def diagonal_loss(attentions,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
g=0.2,
|
||||
multihead=False):
|
||||
"""A metric to evaluate how diagonal a attention distribution is.
|
||||
|
||||
It is computed for batch attention distributions. For each attention
|
||||
distribution, the valid decoder time steps and encoder time steps may
|
||||
differ.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
|
||||
The attention weights from an encoder-decoder structure.
|
||||
|
||||
input_lengths : Tensor [shape=(B,)]
|
||||
The valid length for each encoder output.
|
||||
|
||||
target_lengths : Tensor [shape=(B,)]
|
||||
The valid length for each decoder output.
|
||||
|
||||
g : float, optional
|
||||
[description], by default 0.2.
|
||||
|
||||
multihead : bool, optional
|
||||
A flag indicating whether ``attentions`` is a multihead attention's
|
||||
attention distribution.
|
||||
|
||||
If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by
|
||||
default False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor [shape=(1,)]
|
||||
The diagonal loss.
|
||||
"""
|
||||
W = guided_attentions(input_lengths, target_lengths, g)
|
||||
W_tensor = paddle.to_tensor(W)
|
||||
if not multihead:
|
||||
return paddle.mean(attentions * W_tensor)
|
||||
else:
|
||||
return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1))
|
||||
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
def guided_attention(N, max_N, T, max_T, g):
|
||||
W = np.zeros((max_T, max_N), dtype=np.float32)
|
||||
for t in range(T):
|
||||
for n in range(N):
|
||||
W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
||||
# (T_dec, T_enc)
|
||||
return W
|
||||
|
||||
|
||||
def guided_attentions(input_lengths, target_lengths, g=0.2):
|
||||
B = len(input_lengths)
|
||||
max_input_len = input_lengths.max()
|
||||
max_target_len = target_lengths.max()
|
||||
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
|
||||
for b in range(B):
|
||||
W[b] = guided_attention(input_lengths[b], max_input_len,
|
||||
target_lengths[b], max_target_len, g)
|
||||
# (B, T_dec, T_enc)
|
||||
return W
|
||||
|
|
|
@ -12,18 +12,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
import time
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from visualdl import LogWriter
|
||||
|
||||
import parakeet
|
||||
from parakeet.utils import checkpoint, mp_tools
|
||||
|
||||
__all__ = ["ExperimentBase"]
|
||||
|
@ -31,40 +28,40 @@ __all__ = ["ExperimentBase"]
|
|||
|
||||
class ExperimentBase(object):
|
||||
"""
|
||||
An experiment template in order to structure the training code and take
|
||||
care of saving, loading, logging, visualization stuffs. It's intended to
|
||||
be flexible and simple.
|
||||
|
||||
So it only handles output directory (create directory for the output,
|
||||
create a checkpoint directory, dump the config in use and create
|
||||
An experiment template in order to structure the training code and take
|
||||
care of saving, loading, logging, visualization stuffs. It's intended to
|
||||
be flexible and simple.
|
||||
|
||||
So it only handles output directory (create directory for the output,
|
||||
create a checkpoint directory, dump the config in use and create
|
||||
visualizer and logger) in a standard way without enforcing any
|
||||
input-output protocols to the model and dataloader. It leaves the main
|
||||
part for the user to implement their own (setup the model, criterion,
|
||||
optimizer, define a training step, define a validation function and
|
||||
input-output protocols to the model and dataloader. It leaves the main
|
||||
part for the user to implement their own (setup the model, criterion,
|
||||
optimizer, define a training step, define a validation function and
|
||||
customize all the text and visual logs).
|
||||
|
||||
It does not save too much boilerplate code. The users still have to write
|
||||
the forward/backward/update mannually, but they are free to add
|
||||
It does not save too much boilerplate code. The users still have to write
|
||||
the forward/backward/update mannually, but they are free to add
|
||||
non-standard behaviors if needed.
|
||||
|
||||
We have some conventions to follow.
|
||||
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
|
||||
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
|
||||
``valid_loader``, ``config`` and ``args`` attributes.
|
||||
2. The config should have a ``training`` field, which has
|
||||
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
|
||||
used as the trigger to invoke validation, checkpointing and stop of the
|
||||
2. The config should have a ``training`` field, which has
|
||||
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
|
||||
used as the trigger to invoke validation, checkpointing and stop of the
|
||||
experiment.
|
||||
3. There are four methods, namely ``train_batch``, ``valid``,
|
||||
3. There are four methods, namely ``train_batch``, ``valid``,
|
||||
``setup_model`` and ``setup_dataloader`` that should be implemented.
|
||||
|
||||
Feel free to add/overwrite other methods and standalone functions if you
|
||||
Feel free to add/overwrite other methods and standalone functions if you
|
||||
need.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config: yacs.config.CfgNode
|
||||
The configuration used for the experiment.
|
||||
|
||||
|
||||
args: argparse.Namespace
|
||||
The parsed command line arguments.
|
||||
|
||||
|
@ -73,17 +70,18 @@ class ExperimentBase(object):
|
|||
>>> def main_sp(config, args):
|
||||
>>> exp = Experiment(config, args)
|
||||
>>> exp.setup()
|
||||
>>> exe.resume_or_load()
|
||||
>>> exp.run()
|
||||
>>>
|
||||
>>>
|
||||
>>> config = get_cfg_defaults()
|
||||
>>> parser = default_argument_parser()
|
||||
>>> args = parser.parse_args()
|
||||
>>> if args.config:
|
||||
>>> if args.config:
|
||||
>>> config.merge_from_file(args.config)
|
||||
>>> if args.opts:
|
||||
>>> config.merge_from_list(args.opts)
|
||||
>>> config.freeze()
|
||||
>>>
|
||||
>>>
|
||||
>>> if args.nprocs > 1 and args.device == "gpu":
|
||||
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
>>> else:
|
||||
|
@ -94,6 +92,18 @@ class ExperimentBase(object):
|
|||
self.config = config
|
||||
self.args = args
|
||||
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.iteration = 0
|
||||
self.epoch = 0
|
||||
self.train_loader = None
|
||||
self.valid_loader = None
|
||||
self.iterator = None
|
||||
self.logger = None
|
||||
self.visualizer = None
|
||||
self.output_dir = None
|
||||
self.checkpoint_dir = None
|
||||
|
||||
def setup(self):
|
||||
"""Setup the experiment.
|
||||
"""
|
||||
|
@ -115,7 +125,7 @@ class ExperimentBase(object):
|
|||
|
||||
@property
|
||||
def parallel(self):
|
||||
"""A flag indicating whether the experiment should run with
|
||||
"""A flag indicating whether the experiment should run with
|
||||
multiprocessing.
|
||||
"""
|
||||
return self.args.device == "gpu" and self.args.nprocs > 1
|
||||
|
@ -133,9 +143,9 @@ class ExperimentBase(object):
|
|||
self.model, self.optimizer)
|
||||
|
||||
def resume_or_load(self):
|
||||
"""Resume from latest checkpoint at checkpoints in the output
|
||||
"""Resume from latest checkpoint at checkpoints in the output
|
||||
directory or load a specified checkpoint.
|
||||
|
||||
|
||||
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
||||
resume training.
|
||||
"""
|
||||
|
@ -165,14 +175,15 @@ class ExperimentBase(object):
|
|||
"""Reset the train loader and increment ``epoch``.
|
||||
"""
|
||||
self.epoch += 1
|
||||
if self.parallel:
|
||||
if self.parallel and isinstance(self.train_loader.batch_sampler,
|
||||
DistributedBatchSampler):
|
||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||
self.iterator = iter(self.train_loader)
|
||||
|
||||
def train(self):
|
||||
"""The training process.
|
||||
|
||||
It includes forward/backward/update and periodical validation and
|
||||
|
||||
It includes forward/backward/update and periodical validation and
|
||||
saving.
|
||||
"""
|
||||
self.new_epoch()
|
||||
|
@ -190,14 +201,14 @@ class ExperimentBase(object):
|
|||
"""The routine of the experiment after setup. This method is intended
|
||||
to be used by the user.
|
||||
"""
|
||||
self.resume_or_load()
|
||||
try:
|
||||
self.train()
|
||||
except KeyboardInterrupt:
|
||||
except KeyboardInterrupt as exception:
|
||||
self.save()
|
||||
exit(-1)
|
||||
self.close()
|
||||
sys.exit(exception)
|
||||
finally:
|
||||
self.destory()
|
||||
self.close()
|
||||
|
||||
def setup_output_dir(self):
|
||||
"""Create a directory used for output.
|
||||
|
@ -205,12 +216,12 @@ class ExperimentBase(object):
|
|||
# output dir
|
||||
output_dir = Path(self.args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
self.output_dir = output_dir
|
||||
|
||||
def setup_checkpointer(self):
|
||||
"""Create a directory used to save checkpoints into.
|
||||
|
||||
|
||||
It is "checkpoints" inside the output directory.
|
||||
"""
|
||||
# checkpoint dir
|
||||
|
@ -220,33 +231,34 @@ class ExperimentBase(object):
|
|||
self.checkpoint_dir = checkpoint_dir
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def destory(self):
|
||||
def close(self):
|
||||
"""Close visualizer to avoid hanging after training"""
|
||||
# https://github.com/pytorch/fairseq/issues/2357
|
||||
self.visualizer.close()
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def setup_visualizer(self):
|
||||
"""Initialize a visualizer to log the experiment.
|
||||
|
||||
|
||||
The visual log is saved in the output directory.
|
||||
|
||||
|
||||
Notes
|
||||
------
|
||||
Only the main process has a visualizer with it. Use multiple
|
||||
visualizers in multiprocess to write to a same log file may cause
|
||||
Only the main process has a visualizer with it. Use multiple
|
||||
visualizers in multiprocess to write to a same log file may cause
|
||||
unexpected behaviors.
|
||||
"""
|
||||
# visualizer
|
||||
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
||||
visualizer = LogWriter(logdir=str(self.output_dir))
|
||||
|
||||
self.visualizer = visualizer
|
||||
|
||||
def setup_logger(self):
|
||||
"""Initialize a text logger to log the experiment.
|
||||
|
||||
Each process has its own text logger. The logging message is write to
|
||||
the standard output and a text file named ``worker_n.log`` in the
|
||||
output directory, where ``n`` means the rank of the process.
|
||||
|
||||
Each process has its own text logger. The logging message is write to
|
||||
the standard output and a text file named ``worker_n.log`` in the
|
||||
output directory, where ``n`` means the rank of the process.
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("INFO")
|
||||
|
@ -258,9 +270,9 @@ class ExperimentBase(object):
|
|||
|
||||
@mp_tools.rank_zero_only
|
||||
def dump_config(self):
|
||||
"""Save the configuration used for this experiment.
|
||||
|
||||
It is saved in to ``config.yaml`` in the output directory at the
|
||||
"""Save the configuration used for this experiment.
|
||||
|
||||
It is saved in to ``config.yaml`` in the output directory at the
|
||||
beginning of the experiment.
|
||||
"""
|
||||
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||
|
@ -279,13 +291,13 @@ class ExperimentBase(object):
|
|||
raise NotImplementedError("valid should be implemented.")
|
||||
|
||||
def setup_model(self):
|
||||
"""Setup model, criterion and optimizer, etc. A subclass should
|
||||
"""Setup model, criterion and optimizer, etc. A subclass should
|
||||
implement this method.
|
||||
"""
|
||||
raise NotImplementedError("setup_model should be implemented.")
|
||||
|
||||
def setup_dataloader(self):
|
||||
"""Setup training dataloader and validation dataloader. A subclass
|
||||
"""Setup training dataloader and validation dataloader. A subclass
|
||||
should implement this method.
|
||||
"""
|
||||
raise NotImplementedError("setup_dataloader should be implemented.")
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
|
||||
OBSERVATIONS = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def scope(observations):
|
||||
# make `observation` the target to report to.
|
||||
# it is basically a dictionary that stores temporary observations
|
||||
global OBSERVATIONS
|
||||
old = OBSERVATIONS
|
||||
OBSERVATIONS = observations
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
OBSERVATIONS = old
|
||||
|
||||
|
||||
def get_observations():
|
||||
global OBSERVATIONS
|
||||
return OBSERVATIONS
|
||||
|
||||
|
||||
def report(name, value):
|
||||
# a simple function to report named value
|
||||
# you can use it everywhere, it will get the default target and writ to it
|
||||
# you can think of it as std.out
|
||||
observations = get_observations()
|
||||
if observations is None:
|
||||
return
|
||||
else:
|
||||
observations[name] = value
|
|
@ -0,0 +1,91 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
from dataclasses import dataclass
|
||||
|
||||
from parakeet.training.trigger import get_trigger, IntervalTrigger
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.reporter import scope
|
||||
|
||||
|
||||
class ExtensionEntry(object):
|
||||
def __init__(self, extension, trigger, priority):
|
||||
self.extension = extension
|
||||
self.trigger = trigger
|
||||
self.priority = priority
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self,
|
||||
updater: UpdaterBase,
|
||||
stop_trigger=None,
|
||||
out='result',
|
||||
extensions=None):
|
||||
self.updater = updater
|
||||
self.extensions = {}
|
||||
self.stop_trigger = get_trigger(stop_trigger)
|
||||
self.out = Path(out)
|
||||
self.observation = {}
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||
trigger = get_trigger(trigger)
|
||||
|
||||
ordinal = 0
|
||||
modified_name = name
|
||||
while name in self.extensions:
|
||||
ordinal += 1
|
||||
modified_name = f"{name}_{ordinal}"
|
||||
|
||||
self.extensions[modified_name] = ExtensionEntry(extension, trigger,
|
||||
priority)
|
||||
|
||||
def run(self):
|
||||
# sort extensions by priorities once
|
||||
extension_order = sorted(
|
||||
self.extensions.keys(),
|
||||
key=lambda name: self.extensions[name].priority,
|
||||
reverse=True)
|
||||
extensions = [(name, self.extensions[name])
|
||||
for name in extension_order]
|
||||
|
||||
update = self.updater.update
|
||||
stop_trigger = self.stop_trigger
|
||||
|
||||
# TODO(chenfeiyu): display progress bar correctly
|
||||
# if the trainer is controlled by epoch: use 2 progressbars
|
||||
# if the trainer is controlled by iteration: use 1 progressbar
|
||||
if isinstance(stop_trigger, IntervalTrigger):
|
||||
if stop_trigger.unit is 'epoch':
|
||||
max_epoch = self.stop_trigger.period
|
||||
else:
|
||||
max_iteration = self.stop_trigger.period
|
||||
|
||||
while not stop_trigger(self):
|
||||
self.observation = {}
|
||||
# set observation as the report target
|
||||
# you can use report freely in Updater.update()
|
||||
|
||||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
if entry.trigger(self):
|
||||
entry.extension(self)
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class IntervalTrigger(object):
|
||||
def __init__(self, period: int, unit: str):
|
||||
if unit not in ("iteration", "epoch"):
|
||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||
self.period = period
|
||||
self.unit = unit
|
||||
|
||||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
if self.unit == "epoch":
|
||||
fire = not (state.epoch % self.period)
|
||||
else:
|
||||
fire = not (state.iteration % self.iteration)
|
||||
return fire
|
||||
|
||||
|
||||
def never_file_trigger(trainer):
|
||||
return False
|
||||
|
||||
|
||||
def get_trigger(trigger):
|
||||
if trigger is None:
|
||||
return never_file_trigger
|
||||
if callable(trigger):
|
||||
return trigger
|
||||
else:
|
||||
trigger = IntervalTrigger(*trigger)
|
||||
return trigger
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.io import DataLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterState:
|
||||
iteration: int = 0
|
||||
epoch: int = 0
|
||||
|
||||
|
||||
class UpdaterBase(object):
|
||||
"""An updater is the abstraction of how a model is trained given the
|
||||
dataloader and the optimizer.
|
||||
|
||||
The `update_core` method is a step in the training loop with only necessary
|
||||
operations (get a batch, forward and backward, update the parameters).
|
||||
|
||||
Other stuffs are made extensions. Visualization, saving, loading and
|
||||
periodical validation and evaluation are not considered here.
|
||||
|
||||
But even in such simplist case, things are not that simple. There is an
|
||||
attempt to standardize this process and requires only the model and
|
||||
dataset and do all the stuffs automatically. But this may hurt flexibility.
|
||||
|
||||
If we assume a batch yield from the dataloader is just the input to the
|
||||
model, we will find that some model requires more arguments, or just some
|
||||
keyword arguments. But this prevents us from over-simplifying it.
|
||||
|
||||
From another perspective, the batch may includes not just the input, but
|
||||
also the target. But the model's forward method may just need the input.
|
||||
We can pass a dict or a super-long tuple to the model and let it pick what
|
||||
it really needs. But this is an abuse of lazy interface.
|
||||
|
||||
After all, we care about how a model is trained. But just how the model is
|
||||
used for inference. We want to control how a model is trained. We just
|
||||
don't want to be messed up with other auxiliary code.
|
||||
|
||||
So the best practice is to define a model and define a updater for it.
|
||||
"""
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def update_core(self):
|
||||
pass
|
||||
|
||||
|
||||
class StandardUpdater(UpdaterBase):
|
||||
"""An example of over-simplification. Things may not be that simple, but
|
||||
you can subclass it to fit your need.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
dataloader: DataLoader,
|
||||
optimizer: Optimizer,
|
||||
loss_func=None,
|
||||
auto_new_epoch: bool=True,
|
||||
init_state: Optional[UpdaterState]=None):
|
||||
self.model = model
|
||||
self.dataloader = dataloader
|
||||
self.optimizer = optimizer
|
||||
self.loss_func = loss_func
|
||||
self.auto_new_epoch = auto_new_epoch
|
||||
self.iterator = iter(dataloader)
|
||||
|
||||
if init_state is None:
|
||||
self.state = UpdaterState()
|
||||
else:
|
||||
self.state = init_state
|
||||
|
||||
def update(self):
|
||||
self.update_core()
|
||||
self.state.iteration += 1
|
||||
|
||||
def new_epoch(self):
|
||||
self.iterator = iter(self.dataloader)
|
||||
self.state.epoch += 1
|
||||
|
||||
def update_core(self):
|
||||
model = self.model
|
||||
optimizer = self.optimizer
|
||||
loss_func = self.loss_func
|
||||
|
||||
model.train()
|
||||
optimizer.clear_grad()
|
||||
|
||||
# fetch a batch
|
||||
try:
|
||||
batch = next(self.iterator)
|
||||
except StopIteration as e:
|
||||
if self.auto_new_epoch:
|
||||
self.new_epoch()
|
||||
|
||||
# forward
|
||||
if self.loss_func is not None:
|
||||
loss = loss_func(batch)
|
||||
else:
|
||||
loss = model(batch)
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
|
||||
# update parameters
|
||||
optimizer.step()
|
|
@ -14,51 +14,22 @@
|
|||
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import librosa
|
||||
import librosa.display
|
||||
import matplotlib.pylab as plt
|
||||
from matplotlib import cm, pyplot
|
||||
|
||||
__all__ = [
|
||||
"pack_attention_images",
|
||||
"add_attention_plots",
|
||||
"plot_alignment",
|
||||
"min_max_normalize",
|
||||
"add_spectrogram_plots",
|
||||
"plot_spectrogram",
|
||||
"plot_waveform",
|
||||
"plot_multihead_alignments",
|
||||
"plot_multilayer_multihead_alignments",
|
||||
]
|
||||
|
||||
|
||||
def pack_attention_images(attention_weights, rotate=False):
|
||||
# add a box
|
||||
attention_weights = np.pad(attention_weights, [(0, 0), (1, 1), (1, 1)],
|
||||
mode="constant",
|
||||
constant_values=1.)
|
||||
if rotate:
|
||||
attention_weights = np.rot90(attention_weights, axes=(1, 2))
|
||||
n, h, w = attention_weights.shape
|
||||
|
||||
ratio = h / w
|
||||
if ratio < 1:
|
||||
cols = max(int(np.sqrt(n / ratio)), 1)
|
||||
rows = int(np.ceil(n / cols))
|
||||
else:
|
||||
rows = max(int(np.sqrt(n / ratio)), 1)
|
||||
cols = int(np.ceil(n / rows))
|
||||
extras = rows * cols - n
|
||||
#print(rows, cols, extras)
|
||||
total = np.append(attention_weights, np.zeros([extras, h, w]), axis=0)
|
||||
total = np.reshape(total, [rows, cols, h, w])
|
||||
img = np.block([[total[i, j] for j in range(cols)] for i in range(rows)])
|
||||
return img
|
||||
|
||||
|
||||
def save_figure_to_numpy(fig):
|
||||
# save it to a numpy array.
|
||||
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
||||
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
|
||||
return data
|
||||
|
||||
|
||||
def plot_alignment(alignment, title=None):
|
||||
# alignment: [encoder_steps, decoder_steps)
|
||||
fig, ax = plt.subplots(figsize=(6, 4))
|
||||
im = ax.imshow(
|
||||
alignment, aspect='auto', origin='lower', interpolation='none')
|
||||
|
@ -69,43 +40,76 @@ def plot_alignment(alignment, title=None):
|
|||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close()
|
||||
return data
|
||||
return fig
|
||||
|
||||
|
||||
def add_attention_plots(writer, tag, attention_weights, global_step):
|
||||
img = plot_alignment(attention_weights.numpy().T)
|
||||
writer.add_image(tag, img, global_step, dataformats="HWC")
|
||||
def plot_multihead_alignments(alignments, title=None):
|
||||
# alignments: [N, encoder_steps, decoder_steps)
|
||||
num_subplots = alignments.shape[0]
|
||||
|
||||
fig, axes = plt.subplots(
|
||||
figsize=(6 * num_subplots, 4),
|
||||
ncols=num_subplots,
|
||||
sharey=True,
|
||||
squeeze=True)
|
||||
for i, ax in enumerate(axes):
|
||||
im = ax.imshow(
|
||||
alignments[i], aspect='auto', origin='lower', interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if title is not None:
|
||||
xlabel += '\n\n' + title
|
||||
ax.set_xlabel(xlabel)
|
||||
if i == 0:
|
||||
ax.set_ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def add_multi_attention_plots(writer, tag, attention_weights, global_step):
|
||||
attns = [attn[0].numpy() for attn in attention_weights]
|
||||
for i, attn in enumerate(attns):
|
||||
img = pack_attention_images(attn)
|
||||
writer.add_image(
|
||||
f"{tag}/{i}",
|
||||
cm.plasma(img),
|
||||
global_step=global_step,
|
||||
dataformats="HWC")
|
||||
def plot_multilayer_multihead_alignments(alignments, title=None):
|
||||
# alignments: [num_layers, num_heads, encoder_steps, decoder_steps)
|
||||
num_layers, num_heads, *_ = alignments.shape
|
||||
|
||||
fig, axes = plt.subplots(
|
||||
figsize=(6 * num_heads, 4 * num_layers),
|
||||
nrows=num_layers,
|
||||
ncols=num_heads,
|
||||
sharex=True,
|
||||
sharey=True,
|
||||
squeeze=True)
|
||||
for i, row in enumerate(axes):
|
||||
for j, ax in enumerate(row):
|
||||
im = ax.imshow(
|
||||
alignments[i, j],
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if title is not None:
|
||||
xlabel += '\n\n' + title
|
||||
if i == num_layers - 1:
|
||||
ax.set_xlabel(xlabel)
|
||||
if j == 0:
|
||||
ax.set_ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def add_spectrogram_plots(writer, tag, spec, global_step):
|
||||
spec = spec.numpy().T
|
||||
def plot_spectrogram(spec):
|
||||
# spec: [C, T] librosa convention
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.xlabel("Frames")
|
||||
plt.ylabel("Channels")
|
||||
plt.tight_layout()
|
||||
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
plt.close()
|
||||
writer.add_image(tag, data, global_step, dataformats="HWC")
|
||||
return fig
|
||||
|
||||
|
||||
def min_max_normalize(v):
|
||||
return (v - v.min()) / (v.max() - v.min())
|
||||
def plot_waveform(wav, sr=22050):
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
im = librosa.display.waveplot(wav, sr=22050)
|
||||
plt.colorbar(im, ax=ax)
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
|
|
@ -20,7 +20,6 @@ __all__ = ["rank_zero_only"]
|
|||
|
||||
|
||||
def rank_zero_only(func):
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if dist.get_rank() != 0:
|
||||
|
|
14
setup.py
14
setup.py
|
@ -38,7 +38,6 @@ def find_version(*file_paths):
|
|||
VERSION = find_version('parakeet', '__init__.py')
|
||||
long_description = read("README.md")
|
||||
|
||||
|
||||
setup_info = dict(
|
||||
# Metadata
|
||||
name='paddle-parakeet',
|
||||
|
@ -57,9 +56,9 @@ setup_info = dict(
|
|||
'inflect',
|
||||
'librosa',
|
||||
'unidecode',
|
||||
'numba==0.47.0',
|
||||
'tqdm==4.19.8',
|
||||
'llvmlite==0.31.0',
|
||||
'numba',
|
||||
'tqdm',
|
||||
'llvmlite',
|
||||
'matplotlib',
|
||||
'visualdl>=2.0.1',
|
||||
'scipy',
|
||||
|
@ -68,9 +67,12 @@ setup_info = dict(
|
|||
# 'opencc',
|
||||
'soundfile',
|
||||
'g2p_en',
|
||||
'g2pM',
|
||||
'yacs',
|
||||
'tensorboardX',
|
||||
'visualdl',
|
||||
'pypinyin',
|
||||
'webrtcvad',
|
||||
'g2pM',
|
||||
'praatio',
|
||||
],
|
||||
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ import subprocess
|
|||
import platform
|
||||
|
||||
COPYRIGHT = '''
|
||||
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
Loading…
Reference in New Issue