commit
07ce84c680
|
@ -12,6 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.0.0"
|
||||
__version__ = "0.2.0"
|
||||
|
||||
from parakeet import data, frontend, models, modules
|
||||
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
||||
|
|
|
@ -16,6 +16,8 @@ import librosa
|
|||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["AudioProcessor"]
|
||||
|
||||
class AudioProcessor(object):
|
||||
def __init__(self,
|
||||
sample_rate:int,
|
||||
|
@ -26,7 +28,7 @@ class AudioProcessor(object):
|
|||
f_min:int=0,
|
||||
f_max:int=None,
|
||||
window="hann",
|
||||
center="True",
|
||||
center=True,
|
||||
pad_mode="reflect"):
|
||||
# read & write
|
||||
self.sample_rate = sample_rate
|
||||
|
|
|
@ -13,6 +13,9 @@ https://github.com/mozilla/TTS/issues/377
|
|||
"""
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["NormalizerBase", "LogMagnitude", "UnitMagnitude"]
|
||||
|
||||
|
||||
class NormalizerBase(object):
|
||||
def transform(self, spec):
|
||||
raise NotImplementedError("transform must be implemented")
|
||||
|
|
|
@ -13,5 +13,4 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .dataset import *
|
||||
from .sampler import *
|
||||
from .batch import *
|
||||
|
|
|
@ -17,6 +17,10 @@ Batch functions for text sequences, audio and spectrograms are provided.
|
|||
"""
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"batch_text_id", "batch_wav", "batch_spec",
|
||||
"TextIDBatcher", "WavBatcher", "SpecBatcher",
|
||||
]
|
||||
|
||||
class TextIDBatcher(object):
|
||||
"""A wrapper class for `batch_text_id`."""
|
||||
|
|
|
@ -16,6 +16,11 @@ import six
|
|||
import paddle
|
||||
from paddle.io import Dataset
|
||||
|
||||
__all__ = [
|
||||
"split", "TransformDataset", "CacheDataset", "TupleDataset",
|
||||
"DictDataset", "SliceDataset", "SubsetDataset", "FilterDataset",
|
||||
"ChainDataset",
|
||||
]
|
||||
|
||||
def split(dataset, first_size):
|
||||
"""A utility function to split a dataset into two datasets."""
|
||||
|
|
|
@ -1,200 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
At most cases, we have non-stream dataset, which means we can random access it with __getitem__, and we can get the length of the dataset with __len__.
|
||||
|
||||
This suffices for a sampler. We implemente sampler as iterable of valid indices. By valid, we mean 0 <= index < N, where N is the length of the dataset. We then collect several indices within a batch and use them to collect examples from the dataset with __getitem__. Then transform these examples into a batch.
|
||||
|
||||
So the sampler is only responsible for generating valid indices.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
import paddle
|
||||
from paddle.io import Sampler
|
||||
|
||||
|
||||
class PartialyRandomizedSimilarTimeLengthSampler(Sampler):
|
||||
"""Partially randmoized sampler, implemented as a example sampler
|
||||
1. Sort by lengths
|
||||
2. Pick a small patch and randomize it
|
||||
3. Permutate mini-batchs
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
lengths,
|
||||
batch_size=4,
|
||||
batch_group_size=None,
|
||||
permutate=True):
|
||||
"""[summary]
|
||||
|
||||
Args:
|
||||
lengths (List[int]): The length of the examples of the dataset. This is the key to be considered as 'time length'.
|
||||
batch_size (int, optional): batch size. Defaults to 4.
|
||||
batch_group_size (int, optional): the size of a small batch. Random shuffling is applied within such patches. If `batch_group_size` is not provided, it is set to min(batch_size * 32, len(self.lengths)). Batch_group_size should be perfectly divided by batch_size. Defaults to None.
|
||||
permutate (bool, optional): permutate batches. Defaults to True.
|
||||
"""
|
||||
_lengths = np.array(
|
||||
lengths,
|
||||
dtype=np.int64) # maybe better implement length as a sort key
|
||||
self.lengths = np.sort(_lengths)
|
||||
self.sorted_indices = np.argsort(_lengths)
|
||||
|
||||
self.batch_size = batch_size
|
||||
if batch_group_size is None:
|
||||
batch_group_size = min(batch_size * 32, len(self.lengths))
|
||||
if batch_group_size % batch_size != 0:
|
||||
batch_group_size -= batch_group_size % batch_size
|
||||
|
||||
self.batch_group_size = batch_group_size
|
||||
assert batch_group_size % batch_size == 0
|
||||
self.permutate = permutate
|
||||
|
||||
def __iter__(self):
|
||||
indices = np.copy(self.sorted_indices)
|
||||
batch_group_size = self.batch_group_size
|
||||
s, e = 0, 0
|
||||
for i in range(len(indices) // batch_group_size):
|
||||
s = i * batch_group_size
|
||||
e = s + batch_group_size
|
||||
random.shuffle(indices[s:e]) # inplace
|
||||
|
||||
# Permutate batches
|
||||
if self.permutate:
|
||||
perm = np.arange(len(indices[:e]) // self.batch_size)
|
||||
random.shuffle(perm)
|
||||
indices[:e] = indices[:e].reshape(
|
||||
-1, self.batch_size)[perm, :].reshape(-1)
|
||||
|
||||
# Handle last elements
|
||||
s += batch_group_size
|
||||
#print(indices)
|
||||
if s < len(indices):
|
||||
random.shuffle(indices[s:])
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sorted_indices)
|
||||
|
||||
|
||||
class BucketSampler(Sampler):
|
||||
def __init__(self,
|
||||
lengths,
|
||||
batch_size=4,
|
||||
batch_group_size=None,
|
||||
permutate=True,
|
||||
num_trainers=1,
|
||||
rank=0):
|
||||
# maybe better implement length as a sort key
|
||||
_lengths = np.array(lengths, dtype=np.int64)
|
||||
self.lengths = np.sort(_lengths)
|
||||
self.sorted_indices = np.argsort(_lengths)
|
||||
self.num_trainers = num_trainers
|
||||
self.rank = rank
|
||||
|
||||
self.dataset_size = len(_lengths)
|
||||
self.num_samples = int(np.ceil(self.dataset_size / num_trainers))
|
||||
self.total_size = self.num_samples * num_trainers
|
||||
assert self.total_size >= self.dataset_size
|
||||
|
||||
self.batch_size = batch_size
|
||||
total_batch_size = num_trainers * batch_size
|
||||
self.total_batch_size = total_batch_size
|
||||
|
||||
if batch_group_size is None:
|
||||
batch_group_size = min(total_batch_size * 32, len(self.lengths))
|
||||
if batch_group_size % total_batch_size != 0:
|
||||
batch_group_size -= batch_group_size % total_batch_size
|
||||
|
||||
self.batch_group_size = batch_group_size
|
||||
assert batch_group_size % total_batch_size == 0
|
||||
self.permutate = permutate
|
||||
|
||||
def __iter__(self):
|
||||
indices = self.sorted_indices
|
||||
|
||||
# Append extra samples to make it evenly distributed on all trainers.
|
||||
num_extras = self.total_size - self.dataset_size
|
||||
extra_indices = np.random.choice(
|
||||
indices, size=(num_extras, ), replace=False)
|
||||
indices = np.concatenate((indices, extra_indices))
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
batch_group_size = self.batch_group_size
|
||||
s, e = 0, 0
|
||||
for i in range(len(indices) // batch_group_size):
|
||||
s = i * batch_group_size
|
||||
e = s + batch_group_size
|
||||
random.shuffle(indices[s:e]) # inplace
|
||||
|
||||
# Permutate batches
|
||||
total_batch_size = self.total_batch_size
|
||||
if self.permutate:
|
||||
perm = np.arange(len(indices[:e]) // total_batch_size)
|
||||
random.shuffle(perm)
|
||||
indices[:e] = indices[:e].reshape(
|
||||
-1, total_batch_size)[perm, :].reshape(-1)
|
||||
|
||||
# Handle last elements
|
||||
s += batch_group_size
|
||||
#print(indices)
|
||||
if s < len(indices):
|
||||
random.shuffle(indices[s:])
|
||||
|
||||
# Subset samples for each trainer.
|
||||
indices = indices[self.rank:self.total_size:self.num_trainers]
|
||||
assert len(indices) == self.num_samples
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sorted_indices)
|
||||
|
||||
|
||||
class WeightedRandomSampler(Sampler):
|
||||
"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
|
||||
Args:
|
||||
weights (List[float]): a sequence of weights, not necessary summing up to 1.
|
||||
num_samples (int): number of samples to draw.
|
||||
replacement (bool): whether samples are drawn with replacement. When replacement is False, num_samples should not be larger than len(weights).
|
||||
Example:
|
||||
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
|
||||
[0, 0, 0, 1, 0]
|
||||
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
|
||||
[0, 1, 4, 3, 2]
|
||||
"""
|
||||
|
||||
def __init__(self, weights, num_samples, replacement):
|
||||
if not isinstance(num_samples, int) or num_samples <= 0:
|
||||
raise ValueError("num_samples should be a positive integer "
|
||||
"value, but got num_samples={}".format(
|
||||
num_samples))
|
||||
self.weights = np.array(weights, dtype=np.float64)
|
||||
self.num_samples = num_samples
|
||||
self.replacement = replacement
|
||||
if replacement is False and num_samples > len(weights):
|
||||
raise ValueError(
|
||||
"when replacement is False, num_samples should not be"
|
||||
"larger that length of weight.")
|
||||
|
||||
def __iter__(self):
|
||||
return iter(
|
||||
np.random.choice(
|
||||
len(self.weights),
|
||||
size=(self.num_samples, ),
|
||||
replace=self.replacement,
|
||||
p=self.weights).tolist())
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
|
@ -2,6 +2,8 @@ from paddle.io import Dataset
|
|||
import os
|
||||
import librosa
|
||||
|
||||
__all__ = ["AudioFolderDataset"]
|
||||
|
||||
class AudioFolderDataset(Dataset):
|
||||
def __init__(self, path, sample_rate, extension="wav"):
|
||||
self.root = os.path.expanduser(path)
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from paddle.io import Dataset
|
||||
from pathlib import Path
|
||||
|
||||
__all__ = ["LJSpeechMetaData"]
|
||||
|
||||
class LJSpeechMetaData(Dataset):
|
||||
def __init__(self, root):
|
||||
self.root = Path(root).expanduser()
|
||||
|
|
|
@ -6,6 +6,9 @@ from parakeet.frontend import Vocab
|
|||
from opencc import OpenCC
|
||||
from parakeet.frontend.punctuation import get_punctuations
|
||||
|
||||
__all__ = ["Phonetics", "English", "Chinese"]
|
||||
|
||||
|
||||
class Phonetics(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, sentence):
|
||||
|
|
|
@ -2,6 +2,10 @@ from typing import Dict, Iterable, List
|
|||
from ruamel import yaml
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
__all__ = ["Vocab"]
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
def __init__(self, symbols: Iterable[str],
|
||||
padding_symbol="<pad>",
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from parakeet.models.clarinet import *
|
||||
#from parakeet.models.clarinet import *
|
||||
from parakeet.models.waveflow import *
|
||||
from parakeet.models.wavenet import *
|
||||
#from parakeet.models.wavenet import *
|
||||
|
||||
from parakeet.models.transformer_tts import *
|
||||
from parakeet.models.deepvoice3 import *
|
||||
#from parakeet.models.deepvoice3 import *
|
||||
# from parakeet.models.fastspeech import *
|
||||
|
|
|
@ -273,12 +273,14 @@ class MLPPreNet(nn.Layer):
|
|||
super(MLPPreNet, self).__init__()
|
||||
self.lin1 = nn.Linear(d_input, d_hidden)
|
||||
self.lin2 = nn.Linear(d_hidden, d_hidden)
|
||||
self.lin3 = nn.Linear(d_hidden, d_hidden)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, dropout):
|
||||
l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training)
|
||||
l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training)
|
||||
return l2
|
||||
l3 = self.lin3(l2)
|
||||
return l3
|
||||
|
||||
# NOTE: not used in
|
||||
class CNNPreNet(nn.Layer):
|
||||
|
@ -317,6 +319,7 @@ class CNNPostNet(nn.Layer):
|
|||
Conv1dBatchNorm(c_in, c_out, kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding=padding))
|
||||
self.last_bn = nn.BatchNorm1D(d_output)
|
||||
# for a layer that ends with a normalization layer that is targeted to
|
||||
# output a non zero-central output, it may take a long time to
|
||||
# train the scale and bias
|
||||
|
@ -328,7 +331,7 @@ class CNNPostNet(nn.Layer):
|
|||
x = layer(x)
|
||||
if i != (len(self.convs) - 1):
|
||||
x = F.tanh(x)
|
||||
x = x_in + x
|
||||
x = self.last_bn(x_in + x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -491,7 +494,7 @@ class TransformerTTS(nn.Layer):
|
|||
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
|
||||
|
||||
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
|
||||
if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index):
|
||||
if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index):
|
||||
if verbose:
|
||||
print("Hits stop condition.")
|
||||
break
|
||||
|
@ -526,6 +529,34 @@ class TransformerTTSLoss(nn.Layer):
|
|||
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||
|
||||
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||
losses = dict(
|
||||
loss=loss, # total loss
|
||||
mel_loss1=mel_loss1, # ouput mel loss
|
||||
mel_loss2=mel_loss2, # intermediate mel loss
|
||||
stop_loss=stop_loss # stop prob loss
|
||||
)
|
||||
return losses
|
||||
|
||||
|
||||
class AdaptiveTransformerTTSLoss(nn.Layer):
|
||||
def __init__(self):
|
||||
super(AdaptiveTransformerTTSLoss, self).__init__()
|
||||
|
||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs):
|
||||
mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype)
|
||||
mask1 = paddle.unsqueeze(mask, -1)
|
||||
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||
|
||||
batch_size, mel_len = mask.shape
|
||||
valid_lengths = mask.sum(-1).astype("int64")
|
||||
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
||||
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
||||
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype)
|
||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||
|
||||
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||
losses = dict(
|
||||
loss=loss, # total loss
|
||||
|
|
|
@ -141,6 +141,15 @@ class ResidualBlock(nn.Layer):
|
|||
raise ValueError("Only use start sequence at evaluation mode.")
|
||||
self._conv_buffer = None
|
||||
|
||||
# NOTE: call self.conv's weight norm hook expliccitly since
|
||||
# its weight will be visited directly in `add_input` without
|
||||
# calling its `__call__` method. If we do not trigger the weight
|
||||
# norm hook, the weight may be outdated. e.g. after loading from
|
||||
# a saved checkpoint
|
||||
# see also: https://github.com/pytorch/pytorch/issues/47588
|
||||
for hook in self.conv._forward_pre_hooks.values():
|
||||
hook(self.conv, None)
|
||||
|
||||
def add_input(self, x_row, condition_row):
|
||||
"""Compute the output for a row and update the buffer.
|
||||
|
||||
|
@ -158,10 +167,6 @@ class ResidualBlock(nn.Layer):
|
|||
self._update_buffer(x_row)
|
||||
|
||||
rw = self.rw
|
||||
# call self.conv's weight norm hook expliccitly since its __call__
|
||||
# method is not called here
|
||||
for hook in self.conv._forward_pre_hooks.values():
|
||||
hook(self.conv, self._conv_buffer)
|
||||
x_row = F.conv2d(
|
||||
self._conv_buffer,
|
||||
self.conv.weight,
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
import math
|
||||
import time
|
||||
from typing import Union, Sequence
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
|
||||
|
@ -25,48 +25,19 @@ import paddle.fluid.initializer as I
|
|||
import paddle.fluid.layers.distributions as D
|
||||
|
||||
from parakeet.modules.conv import Conv1dCell
|
||||
|
||||
__all__ = ["ConditionalWavenet"]
|
||||
|
||||
def quantize(values, n_bands):
|
||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
||||
|
||||
Args:
|
||||
values (Variable): dtype: flaot32 or float64. the floating point value.
|
||||
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
|
||||
|
||||
Returns:
|
||||
Variable: the quantized tensor, dtype: int64.
|
||||
"""
|
||||
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize(quantized, n_bands, dtype=None):
|
||||
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
|
||||
|
||||
Args:
|
||||
quantized (Variable): dtype: int64. The quantized value in the range [0, n_bands).
|
||||
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
|
||||
|
||||
Returns:
|
||||
Variable: the dequantized tensor, dtype is specified by dtype.
|
||||
"""
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
||||
return value
|
||||
from parakeet.modules.audio import quantize, dequantize, STFT
|
||||
|
||||
|
||||
def crop(x, audio_start, audio_length):
|
||||
"""Crop the upsampled condition to match audio_length. The upsampled condition has the same time steps as the whole audio does. But since audios are sliced to 0.5 seconds randomly while conditions are not, upsampled conditions should also be sliced to extaclt match the time steps of the audio slice.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, C, T), dtype float32, the upsample condition.
|
||||
audio_start (Variable): shape(B, ), dtype: int64, the index the starting point.
|
||||
x (Tensor): shape(B, C, T), dtype float32, the upsample condition.
|
||||
audio_start (Tensor): shape(B, ), dtype: int64, the index the starting point.
|
||||
audio_length (int): the length of the audio (number of samples it contaions).
|
||||
|
||||
Returns:
|
||||
Variable: shape(B, C, audio_length), cropped condition.
|
||||
Tensor: shape(B, C, audio_length), cropped condition.
|
||||
"""
|
||||
# crop audio
|
||||
slices = [] # for each example
|
||||
|
@ -81,9 +52,52 @@ def crop(x, audio_start, audio_length):
|
|||
return out
|
||||
|
||||
|
||||
class UpsampleNet(nn.LayerList):
|
||||
def __init__(self, upscale_factors=[16, 16]):
|
||||
"""UpsamplingNet.
|
||||
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose layer upsamples the time dimension by its `stride` times. And each Conv2DTranspose's filter_size at frequency dimension is 3.
|
||||
|
||||
Args:
|
||||
upscale_factors (list[int], optional): time upsampling factors for each Conv2DTranspose Layer. The `UpsampleNet` contains len(upscale_factor) Conv2DTranspose Layers. Each upscale_factor is used as the `stride` for the corresponding Conv2DTranspose. Defaults to [16, 16].
|
||||
Note:
|
||||
np.prod(upscale_factors) should equals the `hop_length` of the stft transformation used to extract spectrogram features from audios. For example, 16 * 16 = 256, then the spectram extracted using a stft transformation whose `hop_length` is 256. See `librosa.stft` for more details.
|
||||
"""
|
||||
super(UpsampleNet, self).__init__()
|
||||
self.upscale_factors = list(upscale_factors)
|
||||
self.upscale_factor = 1
|
||||
for item in upscale_factors:
|
||||
self.upscale_factor *= item
|
||||
|
||||
for factor in self.upscale_factors:
|
||||
self.append(
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv2DTranspose(1, 1,
|
||||
kernel_size=(3, 2 * factor),
|
||||
stride=(1, factor),
|
||||
padding=(1, factor // 2))))
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute the upsampled condition.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B, F, T), dtype float32, the condition (mel spectrogram here.) (F means the frequency bands). In the internal Conv2DTransposes, the frequency dimension is treated as `height` dimension instead of `in_channels`.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, F, T * upscale_factor), dtype float32, the upsampled condition.
|
||||
"""
|
||||
x = paddle.unsqueeze(x, 1)
|
||||
for sublayer in self:
|
||||
x = F.leaky_relu(sublayer(x), 0.4)
|
||||
x = paddle.squeeze(x, 1)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
def __init__(self, residual_channels, condition_dim, filter_size,
|
||||
dilation):
|
||||
def __init__(self,
|
||||
residual_channels: int,
|
||||
condition_dim: int,
|
||||
filter_size: Union[int, Sequence[int]],
|
||||
dilation: int):
|
||||
"""A Residual block in wavenet. It does not have parametric residual or skip connection. It consists of a Conv1DCell and an Conv1D(filter_size = 1) to integrate the condition.
|
||||
|
||||
Args:
|
||||
|
@ -108,7 +122,7 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
std = math.sqrt(1 / condition_dim)
|
||||
condition_proj = Conv1dCell(condition_dim, dilated_channels, (1,),
|
||||
weight_attr=I.Normal(scale=std))
|
||||
weight_attr=I.Normal(scale=std))
|
||||
self.condition_proj = nn.utils.weight_norm(condition_proj)
|
||||
|
||||
self.filter_size = filter_size
|
||||
|
@ -121,20 +135,13 @@ class ResidualBlock(nn.Layer):
|
|||
"""Conv1D gated-tanh Block.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B, C_res, T), the input. (B stands for batch_size,
|
||||
C_res stands for residual channels, T stands for time steps.)
|
||||
dtype float32.
|
||||
condition (Tensor, optional): shape(B, C_cond, T), the condition,
|
||||
it has been upsampled in time steps, so it has the same time
|
||||
steps as the input does.(C_cond stands for the condition's channels).
|
||||
Defaults to None.
|
||||
x (Tensor): shape(B, C_res, T), the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.) dtype float32.
|
||||
condition (Tensor, optional): shape(B, C_cond, T), the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels). Defaults to None.
|
||||
|
||||
Returns:
|
||||
(residual, skip_connection)
|
||||
residual (Tensor): shape(B, C_res, T), the residual, which is used
|
||||
as the input to the next layer of ResidualBlock.
|
||||
skip_connection (Tensor): shape(B, C_res, T), the skip connection.
|
||||
This output is accumulated with that of other ResidualBlocks.
|
||||
residual (Tensor): shape(B, C_res, T), the residual, which is used as the input to the next layer of ResidualBlock.
|
||||
skip_connection (Tensor): shape(B, C_res, T), the skip connection. This output is accumulated with that of other ResidualBlocks.
|
||||
"""
|
||||
h = x
|
||||
|
||||
|
@ -155,30 +162,22 @@ class ResidualBlock(nn.Layer):
|
|||
return residual, skip_connection
|
||||
|
||||
def start_sequence(self):
|
||||
"""
|
||||
Prepare the ResidualBlock to generate a new sequence. This method
|
||||
should be called before starting calling `add_input` multiple times.
|
||||
"""Prepare the ResidualBlock to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||
"""
|
||||
self.conv.start_sequence()
|
||||
self.condition_proj.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""
|
||||
Add a step input. This method works similarily with `forward` but
|
||||
in a `step-in-step-out` fashion.
|
||||
"""Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
|
||||
|
||||
Args:
|
||||
x (Variable): shape(B, C_res), input for a step, dtype float32.
|
||||
condition (Variable, optional): shape(B, C_cond). condition for a
|
||||
step, dtype float32. Defaults to None.
|
||||
x (Tensor): shape(B, C_res), input for a step, dtype float32.
|
||||
condition (Tensor, optional): shape(B, C_cond). condition for a step, dtype float32. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(residual, skip_connection)
|
||||
residual (Variable): shape(B, C_res), the residual for a step,
|
||||
which is used as the input to the next layer of ResidualBlock.
|
||||
skip_connection (Variable): shape(B, C_res), the skip connection
|
||||
for a step. This output is accumulated with that of other
|
||||
ResidualBlocks.
|
||||
residual (Tensor): shape(B, C_res), the residual for a step, which is used as the input to the next layer of ResidualBlock.
|
||||
skip_connection (Tensor): shape(B, C_res), the skip connection for a step. This output is accumulated with that of other ResidualBlocks.
|
||||
"""
|
||||
h = x
|
||||
|
||||
|
@ -200,22 +199,24 @@ class ResidualBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResidualNet(nn.LayerList):
|
||||
def __init__(self, n_loop, n_layer, residual_channels, condition_dim,
|
||||
filter_size):
|
||||
"""The residual network in wavenet. It consists of `n_layer` stacks,
|
||||
each of which consists of `n_loop` ResidualBlocks.
|
||||
def __init__(self,
|
||||
n_stack: int,
|
||||
n_loop: int,
|
||||
residual_channels: int,
|
||||
condition_dim: int,
|
||||
filter_size: int):
|
||||
"""The residual network in wavenet. It consists of `n_layer` stacks, each of which consists of `n_loop` ResidualBlocks.
|
||||
|
||||
Args:
|
||||
n_stack (int): number of stacks in the `ResidualNet`.
|
||||
n_loop (int): number of ResidualBlocks in a stack.
|
||||
n_layer (int): number of stacks in the `ResidualNet`.
|
||||
residual_channels (int): channels of each `ResidualBlock`'s input.
|
||||
condition_dim (int): channels of the condition.
|
||||
filter_size (int): filter size of the internal Conv1DCell of each
|
||||
`ResidualBlock`.
|
||||
filter_size (int): filter size of the internal Conv1DCell of each `ResidualBlock`.
|
||||
"""
|
||||
super(ResidualNet, self).__init__()
|
||||
# double the dilation at each layer in a loop(n_loop layers)
|
||||
dilations = [2**i for i in range(n_loop)] * n_layer
|
||||
# double the dilation at each layer in a stack
|
||||
dilations = [2**i for i in range(n_loop)] * n_stack
|
||||
self.context_size = 1 + sum(dilations)
|
||||
for dilation in dilations:
|
||||
self.append(ResidualBlock(residual_channels, condition_dim, filter_size, dilation))
|
||||
|
@ -223,13 +224,8 @@ class ResidualNet(nn.LayerList):
|
|||
def forward(self, x, condition=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): shape(B, C_res, T), dtype float32, the input.
|
||||
(B stands for batch_size, C_res stands for residual channels,
|
||||
T stands for time steps.)
|
||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32,
|
||||
the condition, it has been upsampled in time steps, so it has
|
||||
the same time steps as the input does.(C_cond stands for the
|
||||
condition's channels) Defaults to None.
|
||||
x (Tensor): shape(B, C_res, T), dtype float32, the input. (B stands for batch_size, C_res stands for residual channels, T stands for time steps.)
|
||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32, the condition, it has been upsampled in time steps, so it has the same time steps as the input does.(C_cond stands for the condition's channels) Defaults to None.
|
||||
|
||||
Returns:
|
||||
skip_connection (Tensor): shape(B, C_res, T), dtype float32, the output.
|
||||
|
@ -244,24 +240,20 @@ class ResidualNet(nn.LayerList):
|
|||
return skip_connections
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the ResidualNet to generate a new sequence. This method
|
||||
should be called before starting calling `add_input` multiple times.
|
||||
"""Prepare the ResidualNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||
"""
|
||||
for block in self:
|
||||
block.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""Add a step input. This method works similarily with `forward` but
|
||||
in a `step-in-step-out` fashion.
|
||||
"""Add a step input. This method works similarily with `forward` but in a `step-in-step-out` fashion.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B, C_res), dtype float32, input for a step.
|
||||
condition (Tensor, optional): shape(B, C_cond), dtype float32,
|
||||
condition for a step. Defaults to None.
|
||||
condition (Tensor, optional): shape(B, C_cond), dtype float32, condition for a step. Defaults to None.
|
||||
|
||||
Returns:
|
||||
skip_connection (Tensor): shape(B, C_res), dtype float32, the
|
||||
output for a step.
|
||||
skip_connection (Tensor): shape(B, C_res), dtype float32, the output for a step.
|
||||
"""
|
||||
|
||||
for i, func in enumerate(self):
|
||||
|
@ -275,31 +267,19 @@ class ResidualNet(nn.LayerList):
|
|||
|
||||
|
||||
class WaveNet(nn.Layer):
|
||||
def __init__(self, n_loop, n_layer, residual_channels, output_dim,
|
||||
def __init__(self, n_stack, n_loop, residual_channels, output_dim,
|
||||
condition_dim, filter_size, loss_type, log_scale_min):
|
||||
"""Wavenet that transform upsampled mel spectrogram into waveform.
|
||||
|
||||
Args:
|
||||
n_stack (int): n_stack for the internal ResidualNet.
|
||||
n_loop (int): n_loop for the internal ResidualNet.
|
||||
n_layer (int): n_loop for the internal ResidualNet.
|
||||
residual_channels (int): the channel of the input.
|
||||
output_dim (int): the channel of the output distribution.
|
||||
condition_dim (int): the channel of the condition.
|
||||
filter_size (int): the filter size of the internal ResidualNet.
|
||||
loss_type (str): loss type of the wavenet. Possible values are
|
||||
'softmax' and 'mog'.
|
||||
If `loss_type` is 'softmax', the output is the logits of the
|
||||
catrgotical(multinomial) distribution, `output_dim` means the
|
||||
number of classes of the categorical distribution.
|
||||
If `loss_type` is mog(mixture of gaussians), the output is the
|
||||
parameters of a mixture of gaussians, which consists of weight
|
||||
(in the form of logit) of each gaussian distribution and its
|
||||
mean and log standard deviaton. So when `loss_type` is 'mog',
|
||||
`output_dim` should be perfectly divided by 3.
|
||||
log_scale_min (int): the minimum value of log standard deviation
|
||||
of the output gaussian distributions. Note that this value is
|
||||
only used for computing loss if `loss_type` is 'mog', values
|
||||
less than `log_scale_min` is clipped when computing loss.
|
||||
loss_type (str): loss type of the wavenet. Possible values are 'softmax' and 'mog'. If `loss_type` is 'softmax', the output is the logits of the catrgotical(multinomial) distribution, `output_dim` means the number of classes of the categorical distribution. If `loss_type` is mog(mixture of gaussians), the output is the parameters of a mixture of gaussians, which consists of weight(in the form of logit) of each gaussian distribution and its mean and log standard deviaton. So when `loss_type` is 'mog', `output_dim` should be perfectly divided by 3.
|
||||
log_scale_min (int): the minimum value of log standard deviation of the output gaussian distributions. Note that this value is only used for computing loss if `loss_type` is 'mog', values less than `log_scale_min` is clipped when computing loss.
|
||||
"""
|
||||
super(WaveNet, self).__init__()
|
||||
if loss_type not in ["softmax", "mog"]:
|
||||
|
@ -312,7 +292,7 @@ class WaveNet(nn.Layer):
|
|||
"with Mixture of Gaussians(mog) output, the output dim must be divisible by 3, but get {}".format(output_dim))
|
||||
self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=-1)
|
||||
|
||||
self.resnet = ResidualNet(n_loop, n_layer, residual_channels,
|
||||
self.resnet = ResidualNet(n_stack, n_loop, residual_channels,
|
||||
condition_dim, filter_size)
|
||||
self.context_size = self.resnet.context_size
|
||||
|
||||
|
@ -334,12 +314,10 @@ class WaveNet(nn.Layer):
|
|||
|
||||
Args:
|
||||
x (Tensor): shape(B, T), dtype float32, the input waveform.
|
||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32,
|
||||
the upsampled condition. Defaults to None.
|
||||
condition (Tensor, optional): shape(B, C_cond, T), dtype float32, the upsampled condition. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T, C_output), dtype float32, the parameter of
|
||||
the output distributions.
|
||||
Tensor: shape(B, T, C_output), dtype float32, the parameter of the output distributions.
|
||||
"""
|
||||
|
||||
# Causal Conv
|
||||
|
@ -362,24 +340,19 @@ class WaveNet(nn.Layer):
|
|||
return y
|
||||
|
||||
def start_sequence(self):
|
||||
"""Prepare the WaveNet to generate a new sequence. This method should
|
||||
be called before starting calling `add_input` multiple times.
|
||||
"""Prepare the WaveNet to generate a new sequence. This method should be called before starting calling `add_input` multiple times.
|
||||
"""
|
||||
self.resnet.start_sequence()
|
||||
|
||||
def add_input(self, x, condition=None):
|
||||
"""compute the output distribution (represented by its parameters) for
|
||||
a step. It works similarily with the `forward` method but in a
|
||||
`step-in-step-out` fashion.
|
||||
"""compute the output distribution (represented by its parameters) for a step. It works similarily with the `forward` method but in a `step-in-step-out` fashion.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B,), dtype float32, a step of the input waveform.
|
||||
condition (Tensor, optional): shape(B, C_cond, ), dtype float32, a
|
||||
step of the upsampled condition. Defaults to None.
|
||||
condition (Tensor, optional): shape(B, C_cond, ), dtype float32, a step of the upsampled condition. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, C_output), dtype float32, the parameter of the
|
||||
output distributions.
|
||||
Tensor: shape(B, C_output), dtype float32, the parameter of the output distributions.
|
||||
"""
|
||||
# Causal Conv
|
||||
if self.loss_type == "softmax":
|
||||
|
@ -402,12 +375,8 @@ class WaveNet(nn.Layer):
|
|||
"""compute the loss where output distribution is a categorial distribution.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the logits of the
|
||||
output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that
|
||||
the target's corresponding time index is one step ahead of the
|
||||
output distribution. And output distribution whose input contains
|
||||
padding is neglected in loss computation.
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the logits of the output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1, ), dtype float32, the loss.
|
||||
|
@ -420,15 +389,14 @@ class WaveNet(nn.Layer):
|
|||
label = paddle.unsqueeze(quantized, -1)
|
||||
|
||||
loss = F.softmax_with_cross_entropy(y, label)
|
||||
reduced_loss = paddle.reduce_mean(loss)
|
||||
reduced_loss = paddle.mean(loss)
|
||||
return reduced_loss
|
||||
|
||||
def sample_from_softmax(self, y):
|
||||
"""Sample from the output distribution where the output distribution is
|
||||
a categorical distriobution.
|
||||
"""Sample from the output distribution where the output distribution is a categorical distriobution.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), the logits of the output distribution.
|
||||
y (Tensor): shape(B, T, C_output), the logits of the output distribution
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||
|
@ -446,16 +414,8 @@ class WaveNet(nn.Layer):
|
|||
"""compute the loss where output distribution is a mixture of Gaussians.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
||||
the output distribution. It is the concatenation of 3 parts,
|
||||
the logits of every distribution, the mean of each distribution
|
||||
and the log standard deviation of each distribution. Each part's
|
||||
shape is (B, T, n_mixture), where `n_mixture` means the number
|
||||
of Gaussians in the mixture.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that
|
||||
the target's corresponding time index is one step ahead of the
|
||||
output distribution. And output distribution whose input contains
|
||||
padding is neglected in loss computation.
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1, ), dtype float32, the loss.
|
||||
|
@ -483,22 +443,16 @@ class WaveNet(nn.Layer):
|
|||
|
||||
pdf_x = p_mixture * pdf_x
|
||||
# pdf_x: [bs, len]
|
||||
pdf_x = paddle.reduce_sum(pdf_x, -1)
|
||||
pdf_x = paddle.sum(pdf_x, -1)
|
||||
per_sample_loss = -paddle.log(pdf_x + 1e-9)
|
||||
|
||||
loss = paddle.reduce_mean(per_sample_loss)
|
||||
loss = paddle.mean(per_sample_loss)
|
||||
return loss
|
||||
|
||||
def sample_from_mog(self, y):
|
||||
"""Sample from the output distribution where the output distribution is
|
||||
a mixture of Gaussians.
|
||||
"""Sample from the output distribution where the output distribution is a mixture of Gaussians.
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
||||
the output distribution. It is the concatenation of 3 parts, the
|
||||
logits of every distribution, the mean of each distribution and the
|
||||
log standard deviation of each distribution. Each part's shape is
|
||||
(B, T, n_mixture), where `n_mixture` means the number of Gaussians
|
||||
in the mixture.
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution. It is the concatenation of 3 parts, the logits of every distribution, the mean of each distribution and the log standard deviation of each distribution. Each part's shape is (B, T, n_mixture), where `n_mixture` means the number of Gaussians in the mixture.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||
|
@ -529,8 +483,7 @@ class WaveNet(nn.Layer):
|
|||
def sample(self, y):
|
||||
"""Sample from the output distribution.
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
||||
the output distribution.
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), waveform sampled from the output distribution.
|
||||
|
@ -544,12 +497,8 @@ class WaveNet(nn.Layer):
|
|||
"""compute the loss where output distribution is a mixture of Gaussians.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of
|
||||
the output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that
|
||||
the target's corresponding time index is one step ahead of the
|
||||
output distribution. And output distribution whose input contains
|
||||
padding is neglected in loss computation.
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, the parameterd of the output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, the target audio. Note that the target's corresponding time index is one step ahead of the output distribution. And output distribution whose input contains padding is neglected in loss computation.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(1, ), dtype float32, the loss.
|
||||
|
@ -560,64 +509,9 @@ class WaveNet(nn.Layer):
|
|||
return self.compute_mog_loss(y, t)
|
||||
|
||||
|
||||
class UpsampleNet(nn.LayerList):
|
||||
def __init__(self, upscale_factors=[16, 16]):
|
||||
"""UpsamplingNet.
|
||||
It consists of several layers of Conv2DTranspose. Each Conv2DTranspose
|
||||
layer upsamples the time dimension by its `stride` times. And each
|
||||
Conv2DTranspose's filter_size at frequency dimension is 3.
|
||||
|
||||
Args:
|
||||
upscale_factors (list[int], optional): time upsampling factors for
|
||||
each Conv2DTranspose Layer. The `UpsampleNet` contains
|
||||
len(upscale_factor) Conv2DTranspose Layers. Each upscale_factor
|
||||
is used as the `stride` for the corresponding Conv2DTranspose.
|
||||
Defaults to [16, 16].
|
||||
Note:
|
||||
np.prod(upscale_factors) should equals the `hop_length` of the stft
|
||||
transformation used to extract spectrogram features from audios.
|
||||
For example, 16 * 16 = 256, then the spectram extracted using a
|
||||
stft transformation whose `hop_length` is 256. See `librosa.stft`
|
||||
for more details.
|
||||
"""
|
||||
super(UpsampleNet, self).__init__()
|
||||
self.upscale_factors = list(upscale_factors)
|
||||
self.upscale_factor = 1
|
||||
for item in upscale_factors:
|
||||
self.upscale_factor *= item
|
||||
|
||||
for factor in self.upscale_factors:
|
||||
self.append(
|
||||
nn.utils.weight_norm(
|
||||
nn.ConvTranspose2d(1, 1,
|
||||
kernel_size=(3, 2 * factor),
|
||||
stride=(1, factor),
|
||||
padding=(1, factor // 2))))
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute the upsampled condition.
|
||||
|
||||
Args:
|
||||
x (Tensor): shape(B, F, T), dtype float32, the condition
|
||||
(mel spectrogram here.) (F means the frequency bands). In the
|
||||
internal Conv2DTransposes, the frequency dimension is treated
|
||||
as `height` dimension instead of `in_channels`.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, F, T * upscale_factor), dtype float32, the
|
||||
upsampled condition.
|
||||
"""
|
||||
x = paddle.unsqueeze(x, 1)
|
||||
for sublayer in self:
|
||||
x = F.leaky_relu(sublayer(x), 0.4)
|
||||
x = paddle.squeeze(x, 1)
|
||||
return x
|
||||
|
||||
|
||||
class ConditionalWavenet(nn.Layer):
|
||||
def __init__(self, encoder, decoder):
|
||||
"""Conditional Wavenet, which contains an UpsampleNet as the encoder
|
||||
and a WaveNet as the decoder. It is an autoregressive model.
|
||||
"""Conditional Wavenet, which contains an UpsampleNet as the encoder and a WaveNet as the decoder. It is an autoregressive model.
|
||||
|
||||
Args:
|
||||
encoder (UpsampleNet): the UpsampleNet as the encoder.
|
||||
|
@ -628,20 +522,15 @@ class ConditionalWavenet(nn.Layer):
|
|||
self.decoder = decoder
|
||||
|
||||
def forward(self, audio, mel, audio_start):
|
||||
"""Compute the output distribution given the mel spectrogram and the
|
||||
input(for teacher force training).
|
||||
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
||||
|
||||
Args:
|
||||
audio (Tensor): shape(B, T_audio), dtype float32, ground truth
|
||||
waveform, used for teacher force training.
|
||||
mel (Tensor): shape(B, F, T_mel), dtype float32, mel spectrogram.
|
||||
Note that it is the spectrogram for the whole utterance.
|
||||
audio_start (Tensor): shape(B, ), dtype: int, audio slices' start
|
||||
positions for each utterance.
|
||||
audio (Tensor): shape(B, T_audio), dtype float32, ground truth waveform, used for teacher force training.
|
||||
mel (Tensor): shape(B, F, T_mel), dtype float32, mel spectrogram. Note that it is the spectrogram for the whole utterance.
|
||||
audio_start (Tensor): shape(B, ), dtype: int, audio slices' start positions for each utterance.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T_audio - 1, C_putput), parameters for the output
|
||||
distribution.(C_output is the `output_dim` of the decoder.)
|
||||
Tensor: shape(B, T_audio - 1, C_putput), parameters for the output distribution.(C_output is the `output_dim` of the decoder.)
|
||||
"""
|
||||
audio_length = audio.shape[1] # audio clip's length
|
||||
condition = self.encoder(mel)
|
||||
|
@ -655,12 +544,10 @@ class ConditionalWavenet(nn.Layer):
|
|||
return y
|
||||
|
||||
def loss(self, y, t):
|
||||
"""compute loss with respect to the output distribution and the targer
|
||||
audio.
|
||||
"""compute loss with respect to the output distribution and the targer audio.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T - 1, C_output), dtype float32, parameters of
|
||||
the output distribution.
|
||||
y (Tensor): shape(B, T - 1, C_output), dtype float32, parameters of the output distribution.
|
||||
t (Tensor): shape(B, T), dtype float32, target waveform.
|
||||
|
||||
Returns:
|
||||
|
@ -674,12 +561,10 @@ class ConditionalWavenet(nn.Layer):
|
|||
"""Sample from the output distribution.
|
||||
|
||||
Args:
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, parameters of the
|
||||
output distribution.
|
||||
y (Tensor): shape(B, T, C_output), dtype float32, parameters of the output distribution.
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T), dtype float32, sampled waveform from the output
|
||||
distribution.
|
||||
Tensor: shape(B, T), dtype float32, sampled waveform from the output distribution.
|
||||
"""
|
||||
samples = self.decoder.sample(y)
|
||||
return samples
|
||||
|
@ -692,9 +577,7 @@ class ConditionalWavenet(nn.Layer):
|
|||
mel (Tensor): shape(B, F, T), condition(mel spectrogram here).
|
||||
|
||||
Returns:
|
||||
Tensor: shape(B, T * upsacle_factor), synthesized waveform.
|
||||
(`upscale_factor` is the `upscale_factor` of the encoder
|
||||
`UpsampleNet`)
|
||||
Tensor: shape(B, T * upsacle_factor), synthesized waveform.(`upscale_factor` is the `upscale_factor` of the encoder `UpsampleNet`)
|
||||
"""
|
||||
condition = self.encoder(mel)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
|
@ -712,6 +595,3 @@ class ConditionalWavenet(nn.Layer):
|
|||
|
||||
samples = paddle.concat(samples, -1)
|
||||
return samples
|
||||
|
||||
|
||||
# TODO WaveNetLoss
|
|
@ -4,6 +4,38 @@ from paddle.nn import functional as F
|
|||
from scipy import signal
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["quantize", "dequantize", "STFT"]
|
||||
|
||||
|
||||
def quantize(values, n_bands):
|
||||
"""Linearlly quantize a float Tensor in [-1, 1) to an interger Tensor in [0, n_bands).
|
||||
|
||||
Args:
|
||||
values (Tensor): dtype: flaot32 or float64. the floating point value.
|
||||
n_bands (int): the number of bands. The output integer Tensor's value is in the range [0, n_bans).
|
||||
|
||||
Returns:
|
||||
Tensor: the quantized tensor, dtype: int64.
|
||||
"""
|
||||
quantized = paddle.cast((values + 1.0) / 2.0 * n_bands, "int64")
|
||||
return quantized
|
||||
|
||||
|
||||
def dequantize(quantized, n_bands, dtype=None):
|
||||
"""Linearlly dequantize an integer Tensor into a float Tensor in the range [-1, 1).
|
||||
|
||||
Args:
|
||||
quantized (Tensor): dtype: int64. The quantized value in the range [0, n_bands).
|
||||
n_bands (int): number of bands. The input integer Tensor's value is in the range [0, n_bans).
|
||||
dtype (str, optional): data type of the output.
|
||||
Returns:
|
||||
Tensor: the dequantized tensor, dtype is specified by dtype.
|
||||
"""
|
||||
dtype = dtype or paddle.get_default_dtype()
|
||||
value = (paddle.cast(quantized, dtype) + 0.5) * (2.0 / n_bands) - 1.0
|
||||
return value
|
||||
|
||||
|
||||
class STFT(nn.Layer):
|
||||
def __init__(self, n_fft, hop_length, win_length, window="hanning"):
|
||||
"""A module for computing differentiable stft transform. See `librosa.stft` for more details.
|
|
@ -42,6 +42,14 @@ class Conv1dCell(nn.Conv1D):
|
|||
if self.training:
|
||||
raise Exception("only use start_sequence in evaluation")
|
||||
self._buffer = None
|
||||
|
||||
# NOTE: call self's weight norm hook expliccitly since self.weight
|
||||
# is visited directly in this method without calling self.__call__
|
||||
# method. If we do not trigger the weight norm hook, the weight
|
||||
# may be outdated. e.g. after loading from a saved checkpoint
|
||||
# see also: https://github.com/pytorch/pytorch/issues/47588
|
||||
for hook in self._forward_pre_hooks.values():
|
||||
hook(self, None)
|
||||
self._reshaped_weight = paddle.reshape(
|
||||
self.weight, (self._out_channels, -1))
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import numba
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
@ -12,7 +14,7 @@ def weighted_mean(input, weight):
|
|||
Returns:
|
||||
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
||||
"""
|
||||
weight = paddle.cast(weight, input.dtype)
|
||||
weight = paddle.cast(weight, input.dtype)
|
||||
return paddle.mean(input * weight)
|
||||
|
||||
def masked_l1_loss(prediction, target, mask):
|
||||
|
@ -22,3 +24,32 @@ def masked_l1_loss(prediction, target, mask):
|
|||
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
||||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||
return weighted_mean(ce, mask)
|
||||
|
||||
def diagonal_loss(attentions, input_lengths, target_lengths, g=0.2, multihead=False):
|
||||
"""A metric to evaluate how diagonal a attention distribution is."""
|
||||
W = guided_attentions(input_lengths, target_lengths, g)
|
||||
W_tensor = paddle.to_tensor(W)
|
||||
if not multihead:
|
||||
return paddle.mean(attentions * W_tensor)
|
||||
else:
|
||||
return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1))
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
def guided_attention(N, max_N, T, max_T, g):
|
||||
W = np.zeros((max_T, max_N), dtype=np.float32)
|
||||
for t in range(T):
|
||||
for n in range(N):
|
||||
W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
||||
# (T_dec, T_enc)
|
||||
return W
|
||||
|
||||
def guided_attentions(input_lengths, target_lengths, g=0.2):
|
||||
B = len(input_lengths)
|
||||
max_input_len = input_lengths.max()
|
||||
max_target_len = target_lengths.max()
|
||||
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
|
||||
for b in range(B):
|
||||
W[b] = guided_attention(input_lengths[b], max_input_len,
|
||||
target_lengths[b], max_target_len, g)
|
||||
# (B, T_dec, T_enc)
|
||||
return W
|
|
@ -1,202 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
import paddle.fluid.layers as layers
|
||||
|
||||
|
||||
class Linear(dg.Layer):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
is_bias=True,
|
||||
dtype="float32"):
|
||||
super(Linear, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.dtype = dtype
|
||||
self.weight = fluid.ParamAttr(
|
||||
initializer=fluid.initializer.XavierInitializer())
|
||||
self.bias = is_bias
|
||||
|
||||
if is_bias is not False:
|
||||
k = math.sqrt(1.0 / in_features)
|
||||
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
|
||||
low=-k, high=k))
|
||||
|
||||
self.linear = dg.Linear(
|
||||
in_features,
|
||||
out_features,
|
||||
param_attr=self.weight,
|
||||
bias_attr=self.bias, )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ScaledDotProductAttention(dg.Layer):
|
||||
def __init__(self, d_key):
|
||||
"""Scaled dot product attention module.
|
||||
|
||||
Args:
|
||||
d_key (int): the dim of key in multihead attention.
|
||||
"""
|
||||
super(ScaledDotProductAttention, self).__init__()
|
||||
|
||||
self.d_key = d_key
|
||||
|
||||
# please attention this mask is diff from pytorch
|
||||
def forward(self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
mask=None,
|
||||
query_mask=None,
|
||||
dropout=0.1):
|
||||
"""
|
||||
Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
|
||||
value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
|
||||
query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
|
||||
mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None.
|
||||
query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None.
|
||||
dropout (float32, optional): the probability of dropout. Defaults to 0.1.
|
||||
Returns:
|
||||
result (Variable): shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable): shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
# Compute attention score
|
||||
attention = layers.matmul(
|
||||
query, key, transpose_y=True, alpha=self.d_key
|
||||
**-0.5) #transpose the last dim in y
|
||||
|
||||
# Mask key to ignore padding
|
||||
if mask is not None:
|
||||
attention = attention + mask
|
||||
attention = layers.softmax(attention, use_cudnn=True)
|
||||
attention = layers.dropout(
|
||||
attention, dropout, dropout_implementation='upscale_in_train')
|
||||
|
||||
# Mask query to ignore padding
|
||||
if query_mask is not None:
|
||||
attention = attention * query_mask
|
||||
|
||||
result = layers.matmul(attention, value)
|
||||
return result, attention
|
||||
|
||||
|
||||
class MultiheadAttention(dg.Layer):
|
||||
def __init__(self,
|
||||
num_hidden,
|
||||
d_k,
|
||||
d_q,
|
||||
num_head=4,
|
||||
is_bias=False,
|
||||
dropout=0.1,
|
||||
is_concat=True):
|
||||
"""Multihead Attention.
|
||||
|
||||
Args:
|
||||
num_hidden (int): the number of hidden layer in network.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
||||
is_bias (bool, optional): whether have bias in linear layers. Default to False.
|
||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
||||
is_concat (bool, optional): whether concat query and result. Default to True.
|
||||
"""
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
self.d_k = d_k
|
||||
self.d_q = d_q
|
||||
self.dropout = dropout
|
||||
self.is_concat = is_concat
|
||||
|
||||
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
||||
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
||||
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
|
||||
|
||||
self.scal_attn = ScaledDotProductAttention(d_k)
|
||||
|
||||
if self.is_concat:
|
||||
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
||||
else:
|
||||
self.fc = Linear(num_head * d_q, num_hidden)
|
||||
|
||||
self.layer_norm = dg.LayerNorm(num_hidden)
|
||||
|
||||
def forward(self, key, value, query_input, mask=None, query_mask=None):
|
||||
"""
|
||||
Compute attention.
|
||||
|
||||
Args:
|
||||
key (Variable): shape(B, T, C), dtype float32, the input key of attention.
|
||||
value (Variable): shape(B, T, C), dtype float32, the input value of attention.
|
||||
query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
|
||||
mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
|
||||
query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
|
||||
|
||||
Returns:
|
||||
result (Variable): shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable): shape(num_head * B, T, C), the attention of key and query.
|
||||
"""
|
||||
|
||||
batch_size = key.shape[0]
|
||||
seq_len_key = key.shape[1]
|
||||
seq_len_query = query_input.shape[1]
|
||||
|
||||
# Make multihead attention
|
||||
key = layers.reshape(
|
||||
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
|
||||
value = layers.reshape(
|
||||
self.value(value),
|
||||
[batch_size, seq_len_key, self.num_head, self.d_k])
|
||||
query = layers.reshape(
|
||||
self.query(query_input),
|
||||
[batch_size, seq_len_query, self.num_head, self.d_q])
|
||||
|
||||
key = layers.reshape(
|
||||
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
||||
value = layers.reshape(
|
||||
layers.transpose(value, [2, 0, 1, 3]),
|
||||
[-1, seq_len_key, self.d_k])
|
||||
query = layers.reshape(
|
||||
layers.transpose(query, [2, 0, 1, 3]),
|
||||
[-1, seq_len_query, self.d_q])
|
||||
|
||||
result, attention = self.scal_attn(
|
||||
key, value, query, mask=mask, query_mask=query_mask)
|
||||
|
||||
# concat all multihead result
|
||||
result = layers.reshape(
|
||||
result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
||||
result = layers.reshape(
|
||||
layers.transpose(result, [1, 2, 0, 3]),
|
||||
[batch_size, seq_len_query, -1])
|
||||
if self.is_concat:
|
||||
result = layers.concat([query_input, result], axis=-1)
|
||||
result = layers.dropout(
|
||||
self.fc(result),
|
||||
self.dropout,
|
||||
dropout_implementation='upscale_in_train')
|
||||
result = result + query_input
|
||||
|
||||
result = self.layer_norm(result)
|
||||
return result, attention
|
|
@ -0,0 +1,21 @@
|
|||
import argparse
|
||||
|
||||
def default_argument_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# data and outpu
|
||||
parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.")
|
||||
parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.")
|
||||
parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and log. If not provided, a directory is created in runs/ to save outputs.")
|
||||
|
||||
# load from saved checkpoint
|
||||
parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load")
|
||||
|
||||
# running
|
||||
parser.add_argument("--device", type=str, choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.")
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.")
|
||||
|
||||
# overwrite extra config and default config
|
||||
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
|
||||
|
||||
return parser
|
|
@ -0,0 +1,12 @@
|
|||
from yacs.config import CfgNode
|
||||
|
||||
_C = CfgNode(
|
||||
dict(
|
||||
valid_interval=1000, # validation
|
||||
save_interval=10000, # checkpoint
|
||||
max_iteration=900000, # max iteration to train
|
||||
)
|
||||
)
|
||||
|
||||
def get_default_training_config():
|
||||
return _C.clone()
|
|
@ -0,0 +1,180 @@
|
|||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader, DistributedBatchSampler
|
||||
from tensorboardX import SummaryWriter
|
||||
from collections import defaultdict
|
||||
|
||||
import parakeet
|
||||
from parakeet.utils import checkpoint, mp_tools
|
||||
|
||||
class ExperimentBase(object):
|
||||
"""
|
||||
An experiment template in order to structure the training code and take care of saving, loading, logging, visualization stuffs. It's intended to be flexible and simple.
|
||||
|
||||
So it only handles output directory (create directory for the outut, create a checkpoint directory, dump the config in use and create visualizer and logger)in a standard way without restricting the input/output protocols of the model and dataloader. It leaves the main part for the user to implement their own(setup the model, criterion, optimizer, defaine a training step, define a validation function and customize all the text and visual logs).
|
||||
|
||||
It does not save too much boilerplate code. The users still have to write the forward/backward/update mannually, but they are free to add non-standard behaviors if needed.
|
||||
|
||||
We have some conventions to follow.
|
||||
1. Experiment should have `.model`, `.optimizer`, `.train_loader` and `.valid_loader`, `.config`, `.args` attributes.
|
||||
2. The config should have a `.training` field, which has `valid_interval`, `save_interval` and `max_iteration` keys. It is used as the trigger to invoke validation, checkpointing and stop of the experiment.
|
||||
3. There are three method, namely `train_batch`, `valid`, `setup_model` and `setup_dataloader` that should be implemented.
|
||||
|
||||
Feel free to add/overwrite other methods and standalone functions if you need.
|
||||
|
||||
Examples:
|
||||
--------
|
||||
def main_sp(config, args):
|
||||
exp = Experiment(config, args)
|
||||
exp.setup()
|
||||
exp.run()
|
||||
|
||||
def main(config, args):
|
||||
if args.nprocs > 1 and args.device == "gpu":
|
||||
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
||||
else:
|
||||
main_sp(config, args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = get_cfg_defaults()
|
||||
parser = default_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
print(args)
|
||||
|
||||
main(config, args)
|
||||
|
||||
"""
|
||||
def __init__(self, config, args):
|
||||
self.config = config
|
||||
self.args = args
|
||||
|
||||
def setup(self):
|
||||
if self.parallel:
|
||||
self.init_parallel()
|
||||
|
||||
self.setup_output_dir()
|
||||
self.dump_config()
|
||||
self.setup_visualizer()
|
||||
self.setup_logger()
|
||||
self.setup_checkpointer()
|
||||
|
||||
self.setup_dataloader()
|
||||
self.setup_model()
|
||||
|
||||
self.iteration = 0
|
||||
self.epoch = 0
|
||||
|
||||
@property
|
||||
def parallel(self):
|
||||
return self.args.device == "gpu" and self.args.nprocs > 1
|
||||
|
||||
def init_parallel(self):
|
||||
dist.init_parallel_env()
|
||||
|
||||
def save(self):
|
||||
checkpoint.save_parameters(
|
||||
self.checkpoint_dir, self.iteration, self.model, self.optimizer)
|
||||
|
||||
def resume_or_load(self):
|
||||
iteration = checkpoint.load_parameters(
|
||||
self.model,
|
||||
self.optimizer,
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_path=self.args.checkpoint_path)
|
||||
self.iteration = iteration
|
||||
|
||||
def read_batch(self):
|
||||
try:
|
||||
batch = next(self.iterator)
|
||||
except StopIteration:
|
||||
self.new_epoch()
|
||||
batch = next(self.iterator)
|
||||
return batch
|
||||
|
||||
def new_epoch(self):
|
||||
self.epoch += 1
|
||||
if self.parallel:
|
||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
||||
self.iterator = iter(self.train_loader)
|
||||
|
||||
def train(self):
|
||||
self.new_epoch()
|
||||
while self.iteration <= self.config.training.max_iteration:
|
||||
self.iteration += 1
|
||||
self.train_batch()
|
||||
|
||||
if self.iteration % self.config.training.valid_interval == 0:
|
||||
self.valid()
|
||||
|
||||
if self.iteration % self.config.training.save_interval == 0:
|
||||
self.save()
|
||||
|
||||
def run(self):
|
||||
self.resume_or_load()
|
||||
try:
|
||||
self.train()
|
||||
except KeyboardInterrupt:
|
||||
self.save()
|
||||
exit(-1)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def setup_output_dir(self):
|
||||
# output dir
|
||||
output_dir = Path(self.args.output).expanduser()
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.output_dir = output_dir
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def setup_checkpointer(self):
|
||||
# checkpoint dir
|
||||
checkpoint_dir = self.output_dir / "checkpoints"
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def setup_visualizer(self):
|
||||
# visualizer
|
||||
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
||||
|
||||
self.visualizer = visualizer
|
||||
|
||||
def setup_logger(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel("INFO")
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
logger.addHandler(logging.FileHandler(str(log_file)))
|
||||
|
||||
self.logger = logger
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def dump_config(self):
|
||||
with open(self.output_dir / "config.yaml", 'wt') as f:
|
||||
print(self.config, file=f)
|
||||
|
||||
def train_batch(self):
|
||||
raise NotImplementedError("train_batch should be implemented.")
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
raise NotImplementedError("valid should be implemented.")
|
||||
|
||||
def setup_model(self):
|
||||
raise NotImplementedError("setup_model should be implemented.")
|
||||
|
||||
def setup_dataloader(self):
|
||||
raise NotImplementedError("setup_dataloader should be implemented.")
|
||||
|
|
@ -12,4 +12,4 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import io, layer_tools, scheduler, display
|
||||
from . import checkpoint, layer_tools, scheduler, display, mp_tools
|
||||
|
|
|
@ -14,41 +14,19 @@
|
|||
|
||||
import os
|
||||
import time
|
||||
|
||||
import ruamel.yaml
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle.fluid.framework import convert_np_dtype_to_dtype_ as convert_np_dtype
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
|
||||
from parakeet.utils import mp_tools
|
||||
|
||||
__all__ = ["load_parameters", "save_parameters"]
|
||||
|
||||
|
||||
def is_main_process():
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
return local_rank == 0
|
||||
|
||||
|
||||
def add_yaml_config_to_args(config):
|
||||
""" Add args in yaml config to the args parsed by argparse. The argument in
|
||||
yaml config will be overwritten by the same argument in argparse if they
|
||||
are both valid.
|
||||
|
||||
Args:
|
||||
config (args): the args returned by `argparse.ArgumentParser().parse_args()`
|
||||
|
||||
Returns:
|
||||
config: the args added yaml config.
|
||||
"""
|
||||
with open(config.config, 'rt') as f:
|
||||
yaml_cfg = ruamel.yaml.safe_load(f)
|
||||
cfg_vars = vars(config)
|
||||
for k, v in yaml_cfg.items():
|
||||
if k in cfg_vars and cfg_vars[k] is not None:
|
||||
continue
|
||||
cfg_vars[k] = v
|
||||
return config
|
||||
|
||||
|
||||
def _load_latest_checkpoint(checkpoint_dir):
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||
def _load_latest_checkpoint(checkpoint_dir: str) -> int:
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
|
@ -57,19 +35,17 @@ def _load_latest_checkpoint(checkpoint_dir):
|
|||
int: the latest iteration number.
|
||||
"""
|
||||
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Create checkpoint index file if not exist.
|
||||
if (not os.path.isfile(checkpoint_record)):
|
||||
return 0
|
||||
|
||||
# Fetch the latest checkpoint index.
|
||||
with open(checkpoint_record, "r") as handle:
|
||||
with open(checkpoint_record, "rt") as handle:
|
||||
latest_checkpoint = handle.readline().split()[-1]
|
||||
iteration = int(latest_checkpoint.split("-")[-1])
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
def _save_checkpoint(checkpoint_dir, iteration):
|
||||
def _save_checkpoint(checkpoint_dir: str, iteration: int):
|
||||
"""Save the iteration number of the latest model to be checkpointed.
|
||||
|
||||
Args:
|
||||
|
@ -81,24 +57,20 @@ def _save_checkpoint(checkpoint_dir, iteration):
|
|||
"""
|
||||
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Update the latest checkpoint index.
|
||||
with open(checkpoint_record, "w") as handle:
|
||||
with open(checkpoint_record, "wt") as handle:
|
||||
handle.write("model_checkpoint_path: step-{}".format(iteration))
|
||||
|
||||
|
||||
def load_parameters(model,
|
||||
optimizer=None,
|
||||
checkpoint_dir=None,
|
||||
iteration=None,
|
||||
checkpoint_path=None):
|
||||
"""Load a specific model checkpoint from disk.
|
||||
|
||||
Args:
|
||||
model (obj): model to load parameters.
|
||||
optimizer (obj, optional): optimizer to load states if needed.
|
||||
model (Layer): model to load parameters.
|
||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
||||
iteration (int, optional): if specified, load the specific checkpoint,
|
||||
if not specified, load the latest one. Defaults to None.
|
||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
||||
stored in the checkpoint_path and the argument 'checkpoint_dir' will
|
||||
be ignored. Defaults to None.
|
||||
|
@ -110,8 +82,7 @@ def load_parameters(model,
|
|||
if checkpoint_path is not None:
|
||||
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
|
||||
elif checkpoint_dir is not None:
|
||||
if iteration is None:
|
||||
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||
iteration = _load_latest_checkpoint(checkpoint_dir)
|
||||
if iteration == 0:
|
||||
return iteration
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
|
@ -121,52 +92,49 @@ def load_parameters(model,
|
|||
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
|
||||
)
|
||||
|
||||
local_rank = dg.parallel.Env().local_rank
|
||||
model_dict, optimizer_dict = dg.load_dygraph(checkpoint_path)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
# cast to desired data type, for mixed-precision training/inference.
|
||||
for k, v in model_dict.items():
|
||||
if k in state_dict and convert_np_dtype(v.dtype) != state_dict[
|
||||
k].dtype:
|
||||
model_dict[k] = v.astype(state_dict[k].numpy().dtype)
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
params_path = checkpoint_path + ".pdparams"
|
||||
model_dict = paddle.load(params_path)
|
||||
model.set_state_dict(model_dict)
|
||||
|
||||
print("[checkpoint] Rank {}: loaded model from {}.pdparams".format(
|
||||
local_rank, checkpoint_path))
|
||||
|
||||
if optimizer and optimizer_dict:
|
||||
print("[checkpoint] Rank {}: loaded model from {}".format(
|
||||
local_rank, params_path))
|
||||
|
||||
optimizer_path = checkpoint_path + ".pdopt"
|
||||
if optimizer and os.path.isfile(optimizer_path):
|
||||
optimizer_dict = paddle.load(optimizer_path)
|
||||
optimizer.set_state_dict(optimizer_dict)
|
||||
print("[checkpoint] Rank {}: loaded optimizer state from {}.pdopt".
|
||||
format(local_rank, checkpoint_path))
|
||||
print("[checkpoint] Rank {}: loaded optimizer state from {}".
|
||||
format(local_rank, optimizer_path))
|
||||
|
||||
return iteration
|
||||
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
model (obj): model to be checkpointed.
|
||||
optimizer (obj, optional): optimizer to be checkpointed.
|
||||
model (Layer): model to be checkpointed.
|
||||
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
|
||||
|
||||
model_dict = model.state_dict()
|
||||
dg.save_dygraph(model_dict, checkpoint_path)
|
||||
print("[checkpoint] Saved model to {}.pdparams".format(checkpoint_path))
|
||||
params_path = checkpoint_path + ".pdparams"
|
||||
paddle.save(model_dict, params_path)
|
||||
print("[checkpoint] Saved model to {}".format(params_path))
|
||||
|
||||
if optimizer:
|
||||
opt_dict = optimizer.state_dict()
|
||||
dg.save_dygraph(opt_dict, checkpoint_path)
|
||||
print("[checkpoint] Saved optimzier state to {}.pdopt".format(
|
||||
checkpoint_path))
|
||||
optimizer_path = checkpoint_path + ".pdopt"
|
||||
paddle.save(opt_dict, optimizer_path)
|
||||
print("[checkpoint] Saved optimzier state to {}".format(
|
||||
optimizer_path))
|
||||
|
||||
_save_checkpoint(checkpoint_dir, iteration)
|
|
@ -2,6 +2,9 @@ import numpy as np
|
|||
import matplotlib
|
||||
from matplotlib import cm, pyplot
|
||||
|
||||
__all__ = ["pack_attention_images", "add_attention_plots", "min_max_normalize"]
|
||||
|
||||
|
||||
def pack_attention_images(attention_weights, rotate=False):
|
||||
# add a box
|
||||
attention_weights = np.pad(attention_weights,
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import numpy as np
|
||||
from paddle.framework import core
|
||||
|
||||
__all__ = ["convert_dtype_to_np_dtype_"]
|
||||
|
||||
|
||||
def convert_dtype_to_np_dtype_(dtype):
|
||||
"""
|
||||
Convert paddle's data type to corrsponding numpy data type.
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
import numpy as np
|
||||
from paddle import nn
|
||||
|
||||
__all__ = ["summary","gradient_norm", "freeze", "unfreeze"]
|
||||
|
||||
|
||||
def summary(layer: nn.Layer):
|
||||
num_params = num_elements = 0
|
||||
|
|
|
@ -2,6 +2,9 @@ import paddle
|
|||
from paddle import distributed as dist
|
||||
from functools import wraps
|
||||
|
||||
__all__ = ["rank_zero_only"]
|
||||
|
||||
|
||||
def rank_zero_only(func):
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import math
|
||||
|
||||
__all__ = ["SchedulerBase", "Constant", "PieceWise", "StepWise"]
|
||||
|
||||
|
||||
class SchedulerBase(object):
|
||||
def __call__(self, step):
|
||||
raise NotImplementedError("You should implement the __call__ method.")
|
||||
|
|
Loading…
Reference in New Issue