diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index fedd58e..839e18c 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -26,6 +26,7 @@ from parakeet.modules import masking from parakeet.modules.conv import Conv1dBatchNorm from parakeet.modules import positional_encoding as pe from parakeet.modules import losses as L +from parakeet.utils import checkpoint, scheduler __all__ = ["TransformerTTS", "TransformerTTSLoss"] @@ -285,8 +286,7 @@ class TransformerDecoder(nn.LayerList): d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder)) def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0): - """[summary] - + """ Args: q (Tensor): shape(batch_size, time_steps_q, d_model) k (Tensor): shape(batch_size, time_steps_k, d_encoder) @@ -330,40 +330,6 @@ class MLPPreNet(nn.Layer): return l3 -# NOTE: not used in -class CNNPreNet(nn.Layer): - def __init__(self, - d_input, - d_hidden, - d_output, - kernel_size, - n_layers, - dropout=0.): - # (conv + bn + relu + dropout) * n + last projection - super(CNNPreNet, self).__init__() - self.convs = nn.LayerList() - c_in = d_input - for _ in range(n_layers): - self.convs.append( - Conv1dBatchNorm( - c_in, - d_hidden, - kernel_size, - weight_attr=I.XavierUniform(), - padding="same", - data_format="NLC")) - c_in = d_hidden - self.affine_out = nn.Linear(d_hidden, d_output) - self.dropout = dropout - - def forward(self, x): - for layer in self.convs: - x = F.dropout( - F.relu(layer(x)), self.dropout, training=self.training) - x = self.affine_out(x) - return x - - class CNNPostNet(nn.Layer): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers): super(CNNPostNet, self).__init__() @@ -536,7 +502,8 @@ class TransformerTTS(nn.Layer): return mel_output, mel_intermediate, cross_attention_weights, stop_logits - def predict(self, input, raw_input=True, max_length=1000, verbose=True): + @paddle.no_grad() + def infer(self, input, max_length=1000, verbose=True): """Predict log scale magnitude mel spectrogram from text input. Args: @@ -544,19 +511,13 @@ class TransformerTTS(nn.Layer): max_length (int, optional): max decoder steps. Defaults to 1000. verbose (bool, optional): display progress bar. Defaults to True. """ - if raw_input: - text_ids = paddle.to_tensor(self.frontend(input)) - text_input = paddle.unsqueeze(text_ids, 0) # (1, T) - else: - text_input = input - decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) # encoder the text sequence encoder_output, encoder_attentions, encoder_padding_mask = self.encode( - text_input) - for _ in range(int(max_length // self.r) + 1): + input) + for _ in trange(int(max_length // self.r) + 1): mel_output, _, cross_attention_weights, stop_logits = self.decode( encoder_output, decoder_input, encoder_padding_mask) @@ -584,10 +545,45 @@ class TransformerTTS(nn.Layer): } return outputs + @paddle.no_grad() + def predict(self, input, max_length=1000, verbose=True): + text_ids = paddle.to_tensor(self.frontend(input)) + input = paddle.unsqueeze(text_ids, 0) # (1, T) + outputs = self.infer(input, max_length=max_length, verbose=verbose) + outputs = {k: v[0].numpy() for k, v in outputs.items()} + return outputs + def set_constants(self, reduction_factor, drop_n_heads): self.r = reduction_factor self.drop_n_heads = drop_n_heads + @classmethod + def from_pretrained(cls, frontend, config, checkpoint_path): + model = TransformerTTS( + frontend, + d_encoder=config.model.d_encoder, + d_decoder=config.model.d_decoder, + d_mel=config.data.d_mel, + n_heads=config.model.n_heads, + d_ffn=config.model.d_ffn, + encoder_layers=config.model.encoder_layers, + decoder_layers=config.model.decoder_layers, + d_prenet=config.model.d_prenet, + d_postnet=config.model.d_postnet, + postnet_layers=config.model.postnet_layers, + postnet_kernel_size=config.model.postnet_kernel_size, + max_reduction_factor=config.model.max_reduction_factor, + decoder_prenet_dropout=config.model.decoder_prenet_dropout, + dropout=config.model.dropout) + + iteration = checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) + drop_n_heads = scheduler.StepWise(config.training.drop_n_heads) + reduction_factor = scheduler.StepWise(config.training.reduction_factor) + model.set_constants( + reduction_factor=reduction_factor(iteration), + drop_n_heads=drop_n_heads(iteration)) + return model + class TransformerTTSLoss(nn.Layer): def __init__(self, stop_loss_scale): @@ -618,34 +614,3 @@ class TransformerTTSLoss(nn.Layer): stop_loss=stop_loss # stop prob loss ) return losses - - -class AdaptiveTransformerTTSLoss(nn.Layer): - def __init__(self): - super(AdaptiveTransformerTTSLoss, self).__init__() - - def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, - stop_probs): - mask = masking.feature_mask( - mel_target, axis=-1, dtype=mel_target.dtype) - mask1 = paddle.unsqueeze(mask, -1) - mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) - mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) - - batch_size, mel_len = mask.shape - valid_lengths = mask.sum(-1).astype("int64") - last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len) - stop_loss_scale = valid_lengths.sum() / batch_size - 1 - mask2 = mask + last_position.scale(stop_loss_scale - 1).astype( - mask.dtype) - stop_loss = L.masked_softmax_with_cross_entropy( - stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) - - loss = mel_loss1 + mel_loss2 + stop_loss - losses = dict( - loss=loss, # total loss - mel_loss1=mel_loss1, # ouput mel loss - mel_loss2=mel_loss2, # intermediate mel loss - stop_loss=stop_loss # stop prob loss - ) - return losses diff --git a/parakeet/models/waveflow.py b/parakeet/models/waveflow.py index 5398b22..cd4f3ed 100644 --- a/parakeet/models/waveflow.py +++ b/parakeet/models/waveflow.py @@ -1,10 +1,12 @@ import math import numpy as np +from typing import List, Union import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +from parakeet.utils import checkpoint from parakeet.modules import geometry as geo __all__ = ["UpsampleNet", "WaveFlow", "ConditionalWaveFlow", "WaveFlowLoss"] @@ -478,10 +480,23 @@ class WaveFlow(nn.LayerList): class ConditionalWaveFlow(nn.LayerList): - def __init__(self, encoder, decoder): + def __init__(self, + upsample_factors: List[int], + n_flows: int, + n_layers: int, + n_group: int, + channels: int, + n_mels: int, + kernel_size: Union[int, List[int]]): super(ConditionalWaveFlow, self).__init__() - self.encoder = encoder - self.decoder = decoder + self.encoder = UpsampleNet(upsample_factors) + self.decoder = WaveFlow( + n_flows=n_flows, + n_layers=n_layers, + n_group=n_group, + channels=channels, + mel_bands=n_mels, + kernel_size=kernel_size) def forward(self, audio, mel): condition = self.encoder(mel) @@ -489,12 +504,33 @@ class ConditionalWaveFlow(nn.LayerList): return z, log_det_jacobian @paddle.no_grad() - def synthesize(self, mel): + def infer(self, mel): condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T) batch_size, _, time_steps = condition.shape z = paddle.randn([batch_size, time_steps], dtype=mel.dtype) x = self.decoder.inverse(z, condition) return x + + @paddle.no_grad() + def predict(self, mel): + mel = paddle.to_tensor(mel) + mel = paddle.unsqueeze(mel, 0) + audio = self.infer(mel) + audio = audio[0].numpy() + return audio + + @classmethod + def from_pretrained(cls, config, checkpoint_path): + model = cls( + upsample_factors=config.model.upsample_factors, + n_flows=config.model.n_flows, + n_layers=config.model.n_layers, + n_group=config.model.n_group, + channels=config.model.channels, + n_mels=config.data.n_mels, + kernel_size=config.model.kernel_size) + checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) + return model class WaveFlowLoss(nn.Layer): diff --git a/parakeet/models/wavenet.py b/parakeet/models/wavenet.py index 7f5eb3d..135c8e4 100644 --- a/parakeet/models/wavenet.py +++ b/parakeet/models/wavenet.py @@ -14,7 +14,7 @@ import math import time -from typing import Union, Sequence +from typing import Union, Sequence, List from tqdm import trange import numpy as np @@ -26,6 +26,7 @@ import paddle.fluid.layers.distributions as D from parakeet.modules.conv import Conv1dCell from parakeet.modules.audio import quantize, dequantize, STFT +from parakeet.utils import checkpoint, layer_tools def crop(x, audio_start, audio_length): @@ -290,18 +291,18 @@ class WaveNet(nn.Layer): if (output_dim % 3 != 0): raise ValueError( "with Mixture of Gaussians(mog) output, the output dim must be divisible by 3, but get {}".format(output_dim)) - self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=-1) + self.embed = nn.utils.weight_norm(nn.Linear(1, residual_channels), dim=1) self.resnet = ResidualNet(n_stack, n_loop, residual_channels, condition_dim, filter_size) self.context_size = self.resnet.context_size skip_channels = residual_channels # assume the same channel - self.proj1 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=-1) - self.proj2 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=-1) + self.proj1 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=1) + self.proj2 = nn.utils.weight_norm(nn.Linear(skip_channels, skip_channels), dim=1) # if loss_type is softmax, output_dim is n_vocab of waveform magnitude. # if loss_type is mog, output_dim is 3 * gaussian, (weight, mean and stddev) - self.proj3 = nn.utils.weight_norm(nn.Linear(skip_channels, output_dim), dim=-1) + self.proj3 = nn.utils.weight_norm(nn.Linear(skip_channels, output_dim), dim=1) self.loss_type = loss_type self.output_dim = output_dim @@ -509,17 +510,29 @@ class WaveNet(nn.Layer): return self.compute_mog_loss(y, t) -class ConditionalWavenet(nn.Layer): - def __init__(self, encoder, decoder): +class ConditionalWaveNet(nn.Layer): + def __init__(self, + upsample_factors: List[int], + n_stack: int, + n_loop: int, + residual_channels: int, + output_dim: int, + n_mels: int, + filter_size: int=2, + loss_type: str="mog", + log_scale_min: float=-9.0): """Conditional Wavenet, which contains an UpsampleNet as the encoder and a WaveNet as the decoder. It is an autoregressive model. - - Args: - encoder (UpsampleNet): the UpsampleNet as the encoder. - decoder (WaveNet): the WaveNet as the decoder. """ - super(ConditionalWavenet, self).__init__() - self.encoder = encoder - self.decoder = decoder + super(ConditionalWaveNet, self).__init__() + self.encoder = UpsampleNet(upsample_factors) + self.decoder = WaveNet(n_stack=n_stack, + n_loop=n_loop, + residual_channels=residual_channels, + output_dim=output_dim, + condition_dim=n_mels, + filter_size=filter_size, + loss_type=loss_type, + log_scale_min=log_scale_min) def forward(self, audio, mel, audio_start): """Compute the output distribution given the mel spectrogram and the input(for teacher force training). @@ -570,7 +583,7 @@ class ConditionalWavenet(nn.Layer): return samples @paddle.no_grad() - def synthesis(self, mel): + def infer(self, mel): """Synthesize waveform from mel spectrogram. Args: @@ -595,3 +608,29 @@ class ConditionalWavenet(nn.Layer): samples = paddle.concat(samples, -1) return samples + + @paddle.no_grad() + def predict(self, mel): + mel = paddle.to_tensor(mel) + mel = paddle.unsqueeze(mel, 0) + audio = self.infer(mel) + audio = audio[0].numpy() + return audio + + @classmethod + def from_pretrained(cls, config, checkpoint_path): + model = cls( + upsample_factors=config.model.upsample_factors, + n_stack=config.model.n_stack, + n_loop=config.model.n_loop, + residual_channels=config.model.residual_channels, + output_dim=config.model.output_dim, + n_mels=config.data.n_mels, + filter_size=config.model.filter_size, + loss_type=config.model.loss_type, + log_scale_min=config.model.log_scale_min) + layer_tools.summary(model) + checkpoint.load_parameters(model, checkpoint_path=checkpoint_path) + return model + + diff --git a/parakeet/utils/layer_tools.py b/parakeet/utils/layer_tools.py index 2033eb0..31777a8 100644 --- a/parakeet/utils/layer_tools.py +++ b/parakeet/utils/layer_tools.py @@ -36,6 +36,14 @@ def gradient_norm(layer: nn.Layer): grad_norm_dict[name] = np.linalg.norm(grad) / grad.size return grad_norm_dict +def recursively_remove_weight_norm(layer: nn.Layer): + for layer in layer.sublayers(): + try: + nn.utils.remove_weight_norm(layer) + except: + # ther is not weight norm hoom in this layer + pass + def freeze(layer: nn.Layer): for param in layer.parameters(): param.trainable = False