Merge branch 'master' into 'master'

add docstring for parakeet.data and deep voice 3, wavenet and clarinet

See merge request !35
This commit is contained in:
liuyibing01 2020-03-10 09:47:36 +08:00
commit 0805eea568
20 changed files with 992 additions and 502 deletions

View File

@ -2,59 +2,61 @@
This short guide shows the design of `parakeet.data` and how we use it in an experiment. This short guide shows the design of `parakeet.data` and how we use it in an experiment.
The most concepts of parakeet are Dataset, DataCargo, Sampler, batch function and DataIterator. The most important concepts of `parakeet.data` are `DatasetMixin`, `DataCargo`, `Sampler`, `batch function` and `DataIterator`.
## Dataset ## Dataset
`Dataset`, as we assume here, is a list of examples. You gen get its length by `len(dataset)`(which means it length is known, and we have to implement `__len__` method for it). And you can access its items randomly by `dataset[i]`(which means we have to implement `__getitem__` method for it). Furthermore, you can iterable over it by `iter(dataset)` or `for example in dataset`, which means we have to implement `__iter__` method for it. Dataset, as we assume here, is a list of examples. You can get its length by `len(dataset)`(which means it length is known, and we have to implement `__len__` method for it). And you can access its items randomly by `dataset[i]`(which means we have to implement `__getitem__` method for it). Furthermore, you can iterate over it by `iter(dataset)` or `for example in dataset`, which means we have to implement `__iter__` method for it.
### DatasetMixin
We provide an `DatasetMixin` object which provides the above methods. You can inherit `DatasetMixin` and implement `get_example` method for it to define your own dataset class. The `get_example` method is called by `__getitem__` method automatically. We provide an `DatasetMixin` object which provides the above methods. You can inherit `DatasetMixin` and implement `get_example` method for it to define your own dataset class. The `get_example` method is called by `__getitem__` method automatically.
We also provide several other datasets that is built based on other datasets. We also define several high-order Dataset classes, the obejcts of which can be built from some given Dataset objects.
### TupleDataset ### TupleDataset
Dataset that is combined by sevral datasets of the same length. An `example` of a tupledataset is a tuple of examples of its constituent datasets. Dataset that is a combination of sevral datasets of the same length. An example of a `Tupledataset` is a tuple of examples of its constituent datasets.
### DictDataset ### DictDataset
Dataset that is combined by sevral datasets of the same length. An `example` of the tupledataset is a dict of examples of its constituent datasets. Dataset that is a combination of sevral datasets of the same length. An example of the `Dictdataset` is a dict of examples of its constituent datasets.
### SliceDataset ### SliceDataset
SliceDataset is a slice of the base dataset. `SliceDataset` is a slice of the base dataset.
### SubsetDataset ### SubsetDataset
SubsetDataset is a subset of the base dataset. `SubsetDataset` is a subset of the base dataset.
### ChainDataset ### ChainDataset
ChainDataset is the concatenation of several datastes with the same fields. `ChainDataset` is the concatenation of several datastes with the same fields.
### TransformDataset ### TransformDataset
A `TransformeDataset` is created by applying a `transform` to the base dataset. The `transform` is a callable object which takes an `example` of the base dataset and returns an `example` of the `TransformDataset`. The transform is lazy, which means it is applied to an example only when requested. A `TransformeDataset` is created by applying a `transform` to the base dataset. The `transform` is a callable object which takes an `example` of the base dataset as parameter and returns an `example` of the `TransformDataset`. The transformation is lazy, which means it is applied to an example only when requested.
### FilterDataset ### FilterDataset
A `FilterDataset` is created by applying a `filter` to the base dataset. A `filter` is a predicate that takes an `example` of the base dataset and returns a boolean. Only those examples that pass the filter are included in the FilterDataset. A `FilterDataset` is created by applying a `filter` to the base dataset. A `filter` is a predicate that takes an `example` of the base dataset as parameter and returns a boolean. Only those examples that pass the filter are included in the `FilterDataset`.
Note that filter is applied to all the examples in the base dataset when initializing a FilterDataset. Note that the filter is applied to all the examples in the base dataset when initializing a `FilterDataset`.
### CacheDataset ### CacheDataset
By default, we preprocess dataset lazily in `DatasetMixin.get_example`. An example is preprocessed only only requested. But `CacheDataset` caches the base dataset lazily, so each example is processed only once when it is first requested. When preprocessing the dataset is slow, you can use Cachedataset to speed it up, but caching may consume a lot of RAM if the dataset is large. By default, we preprocess dataset lazily in `DatasetMixin.get_example`. An example is preprocessed whenever requested. But `CacheDataset` caches the base dataset lazily, so each example is processed only once when it is first requested. When preprocessing the dataset is slow, you can use `Cachedataset` to speed it up, but caching may consume a lot of RAM if the dataset is large.
Finally, if preprocessing the dataset is slow and the processed dataset is too large to cache, you can write your own code to save them into files or databases, and then define a Dataset to load them. `Dataset` is flexible, so you can create your own dataset painlessly. Finally, if preprocessing the dataset is slow and the processed dataset is too large to cache, you can write your own code to save them into files or databases, and then define a Dataset to load them. `Dataset` is flexible, so you can create your own dataset painlessly.
## DataCargo ## DataCargo
`DataCargo`, like `Dataset`, is an iterable of batches. We need datacargo because in deep learning, batching examples into batches exploits the computational resources of modern hardwares. You can iterate it by `iter(datacargo)` or `for batch in datacargo`. `DataCargo` is an iterable but not an iterator, in that in can be iterated more than once. `DataCargo`, like `Dataset`, is an iterable, but it is an iterable of batches. We need `Datacargo` because in deep learning, batching examples into batches exploits the computational resources of modern hardwares. You can iterate it by `iter(datacargo)` or `for batch in datacargo`. `DataCargo` is an iterable but not an iterator, in that in can be iterated more than once.
### batch function ### batch function
The concept of `batch` is something transformed from a list of examples. Assume that an example is a structure(tuple in python, or struct in C and C++) consists of several fields, then a list of examples is an array of structure(AOS, a dataset is an AOS). Then a batch here is a structure of arrays (SOA). Here is an example: The concept of `batch` is something transformed from a list of examples. Assume that an example is a structure(tuple in python, or struct in C and C++) consists of several fields, then a list of examples is an array of structures(AOS, e.g. a dataset is an AOS). Then a batch here is a structure of arrays (SOA). Here is an example:
The table below represents 2 examples, each of which contains 5 fields. The table below represents 2 examples, each of which contains 5 fields.
@ -63,7 +65,7 @@ The table below represents 2 examples, each of which contains 5 fields.
| 1.2 | 1.1 | 1.3 | 1.4 | 0.8 | | 1.2 | 1.1 | 1.3 | 1.4 | 0.8 |
| 1.6 | 1.4 | 1.2 | 0.6 | 1.4 | | 1.6 | 1.4 | 1.2 | 0.6 | 1.4 |
The AOS representation and SOA representation of the table is show below. The AOS representation and SOA representation of the table are shown below.
AOS: AOS:
```text ```text
@ -81,15 +83,15 @@ SOA:
[0.8, 1.4]) [0.8, 1.4])
``` ```
For the example above, converting an AOS to an SOA is trivial, just stack every field for all the examples. But it is not always the case. When a field contains a sequence, you may have to pad all the sequences to the largest length then stack them together. In some other cases, we may want to add a field for the batch, for example, `valid_length` for each example. So in general, a function to transform an AOS to SOA is needed to build a datacargo from a dataset. We call this the batch function (`batch_fn`), but you can use any callable if you need to. For the example above, converting an AOS to an SOA is trivial, just stacking every field for all the examples. But it is not always the case. When a field contains a sequence, you may have to pad all the sequences to the largest length then stack them together. In some other cases, we may want to add a field for the batch, for example, `valid_length` for each example. So in general, a function to transform an AOS to SOA is needed to build a `Datacargo` from a dataset. We call this the batch function (`batch_fn`), but you can use any callable object if you need to.
Usually we need to define an callable object which stores all the options and configurations as its members as our `batch_fn`. Its `__call__` method transforms a list of examples into a batch. Usually we need to define the batch function as an callable object which stores all the options and configurations as its members. Its `__call__` method transforms a list of examples into a batch.
### sampler ### Sampler
Equipped with a batch function(we have known __how to batch__), here comes the next question. __What to batch?__ We need to decide which examples to pick when creating a batch. Since a dataset is a list of examples, we only need to pick indices for the corresponding examples. A sampler object is what we use to do this. Equipped with a batch function(we have known __how to batch__), here comes the next question. __What to batch?__ We need to decide which examples to pick when creating a batch. Since a dataset is a list of examples, we only need to pick indices for the corresponding examples. A sampler object is what we use to do this.
A sampler is represented as an iterable of integers. Assume the dataset has `N` examples, then an iterable of intergers in the range`[0, N)` is an appropriate sampler for this dataset to build a `DataCargo`. A `Sampler` is represented as an iterable of integers. Assume the dataset has `N` examples, then an iterable of intergers in the range`[0, N)` is an appropriate sampler for this dataset to build a `DataCargo`.
We provide several samplers that is ready to use. The `SequentialSampler`, `RandomSampler` and so on. We provide several samplers that is ready to use. The `SequentialSampler`, `RandomSampler` and so on.
@ -215,9 +217,9 @@ class Transform(object):
return audio, mel_spectrogram return audio, mel_spectrogram
``` ```
`Transform` loads the audio file, and extracts `mel_spectrogram` from the audio. This transform actually needs a lot of options to specify, namely, the sample rate of the audio files, the `n_fft`, `win_length`, `hop_length` of `stft` transformation, and `n_mels` for transforming spectrogram into mel_spectrogram. So we define it as a callable class. You can also use a closure, or a `partial` if you want to. `Transform` loads the audio files, and extracts `mel_spectrogram` from them. This transformation actually needs a lot of options to specify, namely, the sample rate of the audio files, the `n_fft`, `win_length`, `hop_length` of `stft` transformation, and `n_mels` for transforming spectrogram into mel_spectrogram. So we define it as a callable class. You can also use a closure, or a `partial` if you want to.
Then we defines a functor to batch examples into a batch. Because the two fields ( `audio` and `mel_spectrogram`) are both sequences, batching them is not trivial. Also, because the wavenet model trains in audio clips of a fixed length(0.5 seconds, for example), we have to truncate the audio when creating batches. We want to crop it randomly when creating batches, instead of truncating it when preprocessing each example, because it allows for an audio to be truncated at different positions. Then we defines a functor to batch examples into a batch. Because the two fields ( `audio` and `mel_spectrogram`) are both sequences, batching them is not trivial. Also, because the wavenet model trains in audio clips of a fixed length(0.5 seconds, for example), we have to truncate the audio when creating batches. We want to crop audio randomly when creating batches, instead of truncating them when preprocessing each example, because it allows for an audio to be truncated at different positions.
```python ```python
class DataCollector(object): class DataCollector(object):
@ -321,7 +323,7 @@ for batch in train_cargo:
# your training code here # your training code here
``` ```
In the code above, processing of the data and training of the model runs in the same process. So the next batch starts to load after the training of the current batch has finished. There is actually better solution for this. Data processing and model training can be run asynchronously. To accomplish this, we would use `DataLoader` from Paddle. This serves as an adapter to transform an Iterable of batches into another iterable of batches, which runs asynchronously and transform each ndarray into `Variable`. In the code above, processing of the data and training of the model run in the same process. So the next batch starts to load after the training of the current batch has finished. There is actually better solutions for this. Data processing and model training can be run asynchronously. To accomplish this, we would use `DataLoader` from Paddle. This serves as an adapter to transform an Iterable of batches into another iterable of batches, which runs asynchronously and transform each ndarray into `Variable`.
```python ```python
# connects our data cargos with corresponding DataLoader # connects our data cargos with corresponding DataLoader

87
docs/experiment_guide.md Normal file
View File

@ -0,0 +1,87 @@
# How to build your own model and experiment?
For a general deep learning experiment, there are 4 parts to care for.
1. Preprocess dataset to meet the needs for model training and iterate them in batches;
2. Define the model and the optimizer;
3. Write the training process (including forward-backward computation, parameter update, logging, evaluation, etc.)
4. Configure and launch the experiment.
## Data Processing
For processing data, `parakeet.data` provides `DatasetMixin`, `DataCargo` and `DataIterator`.
Dataset is an iterable of examples. `DatasetMixin` provides the standard indexing interface, and other classes in [parakeet.data.dataset](../parakeet/data/dataset.py) provide flexible interfaces for building customized datasets.
`DataCargo` is an iterable of batches. It differs from a dataset in that it can be iterated in batches. In addition to a dataset, a `Sampler` and a `batch function` are required to build a `DataCargo`. `Sampler` specifies which examples to pick, and `batch function` specifies how to create a batch from them. Commonly used `Samplers` are provides by [parakeet.data](../parakeet/data/). Users should define a `batch function` for a datasets, in order to batch its examples.
`DataIterator` is an iterator class for `DataCargo`. It is create when explicitly creating an iterator of a `DataCargo` by `iter(DataCargo)`, or iterating a `DataCargo` with `for` loop.
Data processing is splited into two phases: sample-level processing and batching.
1. Sample-level processing. This process is transforming an example into another example. This process can be defined as `get_example` method of a dataset, or as a `transform` (callable object) and build a `TransformDataset` with it.
2. Batching. It is the process of transforming a list of examples into a batch. The rationale is to transform an array of structures into a structure of arrays. We generally define a batch function (or a callable object) to do this.
To connect a `DataCargo` with Paddlepaddle's asynchronous data loading mechanism, we need to create a `fluid.io.DataLoader` and connect it to the `Datacargo`.
The overview of data processing in an experiment with Parakeet is :
```text
Dataset --(transform)--> Dataset --+
sampler --+
batch_fn --+-> DataCargo --> DataLoader
```
The user need to define a customized transform and a batch function to accomplish this process. See [data](./data.md) for more details.
## Model
Parakeet provides commonly used functions, modules and models for the users to define their own models. Functions contains no trainable `Parameter`s, and are used in modules and models. Modules and modes are subclasses of `fluid.dygraph.Layer`. The distinction is that `module`s tend to be generic, simple and highly reusable, while `model`s tend to be task-sepcific, complicated and not that reusable. Some models are so complicated that we extract building blocks from it as separate classes but if these building blocks are not common and reusable enough, they are considered as submodels.
In the structure of the project, modules are placed in [parakeet.modules](../parakeet/modules/), while models are in [parakeet.models](../parakeet/models) and grouped into folders like `waveflow` and `wavenet`, which include the whole model and their submodels.
When developers want to add new models to `parakeet`, they can consider the distinctions described above and put the code in an appropriate place.
## Training Process
Training process is basically running a training loop for multiple times. A typical training loop consists of the procedures below:
1. Iterating over training dataset;
2. Prerocessing mini-batches;
3. Forward/backward computations of the neural networks;
4. Updating Parameters;
5. Evaluating the model on validation dataset;
6. Logging or saving intermediate results;
7. Saving checkpoint of the model and the optimizer.
In section `DataProcrssing` we have cover 1 and 2.
`Model` and `Optimizer` cover 3 and 4.
To keep the training loop clear, it's a good idea to define functions for saving/loading of checkpoints, evaluation on validation set, logging and saving of intermediate results, etc. For some complicated model, it is also recommended to define a function to create the model. This function can be used in both train and inference, to ensure that the model is identical at training and inference.
Code is typically organized in this way:
```text
├── configs (example configuration)
├── data.py (definition of custom Dataset, transform and batch function)
├── README.md (README for the experiment)
├── synthesis.py (code for inference)
├── train.py (code for training)
└── utils.py (all other utility functions)
```
## Configuration
Deep learning experiments have many options to configure. These configurations can be roughly grouped into different types: configurations about path of the dataset and path to save results, configurations about how to process data, configuration about the model and configurations about the training process.
Some configurations tend to change when running the code at different times, for example, path of the data and path to save results and whether to load model before training, etc. For these configurations, it's better to define them as command line arguments. We use `argparse` to handle them.
Other groups of configuration may overlap with others. For example, data processing and model may have some common options. The recommended way is to save them as configuration files, for example, `yaml` or `json`. We prefer `yaml`, for it is more human-reabable.
There are several examples in this repo, check [Parakeet/examples](../examples) for more details. `Parakeet/examples` is where we place our experiments. Though experiments are not a part of package `parakeet`, it is a part of repo `Parakeet`. They are provided as examples and allow for the users to run our experiment out-of-the-box. Feel free to add new examples and contribute to `Parakeet`.

View File

@ -67,13 +67,13 @@ def save_checkpoint(model, optim, checkpoint_dir, global_step):
def load_model(model, path): def load_model(model, path):
model_dict, _ = dg.load_dygraph(path) model_dict, _ = dg.load_dygraph(path)
model.state_dict(model_dict) model.set_dict(model_dict)
print("loaded model from {}.pdparams".format(path)) print("loaded model from {}.pdparams".format(path))
def load_checkpoint(model, optim, path): def load_checkpoint(model, optim, path):
model_dict, optim_dict = dg.load_dygraph(path) model_dict, optim_dict = dg.load_dygraph(path)
model.state_dict(model_dict) model.set_dict(model_dict)
print("loaded model from {}.pdparams".format(path)) print("loaded model from {}.pdparams".format(path))
if optim_dict: if optim_dict:
optim.set_dict(optim_dict) optim.set_dict(optim_dict)

View File

@ -69,7 +69,6 @@ def make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
padding_idx=None, padding_idx=None,
embedding_weight_std=embedding_std, embedding_weight_std=embedding_std,
convolutions=encoder_convolutions, convolutions=encoder_convolutions,
max_positions=max_positions,
dropout=dropout) dropout=dropout)
if freeze_embedding: if freeze_embedding:
freeze(enc.embed) freeze(enc.embed)
@ -91,7 +90,6 @@ def make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
mel_dim, mel_dim,
r=r, r=r,
max_positions=max_positions, max_positions=max_positions,
padding_idx=padding_idx,
preattention=prenet_convolutions, preattention=prenet_convolutions,
convolutions=attentive_convolutions, convolutions=attentive_convolutions,
attention=attention, attention=attention,

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
functions to make batch for arrays which satisfy some conditions. Utility functions to create batch for arrays which satisfy some conditions.
Batch functions for text sequences, audio and spectrograms are provided.
""" """
import numpy as np import numpy as np
class TextIDBatcher(object): class TextIDBatcher(object):
"""A wrapper class for a function to build a functor, which holds the configs to pass to the function.""" """A wrapper class for `batch_text_id`."""
def __init__(self, pad_id=0, dtype=np.int64): def __init__(self, pad_id=0, dtype=np.int64):
self.pad_id = pad_id self.pad_id = pad_id
@ -30,9 +31,15 @@ class TextIDBatcher(object):
def batch_text_id(minibatch, pad_id=0, dtype=np.int64): def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
""" """Pad sequences to text_ids to the largest length and batch them.
minibatch: List[Example]
Example: ndarray, shape(T,), dtype: int64 Args:
minibatch (List[np.ndarray]): list of rank-1 arrays, shape(T,), dtype np.int64, text_ids.
pad_id (int, optional): the id which correspond to the special pad token. Defaults to 0.
dtype (np.dtype, optional): the data dtype of the output. Defaults to np.int64.
Returns:
np.ndarray: rank-2 array of text_ids, shape(B, T), B stands for batch_size, T stands for length. The output batch.
""" """
peek_example = minibatch[0] peek_example = minibatch[0]
assert len(peek_example.shape) == 1, "text example is an 1D tensor" assert len(peek_example.shape) == 1, "text example is an 1D tensor"
@ -53,6 +60,8 @@ def batch_text_id(minibatch, pad_id=0, dtype=np.int64):
class WavBatcher(object): class WavBatcher(object):
"""A wrapper class for `batch_wav`."""
def __init__(self, pad_value=0., dtype=np.float32): def __init__(self, pad_value=0., dtype=np.float32):
self.pad_value = pad_value self.pad_value = pad_value
self.dtype = dtype self.dtype = dtype
@ -63,19 +72,25 @@ class WavBatcher(object):
def batch_wav(minibatch, pad_value=0., dtype=np.float32): def batch_wav(minibatch, pad_value=0., dtype=np.float32):
"""pad audios to the largest length and batch them.
Args:
minibatch (List[np.ndarray]): list of rank-1 float arrays(mono-channel audio, shape(T,)) or list of rank-2 float arrays(multi-channel audio, shape(C, T), C stands for numer of channels, T stands for length), dtype float.
pad_value (float, optional): the pad value. Defaults to 0..
dtype (np.dtype, optional): the data type of the output. Defaults to np.float32.
Returns:
np.ndarray: the output batch. It is a rank-2 float array of shape(B, T) if the minibatch is a list of mono-channel audios, or a rank-3 float array of shape(B, C, T) if the minibatch is a list of multi-channel audios.
""" """
minibatch: List[Example]
Example: ndarray, shape(C, T) for multi-channel wav, shape(T,) for mono-channel wav, dtype: float32
"""
# detect data format, maybe better to specify it in __init__
peek_example = minibatch[0] peek_example = minibatch[0]
if len(peek_example.shape) == 1: if len(peek_example.shape) == 1:
mono_channel = True mono_channel = True
elif len(peek_example.shape) == 2: elif len(peek_example.shape) == 2:
mono_channel = False mono_channel = False
lengths = [example.shape[-1] for example in minibatch # assume (channel, n_samples) or (n_samples, )
] # assume (channel, n_samples) or (n_samples, ) lengths = [example.shape[-1] for example in minibatch]
max_len = np.max(lengths) max_len = np.max(lengths)
batch = [] batch = []
@ -90,12 +105,14 @@ def batch_wav(minibatch, pad_value=0., dtype=np.float32):
batch.append( batch.append(
np.pad(example, [(0, 0), (0, pad_len)], np.pad(example, [(0, 0), (0, pad_len)],
mode='constant', mode='constant',
constant_values=pad_value)) # what about PCM, no constant_values=pad_value))
return np.array(batch, dtype=dtype) return np.array(batch, dtype=dtype)
class SpecBatcher(object): class SpecBatcher(object):
"""A wrapper class for `batch_spec`"""
def __init__(self, pad_value=0., dtype=np.float32): def __init__(self, pad_value=0., dtype=np.float32):
self.pad_value = pad_value self.pad_value = pad_value
self.dtype = dtype self.dtype = dtype
@ -106,9 +123,15 @@ class SpecBatcher(object):
def batch_spec(minibatch, pad_value=0., dtype=np.float32): def batch_spec(minibatch, pad_value=0., dtype=np.float32):
""" """Pad spectra to the largest length and batch them.
minibatch: List[Example]
Example: ndarray, shape(C, F, T) for multi-channel spectrogram, shape(F, T) for mono-channel spectrogram, dtype: float32 Args:
minibatch (List[np.ndarray]): list of rank-2 arrays of shape(F, T) for mono-channel spectrograms, or list of rank-3 arrays of shape(C, F, T) for multi-channel spectrograms(F stands for frequency bands.), dtype float.
pad_value (float, optional): the pad value. Defaults to 0..
dtype (np.dtype, optional): data type of the output. Defaults to np.float32.
Returns:
np.ndarray: a rank-3 array of shape(B, F, T) when the minibatch is a list of mono-channel spectrograms, or a rank-4 array of shape(B, C, F, T) when the minibatch is a list of multi-channel spectorgrams.
""" """
# assume (F, T) or (C, F, T) # assume (F, T) or (C, F, T)
peek_example = minibatch[0] peek_example = minibatch[0]
@ -117,8 +140,8 @@ def batch_spec(minibatch, pad_value=0., dtype=np.float32):
elif len(peek_example.shape) == 3: elif len(peek_example.shape) == 3:
mono_channel = False mono_channel = False
lengths = [example.shape[-1] for example in minibatch # assume (channel, F, n_frame) or (F, n_frame)
] # assume (channel, F, n_frame) or (F, n_frame) lengths = [example.shape[-1] for example in minibatch]
max_len = np.max(lengths) max_len = np.max(lengths)
batch = [] batch = []
@ -133,6 +156,6 @@ def batch_spec(minibatch, pad_value=0., dtype=np.float32):
batch.append( batch.append(
np.pad(example, [(0, 0), (0, 0), (0, pad_len)], np.pad(example, [(0, 0), (0, 0), (0, pad_len)],
mode='constant', mode='constant',
constant_values=pad_value)) # what about PCM, no constant_values=pad_value))
return np.array(batch, dtype=dtype) return np.array(batch, dtype=dtype)

View File

@ -25,6 +25,17 @@ class DataCargo(object):
shuffle=False, shuffle=False,
batch_sampler=None, batch_sampler=None,
drop_last=False): drop_last=False):
"""An Iterable object of batches. It requires a dataset, a batch function and a sampler. The sampler yields the example ids, then the corresponding examples in the dataset are collected and transformed into a batch with the batch function.
Args:
dataset (Dataset): the dataset used to build a data cargo.
batch_fn (callable, optional): a callable that takes a list of examples of `dataset` and return a batch, it can be None if the dataset has a `_batch_examples` method which satisfy the requirement. Defaults to None.
batch_size (int, optional): number of examples in a batch. Defaults to 1.
sampler (Sampler, optional): an iterable of example ids(intergers), the example ids are used to pick examples. Defaults to None.
shuffle (bool, optional): when sampler is not provided, shuffle = True creates a RandomSampler and shuffle=False creates a SequentialSampler internally. Defaults to False.
batch_sampler (BatchSampler, optional): an iterable of lists of example ids(intergers), the list is used to pick examples, `batch_sampler` option is mutually exclusive with `batch_size`, `shuffle`, `sampler`, and `drop_last`. Defaults to None.
drop_last (bool, optional): whether to drop the last minibatch. Defaults to False.
"""
self.dataset = dataset self.dataset = dataset
self.batch_fn = batch_fn or self.dataset._batch_examples self.batch_fn = batch_fn or self.dataset._batch_examples
@ -59,11 +70,12 @@ class DataCargo(object):
return DataIterator(self) return DataIterator(self)
def __call__(self): def __call__(self):
# protocol for paddle's DataLoader
return DataIterator(self) return DataIterator(self)
@property @property
def _auto_collation(self): def _auto_collation(self):
# we will auto batching # use auto batching
return self.batch_sampler is not None return self.batch_sampler is not None
@property @property
@ -79,6 +91,11 @@ class DataCargo(object):
class DataIterator(object): class DataIterator(object):
def __init__(self, loader): def __init__(self, loader):
"""Iterator object of DataCargo.
Args:
loader (DataCargo): the data cargo to iterate.
"""
self.loader = loader self.loader = loader
self._dataset = loader.dataset self._dataset = loader.dataset
@ -90,11 +107,9 @@ class DataIterator(object):
return self return self
def __next__(self): def __next__(self):
# TODO(chenfeiyu): use dynamic batch size
index = self._next_index( index = self._next_index()
) # may raise StopIteration, TODO(chenfeiyu): use dynamic batch size minibatch = [self._dataset[i] for i in index]
minibatch = [self._dataset[i] for i in index
] # we can abstract it, too to use dynamic batch size
minibatch = self._batch_fn(minibatch) # list[Example] -> Batch minibatch = self._batch_fn(minibatch) # list[Example] -> Batch
return minibatch return minibatch

View File

@ -18,9 +18,23 @@ from tqdm import tqdm
class DatasetMixin(object): class DatasetMixin(object):
"""standard indexing interface for dataset.""" """Standard indexing interface for dataset. Inherit this class to
get the indexing interface. Since it is a mixin class which does
not have an `__init__` class, the subclass not need to call
`super().__init__()`.
"""
def __getitem__(self, index): def __getitem__(self, index):
"""Standard indexing interface for dataset.
Args:
index (slice, list[int], np.array or int): the index. if can be int, slice, list of integers, or ndarray of integers. It calls `get_example` to pick an example.
Returns:
Example, or List[Example]: If `index` is an interger, it returns an
example. If `index` is a slice, a list of intergers or an array of intergers,
it returns a list of examples.
"""
if isinstance(index, slice): if isinstance(index, slice):
start, stop, step = index.indices(len(self)) start, stop, step = index.indices(len(self))
return [ return [
@ -33,6 +47,12 @@ class DatasetMixin(object):
return self.get_example(index) return self.get_example(index)
def get_example(self, i): def get_example(self, i):
"""Get an example from the dataset. Custom datasets should have
this method implemented.
Args:
i (int): example index.
"""
raise NotImplementedError raise NotImplementedError
def __len__(self): def __len__(self):
@ -44,9 +64,13 @@ class DatasetMixin(object):
class TransformDataset(DatasetMixin): class TransformDataset(DatasetMixin):
"""Transform a dataset to another with a transform."""
def __init__(self, dataset, transform): def __init__(self, dataset, transform):
"""Dataset which is transformed from another with a transform.
Args:
dataset (DatasetMixin): the base dataset.
transform (callable): the transform which takes an example of the base dataset as parameter and return a new example.
"""
self._dataset = dataset self._dataset = dataset
self._transform = transform self._transform = transform
@ -54,14 +78,17 @@ class TransformDataset(DatasetMixin):
return len(self._dataset) return len(self._dataset)
def get_example(self, i): def get_example(self, i):
# CAUTION: only int is supported?
# CAUTION: dataset support support __getitem__ and __len__
in_data = self._dataset[i] in_data = self._dataset[i]
return self._transform(in_data) return self._transform(in_data)
class CacheDataset(DatasetMixin): class CacheDataset(DatasetMixin):
def __init__(self, dataset): def __init__(self, dataset):
"""A lazy cache of the base dataset.
Args:
dataset (DatasetMixin): the base dataset to cache.
"""
self._dataset = dataset self._dataset = dataset
self._cache = dict() self._cache = dict()
@ -76,6 +103,11 @@ class CacheDataset(DatasetMixin):
class TupleDataset(object): class TupleDataset(object):
def __init__(self, *datasets): def __init__(self, *datasets):
"""A compound dataset made from several datasets of the same length. An example of the `TupleDataset` is a tuple of examples from the constituent datasets.
Args:
datasets: tuple[DatasetMixin], the constituent datasets.
"""
if not datasets: if not datasets:
raise ValueError("no datasets are given") raise ValueError("no datasets are given")
length = len(datasets[0]) length = len(datasets[0])
@ -106,6 +138,11 @@ class TupleDataset(object):
class DictDataset(object): class DictDataset(object):
def __init__(self, **datasets): def __init__(self, **datasets):
"""A compound dataset made from several datasets of the same length. An example of the `DictDataset` is a dict of examples from the constituent datasets.
Args:
datasets: Dict[DatasetMixin], the constituent datasets.
"""
if not datasets: if not datasets:
raise ValueError("no datasets are given") raise ValueError("no datasets are given")
length = None length = None
@ -135,6 +172,14 @@ class DictDataset(object):
class SliceDataset(DatasetMixin): class SliceDataset(DatasetMixin):
def __init__(self, dataset, start, finish, order=None): def __init__(self, dataset, start, finish, order=None):
"""A Dataset which is a slice of the base dataset.
Args:
dataset (DatasetMixin): the base dataset.
start (int): the start of the slice.
finish (int): the end of the slice, not inclusive.
order (List[int], optional): the order, it is a permutation of the valid example ids of the base dataset. If `order` is provided, the slice is taken in `order`. Defaults to None.
"""
if start < 0 or finish > len(dataset): if start < 0 or finish > len(dataset):
raise ValueError("subset overruns the dataset.") raise ValueError("subset overruns the dataset.")
self._dataset = dataset self._dataset = dataset
@ -169,6 +214,12 @@ class SliceDataset(DatasetMixin):
class SubsetDataset(DatasetMixin): class SubsetDataset(DatasetMixin):
def __init__(self, dataset, indices): def __init__(self, dataset, indices):
"""A Dataset which is a subset of the base dataset.
Args:
dataset (DatasetMixin): the base dataset.
indices (Iterable[int]): the indices of the examples to pick.
"""
self._dataset = dataset self._dataset = dataset
if len(indices) > len(dataset): if len(indices) > len(dataset):
raise ValueError("subset's size larger that dataset's size!") raise ValueError("subset's size larger that dataset's size!")
@ -185,6 +236,12 @@ class SubsetDataset(DatasetMixin):
class FilterDataset(DatasetMixin): class FilterDataset(DatasetMixin):
def __init__(self, dataset, filter_fn): def __init__(self, dataset, filter_fn):
"""A filtered dataset.
Args:
dataset (DatasetMixin): the base dataset.
filter_fn (callable): a callable which takes an example of the base dataset and return a boolean.
"""
self._dataset = dataset self._dataset = dataset
self._indices = [ self._indices = [
i for i in range(len(dataset)) if filter_fn(dataset[i]) i for i in range(len(dataset)) if filter_fn(dataset[i])
@ -201,6 +258,11 @@ class FilterDataset(DatasetMixin):
class ChainDataset(DatasetMixin): class ChainDataset(DatasetMixin):
def __init__(self, *datasets): def __init__(self, *datasets):
"""A concatenation of the several datasets which the same structure.
Args:
datasets (Iterable[DatasetMixin]): datasets to concat.
"""
self._datasets = datasets self._datasets = datasets
def __len__(self): def __len__(self):

View File

@ -14,7 +14,7 @@
""" """
At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__. At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__.
This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use it to collect examples from the dataset with __getitem__. Then collate this examples to form a batch. This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use them to collect examples from the dataset with __getitem__. Then transform these examples into a batch.
So the sampler is only responsible for generating valid indices. So the sampler is only responsible for generating valid indices.
""" """
@ -24,9 +24,6 @@ import random
class Sampler(object): class Sampler(object):
def __init__(self, data_source):
pass
def __iter__(self): def __iter__(self):
# return a iterator of indices # return a iterator of indices
# or a iterator of list[int], for BatchSampler # or a iterator of list[int], for BatchSampler
@ -35,6 +32,11 @@ class Sampler(object):
class SequentialSampler(Sampler): class SequentialSampler(Sampler):
def __init__(self, data_source): def __init__(self, data_source):
"""Sequential sampler, the simplest sampler that samples indices from 0 to N - 1, where N is the dataset is length.
Args:
data_source (DatasetMixin): the dataset. This is used to get the dataset's length.
"""
self.data_source = data_source self.data_source = data_source
def __iter__(self): def __iter__(self):
@ -46,6 +48,13 @@ class SequentialSampler(Sampler):
class RandomSampler(Sampler): class RandomSampler(Sampler):
def __init__(self, data_source, replacement=False, num_samples=None): def __init__(self, data_source, replacement=False, num_samples=None):
"""Random sampler.
Args:
data_source (DatasetMixin): the dataset. This is used to get the dataset's length.
replacement (bool, optional): whether replacement is enabled in sampling. When `replacement` is True, `num_samples` must be provided. Defaults to False.
num_samples (int, optional): numbers of indices to draw. This option should only be provided when replacement is True. Defaults to None.
"""
self.data_source = data_source self.data_source = data_source
self.replacement = replacement self.replacement = replacement
self._num_samples = num_samples self._num_samples = num_samples
@ -66,7 +75,6 @@ class RandomSampler(Sampler):
@property @property
def num_samples(self): def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None: if self._num_samples is None:
return len(self.data_source) return len(self.data_source)
return self._num_samples return self._num_samples
@ -84,12 +92,16 @@ class RandomSampler(Sampler):
class SubsetRandomSampler(Sampler): class SubsetRandomSampler(Sampler):
r"""Samples elements randomly from a given list of indices, without replacement. """Samples elements randomly from a given list of indices, without replacement.
Arguments: Arguments:
indices (sequence): a sequence of indices indices (sequence): a sequence of indices
""" """
def __init__(self, indices): def __init__(self, indices):
"""
Args:
indices (List[int]): indices to sample from.
"""
self.indices = indices self.indices = indices
def __iter__(self): def __iter__(self):
@ -112,6 +124,14 @@ class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
batch_size=4, batch_size=4,
batch_group_size=None, batch_group_size=None,
permutate=True): permutate=True):
"""[summary]
Args:
lengths (List[int]): The length of the examples of the dataset. This is the key to be considered as 'time length'.
batch_size (int, optional): batch size. Defaults to 4.
batch_group_size (int, optional): the size of a small batch. Random shuffling is applied within such patches. If `batch_group_size` is not provided, it is set to min(batch_size * 32, len(self.lengths)). Batch_group_size should be perfectly divided by batch_size. Defaults to None.
permutate (bool, optional): permutate batches. Defaults to True.
"""
_lengths = np.array( _lengths = np.array(
lengths, lengths,
dtype=np.int64) # maybe better implement length as a sort key dtype=np.int64) # maybe better implement length as a sort key
@ -157,13 +177,11 @@ class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
class WeightedRandomSampler(Sampler): class WeightedRandomSampler(Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). """Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args: Args:
weights (sequence) : a sequence of weights, not necessary summing up to one weights (List[float]): a sequence of weights, not necessary summing up to 1.
num_samples (int): number of samples to draw num_samples (int): number of samples to draw.
replacement (bool): if ``True``, samples are drawn with replacement. replacement (bool): whether samples are drawn with replacement. When replacement is False, num_samples should not be larger than len(weights).
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
Example: Example:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0] [0, 0, 0, 1, 0]
@ -179,6 +197,10 @@ class WeightedRandomSampler(Sampler):
self.weights = np.array(weights, dtype=np.float64) self.weights = np.array(weights, dtype=np.float64)
self.num_samples = num_samples self.num_samples = num_samples
self.replacement = replacement self.replacement = replacement
if replacement is False and num_samples > len(weights):
raise ValueError(
"when replacement is False, num_samples should not be"
"larger that length of weight.")
def __iter__(self): def __iter__(self):
return iter( return iter(
@ -194,6 +216,21 @@ class WeightedRandomSampler(Sampler):
class DistributedSampler(Sampler): class DistributedSampler(Sampler):
def __init__(self, dataset_size, num_trainers, rank, shuffle=True): def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
"""Sampler used for data parallel training. Indices are divided into num_trainers parts. Each trainer gets a subset and iter that subset. If the dataset has 16 examples, and there are 4 trainers.
Trainer 0 gets [0, 4, 8, 12];
Trainer 1 gets [1, 5, 9, 13];
Trainer 2 gets [2, 6, 10, 14];
trainer 3 gets [3, 7, 11, 15].
It ensures that trainer get different parts of the dataset. If dataset's length cannot be perfectly devidef by num_trainers, some examples appended to the dataset, to ensures that every trainer gets the same amounts of examples.
Args:
dataset_size (int): the length of the dataset.
num_trainers (int): number of trainers(training processes).
rank (int): local rank of the trainer.
shuffle (bool, optional): whether to shuffle the indices before iteration. Defaults to True.
"""
self.dataset_size = dataset_size self.dataset_size = dataset_size
self.num_trainers = num_trainers self.num_trainers = num_trainers
self.rank = rank self.rank = rank
@ -222,20 +259,20 @@ class DistributedSampler(Sampler):
class BatchSampler(Sampler): class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices. """Wraps another sampler to yield a mini-batch of indices."""
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler, batch_size, drop_last): def __init__(self, sampler, batch_size, drop_last):
"""
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If True, the sampler will drop the last batch if its size is less than batch_size.
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of " raise ValueError("sampler should be an instance of "
"Sampler, but got sampler={}".format(sampler)) "Sampler, but got sampler={}".format(sampler))

View File

@ -37,28 +37,41 @@ class Clarinet(dg.Layer):
stft, stft,
min_log_scale=-6.0, min_log_scale=-6.0,
lmd=4.0): lmd=4.0):
"""Clarinet model.
Args:
encoder (UpsampleNet): an UpsampleNet to upsample mel spectrogram.
teacher (WaveNet): a WaveNet, the teacher.
student (ParallelWaveNet): a ParallelWaveNet model, the student.
stft (STFT): a STFT model to perform differentiable stft transform.
min_log_scale (float, optional): used only for computing loss, the minimal value of log standard deviation of the output distribution of both the teacher and the student . Defaults to -6.0.
lmd (float, optional): weight for stft loss. Defaults to 4.0.
"""
super(Clarinet, self).__init__() super(Clarinet, self).__init__()
self.lmd = lmd
self.encoder = encoder self.encoder = encoder
self.teacher = teacher self.teacher = teacher
self.student = student self.student = student
self.min_log_scale = min_log_scale
self.stft = stft self.stft = stft
def forward(self, audio, mel, audio_start, clip_kl=True): self.lmd = lmd
"""Compute loss for a distill model self.min_log_scale = min_log_scale
Arguments: def forward(self, audio, mel, audio_start, clip_kl=True):
audio {Variable} -- shape(batch_size, time_steps), target waveform. """Compute loss of Clarinet model.
mel {Variable} -- shape(batch_size, condition_dim, time_steps // hop_length), original mel spectrogram, not upsampled yet.
audio_starts {Variable} -- shape(batch_size, ), the index of the start sample. Args:
clip_kl (bool) -- whether to clip kl divergence if it is greater than 10.0. audio (Variable): shape(B, T_audio), dtype flaot32, ground truth waveform.
mel (Variable): shape(B, F, T_mel), dtype flaot32, condition(mel spectrogram here).
audio_start (Variable): shape(B, ), dtype int64, audio starts positions.
clip_kl (bool, optional): whether to clip kl_loss by maximum=100. Defaults to True.
Returns: Returns:
Variable -- shape(1,), loss Dict(str, Variable)
loss (Variable): shape(1, ), dtype flaot32, total loss.
kl (Variable): shape(1, ), dtype flaot32, kl divergence between the teacher's output distribution and student's output distribution.
regularization (Variable): shape(1, ), dtype flaot32, a regularization term of the KL divergence.
spectrogram_frame_loss (Variable): shape(1, ), dytpe: float, stft loss, the L1-distance of the magnitudes of the spectrograms of the ground truth waveform and synthesized waveform.
""" """
batch_size, audio_length = audio.shape # audio clip's length batch_size, audio_length = audio.shape # audio clip's length
z = F.gaussian_random(audio.shape) z = F.gaussian_random(audio.shape)
@ -104,13 +117,13 @@ class Clarinet(dg.Layer):
@dg.no_grad @dg.no_grad
def synthesis(self, mel): def synthesis(self, mel):
"""Synthesize waveform conditioned on the mel spectrogram. """Synthesize waveform using the encoder and the student network.
Arguments: Args:
mel {Variable} -- shape(batch_size, frequqncy_bands, frames) mel (Variable): shape(B, F, T_mel), the condition(mel spectrogram here).
Returns: Returns:
Variable -- shape(batch_size, frames * upsample_factor) Variable: shape(B, T_audio), the synthesized waveform. (T_audio = T_mel * upscale_factor, where upscale_factor is the `upscale_factor` of the encoder.)
""" """
condition = self.encoder(mel) condition = self.encoder(mel)
samples_shape = (condition.shape[0], condition.shape[-1]) samples_shape = (condition.shape[0], condition.shape[-1])
@ -121,6 +134,14 @@ class Clarinet(dg.Layer):
class STFT(dg.Layer): class STFT(dg.Layer):
def __init__(self, n_fft, hop_length, win_length, window="hanning"): def __init__(self, n_fft, hop_length, win_length, window="hanning"):
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
Args:
n_fft (int): number of samples in a frame.
hop_length (int): number of samples shifted between adjacent frames.
win_length (int): length of the window function.
window (str, optional): name of window function, see `scipy.signal.get_window` for more details. Defaults to "hanning".
"""
super(STFT, self).__init__() super(STFT, self).__init__()
self.hop_length = hop_length self.hop_length = hop_length
self.n_bin = 1 + n_fft // 2 self.n_bin = 1 + n_fft // 2
@ -146,6 +167,16 @@ class STFT(dg.Layer):
self.weight = dg.to_variable(w) self.weight = dg.to_variable(w)
def forward(self, x): def forward(self, x):
"""Compute the stft transform.
Args:
x (Variable): shape(B, T), dtype flaot32, the input waveform.
Returns:
(real, imag)
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram. (C = 1 + n_fft // 2)
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram. (C = 1 + n_fft // 2)
"""
# x(batch_size, time_steps) # x(batch_size, time_steps)
# pad it first with reflect mode # pad it first with reflect mode
pad_start = F.reverse(x[:, 1:1 + self.n_fft // 2], axis=1) pad_start = F.reverse(x[:, 1:1 + self.n_fft // 2], axis=1)
@ -159,11 +190,31 @@ class STFT(dg.Layer):
return real, imag return real, imag
def power(self, x): def power(self, x):
"""Compute the power spectrogram.
Args:
(real, imag)
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram.
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram.
Returns:
Variable: shape(B, C, 1, T), dtype flaot32, the power spectrogram.
"""
real, imag = self(x) real, imag = self(x)
power = real**2 + imag**2 power = real**2 + imag**2
return power return power
def magnitude(self, x): def magnitude(self, x):
"""Compute the magnitude spectrogram.
Args:
(real, imag)
real (Variable): shape(B, C, 1, T), dtype flaot32, the real part of the spectrogram.
imag (Variable): shape(B, C, 1, T), dtype flaot32, the image part of the spectrogram.
Returns:
Variable: shape(B, C, 1, T), dtype flaot32, the magnitude spectrogram. It is the square root of the power spectrogram.
"""
power = self.power(x) power = self.power(x)
magnitude = F.sqrt(power) magnitude = F.sqrt(power)
return magnitude return magnitude

View File

@ -29,6 +29,15 @@ from parakeet.models.wavenet import WaveNet
class ParallelWaveNet(dg.Layer): class ParallelWaveNet(dg.Layer):
def __init__(self, n_loops, n_layers, residual_channels, condition_dim, def __init__(self, n_loops, n_layers, residual_channels, condition_dim,
filter_size): filter_size):
"""ParallelWaveNet, an inverse autoregressive flow model, it contains several flows(WaveNets).
Args:
n_loops (List[int]): `n_loop` for each flow.
n_layers (List[int]): `n_layer` for each flow.
residual_channels (int): `residual_channels` for every flow.
condition_dim (int): `condition_dim` for every flow.
filter_size (int): `filter_size` for every flow.
"""
super(ParallelWaveNet, self).__init__() super(ParallelWaveNet, self).__init__()
self.flows = dg.LayerList() self.flows = dg.LayerList()
for n_loop, n_layer in zip(n_loops, n_layers): for n_loop, n_layer in zip(n_loops, n_layers):
@ -38,20 +47,18 @@ class ParallelWaveNet(dg.Layer):
filter_size, "mog", -100.0)) filter_size, "mog", -100.0))
def forward(self, z, condition=None): def forward(self, z, condition=None):
"""Inverse Autoregressive Flow. Several wavenets. """Transform a random noise sampled from a standard Gaussian distribution into sample from the target distribution. And output the mean and log standard deviation of the output distribution.
Arguments: Args:
z {Variable} -- shape(batch_size, time_steps), hidden variable, sampled from a standard normal distribution. z (Variable): shape(B, T), random noise sampled from a standard gaussian disribution.
condition (Variable, optional): shape(B, F, T), dtype float, the upsampled condition. Defaults to None.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps), condition, basically upsampled mel spectrogram. (default: {None})
Returns: Returns:
Variable -- shape(batch_size, time_steps), transformed z. (z, out_mu, out_log_std)
Variable -- shape(batch_size, time_steps), output distribution's mu. z (Variable): shape(B, T), dtype float, transformed noise, it is the synthesized waveform.
Variable -- shape(batch_size, time_steps), output distribution's log_std. out_mu (Variable): shape(B, T), dtype float, means of the output distributions.
out_log_std (Variable): shape(B, T), dtype float, log standard deviations of the output distributions.
""" """
for i, flow in enumerate(self.flows): for i, flow in enumerate(self.flows):
theta = flow(z, condition) # w, mu, log_std [0: T] theta = flow(z, condition) # w, mu, log_std [0: T]
w, mu, log_std = F.split(theta, 3, dim=-1) # (B, T, 1) for each w, mu, log_std = F.split(theta, 3, dim=-1) # (B, T, 1) for each

View File

@ -31,6 +31,16 @@ class Attention(dg.Layer):
window_range=WindowRange(-1, 3), window_range=WindowRange(-1, 3),
key_projection=True, key_projection=True,
value_projection=True): value_projection=True):
"""Attention Layer for Deep Voice 3.
Args:
query_dim (int): the dimension of query vectors. (The size of a single vector of query.)
embed_dim (int): the dimension of keys and values.
dropout (float, optional): dropout probability of attention. Defaults to 0.0.
window_range (WindowRange, optional): range of attention, this is only used at inference. Defaults to WindowRange(-1, 3).
key_projection (bool, optional): whether the `Attention` Layer has a Linear Layer for the keys to pass through before computing attention. Defaults to True.
value_projection (bool, optional): whether the `Attention` Layer has a Linear Layer for the values to pass through before computing attention. Defaults to True.
"""
super(Attention, self).__init__() super(Attention, self).__init__()
std = np.sqrt(1 / query_dim) std = np.sqrt(1 / query_dim)
self.query_proj = Linear( self.query_proj = Linear(
@ -54,29 +64,19 @@ class Attention(dg.Layer):
def forward(self, query, encoder_out, mask=None, last_attended=None): def forward(self, query, encoder_out, mask=None, last_attended=None):
""" """
Compute pooled context representation and alignment scores. Compute contextualized representation and alignment scores.
Args: Args:
query (Variable): shape(B, T_dec, C_q), the query tensor, query (Variable): shape(B, T_dec, C_q), dtype float32, the query tensor, where C_q means the query dim.
where C_q means the channel of query. encoder_out (keys, values):
encoder_out (Tuple(Variable, Variable)): keys (Variable): shape(B, T_enc, C_emb), dtype float32, the key representation from an encoder, where C_emb means embed dim.
keys (Variable): shape(B, T_enc, C_emb), the key values (Variable): shape(B, T_enc, C_emb), dtype float32, the value representation from an encoder, where C_emb means embed dim.
representation from an encoder, where C_emb means mask (Variable, optional): shape(B, T_enc), dtype float32, mask generated with valid text lengths. Pad tokens corresponds to 1, and valid tokens correspond to 0.
text embedding size. last_attended (int, optional): The position that received the most attention at last time step. This is only used at inference.
values (Variable): shape(B, T_enc, C_emb), the value
representation from an encoder, where C_emb means
text embedding size.
mask (Variable, optional): Shape(B, T_enc), mask generated with
valid text lengths.
last_attended (int, optional): The position that received most
attention at last timestep. This is only used at decoding.
Outpus: Outpus:
x (Variable): Shape(B, T_dec, C_q), the context representation x (Variable): shape(B, T_dec, C_q), dtype float32, the contextualized representation from attention mechanism.
pooled from attention mechanism. attn_scores (Variable): shape(B, T_dec, T_enc), dtype float32, the alignment tensor, where T_dec means the number of decoder time steps and T_enc means number the number of decoder time steps.
attn_scores (Variable): shape(B, T_dec, T_enc), the alignment
tensor, where T_dec means the number of decoder time steps and
T_enc means number the number of decoder time steps.
""" """
keys, values = encoder_out keys, values = encoder_out
residual = query residual = query
@ -85,7 +85,6 @@ class Attention(dg.Layer):
if self.key_projection: if self.key_projection:
keys = self.key_proj(keys) keys = self.key_proj(keys)
x = self.query_proj(query) x = self.query_proj(query)
# TODO: check the code
x = F.matmul(x, keys, transpose_y=True) x = F.matmul(x, keys, transpose_y=True)
@ -97,7 +96,6 @@ class Attention(dg.Layer):
# if last_attended is provided, focus only on a window range around it # if last_attended is provided, focus only on a window range around it
# to enforce monotonic attention. # to enforce monotonic attention.
# TODO: if last attended is a shape(B,) array
if last_attended is not None: if last_attended is not None:
locality_mask = np.ones(shape=x.shape, dtype=np.float32) locality_mask = np.ones(shape=x.shape, dtype=np.float32)
backward, ahead = self.window_range backward, ahead = self.window_range
@ -116,7 +114,7 @@ class Attention(dg.Layer):
x, self.dropout, dropout_implementation="upscale_in_train") x, self.dropout, dropout_implementation="upscale_in_train")
x = F.matmul(x, values) x = F.matmul(x, values)
encoder_length = keys.shape[1] encoder_length = keys.shape[1]
# CAUTION: is it wrong? let it be now
x = F.scale(x, encoder_length * np.sqrt(1.0 / encoder_length)) x = F.scale(x, encoder_length * np.sqrt(1.0 / encoder_length))
x = self.out_proj(x) x = self.out_proj(x)
x = F.scale((x + residual), np.sqrt(0.5)) x = F.scale((x + residual), np.sqrt(0.5))

View File

@ -24,10 +24,7 @@ from parakeet.modules.weight_norm import Conv1D, Conv1DCell, Conv2D, Linear
class Conv1DGLU(dg.Layer): class Conv1DGLU(dg.Layer):
""" """
A Convolution 1D block with GLU activation. It also applys dropout for the A Convolution 1D block with GLU activation. It also applys dropout for the input x. It integrates speaker embeddings through a Linear activated by softsign. It has residual connection from the input x, and scale the output by np.sqrt(0.5).
input x. It fuses speaker embeddings through a FC activated by softsign. It
has residual connection from the input x, and scale the output by
np.sqrt(0.5).
""" """
def __init__(self, def __init__(self,
@ -41,8 +38,21 @@ class Conv1DGLU(dg.Layer):
dropout=0.0, dropout=0.0,
causal=False, causal=False,
residual=True): residual=True):
super(Conv1DGLU, self).__init__() """[summary]
Args:
n_speakers (int): number of speakers.
speaker_dim (int): speaker embedding's size.
in_channels (int): channels of the input.
num_filters (int): channels of the output.
filter_size (int, optional): filter size of the internal Conv1DCell. Defaults to 1.
dilation (int, optional): dilation of the internal Conv1DCell. Defaults to 1.
std_mul (float, optional): [description]. Defaults to 4.0.
dropout (float, optional): dropout probability. Defaults to 0.0.
causal (bool, optional): padding of the Conv1DCell. It shoudl be True if `add_input` method of `Conv1DCell` is ever used. Defaults to False.
residual (bool, optional): whether to use residual connection. If True, in_channels shoudl equals num_filters. Defaults to True.
"""
super(Conv1DGLU, self).__init__()
# conv spec # conv spec
self.in_channels = in_channels self.in_channels = in_channels
self.n_speakers = n_speakers self.n_speakers = n_speakers
@ -83,18 +93,12 @@ class Conv1DGLU(dg.Layer):
def forward(self, x, speaker_embed=None): def forward(self, x, speaker_embed=None):
""" """
Args: Args:
x (Variable): Shape(B, C_in, T), the input of Conv1DGLU x (Variable): shape(B, C_in, T), dtype float32, the input of Conv1DGLU layer, where B means batch_size, C_in means the input channels T means input time steps.
layer, where B means batch_size, C_in means the input channels speaker_embed (Variable): shape(B, C_sp), dtype float32, speaker embed, where C_sp means speaker embedding size.
T means input time steps.
speaker_embed_bct1 (Variable): Shape(B, C_sp), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns: Returns:
x (Variable): Shape(B, C_out, T), the output of Conv1DGLU, where x (Variable): shape(B, C_out, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU. C_out means the `num_filters`.
""" """
residual = x residual = x
x = F.dropout( x = F.dropout(
@ -114,22 +118,20 @@ class Conv1DGLU(dg.Layer):
return x return x
def start_sequence(self): def start_sequence(self):
"""Prepare the Conv1DGLU to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
"""
self.conv.start_sequence() self.conv.start_sequence()
def add_input(self, x_t, speaker_embed=None): def add_input(self, x_t, speaker_embed=None):
""" """
Takes a step of inputs and return a step of outputs. It works similarily with the `forward` method, but in a `step-in-step-out` fashion.
Args: Args:
x (Variable): Shape(B, C_in), the input of Conv1DGLU x_t (Variable): shape(B, C_in, T=1), dtype float32, the input of Conv1DGLU layer, where B means batch_size, C_in means the input channels.
layer, where B means batch_size, C_in means the input channels. speaker_embed (Variable): Shape(B, C_sp), dtype float32, speaker embed, where C_sp means speaker embedding size.
speaker_embed_bct1 (Variable): Shape(B, C_sp), expanded
speaker embed, where C_sp means speaker embedding size. Note
that when using residual connection, the Conv1DGLU does not
change the number of channels, so out channels equals input
channels.
Returns: Returns:
x (Variable): Shape(B, C_out), the output of Conv1DGLU, where x (Variable): shape(B, C_out), the output of Conv1DGLU, where C_out means the `num_filter`.
C_out means the output channels of Conv1DGLU.
""" """
residual = x_t residual = x_t
x_t = F.dropout( x_t = F.dropout(

View File

@ -25,6 +25,17 @@ from parakeet.models.deepvoice3.encoder import ConvSpec
def upsampling_4x_blocks(n_speakers, speaker_dim, target_channels, dropout): def upsampling_4x_blocks(n_speakers, speaker_dim, target_channels, dropout):
"""Return a list of Layers that upsamples the input by 4 times in time dimension.
Args:
n_speakers (int): number of speakers of the Conv1DGLU layers used.
speaker_dim (int): speaker embedding size of the Conv1DGLU layers used.
target_channels (int): channels of the input and the output.(the list of layers does not change the number of channels.)
dropout (float): dropout probability.
Returns:
List[Layer]: upsampling layers.
"""
# upsampling convolitions # upsampling convolitions
upsampling_convolutions = [ upsampling_convolutions = [
Conv1DTranspose( Conv1DTranspose(
@ -41,42 +52,56 @@ def upsampling_4x_blocks(n_speakers, speaker_dim, target_channels, dropout):
3, 3,
dilation=1, dilation=1,
std_mul=1., std_mul=1.,
dropout=dropout), Conv1DGLU( dropout=dropout),
n_speakers, Conv1DGLU(
speaker_dim, n_speakers,
target_channels, speaker_dim,
target_channels, target_channels,
3, target_channels,
dilation=3, 3,
std_mul=4., dilation=3,
dropout=dropout), Conv1DTranspose( std_mul=4.,
target_channels, dropout=dropout),
target_channels, Conv1DTranspose(
2, target_channels,
stride=2, target_channels,
param_attr=I.Normal(scale=np.sqrt( 2,
4. / (2 * target_channels)))), Conv1DGLU( stride=2,
n_speakers, param_attr=I.Normal(scale=np.sqrt(4. / (2 * target_channels)))),
speaker_dim, Conv1DGLU(
target_channels, n_speakers,
target_channels, speaker_dim,
3, target_channels,
dilation=1, target_channels,
std_mul=1., 3,
dropout=dropout), Conv1DGLU( dilation=1,
n_speakers, std_mul=1.,
speaker_dim, dropout=dropout),
target_channels, Conv1DGLU(
target_channels, n_speakers,
3, speaker_dim,
dilation=3, target_channels,
std_mul=4., target_channels,
dropout=dropout) 3,
dilation=3,
std_mul=4.,
dropout=dropout),
] ]
return upsampling_convolutions return upsampling_convolutions
def upsampling_2x_blocks(n_speakers, speaker_dim, target_channels, dropout): def upsampling_2x_blocks(n_speakers, speaker_dim, target_channels, dropout):
"""Return a list of Layers that upsamples the input by 2 times in time dimension.
Args:
n_speakers (int): number of speakers of the Conv1DGLU layers used.
speaker_dim (int): speaker embedding size of the Conv1DGLU layers used.
target_channels (int): channels of the input and the output.(the list of layers does not change the number of channels.)
dropout (float): dropout probability.
Returns:
List[Layer]: upsampling layers.
"""
upsampling_convolutions = [ upsampling_convolutions = [
Conv1DTranspose( Conv1DTranspose(
target_channels, target_channels,
@ -106,6 +131,17 @@ def upsampling_2x_blocks(n_speakers, speaker_dim, target_channels, dropout):
def upsampling_1x_blocks(n_speakers, speaker_dim, target_channels, dropout): def upsampling_1x_blocks(n_speakers, speaker_dim, target_channels, dropout):
"""Return a list of Layers that upsamples the input by 1 times in time dimension.
Args:
n_speakers (int): number of speakers of the Conv1DGLU layers used.
speaker_dim (int): speaker embedding size of the Conv1DGLU layers used.
target_channels (int): channels of the input and the output.(the list of layers does not change the number of channels.)
dropout (float): dropout probability.
Returns:
List[Layer]: upsampling layers.
"""
upsampling_convolutions = [ upsampling_convolutions = [
Conv1DGLU( Conv1DGLU(
n_speakers, n_speakers,
@ -121,11 +157,6 @@ def upsampling_1x_blocks(n_speakers, speaker_dim, target_channels, dropout):
class Converter(dg.Layer): class Converter(dg.Layer):
"""
Vocoder that transforms mel spectrogram (or ecoder hidden states)
to waveform.
"""
def __init__(self, def __init__(self,
n_speakers, n_speakers,
speaker_dim, speaker_dim,
@ -134,6 +165,17 @@ class Converter(dg.Layer):
convolutions=(ConvSpec(256, 5, 1), ) * 4, convolutions=(ConvSpec(256, 5, 1), ) * 4,
time_upsampling=1, time_upsampling=1,
dropout=0.0): dropout=0.0):
"""Vocoder that transforms mel spectrogram (or ecoder hidden states) to waveform.
Args:
n_speakers (int): number of speakers.
speaker_dim (int): speaker embedding size.
in_channels (int): channels of the input.
linear_dim (int): channels of the linear spectrogram.
convolutions (Iterable[ConvSpec], optional): specifications of the internal convolutional layers. ConvSpec is a namedtuple of (output_channels, filter_size, dilation) Defaults to (ConvSpec(256, 5, 1), )*4.
time_upsampling (int, optional): time upsampling factor of the converter, possible options are {1, 2, 4}. Note that this should equals the downsample factor of the mel spectrogram. Defaults to 1.
dropout (float, optional): dropout probability. Defaults to 0.0.
"""
super(Converter, self).__init__() super(Converter, self).__init__()
self.n_speakers = n_speakers self.n_speakers = n_speakers
@ -215,23 +257,12 @@ class Converter(dg.Layer):
Convert mel spectrogram or decoder hidden states to linear spectrogram. Convert mel spectrogram or decoder hidden states to linear spectrogram.
Args: Args:
x (Variable): Shape(B, T_mel, C_in), converter inputs, where x (Variable): Shape(B, T_mel, C_in), dtype float32, converter inputs, where C_in means the input channel for the converter. Note that it can be either C_mel (channel of mel spectrogram) or C_dec // r.
C_in means the input channel for the converter. Note that it When use mel_spectrogram as the input of converter, C_in = C_mel; and when use decoder states as the input of converter, C_in = C_dec // r.
can be either C_mel (channel of mel spectrogram) or C_dec // r. speaker_embed (Variable, optional): shape(B, C_sp), dtype float32, speaker embedding, where C_sp means the speaker embedding size.
When use mel_spectrogram as the input of converter, C_in =
C_mel; and when use decoder states as the input of converter,
C_in = C_dec // r. In this scenario, decoder hidden states are
treated as if they were r outputs per decoder step and are
unpacked before passing to the converter.
speaker_embed (Variable, optional): shape(B, C_sp), speaker
embedding, where C_sp means the speaker embedding size.
Returns: Returns:
out (Variable): Shape(B, T_lin, C_lin), the output linear out (Variable): Shape(B, T_lin, C_lin), the output linear spectrogram, where C_lin means the channel of linear spectrogram and T_linear means the length(time steps) of linear spectrogram. T_line = time_upsampling * T_mel, which depends on the time_upsampling of the converter.
spectrogram, where C_lin means the channel of linear
spectrogram and T_linear means the length(time steps) of linear
spectrogram. T_line = time_upsampling * T_mel, which depends
on the time_upsampling converter.
""" """
x = F.transpose(x, [0, 2, 1]) x = F.transpose(x, [0, 2, 1])
x = self.first_conv_proj(x) x = self.first_conv_proj(x)

View File

@ -36,15 +36,12 @@ def gen_mask(valid_lengths, max_len, dtype="float32"):
[0, 0, 0, 0, 0, 0, 0]]. [0, 0, 0, 0, 0, 0, 0]].
Args: Args:
valid_lengths (Variable): Shape(B), dtype: int64. A 1D-Tensor containing valid_lengths (Variable): shape(B, ), dtype: int64. A rank-1 Tensor containing the valid lengths (timesteps) of each example, where B means beatch_size.
the valid lengths (timesteps) of each example, where B means max_len (int): The length (number of time steps) of the mask.
beatch_size. dtype (str, optional): A string that specifies the data type of the returned mask. Defaults to 'float32'.
max_len (int): The length (number of timesteps) of the mask.
dtype (str, optional): A string that specifies the data type of the
returned mask.
Returns: Returns:
mask (Variable): A mask computed from valid lengths. mask (Variable): shape(B, max_len), dtype float32, a mask computed from valid lengths.
""" """
mask = F.sequence_mask(valid_lengths, maxlen=max_len, dtype=dtype) mask = F.sequence_mask(valid_lengths, maxlen=max_len, dtype=dtype)
mask = 1 - mask mask = 1 - mask
@ -54,14 +51,13 @@ def gen_mask(valid_lengths, max_len, dtype="float32"):
def fold_adjacent_frames(frames, r): def fold_adjacent_frames(frames, r):
"""fold multiple adjacent frames. """fold multiple adjacent frames.
Arguments: Args:
frames {Variable} -- shape(batch_size, time_steps, channels), the spectrogram frames (Variable): shape(B, T, C), the spectrogram.
r {int} -- frames per step. r (int): frames per step.
Returns: Returns:
Variable -- shape(batch_size, time_steps // r, r *channels), folded frames Variable: shape(B, T // r, r * C), folded frames.
""" """
if r == 1: if r == 1:
return frames return frames
batch_size, time_steps, channels = frames.shape batch_size, time_steps, channels = frames.shape
@ -75,16 +71,15 @@ def fold_adjacent_frames(frames, r):
def unfold_adjacent_frames(folded_frames, r): def unfold_adjacent_frames(folded_frames, r):
"""fold multiple adjacent frames. """unfold the folded frames.
Arguments: Args:
folded_frames {Variable} -- shape(batch_size, time_steps // r, r * channels), the spectrogram folded_frames (Variable): shape(B, T, C), the folded spectrogram.
r {int} -- frames per step. r (int): frames per step.
Returns: Returns:
Variable -- shape(batch_size, time_steps, channels), folded frames Variable: shape(B, T * r, C // r), unfolded frames.
""" """
if r == 1: if r == 1:
return folded_frames return folded_frames
batch_size, time_steps, channels = folded_frames.shape batch_size, time_steps, channels = folded_frames.shape
@ -93,26 +88,44 @@ def unfold_adjacent_frames(folded_frames, r):
class Decoder(dg.Layer): class Decoder(dg.Layer):
def __init__( def __init__(self,
self, n_speakers,
n_speakers, speaker_dim,
speaker_dim, embed_dim,
embed_dim, mel_dim,
mel_dim, r=1,
r=1, max_positions=512,
max_positions=512, preattention=(ConvSpec(128, 5, 1), ) * 4,
padding_idx=None, # remove it! convolutions=(ConvSpec(128, 5, 1), ) * 4,
preattention=(ConvSpec(128, 5, 1), ) * 4, attention=True,
convolutions=(ConvSpec(128, 5, 1), ) * 4, dropout=0.0,
attention=True, use_memory_mask=False,
dropout=0.0, force_monotonic_attention=False,
use_memory_mask=False, query_position_rate=1.0,
force_monotonic_attention=False, key_position_rate=1.0,
query_position_rate=1.0, window_range=WindowRange(-1, 3),
key_position_rate=1.0, key_projection=True,
window_range=WindowRange(-1, 3), value_projection=True):
key_projection=True, """Decoder of the Deep Voice 3 model.
value_projection=True):
Args:
n_speakers (int): number of speakers.
speaker_dim (int): speaker embedding size.
embed_dim (int): text embedding size.
mel_dim (int): channel of mel input.(mel bands)
r (int, optional): number of frames generated per decoder step. Defaults to 1.
max_positions (int, optional): max position for text and decoder steps. Defaults to 512.
convolutions (Iterable[ConvSpec], optional): specification of causal convolutional layers inside the decoder. ConvSpec is a namedtuple of output_channels, filter_size and dilation. Defaults to (ConvSpec(128, 5, 1), )*4.
attention (bool or List[bool], optional): whether to use attention, it should have the same length with `convolutions` if it is a list of bool, indicating whether to have an Attention layer coupled with the corresponding convolutional layer. If it is a bool, it is repeated len(convolutions) times internally. Defaults to True.
dropout (float, optional): dropout probability. Defaults to 0.0.
use_memory_mask (bool, optional): whether to use memory mask at the Attention layer. It should have the same length with `attention` if it is a list of bool, indicating whether to use memory mask at the corresponding Attention layer. If it is a bool, it is repeated len(attention) times internally. Defaults to False.
force_monotonic_attention (bool, optional): whether to use monotonic_attention at the Attention layer when inferencing. It should have the same length with `attention` if it is a list of bool, indicating whether to use monotonic_attention at the corresponding Attention layer. If it is a bool, it is repeated len(attention) times internally. Defaults to False.
query_position_rate (float, optional): position_rate of the PositionEmbedding for query. Defaults to 1.0.
key_position_rate (float, optional): position_rate of the PositionEmbedding for key. Defaults to 1.0.
window_range (WindowRange, optional): window range of monotonic attention. Defaults to WindowRange(-1, 3).
key_projection (bool, optional): `key_projection` of Attention layers. Defaults to True.
value_projection (bool, optional): `value_projection` of Attention layers Defaults to True.
"""
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.dropout = dropout self.dropout = dropout
@ -125,10 +138,9 @@ class Decoder(dg.Layer):
conv_channels = convolutions[0].out_channels conv_channels = convolutions[0].out_channels
# only when padding idx is 0 can we easilt handle it # only when padding idx is 0 can we easilt handle it
self.embed_keys_positions = PositionEmbedding( self.embed_keys_positions = PositionEmbedding(max_positions, embed_dim)
max_positions, embed_dim, padding_idx=0) self.embed_query_positions = PositionEmbedding(max_positions,
self.embed_query_positions = PositionEmbedding( conv_channels)
max_positions, conv_channels, padding_idx=0)
if n_speakers > 1: if n_speakers > 1:
std = np.sqrt((1 - dropout) / speaker_dim) std = np.sqrt((1 - dropout) / speaker_dim)
@ -248,41 +260,20 @@ class Decoder(dg.Layer):
Compute decoder outputs with ground truth mel spectrogram. Compute decoder outputs with ground truth mel spectrogram.
Args: Args:
encoder_out (Tuple(Variable, Variable)): encoder_out (keys, values):
keys (Variable): shape(B, T_enc, C_emb), the key keys (Variable): shape(B, T_enc, C_emb), dtype float32, the key representation from an encoder, where C_emb means text embedding size.
representation from an encoder, where C_emb means values (Variable): shape(B, T_enc, C_emb), dtype float32, the value representation from an encoder, where C_emb means text embedding size.
text embedding size. lengths (Variable): shape(batch_size,), dtype: int64, valid lengths of text inputs for each example.
values (Variable): shape(B, T_enc, C_emb), the value inputs (Variable): shape(B, T_mel, C_mel), ground truth mel-spectrogram, which is used as decoder inputs when training.
representation from an encoder, where C_emb means text_positions (Variable): shape(B, T_enc), dtype: int64. Positions indices for text inputs for the encoder, where T_enc means the encoder timesteps.
text embedding size. frame_positions (Variable): shape(B, T_mel // r), dtype: int64. Positions indices for each decoder time steps.
lengths (Variable): shape(batch_size,), dtype: int64, valid lengths speaker_embed (Variable, optionals): shape(batch_size, speaker_dim), speaker embedding, only used for multispeaker model.
of text inputs for each example.
inputs (Variable): shape(B, T_mel, C_mel), ground truth
mel-spectrogram, which is used as decoder inputs when training.
text_positions (Variable): shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps.
frame_positions (Variable): shape(B, T_mel // r), dtype:
int64. Positions indices for each decoder time steps.
speaker_embed: shape(batch_size, speaker_dim), speaker embedding,
only used for multispeaker model.
Returns: Returns:
outputs (Variable): Shape(B, T_mel // r, r * C_mel). Decoder outputs (Variable): shape(B, T_mel, C_mel), dtype float32, decoder outputs, where C_mel means the channels of mel-spectrogram, T_mel means the length(time steps) of mel spectrogram.
outputs, where C_mel means the channels of mel-spectrogram, r alignments (Variable): shape(N, B, T_mel // r, T_enc), dtype float32, the alignment tensor between the decoder and the encoder, where N means number of Attention Layers, T_mel means the length of mel spectrogram, r means the outputs per decoder step, T_enc means the encoder time steps.
means the outputs per decoder step, T_mel means the length(time done (Variable): shape(B, T_mel // r), dtype float32, probability that the last frame has been generated.
steps) of mel spectrogram. Note that, when r > 1, the decoder decoder_states (Variable): shape(B, T_mel, C_dec // r), ddtype float32, decoder hidden states, where C_dec means the channels of decoder states (the output channels of the last `convolutions`). Note that it should be perfectlt devided by `r`.
outputs r frames of mel spectrogram per step.
alignments (Variable): Shape(N, B, T_mel // r, T_enc), the alignment
tensor between the decoder and the encoder, where N means number
of Attention Layers, T_mel means the length of mel spectrogram,
r means the outputs per decoder step, T_enc means the encoder
time steps.
done (Variable): Shape(B, T_mel // r), probability that the
outputs should stop.
decoder_states (Variable): Shape(B, T_mel // r, C_dec), decoder
hidden states, where C_dec means the channels of decoder states.
""" """
if speaker_embed is not None: if speaker_embed is not None:
speaker_embed = F.dropout( speaker_embed = F.dropout(
@ -366,6 +357,8 @@ class Decoder(dg.Layer):
return r return r
def start_sequence(self): def start_sequence(self):
"""Prepare the Decoder to decode. This method is called by `decode`.
"""
for layer in self.prenet: for layer in self.prenet:
if isinstance(layer, Conv1DGLU): if isinstance(layer, Conv1DGLU):
layer.start_sequence() layer.start_sequence()
@ -379,6 +372,25 @@ class Decoder(dg.Layer):
text_positions, text_positions,
speaker_embed=None, speaker_embed=None,
test_inputs=None): test_inputs=None):
"""Decode from the encoder's output and other conditions.
Args:
encoder_out (keys, values):
keys (Variable): shape(B, T_enc, C_emb), dtype float32, the key representation from an encoder, where C_emb means text embedding size.
values (Variable): shape(B, T_enc, C_emb), dtype float32, the value representation from an encoder, where C_emb means text embedding size.
text_positions (Variable): shape(B, T_enc), dtype: int64. Positions indices for text inputs for the encoder, where T_enc means the encoder timesteps.
speaker_embed (Variable, optional): shape(B, C_sp), speaker embedding, only used for multispeaker model.
test_inputs (Variable, optional): shape(B, T_test, C_mel). test input, it is only used for debugging. Defaults to None.
Returns:
outputs (Variable): shape(B, T_mel, C_mel), dtype float32, decoder outputs, where C_mel means the channels of mel-spectrogram, T_mel means the length(time steps) of mel spectrogram.
alignments (Variable): shape(N, B, T_mel // r, T_enc), dtype float32, the alignment tensor between the decoder and the encoder, where N means number of Attention Layers, T_mel means the length of mel spectrogram, r means the outputs per decoder step, T_enc means the encoder time steps.
done (Variable): shape(B, T_mel // r), dtype float32, probability that the last frame has been generated. If the probability is larger than 0.5 at a step, the generation stops.
decoder_states (Variable): shape(B, T_mel, C_dec // r), ddtype float32, decoder hidden states, where C_dec means the channels of decoder states (the output channels of the last `convolutions`). Note that it should be perfectlt devided by `r`.
Note:
Only single instance inference is supported now, so B = 1.
"""
self.start_sequence() self.start_sequence()
keys, values = encoder_out keys, values = encoder_out
batch_size = keys.shape[0] batch_size = keys.shape[0]

View File

@ -34,10 +34,20 @@ class Encoder(dg.Layer):
padding_idx=None, padding_idx=None,
embedding_weight_std=0.1, embedding_weight_std=0.1,
convolutions=(ConvSpec(64, 5, 1), ) * 7, convolutions=(ConvSpec(64, 5, 1), ) * 7,
max_positions=512,
dropout=0.): dropout=0.):
super(Encoder, self).__init__() """Encoder of Deep Voice 3.
Args:
n_vocab (int): vocabulary size of the text embedding.
embed_dim (int): embedding size of the text embedding.
n_speakers (int): number of speakers.
speaker_dim (int): speaker embedding size.
padding_idx (int, optional): padding index of text embedding. Defaults to None.
embedding_weight_std (float, optional): standard deviation of the embedding weights when intialized. Defaults to 0.1.
convolutions (Iterable[ConvSpec], optional): specifications of the convolutional layers. ConvSpec is a namedtuple of output channels, filter_size and dilation. Defaults to (ConvSpec(64, 5, 1), )*7.
dropout (float, optional): dropout probability. Defaults to 0..
"""
super(Encoder, self).__init__()
self.embedding_weight_std = embedding_weight_std self.embedding_weight_std = embedding_weight_std
self.embed = dg.Embedding( self.embed = dg.Embedding(
(n_vocab, embed_dim), (n_vocab, embed_dim),
@ -101,18 +111,12 @@ class Encoder(dg.Layer):
Encode text sequence. Encode text sequence.
Args: Args:
x (Variable): Shape(B, T_enc), dtype: int64. Ihe input text x (Variable): shape(B, T_enc), dtype: int64. Ihe input text indices. T_enc means the timesteps of decoder input x.
indices. T_enc means the timesteps of decoder input x. speaker_embed (Variable, optional): shape(B, C_sp), dtype float32, speaker embeddings. This arg is not None only when the model is a multispeaker model.
speaker_embed (Variable, optional): Shape(batch_size, speaker_dim),
dtype: float32. Speaker embeddings. This arg is not None only
when the model is a multispeaker model.
Returns: Returns:
keys (Variable), Shape(B, T_enc, C_emb), the encoded keys (Variable), Shape(B, T_enc, C_emb), dtype float32, the encoded epresentation for keys, where C_emb menas the text embedding size.
representation for keys, where C_emb menas the text embedding values (Variable), Shape(B, T_enc, C_emb), dtype float32, the encoded representation for values.
size.
values (Variable), Shape(B, T_enc, C_emb), the encoded
representation for values.
""" """
x = self.embed(x) x = self.embed(x)
x = F.dropout( x = F.dropout(

View File

@ -23,12 +23,10 @@ import paddle.fluid.dygraph as dg
def masked_mean(inputs, mask): def masked_mean(inputs, mask):
""" """
Args: Args:
inputs (Variable): Shape(B, T, C), the input, where B means inputs (Variable): shape(B, T, C), dtype float32, the input.
batch size, C means channels of input, T means timesteps of mask (Variable): shape(B, T), dtype float32, a mask.
the input.
mask (Variable): Shape(B, T), a mask.
Returns: Returns:
loss (Variable): Shape(1, ), masked mean. loss (Variable): shape(1, ), dtype float32, masked mean.
""" """
channels = inputs.shape[-1] channels = inputs.shape[-1]
masked_inputs = F.elementwise_mul(inputs, mask, axis=0) masked_inputs = F.elementwise_mul(inputs, mask, axis=0)
@ -38,6 +36,18 @@ def masked_mean(inputs, mask):
@jit(nopython=True) @jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g): def guided_attention(N, max_N, T, max_T, g):
"""Generate an diagonal attention guide.
Args:
N (int): valid length of encoder.
max_N (int): max length of encoder.
T (int): valid length of decoder.
max_T (int): max length of decoder.
g (float): sigma to adjust the degree of diagonal guide.
Returns:
np.ndarray: shape(max_N, max_T), dtype float32, the diagonal guide.
"""
W = np.zeros((max_N, max_T), dtype=np.float32) W = np.zeros((max_N, max_T), dtype=np.float32)
for n in range(N): for n in range(N):
for t in range(T): for t in range(T):
@ -47,6 +57,17 @@ def guided_attention(N, max_N, T, max_T, g):
def guided_attentions(encoder_lengths, decoder_lengths, max_decoder_len, def guided_attentions(encoder_lengths, decoder_lengths, max_decoder_len,
g=0.2): g=0.2):
"""Generate a diagonal attention guide for a batch.
Args:
encoder_lengths (np.ndarray): shape(B, ), dtype: int64, encoder valid lengths.
decoder_lengths (np.ndarray): shape(B, ), dtype: int64, decoder valid lengths.
max_decoder_len (int): max length of decoder.
g (float, optional): sigma to adjust the degree of diagonal guide.. Defaults to 0.2.
Returns:
np.ndarray: shape(B, max_T, max_N), dtype float32, the diagonal guide. (max_N: max encoder length, max_T: max decoder length.)
"""
B = len(encoder_lengths) B = len(encoder_lengths)
max_input_len = encoder_lengths.max() max_input_len = encoder_lengths.max()
W = np.zeros((B, max_decoder_len, max_input_len), dtype=np.float32) W = np.zeros((B, max_decoder_len, max_input_len), dtype=np.float32)
@ -65,6 +86,17 @@ class TTSLoss(object):
guided_attention_sigma=0.2, guided_attention_sigma=0.2,
downsample_factor=4, downsample_factor=4,
r=1): r=1):
"""Compute loss for Deep Voice 3 model.
Args:
masked_weight (float, optional): the weight of masked loss. Defaults to 0.0.
priority_bin ([type], optional): frequency bands for linear spectrogram loss to be prioritized. Defaults to None.
priority_weight (float, optional): weight for the prioritized frequency bands. Defaults to 0.0.
binary_divergence_weight (float, optional): weight for binary cross entropy (used for spectrogram loss). Defaults to 0.0.
guided_attention_sigma (float, optional): `sigma` for attention guide. Defaults to 0.2.
downsample_factor (int, optional): the downsample factor for mel spectrogram. Defaults to 4.
r (int, optional): frames per decoder step. Defaults to 1.
"""
self.masked_weight = masked_weight self.masked_weight = masked_weight
self.priority_bin = priority_bin # only used for lin-spec loss self.priority_bin = priority_bin # only used for lin-spec loss
self.priority_weight = priority_weight # only used for lin-spec loss self.priority_weight = priority_weight # only used for lin-spec loss
@ -76,6 +108,17 @@ class TTSLoss(object):
self.downsample_factor = downsample_factor self.downsample_factor = downsample_factor
def l1_loss(self, prediction, target, mask, priority_bin=None): def l1_loss(self, prediction, target, mask, priority_bin=None):
"""L1 loss for spectrogram.
Args:
prediction (Variable): shape(B, T, C), dtype float32, predicted spectrogram.
target (Variable): shape(B, T, C), dtype float32, target spectrogram.
mask (Variable): shape(B, T), mask.
priority_bin (int, optional): frequency bands for linear spectrogram loss to be prioritized. Defaults to None.
Returns:
Variable: shape(1,), dtype float32, l1 loss(with mask and possibly priority bin applied.)
"""
abs_diff = F.abs(prediction - target) abs_diff = F.abs(prediction - target)
# basic mask-weighted l1 loss # basic mask-weighted l1 loss
@ -103,6 +146,16 @@ class TTSLoss(object):
return loss return loss
def binary_divergence(self, prediction, target, mask): def binary_divergence(self, prediction, target, mask):
"""Binary cross entropy loss for spectrogram. All the values in the spectrogram are treated as logits in a logistic regression.
Args:
prediction (Variable): shape(B, T, C), dtype float32, predicted spectrogram.
target (Variable): shape(B, T, C), dtype float32, target spectrogram.
mask (Variable): shape(B, T), mask.
Returns:
Variable: shape(1,), dtype float32, binary cross entropy loss.
"""
flattened_prediction = F.reshape(prediction, [-1, 1]) flattened_prediction = F.reshape(prediction, [-1, 1])
flattened_target = F.reshape(target, [-1, 1]) flattened_target = F.reshape(target, [-1, 1])
flattened_loss = F.log_loss( flattened_loss = F.log_loss(
@ -119,6 +172,15 @@ class TTSLoss(object):
@staticmethod @staticmethod
def done_loss(done_hat, done): def done_loss(done_hat, done):
"""Compute done loss
Args:
done_hat (Variable): shape(B, T), dtype float32, predicted done probability(the probability that the final frame has been generated.)
done (Variable): shape(B, T), dtype float32, ground truth done probability(the probability that the final frame has been generated.)
Returns:
Variable: shape(1, ), dtype float32, done loss.
"""
flat_done_hat = F.reshape(done_hat, [-1, 1]) flat_done_hat = F.reshape(done_hat, [-1, 1])
flat_done = F.reshape(done, [-1, 1]) flat_done = F.reshape(done, [-1, 1])
loss = F.log_loss(flat_done_hat, flat_done, epsilon=1e-8) loss = F.log_loss(flat_done_hat, flat_done, epsilon=1e-8)
@ -128,21 +190,15 @@ class TTSLoss(object):
def attention_loss(self, predicted_attention, input_lengths, def attention_loss(self, predicted_attention, input_lengths,
target_lengths): target_lengths):
""" """
Given valid encoder_lengths and decoder_lengths, compute a diagonal Given valid encoder_lengths and decoder_lengths, compute a diagonal guide, and compute loss from the predicted attention and the guide.
guide, and compute loss from the predicted attention and the guide.
Args: Args:
predicted_attention (Variable): Shape(*, B, T_dec, T_enc), the predicted_attention (Variable): shape(*, B, T_dec, T_enc), dtype float32, the alignment tensor, where B means batch size, T_dec means number of time steps of the decoder, T_enc means the number of time steps of the encoder, * means other possible dimensions.
alignment tensor, where B means batch size, T_dec means number input_lengths (numpy.ndarray): shape(B,), dtype:int64, valid lengths (time steps) of encoder outputs.
of time steps of the decoder, T_enc means the number of time target_lengths (numpy.ndarray): shape(batch_size,), dtype:int64, valid lengths (time steps) of decoder outputs.
steps of the encoder, * means other possible dimensions.
input_lengths (numpy.ndarray): Shape(B,), dtype:int64, valid lengths
(time steps) of encoder outputs.
target_lengths (numpy.ndarray): Shape(batch_size,), dtype:int64,
valid lengths (time steps) of decoder outputs.
Returns: Returns:
loss (Variable): Shape(1, ) attention loss. loss (Variable): shape(1, ), dtype float32, attention loss.
""" """
n_attention, batch_size, max_target_len, max_input_len = ( n_attention, batch_size, max_target_len, max_input_len = (
predicted_attention.shape) predicted_attention.shape)
@ -167,6 +223,26 @@ class TTSLoss(object):
compute_mel_loss=True, compute_mel_loss=True,
compute_done_loss=True, compute_done_loss=True,
compute_attn_loss=True): compute_attn_loss=True):
"""Total loss
Args:
mel_hyp (Variable): shape(B, T, C_mel), dtype float32, predicted mel spectrogram.
lin_hyp (Variable): shape(B, T, C_lin), dtype float32, predicted linear spectrogram.
done_hyp (Variable): shape(B, T), dtype float32, predicted done probability.
attn_hyp (Variable): shape(N, B, T_dec, T_enc), dtype float32, predicted attention.
mel_ref (Variable): shape(B, T, C_mel), dtype float32, ground truth mel spectrogram.
lin_ref (Variable): shape(B, T, C_lin), dtype float32, ground truth linear spectrogram.
done_ref (Variable): shape(B, T), dtype float32, ground truth done flag.
input_lengths (Variable): shape(B, ), dtype: int, encoder valid lengths.
n_frames (Variable): shape(B, ), dtype: int, decoder valid lengths.
compute_lin_loss (bool, optional): whether to compute linear loss. Defaults to True.
compute_mel_loss (bool, optional): whether to compute mel loss. Defaults to True.
compute_done_loss (bool, optional): whether to compute done loss. Defaults to True.
compute_attn_loss (bool, optional): whether to compute atention loss. Defaults to True.
Returns:
Dict(str, Variable): details of loss.
"""
total_loss = 0. total_loss = 0.
# n_frames # mel_lengths # decoder_lengths # n_frames # mel_lengths # decoder_lengths

View File

@ -22,6 +22,15 @@ import paddle.fluid.dygraph as dg
class DeepVoice3(dg.Layer): class DeepVoice3(dg.Layer):
def __init__(self, encoder, decoder, converter, speaker_embedding, def __init__(self, encoder, decoder, converter, speaker_embedding,
use_decoder_states): use_decoder_states):
"""Deep Voice 3 TTS model.
Args:
encoder (Layer): the encoder.
decoder (Layer): the decoder.
converter (Layer): the converter.
speaker_embedding (Layer): the speaker embedding (for multispeaker cases).
use_decoder_states (bool): use decoder states instead of predicted mel spectrogram as the input of the converter.
"""
super(DeepVoice3, self).__init__() super(DeepVoice3, self).__init__()
if speaker_embedding is None: if speaker_embedding is None:
self.n_speakers = 1 self.n_speakers = 1
@ -34,6 +43,24 @@ class DeepVoice3(dg.Layer):
def forward(self, text_sequences, text_positions, valid_lengths, def forward(self, text_sequences, text_positions, valid_lengths,
speaker_indices, mel_inputs, frame_positions): speaker_indices, mel_inputs, frame_positions):
"""Compute predicted value in a teacher forcing training manner.
Args:
text_sequences (Variable): shape(B, T_enc), dtype: int64, text indices.
text_positions (Variable): shape(B, T_enc), dtype: int64, positions of text indices.
valid_lengths (Variable): shape(B, ), dtype: int64, valid lengths of utterances.
speaker_indices (Variable): shape(B, ), dtype: int64, speaker indices for utterances.
mel_inputs (Variable): shape(B, T_mel, C_mel), dytpe: int64, ground truth mel spectrogram.
frame_positions (Variable): shape(B, T_dec), dtype: int64, positions of decoder steps.
Returns:
(mel_outputs, linear_outputs, alignments, done)
mel_outputs (Variable): shape(B, T_mel, C_mel), dtype float32, predicted mel spectrogram.
mel_outputs (Variable): shape(B, T_mel, C_mel), dtype float32, predicted mel spectrogram.
alignments (Variable): shape(N, B, T_dec, T_enc), dtype float32, predicted attention.
done (Variable): shape(B, T_dec), dtype float32, predicted done probability.
(T_mel: time steps of mel spectrogram, T_lin: time steps of linear spectrogra, T_dec, time steps of decoder, T_enc: time steps of encoder.)
"""
if hasattr(self, "speaker_embedding"): if hasattr(self, "speaker_embedding"):
speaker_embed = self.speaker_embedding(speaker_indices) speaker_embed = self.speaker_embedding(speaker_indices)
else: else:
@ -49,6 +76,21 @@ class DeepVoice3(dg.Layer):
return mel_outputs, linear_outputs, alignments, done return mel_outputs, linear_outputs, alignments, done
def transduce(self, text_sequences, text_positions, speaker_indices=None): def transduce(self, text_sequences, text_positions, speaker_indices=None):
"""Generate output without teacher forcing. Only batch_size = 1 is supported.
Args:
text_sequences (Variable): shape(B, T_enc), dtype: int64, text indices.
text_positions (Variable): shape(B, T_enc), dtype: int64, positions of text indices.
speaker_indices (Variable): shape(B, ), dtype: int64, speaker indices for utterances.
Returns:
(mel_outputs, linear_outputs, alignments, done)
mel_outputs (Variable): shape(B, T_mel, C_mel), dtype float32, predicted mel spectrogram.
mel_outputs (Variable): shape(B, T_mel, C_mel), dtype float32, predicted mel spectrogram.
alignments (Variable): shape(B, T_dec, T_enc), dtype float32, predicted average attention of all attention layers.
done (Variable): shape(B, T_dec), dtype float32, predicted done probability.
(T_mel: time steps of mel spectrogram, T_lin: time steps of linear spectrogra, T_dec, time steps of decoder, T_enc: time steps of encoder.)
"""
if hasattr(self, "speaker_embedding"): if hasattr(self, "speaker_embedding"):
speaker_embed = self.speaker_embedding(speaker_indices) speaker_embed = self.speaker_embedding(speaker_indices)
else: else:

View File

@ -19,14 +19,14 @@ import paddle.fluid.dygraph as dg
def compute_position_embedding(radians, speaker_position_rate): def compute_position_embedding(radians, speaker_position_rate):
"""compute sin/cos separately and scatter them to a zero. """Compute sin/cos interleaved matrix from the radians.
Arguments: Arg:
radians {Variable} -- shape(n_vocab, embed_dim), the radians matrix. radians (Variable): shape(n_vocab, embed_dim), dtype float32, the radians matrix.
speaker_position_rate {Variable} -- shape(batch_size, ), speaker positioning rate. speaker_position_rate (Variable): shape(B, ), speaker positioning rate.
Returns: Returns:
Variable -- shape(batch_size, n_vocab, embed_dim), the sin, cos matrix. Variable: shape(B, n_vocab, embed_dim), the sin, cos interleaved matrix.
""" """
_, embed_dim = radians.shape _, embed_dim = radians.shape
batch_size = speaker_position_rate.shape[0] batch_size = speaker_position_rate.shape[0]
@ -48,10 +48,20 @@ def position_encoding_init(n_position,
d_pos_vec, d_pos_vec,
position_rate=1.0, position_rate=1.0,
padding_idx=None): padding_idx=None):
"""init the position encoding table""" """Init the position encoding.
Args:
n_position (int): max position, vocab size for position embedding.
d_pos_vec (int): position embedding size.
position_rate (float, optional): position rate (this should only be used when all the utterances are from one speaker.). Defaults to 1.0.
padding_idx (int, optional): padding index for the position embedding(it is set as 0 internally if not provided.). Defaults to None.
Returns:
[type]: [description]
"""
# init the position encoding table
# keep idx 0 for padding token position encoding zero vector # keep idx 0 for padding token position encoding zero vector
# CAUTION: it is radians here, sin and cos are not applied # CAUTION: it is radians here, sin and cos are not applied
# CAUTION: difference here
indices_range = np.expand_dims(np.arange(n_position), -1) indices_range = np.expand_dims(np.arange(n_position), -1)
embed_range = 2 * (np.arange(d_pos_vec) // 2) embed_range = 2 * (np.arange(d_pos_vec) // 2)
radians = position_rate \ radians = position_rate \
@ -63,31 +73,32 @@ def position_encoding_init(n_position,
class PositionEmbedding(dg.Layer): class PositionEmbedding(dg.Layer):
def __init__(self, def __init__(self, n_position, d_pos_vec, position_rate=1.0):
n_position, """Position Embedding for Deep Voice 3.
d_pos_vec,
position_rate=1.0, Args:
param_attr=None, n_position (int): max position, vocab size for position embedding.
max_norm=None, d_pos_vec (int): position embedding size.
padding_idx=None): position_rate (float, optional): position rate (this should only be used when all the utterances are from one speaker.). Defaults to 1.0.
"""
super(PositionEmbedding, self).__init__() super(PositionEmbedding, self).__init__()
self.weight = self.create_parameter((n_position, d_pos_vec)) self.weight = self.create_parameter((n_position, d_pos_vec))
self.weight.set_value( self.weight.set_value(
position_encoding_init(n_position, d_pos_vec, position_rate, position_encoding_init(n_position, d_pos_vec, position_rate)
padding_idx).astype("float32")) .astype("float32"))
def forward(self, indices, speaker_position_rate=None): def forward(self, indices, speaker_position_rate=None):
""" """
Args: Args:
indices (Variable): Shape (B, T), dtype: int64, position indices (Variable): shape (B, T), dtype: int64, position
indices, where B means the batch size, T means the time steps. indices, where B means the batch size, T means the time steps.
speaker_position_rate (Variable | float, optional), position speaker_position_rate (Variable | float, optional), position
rate. It can be a float point number or a Variable with rate. It can be a float point number or a Variable with
shape (1,), then this speaker_position_rate is used for every shape (1,), then this speaker_position_rate is used for every
example. It can also be a Variable with shape (B, 1), which example. It can also be a Variable with shape (B, ), which
contains a speaker position rate for each speaker. contains a speaker position rate for each utterance.
Returns: Returns:
out (Variable): Shape(B, T, C_pos), position embedding, where C_pos out (Variable): shape(B, T, C_pos), dtype float32, position embedding, where C_pos
means position embedding size. means position embedding size.
""" """
batch_size, time_steps = indices.shape batch_size, time_steps = indices.shape

View File

@ -27,17 +27,16 @@ from parakeet.models.wavenet.wavenet import WaveNet
def crop(x, audio_start, audio_length): def crop(x, audio_start, audio_length):
"""Crop mel spectrogram. """Crop the upsampled condition to match audio_length. The upsampled condition has the same time steps as the whole audio does. But since audios are sliced to 0.5 seconds randomly while conditions are not, upsampled conditions should also be sliced to extaclt match the time steps of the audio slice.
Args: Args:
x (Variable): shape(batch_size, channels, time_steps), the condition, upsampled mel spectrogram. x (Variable): shape(B, C, T), dtype float32, the upsample condition.
audio_start (int): starting point. audio_start (Variable): shape(B, ), dtype: int64, the index the starting point.
audio_length (int): length. audio_length (int): the length of the audio (number of samples it contaions).
Returns: Returns:
out: cropped condition. Variable: shape(B, C, audio_length), cropped condition.
""" """
# crop audio # crop audio
slices = [] # for each example slices = [] # for each example
starts = audio_start.numpy() starts = audio_start.numpy()
@ -51,12 +50,15 @@ def crop(x, audio_start, audio_length):
class UpsampleNet(dg.Layer): class UpsampleNet(dg.Layer):
"""A upsampling net (bridge net) in clarinet to upsample spectrograms from frame level to sample level.
It consists of several(2) layers of transposed_conv2d. in time and frequency.
The time dim is dilated hop_length times. The frequency bands retains.
"""
def __init__(self, upscale_factors=[16, 16]): def __init__(self, upscale_factors=[16, 16]):
"""UpsamplingNet.
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose layer upsamples the time dimension by its `stride` times. And each Conv2DTranspose's filter_size at frequency dimension is 3.
Args:
upscale_factors (list[int], optional): time upsampling factors for each Conv2DTranspose Layer. The `UpsampleNet` contains len(upscale_factor) Conv2DTranspose Layers. Each upscale_factor is used as the `stride` for the corresponding Conv2DTranspose. Defaults to [16, 16].
Note:
np.prod(upscale_factors) should equals the `hop_length` of the stft transformation used to extract spectrogram features from audios. For example, 16 * 16 = 256, then the spectram extracted using a stft transformation whose `hop_length` is 256. See `librosa.stft` for more details.
"""
super(UpsampleNet, self).__init__() super(UpsampleNet, self).__init__()
self.upscale_factors = list(upscale_factors) self.upscale_factors = list(upscale_factors)
self.upsample_convs = dg.LayerList() self.upsample_convs = dg.LayerList()
@ -74,13 +76,13 @@ class UpsampleNet(dg.Layer):
return np.prod(self.upscale_factors) return np.prod(self.upscale_factors)
def forward(self, x): def forward(self, x):
"""upsample local condition to match time steps of input signals. i.e. upsample mel spectrogram to match time steps for waveform, for each layer of a wavenet. """Compute the upsampled condition.
Arguments: Args:
x {Variable} -- shape(batch_size, frequency, time_steps), local condition x (Variable): shape(B, F, T), dtype float32, the condition (mel spectrogram here.) (F means the frequency bands). In the internal Conv2DTransposes, the frequency dimension is treated as `height` dimension instead of `in_channels`.
Returns: Returns:
Variable -- shape(batch_size, frequency, time_steps * np.prod(upscale_factors)), upsampled condition for each layer. Variable: shape(B, F, T * upscale_factor), dtype float32, the upsampled condition.
""" """
x = F.unsqueeze(x, axes=[1]) x = F.unsqueeze(x, axes=[1])
for sublayer in self.upsample_convs: for sublayer in self.upsample_convs:
@ -91,27 +93,31 @@ class UpsampleNet(dg.Layer):
# AutoRegressive Model # AutoRegressive Model
class ConditionalWavenet(dg.Layer): class ConditionalWavenet(dg.Layer):
def __init__(self, encoder: UpsampleNet, decoder: WaveNet): def __init__(self, encoder, decoder):
"""Conditional Wavenet, which contains an UpsampleNet as the encoder and a WaveNet as the decoder. It is an autoregressive model.
Args:
encoder (UpsampleNet): the UpsampleNet as the encoder.
decoder (WaveNet): the WaveNet as the decoder.
"""
super(ConditionalWavenet, self).__init__() super(ConditionalWavenet, self).__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
def forward(self, audio, mel, audio_start): def forward(self, audio, mel, audio_start):
"""forward """Compute the output distribution given the mel spectrogram and the input(for teacher force training).
Arguments: Args:
audio {Variable} -- shape(batch_size, time_steps), waveform of 0.5 seconds audio (Variable): shape(B, T_audio), dtype float32, ground truth waveform, used for teacher force training.
mel {Variable} -- shape(batch_size, frequency_bands, frames), mel spectrogram of the whole sentence mel ([Variable): shape(B, F, T_mel), dtype float32, mel spectrogram. Note that it is the spectrogram for the whole utterance.
audio_start {Variable} -- shape(batch_size, ), audio start positions audio_start (Variable): shape(B, ), dtype: int, audio slices' start positions for each utterance.
Returns: Returns:
Variable -- shape(batch_size, time_steps - 1, output_dim), output distribution parameters Variable: shape(B, T_audio - 1, C_putput), parameters for the output distribution.(C_output is the `output_dim` of the decoder.)
""" """
audio_length = audio.shape[1] # audio clip's length audio_length = audio.shape[1] # audio clip's length
condition = self.encoder(mel) condition = self.encoder(mel)
condition_slice = crop(condition, audio_start, condition_slice = crop(condition, audio_start, audio_length)
audio_length) # crop audio
# shifting 1 step # shifting 1 step
audio = audio[:, :-1] audio = audio[:, :-1]
@ -121,43 +127,41 @@ class ConditionalWavenet(dg.Layer):
return y return y
def loss(self, y, t): def loss(self, y, t):
"""compute loss """compute loss with respect to the output distribution and the targer audio.
Arguments: Args:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution parameters y (Variable): shape(B, T - 1, C_output), dtype float32, parameters of the output distribution.
t {Variable} -- shape(batch_size, time_steps), target waveform t (Variable): shape(B, T), dtype float32, target waveform.
Returns: Returns:
Variable -- shape(1, ), reduced loss Variable: shape(1, ), dtype float32, the loss.
""" """
t = t[:, 1:] t = t[:, 1:]
loss = self.decoder.loss(y, t) loss = self.decoder.loss(y, t)
return loss return loss
def sample(self, y): def sample(self, y):
"""sample from output distribution """Sample from the output distribution.
Arguments: Args:
y {Variable} -- shape(batch_size, time_steps, output_dim), output distribution parameters y (Variable): shape(B, T, C_output), dtype float32, parameters of the output distribution.
Returns: Returns:
Variable -- shape(batch_size, time_steps) samples Variable: shape(B, T), dtype float32, sampled waveform from the output distribution.
""" """
samples = self.decoder.sample(y) samples = self.decoder.sample(y)
return samples return samples
@dg.no_grad @dg.no_grad
def synthesis(self, mel): def synthesis(self, mel):
"""synthesize waveform from mel spectrogram """Synthesize waveform from mel spectrogram.
Arguments: Args:
mel {Variable} -- shape(batch_size, frequency_bands, frames), mel-spectrogram mel (Variable): shape(B, F, T), condition(mel spectrogram here).
Returns: Returns:
Variable -- shape(batch_size, time_steps), synthesized waveform. Variable: shape(B, T * upsacle_factor), synthesized waveform.(`upscale_factor` is the `upscale_factor` of the encoder `UpsampleNet`)
""" """
condition = self.encoder(mel) condition = self.encoder(mel)
batch_size, _, time_steps = condition.shape batch_size, _, time_steps = condition.shape
samples = [] samples = []

View File

@ -27,11 +27,29 @@ from parakeet.modules.weight_norm import Linear, Conv1D, Conv1DCell, Conv2DTrans
# for wavenet with softmax loss # for wavenet with softmax loss
def quantize(values, n_bands): def quantize(values, n_bands):
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
Args:
values (Variable): dtype: flaot32 or float64. the floating point value.
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
Returns:
Variable: the quantized tensor, dtype: int64.
"""
quantized = F.cast((values + 1.0) / 2.0 * n_bands, "int64") quantized = F.cast((values + 1.0) / 2.0 * n_bands, "int64")
return quantized return quantized
def dequantize(quantized, n_bands): def dequantize(quantized, n_bands):
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
Args:
quantized (Variable): dtype: int64. The quantized value in the range [0, n_bands).
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
Returns:
Variable: the dequantized tensor, dtype float3232.
"""
value = (F.cast(quantized, "float32") + 0.5) * (2.0 / n_bands) - 1.0 value = (F.cast(quantized, "float32") + 0.5) * (2.0 / n_bands) - 1.0
return value return value
@ -39,6 +57,14 @@ def dequantize(quantized, n_bands):
class ResidualBlock(dg.Layer): class ResidualBlock(dg.Layer):
def __init__(self, residual_channels, condition_dim, filter_size, def __init__(self, residual_channels, condition_dim, filter_size,
dilation): dilation):
"""A Residual block in wavenet. It does not have parametric residual or skip connection. It consists of a Conv1DCell and an Conv1D(filter_size = 1) to integrate the condition.
Args:
residual_channels (int): the channels of the input, residual and skip.
condition_dim (int): the channels of the condition.
filter_size (int): filter size of the internal convolution cell.
dilation (int): dilation of the internal convolution cell.
"""
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
dilated_channels = 2 * residual_channels dilated_channels = 2 * residual_channels
# following clarinet's implementation, we do not have parametric residual # following clarinet's implementation, we do not have parametric residual
@ -64,17 +90,16 @@ class ResidualBlock(dg.Layer):
self.condition_dim = condition_dim self.condition_dim = condition_dim
def forward(self, x, condition=None): def forward(self, x, condition=None):
"""Conv1D gated tanh Block """Conv1D gated-tanh Block.
Arguments: Args:
x {Variable} -- shape(batch_size, residual_channels, time_steps), the input. x (Variable): shape(B, C_res, T), the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.) dtype float32.
condition (Variable, optional): shape(B, C_cond, T), the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels). Defaults to None.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps), upsampled local condition, it has the shape time steps as the input x. (default: {None})
Returns: Returns:
Variable -- shape(batch_size, residual_channels, time_steps), the output which is used as the input of the next layer. (residual, skip_connection)
Variable -- shape(batch_size, residual_channels, time_steps), the output which is stacked alongside with other layers' as the output of wavenet. residual (Variable): shape(B, C_res, T), the residual, which is used as the input to the next layer of ResidualBlock.
skip_connection (Variable): shape(B, C_res, T), the skip connection. This output is accumulated with that of other ResidualBlocks.
""" """
time_steps = x.shape[-1] time_steps = x.shape[-1]
h = x h = x
@ -98,20 +123,21 @@ class ResidualBlock(dg.Layer):
return residual, skip_connection return residual, skip_connection
def start_sequence(self): def start_sequence(self):
"""Prepare the ResidualBlock to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
"""
self.conv.start_sequence() self.conv.start_sequence()
def add_input(self, x, condition=None): def add_input(self, x, condition=None):
"""add a step input. """Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
Arguments: Args:
x {Variable} -- shape(batch_size, in_channels, time_steps=1), step input x (Variable): shape(B, C_res, T=1), input for a step, dtype float32.
condition (Variable, optional): shape(B, C_cond, T=1). condition for a step, dtype float32. Defaults to None.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps=1) (default: {None})
Returns: Returns:
Variable -- shape(batch_size, in_channels, time_steps=1), residual connection, which is the input for the next layer (residual, skip_connection)
Variable -- shape(batch_size, in_channels, time_steps=1), skip connection residual (Variable): shape(B, C_res, T=1), the residual for a step, which is used as the input to the next layer of ResidualBlock.
skip_connection (Variable): shape(B, C_res, T=1), the skip connection for a step. This output is accumulated with that of other ResidualBlocks.
""" """
h = x h = x
@ -135,6 +161,15 @@ class ResidualBlock(dg.Layer):
class ResidualNet(dg.Layer): class ResidualNet(dg.Layer):
def __init__(self, n_loop, n_layer, residual_channels, condition_dim, def __init__(self, n_loop, n_layer, residual_channels, condition_dim,
filter_size): filter_size):
"""The residual network in wavenet. It consists of `n_layer` stacks, each of which consists of `n_loop` ResidualBlocks.
Args:
n_loop (int): number of ResidualBlocks in a stack.
n_layer (int): number of stacks in the `ResidualNet`.
residual_channels (int): channels of each `ResidualBlock`'s input.
condition_dim (int): channels of the condition.
filter_size (int): filter size of the internal Conv1DCell of each `ResidualBlock`.
"""
super(ResidualNet, self).__init__() super(ResidualNet, self).__init__()
# double the dilation at each layer in a loop(n_loop layers) # double the dilation at each layer in a loop(n_loop layers)
dilations = [2**i for i in range(n_loop)] * n_layer dilations = [2**i for i in range(n_loop)] * n_layer
@ -145,19 +180,14 @@ class ResidualNet(dg.Layer):
]) ])
def forward(self, x, condition=None): def forward(self, x, condition=None):
"""n_layer layers of n_loop Residual Blocks. """
Args:
Arguments: x (Variable): shape(B, C_res, T), dtype float32, the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.)
x {Variable} -- shape(batch_size, residual_channels, time_steps), input of the residual net. condition (Variable, optional): shape(B, C_cond, T), dtype float32, the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels) Defaults to None.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps), upsampled conditions, which has the same time steps as the input. (default: {None})
Returns: Returns:
Variable -- shape(batch_size, skip_channels, time_steps), output of the residual net. skip_connection (Variable): shape(B, C_res, T), dtype float32, the output.
""" """
#before_resnet = time.time()
for i, func in enumerate(self.residual_blocks): for i, func in enumerate(self.residual_blocks):
x, skip = func(x, condition) x, skip = func(x, condition)
if i == 0: if i == 0:
@ -165,24 +195,23 @@ class ResidualNet(dg.Layer):
else: else:
skip_connections = F.scale(skip_connections + skip, skip_connections = F.scale(skip_connections + skip,
np.sqrt(0.5)) np.sqrt(0.5))
#print("resnet: ", time.time() - before_resnet)
return skip_connections return skip_connections
def start_sequence(self): def start_sequence(self):
"""Prepare the ResidualNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
"""
for block in self.residual_blocks: for block in self.residual_blocks:
block.start_sequence() block.start_sequence()
def add_input(self, x, condition=None): def add_input(self, x, condition=None):
"""add step input and return step output. """Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
Arguments: Args:
x {Variable} -- shape(batch_size, residual_channels, time_steps=1), step input. x (Variable): shape(B, C_res, T=1), dtype float32, input for a step.
condition (Variable, optional): shape(B, C_cond, T=1), dtype float32, condition for a step. Defaults to None.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps=1), step condition (default: {None})
Returns: Returns:
Variable -- shape(batch_size, skip_channels, time_steps=1), step output, parameters of the output distribution. skip_connection (Variable): shape(B, C_res, T=1), dtype float32, the output for a step.
""" """
for i, func in enumerate(self.residual_blocks): for i, func in enumerate(self.residual_blocks):
@ -198,6 +227,18 @@ class ResidualNet(dg.Layer):
class WaveNet(dg.Layer): class WaveNet(dg.Layer):
def __init__(self, n_loop, n_layer, residual_channels, output_dim, def __init__(self, n_loop, n_layer, residual_channels, output_dim,
condition_dim, filter_size, loss_type, log_scale_min): condition_dim, filter_size, loss_type, log_scale_min):
"""Wavenet that transform upsampled mel spectrogram into waveform.
Args:
n_loop (int): n_loop for the internal ResidualNet.
n_layer (int): n_loop for the internal ResidualNet.
residual_channels (int): the channel of the input.
output_dim (int): the channel of the output distribution.
condition_dim (int): the channel of the condition.
filter_size (int): the filter size of the internal ResidualNet.
loss_type (str): loss type of the wavenet. Possible values are 'softmax' and 'mog'. If `loss_type` is 'softmax', the output is the logits of the catrgotical(multinomial) distribution, `output_dim` means the number of classes of the categorical distribution. If `loss_type` is mog(mixture of gaussians), the output is the parameters of a mixture of gaussians, which consists of weight(in the form of logit) of each gaussian distribution and its mean and log standard deviaton. So when `loss_type` is 'mog', `output_dim` should be perfectly divided by 3.
log_scale_min (int): the minimum value of log standard deviation of the output gaussian distributions. Note that this value is only used for computing loss if `loss_type` is 'mog', values less than `log_scale_min` is clipped when computing loss.
"""
super(WaveNet, self).__init__() super(WaveNet, self).__init__()
if loss_type not in ["softmax", "mog"]: if loss_type not in ["softmax", "mog"]:
raise ValueError("loss_type {} is not supported".format(loss_type)) raise ValueError("loss_type {} is not supported".format(loss_type))
@ -225,19 +266,16 @@ class WaveNet(dg.Layer):
self.log_scale_min = log_scale_min self.log_scale_min = log_scale_min
def forward(self, x, condition=None): def forward(self, x, condition=None):
"""(Possibly) Conditonal Wavenet. """compute the output distribution (represented by its parameters).
Arguments: Args:
x {Variable} -- shape(batch_size, time_steps), the input signal of wavenet. The waveform in 0.5 seconds. x (Variable): shape(B, T), dtype float32, the input waveform.
condition (Variable, optional): shape(B, C_cond, T), dtype float32, the upsampled condition. Defaults to None.
Keyword Arguments:
conditions {Variable} -- shape(batch_size, condition_dim, 1, time_steps), the upsampled local condition. (default: {None})
Returns: Returns:
Variable -- shape(batch_size, time_steps, output_dim), output distributions at each time_steps. Variable: shape(B, T, C_output), dtype float32, the parameter of the output distributions.
""" """
# CAUTION: rank-4 condition here
# Causal Conv # Causal Conv
if self.loss_type == "softmax": if self.loss_type == "softmax":
x = F.clip(x, min=-1., max=0.99999) x = F.clip(x, min=-1., max=0.99999)
@ -258,21 +296,20 @@ class WaveNet(dg.Layer):
return y return y
def start_sequence(self): def start_sequence(self):
"""Prepare the WaveNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
"""
self.resnet.start_sequence() self.resnet.start_sequence()
def add_input(self, x, condition=None): def add_input(self, x, condition=None):
"""add step input """compute the output distribution (represented by its parameters) for a step. It works similarily with the `forward` method but in a `step-in-step-out` fashion.
Arguments: Args:
x {Variable} -- shape(batch_size, time_steps=1), step input. x (Variable): shape(B, T=1), dtype float32, a step of the input waveform.
condition (Variable, optional): shape(B, C_cond, T=1), dtype float32, a step of the upsampled condition. Defaults to None.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim , 1, time_steps=1) (default: {None})
Returns: Returns:
Variable -- ouput parameter for the distribution. Variable: shape(B, T=1, C_output), dtype float32, the parameter of the output distributions.
""" """
# Causal Conv # Causal Conv
if self.loss_type == "softmax": if self.loss_type == "softmax":
x = quantize(x, self.output_dim) x = quantize(x, self.output_dim)
@ -292,16 +329,15 @@ class WaveNet(dg.Layer):
return y return y
def compute_softmax_loss(self, y, t): def compute_softmax_loss(self, y, t):
"""compute loss, it is basically a language_model-like loss. """compute the loss where output distribution is a categorial distribution.
Arguments: Args:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution of multinomial distribution. y (Variable): shape(B, T, C_output), dtype float32, the logits of the output distribution.
t {Variable} -- shape(batch_size, time_steps - 1), target waveform. t (Variable): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
Returns: Returns:
Variable -- shape(1,), loss Variable: shape(1, ), dtype float32, the loss.
""" """
# context size is not taken into account # context size is not taken into account
y = y[:, self.context_size:, :] y = y[:, self.context_size:, :]
t = t[:, self.context_size:] t = t[:, self.context_size:]
@ -314,15 +350,14 @@ class WaveNet(dg.Layer):
return reduced_loss return reduced_loss
def sample_from_softmax(self, y): def sample_from_softmax(self, y):
"""sample from output distribution. """Sample from the output distribution where the output distribution is a categorical distriobution.
Arguments: Args:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution. y (Variable): shape(B, T, C_output), the logits of the output distribution
Returns: Returns:
Variable -- shape(batch_size, time_steps - 1), samples. Variable: shape(B, T), waveform sampled from the output distribution.
""" """
# dequantize # dequantize
batch_size, time_steps, output_dim, = y.shape batch_size, time_steps, output_dim, = y.shape
y = F.reshape(y, (batch_size * time_steps, output_dim)) y = F.reshape(y, (batch_size * time_steps, output_dim))
@ -333,17 +368,15 @@ class WaveNet(dg.Layer):
return samples return samples
def compute_mog_loss(self, y, t): def compute_mog_loss(self, y, t):
"""compute the loss with an mog output distribution. """compute the loss where output distribution is a mixture of Gaussians.
WARNING: this is not a legal probability, but a density. so it might be greater than 1.
Arguments: Args:
y {Variable} -- shape(batch_size, time_steps, output_dim), output distribution's parameter. To represent a mixture of Gaussians. The output for each example at each time_step consists of 3 parts. The mean, the stddev, and a weight for that gaussian. y (Variable): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
t {Variable} -- shape(batch_size, time_steps), target waveform. t (Variable): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
Returns: Returns:
Variable -- loss, note that it is computed with the pdf of the MoG distribution. Variable: shape(1, ), dtype float32, the loss.
""" """
n_mixture = self.output_dim // 3 n_mixture = self.output_dim // 3
# context size is not taken in to account # context size is not taken in to account
@ -373,15 +406,13 @@ class WaveNet(dg.Layer):
return loss return loss
def sample_from_mog(self, y): def sample_from_mog(self, y):
"""sample from output distribution. """Sample from the output distribution where the output distribution is a mixture of Gaussians.
Args:
Arguments: y (Variable): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution.
Returns: Returns:
Variable -- shape(batch_size, time_steps - 1), samples. Variable: shape(B, T), waveform sampled from the output distribution.
""" """
batch_size, time_steps, output_dim = y.shape batch_size, time_steps, output_dim = y.shape
n_mixture = output_dim // 3 n_mixture = output_dim // 3
@ -405,31 +436,28 @@ class WaveNet(dg.Layer):
return samples return samples
def sample(self, y): def sample(self, y):
"""sample from output distribution. """Sample from the output distribution.
Args:
Arguments: y (Variable): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution.
Returns: Returns:
Variable -- shape(batch_size, time_steps - 1), samples. Variable: shape(B, T), waveform sampled from the output distribution.
""" """
if self.loss_type == "softmax": if self.loss_type == "softmax":
return self.sample_from_softmax(y) return self.sample_from_softmax(y)
else: else:
return self.sample_from_mog(y) return self.sample_from_mog(y)
def loss(self, y, t): def loss(self, y, t):
"""compute loss. """compute the loss where output distribution is a mixture of Gaussians.
Arguments: Args:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution of multinomial distribution. y (Variable): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
t {Variable} -- shape(batch_size, time_steps - 1), target waveform. t (Variable): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
Returns: Returns:
Variable -- shape(1,), loss Variable: shape(1, ), dtype float32, the loss.
""" """
if self.loss_type == "softmax": if self.loss_type == "softmax":
return self.compute_softmax_loss(y, t) return self.compute_softmax_loss(y, t)
else: else: