WIP: baker

This commit is contained in:
chenfeiyu 2021-03-27 12:43:03 +08:00
parent 3c60fec900
commit a005cc88a3
7 changed files with 48 additions and 14 deletions

View File

@ -17,13 +17,15 @@ import time
from pathlib import Path
import numpy as np
import paddle
import matplotlib
from matplotlib import pyplot as plt
import parakeet
from parakeet.frontend import English
from parakeet.models.transformer_tts import TransformerTTS
from parakeet.utils import scheduler
from parakeet.training.cli import default_argument_parser
from parakeet.utils.display import add_attention_plots
from parakeet.utils.display import add_attention_plots, pack_attention_images
from config import get_cfg_defaults
@ -49,7 +51,16 @@ def main(config, args):
for i, sentence in enumerate(sentences):
outputs = model.predict(sentence, verbose=args.verbose)
mel_output = outputs["mel_output"]
# cross_attention_weights = outputs["cross_attention_weights"]
cross_attention_weights = outputs["cross_attention_weights"]
attns = [attn for attn in cross_attention_weights]
fig = plt.figure(figsize=(40, 40))
for j, attn in enumerate(attns):
plt.subplot(1, 4, j+1)
plt.imshow(attn[0])
plt.tight_layout()
plt.savefig(str(output_dir / f"sentence_{i}.png"))
mel_output = mel_output.T #(C, T)
np.save(str(output_dir / f"sentence_{i}"), mel_output)
if args.verbose:

View File

@ -21,6 +21,7 @@ import paddle
import parakeet
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow
from parakeet.utils import layer_tools, checkpoint
import time
from config import get_cfg_defaults
@ -34,9 +35,10 @@ def main(config, args):
mel_dir = Path(args.input).expanduser()
output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
for file_path in mel_dir.iterdir():
for file_path in mel_dir.glob("*.npy"):
mel = np.load(str(file_path))
audio = model.predict(mel)
with paddle.amp.auto_cast():
audio = model.predict(mel)
audio_path = output_dir / (
os.path.splitext(file_path.name)[0] + ".wav")
sf.write(audio_path, audio, config.data.sample_rate)

View File

@ -18,6 +18,7 @@ import string
__all__ = ["get_punctuations"]
EN_PUNCT = [
" ",
"-",
"...",
",",

View File

@ -22,6 +22,7 @@ from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules import masking
from parakeet.utils import checkpoint
from tqdm import trange
__all__ = ["Tacotron2", "Tacotron2Loss"]
@ -475,10 +476,11 @@ class Tacotron2Decoder(nn.Layer):
dtype=key.dtype) #[B, C]
self._initialize_decoder_states(key)
T_enc = key.shape[1]
self.mask = None
mel_outputs, stop_logits, alignments = [], [], []
while True:
for _ in trange(max_decoder_steps):
query = self.prenet(query)
mel_output, stop_logit, alignment = self._decode(query)
@ -487,8 +489,12 @@ class Tacotron2Decoder(nn.Layer):
alignments += [alignment]
if F.sigmoid(stop_logit) > stop_threshold:
print("hits stop condition!")
break
elif len(mel_outputs) == max_decoder_steps:
if int(paddle.argmax(alignment[0])) == T_enc - 1:
print("content exhausted!")
break
if len(mel_outputs) == max_decoder_steps:
print("Warning! Reached max decoder steps!!!")
break
@ -718,7 +724,7 @@ class Tacotron2(nn.Layer):
"""
embedded_inputs = self.embedding(text_inputs)
if self.toned:
embedded_inputs = paddle.concat([embedded_inputs, self.embedding_tones(tones)], -1)
embedded_inputs += self.embedding_tones(tones)
encoder_outputs = self.encoder(embedded_inputs)
mel_outputs, stop_logits, alignments = self.decoder.infer(
encoder_outputs,

View File

@ -340,13 +340,15 @@ class CNNPostNet(nn.Layer):
c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append(
nn.Conv1D(
Conv1dBatchNorm(
c_in,
c_out,
kernel_size,
weight_attr=I.XavierUniform(),
padding=padding))
# self.last_bn = nn.BatchNorm1D(d_output)
padding=padding,
momentum=0.99,
epsilon=1e-03))
self.last_bn = nn.BatchNorm1D(d_output, momentum=0.99, epsilon=1e-3)
# for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to
# train the scale and bias
@ -359,8 +361,8 @@ class CNNPostNet(nn.Layer):
if i != (len(self.convs) - 1):
x = F.tanh(x)
# TODO: check it
x = x_in + x
# x = self.last_bn(x_in + x)
# x = x_in + x
x = self.last_bn(x_in + x)
return x
@ -567,8 +569,13 @@ class TransformerTTS(nn.Layer):
text_ids = paddle.to_tensor(self.frontend(input))
input = paddle.unsqueeze(text_ids, 0) # (1, T)
outputs = self.infer(input, max_length=max_length, verbose=verbose)
outputs = {k: v[0].numpy() for k, v in outputs.items()}
return outputs
npy_outputs = {
"mel_output": outputs["mel_output"][0].numpy(),
"encoder_attention_weights": [item[0].numpy() for item in outputs["encoder_attention_weights"]],
"cross_attention_weights": [item[0].numpy() for item in outputs["cross_attention_weights"]],
}
return npy_outputs
def set_constants(self, reduction_factor, drop_n_heads):
self.r = reduction_factor

View File

@ -19,6 +19,7 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
import time
from parakeet.utils import checkpoint
from parakeet.modules import geometry as geo
@ -798,10 +799,13 @@ class ConditionalWaveFlow(nn.LayerList):
Tensor : [shape=(B, T)]
The synthesized audio, where``T <= T_mel \* upsample_factors``.
"""
start = time.time()
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
batch_size, _, time_steps = condition.shape
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
x = self.decoder.inverse(z, condition)
end = time.time()
print("time: {}s".format(end - start))
return x
@paddle.no_grad()

View File

@ -14,6 +14,8 @@
import numpy as np
import matplotlib
import librosa
import librosa.display
matplotlib.use("Agg")
import matplotlib.pylab as plt
from matplotlib import cm, pyplot
@ -69,6 +71,7 @@ def plot_alignment(alignment, title=None):
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
return fig
fig.canvas.draw()
data = save_figure_to_numpy(fig)