deepvoice3: update logging functionalities
This commit is contained in:
parent
70e271ed95
commit
2bde79514a
|
@ -18,7 +18,7 @@ You can choose to install via pypi or clone the repository and install manually.
|
||||||
|
|
||||||
1. Install via pypi.
|
1. Install via pypi.
|
||||||
```bash
|
```bash
|
||||||
pip install parakeet
|
pip install paddle-parakeet
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install manually.
|
2. Install manually.
|
||||||
|
@ -102,6 +102,19 @@ optional arguments:
|
||||||
|
|
||||||
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
|
|
||||||
|
example script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --config=./ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
|
||||||
|
```
|
||||||
|
|
||||||
|
You can monitor training log via tensorboard, using the script below.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd experiment/log
|
||||||
|
tensorboard --logdir=.
|
||||||
|
```
|
||||||
|
|
||||||
## Synthesis
|
## Synthesis
|
||||||
```text
|
```text
|
||||||
usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path
|
usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path
|
||||||
|
@ -127,3 +140,9 @@ optional arguments:
|
||||||
4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence.
|
4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence.
|
||||||
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
|
||||||
|
|
||||||
|
example script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python synthesis.py --config=./ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,6 @@ train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
epochs: 2000
|
epochs: 2000
|
||||||
|
|
||||||
report_interval: 100
|
|
||||||
snap_interval: 1000
|
snap_interval: 1000
|
||||||
eval_interval: 10000
|
eval_interval: 10000
|
||||||
save_interval: 10000
|
save_interval: 10000
|
||||||
|
|
|
@ -22,7 +22,7 @@ from parakeet.models.deepvoice3.loss import TTSLoss
|
||||||
from parakeet.utils.layer_tools import summary
|
from parakeet.utils.layer_tools import summary
|
||||||
|
|
||||||
from data import LJSpeechMetaData, DataCollector, Transform
|
from data import LJSpeechMetaData, DataCollector, Transform
|
||||||
from utils import make_model, eval_model, plot_alignment, plot_alignments, save_state, make_output_tree
|
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -176,6 +176,11 @@ if __name__ == "__main__":
|
||||||
parameter_list=dv3.parameters())
|
parameter_list=dv3.parameters())
|
||||||
gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1)
|
gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1)
|
||||||
|
|
||||||
|
# generation
|
||||||
|
synthesis_config = config["synthesis"]
|
||||||
|
power = synthesis_config["power"]
|
||||||
|
n_iter = synthesis_config["n_iter"]
|
||||||
|
|
||||||
# =========================link(dataloader, paddle)=========================
|
# =========================link(dataloader, paddle)=========================
|
||||||
# CAUTION: it does not return a DataLoader
|
# CAUTION: it does not return a DataLoader
|
||||||
loader = fluid.io.DataLoader.from_generator(capacity=10,
|
loader = fluid.io.DataLoader.from_generator(capacity=10,
|
||||||
|
@ -198,16 +203,14 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# =========================train=========================
|
# =========================train=========================
|
||||||
epoch = train_config["epochs"]
|
epoch = train_config["epochs"]
|
||||||
report_interval = train_config["report_interval"]
|
|
||||||
snap_interval = train_config["snap_interval"]
|
snap_interval = train_config["snap_interval"]
|
||||||
save_interval = train_config["save_interval"]
|
save_interval = train_config["save_interval"]
|
||||||
eval_interval = train_config["eval_interval"]
|
eval_interval = train_config["eval_interval"]
|
||||||
|
|
||||||
global_step = 1
|
global_step = 1
|
||||||
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
|
|
||||||
|
|
||||||
for j in range(1, 1 + epoch):
|
for j in range(1, 1 + epoch):
|
||||||
epoch_loss = {"mel": 0., "lin": 0., "done": 0., "attn": 0.}
|
epoch_loss = 0.
|
||||||
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
|
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
|
||||||
dv3.train() # CAUTION: don't forget to switch to train
|
dv3.train() # CAUTION: don't forget to switch to train
|
||||||
(text_sequences, text_lengths, text_positions, mel_specs,
|
(text_sequences, text_lengths, text_positions, mel_specs,
|
||||||
|
@ -225,7 +228,7 @@ if __name__ == "__main__":
|
||||||
losses = criterion(mel_outputs, linear_outputs, done,
|
losses = criterion(mel_outputs, linear_outputs, done,
|
||||||
alignments, downsampled_mel_specs,
|
alignments, downsampled_mel_specs,
|
||||||
lin_specs, done_flags, text_lengths, frames)
|
lin_specs, done_flags, text_lengths, frames)
|
||||||
l = criterion.compose_loss(losses)
|
l = losses["loss"]
|
||||||
l.backward()
|
l.backward()
|
||||||
# record learning rate before updating
|
# record learning rate before updating
|
||||||
writer.add_scalar("learning_rate",
|
writer.add_scalar("learning_rate",
|
||||||
|
@ -235,41 +238,31 @@ if __name__ == "__main__":
|
||||||
optim.clear_gradients()
|
optim.clear_gradients()
|
||||||
|
|
||||||
# ==================all kinds of tedious things=================
|
# ==================all kinds of tedious things=================
|
||||||
for k in epoch_loss.keys():
|
|
||||||
epoch_loss[k] += losses[k].numpy()[0]
|
|
||||||
average_loss[k] += losses[k].numpy()[0]
|
|
||||||
|
|
||||||
# record step loss into tensorboard
|
# record step loss into tensorboard
|
||||||
|
epoch_loss += l.numpy()[0]
|
||||||
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
|
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
|
||||||
print(step_loss)
|
|
||||||
for k, v in step_loss.items():
|
for k, v in step_loss.items():
|
||||||
writer.add_scalar(k, v, global_step)
|
writer.add_scalar(k, v, global_step)
|
||||||
|
|
||||||
# TODO: clean code
|
# TODO: clean code
|
||||||
# train state saving, the first sentence in the batch
|
# train state saving, the first sentence in the batch
|
||||||
if global_step % snap_interval == 0:
|
if global_step % snap_interval == 0:
|
||||||
linear_outputs_np = linear_outputs.numpy()[0].T
|
|
||||||
denoramlized = np.clip(linear_outputs_np, 0, 1) \
|
|
||||||
* (-min_level_db) \
|
|
||||||
+ min_level_db
|
|
||||||
lin_scaled = np.exp(
|
|
||||||
(denoramlized + ref_level_db) / 20 * np.log(10))
|
|
||||||
synthesis_config = config["synthesis"]
|
|
||||||
power = synthesis_config["power"]
|
|
||||||
n_iter = synthesis_config["n_iter"]
|
|
||||||
wav = librosa.griffinlim(lin_scaled**power,
|
|
||||||
n_iter=n_iter,
|
|
||||||
hop_length=hop_length,
|
|
||||||
win_length=win_length)
|
|
||||||
|
|
||||||
save_state(state_dir,
|
save_state(state_dir,
|
||||||
|
writer,
|
||||||
global_step,
|
global_step,
|
||||||
mel_input=mel_specs.numpy()[0].T,
|
mel_input=downsampled_mel_specs,
|
||||||
mel_output=mel_outputs.numpy()[0].T,
|
mel_output=mel_outputs,
|
||||||
lin_input=lin_specs.numpy()[0].T,
|
lin_input=lin_specs,
|
||||||
lin_output=linear_outputs.numpy()[0].T,
|
lin_output=linear_outputs,
|
||||||
alignments=alignments.numpy()[:, 0, :, :],
|
alignments=alignments,
|
||||||
wav=wav)
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
min_level_db=min_level_db,
|
||||||
|
ref_level_db=ref_level_db,
|
||||||
|
power=power,
|
||||||
|
n_iter=n_iter,
|
||||||
|
preemphasis=preemphasis,
|
||||||
|
sample_rate=sample_rate)
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
if global_step % eval_interval == 0:
|
if global_step % eval_interval == 0:
|
||||||
|
@ -291,28 +284,31 @@ if __name__ == "__main__":
|
||||||
state_dir, "waveform",
|
state_dir, "waveform",
|
||||||
"eval_sample_{:09d}.wav".format(global_step))
|
"eval_sample_{:09d}.wav".format(global_step))
|
||||||
sf.write(wav_path, wav, sample_rate)
|
sf.write(wav_path, wav, sample_rate)
|
||||||
|
writer.add_audio("eval_sample_{}".format(idx),
|
||||||
|
wav,
|
||||||
|
global_step,
|
||||||
|
sample_rate=sample_rate)
|
||||||
attn_path = os.path.join(
|
attn_path = os.path.join(
|
||||||
state_dir, "alignments",
|
state_dir, "alignments",
|
||||||
"eval_sample_attn_{:09d}.png".format(global_step))
|
"eval_sample_attn_{:09d}.png".format(global_step))
|
||||||
plot_alignment(attn, attn_path)
|
plot_alignment(attn, attn_path)
|
||||||
|
writer.add_image("eval_sample_attn{}".format(idx),
|
||||||
|
cm.viridis(attn),
|
||||||
|
global_step,
|
||||||
|
dataformats="HWC")
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if global_step % save_interval == 0:
|
if global_step % save_interval == 0:
|
||||||
dg.save_dygraph(dv3.state_dict(),
|
dg.save_dygraph(
|
||||||
os.path.join(ckpt_dir, "dv3"))
|
dv3.state_dict(),
|
||||||
dg.save_dygraph(optim.state_dict(),
|
os.path.join(ckpt_dir,
|
||||||
os.path.join(ckpt_dir, "dv3"))
|
"model_step_{}".format(global_step)))
|
||||||
|
dg.save_dygraph(
|
||||||
# report average loss
|
optim.state_dict(),
|
||||||
if global_step % report_interval == 0:
|
os.path.join(ckpt_dir,
|
||||||
for k in epoch_loss.keys():
|
"model_step_{}".format(global_step)))
|
||||||
average_loss[k] /= report_interval
|
|
||||||
print("[average_loss] ",
|
|
||||||
"global_step: {}".format(global_step), average_loss)
|
|
||||||
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
# epoch report
|
# epoch report
|
||||||
for k in epoch_loss.keys():
|
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
|
||||||
epoch_loss[k] /= i
|
epoch_loss = 0.
|
||||||
print("[epoch_loss] ", "epoch: {}".format(j), epoch_loss)
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from matplotlib import cm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import librosa
|
import librosa
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
|
@ -125,21 +126,32 @@ def eval_model(model, text, replace_pronounciation_prob, min_level_db,
|
||||||
model.eval()
|
model.eval()
|
||||||
mel_outputs, linear_outputs, alignments, done = model.transduce(
|
mel_outputs, linear_outputs, alignments, done = model.transduce(
|
||||||
dg.to_variable(text), dg.to_variable(text_positions))
|
dg.to_variable(text), dg.to_variable(text_positions))
|
||||||
linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
|
|
||||||
print("linear_outputs's shape: ", linear_outputs_np.shape)
|
|
||||||
|
|
||||||
denoramlized = np.clip(linear_outputs_np, 0,
|
linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
|
||||||
1) * (-min_level_db) + min_level_db
|
wav = spec_to_waveform(linear_outputs_np, min_level_db, ref_level_db,
|
||||||
|
power, n_iter, win_length, hop_length, preemphasis)
|
||||||
|
alignments_np = alignments.numpy()[0] # batch_size = 1
|
||||||
|
print("linear_outputs's shape: ", linear_outputs_np.shape)
|
||||||
|
print("alignmnets' shape:", alignments.shape)
|
||||||
|
return wav, alignments_np
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_waveform(spec, min_level_db, ref_level_db, power, n_iter,
|
||||||
|
win_length, hop_length, preemphasis):
|
||||||
|
"""Convert output linear spec to waveform using griffin-lim vocoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec (ndarray): the output linear spectrogram, shape(C, T), where C means n_fft, T means frames.
|
||||||
|
"""
|
||||||
|
denoramlized = np.clip(spec, 0, 1) * (-min_level_db) + min_level_db
|
||||||
lin_scaled = np.exp((denoramlized + ref_level_db) / 20 * np.log(10))
|
lin_scaled = np.exp((denoramlized + ref_level_db) / 20 * np.log(10))
|
||||||
wav = librosa.griffinlim(lin_scaled**power,
|
wav = librosa.griffinlim(lin_scaled**power,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
hop_length=hop_length,
|
hop_length=hop_length,
|
||||||
win_length=win_length)
|
win_length=win_length)
|
||||||
wav = signal.lfilter([1.], [1., -preemphasis], wav)
|
if preemphasis > 0:
|
||||||
|
wav = signal.lfilter([1.], [1., -preemphasis], wav)
|
||||||
print("alignmnets' shape:", alignments.shape)
|
return wav
|
||||||
alignments_np = alignments.numpy()[0].T
|
|
||||||
return wav, alignments_np
|
|
||||||
|
|
||||||
|
|
||||||
def make_output_tree(output_dir):
|
def make_output_tree(output_dir):
|
||||||
|
@ -157,88 +169,89 @@ def make_output_tree(output_dir):
|
||||||
os.makedirs(p)
|
os.makedirs(p)
|
||||||
|
|
||||||
|
|
||||||
def plot_alignment(alignment, path, info=None):
|
def plot_alignment(alignment, path):
|
||||||
"""
|
"""
|
||||||
Plot an attention layer's alignment for a sentence.
|
Plot an attention layer's alignment for a sentence.
|
||||||
alignment: shape(T_enc, T_dec), and T_enc is flipped
|
alignment: shape(T_dec, T_enc).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
plt.figure()
|
||||||
im = ax.imshow(alignment,
|
plt.imshow(alignment)
|
||||||
aspect='auto',
|
plt.colorbar()
|
||||||
origin='lower',
|
plt.xlabel('Encoder timestep')
|
||||||
interpolation='none')
|
plt.ylabel('Decoder timestep')
|
||||||
fig.colorbar(im, ax=ax)
|
|
||||||
xlabel = 'Decoder timestep'
|
|
||||||
if info is not None:
|
|
||||||
xlabel += '\n\n' + info
|
|
||||||
plt.xlabel(xlabel)
|
|
||||||
plt.ylabel('Encoder timestep')
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig(path)
|
plt.savefig(path)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def plot_alignments(alignments, save_dir, global_step):
|
|
||||||
"""
|
|
||||||
Plot alignments for a sentence when training, we just pick the first
|
|
||||||
sentence. Each layer is plot separately.
|
|
||||||
alignments: shape(N, T_dec, T_enc)
|
|
||||||
"""
|
|
||||||
n_layers = alignments.shape[0]
|
|
||||||
for i, alignment in enumerate(alignments):
|
|
||||||
alignment = alignment.T
|
|
||||||
|
|
||||||
path = os.path.join(save_dir, "layer_{}".format(i))
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
fname = os.path.join(path, "step_{:09d}".format(global_step))
|
|
||||||
plot_alignment(alignment, fname)
|
|
||||||
|
|
||||||
average_alignment = np.mean(alignments, axis=0).T
|
|
||||||
path = os.path.join(save_dir, "average")
|
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
fname = os.path.join(path, "step_{:09d}.png".format(global_step))
|
|
||||||
plot_alignment(average_alignment, fname)
|
|
||||||
|
|
||||||
|
|
||||||
def save_state(save_dir,
|
def save_state(save_dir,
|
||||||
|
writer,
|
||||||
global_step,
|
global_step,
|
||||||
mel_input=None,
|
mel_input=None,
|
||||||
mel_output=None,
|
mel_output=None,
|
||||||
lin_input=None,
|
lin_input=None,
|
||||||
lin_output=None,
|
lin_output=None,
|
||||||
alignments=None,
|
alignments=None,
|
||||||
wav=None):
|
win_length=1024,
|
||||||
|
hop_length=256,
|
||||||
|
min_level_db=-100,
|
||||||
|
ref_level_db=20,
|
||||||
|
power=1.4,
|
||||||
|
n_iter=32,
|
||||||
|
preemphasis=0.97,
|
||||||
|
sample_rate=22050):
|
||||||
|
"""Save training intermediate results. Save states for the first sentence in the batch, including
|
||||||
|
mel_spec(predicted, target), lin_spec(predicted, target), attn, waveform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_dir (str): directory to save results.
|
||||||
|
writer (SummaryWriter): tensorboardX summary writer
|
||||||
|
global_step (int): global step.
|
||||||
|
mel_input (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel)
|
||||||
|
mel_output (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel)
|
||||||
|
lin_input (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin)
|
||||||
|
lin_output (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin)
|
||||||
|
alignments (Variable, optional): Defaults to None. Shape(N, B, T_dec, C_enc)
|
||||||
|
wav ([type], optional): Defaults to None. [description]
|
||||||
|
"""
|
||||||
|
|
||||||
if mel_input is not None and mel_output is not None:
|
if mel_input is not None and mel_output is not None:
|
||||||
path = os.path.join(save_dir, "mel_spec")
|
mel_input = mel_input[0].numpy().T
|
||||||
if not os.path.exists(path):
|
mel_output = mel_output[0].numpy().T
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
|
path = os.path.join(save_dir, "mel_spec")
|
||||||
plt.figure(figsize=(10, 3))
|
plt.figure(figsize=(10, 3))
|
||||||
display.specshow(mel_input)
|
display.specshow(mel_input)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title("mel_input")
|
plt.title("mel_input")
|
||||||
plt.savefig(
|
plt.savefig(
|
||||||
os.path.join(path,
|
os.path.join(path,
|
||||||
"target_mel_spec_step{:09d}".format(global_step)))
|
"target_mel_spec_step{:09d}.png".format(global_step)))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
writer.add_image("target/mel_spec",
|
||||||
|
cm.viridis(mel_input),
|
||||||
|
global_step,
|
||||||
|
dataformats="HWC")
|
||||||
|
|
||||||
plt.figure(figsize=(10, 3))
|
plt.figure(figsize=(10, 3))
|
||||||
display.specshow(mel_output)
|
display.specshow(mel_output)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title("mel_input")
|
plt.title("mel_output")
|
||||||
plt.savefig(
|
plt.savefig(
|
||||||
os.path.join(path,
|
os.path.join(
|
||||||
"predicted_mel_spec_step{:09d}".format(global_step)))
|
path, "predicted_mel_spec_step{:09d}.png".format(global_step)))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
writer.add_image("predicted/mel_spec",
|
||||||
|
cm.viridis(mel_output),
|
||||||
|
global_step,
|
||||||
|
dataformats="HWC")
|
||||||
|
|
||||||
if lin_input is not None and lin_output is not None:
|
if lin_input is not None and lin_output is not None:
|
||||||
|
lin_input = lin_input[0].numpy().T
|
||||||
|
lin_output = lin_output[0].numpy().T
|
||||||
path = os.path.join(save_dir, "lin_spec")
|
path = os.path.join(save_dir, "lin_spec")
|
||||||
if not os.path.exists(path):
|
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
plt.figure(figsize=(10, 3))
|
plt.figure(figsize=(10, 3))
|
||||||
display.specshow(lin_input)
|
display.specshow(lin_input)
|
||||||
|
@ -246,28 +259,50 @@ def save_state(save_dir,
|
||||||
plt.title("mel_input")
|
plt.title("mel_input")
|
||||||
plt.savefig(
|
plt.savefig(
|
||||||
os.path.join(path,
|
os.path.join(path,
|
||||||
"target_lin_spec_step{:09d}".format(global_step)))
|
"target_lin_spec_step{:09d}.png".format(global_step)))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
writer.add_image("target/lin_spec",
|
||||||
|
cm.viridis(lin_input),
|
||||||
|
global_step,
|
||||||
|
dataformats="HWC")
|
||||||
|
|
||||||
plt.figure(figsize=(10, 3))
|
plt.figure(figsize=(10, 3))
|
||||||
display.specshow(lin_output)
|
display.specshow(lin_output)
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.title("mel_input")
|
plt.title("mel_input")
|
||||||
plt.savefig(
|
plt.savefig(
|
||||||
os.path.join(path,
|
os.path.join(
|
||||||
"predicted_lin_spec_step{:09d}".format(global_step)))
|
path, "predicted_lin_spec_step{:09d}.png".format(global_step)))
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
if alignments is not None and len(alignments.shape) == 3:
|
writer.add_image("predicted/lin_spec",
|
||||||
path = os.path.join(save_dir, "alignments")
|
cm.viridis(lin_output),
|
||||||
if not os.path.exists(path):
|
global_step,
|
||||||
os.makedirs(path)
|
dataformats="HWC")
|
||||||
plot_alignments(alignments, path, global_step)
|
|
||||||
|
|
||||||
if wav is not None:
|
if alignments is not None and len(alignments.shape) == 4:
|
||||||
|
path = os.path.join(save_dir, "alignments")
|
||||||
|
alignments = alignments[:, 0, :, :].numpy()
|
||||||
|
for idx, attn_layer in enumerate(alignments):
|
||||||
|
save_path = os.path.join(
|
||||||
|
path,
|
||||||
|
"train_attn_layer_{}_step_{}.png".format(idx, global_step))
|
||||||
|
plot_alignment(attn_layer, save_path)
|
||||||
|
|
||||||
|
writer.add_image("train_attn/layer_{}".format(idx),
|
||||||
|
cm.viridis(attn_layer),
|
||||||
|
global_step,
|
||||||
|
dataformats="HWC")
|
||||||
|
|
||||||
|
if lin_output is not None:
|
||||||
|
wav = spec_to_waveform(lin_output, min_level_db, ref_level_db, power,
|
||||||
|
n_iter, win_length, hop_length, preemphasis)
|
||||||
path = os.path.join(save_dir, "waveform")
|
path = os.path.join(save_dir, "waveform")
|
||||||
if not os.path.exists(path):
|
save_path = os.path.join(
|
||||||
os.makedirs(path)
|
path, "train_sample_step_{:09d}.wav".format(global_step))
|
||||||
sf.write(
|
sf.write(save_path, wav, sample_rate)
|
||||||
os.path.join(path, "sample_step_{:09d}.wav".format(global_step)),
|
writer.add_audio("train_sample",
|
||||||
wav, 22050)
|
wav,
|
||||||
|
global_step,
|
||||||
|
sample_rate=sample_rate)
|
||||||
|
|
|
@ -79,25 +79,26 @@ def unfold_adjacent_frames(folded_frames, r):
|
||||||
|
|
||||||
|
|
||||||
class Decoder(dg.Layer):
|
class Decoder(dg.Layer):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
n_speakers,
|
self,
|
||||||
speaker_dim,
|
n_speakers,
|
||||||
embed_dim,
|
speaker_dim,
|
||||||
mel_dim,
|
embed_dim,
|
||||||
r=1,
|
mel_dim,
|
||||||
max_positions=512,
|
r=1,
|
||||||
padding_idx=None,
|
max_positions=512,
|
||||||
preattention=(ConvSpec(128, 5, 1), ) * 4,
|
padding_idx=None, # remove it!
|
||||||
convolutions=(ConvSpec(128, 5, 1), ) * 4,
|
preattention=(ConvSpec(128, 5, 1), ) * 4,
|
||||||
attention=True,
|
convolutions=(ConvSpec(128, 5, 1), ) * 4,
|
||||||
dropout=0.0,
|
attention=True,
|
||||||
use_memory_mask=False,
|
dropout=0.0,
|
||||||
force_monotonic_attention=False,
|
use_memory_mask=False,
|
||||||
query_position_rate=1.0,
|
force_monotonic_attention=False,
|
||||||
key_position_rate=1.0,
|
query_position_rate=1.0,
|
||||||
window_range=WindowRange(-1, 3),
|
key_position_rate=1.0,
|
||||||
key_projection=True,
|
window_range=WindowRange(-1, 3),
|
||||||
value_projection=True):
|
key_projection=True,
|
||||||
|
value_projection=True):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
@ -109,21 +110,23 @@ class Decoder(dg.Layer):
|
||||||
self.n_speakers = n_speakers
|
self.n_speakers = n_speakers
|
||||||
|
|
||||||
conv_channels = convolutions[0].out_channels
|
conv_channels = convolutions[0].out_channels
|
||||||
|
# only when padding idx is 0 can we easilt handle it
|
||||||
self.embed_keys_positions = PositionEmbedding(max_positions,
|
self.embed_keys_positions = PositionEmbedding(max_positions,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
padding_idx=padding_idx)
|
padding_idx=0)
|
||||||
self.embed_query_positions = PositionEmbedding(max_positions,
|
self.embed_query_positions = PositionEmbedding(max_positions,
|
||||||
conv_channels,
|
conv_channels,
|
||||||
padding_idx=padding_idx)
|
padding_idx=0)
|
||||||
|
|
||||||
if n_speakers > 1:
|
if n_speakers > 1:
|
||||||
# CAUTION: mind the sigmoid
|
|
||||||
std = np.sqrt((1 - dropout) / speaker_dim)
|
std = np.sqrt((1 - dropout) / speaker_dim)
|
||||||
self.speaker_proj1 = Linear(speaker_dim,
|
self.speaker_proj1 = Linear(speaker_dim,
|
||||||
1,
|
1,
|
||||||
|
act="sigmoid",
|
||||||
param_attr=I.Normal(scale=std))
|
param_attr=I.Normal(scale=std))
|
||||||
self.speaker_proj2 = Linear(speaker_dim,
|
self.speaker_proj2 = Linear(speaker_dim,
|
||||||
1,
|
1,
|
||||||
|
act="sigmoid",
|
||||||
param_attr=I.Normal(scale=std))
|
param_attr=I.Normal(scale=std))
|
||||||
|
|
||||||
# prenet
|
# prenet
|
||||||
|
@ -168,6 +171,7 @@ class Decoder(dg.Layer):
|
||||||
] * len(convolutions)
|
] * len(convolutions)
|
||||||
else:
|
else:
|
||||||
self.force_monotonic_attention = force_monotonic_attention
|
self.force_monotonic_attention = force_monotonic_attention
|
||||||
|
|
||||||
for x, y in zip(self.force_monotonic_attention, self.attention):
|
for x, y in zip(self.force_monotonic_attention, self.attention):
|
||||||
if x is True and y is False:
|
if x is True and y is False:
|
||||||
raise ValueError("When not using attention, there is no "
|
raise ValueError("When not using attention, there is no "
|
||||||
|
@ -249,7 +253,7 @@ class Decoder(dg.Layer):
|
||||||
text_positions (Variable): shape(B, T_enc), dtype: int64.
|
text_positions (Variable): shape(B, T_enc), dtype: int64.
|
||||||
Positions indices for text inputs for the encoder, where
|
Positions indices for text inputs for the encoder, where
|
||||||
T_enc means the encoder timesteps.
|
T_enc means the encoder timesteps.
|
||||||
frame_positions (Variable): shape(B, T_dec // r), dtype:
|
frame_positions (Variable): shape(B, T_mel // r), dtype:
|
||||||
int64. Positions indices for each decoder time steps.
|
int64. Positions indices for each decoder time steps.
|
||||||
speaker_embed: shape(batch_size, speaker_dim), speaker embedding,
|
speaker_embed: shape(batch_size, speaker_dim), speaker embedding,
|
||||||
only used for multispeaker model.
|
only used for multispeaker model.
|
||||||
|
@ -287,16 +291,14 @@ class Decoder(dg.Layer):
|
||||||
if text_positions is not None:
|
if text_positions is not None:
|
||||||
w = self.key_position_rate
|
w = self.key_position_rate
|
||||||
if self.n_speakers > 1:
|
if self.n_speakers > 1:
|
||||||
w = w * F.squeeze(F.sigmoid(self.speaker_proj1(speaker_embed)),
|
w = w * F.squeeze(self.speaker_proj1(speaker_embed), [-1])
|
||||||
[-1])
|
|
||||||
text_pos_embed = self.embed_keys_positions(text_positions, w)
|
text_pos_embed = self.embed_keys_positions(text_positions, w)
|
||||||
keys += text_pos_embed # (B, T, C)
|
keys += text_pos_embed # (B, T, C)
|
||||||
|
|
||||||
if frame_positions is not None:
|
if frame_positions is not None:
|
||||||
w = self.query_position_rate
|
w = self.query_position_rate
|
||||||
if self.n_speakers > 1:
|
if self.n_speakers > 1:
|
||||||
w = w * F.squeeze(F.sigmoid(self.speaker_proj2(speaker_embed)),
|
w = w * F.squeeze(self.speaker_proj2(speaker_embed), [-1])
|
||||||
[-1])
|
|
||||||
frame_pos_embed = self.embed_query_positions(frame_positions, w)
|
frame_pos_embed = self.embed_query_positions(frame_positions, w)
|
||||||
else:
|
else:
|
||||||
frame_pos_embed = None
|
frame_pos_embed = None
|
||||||
|
@ -387,8 +389,7 @@ class Decoder(dg.Layer):
|
||||||
w = self.key_position_rate
|
w = self.key_position_rate
|
||||||
if self.n_speakers > 1:
|
if self.n_speakers > 1:
|
||||||
# shape (B, )
|
# shape (B, )
|
||||||
w = w * F.squeeze(F.sigmoid(self.speaker_proj1(speaker_embed)),
|
w = w * F.squeeze(self.speaker_proj1(speaker_embed), [-1])
|
||||||
[-1])
|
|
||||||
text_pos_embed = self.embed_keys_positions(text_positions, w)
|
text_pos_embed = self.embed_keys_positions(text_positions, w)
|
||||||
keys += text_pos_embed # (B, T, C)
|
keys += text_pos_embed # (B, T, C)
|
||||||
|
|
||||||
|
@ -417,8 +418,7 @@ class Decoder(dg.Layer):
|
||||||
dtype="int64")
|
dtype="int64")
|
||||||
w = self.query_position_rate
|
w = self.query_position_rate
|
||||||
if self.n_speakers > 1:
|
if self.n_speakers > 1:
|
||||||
w = w * F.squeeze(F.sigmoid(self.speaker_proj2(speaker_embed)),
|
w = w * F.squeeze(self.speaker_proj2(speaker_embed), [-1])
|
||||||
[-1])
|
|
||||||
# (B, T=1, C)
|
# (B, T=1, C)
|
||||||
frame_pos_embed = self.embed_query_positions(frame_pos, w)
|
frame_pos_embed = self.embed_query_positions(frame_pos, w)
|
||||||
|
|
||||||
|
|
|
@ -35,9 +35,11 @@ class Encoder(dg.Layer):
|
||||||
std = np.sqrt((1 - dropout) / speaker_dim)
|
std = np.sqrt((1 - dropout) / speaker_dim)
|
||||||
self.sp_proj1 = Linear(speaker_dim,
|
self.sp_proj1 = Linear(speaker_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
act="softsign",
|
||||||
param_attr=I.Normal(scale=std))
|
param_attr=I.Normal(scale=std))
|
||||||
self.sp_proj2 = Linear(speaker_dim,
|
self.sp_proj2 = Linear(speaker_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
|
act="softsign",
|
||||||
param_attr=I.Normal(scale=std))
|
param_attr=I.Normal(scale=std))
|
||||||
self.n_speakers = n_speakers
|
self.n_speakers = n_speakers
|
||||||
|
|
||||||
|
@ -104,9 +106,7 @@ class Encoder(dg.Layer):
|
||||||
speaker_embed,
|
speaker_embed,
|
||||||
self.dropout,
|
self.dropout,
|
||||||
dropout_implementation="upscale_in_train")
|
dropout_implementation="upscale_in_train")
|
||||||
x = F.elementwise_add(x,
|
x = F.elementwise_add(x, self.sp_proj1(speaker_embed), axis=0)
|
||||||
F.softsign(self.sp_proj1(speaker_embed)),
|
|
||||||
axis=0)
|
|
||||||
|
|
||||||
input_embed = x
|
input_embed = x
|
||||||
for layer in self.convolutions:
|
for layer in self.convolutions:
|
||||||
|
@ -117,9 +117,7 @@ class Encoder(dg.Layer):
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
if self.n_speakers > 1 and speaker_embed is not None:
|
if self.n_speakers > 1 and speaker_embed is not None:
|
||||||
x = F.elementwise_add(x,
|
x = F.elementwise_add(x, self.sp_proj2(speaker_embed), axis=0)
|
||||||
F.softsign(self.sp_proj2(speaker_embed)),
|
|
||||||
axis=0)
|
|
||||||
|
|
||||||
keys = x # (B, C, T)
|
keys = x # (B, C, T)
|
||||||
values = F.scale(input_embed + x, scale=np.sqrt(0.5))
|
values = F.scale(input_embed + x, scale=np.sqrt(0.5))
|
||||||
|
|
|
@ -156,8 +156,9 @@ class TTSLoss(object):
|
||||||
compute_mel_loss=True,
|
compute_mel_loss=True,
|
||||||
compute_done_loss=True,
|
compute_done_loss=True,
|
||||||
compute_attn_loss=True):
|
compute_attn_loss=True):
|
||||||
|
total_loss = 0.
|
||||||
|
|
||||||
# n_frames # mel_lengths # decoder_lengths
|
# n_frames # mel_lengths # decoder_lengths
|
||||||
# 4 个 loss 吧。lin(l1, bce, lin), mel(l1, bce, mel), attn, done
|
|
||||||
max_frames = lin_hyp.shape[1]
|
max_frames = lin_hyp.shape[1]
|
||||||
max_mel_steps = max_frames // self.downsample_factor
|
max_mel_steps = max_frames // self.downsample_factor
|
||||||
max_decoder_steps = max_mel_steps // self.r
|
max_decoder_steps = max_mel_steps // self.r
|
||||||
|
@ -182,6 +183,7 @@ class TTSLoss(object):
|
||||||
lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
|
lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
|
||||||
lin_loss = self.binary_divergence_weight * lin_bce_loss \
|
lin_loss = self.binary_divergence_weight * lin_bce_loss \
|
||||||
+ (1 - self.binary_divergence_weight) * lin_l1_loss
|
+ (1 - self.binary_divergence_weight) * lin_l1_loss
|
||||||
|
total_loss += lin_loss
|
||||||
|
|
||||||
if compute_mel_loss:
|
if compute_mel_loss:
|
||||||
mel_hyp = mel_hyp[:, :-self.time_shift, :]
|
mel_hyp = mel_hyp[:, :-self.time_shift, :]
|
||||||
|
@ -192,32 +194,28 @@ class TTSLoss(object):
|
||||||
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
|
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
|
||||||
mel_loss = self.binary_divergence_weight * mel_bce_loss \
|
mel_loss = self.binary_divergence_weight * mel_bce_loss \
|
||||||
+ (1 - self.binary_divergence_weight) * mel_l1_loss
|
+ (1 - self.binary_divergence_weight) * mel_l1_loss
|
||||||
|
total_loss += mel_loss
|
||||||
|
|
||||||
if compute_attn_loss:
|
if compute_attn_loss:
|
||||||
attn_loss = self.attention_loss(
|
attn_loss = self.attention_loss(
|
||||||
attn_hyp, input_lengths.numpy(),
|
attn_hyp, input_lengths.numpy(),
|
||||||
n_frames.numpy() // (self.downsample_factor * self.r))
|
n_frames.numpy() // (self.downsample_factor * self.r))
|
||||||
|
total_loss += attn_loss
|
||||||
|
|
||||||
if compute_done_loss:
|
if compute_done_loss:
|
||||||
done_loss = self.done_loss(done_hyp, done_ref)
|
done_loss = self.done_loss(done_hyp, done_ref)
|
||||||
|
total_loss += done_loss
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"mel": mel_loss if compute_mel_loss else None,
|
"loss": total_loss,
|
||||||
"mel_l1_loss": mel_l1_loss if compute_mel_loss else None,
|
"mel/mel_loss": mel_loss if compute_mel_loss else None,
|
||||||
"mel_bce_loss": mel_bce_loss if compute_mel_loss else None,
|
"mel/l1_loss": mel_l1_loss if compute_mel_loss else None,
|
||||||
"lin": lin_loss if compute_lin_loss else None,
|
"mel/bce_loss": mel_bce_loss if compute_mel_loss else None,
|
||||||
"lin_l1_loss": lin_l1_loss if compute_lin_loss else None,
|
"lin/lin_loss": lin_loss if compute_lin_loss else None,
|
||||||
"lin_bce_loss": lin_bce_loss if compute_lin_loss else None,
|
"lin/l1_loss": lin_l1_loss if compute_lin_loss else None,
|
||||||
|
"lin/bce_loss": lin_bce_loss if compute_lin_loss else None,
|
||||||
"done": done_loss if compute_done_loss else None,
|
"done": done_loss if compute_done_loss else None,
|
||||||
"attn": attn_loss if compute_attn_loss else None,
|
"attn": attn_loss if compute_attn_loss else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def compose_loss(result):
|
|
||||||
total_loss = 0.
|
|
||||||
for k in ["mel", "lin", "done", "attn"]:
|
|
||||||
if result[k] is not None:
|
|
||||||
total_loss += result[k]
|
|
||||||
return total_loss
|
|
|
@ -6,7 +6,6 @@ import paddle.fluid.layers as F
|
||||||
from parakeet.modules import customized as L
|
from parakeet.modules import customized as L
|
||||||
|
|
||||||
|
|
||||||
# TODO: just use numpy to init weight norm wrappers
|
|
||||||
def norm(param, dim, power):
|
def norm(param, dim, power):
|
||||||
powered = F.pow(param, power)
|
powered = F.pow(param, power)
|
||||||
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
|
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
|
||||||
|
@ -73,7 +72,7 @@ class WeightNormWrapper(dg.Layer):
|
||||||
w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype))
|
w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype))
|
||||||
F.assign(temp, getattr(self, w_g))
|
F.assign(temp, getattr(self, w_g))
|
||||||
|
|
||||||
# also set this
|
# also set this when setting up
|
||||||
setattr(
|
setattr(
|
||||||
self.layer, self.param_name,
|
self.layer, self.param_name,
|
||||||
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,
|
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,
|
||||||
|
|
Loading…
Reference in New Issue