WIP: baker

This commit is contained in:
chenfeiyu 2021-03-27 12:43:03 +08:00
parent 3c60fec900
commit a005cc88a3
7 changed files with 48 additions and 14 deletions

View File

@ -17,13 +17,15 @@ import time
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle import paddle
import matplotlib
from matplotlib import pyplot as plt
import parakeet import parakeet
from parakeet.frontend import English from parakeet.frontend import English
from parakeet.models.transformer_tts import TransformerTTS from parakeet.models.transformer_tts import TransformerTTS
from parakeet.utils import scheduler from parakeet.utils import scheduler
from parakeet.training.cli import default_argument_parser from parakeet.training.cli import default_argument_parser
from parakeet.utils.display import add_attention_plots from parakeet.utils.display import add_attention_plots, pack_attention_images
from config import get_cfg_defaults from config import get_cfg_defaults
@ -49,7 +51,16 @@ def main(config, args):
for i, sentence in enumerate(sentences): for i, sentence in enumerate(sentences):
outputs = model.predict(sentence, verbose=args.verbose) outputs = model.predict(sentence, verbose=args.verbose)
mel_output = outputs["mel_output"] mel_output = outputs["mel_output"]
# cross_attention_weights = outputs["cross_attention_weights"] cross_attention_weights = outputs["cross_attention_weights"]
attns = [attn for attn in cross_attention_weights]
fig = plt.figure(figsize=(40, 40))
for j, attn in enumerate(attns):
plt.subplot(1, 4, j+1)
plt.imshow(attn[0])
plt.tight_layout()
plt.savefig(str(output_dir / f"sentence_{i}.png"))
mel_output = mel_output.T #(C, T) mel_output = mel_output.T #(C, T)
np.save(str(output_dir / f"sentence_{i}"), mel_output) np.save(str(output_dir / f"sentence_{i}"), mel_output)
if args.verbose: if args.verbose:

View File

@ -21,6 +21,7 @@ import paddle
import parakeet import parakeet
from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow from parakeet.models.waveflow import UpsampleNet, WaveFlow, ConditionalWaveFlow
from parakeet.utils import layer_tools, checkpoint from parakeet.utils import layer_tools, checkpoint
import time
from config import get_cfg_defaults from config import get_cfg_defaults
@ -34,9 +35,10 @@ def main(config, args):
mel_dir = Path(args.input).expanduser() mel_dir = Path(args.input).expanduser()
output_dir = Path(args.output).expanduser() output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
for file_path in mel_dir.iterdir(): for file_path in mel_dir.glob("*.npy"):
mel = np.load(str(file_path)) mel = np.load(str(file_path))
audio = model.predict(mel) with paddle.amp.auto_cast():
audio = model.predict(mel)
audio_path = output_dir / ( audio_path = output_dir / (
os.path.splitext(file_path.name)[0] + ".wav") os.path.splitext(file_path.name)[0] + ".wav")
sf.write(audio_path, audio, config.data.sample_rate) sf.write(audio_path, audio, config.data.sample_rate)

View File

@ -18,6 +18,7 @@ import string
__all__ = ["get_punctuations"] __all__ = ["get_punctuations"]
EN_PUNCT = [ EN_PUNCT = [
" ",
"-", "-",
"...", "...",
",", ",",

View File

@ -22,6 +22,7 @@ from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules import masking from parakeet.modules import masking
from parakeet.utils import checkpoint from parakeet.utils import checkpoint
from tqdm import trange
__all__ = ["Tacotron2", "Tacotron2Loss"] __all__ = ["Tacotron2", "Tacotron2Loss"]
@ -475,10 +476,11 @@ class Tacotron2Decoder(nn.Layer):
dtype=key.dtype) #[B, C] dtype=key.dtype) #[B, C]
self._initialize_decoder_states(key) self._initialize_decoder_states(key)
T_enc = key.shape[1]
self.mask = None self.mask = None
mel_outputs, stop_logits, alignments = [], [], [] mel_outputs, stop_logits, alignments = [], [], []
while True: for _ in trange(max_decoder_steps):
query = self.prenet(query) query = self.prenet(query)
mel_output, stop_logit, alignment = self._decode(query) mel_output, stop_logit, alignment = self._decode(query)
@ -487,8 +489,12 @@ class Tacotron2Decoder(nn.Layer):
alignments += [alignment] alignments += [alignment]
if F.sigmoid(stop_logit) > stop_threshold: if F.sigmoid(stop_logit) > stop_threshold:
print("hits stop condition!")
break break
elif len(mel_outputs) == max_decoder_steps: if int(paddle.argmax(alignment[0])) == T_enc - 1:
print("content exhausted!")
break
if len(mel_outputs) == max_decoder_steps:
print("Warning! Reached max decoder steps!!!") print("Warning! Reached max decoder steps!!!")
break break
@ -718,7 +724,7 @@ class Tacotron2(nn.Layer):
""" """
embedded_inputs = self.embedding(text_inputs) embedded_inputs = self.embedding(text_inputs)
if self.toned: if self.toned:
embedded_inputs = paddle.concat([embedded_inputs, self.embedding_tones(tones)], -1) embedded_inputs += self.embedding_tones(tones)
encoder_outputs = self.encoder(embedded_inputs) encoder_outputs = self.encoder(embedded_inputs)
mel_outputs, stop_logits, alignments = self.decoder.infer( mel_outputs, stop_logits, alignments = self.decoder.infer(
encoder_outputs, encoder_outputs,

View File

@ -340,13 +340,15 @@ class CNNPostNet(nn.Layer):
c_in = d_input if i == 0 else d_hidden c_in = d_input if i == 0 else d_hidden
c_out = d_output if i == n_layers - 1 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden
self.convs.append( self.convs.append(
nn.Conv1D( Conv1dBatchNorm(
c_in, c_in,
c_out, c_out,
kernel_size, kernel_size,
weight_attr=I.XavierUniform(), weight_attr=I.XavierUniform(),
padding=padding)) padding=padding,
# self.last_bn = nn.BatchNorm1D(d_output) momentum=0.99,
epsilon=1e-03))
self.last_bn = nn.BatchNorm1D(d_output, momentum=0.99, epsilon=1e-3)
# for a layer that ends with a normalization layer that is targeted to # for a layer that ends with a normalization layer that is targeted to
# output a non zero-central output, it may take a long time to # output a non zero-central output, it may take a long time to
# train the scale and bias # train the scale and bias
@ -359,8 +361,8 @@ class CNNPostNet(nn.Layer):
if i != (len(self.convs) - 1): if i != (len(self.convs) - 1):
x = F.tanh(x) x = F.tanh(x)
# TODO: check it # TODO: check it
x = x_in + x # x = x_in + x
# x = self.last_bn(x_in + x) x = self.last_bn(x_in + x)
return x return x
@ -567,8 +569,13 @@ class TransformerTTS(nn.Layer):
text_ids = paddle.to_tensor(self.frontend(input)) text_ids = paddle.to_tensor(self.frontend(input))
input = paddle.unsqueeze(text_ids, 0) # (1, T) input = paddle.unsqueeze(text_ids, 0) # (1, T)
outputs = self.infer(input, max_length=max_length, verbose=verbose) outputs = self.infer(input, max_length=max_length, verbose=verbose)
outputs = {k: v[0].numpy() for k, v in outputs.items()} npy_outputs = {
return outputs "mel_output": outputs["mel_output"][0].numpy(),
"encoder_attention_weights": [item[0].numpy() for item in outputs["encoder_attention_weights"]],
"cross_attention_weights": [item[0].numpy() for item in outputs["cross_attention_weights"]],
}
return npy_outputs
def set_constants(self, reduction_factor, drop_n_heads): def set_constants(self, reduction_factor, drop_n_heads):
self.r = reduction_factor self.r = reduction_factor

View File

@ -19,6 +19,7 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
import time
from parakeet.utils import checkpoint from parakeet.utils import checkpoint
from parakeet.modules import geometry as geo from parakeet.modules import geometry as geo
@ -798,10 +799,13 @@ class ConditionalWaveFlow(nn.LayerList):
Tensor : [shape=(B, T)] Tensor : [shape=(B, T)]
The synthesized audio, where``T <= T_mel \* upsample_factors``. The synthesized audio, where``T <= T_mel \* upsample_factors``.
""" """
start = time.time()
condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T) condition = self.encoder(mel, trim_conv_artifact=True) #(B, C, T)
batch_size, _, time_steps = condition.shape batch_size, _, time_steps = condition.shape
z = paddle.randn([batch_size, time_steps], dtype=mel.dtype) z = paddle.randn([batch_size, time_steps], dtype=mel.dtype)
x = self.decoder.inverse(z, condition) x = self.decoder.inverse(z, condition)
end = time.time()
print("time: {}s".format(end - start))
return x return x
@paddle.no_grad() @paddle.no_grad()

View File

@ -14,6 +14,8 @@
import numpy as np import numpy as np
import matplotlib import matplotlib
import librosa
import librosa.display
matplotlib.use("Agg") matplotlib.use("Agg")
import matplotlib.pylab as plt import matplotlib.pylab as plt
from matplotlib import cm, pyplot from matplotlib import cm, pyplot
@ -69,6 +71,7 @@ def plot_alignment(alignment, title=None):
plt.xlabel(xlabel) plt.xlabel(xlabel)
plt.ylabel('Encoder timestep') plt.ylabel('Encoder timestep')
plt.tight_layout() plt.tight_layout()
return fig
fig.canvas.draw() fig.canvas.draw()
data = save_figure_to_numpy(fig) data = save_figure_to_numpy(fig)