1. add interfaces for inference;;
2. add a function to recursively remove weight norm; 3. wavenet: fix weight norm dimension: explicitly specify dim=1 instead of -1.
This commit is contained in:
parent
b2bd479f46
commit
796e0b1e1f
|
@ -26,6 +26,7 @@ from parakeet.modules import masking
|
|||
from parakeet.modules.conv import Conv1dBatchNorm
|
||||
from parakeet.modules import positional_encoding as pe
|
||||
from parakeet.modules import losses as L
|
||||
from parakeet.utils import checkpoint, scheduler
|
||||
|
||||
__all__ = ["TransformerTTS", "TransformerTTSLoss"]
|
||||
|
||||
|
@ -285,8 +286,7 @@ class TransformerDecoder(nn.LayerList):
|
|||
d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder))
|
||||
|
||||
def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0):
|
||||
"""[summary]
|
||||
|
||||
"""
|
||||
Args:
|
||||
q (Tensor): shape(batch_size, time_steps_q, d_model)
|
||||
k (Tensor): shape(batch_size, time_steps_k, d_encoder)
|
||||
|
@ -330,40 +330,6 @@ class MLPPreNet(nn.Layer):
|
|||
return l3
|
||||
|
||||
|
||||
# NOTE: not used in
|
||||
class CNNPreNet(nn.Layer):
|
||||
def __init__(self,
|
||||
d_input,
|
||||
d_hidden,
|
||||
d_output,
|
||||
kernel_size,
|
||||
n_layers,
|
||||
dropout=0.):
|
||||
# (conv + bn + relu + dropout) * n + last projection
|
||||
super(CNNPreNet, self).__init__()
|
||||
self.convs = nn.LayerList()
|
||||
c_in = d_input
|
||||
for _ in range(n_layers):
|
||||
self.convs.append(
|
||||
Conv1dBatchNorm(
|
||||
c_in,
|
||||
d_hidden,
|
||||
kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding="same",
|
||||
data_format="NLC"))
|
||||
c_in = d_hidden
|
||||
self.affine_out = nn.Linear(d_hidden, d_output)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.convs:
|
||||
x = F.dropout(
|
||||
F.relu(layer(x)), self.dropout, training=self.training)
|
||||
x = self.affine_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class CNNPostNet(nn.Layer):
|
||||
def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers):
|
||||
super(CNNPostNet, self).__init__()
|
||||
|
@ -536,7 +502,8 @@ class TransformerTTS(nn.Layer):
|
|||
|
||||
return mel_output, mel_intermediate, cross_attention_weights, stop_logits
|
||||
|
||||
def predict(self, input, raw_input=True, max_length=1000, verbose=True):
|
||||
@paddle.no_grad()
|
||||
def infer(self, input, max_length=1000, verbose=True):
|
||||
"""Predict log scale magnitude mel spectrogram from text input.
|
||||
|
||||
Args:
|
||||
|
@ -544,19 +511,13 @@ class TransformerTTS(nn.Layer):
|
|||
max_length (int, optional): max decoder steps. Defaults to 1000.
|
||||
verbose (bool, optional): display progress bar. Defaults to True.
|
||||
"""
|
||||
if raw_input:
|
||||
text_ids = paddle.to_tensor(self.frontend(input))
|
||||
text_input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
||||
else:
|
||||
text_input = input
|
||||
|
||||
decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||
decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C)
|
||||
|
||||
# encoder the text sequence
|
||||
encoder_output, encoder_attentions, encoder_padding_mask = self.encode(
|
||||
text_input)
|
||||
for _ in range(int(max_length // self.r) + 1):
|
||||
input)
|
||||
for _ in trange(int(max_length // self.r) + 1):
|
||||
mel_output, _, cross_attention_weights, stop_logits = self.decode(
|
||||
encoder_output, decoder_input, encoder_padding_mask)
|
||||
|
||||
|
@ -584,10 +545,45 @@ class TransformerTTS(nn.Layer):
|
|||
}
|
||||
return outputs
|
||||
|
||||
@paddle.no_grad()
|
||||
def predict(self, input, max_length=1000, verbose=True):
|
||||
text_ids = paddle.to_tensor(self.frontend(input))
|
||||
input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
||||
outputs = self.infer(input, max_length=max_length, verbose=verbose)
|
||||
outputs = {k: v[0].numpy() for k, v in outputs.items()}
|
||||
return outputs
|
||||
|
||||
def set_constants(self, reduction_factor, drop_n_heads):
|
||||
self.r = reduction_factor
|
||||
self.drop_n_heads = drop_n_heads
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, frontend, config, checkpoint_path):
|
||||
model = TransformerTTS(
|
||||
frontend,
|
||||
d_encoder=config.model.d_encoder,
|
||||
d_decoder=config.model.d_decoder,
|
||||
d_mel=config.data.d_mel,
|
||||
n_heads=config.model.n_heads,
|
||||
d_ffn=config.model.d_ffn,
|
||||
encoder_layers=config.model.encoder_layers,
|
||||
decoder_layers=config.model.decoder_layers,
|
||||
d_prenet=config.model.d_prenet,
|
||||
d_postnet=config.model.d_postnet,
|
||||
postnet_layers=config.model.postnet_layers,
|
||||
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||
max_reduction_factor=config.model.max_reduction_factor,
|
||||
decoder_prenet_dropout=config.model.decoder_prenet_dropout,
|
||||
dropout=config.model.dropout)
|
||||
|
||||
iteration = checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
drop_n_heads = scheduler.StepWise(config.training.drop_n_heads)
|
||||
reduction_factor = scheduler.StepWise(config.training.reduction_factor)
|
||||
model.set_constants(
|
||||
reduction_factor=reduction_factor(iteration),
|
||||
drop_n_heads=drop_n_heads(iteration))
|
||||
return model
|
||||
|
||||
|
||||
class TransformerTTSLoss(nn.Layer):
|
||||
def __init__(self, stop_loss_scale):
|
||||
|
@ -618,34 +614,3 @@ class TransformerTTSLoss(nn.Layer):
|
|||
stop_loss=stop_loss # stop prob loss
|
||||
)
|
||||
return losses
|
||||
|
||||
|
||||
class AdaptiveTransformerTTSLoss(nn.Layer):
|
||||
def __init__(self):
|
||||
super(AdaptiveTransformerTTSLoss, self).__init__()
|
||||
|
||||
def forward(self, mel_output, mel_intermediate, mel_target, stop_logits,
|
||||
stop_probs):
|
||||
mask = masking.feature_mask(
|
||||
mel_target, axis=-1, dtype=mel_target.dtype)
|
||||
mask1 = paddle.unsqueeze(mask, -1)
|
||||
mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1)
|
||||
mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1)
|
||||
|
||||
batch_size, mel_len = mask.shape
|
||||
valid_lengths = mask.sum(-1).astype("int64")
|
||||
last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len)
|
||||
stop_loss_scale = valid_lengths.sum() / batch_size - 1
|
||||
mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(
|
||||
mask.dtype)
|
||||
stop_loss = L.masked_softmax_with_cross_entropy(
|
||||
stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1))
|
||||
|
||||
loss = mel_loss1 + mel_loss2 + stop_loss
|
||||
losses = dict(
|
||||
loss=loss, # total loss
|
||||
mel_loss1=mel_loss1, # ouput mel loss
|
||||
mel_loss2=mel_loss2, # intermediate mel loss
|
||||
stop_loss=stop_loss # stop prob loss
|
||||
)
|
||||
return losses
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
import math
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from parakeet.utils import checkpoint
|
||||
from parakeet.modules import geometry as geo
|
||||
|
||||
__all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"]
|
||||
|
@ -478,10 +480,23 @@ class WaveFlow(nn.LayerList):
|
|||
|
||||
|
||||
class ConditionalWaveFlow(nn.LayerList):
|
||||
def __init__(self, encoder, decoder):
|
||||
def __init__(self,
|
||||
upsample_factors: List[int],
|
||||
n_flows: int,
|
||||
n_layers: int,
|
||||
n_group: int,
|
||||
channels: int,
|
||||
n_mels: int,
|
||||
kernel_size: Union[int, List[int]]):
|
||||
super(ConditionalWaveFlow, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.encoder = UpsampleNet(upsample_factors)
|
||||
self.decoder = WaveFlow(
|
||||
n_flows=n_flows,
|
||||
n_layers=n_layers,
|
||||
n_group=n_group,
|
||||
channels=channels,
|
||||
mel_bands=n_mels,
|
||||
kernel_size=kernel_size)
|
||||
|
||||
def forward(self, audio, mel):
|
||||
condition = self.encoder(mel)
|
||||
|
@ -489,12 +504,33 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
return z, log_det_jacobian
|
||||
|
||||
@paddle.no_grad()
|
||||
def synthesize(self, mel):
|
||||
def infer(self, mel):
|
||||
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
||||
x = self.decoder.inverse(z, condition)
|
||||
return x
|
||||
|
||||
@paddle.no_grad()
|
||||
def predict(self, mel):
|
||||
mel = paddle.to_tensor(mel)
|
||||
mel = paddle.unsqueeze(mel, 0)
|
||||
audio = self.infer(mel)
|
||||
audio = audio[0].numpy()
|
||||
return audio
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, checkpoint_path):
|
||||
model = cls(
|
||||
upsample_factors=config.model.upsample_factors,
|
||||
n_flows=config.model.n_flows,
|
||||
n_layers=config.model.n_layers,
|
||||
n_group=config.model.n_group,
|
||||
channels=config.model.channels,
|
||||
n_mels=config.data.n_mels,
|
||||
kernel_size=config.model.kernel_size)
|
||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
class WaveFlowLoss(nn.Layer):
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
import math
|
||||
import time
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, List
|
||||
from tqdm import trange
|
||||
import numpy as np
|
||||
|
||||
|
@ -26,6 +26,7 @@ import paddle.fluid.layers.distributions as D
|
|||
|
||||
from parakeet.modules.conv import Conv1dCell
|
||||
from parakeet.modules.audio import quantize, dequantize, STFT
|
||||
from parakeet.utils import checkpoint, layer_tools
|
||||
|
||||
|
||||
def crop(x, audio_start, audio_length):
|
||||
|
@ -290,18 +291,18 @@ class WaveNet(nn.Layer):
|
|||
if (output_dim % 3 != 0):
|
||||
raise ValueError(
|
||||
"with Mixture of Gaussians(mog) output, the output dim must be divisible by 3, but get {}".format(output_dim))
|
||||
self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=-1)
|
||||
self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=1)
|
||||
|
||||
self.resnet = ResidualNet(n_stack, n_loop, residual_channels,
|
||||
condition_dim, filter_size)
|
||||
self.context_size = self.resnet.context_size
|
||||
|
||||
skip_channels = residual_channels # assume the same channel
|
||||
self.proj1 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=-1)
|
||||
self.proj2 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=-1)
|
||||
self.proj1 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=1)
|
||||
self.proj2 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=1)
|
||||
# if loss_type is softmax, output_dim is n_vocab of waveform magnitude.
|
||||
# if loss_type is mog, output_dim is 3 * gaussian, (weight, mean and stddev)
|
||||
self.proj3 = nn.utils.weight_norm(nn.Linear(skip_channels, output_dim), dim=-1)
|
||||
self.proj3 = nn.utils.weight_norm(nn.Linear(skip_channels, output_dim), dim=1)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.output_dim = output_dim
|
||||
|
@ -509,17 +510,29 @@ class WaveNet(nn.Layer):
|
|||
return self.compute_mog_loss(y, t)
|
||||
|
||||
|
||||
class ConditionalWavenet(nn.Layer):
|
||||
def __init__(self, encoder, decoder):
|
||||
class ConditionalWaveNet(nn.Layer):
|
||||
def __init__(self,
|
||||
upsample_factors: List[int],
|
||||
n_stack: int,
|
||||
n_loop: int,
|
||||
residual_channels: int,
|
||||
output_dim: int,
|
||||
n_mels: int,
|
||||
filter_size: int=2,
|
||||
loss_type: str="mog",
|
||||
log_scale_min: float=-9.0):
|
||||
"""Conditional Wavenet, which contains an UpsampleNet as the encoder and a WaveNet as the decoder. It is an autoregressive model.
|
||||
|
||||
Args:
|
||||
encoder (UpsampleNet): the UpsampleNet as the encoder.
|
||||
decoder (WaveNet): the WaveNet as the decoder.
|
||||
"""
|
||||
super(ConditionalWavenet, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super(ConditionalWaveNet, self).__init__()
|
||||
self.encoder = UpsampleNet(upsample_factors)
|
||||
self.decoder = WaveNet(n_stack=n_stack,
|
||||
n_loop=n_loop,
|
||||
residual_channels=residual_channels,
|
||||
output_dim=output_dim,
|
||||
condition_dim=n_mels,
|
||||
filter_size=filter_size,
|
||||
loss_type=loss_type,
|
||||
log_scale_min=log_scale_min)
|
||||
|
||||
def forward(self, audio, mel, audio_start):
|
||||
"""Compute the output distribution given the mel spectrogram and the input(for teacher force training).
|
||||
|
@ -570,7 +583,7 @@ class ConditionalWavenet(nn.Layer):
|
|||
return samples
|
||||
|
||||
@paddle.no_grad()
|
||||
def synthesis(self, mel):
|
||||
def infer(self, mel):
|
||||
"""Synthesize waveform from mel spectrogram.
|
||||
|
||||
Args:
|
||||
|
@ -595,3 +608,29 @@ class ConditionalWavenet(nn.Layer):
|
|||
|
||||
samples = paddle.concat(samples, -1)
|
||||
return samples
|
||||
|
||||
@paddle.no_grad()
|
||||
def predict(self, mel):
|
||||
mel = paddle.to_tensor(mel)
|
||||
mel = paddle.unsqueeze(mel, 0)
|
||||
audio = self.infer(mel)
|
||||
audio = audio[0].numpy()
|
||||
return audio
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, checkpoint_path):
|
||||
model = cls(
|
||||
upsample_factors=config.model.upsample_factors,
|
||||
n_stack=config.model.n_stack,
|
||||
n_loop=config.model.n_loop,
|
||||
residual_channels=config.model.residual_channels,
|
||||
output_dim=config.model.output_dim,
|
||||
n_mels=config.data.n_mels,
|
||||
filter_size=config.model.filter_size,
|
||||
loss_type=config.model.loss_type,
|
||||
log_scale_min=config.model.log_scale_min)
|
||||
layer_tools.summary(model)
|
||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -36,6 +36,14 @@ def gradient_norm(layer: nn.Layer):
|
|||
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
|
||||
return grad_norm_dict
|
||||
|
||||
def recursively_remove_weight_norm(layer: nn.Layer):
|
||||
for layer in layer.sublayers():
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except:
|
||||
# ther is not weight norm hoom in this layer
|
||||
pass
|
||||
|
||||
def freeze(layer: nn.Layer):
|
||||
for param in layer.parameters():
|
||||
param.trainable = False
|
||||
|
|
Loading…
Reference in New Issue