diff --git a/examples/transformer_tts/synthesize.py b/examples/transformer_tts/synthesize.py index 6758819..5c5ddee 100644 --- a/examples/transformer_tts/synthesize.py +++ b/examples/transformer_tts/synthesize.py @@ -17,13 +17,15 @@ import time from pathlib import Path import numpy as np import paddle +import matplotlib +from matplotlib import pyplot as plt import parakeet from parakeet.frontend import English from parakeet.models.transformer_tts import TransformerTTS from parakeet.utils import scheduler from parakeet.training.cli import default_argument_parser -from parakeet.utils.display import add_attention_plots +from parakeet.utils.display import add_attention_plots, pack_attention_images from config import get_cfg_defaults @@ -49,7 +51,16 @@ def main(config, args): for i, sentence in enumerate(sentences): outputs = model.predict(sentence, verbose=args.verbose) mel_output = outputs["mel_output"] - # cross_attention_weights = outputs["cross_attention_weights"] + cross_attention_weights = outputs["cross_attention_weights"] + attns = [attn for attn in cross_attention_weights] + + fig = plt.figure(figsize=(40, 40)) + for j, attn in enumerate(attns): + plt.subplot(1, 4, j+1) + plt.imshow(attn[0]) + plt.tight_layout() + plt.savefig(str(output_dir / f"sentence_{i}.png")) + mel_output = mel_output.T #(C, T) np.save(str(output_dir / f"sentence_{i}"), mel_output) if args.verbose: diff --git a/examples/waveflow/synthesize.py b/examples/waveflow/synthesize.py index 45c751a..0dfe1c6 100644 --- a/examples/waveflow/synthesize.py +++ b/examples/waveflow/synthesize.py @@ -21,6 +21,7 @@ import paddle import parakeet from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow from parakeet.utils import layer_tools, checkpoint +import time from config import get_cfg_defaults @@ -34,9 +35,10 @@ def main(config, args): mel_dir = Path(args.input).expanduser() output_dir = Path(args.output).expanduser() output_dir.mkdir(parents=True, exist_ok=True) - for file_path in mel_dir.iterdir(): + for file_path in mel_dir.glob("*.npy"): mel = np.load(str(file_path)) - audio = model.predict(mel) + with paddle.amp.auto_cast(): + audio = model.predict(mel) audio_path = output_dir / ( os.path.splitext(file_path.name)[0] + ".wav") sf.write(audio_path, audio, config.data.sample_rate) diff --git a/parakeet/frontend/punctuation.py b/parakeet/frontend/punctuation.py index e50bca4..099e759 100644 --- a/parakeet/frontend/punctuation.py +++ b/parakeet/frontend/punctuation.py @@ -18,6 +18,7 @@ import string __all__ = ["get_punctuations"] EN_PUNCT = [ + " ", "-", "...", ",", diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index 3cc5ac9..e28f8df 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -22,6 +22,7 @@ from parakeet.modules.conv import Conv1dBatchNorm from parakeet.modules.attention import LocationSensitiveAttention from parakeet.modules import masking from parakeet.utils import checkpoint +from tqdm import trange __all__ = ["Tacotron2", "Tacotron2Loss"] @@ -475,10 +476,11 @@ class Tacotron2Decoder(nn.Layer): dtype=key.dtype) #[B, C] self._initialize_decoder_states(key) + T_enc = key.shape[1] self.mask = None mel_outputs, stop_logits, alignments = [], [], [] - while True: + for _ in trange(max_decoder_steps): query = self.prenet(query) mel_output, stop_logit, alignment = self._decode(query) @@ -487,8 +489,12 @@ class Tacotron2Decoder(nn.Layer): alignments += [alignment] if F.sigmoid(stop_logit) > stop_threshold: + print("hits stop condition!") break - elif len(mel_outputs) == max_decoder_steps: + if int(paddle.argmax(alignment[0])) == T_enc - 1: + print("content exhausted!") + break + if len(mel_outputs) == max_decoder_steps: print("Warning! Reached max decoder steps!!!") break @@ -718,7 +724,7 @@ class Tacotron2(nn.Layer): """ embedded_inputs = self.embedding(text_inputs) if self.toned: - embedded_inputs = paddle.concat([embedded_inputs, self.embedding_tones(tones)], -1) + embedded_inputs += self.embedding_tones(tones) encoder_outputs = self.encoder(embedded_inputs) mel_outputs, stop_logits, alignments = self.decoder.infer( encoder_outputs, diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index 64945eb..7259f9c 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -340,13 +340,15 @@ class CNNPostNet(nn.Layer): c_in = d_input if i == 0 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden self.convs.append( - nn.Conv1D( + Conv1dBatchNorm( c_in, c_out, kernel_size, weight_attr=I.XavierUniform(), - padding=padding)) - # self.last_bn = nn.BatchNorm1D(d_output) + padding=padding, + momentum=0.99, + epsilon=1e-03)) + self.last_bn = nn.BatchNorm1D(d_output, momentum=0.99, epsilon=1e-3) # for a layer that ends with a normalization layer that is targeted to # output a non zero-central output, it may take a long time to # train the scale and bias @@ -359,8 +361,8 @@ class CNNPostNet(nn.Layer): if i != (len(self.convs) - 1): x = F.tanh(x) # TODO: check it - x = x_in + x - # x = self.last_bn(x_in + x) + # x = x_in + x + x = self.last_bn(x_in + x) return x @@ -567,8 +569,13 @@ class TransformerTTS(nn.Layer): text_ids = paddle.to_tensor(self.frontend(input)) input = paddle.unsqueeze(text_ids, 0) # (1, T) outputs = self.infer(input, max_length=max_length, verbose=verbose) - outputs = {k: v[0].numpy() for k, v in outputs.items()} - return outputs + npy_outputs = { + "mel_output": outputs["mel_output"][0].numpy(), + "encoder_attention_weights": [item[0].numpy() for item in outputs["encoder_attention_weights"]], + "cross_attention_weights": [item[0].numpy() for item in outputs["cross_attention_weights"]], + } + + return npy_outputs def set_constants(self, reduction_factor, drop_n_heads): self.r = reduction_factor diff --git a/parakeet/models/waveflow.py b/parakeet/models/waveflow.py index 625e61f..6055497 100644 --- a/parakeet/models/waveflow.py +++ b/parakeet/models/waveflow.py @@ -19,6 +19,7 @@ import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +import time from parakeet.utils import checkpoint from parakeet.modules import geometry as geo @@ -798,10 +799,13 @@ class ConditionalWaveFlow(nn.LayerList): Tensor : [shape=(B, T)] The synthesized audio, where``T <= T_mel \* upsample_factors``. """ + start = time.time() condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T) batch_size, _, time_steps = condition.shape z = paddle.randn([batch_size, time_steps], dtype=mel.dtype) x = self.decoder.inverse(z, condition) + end = time.time() + print("time: {}s".format(end - start)) return x @paddle.no_grad() diff --git a/parakeet/utils/display.py b/parakeet/utils/display.py index 6c13931..414cbc0 100644 --- a/parakeet/utils/display.py +++ b/parakeet/utils/display.py @@ -14,6 +14,8 @@ import numpy as np import matplotlib +import librosa +import librosa.display matplotlib.use("Agg") import matplotlib.pylab as plt from matplotlib import cm, pyplot @@ -69,6 +71,7 @@ def plot_alignment(alignment, title=None): plt.xlabel(xlabel) plt.ylabel('Encoder timestep') plt.tight_layout() + return fig fig.canvas.draw() data = save_figure_to_numpy(fig)