Init deepvoice3 commit

This commit is contained in:
liuyibing01 2019-11-13 14:22:46 +00:00
commit b843f185ff
43 changed files with 7150 additions and 0 deletions

2
.gitingore Normal file
View File

@ -0,0 +1,2 @@
*.pyc
*.tar.*

18
LICENSE Normal file
View File

@ -0,0 +1,18 @@
Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Part of code was copied or adpated from https://github.com/r9y9/deepvoice3_pytorch/
Copyright (c) 2017: Ryuichi Yamamoto, whose applies.

22
README.md Normal file
View File

@ -0,0 +1,22 @@
# Parakeet
Parakeet aims to provide a flexible, efficient and state-of-the-art text-to-speech toolkit for the open-source community. It is built on Paddle Fluid dynamic graph, with the support of many influential TTS models proposed by [Baidu Research](http://research.baidu.com) and other academic institutions.
## Installation
### Install paddlepaddle
For faster training speed and better support, it is recommended that you install the lasted develop version of paddlepaddle. Please refer to the [quick installation guide](https://paddlepaddle.org.cn/install/quick).
### Other Requirements
Install other requirements with pip.
```bash
pip install -r requirements.txt
```
## Supported models
- [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](./deepvoice3)

0
data/__init__.py Normal file
View File

328
data/data.py Normal file
View File

@ -0,0 +1,328 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import random
import io
import platform
from os.path import dirname, join
from nnmnkwii.datasets import FileSourceDataset, FileDataSource
from os.path import join, expanduser
import random
# import global hyper parameters
from hparams import hparams
from modules import frontend
import builder
_frontend = getattr(frontend, hparams.frontend)
def _pad(seq, max_len, constant_values=0):
return np.pad(seq, (0, max_len - len(seq)),
mode="constant",
constant_values=constant_values)
def _pad_2d(x, max_len, b_pad=0):
x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)],
mode="constant",
constant_values=0)
return x
class TextDataSource(FileDataSource):
def __init__(self, data_root, speaker_id=None):
self.data_root = data_root
self.speaker_ids = None
self.multi_speaker = False
# If not None, filter by speaker_id
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with io.open(meta, "rt", encoding="utf-8") as f:
lines = f.readlines()
l = lines[0].split("|")
assert len(l) == 4 or len(l) == 5
self.multi_speaker = len(l) == 5
texts = list(map(lambda l: l.split("|")[3], lines))
if self.multi_speaker:
speaker_ids = list(map(lambda l: int(l.split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
if self.speaker_id is not None:
indices = np.array(speaker_ids) == self.speaker_id
texts = list(np.array(texts)[indices])
self.multi_speaker = False
return texts
return texts, speaker_ids
else:
return texts
def collect_features(self, *args):
if self.multi_speaker:
text, speaker_id = args
else:
text = args[0]
global _frontend
if _frontend is None:
_frontend = getattr(frontend, hparams.frontend)
seq = _frontend.text_to_sequence(
text, p=hparams.replace_pronunciation_prob)
if platform.system() == "Windows":
if hasattr(hparams, "gc_probability"):
_frontend = None # memory leaking prevention in Windows
if np.random.rand() < hparams.gc_probability:
gc.collect() # garbage collection enforced
print("GC done")
if self.multi_speaker:
return np.asarray(seq, dtype=np.int32), int(speaker_id)
else:
return np.asarray(seq, dtype=np.int32)
class _NPYDataSource(FileDataSource):
def __init__(self, data_root, col, speaker_id=None):
self.data_root = data_root
self.col = col
self.frame_lengths = []
self.speaker_id = speaker_id
def collect_files(self):
meta = join(self.data_root, "train.txt")
with io.open(meta, "rt", encoding="utf-8") as f:
lines = f.readlines()
l = lines[0].split("|")
assert len(l) == 4 or len(l) == 5
multi_speaker = len(l) == 5
self.frame_lengths = list(map(lambda l: int(l.split("|")[2]), lines))
paths = list(map(lambda l: l.split("|")[self.col], lines))
paths = list(map(lambda f: join(self.data_root, f), paths))
if multi_speaker and self.speaker_id is not None:
speaker_ids = list(map(lambda l: int(l.split("|")[-1]), lines))
# Filter by speaker_id
# using multi-speaker dataset as a single speaker dataset
indices = np.array(speaker_ids) == self.speaker_id
paths = list(np.array(paths)[indices])
self.frame_lengths = list(np.array(self.frame_lengths)[indices])
# aha, need to cast numpy.int64 to int
self.frame_lengths = list(map(int, self.frame_lengths))
return paths
def collect_features(self, path):
return np.load(path)
class MelSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(MelSpecDataSource, self).__init__(data_root, 1, speaker_id)
class LinearSpecDataSource(_NPYDataSource):
def __init__(self, data_root, speaker_id=None):
super(LinearSpecDataSource, self).__init__(data_root, 0, speaker_id)
class PartialyRandomizedSimilarTimeLengthSampler(object):
"""Partially randmoized sampler
1. Sort by lengths
2. Pick a small patch and randomize it
3. Permutate mini-batchs
"""
def __init__(self,
lengths,
batch_size=16,
batch_group_size=None,
permutate=True):
self.sorted_indices = np.argsort(lengths)
self.lengths = np.array(lengths)[self.sorted_indices]
self.batch_size = batch_size
if batch_group_size is None:
batch_group_size = min(batch_size * 32, len(self.lengths))
if batch_group_size % batch_size != 0:
batch_group_size -= batch_group_size % batch_size
self.batch_group_size = batch_group_size
assert batch_group_size % batch_size == 0
self.permutate = permutate
def __iter__(self):
indices = self.sorted_indices.copy()
batch_group_size = self.batch_group_size
s, e = 0, 0
for i in range(len(indices) // batch_group_size):
s = i * batch_group_size
e = s + batch_group_size
random.shuffle(indices[s:e])
# Permutate batches
if self.permutate:
perm = np.arange(len(indices[:e]) // self.batch_size)
random.shuffle(perm)
indices[:e] = indices[:e].reshape(
-1, self.batch_size)[perm, :].reshape(-1)
# Handle last elements
s += batch_group_size
if s < len(indices):
random.shuffle(indices[s:])
return iter(indices)
def __len__(self):
return len(self.sorted_indices)
class Dataset(object):
def __init__(self, X, Mel, Y):
self.X = X
self.Mel = Mel
self.Y = Y
# alias
self.multi_speaker = X.file_data_source.multi_speaker
def __getitem__(self, idx):
if self.multi_speaker:
text, speaker_id = self.X[idx]
return text, self.Mel[idx], self.Y[idx], speaker_id
else:
return self.X[idx], self.Mel[idx], self.Y[idx]
def __len__(self):
return len(self.X)
def make_loader(dataset, batch_size, shuffle, sampler, create_batch_fn,
trainer_count, local_rank):
assert not (
shuffle and
sampler), "shuffle and sampler should not be valid in the same time."
num_samples = len(dataset)
def wrapper():
if sampler is None:
ids = range(num_samples)
if shuffle:
random.shuffle(ids)
else:
ids = sampler
batch, batches = [], []
for idx in ids:
batch.append(dataset[idx])
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batches) >= trainer_count:
yield create_batch_fn(batches[local_rank])
batches = []
if len(batch) > 0:
batches.append(batch)
if len(batches) >= trainer_count:
yield create_batch_fn(batches[local_rank])
return wrapper
def create_batch(batch):
"""Create batch"""
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
multi_speaker = len(batch[0]) == 4
# Lengths
input_lengths = [len(x[0]) for x in batch]
max_input_len = max(input_lengths)
input_lengths = np.array(input_lengths, dtype=np.int64)
target_lengths = [len(x[1]) for x in batch]
max_target_len = max(target_lengths)
target_lengths = np.array(target_lengths, dtype=np.int64)
if max_target_len % (r * downsample_step) != 0:
max_target_len += (r * downsample_step) - max_target_len % (
r * downsample_step)
assert max_target_len % (r * downsample_step) == 0
# Set 0 for zero beginning padding
# imitates initial decoder states
b_pad = r
max_target_len += b_pad * downsample_step
x_batch = np.array(
[_pad(x[0], max_input_len) for x in batch], dtype=np.int64)
x_batch = np.expand_dims(x_batch, axis=-1)
mel_batch = np.array(
[_pad_2d(
x[1], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
# down sampling is done here
if downsample_step > 1:
mel_batch = mel_batch[:, 0::downsample_step, :]
mel_batch = np.expand_dims(np.transpose(mel_batch, axes=[0, 2, 1]), axis=2)
y_batch = np.array(
[_pad_2d(
x[2], max_target_len, b_pad=b_pad) for x in batch],
dtype=np.float32)
y_batch = np.expand_dims(np.transpose(y_batch, axes=[0, 2, 1]), axis=2)
# text positions
text_positions = np.array(
[_pad(np.arange(1, len(x[0]) + 1), max_input_len) for x in batch],
dtype=np.int64)
text_positions = np.expand_dims(text_positions, axis=-1)
max_decoder_target_len = max_target_len // r // downsample_step
# frame positions
s, e = 1, max_decoder_target_len + 1
frame_positions = np.tile(
np.expand_dims(
np.arange(
s, e, dtype=np.int64), axis=0), (len(batch), 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
# done flags
done = np.array([
_pad(
np.zeros(
len(x[1]) // r // downsample_step - 1, dtype=np.float32),
max_decoder_target_len,
constant_values=1) for x in batch
])
done = np.expand_dims(np.expand_dims(done, axis=1), axis=1)
if multi_speaker:
speaker_ids = np.expand_dims(np.array([x[3] for x in batch]), axis=-1)
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths, speaker_ids)
else:
speaker_ids = None
return (x_batch, input_lengths, mel_batch, y_batch, text_positions,
frame_positions, done, target_lengths)

209
deepvoice3/README.md Normal file
View File

@ -0,0 +1,209 @@
# Deep Voice 3 with Paddle Fluid
[中文版](README_cn.md)
Paddle fluid implementation of DeepVoice 3, a convolutional network based text-to-speech synthesis model. The implementation is based on [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654).
We implement Deepvoice3 model in paddle fluid with dynamic graph, which is convenient for flexible network architectures.
## Installation
You additionally need to download punkt and cmudict for nltk, because we tokenize text with `punkt` and convert text into phonemes with `cmudict`.
```python
import nltk
nltk.download("punkt")
nltk.download("cmudict")
```
## Model Architecture
![DeepVoice3 model architecture](./_images/model_architecture.png)
The model consists of an encoder, a decoder and a converter (and a speaker embedding for multispeaker models). The encoder, together with the decoder forms the seq2seq part of the model, and the converter forms the postnet part.
## Project Structure
```text
├── audio.py # audio processing
├── compute_timestamp_ratio.py # script to compute position rate
├── conversion # parameter conversion from pytorch model
├── requirements.txt # requirements
├── hparams.py # HParam class for deepvoice3
├── hparam_tf # hyper parameter related stuffs
├── ljspeech.py # functions for ljspeech preprocessing
├── preprocess.py # preprocrssing script
├── presets # preset hyperparameters
├── deepvoice3_paddle # DeepVoice3 model implementation
├── eval_model.py # functions for model evaluation
├── synthesis.py # script for speech synthesis
├── train_model.py # functions for model training
└── train.py # script for model training
```
## Usage
There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good are provided in the repository. See `presets` directory for details. Now we only provide preset with LJSpeech dataset (`deepvoice3_ljspeech.json`). Support for more models and datasets is pending.
Note that `preprocess.py`, `train.py` and `synthesis.py` all accept a `--preset` parameter. To ensure consistency, you should use the same preset for preprocessing, training and synthesizing.
Note that you can overwrite preset hyperparameters with command line argument `--hparams`, just pass several key-value pair in `${key}=${value}` format seperated by comma `,`. For example `--hparams="batch_size=8, nepochs=500"` can overwrite default values in the preset json file.
Some hyperparameters are only related to training, like `batch_size`, `checkpoint_interval` and you can use different values for preprocessing and training. But hyperparameters related to data preprocessing, like `num_mels` and `ref_level_db`, should be kept the same for preprocessing and training.
For more details about hyperparameters, see `hparams.py`, which contains the definition of `hparams`. Priority order of hyperparameters is command line option `--hparams` > `--preset` json configuration file > definition of hparams in `hparams.py`.
### Dataset
Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
Preprocessing with `preprocess.py`.
```bash
python preprocess.py \
--preset=${preset_json_path} \
--hparams="hyper parameters you want to overwrite" \
${name} ${in_dir} ${out_dir}
```
Now `${name}$` only supports `ljspeech`. Support for other datasets is pending.
Assuming that you use `presers/deepvoice3_ljspeech.json` for LJSpeech and the path of the unziped dataset is `./data/LJSpeech-1.1`, then you can preprocess data with the following command.
```bash
python preprocess.py \
--preset=presets/deepvoice3_ljspeech.json \
ljspeech ./data/LJSpeech-1.1/ ./data/ljspeech
```
When this is done, you will see extracted features in `./data/ljspeech` including:
1. text and corresponding file names for the extracted features in `train.txt`.
2. mel-spectrogram in `ljspeech-mel-*.npy` .
3. linear-spectrogram in `ljspeech-spec-*.npy`.
### Train on single GPU
Training the whole model on one single GPU:
```bash
export PYTHONPATH=../:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
For more details about `train.py`, see `python train.py --help`.
#### load checkpoints
We provide a trained model ([dv3.single_frame](https://paddlespeech.bj.bcebos.com/Parakeet/dv3.single_frame.tar.gz)) for downloading, which is trained with the default preset. Unzip the downloaded file with `tar xzvf dv3.single_frame.tar.gz`, you will get `config.json`, `model.pdparams` and `README.md`. `config.json` is the preset json with which the model is trained, `model.pdparams` is the parameter file, and `README.md` is a brief introduction of the model.
You can load saved checkpoint and resume training with `--checkpoint` (You only need to provide the base name of the parameter file, eg. if you want to load `model.pdparams`, just use `--checkpoint=model`). If there is also a file with the same basename and extension name `.pdopt` in the same folder with the model file (i.e. `model.pdopt`, which is the optimizer file), it is also loaded automatically. If you wan to reset optimizer states, pass `--reset-optimizer` in addition.
#### train a part of the model
You can also train parts of the model while freezing other parts, by passing `--train-seq2seq-only` or `--train-postnet-only`. When training only parts of the model, other parts should be loaded from saved checkpoint.
To train only the `seq2seq` or `postnet`, you should load from a whole model with `--checkpoint` and keep the same configurations with which the checkpoint is trained. Note that when training only the `postnet`, you should set `use_decoder_state_for_postnet_input=false`, because when train only the postnet, the postnet takes the ground truth mel-spectrogram as input. Note that the default value for `use_decoder_state_for_postnet_input` is `True`.
example:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--train-seq2seq-only \
--output=${directory_to_save_results}
```
### Training on multiple GPUs
Training on multiple GPUs with data parallel is enabled. You can run `train.py` with `paddle.distributed.launch` module. Here is the command line usage.
```bash
python -m paddle.distributed.launch \
--started_port ${port_of_the_first_worker} \
--selected_gpus ${logical_gpu_ids_to_choose} \
--log_dir ${path_of_write_log} \
training_script ...
```
`paddle.distributed.launch` parallelizes training in multiprocessing mode.`--selected_gpus` means the logical ids of the selected GPUs, and `started_port` means the port used by the first worker. Outputs of each process are saved in `--log_dir.` Then follows the command for training on a single GPU, except that you should pass `--use-data-paralle` in addition.
```bash
export CUDA_VISIBLE_DEVICES=2,3,4,5 # The IDs of visible physical devices
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3 --log_dir ${multi_gpu_log_dir} \
train.py --data-root=${data-root} \
--use-gpu --use-data-parallel \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
In the example above, we set only GPU `2, 3, 4, 5` to be visible. Then `--selected_gpus="0, 1, 2, 3"` means the logical ids of the selected gpus, which correponds to GPU `2, 3, 4, 5`.
Model checkpoints (`*.pdparams` for the model and `*.pdopt` for the optimizer) are saved in `${directory_to_save_results}/checkpoints` per 10000 steps by default. Layer-wise averaged attention alignments (.png) are saved in `${directory_to_save_results}/checkpoints/alignment_ave`. And alignments for each attention layer are saved in `${directory_to_save_results}/checkpoints/alignment_layer{attention_layer_num}` per 10000 steps for inspection.
Synthesis results of 6 sentences (hardcoded in `eval_model.py`) are saved in `${directory_to_save_results}/checkpoints/eval`, including `step{step_num}_text{text_id}_single_alignment.png` for averaged alignments and `step{step_num}_text{text_id}_single_predicted.wav` for the predicted waveforms.
### Monitor with Tensorboard
Logs with tensorboard are saved in `${directory_to_save_results}/log/` directory by default. You can monitor logs by tensorboard.
```bash
tensorboard --logdir=${log_dir} --host=$HOSTNAME --port=8888
```
### Synthesize from a checkpoint
Given a list of text, `synthesis.py` synthesize audio signals from a trained model.
```bash
python synthesis.py --use-gpu --preset=${preset_json_path} \
--hparams="parameters you may want to override" \
${checkpoint} ${text_list_file} ${dst_dir}
```
Example test_list.txt:
```text
Generative adversarial network or variational auto-encoder.
Once upon a time there was a dear little girl who was loved by every one who looked at her, but most of all by her grandmother, and there was nothing that she would not have given to the child.
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```
generated waveform files and alignment files are saved in `${dst_dir}`.
### Compute position ratio
According to [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654), the position rate is different for different datasets. There are 2 position rates, one for the query and the other for the key, which are referred to as $\omega_1$ and $\omega_2$ in th paper, and the corresponding names in preset json are `query_position_rate` and `key_position_rate`.
For example, the `query_position_rate` and `key_position_rate` for LJSpeech are `1.0` and `1.385`, respectively. Fix the `query_position_rate` as 1.0, the `key_position_rate` is computed with `compute_timestamp_ratio.py`. Run the command below, where `${data_root}` means the path of the preprocessed dataset.
```bash
python compute_timestamp_ratio.py --preset=${preset_json_path} \
--hparams="parameters you may want to override" ${data_root}
```
You will get outputs like this.
```text
100%|██████████████████████████████████████████████████████████| 13047/13047 [00:12<00:00, 1058.19it/s]
1345587 1863884.0 1.3851828235558161
```
Then set the `key_position_rate=1.385` and `query_position_rate=1.0` in the preset.
## Acknowledgement
We thankfully included and adapted some files from r9y9's [deepvoice3_pytorch](https://github.com/r9y9/deepvoice3_pytorch).

224
deepvoice3/README_cn.md Normal file
View File

@ -0,0 +1,224 @@
# Deep Voice 3 with Paddle Fluid
[English](README.md)
Paddle 实现的 Deepvoice3一个基于卷积神经网络的语音合成 (Text to Speech) 模型。本实现基于 [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654) 。
本 Deepvoice3 实现使用 Paddle 动态图模式,这对于灵活的网络结构更为方便。
## 安装
### 安装 paddlepaddle 框架
为了更快的训练速度和更好的支持,我们推荐使用最新的开发版 paddle。用户可以最新编译的开发版 whl 包,也可以选择从源码编译 Paddle。
1. 下载最新编译的开发版 whl 包。可以从 [**多版本 wheel 包列表-dev**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/Tables.html#whl-dev) 页面中选择合适的版本。
2. 从源码编译 Paddle. 参考[**从源码编译**](https://www.paddlepaddle.org.cn/documentation/docs/zh/beginners_guide/install/compile/fromsource.html) 页面。注意,如果你需要使用多卡训练,那么编译前需要设置选项 `-DWITH_DISTRIBUTE=ON`
### 其他依赖
使用 pip 安装其他依赖。
```bash
pip install -r requirements.txt
```
另外需要下载 nltk 的两个库,因为使用了 `punkt` 对文本进行 tokenization并且使用了 `cmudict` 来将文本转为音位。
```python
import nltk
nltk.download("punkt")
nltk.download("cmudict")
```
## 模型结构
![DeepVoice3 模型结构](./_images/model_architecture.png)
模型包含 encoder, decoder, converter 几个部分,对于 multispeaker 数据集,还有一个 speaker embedding。其中 encoder 和 decoder 构成 seq2seq 部分converter 构成 postnet 部分。
## 项目结构
```text
├── audio.py # 用于处理处理音频的函数
├── compute_timestamp_ratio.py # 计算 position rate 的脚本
├── conversion # 用于转换 pytorch 实现的参数
├── requirements.txt # 项目依赖
├── hparams.py # DeepVoice3 运行超参数配置类的定义
├── hparam_tf # 超参数相关
├── ljspeech.py # ljspeech 数据集预处理
├── preprocess.py # 通用预处理脚本
├── presets # 预设超参数配置
├── deepvoice3_paddle # DeepVoice3 模型实现的主要文件
├── eval_model.py # 模型测评相关函数
├── synthesis.py # 用于语音合成的脚本
├── train_model.py # 模型训练相关函数
└── train.py # 用于模型训练的脚本
```
## 使用方法
根据所使用的模型配置和数据集的不同,有不少超参数需要进行调节。我们提供已知结果较好的超参数设置,详见 `presets` 文件夹。目前我们只提供 LJSpeech 的预设配置 `deepvoice3_ljspeech.json`)。后续将提供更多模型和数据集的预设配置。
`preprocess.py``train.py``synthesis.py` 都接受 `--preset` 参数。为了保持一致性,最好在数据预处理,模型训练和语音合成时使用相同的预设配置。
可以通过 `--hparams` 参数来覆盖预设的超参数配置,参数格式是逗号分隔的键值对 `${key}=${value}`,例如 `--hparams="batch_size=8, nepochs=500"`
部分参数只和训练有关,如 `batch_size`, `checkpoint_interval`, 用户在训练时可以使用不同的值。但部分参数和数据预处理相关,如 `num_mels``ref_level_db`, 这些参数在数据预处理和训练时候应该保持一致。
关于超参数设置更多细节可以参考 `hparams.py` ,其中定义了 hparams。超参数的优先级序列是通过命令行参数 `--hparams` 传入的参数优先级高于通过 `--preset` 参数传入的 json 配置文件,高于 `hparams.py` 中的定义。
### 数据集
下载并解压 [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) 数据集。
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
使用 `preprocess.py`进行预处理。
```bash
python preprocess.py \
--preset=${preset_json_path} \
--hparams="hyper parameters you want to overwrite" \
${name} ${in_dir} ${out_dir}
```
目前 `${name}$` 只支持 `ljspeech`。未来将会支持更多数据集。
假设你使用 `presers/deepvoice3_ljspeech.json` 作为处理 LJSpeech 的预设配置文件,并且解压后的数据集位于 `./data/LJSpeech-1.1`, 那么使用如下的命令进行数据预处理。
```bash
python preprocess.py \
--preset=presets/deepvoice3_ljspeech.json \
ljspeech ./data/LJSpeech-1.1/ ./data/ljspeech
```
数据处理完成后,你会在 `./data/ljspeech` 看到提取的特征,包含如下文件。
1. `train.txt`,包含文本和对应的音频特征的文件名。
2. `ljspeech-mel-*.npy`,包含 mel 频谱。
3. `ljspeech-spec-*.npy`,包含线性频谱。
### 使用 GPU 单卡训练
在单个 GPU 上训练整个模型的使用方法如下。
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override"
```
用于可以通过 `python train.py --help` 查看 `train.py` 的详细使用方法。
#### 加载保存的模型
我们提供了使用默认的配置文件训练的模型 [dv3.single_frame](https://paddlespeech.bj.bcebos.com/Parakeet/dv3.single_frame.tar.gz) 供用户下载。使用 `tar xzvf dv3.single_frame.tar.gz` 解压下载的文件,会得到 `config.json`, `model.pdparams` and `README.md`。其中 `config.json` 是模型训练时使用的配置文件,`model.pdparams` 是参数文件,`README.md` 是模型的简要说明。
用户可以通过 `--checkpoint` 参数加载保存的模型并恢复训练(注意:只需要传基础文件名,不需要扩展名,例如需要加载 `model.pdparams` 那么,只需要使用 `--checkpoint=model`)。如果同一个文件夹内有一个和参数文件基础文件名相同,而后缀为 `.pdopt` 的文件,(如 `model.pdopt`,即优化器文件),那么该文件也会被自动加载。如果你想要重置优化器的状态,在训练脚本加入 `--reset-optimizer` 参数。
#### 训练模型的一部分
用户可以通过 `--train-seq2seq-only` 或者 `--train-postnet-only` 来实现固定模型的其他部分,只训练需要训练的部分。但当只训练模型的一部分时,其他的部分需要从保存的模型中加载。
当只训练模型的 `seq2seq` 部分或者 `postnet` 部分时,需要使用 `--checkpoint` 加载整个模型并保持相同的配置。注意,当只训练 `postnet` 的时候,需要保证配置中的`use_decoder_state_for_postnet_input=false`因为在这种情况下postnet 使用真实的 mel 频谱作为输入。注意,`use_decoder_state_for_postnet_input` 的默认值是 `True`
示例:
```bash
export CUDA_VISIBLE_DEVICES=0
python train.py --data-root=${data-root} --use-gpu \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--train-seq2seq-only \
--output=${directory_to_save_results}
```
### 使用 GPU 多卡训练
本模型支持使用多个 GPU 通过数据并行的方式训练。方法是使用 `paddle.distributed.launch` 模块来启动 `train.py`
```bash
python -m paddle.distributed.launch \
--started_port ${port_of_the_first_worker} \
--selected_gpus ${logical_gpu_ids_to_choose} \
--log_dir ${path_to_write_log} \
training_script ...
```
paddle.distributed.launch 通过多进程的方式进行并行训练。`--selected_gpus` 指的是选择的 GPU 的逻辑序号,`started_port` 指的是 0 号显卡的使用的端口号,`--log_dir` 是日志保存的目录,每个进程的输出会在这个文件夹中保存为单独的文件。再在后面接上需要启动的脚本文件及其参数即可。这部分和单卡训练的脚本一致,但是需要传入 `--use-data-paralle` 以使用数据并行训练。示例命令如下。
```bash
export CUDA_VISIBLE_DEVICES=2,3,4,5 # The IDs of visible physical devices
python -m paddle.distributed.launch \
--selected_gpus=0,1,2,3 --log_dir ${multi_gpu_log_dir} \
train.py --data-root=${data-root} \
--use-gpu --use-data-parallel \
--preset=${preset_json_path} \
--hparams="parameters you may want to override" \
--output=${directory_to_save_results}
```
上述的示例中,设置了 `2, 3, 4, 5` 号显卡为可见的 GPU。然后 `--selected_gpus=0,1,2,3` 选择的是 GPU 的逻辑序号,分别对应于 `2, 3, 4, 5` 号卡。
模型 (模型参数保存为`*.pdparams` 文件,优化器被保存为 `*.pdopt` 文件)保存在 `${directory_to_save_results}/checkpoints` 文件夹中。多层平均的注意力机制对齐结果被保存为 `.png` 图片,默认保存在 `${directory_to_save_results}/checkpoints/alignment_ave` 中。每一层的注意力机制对齐结果默认被保存在 `${directory_to_save_results}/checkpoints/alignment_layer{attention_layer_num}`文件夹中。默认每 10000 步保存一次用于查看。
对 6 个给定的句子的语音合成结果保存在 `${directory_to_save_results}/checkpoints/eval` 中,包含多层平均平均的注意力机制对齐结果,这被保存为名为 `step{step_num}_text{text_id}_single_alignment.png` 的图片;以及合成的音频文件,保存为名为 `step{step_num}_text{text_id}_single_predicted.wav` 的音频。
### 使用 Tensorboard 查看训练
Tensorboard 训练日志被保存在 `${directory_to_save_results}/log/` 文件夹,可以通过 tensorboard 查看。使用方法如下。
```bash
tensorboard --logdir=${log_dir} --host=$HOSTNAME --port=8888
```
### 从保存的模型合成语音
给定一组文本,使用 `synthesis.py` 从一个训练好的模型来合成语音,使用方法如下。
```bash
python synthesis.py --use-gpu --preset=${preset_json_path} \
--hparams="parameters you may want to override" \
${checkpoint} ${text_list_file} ${dst_dir}
```
示例文本文件如下:
```text
Generative adversarial network or variational auto-encoder.
Once upon a time there was a dear little girl who was loved by every one who looked at her, but most of all by her grandmother, and there was nothing that she would not have given to the child.
A text-to-speech synthesis system typically consists of multiple stages, such as a text analysis frontend, an acoustic model and an audio synthesis module.
```
合成的结果包含注意力机制对齐结果和音频文件,保存于 `${dst_dir}`
### 计算 position rate
根据 [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654), 对于不同的数据集,会有不同的 position rate. 有两个不同的 position rate一个用于 query 一个用于 key 这在论文中称为 $\omega_1$ 和 $\omega_2$ ,在预设配置文件中的名字分别为 `query_position_rate``key_position_rate`
比如 LJSpeech 数据集的 `query_position_rate``key_position_rate` 分别为 `1.0``1.385`。固定 `query_position_rate` 为 1.0`key_position_rate` 可以使用 `compute_timestamp_ratio.py` 计算,命令如下,其中 `${data_root}` 是预处理后的数据集路径。
```bash
python compute_timestamp_ratio.py --preset=${preset_json_path} \
--hparams="parameters you may want to override" ${data_root}
```
可以得到如下的结果。
```text
100%|██████████████████████████████████████████████████████████| 13047/13047 [00:12<00:00, 1058.19it/s]
1345587 1863884.0 1.3851828235558161
```
然后在预设配置文件中设置 `key_position_rate=1.385` 以及 `query_position_rate=1.0`
## 致谢
本实现包含及改写了 r9y9's 的 [deepvoice3_pytorch](https://github.com/r9y9/deepvoice3_pytorch) 中的部分文件,在此表示感谢。

62
deepvoice3/_ce.py Normal file
View File

@ -0,0 +1,62 @@
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
from kpi import AccKpi
each_epoch_duration_frame1_card1 = DurationKpi("each_epoch_duration_frame1_card1", 0.02, actived=True)
train_cost_frame1_card1 = CostKpi("train_cost_frame1_card1", 0.02, actived=True)
each_epoch_duration_frame4_card1 = DurationKpi("each_epoch_duration_frame4_card1", 0.05, actived=True)
train_cost_frame4_card1 = CostKpi("train_cost_frame4_card1", 0.02, actived=True)
tracking_kpis = [
each_epoch_duration_frame1_card1,
train_cost_frame1_card1,
each_epoch_duration_frame4_card1,
train_cost_frame4_card1,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)

Binary file not shown.

After

Width:  |  Height:  |  Size: 447 KiB

98
deepvoice3/audio.py Normal file
View File

@ -0,0 +1,98 @@
# This file was copied from https://github.com/r9y9/deepvoice3_pytorch/tree/master/audio.py
# Copyright (c) 2017: Ryuichi Yamamoto.
import librosa
import librosa.filters
import math
import numpy as np
from scipy import signal
from hparams import hparams
from scipy.io import wavfile
import lws
def load_wav(path):
return librosa.core.load(path, sr=hparams.sample_rate)[0]
def save_wav(wav, path):
wav = wav * 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, hparams.sample_rate, wav.astype(np.int16))
def preemphasis(x):
from nnmnkwii.preprocessing import preemphasis
return preemphasis(x, hparams.preemphasis)
def inv_preemphasis(x):
from nnmnkwii.preprocessing import inv_preemphasis
return inv_preemphasis(x, hparams.preemphasis)
def spectrogram(y):
D = _lws_processor().stft(preemphasis(y)).T
S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
return _normalize(S)
def inv_spectrogram(spectrogram):
'''Converts spectrogram to waveform using librosa'''
S = _db_to_amp(_denormalize(spectrogram) +
hparams.ref_level_db) # Convert back to linear
processor = _lws_processor()
D = processor.run_lws(S.astype(np.float64).T**hparams.power)
y = processor.istft(D).astype(np.float32)
return inv_preemphasis(y)
def melspectrogram(y):
D = _lws_processor().stft(preemphasis(y)).T
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db
if not hparams.allow_clipping_in_normalization:
assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
return _normalize(S)
def _lws_processor():
return lws.lws(hparams.fft_size, hparams.hop_size, mode="speech")
# Conversions:
_mel_basis = None
def _linear_to_mel(spectrogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectrogram)
def _build_mel_basis():
if hparams.fmax is not None:
assert hparams.fmax <= hparams.sample_rate // 2
return librosa.filters.mel(hparams.sample_rate,
hparams.fft_size,
fmin=hparams.fmin,
fmax=hparams.fmax,
n_mels=hparams.num_mels)
def _amp_to_db(x):
min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _normalize(S):
return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
def _denormalize(S):
return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db

137
deepvoice3/builder.py Normal file
View File

@ -0,0 +1,137 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepvoice3 import DeepVoiceTTS, ConvSpec, WindowRange
def deepvoice3(n_vocab,
embed_dim=256,
mel_dim=80,
linear_dim=513,
r=4,
downsample_step=1,
n_speakers=1,
speaker_dim=16,
padding_idx=0,
dropout=(1 - 0.96),
filter_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_range=WindowRange(-1, 3),
key_projection=False,
value_projection=False):
time_upsampling = max(downsample_step, 1)
h = encoder_channels
k = filter_size
encoder_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3))
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3))
attentive_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1))
attention = [True, False, False, False, True]
h = converter_channels
postnet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1), ConvSpec(2 * h, k, 3))
model = DeepVoiceTTS(
"dv3", n_speakers, speaker_dim, speaker_embedding_weight_std, n_vocab,
embed_dim, padding_idx, embedding_weight_std, freeze_embedding,
encoder_convolutions, max_positions, padding_idx,
trainable_positional_encodings, mel_dim, r, prenet_convolutions,
attentive_convolutions, attention, use_memory_mask,
force_monotonic_attention, query_position_rate, key_position_rate,
window_range, key_projection, value_projection, linear_dim,
postnet_convolutions, time_upsampling, dropout,
use_decoder_state_for_postnet_input, "float32")
return model
def deepvoice3_multispeaker(n_vocab,
embed_dim=256,
mel_dim=80,
linear_dim=513,
r=4,
downsample_step=1,
n_speakers=1,
speaker_dim=16,
padding_idx=0,
dropout=(1 - 0.96),
filter_size=5,
encoder_channels=128,
decoder_channels=256,
converter_channels=256,
query_position_rate=1.0,
key_position_rate=1.29,
use_memory_mask=False,
trainable_positional_encodings=False,
force_monotonic_attention=True,
use_decoder_state_for_postnet_input=True,
max_positions=512,
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
freeze_embedding=False,
window_range=WindowRange(-1, 3),
key_projection=False,
value_projection=False):
time_upsampling = max(downsample_step, 1)
h = encoder_channels
k = filter_size
encoder_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1), ConvSpec(h, k, 3))
h = decoder_channels
prenet_convolutions = (ConvSpec(h, k, 1))
attentive_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(h, k, 9), ConvSpec(h, k, 27),
ConvSpec(h, k, 1))
attention = [True, False, False, False, False]
h = converter_channels
postnet_convolutions = (ConvSpec(h, k, 1), ConvSpec(h, k, 3),
ConvSpec(2 * h, k, 1), ConvSpec(2 * h, k, 3))
model = DeepVoiceTTS(
"dv3", n_speakers, speaker_dim, speaker_embedding_weight_std, n_vocab,
embed_dim, padding_idx, embedding_weight_std, freeze_embedding,
encoder_convolutions, max_positions, padding_idx,
trainable_positional_encodings, mel_dim, r, prenet_convolutions,
attentive_convolutions, attention, use_memory_mask,
force_monotonic_attention, query_position_rate, key_position_rate,
window_range, key_projection, value_projection, linear_dim,
postnet_convolutions, time_upsampling, dropout,
use_decoder_state_for_postnet_input, "float32")
return model

View File

@ -0,0 +1,71 @@
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/compute_timestamp_ratio.py
# Copyright (c) 2017: Ryuichi Yamamoto.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import io
import numpy as np
sys.path.append("../")
from hparams import hparams, hparams_debug_string
from data.data import TextDataSource, MelSpecDataSource
from nnmnkwii.datasets import FileSourceDataset
from tqdm import trange
from modules import frontend
def build_parser():
parser = argparse.ArgumentParser(
description="Compute output/input timestamp ratio.")
parser.add_argument(
"--hparams", type=str, default="", help="Hyper parameters.")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters (json).")
parser.add_argument("data_root", type=str, help="path of the dataset.")
return parser
if __name__ == "__main__":
parser = build_parser()
args, _ = parser.parse_known_args()
data_root = args.data_root
preset = args.preset
# Load preset if specified
if preset is not None:
with io.open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
assert hparams.name == "deepvoice3"
# Code below
X = FileSourceDataset(TextDataSource(data_root))
Mel = FileSourceDataset(MelSpecDataSource(data_root))
in_sizes = []
out_sizes = []
for i in trange(len(X)):
x, m = X[i], Mel[i]
if X.file_data_source.multi_speaker:
x = x[0]
in_sizes.append(x.shape[0])
out_sizes.append(m.shape[0])
in_sizes = np.array(in_sizes)
out_sizes = np.array(out_sizes)
input_timestamps = np.sum(in_sizes)
output_timestamps = np.sum(
out_sizes) / hparams.outputs_per_step / hparams.downsample_step
print(input_timestamps, output_timestamps,
output_timestamps / input_timestamps)
sys.exit(0)

1497
deepvoice3/deepvoice3.py Normal file

File diff suppressed because it is too large Load Diff

113
deepvoice3/dry_run.py Normal file
View File

@ -0,0 +1,113 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
from hparams import hparams, hparams_debug_string
from modules import frontend
from deepvoice3 import DeepVoiceTTS
def dry_run(model):
"""
Run the model once, just to get it initialized.
"""
model.train()
_frontend = getattr(frontend, hparams.frontend)
batch_size = 4
enc_length = 157
snd_sample_length = 500
r = hparams.outputs_per_step
downsample_step = hparams.downsample_step
n_speakers = hparams.n_speakers
# make sure snd_sample_length can be divided by r * downsample_step
linear_shift = r * downsample_step
snd_sample_length += linear_shift - snd_sample_length % linear_shift
decoder_length = snd_sample_length // downsample_step // r
mel_length = snd_sample_length // downsample_step
n_vocab = _frontend.n_vocab
max_pos = hparams.max_positions
spker_embed = hparams.speaker_embed_dim
linear_dim = model.linear_dim
mel_dim = hparams.num_mels
x = np.random.randint(
low=0, high=n_vocab, size=(batch_size, enc_length, 1), dtype="int64")
input_lengths = np.arange(
enc_length - batch_size + 1, enc_length + 1, dtype="int64")
mel = np.random.randn(batch_size, mel_dim, 1, mel_length).astype("float32")
y = np.random.randn(batch_size, linear_dim, 1,
snd_sample_length).astype("float32")
text_positions = np.tile(
np.arange(
0, enc_length, dtype="int64"), (batch_size, 1))
text_mask = text_positions > np.expand_dims(input_lengths, 1)
text_positions[text_mask] = 0
text_positions = np.expand_dims(text_positions, axis=-1)
frame_positions = np.tile(
np.arange(
1, decoder_length + 1, dtype="int64"), (batch_size, 1))
frame_positions = np.expand_dims(frame_positions, axis=-1)
done = np.zeros(shape=(batch_size, 1, 1, decoder_length), dtype="float32")
target_lengths = np.array([snd_sample_length] * batch_size).astype("int64")
speaker_ids = np.random.randint(
low=0, high=n_speakers, size=(batch_size, 1),
dtype="int64") if n_speakers > 1 else None
ismultispeaker = speaker_ids is not None
x = dg.to_variable(x)
input_lengths = dg.to_variable(input_lengths)
mel = dg.to_variable(mel)
y = dg.to_variable(y)
text_positions = dg.to_variable(text_positions)
frame_positions = dg.to_variable(frame_positions)
done = dg.to_variable(done)
target_lengths = dg.to_variable(target_lengths)
speaker_ids = dg.to_variable(
speaker_ids) if speaker_ids is not None else None
# these two fields are used as numpy ndarray
text_lengths = input_lengths.numpy()
decoder_lengths = target_lengths.numpy() // r // downsample_step
max_seq_len = max(text_lengths.max(), decoder_lengths.max())
if max_seq_len >= hparams.max_positions:
raise RuntimeError(
"max_seq_len ({}) >= max_posision ({})\n"
"Input text or decoder targget length exceeded the maximum length.\n"
"Please set a larger value for ``max_position`` in hyper parameters."
.format(max_seq_len, hparams.max_positions))
# cause paddle's embedding layer expect shape[-1] == 1
# first dry run runs the whole model
mel_outputs, linear_outputs, attn, done_hat = model(
x, input_lengths, mel, speaker_ids, text_positions, frame_positions)
num_parameters = 0
for k, v in model.state_dict().items():
print("{}|{}|{}".format(k, v.shape, np.prod(v.shape)))
num_parameters += np.prod(v.shape)
print("now model has {} parameters".format(len(model.state_dict())))
print("now model has {} elements".format(num_parameters))

321
deepvoice3/eval_model.py Normal file
View File

@ -0,0 +1,321 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
from os.path import join, expanduser
from warnings import warn
from datetime import datetime
import matplotlib
# Force matplotlib not to use any Xwindows backend.
matplotlib.use("Agg")
from matplotlib import pyplot as plt
from matplotlib import cm
import audio
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
import librosa.display
from tensorboardX import SummaryWriter
# import global hyper parameters
from hparams import hparams
from modules import frontend
_frontend = getattr(frontend, hparams.frontend)
def tts(model, text, p=0., speaker_id=None):
"""
Convert text to speech waveform given a deepvoice3 model.
Args:
model (DeepVoiceTTS): Model used to synthesize waveform.
text (str) : Input text to be synthesized
p (float) : Replace word to pronounciation if p > 0. Default is 0.
Returns:
waveform (numpy.ndarray): Shape(T_wav, ), predicted wave form, where
T_wav means the length of the synthesized wave form.
alignment (numpy.ndarray): Shape(T_dec, T_enc), predicted alignment
matrix, where T_dec means the time steps of decoder outputs, T_enc
means the time steps of encoder outoputs.
spectrogram (numpy.ndarray): Shape(T_lin, C_lin), predicted linear
spectrogram, where T__lin means the time steps of linear
spectrogram and C_lin mean sthe channels of linear spectrogram.
mel (numpy.ndarray): Shape(T_mel, C_mel), predicted mel spectrogram,
where T_mel means the time steps of mel spectrogram and C_mel means
the channels of mel spectrogram.
"""
model.eval()
sequence = np.array(_frontend.text_to_sequence(text, p=p)).astype("int64")
sequence = np.reshape(sequence, (1, -1, 1))
text_positions = np.arange(1, sequence.shape[1] + 1, dtype="int64")
text_positions = np.reshape(text_positions, (1, -1, 1))
sequence = dg.to_variable(sequence)
text_positions = dg.to_variable(text_positions)
speaker_ids = None if speaker_id is None else fluid.layers.fill_constant(
shape=[1, 1], value=speaker_id)
# sequence: shape(1, input_length, 1)
# text_positions: shape(1, input_length, 1)
# Greedy decoding
mel_outputs, linear_outputs, alignments, done = model.transduce(
sequence, text_positions, speaker_ids)
# reshape to the desired shape
linear_output = linear_outputs.numpy().squeeze().T
spectrogram = audio._denormalize(linear_output)
alignment = alignments.numpy()[0]
mel = mel_outputs.numpy().squeeze().T
mel = audio._denormalize(mel)
# Predicted audio signal
waveform = audio.inv_spectrogram(linear_output.T)
return waveform, alignment, spectrogram, mel
def prepare_spec_image(spectrogram):
"""
Prepare an image from spectrogram to be written to tensorboardX
summary writer.
Args:
spectrogram (numpy.ndarray): Shape(T, C), spectrogram to be
visualized, where T means the time steps of the spectrogram,
and C means the channels of the spectrogram.
Return:
np.ndarray: Shape(C, T, 4), the generated image of the spectrogram,
where T means the time steps of the spectrogram. It is treated
as the width of the image. And C means the channels of the
spectrogram, which is treated as the height of the image. And 4
means it is a 'ARGB' format.
"""
# [0, 1]
spectrogram = (spectrogram - np.min(spectrogram)) / (
np.max(spectrogram) - np.min(spectrogram))
spectrogram = np.flip(spectrogram, axis=1) # flip against freq axis
return np.uint8(cm.magma(spectrogram.T) * 255)
def plot_alignment(alignment, path, info=None):
fig, ax = plt.subplots()
im = ax.imshow(
alignment, aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
plt.savefig(path, format="png")
plt.close()
def time_string():
return datetime.now().strftime("%Y-%m-%d %H:%M")
def save_alignment(global_step, path, attn):
plot_alignment(
attn.T,
path,
info="{}, {}, step={}".format(hparams.builder,
time_string(), global_step))
def eval_model(global_step, writer, model, checkpoint_dir, ismultispeaker):
# hard coded text sequences
texts = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
eval_output_dir = join(checkpoint_dir, "eval")
if not os.path.exists(eval_output_dir):
os.makedirs(eval_output_dir)
print("[eval] Evaluating the model, results are saved in {}".format(
eval_output_dir))
model.eval()
# hard coded
speaker_ids = [0, 1, 10] if ismultispeaker else [None]
for speaker_id in speaker_ids:
speaker_str = ("multispeaker{}".format(speaker_id)
if speaker_id is not None else "single")
for idx, text in enumerate(texts):
signal, alignment, _, mel = tts(model,
text,
p=0,
speaker_id=speaker_id)
signal /= np.max(np.abs(signal))
# Alignment
path = join(eval_output_dir,
"step{:09d}_text{}_{}_alignment.png".format(
global_step, idx, speaker_str))
save_alignment(global_step, path, alignment)
tag = "eval_averaged_alignment_{}_{}".format(idx, speaker_str)
writer.add_image(
tag,
np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255),
global_step,
dataformats='HWC')
# Mel
writer.add_image(
"(Eval) Predicted mel spectrogram text{}_{}".format(
idx, speaker_str),
prepare_spec_image(mel),
global_step,
dataformats='HWC')
# Audio
path = join(eval_output_dir,
"step{:09d}_text{}_{}_predicted.wav".format(
global_step, idx, speaker_str))
audio.save_wav(signal, path)
try:
writer.add_audio(
"(Eval) Predicted audio signal {}_{}".format(idx,
speaker_str),
signal,
global_step,
sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
def save_states(global_step,
writer,
mel_outputs,
linear_outputs,
attn,
mel,
y,
input_lengths,
checkpoint_dir=None):
"""
Save states for the trainning process.
"""
print("[train] Saving intermediate states at step {}".format(global_step))
idx = min(1, len(input_lengths) - 1)
input_length = input_lengths[idx]
# Alignment, Multi-hop attention
if attn is not None and len(attn.shape) == 4:
attn = attn.numpy()
for i in range(attn.shape[0]):
alignment = attn[i]
alignment = alignment[idx]
tag = "alignment_layer{}".format(i + 1)
writer.add_image(
tag,
np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255),
global_step,
dataformats='HWC')
alignment_dir = join(checkpoint_dir,
"alignment_layer{}".format(i + 1))
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
path = join(
alignment_dir,
"step{:09d}_layer_{}_alignment.png".format(global_step, i + 1))
save_alignment(global_step, path, alignment)
alignment_dir = join(checkpoint_dir, "alignment_ave")
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
path = join(alignment_dir,
"step{:09d}_alignment.png".format(global_step))
alignment = np.mean(attn, axis=0)[idx]
save_alignment(global_step, path, alignment)
tag = "averaged_alignment"
writer.add_image(
tag,
np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255),
global_step,
dataformats="HWC")
if mel_outputs is not None:
mel_output = mel_outputs[idx].numpy().squeeze().T
mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image(
"Predicted mel spectrogram",
mel_output,
global_step,
dataformats="HWC")
if linear_outputs is not None:
linear_output = linear_outputs[idx].numpy().squeeze().T
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image(
"Predicted linear spectrogram",
spectrogram,
global_step,
dataformats="HWC")
signal = audio.inv_spectrogram(linear_output.T)
signal /= np.max(np.abs(signal))
path = join(checkpoint_dir,
"step{:09d}_predicted.wav".format(global_step))
try:
writer.add_audio(
"Predicted audio signal",
signal,
global_step,
sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
audio.save_wav(signal, path)
if mel_outputs is not None:
mel_output = mel[idx].numpy().squeeze().T
mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image(
"Target mel spectrogram",
mel_output,
global_step,
dataformats="HWC")
if linear_outputs is not None:
linear_output = y[idx].numpy().squeeze().T
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image(
"Target linear spectrogram",
spectrogram,
global_step,
dataformats="HWC")

150
deepvoice3/hparams.py Normal file
View File

@ -0,0 +1,150 @@
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/hparams.py
# Copyright (c) 2017: Ryuichi Yamamoto.
import hparam_tf.hparam
# NOTE: If you want full control for model architecture. please take a look
# at the code and change whatever you want. Some hyper parameters are hardcoded.
# Default hyperparameters:
hparams = hparam_tf.hparam.HParams(
name="deepvoice3",
# Text:
# [en, jp]
frontend='en',
# Replace words to its pronunciation with fixed probability.
# e.g., 'hello' to 'HH AH0 L OW1'
# [en, jp]
# en: Word -> pronunciation using CMUDict
# jp: Word -> pronounciation usnig MeCab
# [0 ~ 1.0]: 0 means no replacement happens.
replace_pronunciation_prob=0.5,
# Convenient model builder
# [deepvoice3, deepvoice3_multispeaker, nyanko]
# Definitions can be found at deepvoice3_pytorch/builder.py
# deepvoice3: DeepVoice3 https://arxiv.org/abs/1710.07654
# deepvoice3_multispeaker: Multi-speaker version of DeepVoice3
# nyanko: https://arxiv.org/abs/1710.08969
builder="deepvoice3",
# Must be configured depends on the dataset and model you use
n_speakers=1,
speaker_embed_dim=16,
# Audio:
num_mels=80,
fmin=125,
fmax=7600,
fft_size=1024,
hop_size=256,
sample_rate=22050,
preemphasis=0.97,
min_level_db=-100,
ref_level_db=20,
# whether to rescale waveform or not.
# Let x is an input waveform, rescaled waveform y is given by:
# y = x / np.abs(x).max() * rescaling_max
rescaling=False,
rescaling_max=0.999,
# mel-spectrogram is normalized to [0, 1] for each utterance and clipping may
# happen depends on min_level_db and ref_level_db, causing clipping noise.
# If False, assertion is added to ensure no clipping happens.
allow_clipping_in_normalization=True,
# Model:
downsample_step=4, # must be 4 when builder="nyanko"
outputs_per_step=1, # must be 1 when builder="nyanko"
embedding_weight_std=0.1,
speaker_embedding_weight_std=0.01,
padding_idx=0,
# Maximum number of input text length
# try setting larger value if you want to give very long text input
max_positions=512,
dropout=1 - 0.95,
kernel_size=3,
text_embed_dim=128,
encoder_channels=256,
decoder_channels=256,
# Note: large converter channels requires significant computational cost
converter_channels=256,
query_position_rate=1.0,
# can be computed by `compute_timestamp_ratio.py`.
key_position_rate=1.385, # 2.37 for jsut
key_projection=False,
value_projection=False,
use_memory_mask=True,
trainable_positional_encodings=False,
freeze_embedding=False,
# If True, use decoder's internal representation for postnet inputs,
# otherwise use mel-spectrogram.
use_decoder_state_for_postnet_input=True,
# Data loader
random_seed=1234,
pin_memory=True,
# Set it to 1 when in Windows (MemoryError, THAllocator.c 0x5)
num_workers=2,
# Loss
masked_loss_weight=0.5, # (1-w)*loss + w * masked_loss
# heuristic: priotrize [0 ~ priotiry_freq] for linear loss
priority_freq=3000,
priority_freq_weight=0.0, # (1-w)*linear_loss + w*priority_linear_loss
# https://arxiv.org/pdf/1710.08969.pdf
# Adding the divergence to the loss stabilizes training, expecially for
# very deep (> 10 layers) networks.
# Binary div loss seems has approx 10x scale compared to L1 loss, so I choose 0.1.
binary_divergence_weight=0.1, # set 0 to disable
use_guided_attention=True,
guided_attention_sigma=0.2,
# Training:
batch_size=16,
adam_beta1=0.5,
adam_beta2=0.9,
adam_eps=1e-6,
amsgrad=False,
initial_learning_rate=5e-4, # 0.001,
lr_schedule="noam_learning_rate_decay",
lr_schedule_kwargs={},
nepochs=2000,
weight_decay=0.0,
clip_thresh=0.1,
# Save
checkpoint_interval=10000,
eval_interval=10000,
save_optimizer_state=True,
# Eval:
# this can be list for multple layers of attention
# e.g., [True, False, False, False, True]
force_monotonic_attention=True,
# Attention constraint for incremental decoding
window_ahead=3,
# 0 tends to prevent word repretetion, but sometime causes skip words
window_backward=1,
power=1.4, # Power to raise magnitudes to prior to phase retrieval
# GC:
# Forced garbage collection probability
# Use only when MemoryError continues in Windows (Disabled by default)
#gc_probability = 0.001,
# json_meta mode only
# 0: "use all",
# 1: "ignore only unmatched_alignment",
# 2: "fully ignore recognition",
ignore_recognition_level=2,
# when dealing with non-dedicated speech dataset(e.g. movie excerpts), setting min_text above 15 is desirable. Can be adjusted by dataset.
min_text=20,
# if true, data without phoneme alignment file(.lab) will be ignored
process_only_htk_aligned=False)
def hparams_debug_string():
values = hparams.values()
hp = [' %s: %s' % (name, values[name]) for name in sorted(values)]
return 'Hyperparameters:\n' + '\n'.join(hp)

89
deepvoice3/ljspeech.py Normal file
View File

@ -0,0 +1,89 @@
# This file is copied from https://github.com/r9y9/deepvoice3_pytorch/tree/master/ljspeech.py
# Copyright (c) 2017: Ryuichi Yamamoto.
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import numpy as np
import io
import os
import audio
from hparams import hparams
def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x):
'''Preprocesses the LJ Speech dataset from a given input path into a given output directory.
Args:
in_dir: The directory where you have downloaded the LJ Speech dataset
out_dir: The directory to write the output into
num_workers: Optional number of worker processes to parallelize across
tqdm: You can optionally pass tqdm to get a nice progress bar
Returns:
A list of tuples describing the training examples. This should be written to train.txt
'''
# We use ProcessPoolExecutor to parallize across processes. This is just an optimization and you
# can omit it and just call _process_utterance on each input if you want.
executor = ProcessPoolExecutor(max_workers=num_workers)
futures = []
index = 1
with io.open(
os.path.join(in_dir, 'metadata.csv'), "rt", encoding='utf-8') as f:
for line in f:
parts = line.strip().split('|')
wav_path = os.path.join(in_dir, 'wavs', '%s.wav' % parts[0])
text = parts[2]
if len(text) < hparams.min_text:
continue
futures.append(
executor.submit(
partial(_process_utterance, out_dir, index, wav_path,
text)))
index += 1
return [future.result() for future in tqdm(futures)]
def _process_utterance(out_dir, index, wav_path, text):
'''Preprocesses a single utterance audio/text pair.
This writes the mel and linear scale spectrograms to disk and returns a tuple to write
to the train.txt file.
Args:
out_dir: The directory to write the spectrograms into
index: The numeric index to use in the spectrogram filenames.
wav_path: Path to the audio file containing the speech input
text: The text spoken in the input audio file
Returns:
A (spectrogram_filename, mel_filename, n_frames, text) tuple to write to train.txt
'''
# Load the audio to a numpy array:
wav = audio.load_wav(wav_path)
if hparams.rescaling:
wav = wav / np.abs(wav).max() * hparams.rescaling_max
# Compute the linear-scale spectrogram from the wav:
spectrogram = audio.spectrogram(wav).astype(np.float32)
n_frames = spectrogram.shape[1]
# Compute a mel-scale spectrogram from the wav:
mel_spectrogram = audio.melspectrogram(wav).astype(np.float32)
# Write the spectrograms to disk:
spectrogram_filename = 'ljspeech-spec-%05d.npy' % index
mel_filename = 'ljspeech-mel-%05d.npy' % index
np.save(
os.path.join(out_dir, spectrogram_filename),
spectrogram.T,
allow_pickle=False)
np.save(
os.path.join(out_dir, mel_filename),
mel_spectrogram.T,
allow_pickle=False)
# Return a tuple describing this training example:
return (spectrogram_filename, mel_filename, n_frames, text)

89
deepvoice3/preprocess.py Normal file
View File

@ -0,0 +1,89 @@
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/preprocess.py
# Copyright (c) 2017: Ryuichi Yamamoto.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import io
import six
import os
from multiprocessing import cpu_count
from tqdm import tqdm
import importlib
from hparams import hparams, hparams_debug_string
def build_parser():
parser = argparse.ArgumentParser(description="Data Preprocessing")
parser.add_argument("--num-workers", type=int, help="Num workers.")
parser.add_argument(
"--hparams",
type=str,
default="",
help="Hyper parameters to overwrite.")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters (json)")
parser.add_argument("name", type=str, help="Dataset name")
parser.add_argument("in_dir", type=str, help="Dataset path.")
parser.add_argument(
"out_dir", type=str, help="Path of preprocessed dataset.")
return parser
def preprocess(mod, in_dir, out_root, num_workers):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
metadata = mod.build_from_path(in_dir, out_dir, num_workers, tqdm=tqdm)
write_metadata(metadata, out_dir)
def write_metadata(metadata, out_dir):
if six.PY3:
string_type = str
elif six.PY2:
string_type = unicode
else:
raise ValueError("Not running on Python2 or Python 3?")
with io.open(
os.path.join(out_dir, 'train.txt'), 'wt', encoding='utf-8') as f:
for m in metadata:
f.write(u'|'.join([string_type(x) for x in m]) + '\n')
frames = sum([m[2] for m in metadata])
frame_shift_ms = hparams.hop_size / hparams.sample_rate * 1000
hours = frames * frame_shift_ms / (3600 * 1000)
print('Wrote %d utterances, %d frames (%.2f hours)' %
(len(metadata), frames, hours))
print('Max input length: %d' % max(len(m[3]) for m in metadata))
print('Max output length: %d' % max(m[2] for m in metadata))
if __name__ == "__main__":
parser = build_parser()
args, _ = parser.parse_known_args()
name = args.name
in_dir = args.in_dir
out_dir = args.out_dir
num_workers = args.num_workers
if num_workers is None:
num_workers = cpu_count()
preset = args.preset
# Load preset if specified
if preset is not None:
with io.open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
assert hparams.name == "deepvoice3"
print(hparams_debug_string())
assert name in ["ljspeech"], "now we only supports ljspeech"
mod = importlib.import_module(name)
preprocess(mod, in_dir, out_dir, num_workers)

View File

@ -0,0 +1,65 @@
{
"name": "deepvoice3",
"frontend": "en",
"replace_pronunciation_prob": 0.5,
"builder": "deepvoice3",
"n_speakers": 1,
"speaker_embed_dim": 16,
"num_mels": 80,
"fmin": 125,
"fmax": 7600,
"fft_size": 1024,
"hop_size": 256,
"sample_rate": 22050,
"preemphasis": 0.97,
"min_level_db": -100,
"ref_level_db": 20,
"rescaling": false,
"rescaling_max": 0.999,
"allow_clipping_in_normalization": true,
"downsample_step": 4,
"outputs_per_step": 1,
"embedding_weight_std": 0.1,
"speaker_embedding_weight_std": 0.01,
"padding_idx": 0,
"max_positions": 512,
"dropout": 0.050000000000000044,
"kernel_size": 3,
"text_embed_dim": 256,
"encoder_channels": 512,
"decoder_channels": 256,
"converter_channels": 256,
"query_position_rate": 1.0,
"key_position_rate": 1.385,
"key_projection": true,
"value_projection": true,
"use_memory_mask": true,
"trainable_positional_encodings": false,
"freeze_embedding": false,
"use_decoder_state_for_postnet_input": true,
"pin_memory": true,
"num_workers": 2,
"masked_loss_weight": 0.5,
"priority_freq": 3000,
"priority_freq_weight": 0.0,
"binary_divergence_weight": 0.1,
"use_guided_attention": true,
"guided_attention_sigma": 0.2,
"batch_size": 16,
"adam_beta1": 0.5,
"adam_beta2": 0.9,
"adam_eps": 1e-06,
"initial_learning_rate": 0.0005,
"lr_schedule": "noam_learning_rate_decay",
"lr_schedule_kwargs": {},
"nepochs": 2000,
"weight_decay": 0.0,
"clip_thresh": 0.1,
"checkpoint_interval": 10000,
"eval_interval": 10000,
"save_optimizer_state": true,
"force_monotonic_attention": true,
"window_ahead": 3,
"window_backward": 1,
"power": 1.4
}

167
deepvoice3/synthesis.py Normal file
View File

@ -0,0 +1,167 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import os
import io
from os.path import dirname, join, basename, splitext, exists
from tqdm import tqdm
import numpy as np
import nltk
from paddle import fluid
import paddle.fluid.dygraph as dg
sys.path.append("../")
import audio
from modules import frontend
import dry_run
from hparams import hparams
from train import make_deepvoice3_from_hparams
from eval_model import tts, plot_alignment
def build_parser():
parser = argparse.ArgumentParser(
description="Synthesis waveform from trained model.")
parser.add_argument(
"--hparams", type=str, default="", help="Hyper parameters.")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters (json).")
parser.add_argument(
"--use-gpu",
action="store_true",
help="Whether to use gpu for generation.")
parser.add_argument(
"--file-name-suffix", type=str, default="", help="File name suffix.")
parser.add_argument(
"--max-decoder-steps", type=int, default=500, help="Max decoder steps.")
parser.add_argument(
"--replace_pronunciation_prob",
type=float,
default=0.,
help="Probility to replace text with pronunciation.")
parser.add_argument(
"--speaker-id", type=int, help="Speaker ID (for multi-speaker model).")
parser.add_argument(
"--output-html", action="store_true", help="Output html for blog post.")
parser.add_argument(
"checkpoint", type=str, help="The checkpoint used for synthesis")
parser.add_argument(
"text_list_file",
type=str,
help="Text file to synthesis, a sentence per line.")
parser.add_argument(
"dst_dir", type=str, help="Directory to save synthesis results.")
return parser
if __name__ == "__main__":
parser = build_parser()
args, _ = parser.parse_known_args()
checkpoint_path = args.checkpoint
text_list_file_path = args.text_list_file
dst_dir = args.dst_dir
use_gpu = args.use_gpu
max_decoder_steps = args.max_decoder_steps
file_name_suffix = args.file_name_suffix
replace_pronunciation_prob = args.replace_pronunciation_prob
output_html = args.output_html
speaker_id = args.speaker_id
preset = args.preset
print("Command Line Args:")
for k, v in vars(args).items():
print(" {}: {}".format(k, v))
# Load preset if specified
if preset is not None:
with io.open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
assert hparams.name == "deepvoice3"
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Model
model = make_deepvoice3_from_hparams(hparams)
dry_run(model)
model_dict, _ = dg.load_dygraph(args.checkpoint)
model.set_dict(model_dict)
checkpoint_name = splitext(basename(checkpoint_path))[0]
model.seq2seq.decoder.max_decoder_steps = max_decoder_steps
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
with io.open(text_list_file_path, "rt", encoding="utf-8") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
text = line[:-1]
words = nltk.word_tokenize(text)
waveform, alignment, _, _ = tts(model,
text,
p=replace_pronunciation_prob,
speaker_id=speaker_id)
dst_wav_path = join(dst_dir, "{}_{}{}.wav".format(
idx, checkpoint_name, file_name_suffix))
dst_alignment_path = join(
dst_dir, "{}_{}{}_alignment.png".format(
idx, checkpoint_name, file_name_suffix))
plot_alignment(
alignment.T,
dst_alignment_path,
info="{}, {}".format(hparams.builder,
basename(checkpoint_path)))
audio.save_wav(waveform, dst_wav_path)
name = splitext(basename(text_list_file_path))[0]
if output_html:
print("""
{}
({} chars, {} words)
<audio controls="controls" >
<source src="/audio/{}/{}/{}" autoplay/>
Your browser does not support the audio element.
</audio>
<div align="center"><img src="/audio/{}/{}/{}" /></div>
""".format(text,
len(text),
len(words), hparams.builder, name,
basename(dst_wav_path), hparams.builder, name,
basename(dst_alignment_path)))
else:
print(idx, ": {}\n ({} chars, {} words)".format(text,
len(text),
len(words)))
print("Finished! Check out {} for generated audio samples.".format(
dst_dir))
sys.exit(0)

250
deepvoice3/train.py Normal file
View File

@ -0,0 +1,250 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import io
from paddle import fluid
import paddle.fluid.dygraph as dg
import sys
sys.path.append("../")
from argparse import ArgumentParser
from hparams import hparams, hparams_debug_string
from nnmnkwii.datasets import FileSourceDataset
from data.data import (TextDataSource, MelSpecDataSource,
LinearSpecDataSource,
PartialyRandomizedSimilarTimeLengthSampler,
Dataset, make_loader, create_batch)
from modules import frontend
from builder import deepvoice3, WindowRange
from dry_run import dry_run
from train_model import train_model
from modules.loss import TTSLoss
from tensorboardX import SummaryWriter
def build_arg_parser():
parser = ArgumentParser(description="Train deepvoice 3 model.")
parser.add_argument(
"--data-root",
type=str,
required=True,
help="Directory contains preprocessed features.")
parser.add_argument(
"--use-data-parallel",
action="store_true",
help="Whether to use data parallel training.")
parser.add_argument(
"--use-gpu", action="store_true", help="Whether to use gpu training.")
parser.add_argument(
"--output",
type=str,
default="result",
help="Directory to save results")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters in json format.")
parser.add_argument(
"--hparams",
type=str,
default="",
help="Hyper parameters to override preset.")
parser.add_argument(
"--checkpoint",
type=str,
help="Restore model from checkpoint path if given.")
parser.add_argument(
"--reset-optimizer", action="store_true", help="Reset optimizer.")
# mutually exclusive option
train_opt = parser.add_mutually_exclusive_group()
train_opt.add_argument(
"--train-seq2seq-only",
action="store_true",
help="Train only seq2seq model")
train_opt.add_argument(
"--train-postnet-only",
action="store_true",
help="Train only postnet model.")
parser.add_argument(
"--speaker-id",
type=int,
help="Use specific speaker of data in case for multi-speaker datasets.",
)
return parser
def make_deepvoice3_from_hparams(hparams):
n_vocab = getattr(frontend, hparams.frontend).n_vocab
model = deepvoice3(
n_vocab, hparams.text_embed_dim, hparams.num_mels,
hparams.fft_size // 2 + 1, hparams.outputs_per_step,
hparams.downsample_step, hparams.n_speakers, hparams.speaker_embed_dim,
hparams.padding_idx, hparams.dropout, hparams.kernel_size,
hparams.encoder_channels, hparams.decoder_channels,
hparams.converter_channels, hparams.query_position_rate,
hparams.key_position_rate, hparams.use_memory_mask,
hparams.trainable_positional_encodings,
hparams.force_monotonic_attention,
hparams.use_decoder_state_for_postnet_input, hparams.max_positions,
hparams.embedding_weight_std, hparams.speaker_embedding_weight_std,
hparams.freeze_embedding,
WindowRange(-hparams.window_backward, hparams.window_ahead),
hparams.key_projection, hparams.value_projection)
return model
def noam_learning_rate_decay(init_lr, warmup_steps=4000):
# Noam scheme from tensor2tensor:
warmup_steps = float(warmup_steps)
return dg.NoamDecay(1 / (warmup_steps * (init_lr**2)), warmup_steps)
def make_optimizer_from_hparams(hparams):
if hparams.lr_schedule is not None:
learning_rate = noam_learning_rate_decay(hparams.initial_learning_rate,
**hparams.lr_schedule_kwargs)
else:
learning_rate = hparams.initial_learning_rate
if hparams.weight_decay > 0.0:
regularization = fluid.regularizer.L2DecayRegularizer(
hparams.weight_decay)
else:
regularization = None
optim = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=hparams.adam_beta1,
beta2=hparams.adam_beta2,
regularization=regularization)
if hparams.clip_thresh > 0.0:
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
hparams.clip_thresh)
else:
clipper = None
return optim, clipper
def make_loss_from_hparams(hparams):
criterion = TTSLoss(
hparams.masked_loss_weight, hparams.priority_freq_weight,
hparams.binary_divergence_weight, hparams.guided_attention_sigma)
return criterion
class MyDataParallel(dg.parallel.DataParallel):
"""
A data parallel proxy for model.
"""
def __init__(self, layers, strategy):
super(MyDataParallel, self).__init__(layers, strategy)
def __getattr__(self, key):
if key in self.__dict__:
return object.__getattribute__(self, key)
elif key is "_layers":
return object.__getattribute__(self, "_sub_layers")["_layers"]
else:
return getattr(
object.__getattribute__(self, "_sub_layers")["_layers"], key)
if __name__ == "__main__":
parser = build_arg_parser()
args, _ = parser.parse_known_args()
print("Command Line Args:")
for k, v in vars(args).items():
print(" {}: {}".format(k, v))
# Load preset if specified
if args.preset is not None:
with io.open(args.preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
print(hparams_debug_string())
checkpoint_dir = os.path.join(args.output, "checkpoints")
tensorboard_dir = os.path.join(args.output, "log")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
data_root = args.data_root
speaker_id = args.speaker_id
X = FileSourceDataset(TextDataSource(data_root, speaker_id))
Mel = FileSourceDataset(MelSpecDataSource(data_root, speaker_id))
Y = FileSourceDataset(LinearSpecDataSource(data_root, speaker_id))
frame_lengths = Mel.file_data_source.frame_lengths
sampler = PartialyRandomizedSimilarTimeLengthSampler(
frame_lengths, batch_size=hparams.batch_size)
dataset = Dataset(X, Mel, Y)
n_trainers = dg.parallel.Env().nranks
local_rank = dg.parallel.Env().local_rank
data_loader = make_loader(
dataset,
batch_size=hparams.batch_size,
shuffle=False,
sampler=sampler,
create_batch_fn=create_batch,
trainer_count=n_trainers,
local_rank=local_rank)
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
if args.use_data_parallel else fluid.CUDAPlace(0)
if args.use_gpu else fluid.CPUPlace())
with dg.guard(place) as g:
pyreader = fluid.io.PyReader(capacity=10, return_list=True)
pyreader.decorate_batch_generator(data_loader, place)
model = make_deepvoice3_from_hparams(hparams)
optimizer, clipper = make_optimizer_from_hparams(hparams)
print("Log event path: {}".format(tensorboard_dir))
writer = SummaryWriter(tensorboard_dir) if local_rank == 0 else None
criterion = make_loss_from_hparams(hparams)
# loading saved model
if args.train_postnet_only or args.train_seq2seq_only:
assert args.checkpoint is not None, \
"you must train part of the model from a trained whole model"
if args.train_postnet_only:
assert hparams.use_decoder_state_for_postnet_input is False, \
"when training only the postnet, there is no decoder states"
if args.checkpoint is not None:
model_dict, optimizer_dict = dg.load_dygraph(args.checkpoint)
if args.use_data_parallel:
strategy = dg.parallel.prepare_context()
model = MyDataParallel(model, strategy)
train_model(model, pyreader, criterion, optimizer, clipper, writer,
args, hparams)
print("Done!")

13
deepvoice3/train.sh Normal file
View File

@ -0,0 +1,13 @@
export LD_LIBRARY_PATH=/fluid13_workspace/cuda-9.0/lib64/:/fluid13_workspace/cudnnv7.5_cuda9.0/lib64/:$LD_LIBRARY_PATH
#export PYTHONPATH=/dv3_workspace/paddle_for_dv3/build/python/
export PYTHONPATH=/fluid13_workspace/paddle_cherry_pick/build/python/:../
export CUDA_VISIBLE_DEVICES=7
GLOG_v=0 python -u train.py \
--use-gpu \
--reset-optimizer \
--preset=presets/deepvoice3_ljspeech.json \
--checkpoint-dir=checkpoint_single_1014 \
--data-root="/fluid13_workspace/dv3_workspace/deepvoice3_pytorch/data/ljspeech/" \
--hparams="batch_size=16"

258
deepvoice3/train_model.py Normal file
View File

@ -0,0 +1,258 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from itertools import chain
from paddle import fluid
import paddle.fluid.dygraph as dg
from tqdm import tqdm
from eval_model import eval_model, save_states
def train_model(model, loader, criterion, optimizer, clipper, writer, args,
hparams):
assert fluid.framework.in_dygraph_mode(
), "this function must be run within dygraph guard"
n_trainers = dg.parallel.Env().nranks
local_rank = dg.parallel.Env().local_rank
# amount of shifting when compute losses
linear_shift = hparams.outputs_per_step
mel_shift = hparams.outputs_per_step
global_step = 0
global_epoch = 0
ismultispeaker = model.n_speakers > 1
checkpoint_dir = os.path.join(args.output, "checkpoints")
tensorboard_dir = os.path.join(args.output, "log")
ce_loss = 0
start_time = time.time()
for epoch in range(hparams.nepochs):
epoch_loss = 0.
for step, inputs in tqdm(enumerate(loader())):
if len(inputs) == 9:
(text, input_lengths, mel, linear, text_positions,
frame_positions, done, target_lengths, speaker_ids) = inputs
else:
(text, input_lengths, mel, linear, text_positions,
frame_positions, done, target_lengths) = inputs
speaker_ids = None
model.train()
if not (args.train_seq2seq_only or args.train_postnet_only):
results = model(text, input_lengths, mel, speaker_ids,
text_positions, frame_positions)
mel_outputs, linear_outputs, alignments, done_hat = results
elif args.train_seq2seq_only:
if speaker_ids is not None:
speaker_embed = model.speaker_embedding(speaker_ids)
else:
speaker_embed = None
results = model.seq2seq(text, input_lengths, mel, speaker_embed,
text_positions, frame_positions)
mel_outputs, alignments, done_hat, decoder_states = results
if model.r > 1:
mel_outputs = fluid.layers.transpose(mel_outputs,
[0, 3, 2, 1])
mel_outputs = fluid.layers.reshape(
mel_outputs,
[mel_outputs.shape[0], -1, 1, model.mel_dim])
mel_outputs = fluid.layers.transpose(mel_outputs,
[0, 3, 2, 1])
linear_outputs = None
else:
assert (
model.use_decoder_state_for_postnet_input is False
), "when train only the converter, you have no decoder states"
if speaker_ids is not None:
speaker_embed = model.speaker_embedding(speaker_ids)
else:
speaker_embed = None
linear_outputs = model.converter(mel, speaker_embed)
alignments = None
mel_outputs = None
done_hat = None
if not args.train_seq2seq_only:
n_priority_freq = int(hparams.priority_freq /
(hparams.sample_rate * 0.5) *
model.linear_dim)
linear_mask = fluid.layers.sequence_mask(
target_lengths, maxlen=linear.shape[-1], dtype="float32")
linear_mask = linear_mask[:, linear_shift:]
linear_predicted = linear_outputs[:, :, :, :-linear_shift]
linear_target = linear[:, :, :, linear_shift:]
lin_l1_loss = criterion.l1_loss(
linear_predicted,
linear_target,
linear_mask,
priority_bin=n_priority_freq)
lin_div = criterion.binary_divergence(
linear_predicted, linear_target, linear_mask)
lin_loss = criterion.binary_divergence_weight * lin_div \
+ (1 - criterion.binary_divergence_weight) * lin_l1_loss
if writer is not None and local_rank == 0:
writer.add_scalar("linear_loss",
float(lin_loss.numpy()), global_step)
writer.add_scalar("linear_l1_loss",
float(lin_l1_loss.numpy()), global_step)
writer.add_scalar("linear_binary_div_loss",
float(lin_div.numpy()), global_step)
if not args.train_postnet_only:
mel_lengths = target_lengths // hparams.downsample_step
mel_mask = fluid.layers.sequence_mask(
mel_lengths, maxlen=mel.shape[-1], dtype="float32")
mel_mask = mel_mask[:, mel_shift:]
mel_predicted = mel_outputs[:, :, :, :-mel_shift]
mel_target = mel[:, :, :, mel_shift:]
mel_l1_loss = criterion.l1_loss(mel_predicted, mel_target,
mel_mask)
mel_div = criterion.binary_divergence(mel_predicted, mel_target,
mel_mask)
mel_loss = criterion.binary_divergence_weight * mel_div \
+ (1 - criterion.binary_divergence_weight) * mel_l1_loss
if writer is not None and local_rank == 0:
writer.add_scalar("mel_loss",
float(mel_loss.numpy()), global_step)
writer.add_scalar("mel_l1_loss",
float(mel_l1_loss.numpy()), global_step)
writer.add_scalar("mel_binary_div_loss",
float(mel_div.numpy()), global_step)
done_loss = criterion.done_loss(done_hat, done)
if writer is not None and local_rank == 0:
writer.add_scalar("done_loss",
float(done_loss.numpy()), global_step)
if hparams.use_guided_attention:
decoder_length = target_lengths.numpy() / (
hparams.outputs_per_step * hparams.downsample_step)
attn_loss = criterion.attention_loss(alignments,
input_lengths.numpy(),
decoder_length)
if writer is not None and local_rank == 0:
writer.add_scalar("attention_loss",
float(attn_loss.numpy()), global_step)
if not (args.train_seq2seq_only or args.train_postnet_only):
if hparams.use_guided_attention:
loss = lin_loss + mel_loss + done_loss + attn_loss
else:
loss = lin_loss + mel_loss + done_loss
elif args.train_seq2seq_only:
if hparams.use_guided_attention:
loss = mel_loss + done_loss + attn_loss
else:
loss = mel_loss + done_loss
else:
loss = lin_loss
if writer is not None and local_rank == 0:
writer.add_scalar("loss", float(loss.numpy()), global_step)
if isinstance(optimizer._learning_rate,
fluid.optimizer.LearningRateDecay):
current_lr = optimizer._learning_rate.step().numpy()
else:
current_lr = optimizer._learning_rate
if writer is not None and local_rank == 0:
writer.add_scalar("learning_rate", current_lr, global_step)
epoch_loss += loss.numpy()[0]
if (local_rank == 0 and global_step > 0 and
global_step % hparams.checkpoint_interval == 0):
save_states(global_step, writer, mel_outputs, linear_outputs,
alignments, mel, linear,
input_lengths.numpy(), checkpoint_dir)
step_path = os.path.join(
checkpoint_dir, "checkpoint_{:09d}".format(global_step))
dg.save_dygraph(model.state_dict(), step_path)
dg.save_dygraph(optimizer.state_dict(), step_path)
if (local_rank == 0 and global_step > 0 and
global_step % hparams.eval_interval == 0):
eval_model(global_step, writer, model, checkpoint_dir,
ismultispeaker)
if args.use_data_parallel:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
else:
loss.backward()
if not (args.train_seq2seq_only or args.train_postnet_only):
param_list = model.parameters()
elif args.train_seq2seq_only:
if ismultispeaker:
param_list = chain(model.speaker_embedding.parameters(),
model.seq2seq.parameters())
else:
param_list = model.seq2seq.parameters()
else:
if ismultispeaker:
param_list = chain(model.speaker_embedding.parameters(),
model.seq2seq.parameters())
else:
param_list = model.converter.parameters()
optimizer.minimize(
loss, grad_clip=clipper, parameter_list=param_list)
if not (args.train_seq2seq_only or args.train_postnet_only):
model.clear_gradients()
elif args.train_seq2seq_only:
if ismultispeaker:
model.speaker_embedding.clear_gradients()
model.seq2seq.clear_gradients()
else:
if ismultispeaker:
model.speaker_embedding.clear_gradients()
model.converter.clear_gradients()
global_step += 1
average_loss_in_epoch = epoch_loss / (step + 1)
print("Epoch loss: {}".format(average_loss_in_epoch))
if writer is not None and local_rank == 0:
writer.add_scalar("average_loss_in_epoch", average_loss_in_epoch,
global_epoch)
ce_loss = average_loss_in_epoch
global_epoch += 1
end_time = time.time()
epoch_time = (end_time - start_time) / global_epoch
print("kpis\teach_epoch_duration_frame%s_card%s\t%s" %
(hparams.outputs_per_step, n_trainers, epoch_time))
print("kpis\ttrain_cost_frame%s_card%s\t%f" %
(hparams.outputs_per_step, n_trainers, ce_loss))

0
hparam_tf/__init__.py Normal file
View File

731
hparam_tf/hparam.py Normal file
View File

@ -0,0 +1,731 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparameter values."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numbers
import re
import six
## from tensorflow.contrib.training.python.training import hparam_pb2
## from tensorflow.python.framework import ops
## from tensorflow.python.util import compat
## from tensorflow.python.util import deprecation
# Define the regular expression for parsing a single clause of the input
# (delimited by commas). A legal clause looks like:
# <variable name>[<index>]? = <rhs>
# where <rhs> is either a single token or [] enclosed list of tokens.
# For example: "var[1] = a" or "x = [1,2,3]"
PARAM_RE = re.compile(r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
\s*=\s*
((?P<val>[^,\[]*) # single value: "a" or None
|
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
($|,\s*)""", re.VERBOSE)
def _parse_fail(name, var_type, value, values):
"""Helper function for raising a value error for bad assignment."""
raise ValueError(
'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
(name, var_type.__name__, value, values))
def _reuse_fail(name, values):
"""Helper function for raising a value error for reuse of name."""
raise ValueError('Multiple assignments to variable \'%s\' in %s' %
(name, values))
def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
results_dictionary):
"""Update results_dictionary with a scalar value.
Used to update the results_dictionary to be returned by parse_values when
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
Mutates results_dictionary.
Args:
name: Name of variable in assignment ("s" or "arr").
parse_fn: Function for parsing the actual value.
var_type: Type of named variable.
m_dict: Dictionary constructed from regex parsing.
m_dict['val']: RHS value (scalar)
m_dict['index']: List index value (or None)
values: Full expression being parsed
results_dictionary: The dictionary being updated for return by the parsing
function.
Raises:
ValueError: If the name has already been used.
"""
try:
parsed_value = parse_fn(m_dict['val'])
except ValueError:
_parse_fail(name, var_type, m_dict['val'], values)
# If no index is provided
if not m_dict['index']:
if name in results_dictionary:
_reuse_fail(name, values)
results_dictionary[name] = parsed_value
else:
if name in results_dictionary:
# The name has already been used as a scalar, then it
# will be in this dictionary and map to a non-dictionary.
if not isinstance(results_dictionary.get(name), dict):
_reuse_fail(name, values)
else:
results_dictionary[name] = {}
index = int(m_dict['index'])
# Make sure the index position hasn't already been assigned a value.
if index in results_dictionary[name]:
_reuse_fail('{}[{}]'.format(name, index), values)
results_dictionary[name][index] = parsed_value
def _process_list_value(name, parse_fn, var_type, m_dict, values,
results_dictionary):
"""Update results_dictionary from a list of values.
Used to update results_dictionary to be returned by parse_values when
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
Mutates results_dictionary.
Args:
name: Name of variable in assignment ("arr").
parse_fn: Function for parsing individual values.
var_type: Type of named variable.
m_dict: Dictionary constructed from regex parsing.
m_dict['val']: RHS value (scalar)
values: Full expression being parsed
results_dictionary: The dictionary being updated for return by the parsing
function.
Raises:
ValueError: If the name has an index or the values cannot be parsed.
"""
if m_dict['index'] is not None:
raise ValueError('Assignment of a list to a list index.')
elements = filter(None, re.split('[ ,]', m_dict['vals']))
# Make sure the name hasn't already been assigned a value
if name in results_dictionary:
raise _reuse_fail(name, values)
try:
results_dictionary[name] = [parse_fn(e) for e in elements]
except ValueError:
_parse_fail(name, var_type, m_dict['vals'], values)
def _cast_to_type_if_compatible(name, param_type, value):
"""Cast hparam to the provided type, if compatible.
Args:
name: Name of the hparam to be cast.
param_type: The type of the hparam.
value: The value to be cast, if compatible.
Returns:
The result of casting `value` to `param_type`.
Raises:
ValueError: If the type of `value` is not compatible with param_type.
* If `param_type` is a string type, but `value` is not.
* If `param_type` is a boolean, but `value` is not, or vice versa.
* If `param_type` is an integer type, but `value` is not.
* If `param_type` is a float type, but `value` is not a numeric type.
"""
fail_msg = ("Could not cast hparam '%s' of type '%s' from value %r" %
(name, param_type, value))
# Some callers use None, for which we can't do any casting/checking. :(
if issubclass(param_type, type(None)):
return value
# Avoid converting a non-string type to a string.
if (issubclass(param_type, (six.string_types, six.binary_type)) and
not isinstance(value, (six.string_types, six.binary_type))):
raise ValueError(fail_msg)
# Avoid converting a number or string type to a boolean or vice versa.
if issubclass(param_type, bool) != isinstance(value, bool):
raise ValueError(fail_msg)
# Avoid converting float to an integer (the reverse is fine).
if (issubclass(param_type, numbers.Integral) and
not isinstance(value, numbers.Integral)):
raise ValueError(fail_msg)
# Avoid converting a non-numeric type to a numeric type.
if (issubclass(param_type, numbers.Number) and
not isinstance(value, numbers.Number)):
raise ValueError(fail_msg)
return param_type(value)
def parse_values(values, type_map):
"""Parses hyperparameter values from a string into a python map.
`values` is a string containing comma-separated `name=value` pairs.
For each pair, the value of the hyperparameter named `name` is set to
`value`.
If a hyperparameter name appears multiple times in `values`, a ValueError
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
If a hyperparameter name in both an index assignment and scalar assignment,
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
The hyperparameter name may contain '.' symbols, which will result in an
attribute name that is only accessible through the getattr and setattr
functions. (And must be first explicit added through add_hparam.)
WARNING: Use of '.' in your variable names is allowed, but is not well
supported and not recommended.
The `value` in `name=value` must follows the syntax according to the
type of the parameter:
* Scalar integer: A Python-parsable integer point value. E.g.: 1,
100, -12.
* Scalar float: A Python-parsable floating point value. E.g.: 1.0,
-.54e89.
* Boolean: Either true or false.
* Scalar string: A non-empty sequence of characters, excluding comma,
spaces, and square brackets. E.g.: foo, bar_1.
* List: A comma separated list of scalar values of the parameter type
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
When index assignment is used, the corresponding type_map key should be the
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
"arr[1]").
Args:
values: String. Comma separated list of `name=value` pairs where
'value' must follow the syntax described above.
type_map: A dictionary mapping hyperparameter names to types. Note every
parameter name in values must be a key in type_map. The values must
conform to the types indicated, where a value V is said to conform to a
type T if either V has type T, or V is a list of elements of type T.
Hence, for a multidimensional parameter 'x' taking float values,
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
Returns:
A python map mapping each name to either:
* A scalar value.
* A list of scalar values.
* A dictionary mapping index numbers to scalar values.
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
Raises:
ValueError: If there is a problem with input.
* If `values` cannot be parsed.
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
'a[1]=1,a[1]=2', or 'a=1,a=[1]')
"""
results_dictionary = {}
pos = 0
while pos < len(values):
m = PARAM_RE.match(values, pos)
if not m:
raise ValueError('Malformed hyperparameter value: %s' %
values[pos:])
# Check that there is a comma between parameters and move past it.
pos = m.end()
# Parse the values.
m_dict = m.groupdict()
name = m_dict['name']
if name not in type_map:
raise ValueError('Unknown hyperparameter type for %s' % name)
type_ = type_map[name]
# Set up correct parsing function (depending on whether type_ is a bool)
if type_ == bool:
def parse_bool(value):
if value in ['true', 'True']:
return True
elif value in ['false', 'False']:
return False
else:
try:
return bool(int(value))
except ValueError:
_parse_fail(name, type_, value, values)
parse = parse_bool
else:
parse = type_
# If a singe value is provided
if m_dict['val'] is not None:
_process_scalar_value(name, parse, type_, m_dict, values,
results_dictionary)
# If the assigned value is a list:
elif m_dict['vals'] is not None:
_process_list_value(name, parse, type_, m_dict, values,
results_dictionary)
else: # Not assigned a list or value
_parse_fail(name, type_, '', values)
return results_dictionary
class HParams(object):
"""Class to hold a set of hyperparameters as name-value pairs.
A `HParams` object holds hyperparameters used to build and train a model,
such as the number of hidden units in a neural net layer or the learning rate
to use when training.
You first create a `HParams` object by specifying the names and values of the
hyperparameters.
To make them easily accessible the parameter names are added as direct
attributes of the class. A typical usage is as follows:
```python
# Create a HParams object specifying names and values of the model
# hyperparameters:
hparams = HParams(learning_rate=0.1, num_hidden_units=100)
# The hyperparameter are available as attributes of the HParams object:
hparams.learning_rate ==> 0.1
hparams.num_hidden_units ==> 100
```
Hyperparameters have type, which is inferred from the type of their value
passed at construction type. The currently supported types are: integer,
float, boolean, string, and list of integer, float, boolean, or string.
You can override hyperparameter values by calling the
[`parse()`](#HParams.parse) method, passing a string of comma separated
`name=value` pairs. This is intended to make it possible to override
any hyperparameter values from a single command-line flag to which
the user passes 'hyper-param=value' pairs. It avoids having to define
one flag for each hyperparameter.
The syntax expected for each value depends on the type of the parameter.
See `parse()` for a description of the syntax.
Example:
```python
# Define a command line flag to pass name=value pairs.
# For example using argparse:
import argparse
parser = argparse.ArgumentParser(description='Train my model.')
parser.add_argument('--hparams', type=str,
help='Comma separated list of "name=value" pairs.')
args = parser.parse_args()
...
def my_program():
# Create a HParams object specifying the names and values of the
# model hyperparameters:
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
activations=['relu', 'tanh'])
# Override hyperparameters values by parsing the command line
hparams.parse(args.hparams)
# If the user passed `--hparams=learning_rate=0.3` on the command line
# then 'hparams' has the following attributes:
hparams.learning_rate ==> 0.3
hparams.num_hidden_units ==> 100
hparams.activations ==> ['relu', 'tanh']
# If the hyperparameters are in json format use parse_json:
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
```
"""
_HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
def __init__(self, hparam_def=None, model_structure=None, **kwargs):
"""Create an instance of `HParams` from keyword arguments.
The keyword arguments specify name-values pairs for the hyperparameters.
The parameter types are inferred from the type of the values passed.
The parameter names are added as attributes of `HParams` object, so they
can be accessed directly with the dot notation `hparams._name_`.
Example:
```python
# Define 3 hyperparameters: 'learning_rate' is a float parameter,
# 'num_hidden_units' an integer parameter, and 'activation' a string
# parameter.
hparams = tf.HParams(
learning_rate=0.1, num_hidden_units=100, activation='relu')
hparams.activation ==> 'relu'
```
Note that a few names are reserved and cannot be used as hyperparameter
names. If you use one of the reserved name the constructor raises a
`ValueError`.
Args:
hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef
protocol buffer. If provided, this object is initialized by
deserializing hparam_def. Otherwise **kwargs is used.
model_structure: An instance of ModelStructure, defining the feature
crosses to be used in the Trial.
**kwargs: Key-value pairs where the key is the hyperparameter name and
the value is the value for the parameter.
Raises:
ValueError: If both `hparam_def` and initialization values are provided,
or if one of the arguments is invalid.
"""
# Register the hyperparameters and their type in _hparam_types.
# This simplifies the implementation of parse().
# _hparam_types maps the parameter name to a tuple (type, bool).
# The type value is the type of the parameter for scalar hyperparameters,
# or the type of the list elements for multidimensional hyperparameters.
# The bool value is True if the value is a list, False otherwise.
self._hparam_types = {}
self._model_structure = model_structure
if hparam_def:
## self._init_from_proto(hparam_def)
## if kwargs:
## raise ValueError('hparam_def and initialization values are '
## 'mutually exclusive')
raise ValueError('hparam_def has been disabled in this version')
else:
for name, value in six.iteritems(kwargs):
self.add_hparam(name, value)
## def _init_from_proto(self, hparam_def):
## """Creates a new HParams from `HParamDef` protocol buffer.
##
## Args:
## hparam_def: `HParamDef` protocol buffer.
## """
## assert isinstance(hparam_def, hparam_pb2.HParamDef)
## for name, value in hparam_def.hparam.items():
## kind = value.WhichOneof('kind')
## if kind.endswith('_value'):
## # Single value.
## if kind.startswith('int64'):
## # Setting attribute value to be 'int' to ensure the type is compatible
## # with both Python2 and Python3.
## self.add_hparam(name, int(getattr(value, kind)))
## elif kind.startswith('bytes'):
## # Setting attribute value to be 'str' to ensure the type is compatible
## # with both Python2 and Python3. UTF-8 encoding is assumed.
## self.add_hparam(name, compat.as_str(getattr(value, kind)))
## else:
## self.add_hparam(name, getattr(value, kind))
## else:
## # List of values.
## if kind.startswith('int64'):
## # Setting attribute value to be 'int' to ensure the type is compatible
## # with both Python2 and Python3.
## self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
## elif kind.startswith('bytes'):
## # Setting attribute value to be 'str' to ensure the type is compatible
## # with both Python2 and Python3. UTF-8 encoding is assumed.
## self.add_hparam(
## name, [compat.as_str(v) for v in getattr(value, kind).value])
## else:
## self.add_hparam(name, [v for v in getattr(value, kind).value])
def add_hparam(self, name, value):
"""Adds {name, value} pair to hyperparameters.
Args:
name: Name of the hyperparameter.
value: Value of the hyperparameter. Can be one of the following types:
int, float, string, int list, float list, or string list.
Raises:
ValueError: if one of the arguments is invalid.
"""
# Keys in kwargs are unique, but 'name' could the name of a pre-existing
# attribute of this object. In that case we refuse to use it as a
# hyperparameter name.
if getattr(self, name, None) is not None:
raise ValueError('Hyperparameter name is reserved: %s' % name)
if isinstance(value, (list, tuple)):
if not value:
raise ValueError(
'Multi-valued hyperparameters cannot be empty: %s' % name)
self._hparam_types[name] = (type(value[0]), True)
else:
self._hparam_types[name] = (type(value), False)
setattr(self, name, value)
def set_hparam(self, name, value):
"""Set the value of an existing hyperparameter.
This function verifies that the type of the value matches the type of the
existing hyperparameter.
Args:
name: Name of the hyperparameter.
value: New value of the hyperparameter.
Raises:
ValueError: If there is a type mismatch.
"""
param_type, is_list = self._hparam_types[name]
if isinstance(value, list):
if not is_list:
raise ValueError(
'Must not pass a list for single-valued parameter: %s' %
name)
setattr(self, name, [
_cast_to_type_if_compatible(name, param_type, v) for v in value
])
else:
if is_list:
raise ValueError(
'Must pass a list for multi-valued parameter: %s.' % name)
setattr(self, name,
_cast_to_type_if_compatible(name, param_type, value))
def del_hparam(self, name):
"""Removes the hyperparameter with key 'name'.
Args:
name: Name of the hyperparameter.
"""
if hasattr(self, name):
delattr(self, name)
del self._hparam_types[name]
def parse(self, values):
"""Override hyperparameter values, parsing new values from a string.
See parse_values for more detail on the allowed format for values.
Args:
values: String. Comma separated list of `name=value` pairs where
'value' must follow the syntax described above.
Returns:
The `HParams` instance.
Raises:
ValueError: If `values` cannot be parsed.
"""
type_map = dict()
for name, t in self._hparam_types.items():
param_type, _ = t
type_map[name] = param_type
values_map = parse_values(values, type_map)
return self.override_from_dict(values_map)
def override_from_dict(self, values_dict):
"""Override hyperparameter values, parsing new values from a dictionary.
Args:
values_dict: Dictionary of name:value pairs.
Returns:
The `HParams` instance.
Raises:
ValueError: If `values_dict` cannot be parsed.
"""
for name, value in values_dict.items():
self.set_hparam(name, value)
return self
## @deprecation.deprecated(None, 'Use `override_from_dict`.')
def set_from_map(self, values_map):
"""DEPRECATED. Use override_from_dict."""
return self.override_from_dict(values_dict=values_map)
def set_model_structure(self, model_structure):
self._model_structure = model_structure
def get_model_structure(self):
return self._model_structure
def to_json(self, indent=None, separators=None, sort_keys=False):
"""Serializes the hyperparameters into JSON.
Args:
indent: If a non-negative integer, JSON array elements and object members
will be pretty-printed with that indent level. An indent level of 0, or
negative, will only insert newlines. `None` (the default) selects the
most compact representation.
separators: Optional `(item_separator, key_separator)` tuple. Default is
`(', ', ': ')`.
sort_keys: If `True`, the output dictionaries will be sorted by key.
Returns:
A JSON string.
"""
return json.dumps(
self.values(),
indent=indent,
separators=separators,
sort_keys=sort_keys)
def parse_json(self, values_json):
"""Override hyperparameter values, parsing new values from a json object.
Args:
values_json: String containing a json object of name:value pairs.
Returns:
The `HParams` instance.
Raises:
ValueError: If `values_json` cannot be parsed.
"""
values_map = json.loads(values_json)
return self.override_from_dict(values_map)
def values(self):
"""Return the hyperparameter values as a Python dictionary.
Returns:
A dictionary with hyperparameter names as keys. The values are the
hyperparameter values.
"""
return {n: getattr(self, n) for n in self._hparam_types.keys()}
def get(self, key, default=None):
"""Returns the value of `key` if it exists, else `default`."""
if key in self._hparam_types:
# Ensure that default is compatible with the parameter type.
if default is not None:
param_type, is_param_list = self._hparam_types[key]
type_str = 'list<%s>' % param_type if is_param_list else str(
param_type)
fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
'default=%s' % (key, type_str, default))
is_default_list = isinstance(default, list)
if is_param_list != is_default_list:
raise ValueError(fail_msg)
try:
if is_default_list:
for value in default:
_cast_to_type_if_compatible(key, param_type, value)
else:
_cast_to_type_if_compatible(key, param_type, default)
except ValueError as e:
raise ValueError('%s. %s' % (fail_msg, e))
return getattr(self, key)
return default
def __contains__(self, key):
return key in self._hparam_types
def __str__(self):
return str(sorted(self.values().items()))
def __repr__(self):
return '%s(%s)' % (type(self).__name__, self.__str__())
@staticmethod
def _get_kind_name(param_type, is_list):
"""Returns the field name given parameter type and is_list.
Args:
param_type: Data type of the hparam.
is_list: Whether this is a list.
Returns:
A string representation of the field name.
Raises:
ValueError: If parameter type is not recognized.
"""
if issubclass(param_type, bool):
# This check must happen before issubclass(param_type, six.integer_types),
# since Python considers bool to be a subclass of int.
typename = 'bool'
elif issubclass(param_type, six.integer_types):
# Setting 'int' and 'long' types to be 'int64' to ensure the type is
# compatible with both Python2 and Python3.
typename = 'int64'
elif issubclass(param_type, (six.string_types, six.binary_type)):
# Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
# compatible with both Python2 and Python3.
typename = 'bytes'
elif issubclass(param_type, float):
typename = 'float'
else:
raise ValueError('Unsupported parameter type: %s' % str(param_type))
suffix = 'list' if is_list else 'value'
return '_'.join([typename, suffix])
## def to_proto(self, export_scope=None): # pylint: disable=unused-argument
## """Converts a `HParams` object to a `HParamDef` protocol buffer.
##
## Args:
## export_scope: Optional `string`. Name scope to remove.
##
## Returns:
## A `HParamDef` protocol buffer.
## """
## hparam_proto = hparam_pb2.HParamDef()
## for name in self._hparam_types:
## # Parse the values.
## param_type, is_list = self._hparam_types.get(name, (None, None))
## kind = HParams._get_kind_name(param_type, is_list)
##
## if is_list:
## if kind.startswith('bytes'):
## v_list = [compat.as_bytes(v) for v in getattr(self, name)]
## else:
## v_list = [v for v in getattr(self, name)]
## getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
## else:
## v = getattr(self, name)
## if kind.startswith('bytes'):
## v = compat.as_bytes(getattr(self, name))
## setattr(hparam_proto.hparam[name], kind, v)
##
## return hparam_proto
## @staticmethod
## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument
## return HParams(hparam_def=hparam_def)
## ops.register_proto_function(
## 'hparams',
## proto_type=hparam_pb2.HParamDef,
## to_proto=HParams.to_proto,
## from_proto=HParams.from_proto)

8
hparam_tf/readme.md Normal file
View File

@ -0,0 +1,8 @@
Source: hparam.py copied from tensorflow v1.12.0.
https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
with the following:
wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py
Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project.

0
modules/__init__.py Normal file
View File

222
modules/conv.py Normal file
View File

@ -0,0 +1,222 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
from weight_norm import Conv2D, Conv2DTranspose
class Conv1D(dg.Layer):
"""
A convolution 1D block implemented with Conv2D. Form simplicity and
ensuring the output has the same length as the input, it does not allow
stride > 1.
"""
def __init__(self,
name_scope,
in_cahnnels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1D, self).__init__(name_scope, dtype=dtype)
if causal:
padding = dilation * (filter_size - 1)
else:
padding = (dilation * (filter_size - 1)) // 2
self.in_channels = in_cahnnels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.padding = padding
self.act = act
self.conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=(1, filter_size),
stride=(1, 1),
dilation=(1, dilation),
padding=(0, padding),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
output channels (num_filters).
"""
x = self.conv(x)
if self.filter_size > 1:
if self.causal:
x = fluid.layers.slice(
x, axes=[3], starts=[0], ends=[-self.padding])
elif self.filter_size % 2 == 0:
x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1])
return x
def start_new_sequence(self):
self.temp_weight = None
self.input_buffer = None
def add_input(self, x):
"""
Adding input for a time step and compute an output for a time step.
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels, and T = 1.
Returns:
out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out
means output channels (num_filters), and T = 1.
"""
if self.temp_weight is None:
self.temp_weight = self._reshaped_weight()
window_size = 1 + (self.filter_size - 1) * self.dilation
batch_size = x.shape[0]
in_channels = x.shape[1]
if self.filter_size > 1:
if self.input_buffer is None:
self.input_buffer = fluid.layers.fill_constant(
[batch_size, in_channels, 1, window_size - 1],
dtype=x.dtype,
value=0.0)
else:
self.input_buffer = self.input_buffer[:, :, :, 1:]
self.input_buffer = fluid.layers.concat(
[self.input_buffer, x], axis=3)
x = self.input_buffer
if self.dilation > 1:
if not hasattr(self, "indices"):
self.indices = dg.to_variable(
np.arange(0, window_size, self.dilation))
tmp = fluid.layers.transpose(
self.input_buffer, perm=[3, 1, 2, 0])
tmp = fluid.layers.gather(tmp, index=self.indices)
tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0])
x = tmp
inputs = fluid.layers.reshape(
x, shape=[batch_size, in_channels * 1 * self.filter_size])
out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True)
out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1)
out = fluid.layers.reshape(out, out.shape + [1, 1])
out = self._helper.append_activation(out, act=self.act)
return out
def _reshaped_weight(self):
"""
Get the linearized weight of convolution filter, cause it is by nature
a matmul weight. And because the model uses weight norm, compute the
weight by weight_v * weight_g to make it faster.
Returns:
weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size)
"""
shape = self.conv._filter_param_v.shape
matrix_shape = [shape[0], np.prod(shape[1:])]
weight_matrix = fluid.layers.reshape(
self.conv._filter_param_v, shape=matrix_shape)
weight_matrix = fluid.layers.elementwise_mul(
fluid.layers.l2_normalize(
weight_matrix, axis=1),
self.conv._filter_param_g,
axis=0)
return weight_matrix
class Conv1DTranspose(dg.Layer):
"""
A convolutional transpose 1D block implemented with convolutional transpose
2D. It does not ensure that the output is exactly expanded stride times in
time dimension.
"""
def __init__(self,
name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1DTranspose, self).__init__(name_scope, dtype=dtype)
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.padding = padding
self.stride = stride
self.dilation = dilation
self.groups = groups
self.conv_transpose = Conv2DTranspose(
self.full_name(),
num_filters,
filter_size=(1, filter_size),
padding=(0, padding),
stride=(1, stride),
dilation=(1, dilation),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Argss:
x (Variable): Shape(B, C_in, 1, T_in), where C_in means the input
channels and T_in means the number of time steps of input.
Returns:
out (Variable): shape(B, C_out, 1, T_out), where C_out means the
output channels and T_out means the number of time steps of
input.
"""
return self.conv_transpose(x)

View File

@ -0,0 +1 @@
This package is adapted from https://github.com/r9y9/deepvoice3_pytorch/tree/master/deepvoice3_pytorch/frontend, Copyright (c) 2017: Ryuichi Yamamoto, whose license applies.

View File

@ -0,0 +1,33 @@
# coding: utf-8
"""Text processing frontend
All frontend module should have the following functions:
- text_to_sequence(text, p)
- sequence_to_text(sequence)
and the property:
- n_vocab
"""
from . import en
# optinoal Japanese frontend
try:
from . import jp
except ImportError:
jp = None
try:
from . import ko
except ImportError:
ko = None
# if you are going to use the frontend, you need to modify _characters in
# symbol.py:
# _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + '¡¿ñáéíóúÁÉÍÓÚÑ'
try:
from . import es
except ImportError:
es = None

View File

@ -0,0 +1,35 @@
# coding: utf-8
from modules.frontend.text.symbols import symbols
import nltk
from random import random
n_vocab = len(symbols)
_arpabet = nltk.corpus.cmudict.dict()
def _maybe_get_arpabet(word, p):
try:
phonemes = _arpabet[word][0]
phonemes = " ".join(phonemes)
except KeyError:
return word
return '{%s}' % phonemes if random() < p else word
def mix_pronunciation(text, p):
text = ' '.join(_maybe_get_arpabet(word, p) for word in text.split(' '))
return text
def text_to_sequence(text, p=0.0):
if p >= 0:
text = mix_pronunciation(text, p)
from modules.frontend.text import text_to_sequence
text = text_to_sequence(text, ["english_cleaners"])
return text
from modules.frontend.text import sequence_to_text

View File

@ -0,0 +1,16 @@
# coding: utf-8
from deepvoice3_paddle.frontend.text.symbols import symbols
import nltk
from random import random
n_vocab = len(symbols)
def text_to_sequence(text, p=0.0):
from deepvoice3_paddle.frontend.text import text_to_sequence
text = text_to_sequence(text, ["basic_cleaners"])
return text
from deepvoice3_paddle.frontend.text import sequence_to_text

View File

@ -0,0 +1,77 @@
# coding: utf-8
import MeCab
import jaconv
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def _yomi(mecab_result):
tokens = []
yomis = []
for line in mecab_result.split("\n")[:-1]:
s = line.split("\t")
if len(s) == 1:
break
token, rest = s
rest = rest.split(",")
tokens.append(token)
yomi = rest[7] if len(rest) > 7 else None
yomi = None if yomi == "*" else yomi
yomis.append(yomi)
return tokens, yomis
def _mix_pronunciation(tokens, yomis, p):
return "".join(yomis[idx]
if yomis[idx] is not None and random() < p else tokens[idx]
for idx in range(len(tokens)))
def mix_pronunciation(text, p):
global _tagger
if _tagger is None:
_tagger = MeCab.Tagger("")
tokens, yomis = _yomi(_tagger.parse(text))
return _mix_pronunciation(tokens, yomis, p)
def add_punctuation(text):
last = text[-1]
if last not in [".", ",", "", "", "", "", "!", "?"]:
text = text + ""
return text
def normalize_delimitor(text):
text = text.replace(",", "")
text = text.replace(".", "")
text = text.replace("", "")
text = text.replace("", "")
return text
def text_to_sequence(text, p=0.0):
for c in [" ", " ", "", "", "", "", "", "", "", "", "", "(", ")"]:
text = text.replace(c, "")
text = text.replace("!", "")
text = text.replace("?", "")
text = normalize_delimitor(text)
text = jaconv.normalize(text)
if p > 0:
text = mix_pronunciation(text, p)
text = jaconv.hira2kata(text)
text = add_punctuation(text)
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)

View File

@ -0,0 +1,17 @@
# coding: utf-8
from random import random
n_vocab = 0xffff
_eos = 1
_pad = 0
_tagger = None
def text_to_sequence(text, p=0.0):
return [ord(c) for c in text] + [_eos] # EOS
def sequence_to_text(seq):
return "".join(chr(n) for n in seq)

View File

@ -0,0 +1,74 @@
import re
from . import cleaners
from .symbols import symbols
# Mappings from symbol to numeric ID and vice versa:
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
def text_to_sequence(text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# Append EOS token
sequence.append(_symbol_to_id['~'])
return sequence
def sequence_to_text(sequence):
'''Converts a sequence of IDs back to a string'''
result = ''
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
result += s
return result.replace('}{', ' ')
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s is not '_' and s is not '~'

View File

@ -0,0 +1,104 @@
'''
Cleaners are transformations that run over the input text at both training and
eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as
the "cleaners" hyperparameter. Some cleaners are English-specific. You'll
typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated
to ASCII using the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you
should also update the symbols in symbols.py to match your data).
'''
import re
from unidecode import unidecode
from .numbers import normalize_numbers
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def add_punctuation(text):
if len(text) == 0:
return text
if text[-1] not in '!,.:;?':
text = text + '.' # without this decoder is confused when to output EOS
return text
def basic_cleaners(text):
'''
Basic pipeline that lowercases and collapses whitespace without
transliteration.
'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''
Pipeline for English text, including number and abbreviation expansion.
'''
text = convert_to_ascii(text)
text = add_punctuation(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text

View File

@ -0,0 +1,67 @@
import re
valid_symbols = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1',
'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T',
'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y',
'Z', 'ZH'
]
_valid_symbol_set = set(valid_symbols)
class CMUDict:
'''
Thin wrapper around CMUDict data.
http://www.speech.cs.cmu.edu/cgi-bin/cmudict
'''
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {
word: pron
for word, pron in entries.items() if len(pron) == 1
}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
return self._entries.get(word.upper())
_alt_re = re.compile(r'\([0-9]+\)')
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(' ')
for part in parts:
if part not in _valid_symbol_set:
return None
return ' '.join(parts)

View File

@ -0,0 +1,71 @@
# -*- coding: utf-8 -*-
import inflect
import re
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(
num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text

View File

@ -0,0 +1,18 @@
'''
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text
that has been run through Unidecode. For other data, you can modify _characters.
See TRAINING_DATA.md for details.
'''
from .cmudict import valid_symbols
_pad = '_'
_eos = '~'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad, _eos] + list(_characters) + _arpabet

158
modules/loss.py Normal file
View File

@ -0,0 +1,158 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from numba import jit
from paddle import fluid
import paddle.fluid.dygraph as dg
def masked_mean(inputs, mask):
"""
Args:
inputs (Variable): Shape(B, C, 1, T), the input, where B means
batch size, C means channels of input, T means timesteps of
the input.
mask (Variable): Shape(B, T), a mask.
Returns:
loss (Variable): Shape(1, ), masked mean.
"""
channels = inputs.shape[1]
reshaped_mask = fluid.layers.reshape(
mask, shape=[mask.shape[0], 1, 1, mask.shape[-1]])
expanded_mask = fluid.layers.expand(
reshaped_mask, expand_times=[1, channels, 1, 1])
expanded_mask.stop_gradient = True
valid_cnt = fluid.layers.reduce_sum(expanded_mask)
valid_cnt.stop_gradient = True
masked_inputs = inputs * expanded_mask
loss = fluid.layers.reduce_sum(masked_inputs) / valid_cnt
return loss
@jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
W = np.zeros((max_N, max_T), dtype=np.float32)
for n in range(N):
for t in range(T):
W[n, t] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
return W
def guided_attentions(input_lengths, target_lengths, max_target_len, g=0.2):
B = len(input_lengths)
max_input_len = input_lengths.max()
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
for b in range(B):
W[b] = guided_attention(input_lengths[b], max_input_len,
target_lengths[b], max_target_len, g).T
return W
class TTSLoss(object):
def __init__(self,
masked_weight=0.0,
priority_weight=0.0,
binary_divergence_weight=0.0,
guided_attention_sigma=0.2):
self.masked_weight = masked_weight
self.priority_weight = priority_weight
self.binary_divergence_weight = binary_divergence_weight
self.guided_attention_sigma = guided_attention_sigma
def l1_loss(self, prediction, target, mask, priority_bin=None):
abs_diff = fluid.layers.abs(prediction - target)
# basic mask-weighted l1 loss
w = self.masked_weight
if w > 0 and mask is not None:
base_l1_loss = w * masked_mean(abs_diff, mask) + (
1 - w) * fluid.layers.reduce_mean(abs_diff)
else:
base_l1_loss = fluid.layers.reduce_mean(abs_diff)
if self.priority_weight > 0 and priority_bin is not None:
# mask-weighted priority channels' l1-loss
priority_abs_diff = fluid.layers.slice(
abs_diff, axes=[1], starts=[0], ends=[priority_bin])
if w > 0 and mask is not None:
priority_loss = w * masked_mean(priority_abs_diff, mask) + (
1 - w) * fluid.layers.reduce_mean(priority_abs_diff)
else:
priority_loss = fluid.layers.reduce_mean(priority_abs_diff)
# priority weighted sum
p = self.priority_weight
loss = p * priority_loss + (1 - p) * base_l1_loss
else:
loss = base_l1_loss
return loss
def binary_divergence(self, prediction, target, mask):
flattened_prediction = fluid.layers.reshape(prediction, [-1, 1])
flattened_target = fluid.layers.reshape(target, [-1, 1])
flattened_loss = fluid.layers.log_loss(
flattened_prediction, flattened_target, epsilon=1e-8)
bin_div = fluid.layers.reshape(flattened_loss, prediction.shape)
w = self.masked_weight
if w > 0 and mask is not None:
loss = w * masked_mean(bin_div, mask) + (
1 - w) * fluid.layers.reduce_mean(bin_div)
else:
loss = fluid.layers.reduce_mean(bin_div)
return loss
@staticmethod
def done_loss(done_hat, done):
flat_done_hat = fluid.layers.reshape(done_hat, [-1, 1])
flat_done = fluid.layers.reshape(done, [-1, 1])
loss = fluid.layers.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
loss = fluid.layers.reduce_mean(loss)
return loss
def attention_loss(self, predicted_attention, input_lengths,
target_lengths):
"""
Given valid encoder_lengths and decoder_lengths, compute a diagonal
guide, and compute loss from the predicted attention and the guide.
Args:
predicted_attention (Variable): Shape(*, B, T_dec, T_enc), the
alignment tensor, where B means batch size, T_dec means number
of time steps of the decoder, T_enc means the number of time
steps of the encoder, * means other possible dimensions.
input_lengths (numpy.ndarray): Shape(B,), dtype:int64, valid lengths
(time steps) of encoder outputs.
target_lengths (numpy.ndarray): Shape(batch_size,), dtype:int64,
valid lengths (time steps) of decoder outputs.
Returns:
loss (Variable): Shape(1, ) attention loss.
"""
n_attention, batch_size, max_target_len, max_input_len = (
predicted_attention.shape)
soft_mask = guided_attentions(input_lengths, target_lengths,
max_target_len,
self.guided_attention_sigma)
soft_mask_ = dg.to_variable(soft_mask)
loss = fluid.layers.reduce_mean(predicted_attention * soft_mask_)
return loss

458
modules/modules.py Normal file
View File

@ -0,0 +1,458 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
import numpy as np
import conv
import weight_norm as weight_norm
def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
dropout=0.0,
epsilon=1e-30,
act=None,
is_test=False,
dtype="float32"):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if isinstance(in_features, int):
in_features = [in_features]
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.FC(name_scope,
size,
num_flatten_dims=num_flatten_dims,
param_attr=weight_attrs,
bias_attr=bias_attr,
act=act,
dtype=dtype)
return layer
def Conv1D(name_scope,
in_channels,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features)))
"""
# std
std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels))
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1D(
name_scope,
in_channels,
num_filters,
filter_size,
dilation,
groups=groups,
causal=causal,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def Embedding(name_scope,
num_embeddings,
embed_dim,
is_sparse=False,
is_distributed=False,
padding_idx=None,
std=0.01,
dtype="float32"):
# param attrs
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=std))
layer = dg.Embedding(
name_scope, (num_embeddings, embed_dim),
padding_idx=padding_idx,
param_attr=weight_attr,
dtype=dtype)
return layer
class Conv1DGLU(dg.Layer):
"""
A Convolution 1D block with GLU activation. It also applys dropout for the
input x. It fuses speaker embeddings through a FC activated by softsign. It
has residual connection from the input x, and scale the output by
np.sqrt(0.5).
"""
def __init__(self,
name_scope,
n_speakers,
speaker_dim,
in_channels,
num_filters,
filter_size,
dilation,
std_mul=4.0,
dropout=0.0,
causal=False,
residual=True,
dtype="float32"):
super(Conv1DGLU, self).__init__(name_scope, dtype=dtype)
# conv spec
self.in_channels = in_channels
self.n_speakers = n_speakers
self.speaker_dim = speaker_dim
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
# weight init and dropout
self.std_mul = std_mul
self.dropout = dropout
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channes should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
std_mul=std_mul,
dropout=dropout,
dtype=dtype)
if n_speakers > 1:
assert (speaker_dim is not None
), "speaker embed should not be null in multi-speaker case"
self.fc = Conv1D(
self.full_name(),
speaker_dim,
num_filters,
filter_size=1,
dilation=1,
causal=False,
act="softsign",
dtype=dtype)
def forward(self, x, speaker_embed_bc1t=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = fluid.layers.dropout(
x, self.dropout, dropout_implementation="upscale_in_train")
x = self.conv(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc1t is not None:
sp = self.fc(speaker_embed_bc1t)
content = content + sp
# glu
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def add_input(self, x, speaker_embed_bc11=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual = x
# add step input and produce step output
x = fluid.layers.dropout(
x, self.dropout, dropout_implementation="upscale_in_train")
x = self.conv.add_input(x)
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
if speaker_embed_bc11 is not None:
sp = self.fc(speaker_embed_bc11)
content = content + sp
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content)
if self.residual:
x = fluid.layers.scale(x + residual, np.sqrt(0.5))
return x
def Conv1DTranspose(name_scope,
in_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size))
weight_init = fluid.initializer.NormalInitializer(scale=std)
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_init = fluid.initializer.ConstantInitializer(0.0)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = conv.Conv1DTranspose(
name_scope,
in_channels,
num_filters,
filter_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
def compute_position_embedding(rad):
# rad is a transposed radius, shape(embed_dim, n_vocab)
embed_dim, n_vocab = rad.shape
even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32"))
odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32"))
even_rads = fluid.layers.gather(rad, even_dims)
odd_rads = fluid.layers.gather(rad, odd_dims)
sines = fluid.layers.sin(even_rads)
cosines = fluid.layers.cos(odd_rads)
temp = fluid.layers.scatter(rad, even_dims, sines)
out = fluid.layers.scatter(temp, odd_dims, cosines)
out = fluid.layers.transpose(out, perm=[1, 0])
return out
def position_encoding_init(n_position,
d_pos_vec,
position_rate=1.0,
sinusoidal=True):
""" Init the sinusoid position encoding table """
# keep idx 0 for padding token position encoding zero vector
position_enc = np.array([[
position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec)
for i in range(d_pos_vec)
] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])
if sinusoidal:
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
return position_enc
class PositionEmbedding(dg.Layer):
def __init__(self,
name_scope,
n_position,
d_pos_vec,
position_rate=1.0,
is_sparse=False,
is_distributed=False,
param_attr=None,
max_norm=None,
padding_idx=None,
dtype="float32"):
super(PositionEmbedding, self).__init__(name_scope, dtype=dtype)
self.embed = dg.Embedding(
self.full_name(),
size=(n_position, d_pos_vec),
is_sparse=is_sparse,
is_distributed=is_distributed,
padding_idx=None,
param_attr=param_attr,
dtype=dtype)
self.set_weight(
position_encoding_init(
n_position,
d_pos_vec,
position_rate=position_rate,
sinusoidal=False).astype(dtype))
self._is_sparse = is_sparse
self._is_distributed = is_distributed
self._remote_prefetch = self._is_sparse and (not self._is_distributed)
if self._remote_prefetch:
assert self._is_sparse is True and self._is_distributed is False
self._padding_idx = (-1 if padding_idx is None else padding_idx if
padding_idx >= 0 else (n_position + padding_idx))
self._position_rate = position_rate
self._max_norm = max_norm
self._dtype = dtype
def set_weight(self, array):
assert self.embed._w.shape == list(array.shape), "shape does not match"
self.embed._w._ivar.value().get_tensor().set(
array, fluid.framework._current_expected_place())
def forward(self, indices, speaker_position_rate=None):
"""
Args:
indices (Variable): Shape (B, T, 1), dtype: int64, position
indices, where B means the batch size, T means the time steps.
speaker_position_rate (Variable | float, optional), position
rate. It can be a float point number or a Variable with
shape (1,), then this speaker_position_rate is used for every
example. It can also be a Variable with shape (B, 1), which
contains a speaker position rate for each speaker.
Returns:
out (Variable): Shape(B, C_pos), position embedding, where C_pos
means position embedding size.
"""
rad = fluid.layers.transpose(self.embed._w, perm=[1, 0])
batch_size = indices.shape[0]
if speaker_position_rate is None:
weight = compute_position_embedding(rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif (np.isscalar(speaker_position_rate) or
isinstance(speaker_position_rate, fluid.framework.Variable) and
speaker_position_rate.shape == [1, 1]):
# # make a weight
# scale the weight (the operand for sin & cos)
if np.isscalar(speaker_position_rate):
scaled_rad = fluid.layers.scale(rad, speaker_position_rate)
else:
scaled_rad = fluid.layers.elementwise_mul(
rad, speaker_position_rate[0])
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="lookup_table",
inputs={"Ids": indices,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx":
self._padding_idx, # special value for lookup table op
})
return out
elif np.prod(speaker_position_rate.shape) > 1:
assert speaker_position_rate.shape == [batch_size, 1]
outputs = []
for i in range(batch_size):
rate = speaker_position_rate[i] # rate has shape [1]
scaled_rad = fluid.layers.elementwise_mul(rad, rate)
weight = compute_position_embedding(scaled_rad)
out = self._helper.create_variable_for_type_inference(
self._dtype)
sequence = indices[i]
self._helper.append_op(
type="lookup_table",
inputs={"Ids": sequence,
"W": weight},
outputs={"Out": out},
attrs={
"is_sparse": self._is_sparse,
"is_distributed": self._is_distributed,
"remote_prefetch": self._remote_prefetch,
"padding_idx": -1,
})
outputs.append(out)
out = fluid.layers.stack(outputs)
return out
else:
raise Exception("Then you can just use position rate at init")

863
modules/weight_norm.py Normal file
View File

@ -0,0 +1,863 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
from six.moves import reduce
from copy import deepcopy
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
from paddle.fluid import core
from paddle.fluid.layers import utils
from paddle.fluid.framework import Variable
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
def _norm(p, dim):
"""Computes the norm over all dimensions except dim.
It differs from pytorch implementation that it does not keep dim.
This difference is related with the broadcast mechanism in paddle.
Read elementeise_mul for more.
"""
if dim is None:
return np.linalg.norm(p, ord=2, axis=None)
elif dim == 0:
p = np.reshape(p, newshape=(p.shape[0], -1))
return np.linalg.norm(p, ord=2, axis=1)
elif dim == p.ndim - 1:
p = np.reshape(p, newshape=(-1, p.shape[-1]))
return np.linalg.norm(p, ord=2, axis=0)
else:
perm = list(range(p.ndim))
perm[0] = dim
perm[dim] = 0
return _norm(np.transpose(p, axes=perm))
class FC(dg.Layer):
"""
**Fully Connected Layer**
This function creates a fully connected layer in the network. It can take
one or multiple tensors as its inputs(input can be a list of Variable, see
Args in detail). It creates a pair of variables called (magnitude(g),
direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected
weight matrix from each input unit to each output unit.
The fully connected layer multiplies each input tensor
with its corresponding weight to produce an output Tensor with shape [M, `size`],
where M is batch size. If multiple input tensors are given, the results of
multiple output tensors with shape [M, `size`] will be summed up. If bias_attr
is not None, a bias variable will be created and added to the output.
Finally, if activation is not None, it will be applied to the output as well.
When the input is single tensor:
.. math::
Out = Act({X(normalize(V)g) + b})
When the input are multiple tensors:
.. math::
Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b})
In the above equation:
* :math:`N`: Number of the input. N equals to len(input) if input is list of Variable.
* :math:`X_i`: The i-th input tensor.
* :math:`V_i`: The i-th direction matrix corresponding i-th input tensor.
* :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation function.
* :math:`Out`: The output tensor.
See below for an example.
.. code-block:: text
Given:
data_1.data = [[[0.1, 0.2],
[0.3, 0.4]]]
data_1.shape = (1, 2, 2) # 1 is batch_size
data_2 = [[[0.1, 0.2, 0.3]]]
data_2.shape = (1, 1, 3)
out = fluid.layers.fc(input=[data_1, data_2], size=2)
Then:
out.data = [[0.18669507, 0.1893476]]
out.shape = (1, 2)
Args:
name_scope(str): The name of this class.
size(int): The number of output units in this layer.
num_flatten_dims (int): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional tensor will first be flattened
into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
dimensions will be flatten to form the first dimension of the final matrix (height of
the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
form the second dimension of the final matrix (width of the matrix). For example, suppose
`X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1
param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable
parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str|None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase. Default: False
dtype(str): Dtype used for weight
Raises:
ValueError: If rank of the input tensor is less than 2.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32')
with fluid.dygraph.guard():
fc = FC( "fc", 64, num_flatten_dims=2)
data = to_variable( data )
conv = fc( data )
"""
def __init__(self,
name_scope,
size,
num_flatten_dims=1,
epsilon=1e-30,
param_attr=None,
bias_attr=None,
act=None,
is_test=False,
dtype="float32"):
super(FC, self).__init__(name_scope, dtype)
self._size = size
self._num_flatten_dims = num_flatten_dims
self._epsilon = epsilon
self._dtype = dtype
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
self.__g = list()
self.__v = list()
@property
def _v(self, i=0):
return self.__v[i]
@property
def _g(self, i=0):
return self.__g[i]
@_v.setter
def _v(self, value, i=0):
assert isinstance(value, Parameter)
self.__v[i] = value
@_g.setter
def _g(self, value, i=0):
assert isinstance(value, Parameter)
self.__g[i] = value
def _build_once(self, input):
i = 0
for inp, param in self._helper.iter_inputs_and_params(input,
self._param_attr):
input_shape = inp.shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:],
1)
] + [self._size]
self.__v.append(
self.add_parameter(
"_v%d" % i,
self.create_parameter(
attr=param,
shape=param_shape,
dtype=self._dtype,
is_bias=False)))
magnitude_shape = param_shape[1:]
magnitude_value = np.linalg.norm(self.__v[i].numpy(), ord=2, axis=0)
self.__g.append(
self.add_parameter(
"_g%d" % i,
self.create_parameter(
attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
magnitude_value)),
shape=magnitude_shape,
dtype=self._dtype,
is_bias=False)))
i += 1
size = list([self._size])
self._b = self.create_parameter(
attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True)
def forward(self, input):
mul_results = list()
i = 0
for inp, param in self._helper.iter_inputs_and_params(input,
self._param_attr):
v_norm = self._helper.create_variable_for_type_inference(
self._dtype)
v_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="norm",
inputs={"X": self.__v[i]},
outputs={"Out": v_normalized,
"Norm": v_norm},
attrs={"axis": 0,
"epsilon": self._epsilon})
weight = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="elementwise_mul",
inputs={"X": [v_normalized],
"Y": [self.__g[i]]},
outputs={"Out": [weight]},
attrs={"axis": 1})
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": inp,
"Y": weight},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
i += 1
mul_results.append(tmp)
if len(mul_results) == 1:
pre_bias = mul_results[0]
else:
pre_bias = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": mul_results},
outputs={"Out": pre_bias},
attrs={"use_mkldnn": False})
if self._b:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type="elementwise_add",
inputs={"X": [pre_bias],
"Y": [self._b]},
outputs={"Out": [pre_activation]},
attrs={"axis": self._num_flatten_dims})
else:
pre_activation = pre_bias
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_activation, act=self._act)
class Conv2D(dg.Layer):
"""
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
channels, H is the height of the feature, and W is the width of the feature.
Filter is in MCHW format, where M is the number of output image channels,
C is the number of input image channels, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input image channels divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`
for more detials.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma ((Vg) \\ast X + b)
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`V`: Filter direction value, a tensor with MCHW format.
* :math:`g`: Filter magnitude value, a tensor with M format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
name_scope(str) : The name for this class.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
import numpy as np
data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32')
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
"""
def __init__(self,
name_scope,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
epsilon=1e-30,
dtype="float32"):
assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name_scope, dtype)
self._groups = groups
self._stride = utils.convert_to_list(stride, 2, "stride")
self._padding = utils.convert_to_list(padding, 2, "padding")
self._dilation = utils.convert_to_list(dilation, 2, "dilation")
self._act = act
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
self._use_cudnn = use_cudnn
self._filter_size = filter_size
self._num_filters = num_filters
self._param_attr = param_attr
self._bias_attr = bias_attr
self._epsilon = epsilon
self._dtype = dtype
# if (self._num_channels == self._groups and
# num_filters % self._num_channels == 0 and not self._use_cudnn):
# self._l_type = 'depthwise_conv2d'
# else:
# TODO(jiabin): recover the usage of depthwise_conv2d when it's
# kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275
self._l_type = "conv2d"
def _build_once(self, input):
self._num_channels = input.shape[1]
if self._groups is None:
num_filter_channels = self._num_channels
else:
if self._num_channels % self._groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = self._num_channels // self._groups
filter_size = utils.convert_to_list(self._filter_size, 2, "filter_size")
filter_shape = [self._num_filters, int(num_filter_channels)
] + filter_size
def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[
1] * self._num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
# weight_v
self._filter_param_v = self.create_parameter(
attr=self._param_attr,
shape=filter_shape,
dtype=self._dtype,
default_initializer=_get_default_param_initializer())
# weight_g
norm_value = _norm(
self._filter_param_v.numpy(), dim=0) # CAUTION: hard-code
self._filter_param_g = self.create_parameter(
attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
norm_value)),
shape=norm_value.shape,
dtype=self._dtype,
default_initializer=_get_default_param_initializer())
if self._use_cudnn:
self.create_variable(
name="kCUDNNFwdAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self.create_variable(
name="kCUDNNBwdDataAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self.create_variable(
name="kCUDNNBwdFilterAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self._bias_param = self.create_parameter(
attr=self._bias_attr,
shape=[self._num_filters],
dtype=self._dtype,
is_bias=True)
def forward(self, input):
matrix = self._helper.create_variable_for_type_inference(self._dtype)
tmp = self._helper.create_variable_for_type_inference(self._dtype)
new_shape = [
self._filter_param_v.shape[0],
reduce(lambda x, y: x * y, self._filter_param_v.shape[1:], 1),
]
self._helper.append_op(
type="reshape2",
inputs={"X": self._filter_param_v},
attrs={"shape": new_shape},
outputs={"Out": matrix,
"XShape": tmp})
m_norm = self._helper.create_variable_for_type_inference(self._dtype)
m_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="norm",
inputs={"X": matrix},
outputs={"Out": m_normalized,
"Norm": m_norm},
attrs={"axis": 1,
"epsilon": self._epsilon})
v_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
tmp2 = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="reshape2",
inputs={"X": m_normalized},
attrs={"shape": self._filter_param_v.shape},
outputs={"Out": v_normalized,
"XShape": tmp2})
filter_param = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="elementwise_mul",
inputs={"X": [v_normalized],
"Y": [self._filter_param_g]},
outputs={"Out": [filter_param]},
attrs={"axis": 0}, # CAUTION: hard-code
)
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={"Input": input,
"Filter": filter_param},
outputs={"Output": pre_bias},
attrs={
"strides": self._stride,
"paddings": self._padding,
"dilations": self._dilation,
"groups": self._groups if self._groups else 1,
"use_cudnn": self._use_cudnn,
"use_mkldnn": False,
})
if self._bias_param is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type="elementwise_add",
inputs={"X": [pre_bias],
"Y": [self._bias_param]},
outputs={"Out": [pre_act]},
attrs={"axis": 1})
else:
pre_act = pre_bias
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_act, act=self._act)
class Conv2DTranspose(dg.Layer):
"""
**Convlution2D transpose layer**
The convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input(Input) and output(Output)
are in NCHW format. Where N is batch size, C is the number of channels,
H is the height of the feature, and W is the width of the feature.
Parameters(dilations, strides, paddings) are two elements. These two elements
represent height and width, respectively. The details of convolution transpose
layer, please refer to the following explanation and references
`therein <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_.
If bias attribution and activation type are provided, bias is added to
the output of the convolution, and the corresponding activation function
is applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma ((Vg) \\ast X + b)
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`V`: Filter value, a tensor with MCHW format.
* :math:`g`: Filter value, a tensor with M format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
Args:
name_scope(str): The name of this class.
num_filters(int): The number of the filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above. Default: None.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square. None if use output size to
calculate filter_size. Default: None.
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
Returns:
Variable: The tensor variable storing the convolution transpose result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
with fluid.dygraph.guard():
data = numpy.random.random((3, 32, 32)).astype('float32')
conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose(
'Conv2DTranspose', num_filters=2, filter_size=3)
ret = conv2DTranspose(fluid.dygraph.base.to_variable(data))
"""
def __init__(self,
name_scope,
num_filters,
output_size=None,
filter_size=None,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
epsilon=1e-30,
act=None,
dtype="float32"):
super(Conv2DTranspose, self).__init__(name_scope, dtype)
assert (param_attr is not False
), "param_attr should not be False in conv2d_transpose."
self._param_attr = param_attr
self._bias_attr = bias_attr
self._groups = groups
self._num_filters = num_filters
self._use_cudnn = use_cudnn
self._padding = padding
self._stride = stride
self._dilation = dilation
self._filter_size = filter_size
self._output_size = output_size
self._op_type = "conv2d_transpose"
self._epsilon = epsilon
def _build_once(self, input):
input_channel = input.shape[1]
if (input_channel == self._groups and
self._num_filters == input_channel and not self._use_cudnn):
self._op_type = "depthwise_conv2d_transpose"
if not isinstance(input, Variable):
raise TypeError("Input of conv2d_transpose must be Variable")
self._padding = utils.convert_to_list(self._padding, 2, "padding")
self._stride = utils.convert_to_list(self._stride, 2, "stride")
self._dilation = utils.convert_to_list(self._dilation, 2, "dilation")
if not isinstance(self._use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
if self._filter_size is None:
if self._output_size is None:
raise ValueError(
"output_size must be set when filter_size is None")
if isinstance(self._output_size, int):
self._output_size = [self._output_size, self._output_size]
h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = (self._output_size[0] -
(h_in - 1) * self._stride[0] + 2 * self._padding[0]
- 1) // self._dilation[0] + 1
filter_size_w = (self._output_size[1] -
(w_in - 1) * self._stride[1] + 2 * self._padding[1]
- 1) // self._dilation[1] + 1
self._filter_size = [filter_size_h, filter_size_w]
else:
self._filter_size = utils.convert_to_list(
self._filter_size, 2, "conv2d_transpose.filter_size")
if self._output_size is None:
self._output_size = []
elif isinstance(self._output_size, list) or isinstance(
self._output_size, int):
self._output_size = utils.convert_to_list(self._output_size, 2,
"output_size")
else:
raise ValueError("output_size should be list or int")
self._padding = utils.convert_to_list(self._padding, 2, "padding")
self._groups = 1 if self._groups is None else self._groups
filter_shape = [
input_channel,
self._num_filters // self._groups,
] + self._filter_size
# img filter v (direction)
self._img_filter_v = self.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=self._param_attr)
# img filter g (magnitude)
img_filter_magnitude = _norm(
self._img_filter_v.numpy(), dim=0) # CAUTION: hard-code
self._img_filter_g = self.create_parameter(
dtype=input.dtype,
shape=img_filter_magnitude.shape,
attr=fluid.ParamAttr(
initializer=NumpyArrayInitializer(img_filter_magnitude)))
self._img_bias = self.create_parameter(
attr=self._bias_attr,
shape=[self._num_filters],
dtype=self._dtype,
is_bias=True)
def forward(self, input):
matrix = self._helper.create_variable_for_type_inference(self._dtype)
tmp = self._helper.create_variable_for_type_inference(self._dtype)
new_shape = [
self._img_filter_v.shape[0],
reduce(lambda x, y: x * y, self._img_filter_v.shape[1:], 1),
]
self._helper.append_op(
type="reshape2",
inputs={"X": self._img_filter_v},
attrs={"shape": new_shape},
outputs={"Out": matrix,
"XShape": tmp})
m_norm = self._helper.create_variable_for_type_inference(self._dtype)
m_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="norm",
inputs={"X": matrix},
outputs={"Out": m_normalized,
"Norm": m_norm},
attrs={"axis": 1,
"epsilon": self._epsilon})
v_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
tmp2 = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="reshape2",
inputs={"X": m_normalized},
attrs={"shape": self._img_filter_v.shape},
outputs={"Out": v_normalized,
"XShape": tmp2})
img_filter = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="elementwise_mul",
inputs={"X": [v_normalized],
"Y": [self._img_filter_g]},
outputs={"Out": [img_filter]},
attrs={"axis": 0}, # CAUTION: hard-code
)
pre_bias = self._helper.create_variable_for_type_inference(
dtype=input.dtype)
self._helper.append_op(
type=self._op_type,
inputs={"Input": [input],
"Filter": [img_filter]},
outputs={"Output": pre_bias},
attrs={
"output_size": self._output_size,
"strides": self._stride,
"paddings": self._padding,
"dilations": self._dilation,
"groups": self._groups,
"use_cudnn": self._use_cudnn,
})
if self._img_bias is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type="elementwise_add",
inputs={"X": [pre_bias],
"Y": [self._img_bias]},
outputs={"Out": [pre_act]},
attrs={"axis": 1})
else:
pre_act = pre_bias
out = self._helper.append_activation(pre_act)
return out

14
requirements.txt Normal file
View File

@ -0,0 +1,14 @@
numba==0.45.1
numpy==1.16.4
nltk==3.4.4
scipy
unidecode==1.1.1
inflect==2.1.0
librosa==0.7.0
tqdm==4.35.0
tensorboardX==1.8
matplotlib
requests==2.22.0
lws==1.2.4
nnmnkwii
tensorboard