1. update docstrings for models.wavenet;
2. remove unnecessary code; 3. fix typos
This commit is contained in:
parent
dd2c5cc6c6
commit
84ad4c9e65
|
@ -20,22 +20,6 @@ parakeet.modules.audio module
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
parakeet.modules.cbhg module
|
|
||||||
----------------------------
|
|
||||||
|
|
||||||
.. automodule:: parakeet.modules.cbhg
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
||||||
|
|
||||||
parakeet.modules.connections module
|
|
||||||
-----------------------------------
|
|
||||||
|
|
||||||
.. automodule:: parakeet.modules.connections
|
|
||||||
:members:
|
|
||||||
:undoc-members:
|
|
||||||
:show-inheritance:
|
|
||||||
|
|
||||||
parakeet.modules.conv module
|
parakeet.modules.conv module
|
||||||
----------------------------
|
----------------------------
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
__version__ = "0.2.0"
|
__version__ = "0.2.0-beta"
|
||||||
|
|
||||||
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
||||||
|
|
|
@ -1,36 +0,0 @@
|
||||||
import parakeet
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
package_path = Path(__file__).parent
|
|
||||||
print(package_path)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
subparser = parser.add_subparsers(dest="cmd")
|
|
||||||
|
|
||||||
list_exp_parser = subparser.add_parser("list-examples")
|
|
||||||
clone = subparser.add_parser("clone-example")
|
|
||||||
clone.add_argument("experiment_name", type=str, help="experiment name")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.cmd == "list-examples":
|
|
||||||
print(os.listdir(package_path / "examples"))
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
if args.cmd == "clone-example":
|
|
||||||
source = package_path / "examples" / (args.experiment_name)
|
|
||||||
target = Path(os.getcwd()) / (args.experiment_name)
|
|
||||||
if not os.path.exists(str(source)):
|
|
||||||
raise ValueError("{} does not exist".format(str(source)))
|
|
||||||
|
|
||||||
if os.path.exists(str(target)):
|
|
||||||
raise FileExistsError("{} already exists".format(str(target)))
|
|
||||||
|
|
||||||
shutil.copytree(str(source), str(target))
|
|
||||||
print("{} copied!".format(args.experiment_name))
|
|
||||||
exit(0)
|
|
|
@ -1,5 +1,4 @@
|
||||||
from typing import Dict, Iterable, List
|
from typing import Dict, Iterable, List
|
||||||
from ruamel import yaml
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
#from parakeet.models.clarinet import *
|
#from parakeet.models.clarinet import *
|
||||||
from parakeet.models.waveflow import *
|
from parakeet.models.waveflow import *
|
||||||
#from parakeet.models.wavenet import *
|
from parakeet.models.wavenet import *
|
||||||
|
|
||||||
from parakeet.models.transformer_tts import *
|
from parakeet.models.transformer_tts import *
|
||||||
#from parakeet.models.deepvoice3 import *
|
#from parakeet.models.deepvoice3 import *
|
||||||
|
|
|
@ -30,15 +30,26 @@ from parakeet.utils import checkpoint, layer_tools
|
||||||
|
|
||||||
|
|
||||||
def crop(x, audio_start, audio_length):
|
def crop(x, audio_start, audio_length):
|
||||||
"""Crop the upsampled condition to match audio_length. The upsampled condition has the same time steps as the whole audio does. But since audios are sliced to 0.5 seconds randomly while conditions are not, upsampled conditions should also be sliced to extaclt match the time steps of the audio slice.
|
"""Crop the upsampled condition to match audio_length.
|
||||||
|
|
||||||
|
The upsampled condition has the same time steps as the whole audio does.
|
||||||
|
But since audios are sliced to 0.5 seconds randomly while conditions are
|
||||||
|
not, upsampled conditions should also be sliced to extactly match the time
|
||||||
|
steps of the audio slice.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(B, C, T), dtype float32, the upsample condition.
|
----------
|
||||||
audio_start (Tensor): shape(B, ), dtype: int64, the index the starting point.
|
x : Tensor [shape=(B, C, T)]
|
||||||
audio_length (int): the length of the audio (number of samples it contaions).
|
The upsampled condition.
|
||||||
|
audio_start : Tensor [shape=(B,), dtype:int]
|
||||||
|
The index of the starting point of the audio clips.
|
||||||
|
audio_length : int
|
||||||
|
The length of the audio clip(number of samples it contaions).
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(B, C, audio_length), cropped condition.
|
-------
|
||||||
|
Tensor [shape=(B, C, audio_length)]
|
||||||
|
Cropped condition.
|
||||||
"""
|
"""
|
||||||
# crop audio
|
# crop audio
|
||||||
slices = [] # for each example
|
slices = [] # for each example
|
||||||
|
@ -54,15 +65,37 @@ def crop(x, audio_start, audio_length):
|
||||||
|
|
||||||
|
|
||||||
class UpsampleNet(nn.LayerList):
|
class UpsampleNet(nn.LayerList):
|
||||||
def __init__(self, upscale_factors=[16, 16]):
|
"""A network used to upsample mel spectrogram to match the time steps of
|
||||||
"""UpsamplingNet.
|
audio.
|
||||||
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose layer upsamples the time dimension by its `stride` times. And each Conv2DTranspose's filter_size at frequency dimension is 3.
|
|
||||||
|
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose
|
||||||
|
layer upsamples the time dimension by its `stride` times.
|
||||||
|
|
||||||
|
Also, each Conv2DTranspose's filter_size at frequency dimension is 3.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
upscale_factors (list[int], optional): time upsampling factors for each Conv2DTranspose Layer. The `UpsampleNet` contains len(upscale_factor) Conv2DTranspose Layers. Each upscale_factor is used as the `stride` for the corresponding Conv2DTranspose. Defaults to [16, 16].
|
----------
|
||||||
Note:
|
upscale_factors : List[int], optional
|
||||||
np.prod(upscale_factors) should equals the `hop_length` of the stft transformation used to extract spectrogram features from audios. For example, 16 * 16 = 256, then the spectram extracted using a stft transformation whose `hop_length` is 256. See `librosa.stft` for more details.
|
Time upsampling factors for each Conv2DTranspose Layer.
|
||||||
"""
|
|
||||||
|
The ``UpsampleNet`` contains ``len(upscale_factor)`` Conv2DTranspose
|
||||||
|
Layers. Each upscale_factor is used as the ``stride`` for the
|
||||||
|
corresponding Conv2DTranspose. Defaults to [16, 16], this the default
|
||||||
|
upsampling factor is 256.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
``np.prod(upscale_factors)`` should equals the ``hop_length`` of the stft
|
||||||
|
transformation used to extract spectrogram features from audio.
|
||||||
|
|
||||||
|
For example, ``16 * 16 = 256``, then the spectrogram extracted with a stft
|
||||||
|
transformation whose ``hop_length`` equals 256 is suitable.
|
||||||
|
|
||||||
|
See Also
|
||||||
|
---------
|
||||||
|
``librosa.core.stft``
|
||||||
|
"""
|
||||||
|
def __init__(self, upscale_factors=[16, 16]):
|
||||||
super(UpsampleNet, self).__init__()
|
super(UpsampleNet, self).__init__()
|
||||||
self.upscale_factors = list(upscale_factors)
|
self.upscale_factors = list(upscale_factors)
|
||||||
self.upscale_factor = 1
|
self.upscale_factor = 1
|
||||||
|
@ -78,13 +111,20 @@ class UpsampleNet(nn.LayerList):
|
||||||
padding=(1, factor // 2))))
|
padding=(1, factor // 2))))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Compute the upsampled condition.
|
r"""Compute the upsampled condition.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(B, F, T), dtype float32, the condition (mel spectrogram here.) (F means the frequency bands). In the internal Conv2DTransposes, the frequency dimension is treated as `height` dimension instead of `in_channels`.
|
-----------
|
||||||
|
x : Tensor [shape=(B, F, T)]
|
||||||
|
The condition (mel spectrogram here). ``F`` means the frequency
|
||||||
|
bands, which is the feature size of the input.
|
||||||
|
|
||||||
|
In the internal Conv2DTransposes, the frequency dimension
|
||||||
|
is treated as ``height`` dimension instead of ``in_channels``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: shape(B, F, T * upscale_factor), dtype float32, the upsampled condition.
|
Tensor [shape=(B, F, T \* upscale_factor)]
|
||||||
|
The upsampled condition.
|
||||||
"""
|
"""
|
||||||
x = paddle.unsqueeze(x, 1)
|
x = paddle.unsqueeze(x, 1)
|
||||||
for sublayer in self:
|
for sublayer in self:
|
||||||
|
@ -94,19 +134,36 @@ class UpsampleNet(nn.LayerList):
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Layer):
|
class ResidualBlock(nn.Layer):
|
||||||
|
"""A Residual block used in wavenet. Conv1D-gated-tanh Block.
|
||||||
|
|
||||||
|
It consists of a Conv1DCell and an Conv1D(kernel_size = 1) to integrate
|
||||||
|
information of the condition.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
--------
|
||||||
|
It does not have parametric residual or skip connection.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
residual_channels : int
|
||||||
|
The feature size of the input. It is also the feature size of the
|
||||||
|
residual output and skip output.
|
||||||
|
|
||||||
|
condition_dim : int
|
||||||
|
The feature size of the condition.
|
||||||
|
|
||||||
|
filter_size : int
|
||||||
|
Kernel size of the internal convolution cells.
|
||||||
|
|
||||||
|
dilation :int
|
||||||
|
Dilation of the internal convolution cells.
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
residual_channels: int,
|
residual_channels: int,
|
||||||
condition_dim: int,
|
condition_dim: int,
|
||||||
filter_size: Union[int, Sequence[int]],
|
filter_size: Union[int, Sequence[int]],
|
||||||
dilation: int):
|
dilation: int):
|
||||||
"""A Residual block in wavenet. It does not have parametric residual or skip connection. It consists of a Conv1DCell and an Conv1D(filter_size = 1) to integrate the condition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
residual_channels (int): the channels of the input, residual and skip.
|
|
||||||
condition_dim (int): the channels of the condition.
|
|
||||||
filter_size (int): filter size of the internal convolution cell.
|
|
||||||
dilation (int): dilation of the internal convolution cell.
|
|
||||||
"""
|
|
||||||
super(ResidualBlock, self).__init__()
|
super(ResidualBlock, self).__init__()
|
||||||
dilated_channels = 2 * residual_channels
|
dilated_channels = 2 * residual_channels
|
||||||
# following clarinet's implementation, we do not have parametric residual
|
# following clarinet's implementation, we do not have parametric residual
|
||||||
|
@ -133,17 +190,29 @@ class ResidualBlock(nn.Layer):
|
||||||
self.condition_dim = condition_dim
|
self.condition_dim = condition_dim
|
||||||
|
|
||||||
def forward(self, x, condition=None):
|
def forward(self, x, condition=None):
|
||||||
"""Conv1D gated-tanh Block.
|
"""Forward pass of the ResidualBlock.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(B, C_res, T), the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.) dtype float32.
|
-----------
|
||||||
condition (Tensor, optional): shape(B, C_cond, T), the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels). Defaults to None.
|
x : Tensor [shape=(B, C, T)]
|
||||||
|
The input tensor.
|
||||||
|
|
||||||
|
condition : Tensor, optional [shape(B, C_cond, T)]
|
||||||
|
The condition.
|
||||||
|
|
||||||
|
It has been upsampled in time steps, so it has the same time steps
|
||||||
|
as the input does.(C_cond stands for the condition's channels).
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
(residual, skip_connection)
|
-----------
|
||||||
residual (Tensor): shape(B, C_res, T), the residual, which is used as the input to the next layer of ResidualBlock.
|
residual : Tensor [shape=(B, C, T)]
|
||||||
skip_connection (Tensor): shape(B, C_res, T), the skip connection. This output is accumulated with that of other ResidualBlocks.
|
The residual, which is used as the input to the next ResidualBlock.
|
||||||
"""
|
|
||||||
|
skip_connection : Tensor [shape=(B, C, T)]
|
||||||
|
Tthe skip connection. This output is accumulated with that of
|
||||||
|
other ResidualBlocks.
|
||||||
|
"""
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
# dilated conv
|
# dilated conv
|
||||||
|
@ -163,22 +232,38 @@ class ResidualBlock(nn.Layer):
|
||||||
return residual, skip_connection
|
return residual, skip_connection
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
"""Prepare the ResidualBlock to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
"""Prepare the ResidualBlock to generate a new sequence.
|
||||||
|
|
||||||
|
Warnings
|
||||||
|
---------
|
||||||
|
This method should be called before calling ``add_input`` multiple times.
|
||||||
"""
|
"""
|
||||||
self.conv.start_sequence()
|
self.conv.start_sequence()
|
||||||
self.condition_proj.start_sequence()
|
self.condition_proj.start_sequence()
|
||||||
|
|
||||||
def add_input(self, x, condition=None):
|
def add_input(self, x, condition=None):
|
||||||
"""Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
|
"""Take a step input and return a step output.
|
||||||
|
|
||||||
|
This method works similarily with ``forward`` but in a
|
||||||
|
``step-in-step-out`` fashion.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
x (Tensor): shape(B, C_res), input for a step, dtype float32.
|
----------
|
||||||
condition (Tensor, optional): shape(B, C_cond). condition for a step, dtype float32. Defaults to None.
|
x : Tensor [shape=(B, C)]
|
||||||
|
Input for a step.
|
||||||
|
|
||||||
|
condition : Tensor, optional [shape=(B, C_cond)]
|
||||||
|
Condition for a step. Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
(residual, skip_connection)
|
----------
|
||||||
residual (Tensor): shape(B, C_res), the residual for a step, which is used as the input to the next layer of ResidualBlock.
|
residual : Tensor [shape=(B, C)]
|
||||||
skip_connection (Tensor): shape(B, C_res), the skip connection for a step. This output is accumulated with that of other ResidualBlocks.
|
The residual for a step, which is used as the input to the next
|
||||||
|
layer of ResidualBlock.
|
||||||
|
|
||||||
|
skip_connection : Tensor [shape=(B, C)]
|
||||||
|
T he skip connection for a step. This output is accumulated with
|
||||||
|
that of other ResidualBlocks.
|
||||||
"""
|
"""
|
||||||
h = x
|
h = x
|
||||||
|
|
||||||
|
@ -511,6 +596,54 @@ class WaveNet(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class ConditionalWaveNet(nn.Layer):
|
class ConditionalWaveNet(nn.Layer):
|
||||||
|
r"""Conditional Wavenet. An implementation of
|
||||||
|
`WaveNet: A Generative Model for Raw Audio <http://arxiv.org/abs/1609.03499>`_.
|
||||||
|
|
||||||
|
It contains an UpsampleNet as the encoder and a WaveNet as the decoder.
|
||||||
|
It is an autoregressive model that generate raw audio.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
upsample_factors : List[int]
|
||||||
|
The upsampling factors of the UpsampleNet.
|
||||||
|
|
||||||
|
n_stack : int
|
||||||
|
Number of convolution stacks in the WaveNet.
|
||||||
|
|
||||||
|
n_loop : int
|
||||||
|
Number of convolution layers in a convolution stack.
|
||||||
|
|
||||||
|
Convolution layers in a stack have exponentially growing dilations,
|
||||||
|
from 1 to .. math:: `k^{n_{loop} - 1}`, where k is the kernel size.
|
||||||
|
|
||||||
|
residual_channels : int
|
||||||
|
Feature size of each ResidualBlocks.
|
||||||
|
|
||||||
|
output_dim : int
|
||||||
|
Feature size of the output. See ``loss_type`` for details.
|
||||||
|
|
||||||
|
n_mels : int
|
||||||
|
The number of bands of mel spectrogram.
|
||||||
|
|
||||||
|
filter_size : int, optional
|
||||||
|
Convolution kernel size of each ResidualBlock, by default 2.
|
||||||
|
|
||||||
|
loss_type : str, optional ["mog" or "softmax"]
|
||||||
|
The output type and loss type of the model, by default "mog".
|
||||||
|
|
||||||
|
If "softmax", the model input should be quantized audio and the model
|
||||||
|
outputs a discret distribution.
|
||||||
|
|
||||||
|
If "mog", the model input is audio in floating point format, and the
|
||||||
|
model outputs parameters for a mixture of gaussian distributions.
|
||||||
|
Namely, the weight, mean and logscale of each gaussian distribution.
|
||||||
|
Thus, the ``output_size`` should be a multiple of 3.
|
||||||
|
|
||||||
|
log_scale_min : float, optional
|
||||||
|
Minimum value of the log probability density, by default -9.0.
|
||||||
|
|
||||||
|
This is only used for computing loss when ``loss_type`` is "mog", If the
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
upsample_factors: List[int],
|
upsample_factors: List[int],
|
||||||
n_stack: int,
|
n_stack: int,
|
||||||
|
@ -521,29 +654,37 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
filter_size: int=2,
|
filter_size: int=2,
|
||||||
loss_type: str="mog",
|
loss_type: str="mog",
|
||||||
log_scale_min: float=-9.0):
|
log_scale_min: float=-9.0):
|
||||||
"""Conditional Wavenet, which contains an UpsampleNet as the encoder and a WaveNet as the decoder. It is an autoregressive model.
|
|
||||||
"""
|
|
||||||
super(ConditionalWaveNet, self).__init__()
|
super(ConditionalWaveNet, self).__init__()
|
||||||
self.encoder = UpsampleNet(upsample_factors)
|
self.encoder = UpsampleNet(upsample_factors)
|
||||||
self.decoder = WaveNet(n_stack=n_stack,
|
self.decoder = WaveNet(n_stack=n_stack,
|
||||||
n_loop=n_loop,
|
n_loop=n_loop,
|
||||||
residual_channels=residual_channels,
|
residual_channels=residual_channels,
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
condition_dim=n_mels,
|
condition_dim=n_mels,
|
||||||
filter_size=filter_size,
|
filter_size=filter_size,
|
||||||
loss_type=loss_type,
|
loss_type=loss_type,
|
||||||
log_scale_min=log_scale_min)
|
log_scale_min=log_scale_min)
|
||||||
|
|
||||||
def forward(self, audio, mel, audio_start):
|
def forward(self, audio, mel, audio_start):
|
||||||
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
audio (Tensor): shape(B, T_audio), dtype float32, ground truth waveform, used for teacher force training.
|
-----------
|
||||||
mel (Tensor): shape(B, F, T_mel), dtype float32, mel spectrogram. Note that it is the spectrogram for the whole utterance.
|
audio : Tensor [shape=(B, T_audio)]
|
||||||
audio_start (Tensor): shape(B, ), dtype: int, audio slices' start positions for each utterance.
|
ground truth waveform, used for teacher force training.
|
||||||
|
|
||||||
|
mel : Tensor [shape(B, F, T_mel)]
|
||||||
|
Mel spectrogram. Note that it is the spectrogram for the whole
|
||||||
|
utterance.
|
||||||
|
|
||||||
|
audio_start : Tensor [shape=(B,), dtype: int]
|
||||||
|
Audio slices' start positions for each utterance.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(B, T_audio - 1, C_putput), parameters for the output distribution.(C_output is the `output_dim` of the decoder.)
|
----------
|
||||||
|
Tensor [shape(B, T_audio - 1, C_output)]
|
||||||
|
Parameters for the output distribution, where ``C_output`` is the
|
||||||
|
``output_dim`` of the decoder.)
|
||||||
"""
|
"""
|
||||||
audio_length = audio.shape[1] # audio clip's length
|
audio_length = audio.shape[1] # audio clip's length
|
||||||
condition = self.encoder(mel)
|
condition = self.encoder(mel)
|
||||||
|
@ -557,14 +698,21 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def loss(self, y, t):
|
def loss(self, y, t):
|
||||||
"""compute loss with respect to the output distribution and the targer audio.
|
"""Compute loss with respect to the output distribution and the target
|
||||||
|
audio.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
y (Tensor): shape(B, T - 1, C_output), dtype float32, parameters of the output distribution.
|
-----------
|
||||||
t (Tensor): shape(B, T), dtype float32, target waveform.
|
y : Tensor [shape=(B, T - 1, C_output)]
|
||||||
|
Parameters of the output distribution.
|
||||||
|
|
||||||
|
t : Tensor [shape(B, T)]
|
||||||
|
target waveform.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(1, ), dtype float32, the loss.
|
--------
|
||||||
|
Tensor: [shape=(1,)]
|
||||||
|
the loss.
|
||||||
"""
|
"""
|
||||||
t = t[:, 1:]
|
t = t[:, 1:]
|
||||||
loss = self.decoder.loss(y, t)
|
loss = self.decoder.loss(y, t)
|
||||||
|
@ -573,24 +721,35 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
def sample(self, y):
|
def sample(self, y):
|
||||||
"""Sample from the output distribution.
|
"""Sample from the output distribution.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
y (Tensor): shape(B, T, C_output), dtype float32, parameters of the output distribution.
|
-----------
|
||||||
|
y : Tensor [shape=(B, T, C_output)]
|
||||||
|
Parameters of the output distribution.
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(B, T), dtype float32, sampled waveform from the output distribution.
|
--------
|
||||||
|
Tensor [shape=(B, T)]
|
||||||
|
Sampled waveform from the output distribution.
|
||||||
"""
|
"""
|
||||||
samples = self.decoder.sample(y)
|
samples = self.decoder.sample(y)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def infer(self, mel):
|
def infer(self, mel):
|
||||||
"""Synthesize waveform from mel spectrogram.
|
r"""Synthesize waveform from mel spectrogram.
|
||||||
|
|
||||||
Args:
|
Parameters
|
||||||
mel (Tensor): shape(B, F, T), condition(mel spectrogram here).
|
-----------
|
||||||
|
mel : Tensor [shape=(B, F, T)]
|
||||||
|
The ondition (mel spectrogram here).
|
||||||
|
|
||||||
Returns:
|
Returns
|
||||||
Tensor: shape(B, T * upsacle_factor), synthesized waveform.(`upscale_factor` is the `upscale_factor` of the encoder `UpsampleNet`)
|
-----------
|
||||||
|
Tensor [shape=(B, T \* upsacle_factor)]
|
||||||
|
Synthesized waveform.
|
||||||
|
|
||||||
|
``upscale_factor`` is the ``upscale_factor`` of the encoder
|
||||||
|
``UpsampleNet``.
|
||||||
"""
|
"""
|
||||||
condition = self.encoder(mel)
|
condition = self.encoder(mel)
|
||||||
batch_size, _, time_steps = condition.shape
|
batch_size, _, time_steps = condition.shape
|
||||||
|
@ -610,6 +769,20 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def predict(self, mel):
|
def predict(self, mel):
|
||||||
|
r"""Synthesize audio from mel spectrogram.
|
||||||
|
|
||||||
|
The output and input are numpy arrays without batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
mel : np.ndarray [shape=(C, T)]
|
||||||
|
Mel spectrogram of an utterance.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor : np.ndarray [shape=(C, T \* upsample_factor)]
|
||||||
|
The synthesized waveform of an utterance.
|
||||||
|
"""
|
||||||
mel = paddle.to_tensor(mel)
|
mel = paddle.to_tensor(mel)
|
||||||
mel = paddle.unsqueeze(mel, 0)
|
mel = paddle.unsqueeze(mel, 0)
|
||||||
audio = self.infer(mel)
|
audio = self.infer(mel)
|
||||||
|
@ -618,6 +791,21 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, config, checkpoint_path):
|
def from_pretrained(cls, config, checkpoint_path):
|
||||||
|
"""Build a ConditionalWaveNet model from a pretrained model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config: yacs.config.CfgNode
|
||||||
|
model configs
|
||||||
|
|
||||||
|
checkpoint_path: Path or str
|
||||||
|
the path of pretrained model checkpoint, without extension name
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ConditionalWaveNet
|
||||||
|
The model built from pretrained result.
|
||||||
|
"""
|
||||||
model = cls(
|
model = cls(
|
||||||
upsample_factors=config.model.upsample_factors,
|
upsample_factors=config.model.upsample_factors,
|
||||||
n_stack=config.model.n_stack,
|
n_stack=config.model.n_stack,
|
||||||
|
@ -631,5 +819,3 @@ class ConditionalWaveNet(nn.Layer):
|
||||||
layer_tools.summary(model)
|
layer_tools.summary(model)
|
||||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,3 +12,11 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from parakeet.modules.attention import *
|
||||||
|
from parakeet.modules.audio import *
|
||||||
|
from parakeet.modules.conv import *
|
||||||
|
from parakeet.modules.geometry import *
|
||||||
|
from parakeet.modules.losses import *
|
||||||
|
from parakeet.modules.masking import *
|
||||||
|
from parakeet.modules.positional_encoding import *
|
||||||
|
from parakeet.modules.transformer import *
|
|
@ -1,90 +0,0 @@
|
||||||
import math
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.nn import functional as F
|
|
||||||
from paddle.nn import initializer as I
|
|
||||||
|
|
||||||
from parakeet.modules.conv import Conv1dBatchNorm
|
|
||||||
|
|
||||||
|
|
||||||
class Highway(nn.Layer):
|
|
||||||
def __init__(self, num_features):
|
|
||||||
super(Highway, self).__init__()
|
|
||||||
self.H = nn.Linear(num_features, num_features)
|
|
||||||
self.T = nn.Linear(num_features, num_features,
|
|
||||||
bias_attr=I.Constant(-1.))
|
|
||||||
|
|
||||||
self.num_features = num_features
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
H = F.relu(self.H(x))
|
|
||||||
T = F.sigmoid(self.T(x)) # gate
|
|
||||||
return H * T + x * (1.0 - T)
|
|
||||||
|
|
||||||
|
|
||||||
class CBHG(nn.Layer):
|
|
||||||
def __init__(self, in_channels, out_channels_per_conv, max_kernel_size,
|
|
||||||
projection_channels,
|
|
||||||
num_highways, highway_features,
|
|
||||||
gru_features):
|
|
||||||
super(CBHG, self).__init__()
|
|
||||||
self.conv1d_banks = nn.LayerList(
|
|
||||||
[Conv1dBatchNorm(in_channels, out_channels_per_conv, (k,),
|
|
||||||
padding=((k - 1) // 2, k // 2))
|
|
||||||
for k in range(1, 1 + max_kernel_size)])
|
|
||||||
|
|
||||||
self.projections = nn.LayerList()
|
|
||||||
projection_channels = list(projection_channels)
|
|
||||||
proj_in_channels = [max_kernel_size *
|
|
||||||
out_channels_per_conv] + projection_channels
|
|
||||||
proj_out_channels = projection_channels + \
|
|
||||||
[in_channels] # ensure residual connection
|
|
||||||
for c_in, c_out in zip(proj_in_channels, proj_out_channels):
|
|
||||||
conv = nn.Conv1D(c_in, c_out, (3,), padding=(1, 1))
|
|
||||||
self.projections.append(conv)
|
|
||||||
|
|
||||||
if in_channels != highway_features:
|
|
||||||
self.pre_highway = nn.Linear(in_channels, highway_features)
|
|
||||||
|
|
||||||
self.highways = nn.LayerList(
|
|
||||||
[Highway(highway_features) for _ in range(num_highways)])
|
|
||||||
|
|
||||||
self.gru = nn.GRU(highway_features, gru_features,
|
|
||||||
direction="bidirectional")
|
|
||||||
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels_per_conv = out_channels_per_conv
|
|
||||||
self.max_kernel_size = max_kernel_size
|
|
||||||
self.num_projections = 1 + len(projection_channels)
|
|
||||||
self.num_highways = num_highways
|
|
||||||
self.highway_features = highway_features
|
|
||||||
self.gru_features = gru_features
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
input = x
|
|
||||||
|
|
||||||
# conv banks
|
|
||||||
conv_outputs = []
|
|
||||||
for conv in self.conv1d_banks:
|
|
||||||
conv_outputs.append(conv(x))
|
|
||||||
x = F.relu(paddle.concat(conv_outputs, 1))
|
|
||||||
|
|
||||||
# max pool
|
|
||||||
x = F.max_pool1d(x, 2, stride=1, padding=(0, 1))
|
|
||||||
|
|
||||||
# conv1d projections
|
|
||||||
n_projections = len(self.projections)
|
|
||||||
for i, conv in enumerate(self.projections):
|
|
||||||
x = conv(x)
|
|
||||||
if i != n_projections:
|
|
||||||
x = F.relu(x)
|
|
||||||
x += input # residual connection
|
|
||||||
|
|
||||||
# highway
|
|
||||||
x = paddle.transpose(x, [0, 2, 1])
|
|
||||||
if hasattr(self, "pre_highway"):
|
|
||||||
x = self.pre_highway(x)
|
|
||||||
|
|
||||||
# gru
|
|
||||||
x, _ = self.gru(x)
|
|
||||||
return x
|
|
|
@ -1,62 +0,0 @@
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.nn import functional as F
|
|
||||||
|
|
||||||
def residual_connection(input, layer):
|
|
||||||
"""residual connection, only used for single input-single output layer.
|
|
||||||
y = x + F(x) where F corresponds to the layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): the input tensor.
|
|
||||||
layer (callable): a callable that preserve tensor shape.
|
|
||||||
"""
|
|
||||||
return input + layer(input)
|
|
||||||
|
|
||||||
class ResidualWrapper(nn.Layer):
|
|
||||||
def __init__(self, layer):
|
|
||||||
super(ResidualWrapper, self).__init__()
|
|
||||||
self.layer = layer
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return residual_connection(x, self.layer)
|
|
||||||
|
|
||||||
|
|
||||||
class PreLayerNormWrapper(nn.Layer):
|
|
||||||
def __init__(self, layer, d_model):
|
|
||||||
super(PreLayerNormWrapper, self).__init__()
|
|
||||||
self.layer = layer
|
|
||||||
self.layer_norm = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x + self.layer(self.layer_norm(x))
|
|
||||||
|
|
||||||
|
|
||||||
class PostLayerNormWrapper(nn.Layer):
|
|
||||||
def __init__(self, layer, d_model):
|
|
||||||
super(PostLayerNormWrapper, self).__init__()
|
|
||||||
self.layer = layer
|
|
||||||
self.layer_norm = nn.LayerNorm([d_model], epsilon=1e-6)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.layer_norm(x + self.layer(x))
|
|
||||||
|
|
||||||
|
|
||||||
def context_gate(input, axis):
|
|
||||||
"""sigmoid gate the content by gate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input (Tensor): shape(*, d_axis, *), the input, treated as content & gate.
|
|
||||||
axis (int): the axis to chunk content and gate.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if input.shape[axis] is not even.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: shape(*, d_axis / 2 , *), the gated content.
|
|
||||||
"""
|
|
||||||
size = input.shape[axis]
|
|
||||||
if size % 2 != 0:
|
|
||||||
raise ValueError("the size of the {}-th dimension of input should "
|
|
||||||
"be even, but received {}".format(axis, size))
|
|
||||||
content, gate = paddle.chunk(input, 2, axis)
|
|
||||||
return F.sigmoid(gate) * content
|
|
|
@ -15,19 +15,69 @@
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Conv1dCell",
|
||||||
|
"Conv1dBatchNorm",
|
||||||
|
]
|
||||||
|
|
||||||
class Conv1dCell(nn.Conv1D):
|
class Conv1dCell(nn.Conv1D):
|
||||||
"""
|
"""A subclass of Conv1D layer, which can be used in an autoregressive
|
||||||
A subclass of Conv1d layer, which can be used like an RNN cell. It can take
|
decoder like an RNN cell.
|
||||||
step input and return step output. It is done by keeping an internal buffer,
|
|
||||||
when adding a step input, we shift the buffer and return a step output. For
|
When used in autoregressive decoding, it performs causal temporal
|
||||||
single step case, convolution devolves to a linear transformation.
|
convolution incrementally. At each time step, it takes a step input and
|
||||||
|
returns a step output.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
It is done by caching an internal buffer of length ``receptive_file - 1``.
|
||||||
|
when adding a step input, the buffer is shited by one step, the latest
|
||||||
|
input is added to be buffer and the oldest step is discarded. And it
|
||||||
|
returns a step output. For single step case, convolution is equivalent to a
|
||||||
|
linear transformation.
|
||||||
|
|
||||||
That it can be used as a cell depends on several restrictions:
|
That it can be used as a cell depends on several restrictions:
|
||||||
1. stride must be 1;
|
|
||||||
2. padding must be an asymmetric padding (recpetive_field - 1, 0).
|
|
||||||
|
|
||||||
As a result, these arguments are removed form the initializer.
|
1. stride must be 1;
|
||||||
|
2. padding must be a causal padding (recpetive_field - 1, 0).
|
||||||
|
|
||||||
|
Thus, these arguments are removed from the ``__init__`` method of this
|
||||||
|
class.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels: int
|
||||||
|
The feature size of the input.
|
||||||
|
|
||||||
|
out_channels: int
|
||||||
|
The feature size of the output.
|
||||||
|
|
||||||
|
kernel_size: int or Tuple[int]
|
||||||
|
The size of the kernel.
|
||||||
|
|
||||||
|
dilation: int or Tuple[int]
|
||||||
|
The dilation of the convolution, by default 1
|
||||||
|
|
||||||
|
weight_attr: ParamAttr, Initializer, str or bool, optional
|
||||||
|
The parameter attribute of the convolution kernel, by default None.
|
||||||
|
|
||||||
|
bias_attr: ParamAttr, Initializer, str or bool, optional
|
||||||
|
The parameter attribute of the bias. If ``False``, this layer does not
|
||||||
|
have a bias, by default None.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> cell = Conv1dCell(3, 4, kernel_size=5)
|
||||||
|
>>> inputs = [paddle.randn([4, 3]) for _ in range(16)]
|
||||||
|
>>> outputs = []
|
||||||
|
>>> cell.eval()
|
||||||
|
>>> cell.start_sequence()
|
||||||
|
>>> for xt in inputs:
|
||||||
|
>>> outputs.append(cell.add_input(xt))
|
||||||
|
>>> len(outputs))
|
||||||
|
16
|
||||||
|
>>> outputs[0].shape
|
||||||
|
[4, 4]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -54,9 +104,23 @@ class Conv1dCell(nn.Conv1D):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def receptive_field(self):
|
def receptive_field(self):
|
||||||
|
"""The receptive field of the Conv1dCell.
|
||||||
|
"""
|
||||||
return self._r
|
return self._r
|
||||||
|
|
||||||
def start_sequence(self):
|
def start_sequence(self):
|
||||||
|
"""Prepare the layer for a series of incremental forward.
|
||||||
|
|
||||||
|
Warnings
|
||||||
|
---------
|
||||||
|
This method should be called before a sequence of calls to
|
||||||
|
``add_input``.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
Exception
|
||||||
|
If this method is called when the layer is in training mode.
|
||||||
|
"""
|
||||||
if self.training:
|
if self.training:
|
||||||
raise Exception("only use start_sequence in evaluation")
|
raise Exception("only use start_sequence in evaluation")
|
||||||
self._buffer = None
|
self._buffer = None
|
||||||
|
@ -72,21 +136,41 @@ class Conv1dCell(nn.Conv1D):
|
||||||
(self._out_channels, -1))
|
(self._out_channels, -1))
|
||||||
|
|
||||||
def initialize_buffer(self, x_t):
|
def initialize_buffer(self, x_t):
|
||||||
|
"""Initialize the buffer for the step input.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x_t : Tensor [shape=(batch_size, in_channels)]
|
||||||
|
The step input.
|
||||||
|
"""
|
||||||
batch_size, _ = x_t.shape
|
batch_size, _ = x_t.shape
|
||||||
self._buffer = paddle.zeros(
|
self._buffer = paddle.zeros(
|
||||||
(batch_size, self._in_channels, self.receptive_field),
|
(batch_size, self._in_channels, self.receptive_field),
|
||||||
dtype=x_t.dtype)
|
dtype=x_t.dtype)
|
||||||
|
|
||||||
def update_buffer(self, x_t):
|
def update_buffer(self, x_t):
|
||||||
|
"""Shift the buffer by one step.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x_t : Tensor [shape=(batch_size, in_channels)]
|
||||||
|
The step input.
|
||||||
|
"""
|
||||||
self._buffer = paddle.concat(
|
self._buffer = paddle.concat(
|
||||||
[self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1)
|
[self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1)
|
||||||
|
|
||||||
def add_input(self, x_t):
|
def add_input(self, x_t):
|
||||||
"""
|
"""Add step input and compute step output.
|
||||||
Arguments:
|
|
||||||
x_t (Tensor): shape (batch_size, in_channels), step input.
|
Parameters
|
||||||
Rerurns:
|
-----------
|
||||||
y_t (Tensor): shape (batch_size, out_channels), step output.
|
x_t : Tensor [shape=(batch_size, in_channels)]
|
||||||
|
The step input.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
y_t :Tensor [shape=(batch_size, out_channels)]
|
||||||
|
The step output.
|
||||||
"""
|
"""
|
||||||
batch_size = x_t.shape[0]
|
batch_size = x_t.shape[0]
|
||||||
if self.receptive_field > 1:
|
if self.receptive_field > 1:
|
||||||
|
@ -108,6 +192,45 @@ class Conv1dCell(nn.Conv1D):
|
||||||
|
|
||||||
|
|
||||||
class Conv1dBatchNorm(nn.Layer):
|
class Conv1dBatchNorm(nn.Layer):
|
||||||
|
"""A Conv1D Layer followed by a BatchNorm1D.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_channels : int
|
||||||
|
The feature size of the input.
|
||||||
|
|
||||||
|
out_channels : int
|
||||||
|
The feature size of the output.
|
||||||
|
|
||||||
|
kernel_size : int
|
||||||
|
The size of the convolution kernel.
|
||||||
|
|
||||||
|
stride : int, optional
|
||||||
|
The stride of the convolution, by default 1.
|
||||||
|
|
||||||
|
padding : int, str or Tuple[int], optional
|
||||||
|
The padding of the convolution.
|
||||||
|
If int, a symmetrical padding is applied before convolution;
|
||||||
|
If str, it should be "same" or "valid";
|
||||||
|
If Tuple[int], its length should be 2, meaning
|
||||||
|
``(pad_before, pad_after)``, by default 0.
|
||||||
|
|
||||||
|
weight_attr : ParamAttr, Initializer, str or bool, optional
|
||||||
|
The parameter attribute of the convolution kernel, by default None.
|
||||||
|
|
||||||
|
bias_attr : ParamAttr, Initializer, str or bool, optional
|
||||||
|
The parameter attribute of the bias of the convolution, by default
|
||||||
|
None.
|
||||||
|
|
||||||
|
data_format : str ["NCL" or "NLC"], optional
|
||||||
|
The data layout of the input, by default "NCL"
|
||||||
|
|
||||||
|
momentum : float, optional
|
||||||
|
The momentum of the BatchNorm1D layer, by default 0.9
|
||||||
|
|
||||||
|
epsilon : [type], optional
|
||||||
|
The epsilon of the BatchNorm1D layer, by default 1e-05
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -136,6 +259,18 @@ class Conv1dBatchNorm(nn.Layer):
|
||||||
data_format=data_format)
|
data_format=data_format)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""Forward pass of the Conv1dBatchNorm layer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor [shape=(B, C_in, T_in) or (B, T_in, C_in)]
|
||||||
|
The input tensor. Its data layout depends on ``data_format``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor [shape=(B, C_out, T_out) or (B, T_out, C_out)]
|
||||||
|
The output tensor.
|
||||||
|
"""
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.bn(x)
|
x = self.bn(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -1,12 +1,40 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
def default_argument_parser():
|
def default_argument_parser():
|
||||||
|
r"""A simple yet genral argument parser for experiments with parakeet.
|
||||||
|
|
||||||
|
This is used in examples with parakeet. And it is intended to be used by
|
||||||
|
other experiments with parakeet. It requires a minimal set of command line
|
||||||
|
arguments to start a training script.
|
||||||
|
|
||||||
|
The ``--config`` and ``--opts`` are used for overwrite the deault
|
||||||
|
configuration.
|
||||||
|
|
||||||
|
The ``--data`` and ``--output`` specifies the data path and output path.
|
||||||
|
Resuming training from existing progress at the output directory is the
|
||||||
|
intended default behavior.
|
||||||
|
|
||||||
|
The ``--checkpoint_path`` specifies the checkpoint to load from.
|
||||||
|
|
||||||
|
The ``--device`` and ``--nprocs`` specifies how to run the training.
|
||||||
|
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
parakeet.training.experiment
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
argparse.ArgumentParser
|
||||||
|
the parser
|
||||||
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
# data and outpu
|
# data and outpu
|
||||||
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
|
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
|
||||||
parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
|
parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
|
||||||
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and log. If not provided, a directory is created in runs/ to save outputs.")
|
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.")
|
||||||
|
|
||||||
# load from saved checkpoint
|
# load from saved checkpoint
|
||||||
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
|
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
|
||||||
|
@ -17,5 +45,6 @@ def default_argument_parser():
|
||||||
|
|
||||||
# overwrite extra config and default config
|
# overwrite extra config and default config
|
||||||
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
||||||
|
# yapd: enable
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
|
@ -29,47 +29,63 @@ __all__ = ["ExperimentBase"]
|
||||||
|
|
||||||
class ExperimentBase(object):
|
class ExperimentBase(object):
|
||||||
"""
|
"""
|
||||||
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
|
An experiment template in order to structure the training code and take
|
||||||
|
care of saving, loading, logging, visualization stuffs. It's intended to
|
||||||
|
be flexible and simple.
|
||||||
|
|
||||||
So it only handles output directory (create directory for the outut, create a checkpoint directory, dump the config in use and create visualizer and logger)in a standard way without restricting the input/output protocols of the model and dataloader. It leaves the main part for the user to implement their own(setup the model, criterion, optimizer, defaine a training step, define a validation function and customize all the text and visual logs).
|
So it only handles output directory (create directory for the output,
|
||||||
|
create a checkpoint directory, dump the config in use and create
|
||||||
|
visualizer and logger) in a standard way without enforcing any
|
||||||
|
input-output protocols to the model and dataloader. It leaves the main
|
||||||
|
part for the user to implement their own (setup the model, criterion,
|
||||||
|
optimizer, define a training step, define a validation function and
|
||||||
|
customize all the text and visual logs).
|
||||||
|
|
||||||
It does not save too much boilerplate code. The users still have to write the forward/backward/update mannually, but they are free to add non-standard behaviors if needed.
|
It does not save too much boilerplate code. The users still have to write
|
||||||
|
the forward/backward/update mannually, but they are free to add
|
||||||
|
non-standard behaviors if needed.
|
||||||
|
|
||||||
We have some conventions to follow.
|
We have some conventions to follow.
|
||||||
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
|
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
|
||||||
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
|
``valid_loader``, ``config`` and ``args`` attributes.
|
||||||
3. There are four method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
|
2. The config should have a ``training`` field, which has
|
||||||
|
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
|
||||||
|
used as the trigger to invoke validation, checkpointing and stop of the
|
||||||
|
experiment.
|
||||||
|
3. There are four methods, namely ``train_batch``, ``valid``,
|
||||||
|
``setup_model`` and ``setup_dataloader`` that should be implemented.
|
||||||
|
|
||||||
Feel free to add/overwrite other methods and standalone functions if you need.
|
Feel free to add/overwrite other methods and standalone functions if you
|
||||||
|
need.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config: yacs.config.CfgNode
|
||||||
|
The configuration used for the experiment.
|
||||||
|
|
||||||
|
args: argparse.Namespace
|
||||||
|
The parsed command line arguments.
|
||||||
|
|
||||||
|
Examples
|
||||||
Examples:
|
|
||||||
--------
|
--------
|
||||||
def main_sp(config, args):
|
>>> def main_sp(config, args):
|
||||||
exp = Experiment(config, args)
|
>>> exp = Experiment(config, args)
|
||||||
exp.setup()
|
>>> exp.setup()
|
||||||
exp.run()
|
>>> exp.run()
|
||||||
|
>>>
|
||||||
def main(config, args):
|
>>> config = get_cfg_defaults()
|
||||||
if args.nprocs > 1 and args.device == "gpu":
|
>>> parser = default_argument_parser()
|
||||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
>>> args = parser.parse_args()
|
||||||
else:
|
>>> if args.config:
|
||||||
main_sp(config, args)
|
>>> config.merge_from_file(args.config)
|
||||||
|
>>> if args.opts:
|
||||||
if __name__ == "__main__":
|
>>> config.merge_from_list(args.opts)
|
||||||
config = get_cfg_defaults()
|
>>> config.freeze()
|
||||||
parser = default_argument_parser()
|
>>>
|
||||||
args = parser.parse_args()
|
>>> if args.nprocs > 1 and args.device == "gpu":
|
||||||
if args.config:
|
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||||
config.merge_from_file(args.config)
|
>>> else:
|
||||||
if args.opts:
|
>>> main_sp(config, args)
|
||||||
config.merge_from_list(args.opts)
|
|
||||||
config.freeze()
|
|
||||||
print(config)
|
|
||||||
print(args)
|
|
||||||
|
|
||||||
main(config, args)
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, args):
|
def __init__(self, config, args):
|
||||||
|
@ -77,6 +93,8 @@ class ExperimentBase(object):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
|
"""Setup the experiment.
|
||||||
|
"""
|
||||||
paddle.set_device(self.args.device)
|
paddle.set_device(self.args.device)
|
||||||
if self.parallel:
|
if self.parallel:
|
||||||
self.init_parallel()
|
self.init_parallel()
|
||||||
|
@ -95,16 +113,29 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parallel(self):
|
def parallel(self):
|
||||||
|
"""A flag indicating whether the experiment should run with
|
||||||
|
multiprocessing.
|
||||||
|
"""
|
||||||
return self.args.device == "gpu" and self.args.nprocs > 1
|
return self.args.device == "gpu" and self.args.nprocs > 1
|
||||||
|
|
||||||
def init_parallel(self):
|
def init_parallel(self):
|
||||||
|
"""Init environment for multiprocess training.
|
||||||
|
"""
|
||||||
dist.init_parallel_env()
|
dist.init_parallel_env()
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
|
"""Save checkpoint (model parameters and optimizer states).
|
||||||
|
"""
|
||||||
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
||||||
self.model, self.optimizer)
|
self.model, self.optimizer)
|
||||||
|
|
||||||
def resume_or_load(self):
|
def resume_or_load(self):
|
||||||
|
"""Resume from latest checkpoint at checkpoints in the output
|
||||||
|
directory or load a specified checkpoint.
|
||||||
|
|
||||||
|
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
||||||
|
resume training.
|
||||||
|
"""
|
||||||
iteration = checkpoint.load_parameters(
|
iteration = checkpoint.load_parameters(
|
||||||
self.model,
|
self.model,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
|
@ -113,6 +144,13 @@ class ExperimentBase(object):
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
|
||||||
def read_batch(self):
|
def read_batch(self):
|
||||||
|
"""Read a batch from the train_loader.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Tensor]
|
||||||
|
A batch.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
batch = next(self.iterator)
|
batch = next(self.iterator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
@ -121,12 +159,19 @@ class ExperimentBase(object):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def new_epoch(self):
|
def new_epoch(self):
|
||||||
|
"""Reset the train loader and increment ``epoch``.
|
||||||
|
"""
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
if self.parallel:
|
if self.parallel:
|
||||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||||
self.iterator = iter(self.train_loader)
|
self.iterator = iter(self.train_loader)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
"""The training process.
|
||||||
|
|
||||||
|
It includes forward/backward/update and periodical validation and
|
||||||
|
saving.
|
||||||
|
"""
|
||||||
self.new_epoch()
|
self.new_epoch()
|
||||||
while self.iteration < self.config.training.max_iteration:
|
while self.iteration < self.config.training.max_iteration:
|
||||||
self.iteration += 1
|
self.iteration += 1
|
||||||
|
@ -139,6 +184,9 @@ class ExperimentBase(object):
|
||||||
self.save()
|
self.save()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
"""The routine of the experiment after setup. This method is intended
|
||||||
|
to be used by the user.
|
||||||
|
"""
|
||||||
self.resume_or_load()
|
self.resume_or_load()
|
||||||
try:
|
try:
|
||||||
self.train()
|
self.train()
|
||||||
|
@ -148,6 +196,8 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_output_dir(self):
|
def setup_output_dir(self):
|
||||||
|
"""Create a directory used for output.
|
||||||
|
"""
|
||||||
# output dir
|
# output dir
|
||||||
output_dir = Path(self.args.output).expanduser()
|
output_dir = Path(self.args.output).expanduser()
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -156,6 +206,10 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_checkpointer(self):
|
def setup_checkpointer(self):
|
||||||
|
"""Create a directory used to save checkpoints into.
|
||||||
|
|
||||||
|
It is "checkpoints" inside the output directory.
|
||||||
|
"""
|
||||||
# checkpoint dir
|
# checkpoint dir
|
||||||
checkpoint_dir = self.output_dir / "checkpoints"
|
checkpoint_dir = self.output_dir / "checkpoints"
|
||||||
checkpoint_dir.mkdir(exist_ok=True)
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
@ -164,12 +218,28 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def setup_visualizer(self):
|
def setup_visualizer(self):
|
||||||
|
"""Initialize a visualizer to log the experiment.
|
||||||
|
|
||||||
|
The visual log is saved in the output directory.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
------
|
||||||
|
Only the main process has a visualizer with it. Use multiple
|
||||||
|
visualizers in multiprocess to write to a same log file may cause
|
||||||
|
unexpected behaviors.
|
||||||
|
"""
|
||||||
# visualizer
|
# visualizer
|
||||||
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
||||||
|
|
||||||
self.visualizer = visualizer
|
self.visualizer = visualizer
|
||||||
|
|
||||||
def setup_logger(self):
|
def setup_logger(self):
|
||||||
|
"""Initialize a text logger to log the experiment.
|
||||||
|
|
||||||
|
Each process has its own text logger. The logging message is write to
|
||||||
|
the standard output and a text file named ``worker_n.log`` in the
|
||||||
|
output directory, where ``n`` means the rank of the process.
|
||||||
|
"""
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel("INFO")
|
logger.setLevel("INFO")
|
||||||
logger.addHandler(logging.StreamHandler())
|
logger.addHandler(logging.StreamHandler())
|
||||||
|
@ -180,19 +250,34 @@ class ExperimentBase(object):
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
def dump_config(self):
|
def dump_config(self):
|
||||||
|
"""Save the configuration used for this experiment.
|
||||||
|
|
||||||
|
It is saved in to ``config.yaml`` in the output directory at the
|
||||||
|
beginning of the experiment.
|
||||||
|
"""
|
||||||
with open(self.output_dir / "config.yaml", 'wt') as f:
|
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||||
print(self.config, file=f)
|
print(self.config, file=f)
|
||||||
|
|
||||||
def train_batch(self):
|
def train_batch(self):
|
||||||
|
"""The training loop. A subclass should implement this method.
|
||||||
|
"""
|
||||||
raise NotImplementedError("train_batch should be implemented.")
|
raise NotImplementedError("train_batch should be implemented.")
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
@mp_tools.rank_zero_only
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def valid(self):
|
def valid(self):
|
||||||
|
"""The validation. A subclass should implement this method.
|
||||||
|
"""
|
||||||
raise NotImplementedError("valid should be implemented.")
|
raise NotImplementedError("valid should be implemented.")
|
||||||
|
|
||||||
def setup_model(self):
|
def setup_model(self):
|
||||||
|
"""Setup model, criterion and optimizer, etc. A subclass should
|
||||||
|
implement this method.
|
||||||
|
"""
|
||||||
raise NotImplementedError("setup_model should be implemented.")
|
raise NotImplementedError("setup_model should be implemented.")
|
||||||
|
|
||||||
def setup_dataloader(self):
|
def setup_dataloader(self):
|
||||||
|
"""Setup training dataloader and validation dataloader. A subclass
|
||||||
|
should implement this method.
|
||||||
|
"""
|
||||||
raise NotImplementedError("setup_dataloader should be implemented.")
|
raise NotImplementedError("setup_dataloader should be implemented.")
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
__all__ = ["summary","gradient_norm", "freeze", "unfreeze"]
|
__all__ = ["summary", "gradient_norm", "freeze", "unfreeze"]
|
||||||
|
|
||||||
|
|
||||||
def summary(layer: nn.Layer):
|
def summary(layer: nn.Layer):
|
||||||
|
|
Loading…
Reference in New Issue