WIP: baker
This commit is contained in:
parent
3c60fec900
commit
a005cc88a3
|
@ -17,13 +17,15 @@ import time
|
|||
from pathlib import Path
|
||||
import numpy as np
|
||||
import paddle
|
||||
import matplotlib
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import parakeet
|
||||
from parakeet.frontend import English
|
||||
from parakeet.models.transformer_tts import TransformerTTS
|
||||
from parakeet.utils import scheduler
|
||||
from parakeet.training.cli import default_argument_parser
|
||||
from parakeet.utils.display import add_attention_plots
|
||||
from parakeet.utils.display import add_attention_plots, pack_attention_images
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
@ -49,7 +51,16 @@ def main(config, args):
|
|||
for i, sentence in enumerate(sentences):
|
||||
outputs = model.predict(sentence, verbose=args.verbose)
|
||||
mel_output = outputs["mel_output"]
|
||||
# cross_attention_weights = outputs["cross_attention_weights"]
|
||||
cross_attention_weights = outputs["cross_attention_weights"]
|
||||
attns = [attn for attn in cross_attention_weights]
|
||||
|
||||
fig = plt.figure(figsize=(40, 40))
|
||||
for j, attn in enumerate(attns):
|
||||
plt.subplot(1, 4, j+1)
|
||||
plt.imshow(attn[0])
|
||||
plt.tight_layout()
|
||||
plt.savefig(str(output_dir / f"sentence_{i}.png"))
|
||||
|
||||
mel_output = mel_output.T #(C, T)
|
||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||
if args.verbose:
|
||||
|
|
|
@ -21,6 +21,7 @@ import paddle
|
|||
import parakeet
|
||||
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow
|
||||
from parakeet.utils import layer_tools, checkpoint
|
||||
import time
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
@ -34,8 +35,9 @@ def main(config, args):
|
|||
mel_dir = Path(args.input).expanduser()
|
||||
output_dir = Path(args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for file_path in mel_dir.iterdir():
|
||||
for file_path in mel_dir.glob("*.npy"):
|
||||
mel = np.load(str(file_path))
|
||||
with paddle.amp.auto_cast():
|
||||
audio = model.predict(mel)
|
||||
audio_path = output_dir / (
|
||||
os.path.splitext(file_path.name)[0] + ".wav")
|
||||
|
|
|
@ -18,6 +18,7 @@ import string
|
|||
__all__ = ["get_punctuations"]
|
||||
|
||||
EN_PUNCT = [
|
||||
" ",
|
||||
"-",
|
||||
"...",
|
||||
",",
|
||||
|
|
|
@ -22,6 +22,7 @@ from parakeet.modules.conv import Conv1dBatchNorm
|
|||
from parakeet.modules.attention import LocationSensitiveAttention
|
||||
from parakeet.modules import masking
|
||||
from parakeet.utils import checkpoint
|
||||
from tqdm import trange
|
||||
|
||||
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||
|
||||
|
@ -475,10 +476,11 @@ class Tacotron2Decoder(nn.Layer):
|
|||
dtype=key.dtype) #[B, C]
|
||||
|
||||
self._initialize_decoder_states(key)
|
||||
T_enc = key.shape[1]
|
||||
self.mask = None
|
||||
|
||||
mel_outputs, stop_logits, alignments = [], [], []
|
||||
while True:
|
||||
for _ in trange(max_decoder_steps):
|
||||
query = self.prenet(query)
|
||||
mel_output, stop_logit, alignment = self._decode(query)
|
||||
|
||||
|
@ -487,8 +489,12 @@ class Tacotron2Decoder(nn.Layer):
|
|||
alignments += [alignment]
|
||||
|
||||
if F.sigmoid(stop_logit) > stop_threshold:
|
||||
print("hits stop condition!")
|
||||
break
|
||||
elif len(mel_outputs) == max_decoder_steps:
|
||||
if int(paddle.argmax(alignment[0])) == T_enc - 1:
|
||||
print("content exhausted!")
|
||||
break
|
||||
if len(mel_outputs) == max_decoder_steps:
|
||||
print("Warning! Reached max decoder steps!!!")
|
||||
break
|
||||
|
||||
|
@ -718,7 +724,7 @@ class Tacotron2(nn.Layer):
|
|||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
if self.toned:
|
||||
embedded_inputs = paddle.concat([embedded_inputs, self.embedding_tones(tones)], -1)
|
||||
embedded_inputs += self.embedding_tones(tones)
|
||||
encoder_outputs = self.encoder(embedded_inputs)
|
||||
mel_outputs, stop_logits, alignments = self.decoder.infer(
|
||||
encoder_outputs,
|
||||
|
|
|
@ -340,13 +340,15 @@ class CNNPostNet(nn.Layer):
|
|||
c_in = d_input if i == 0 else d_hidden
|
||||
c_out = d_output if i == n_layers - 1 else d_hidden
|
||||
self.convs.append(
|
||||
nn.Conv1D(
|
||||
Conv1dBatchNorm(
|
||||
c_in,
|
||||
c_out,
|
||||
kernel_size,
|
||||
weight_attr=I.XavierUniform(),
|
||||
padding=padding))
|
||||
# self.last_bn = nn.BatchNorm1D(d_output)
|
||||
padding=padding,
|
||||
momentum=0.99,
|
||||
epsilon=1e-03))
|
||||
self.last_bn = nn.BatchNorm1D(d_output, momentum=0.99, epsilon=1e-3)
|
||||
# for a layer that ends with a normalization layer that is targeted to
|
||||
# output a non zero-central output, it may take a long time to
|
||||
# train the scale and bias
|
||||
|
@ -359,8 +361,8 @@ class CNNPostNet(nn.Layer):
|
|||
if i != (len(self.convs) - 1):
|
||||
x = F.tanh(x)
|
||||
# TODO: check it
|
||||
x = x_in + x
|
||||
# x = self.last_bn(x_in + x)
|
||||
# x = x_in + x
|
||||
x = self.last_bn(x_in + x)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -567,8 +569,13 @@ class TransformerTTS(nn.Layer):
|
|||
text_ids = paddle.to_tensor(self.frontend(input))
|
||||
input = paddle.unsqueeze(text_ids, 0) # (1, T)
|
||||
outputs = self.infer(input, max_length=max_length, verbose=verbose)
|
||||
outputs = {k: v[0].numpy() for k, v in outputs.items()}
|
||||
return outputs
|
||||
npy_outputs = {
|
||||
"mel_output": outputs["mel_output"][0].numpy(),
|
||||
"encoder_attention_weights": [item[0].numpy() for item in outputs["encoder_attention_weights"]],
|
||||
"cross_attention_weights": [item[0].numpy() for item in outputs["cross_attention_weights"]],
|
||||
}
|
||||
|
||||
return npy_outputs
|
||||
|
||||
def set_constants(self, reduction_factor, drop_n_heads):
|
||||
self.r = reduction_factor
|
||||
|
|
|
@ -19,6 +19,7 @@ import paddle
|
|||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
import time
|
||||
|
||||
from parakeet.utils import checkpoint
|
||||
from parakeet.modules import geometry as geo
|
||||
|
@ -798,10 +799,13 @@ class ConditionalWaveFlow(nn.LayerList):
|
|||
Tensor : [shape=(B, T)]
|
||||
The synthesized audio, where``T <= T_mel \* upsample_factors``.
|
||||
"""
|
||||
start = time.time()
|
||||
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
|
||||
batch_size, _, time_steps = condition.shape
|
||||
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
|
||||
x = self.decoder.inverse(z, condition)
|
||||
end = time.time()
|
||||
print("time: {}s".format(end - start))
|
||||
return x
|
||||
|
||||
@paddle.no_grad()
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import librosa
|
||||
import librosa.display
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from matplotlib import cm, pyplot
|
||||
|
@ -69,6 +71,7 @@ def plot_alignment(alignment, title=None):
|
|||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
fig.canvas.draw()
|
||||
data = save_figure_to_numpy(fig)
|
||||
|
|
Loading…
Reference in New Issue