deepvoice3: update logging functionalities

This commit is contained in:
chenfeiyu 2020-02-17 14:52:12 +00:00
parent 70e271ed95
commit 2bde79514a
8 changed files with 216 additions and 172 deletions

View File

@ -18,7 +18,7 @@ You can choose to install via pypi or clone the repository and install manually.
1. Install via pypi. 1. Install via pypi.
```bash ```bash
pip install parakeet pip install paddle-parakeet
``` ```
2. Install manually. 2. Install manually.
@ -102,6 +102,19 @@ optional arguments:
5. `--device` is the device (gpu id) to use for training. `-1` means CPU. 5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
example script:
```bash
python train.py --config=./ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
```
You can monitor training log via tensorboard, using the script below.
```bash
cd experiment/log
tensorboard --logdir=.
```
## Synthesis ## Synthesis
```text ```text
usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path usage: synthesis.py [-h] [-c CONFIG] [-g DEVICE] checkpoint text output_path
@ -127,3 +140,9 @@ optional arguments:
4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence. 4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`) and attention plots (*.png) for each sentence.
5. `--device` is the device (gpu id) to use for training. `-1` means CPU. 5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
example script:
```bash
python synthesis.py --config=./ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated
```

View File

@ -85,7 +85,6 @@ train:
batch_size: 16 batch_size: 16
epochs: 2000 epochs: 2000
report_interval: 100
snap_interval: 1000 snap_interval: 1000
eval_interval: 10000 eval_interval: 10000
save_interval: 10000 save_interval: 10000

View File

@ -22,7 +22,7 @@ from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, DataCollector, Transform from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, plot_alignment, plot_alignments, save_state, make_output_tree from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -176,6 +176,11 @@ if __name__ == "__main__":
parameter_list=dv3.parameters()) parameter_list=dv3.parameters())
gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1) gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1)
# generation
synthesis_config = config["synthesis"]
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
# =========================link(dataloader, paddle)========================= # =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader # CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(capacity=10, loader = fluid.io.DataLoader.from_generator(capacity=10,
@ -198,16 +203,14 @@ if __name__ == "__main__":
# =========================train========================= # =========================train=========================
epoch = train_config["epochs"] epoch = train_config["epochs"]
report_interval = train_config["report_interval"]
snap_interval = train_config["snap_interval"] snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"] save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"] eval_interval = train_config["eval_interval"]
global_step = 1 global_step = 1
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
for j in range(1, 1 + epoch): for j in range(1, 1 + epoch):
epoch_loss = {"mel": 0., "lin": 0., "done": 0., "attn": 0.} epoch_loss = 0.
for i, batch in tqdm.tqdm(enumerate(loader, 1)): for i, batch in tqdm.tqdm(enumerate(loader, 1)):
dv3.train() # CAUTION: don't forget to switch to train dv3.train() # CAUTION: don't forget to switch to train
(text_sequences, text_lengths, text_positions, mel_specs, (text_sequences, text_lengths, text_positions, mel_specs,
@ -225,7 +228,7 @@ if __name__ == "__main__":
losses = criterion(mel_outputs, linear_outputs, done, losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs, alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames) lin_specs, done_flags, text_lengths, frames)
l = criterion.compose_loss(losses) l = losses["loss"]
l.backward() l.backward()
# record learning rate before updating # record learning rate before updating
writer.add_scalar("learning_rate", writer.add_scalar("learning_rate",
@ -235,41 +238,31 @@ if __name__ == "__main__":
optim.clear_gradients() optim.clear_gradients()
# ==================all kinds of tedious things================= # ==================all kinds of tedious things=================
for k in epoch_loss.keys():
epoch_loss[k] += losses[k].numpy()[0]
average_loss[k] += losses[k].numpy()[0]
# record step loss into tensorboard # record step loss into tensorboard
epoch_loss += l.numpy()[0]
step_loss = {k: v.numpy()[0] for k, v in losses.items()} step_loss = {k: v.numpy()[0] for k, v in losses.items()}
print(step_loss)
for k, v in step_loss.items(): for k, v in step_loss.items():
writer.add_scalar(k, v, global_step) writer.add_scalar(k, v, global_step)
# TODO: clean code # TODO: clean code
# train state saving, the first sentence in the batch # train state saving, the first sentence in the batch
if global_step % snap_interval == 0: if global_step % snap_interval == 0:
linear_outputs_np = linear_outputs.numpy()[0].T
denoramlized = np.clip(linear_outputs_np, 0, 1) \
* (-min_level_db) \
+ min_level_db
lin_scaled = np.exp(
(denoramlized + ref_level_db) / 20 * np.log(10))
synthesis_config = config["synthesis"]
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
wav = librosa.griffinlim(lin_scaled**power,
n_iter=n_iter,
hop_length=hop_length,
win_length=win_length)
save_state(state_dir, save_state(state_dir,
writer,
global_step, global_step,
mel_input=mel_specs.numpy()[0].T, mel_input=downsampled_mel_specs,
mel_output=mel_outputs.numpy()[0].T, mel_output=mel_outputs,
lin_input=lin_specs.numpy()[0].T, lin_input=lin_specs,
lin_output=linear_outputs.numpy()[0].T, lin_output=linear_outputs,
alignments=alignments.numpy()[:, 0, :, :], alignments=alignments,
wav=wav) win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
# evaluation # evaluation
if global_step % eval_interval == 0: if global_step % eval_interval == 0:
@ -291,28 +284,31 @@ if __name__ == "__main__":
state_dir, "waveform", state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step)) "eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate) sf.write(wav_path, wav, sample_rate)
writer.add_audio("eval_sample_{}".format(idx),
wav,
global_step,
sample_rate=sample_rate)
attn_path = os.path.join( attn_path = os.path.join(
state_dir, "alignments", state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step)) "eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path) plot_alignment(attn, attn_path)
writer.add_image("eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
# save checkpoint # save checkpoint
if global_step % save_interval == 0: if global_step % save_interval == 0:
dg.save_dygraph(dv3.state_dict(), dg.save_dygraph(
os.path.join(ckpt_dir, "dv3")) dv3.state_dict(),
dg.save_dygraph(optim.state_dict(), os.path.join(ckpt_dir,
os.path.join(ckpt_dir, "dv3")) "model_step_{}".format(global_step)))
dg.save_dygraph(
# report average loss optim.state_dict(),
if global_step % report_interval == 0: os.path.join(ckpt_dir,
for k in epoch_loss.keys(): "model_step_{}".format(global_step)))
average_loss[k] /= report_interval
print("[average_loss] ",
"global_step: {}".format(global_step), average_loss)
average_loss = {"mel": 0, "lin": 0, "done": 0, "attn": 0}
global_step += 1 global_step += 1
# epoch report # epoch report
for k in epoch_loss.keys(): writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
epoch_loss[k] /= i epoch_loss = 0.
print("[epoch_loss] ", "epoch: {}".format(j), epoch_loss)

View File

@ -1,5 +1,6 @@
import os import os
import numpy as np import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import librosa import librosa
from scipy import signal from scipy import signal
@ -125,21 +126,32 @@ def eval_model(model, text, replace_pronounciation_prob, min_level_db,
model.eval() model.eval()
mel_outputs, linear_outputs, alignments, done = model.transduce( mel_outputs, linear_outputs, alignments, done = model.transduce(
dg.to_variable(text), dg.to_variable(text_positions)) dg.to_variable(text), dg.to_variable(text_positions))
linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
print("linear_outputs's shape: ", linear_outputs_np.shape)
denoramlized = np.clip(linear_outputs_np, 0, linear_outputs_np = linear_outputs.numpy()[0].T # (C, T)
1) * (-min_level_db) + min_level_db wav = spec_to_waveform(linear_outputs_np, min_level_db, ref_level_db,
power, n_iter, win_length, hop_length, preemphasis)
alignments_np = alignments.numpy()[0] # batch_size = 1
print("linear_outputs's shape: ", linear_outputs_np.shape)
print("alignmnets' shape:", alignments.shape)
return wav, alignments_np
def spec_to_waveform(spec, min_level_db, ref_level_db, power, n_iter,
win_length, hop_length, preemphasis):
"""Convert output linear spec to waveform using griffin-lim vocoder.
Args:
spec (ndarray): the output linear spectrogram, shape(C, T), where C means n_fft, T means frames.
"""
denoramlized = np.clip(spec, 0, 1) * (-min_level_db) + min_level_db
lin_scaled = np.exp((denoramlized + ref_level_db) / 20 * np.log(10)) lin_scaled = np.exp((denoramlized + ref_level_db) / 20 * np.log(10))
wav = librosa.griffinlim(lin_scaled**power, wav = librosa.griffinlim(lin_scaled**power,
n_iter=n_iter, n_iter=n_iter,
hop_length=hop_length, hop_length=hop_length,
win_length=win_length) win_length=win_length)
wav = signal.lfilter([1.], [1., -preemphasis], wav) if preemphasis > 0:
wav = signal.lfilter([1.], [1., -preemphasis], wav)
print("alignmnets' shape:", alignments.shape) return wav
alignments_np = alignments.numpy()[0].T
return wav, alignments_np
def make_output_tree(output_dir): def make_output_tree(output_dir):
@ -157,88 +169,89 @@ def make_output_tree(output_dir):
os.makedirs(p) os.makedirs(p)
def plot_alignment(alignment, path, info=None): def plot_alignment(alignment, path):
""" """
Plot an attention layer's alignment for a sentence. Plot an attention layer's alignment for a sentence.
alignment: shape(T_enc, T_dec), and T_enc is flipped alignment: shape(T_dec, T_enc).
""" """
fig, ax = plt.subplots() plt.figure()
im = ax.imshow(alignment, plt.imshow(alignment)
aspect='auto', plt.colorbar()
origin='lower', plt.xlabel('Encoder timestep')
interpolation='none') plt.ylabel('Decoder timestep')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
xlabel += '\n\n' + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.tight_layout()
plt.savefig(path) plt.savefig(path)
plt.close() plt.close()
def plot_alignments(alignments, save_dir, global_step):
"""
Plot alignments for a sentence when training, we just pick the first
sentence. Each layer is plot separately.
alignments: shape(N, T_dec, T_enc)
"""
n_layers = alignments.shape[0]
for i, alignment in enumerate(alignments):
alignment = alignment.T
path = os.path.join(save_dir, "layer_{}".format(i))
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.join(path, "step_{:09d}".format(global_step))
plot_alignment(alignment, fname)
average_alignment = np.mean(alignments, axis=0).T
path = os.path.join(save_dir, "average")
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.join(path, "step_{:09d}.png".format(global_step))
plot_alignment(average_alignment, fname)
def save_state(save_dir, def save_state(save_dir,
writer,
global_step, global_step,
mel_input=None, mel_input=None,
mel_output=None, mel_output=None,
lin_input=None, lin_input=None,
lin_output=None, lin_output=None,
alignments=None, alignments=None,
wav=None): win_length=1024,
hop_length=256,
min_level_db=-100,
ref_level_db=20,
power=1.4,
n_iter=32,
preemphasis=0.97,
sample_rate=22050):
"""Save training intermediate results. Save states for the first sentence in the batch, including
mel_spec(predicted, target), lin_spec(predicted, target), attn, waveform.
Args:
save_dir (str): directory to save results.
writer (SummaryWriter): tensorboardX summary writer
global_step (int): global step.
mel_input (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel)
mel_output (Variable, optional): Defaults to None. Shape(B, T_mel, C_mel)
lin_input (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin)
lin_output (Variable, optional): Defaults to None. Shape(B, T_lin, C_lin)
alignments (Variable, optional): Defaults to None. Shape(N, B, T_dec, C_enc)
wav ([type], optional): Defaults to None. [description]
"""
if mel_input is not None and mel_output is not None: if mel_input is not None and mel_output is not None:
path = os.path.join(save_dir, "mel_spec") mel_input = mel_input[0].numpy().T
if not os.path.exists(path): mel_output = mel_output[0].numpy().T
os.makedirs(path)
path = os.path.join(save_dir, "mel_spec")
plt.figure(figsize=(10, 3)) plt.figure(figsize=(10, 3))
display.specshow(mel_input) display.specshow(mel_input)
plt.colorbar() plt.colorbar()
plt.title("mel_input") plt.title("mel_input")
plt.savefig( plt.savefig(
os.path.join(path, os.path.join(path,
"target_mel_spec_step{:09d}".format(global_step))) "target_mel_spec_step{:09d}.png".format(global_step)))
plt.close() plt.close()
writer.add_image("target/mel_spec",
cm.viridis(mel_input),
global_step,
dataformats="HWC")
plt.figure(figsize=(10, 3)) plt.figure(figsize=(10, 3))
display.specshow(mel_output) display.specshow(mel_output)
plt.colorbar() plt.colorbar()
plt.title("mel_input") plt.title("mel_output")
plt.savefig( plt.savefig(
os.path.join(path, os.path.join(
"predicted_mel_spec_step{:09d}".format(global_step))) path, "predicted_mel_spec_step{:09d}.png".format(global_step)))
plt.close() plt.close()
writer.add_image("predicted/mel_spec",
cm.viridis(mel_output),
global_step,
dataformats="HWC")
if lin_input is not None and lin_output is not None: if lin_input is not None and lin_output is not None:
lin_input = lin_input[0].numpy().T
lin_output = lin_output[0].numpy().T
path = os.path.join(save_dir, "lin_spec") path = os.path.join(save_dir, "lin_spec")
if not os.path.exists(path):
os.makedirs(path)
plt.figure(figsize=(10, 3)) plt.figure(figsize=(10, 3))
display.specshow(lin_input) display.specshow(lin_input)
@ -246,28 +259,50 @@ def save_state(save_dir,
plt.title("mel_input") plt.title("mel_input")
plt.savefig( plt.savefig(
os.path.join(path, os.path.join(path,
"target_lin_spec_step{:09d}".format(global_step))) "target_lin_spec_step{:09d}.png".format(global_step)))
plt.close() plt.close()
writer.add_image("target/lin_spec",
cm.viridis(lin_input),
global_step,
dataformats="HWC")
plt.figure(figsize=(10, 3)) plt.figure(figsize=(10, 3))
display.specshow(lin_output) display.specshow(lin_output)
plt.colorbar() plt.colorbar()
plt.title("mel_input") plt.title("mel_input")
plt.savefig( plt.savefig(
os.path.join(path, os.path.join(
"predicted_lin_spec_step{:09d}".format(global_step))) path, "predicted_lin_spec_step{:09d}.png".format(global_step)))
plt.close() plt.close()
if alignments is not None and len(alignments.shape) == 3: writer.add_image("predicted/lin_spec",
path = os.path.join(save_dir, "alignments") cm.viridis(lin_output),
if not os.path.exists(path): global_step,
os.makedirs(path) dataformats="HWC")
plot_alignments(alignments, path, global_step)
if wav is not None: if alignments is not None and len(alignments.shape) == 4:
path = os.path.join(save_dir, "alignments")
alignments = alignments[:, 0, :, :].numpy()
for idx, attn_layer in enumerate(alignments):
save_path = os.path.join(
path,
"train_attn_layer_{}_step_{}.png".format(idx, global_step))
plot_alignment(attn_layer, save_path)
writer.add_image("train_attn/layer_{}".format(idx),
cm.viridis(attn_layer),
global_step,
dataformats="HWC")
if lin_output is not None:
wav = spec_to_waveform(lin_output, min_level_db, ref_level_db, power,
n_iter, win_length, hop_length, preemphasis)
path = os.path.join(save_dir, "waveform") path = os.path.join(save_dir, "waveform")
if not os.path.exists(path): save_path = os.path.join(
os.makedirs(path) path, "train_sample_step_{:09d}.wav".format(global_step))
sf.write( sf.write(save_path, wav, sample_rate)
os.path.join(path, "sample_step_{:09d}.wav".format(global_step)), writer.add_audio("train_sample",
wav, 22050) wav,
global_step,
sample_rate=sample_rate)

View File

@ -79,25 +79,26 @@ def unfold_adjacent_frames(folded_frames, r):
class Decoder(dg.Layer): class Decoder(dg.Layer):
def __init__(self, def __init__(
n_speakers, self,
speaker_dim, n_speakers,
embed_dim, speaker_dim,
mel_dim, embed_dim,
r=1, mel_dim,
max_positions=512, r=1,
padding_idx=None, max_positions=512,
preattention=(ConvSpec(128, 5, 1), ) * 4, padding_idx=None, # remove it!
convolutions=(ConvSpec(128, 5, 1), ) * 4, preattention=(ConvSpec(128, 5, 1), ) * 4,
attention=True, convolutions=(ConvSpec(128, 5, 1), ) * 4,
dropout=0.0, attention=True,
use_memory_mask=False, dropout=0.0,
force_monotonic_attention=False, use_memory_mask=False,
query_position_rate=1.0, force_monotonic_attention=False,
key_position_rate=1.0, query_position_rate=1.0,
window_range=WindowRange(-1, 3), key_position_rate=1.0,
key_projection=True, window_range=WindowRange(-1, 3),
value_projection=True): key_projection=True,
value_projection=True):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.dropout = dropout self.dropout = dropout
@ -109,21 +110,23 @@ class Decoder(dg.Layer):
self.n_speakers = n_speakers self.n_speakers = n_speakers
conv_channels = convolutions[0].out_channels conv_channels = convolutions[0].out_channels
# only when padding idx is 0 can we easilt handle it
self.embed_keys_positions = PositionEmbedding(max_positions, self.embed_keys_positions = PositionEmbedding(max_positions,
embed_dim, embed_dim,
padding_idx=padding_idx) padding_idx=0)
self.embed_query_positions = PositionEmbedding(max_positions, self.embed_query_positions = PositionEmbedding(max_positions,
conv_channels, conv_channels,
padding_idx=padding_idx) padding_idx=0)
if n_speakers > 1: if n_speakers > 1:
# CAUTION: mind the sigmoid
std = np.sqrt((1 - dropout) / speaker_dim) std = np.sqrt((1 - dropout) / speaker_dim)
self.speaker_proj1 = Linear(speaker_dim, self.speaker_proj1 = Linear(speaker_dim,
1, 1,
act="sigmoid",
param_attr=I.Normal(scale=std)) param_attr=I.Normal(scale=std))
self.speaker_proj2 = Linear(speaker_dim, self.speaker_proj2 = Linear(speaker_dim,
1, 1,
act="sigmoid",
param_attr=I.Normal(scale=std)) param_attr=I.Normal(scale=std))
# prenet # prenet
@ -168,6 +171,7 @@ class Decoder(dg.Layer):
] * len(convolutions) ] * len(convolutions)
else: else:
self.force_monotonic_attention = force_monotonic_attention self.force_monotonic_attention = force_monotonic_attention
for x, y in zip(self.force_monotonic_attention, self.attention): for x, y in zip(self.force_monotonic_attention, self.attention):
if x is True and y is False: if x is True and y is False:
raise ValueError("When not using attention, there is no " raise ValueError("When not using attention, there is no "
@ -249,7 +253,7 @@ class Decoder(dg.Layer):
text_positions (Variable): shape(B, T_enc), dtype: int64. text_positions (Variable): shape(B, T_enc), dtype: int64.
Positions indices for text inputs for the encoder, where Positions indices for text inputs for the encoder, where
T_enc means the encoder timesteps. T_enc means the encoder timesteps.
frame_positions (Variable): shape(B, T_dec // r), dtype: frame_positions (Variable): shape(B, T_mel // r), dtype:
int64. Positions indices for each decoder time steps. int64. Positions indices for each decoder time steps.
speaker_embed: shape(batch_size, speaker_dim), speaker embedding, speaker_embed: shape(batch_size, speaker_dim), speaker embedding,
only used for multispeaker model. only used for multispeaker model.
@ -287,16 +291,14 @@ class Decoder(dg.Layer):
if text_positions is not None: if text_positions is not None:
w = self.key_position_rate w = self.key_position_rate
if self.n_speakers > 1: if self.n_speakers > 1:
w = w * F.squeeze(F.sigmoid(self.speaker_proj1(speaker_embed)), w = w * F.squeeze(self.speaker_proj1(speaker_embed), [-1])
[-1])
text_pos_embed = self.embed_keys_positions(text_positions, w) text_pos_embed = self.embed_keys_positions(text_positions, w)
keys += text_pos_embed # (B, T, C) keys += text_pos_embed # (B, T, C)
if frame_positions is not None: if frame_positions is not None:
w = self.query_position_rate w = self.query_position_rate
if self.n_speakers > 1: if self.n_speakers > 1:
w = w * F.squeeze(F.sigmoid(self.speaker_proj2(speaker_embed)), w = w * F.squeeze(self.speaker_proj2(speaker_embed), [-1])
[-1])
frame_pos_embed = self.embed_query_positions(frame_positions, w) frame_pos_embed = self.embed_query_positions(frame_positions, w)
else: else:
frame_pos_embed = None frame_pos_embed = None
@ -387,8 +389,7 @@ class Decoder(dg.Layer):
w = self.key_position_rate w = self.key_position_rate
if self.n_speakers > 1: if self.n_speakers > 1:
# shape (B, ) # shape (B, )
w = w * F.squeeze(F.sigmoid(self.speaker_proj1(speaker_embed)), w = w * F.squeeze(self.speaker_proj1(speaker_embed), [-1])
[-1])
text_pos_embed = self.embed_keys_positions(text_positions, w) text_pos_embed = self.embed_keys_positions(text_positions, w)
keys += text_pos_embed # (B, T, C) keys += text_pos_embed # (B, T, C)
@ -417,8 +418,7 @@ class Decoder(dg.Layer):
dtype="int64") dtype="int64")
w = self.query_position_rate w = self.query_position_rate
if self.n_speakers > 1: if self.n_speakers > 1:
w = w * F.squeeze(F.sigmoid(self.speaker_proj2(speaker_embed)), w = w * F.squeeze(self.speaker_proj2(speaker_embed), [-1])
[-1])
# (B, T=1, C) # (B, T=1, C)
frame_pos_embed = self.embed_query_positions(frame_pos, w) frame_pos_embed = self.embed_query_positions(frame_pos, w)

View File

@ -35,9 +35,11 @@ class Encoder(dg.Layer):
std = np.sqrt((1 - dropout) / speaker_dim) std = np.sqrt((1 - dropout) / speaker_dim)
self.sp_proj1 = Linear(speaker_dim, self.sp_proj1 = Linear(speaker_dim,
embed_dim, embed_dim,
act="softsign",
param_attr=I.Normal(scale=std)) param_attr=I.Normal(scale=std))
self.sp_proj2 = Linear(speaker_dim, self.sp_proj2 = Linear(speaker_dim,
embed_dim, embed_dim,
act="softsign",
param_attr=I.Normal(scale=std)) param_attr=I.Normal(scale=std))
self.n_speakers = n_speakers self.n_speakers = n_speakers
@ -104,9 +106,7 @@ class Encoder(dg.Layer):
speaker_embed, speaker_embed,
self.dropout, self.dropout,
dropout_implementation="upscale_in_train") dropout_implementation="upscale_in_train")
x = F.elementwise_add(x, x = F.elementwise_add(x, self.sp_proj1(speaker_embed), axis=0)
F.softsign(self.sp_proj1(speaker_embed)),
axis=0)
input_embed = x input_embed = x
for layer in self.convolutions: for layer in self.convolutions:
@ -117,9 +117,7 @@ class Encoder(dg.Layer):
x = layer(x) x = layer(x)
if self.n_speakers > 1 and speaker_embed is not None: if self.n_speakers > 1 and speaker_embed is not None:
x = F.elementwise_add(x, x = F.elementwise_add(x, self.sp_proj2(speaker_embed), axis=0)
F.softsign(self.sp_proj2(speaker_embed)),
axis=0)
keys = x # (B, C, T) keys = x # (B, C, T)
values = F.scale(input_embed + x, scale=np.sqrt(0.5)) values = F.scale(input_embed + x, scale=np.sqrt(0.5))

View File

@ -156,8 +156,9 @@ class TTSLoss(object):
compute_mel_loss=True, compute_mel_loss=True,
compute_done_loss=True, compute_done_loss=True,
compute_attn_loss=True): compute_attn_loss=True):
total_loss = 0.
# n_frames # mel_lengths # decoder_lengths # n_frames # mel_lengths # decoder_lengths
# 4 个 loss 吧。lin(l1, bce, lin), mel(l1, bce, mel), attn, done
max_frames = lin_hyp.shape[1] max_frames = lin_hyp.shape[1]
max_mel_steps = max_frames // self.downsample_factor max_mel_steps = max_frames // self.downsample_factor
max_decoder_steps = max_mel_steps // self.r max_decoder_steps = max_mel_steps // self.r
@ -182,6 +183,7 @@ class TTSLoss(object):
lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask) lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask)
lin_loss = self.binary_divergence_weight * lin_bce_loss \ lin_loss = self.binary_divergence_weight * lin_bce_loss \
+ (1 - self.binary_divergence_weight) * lin_l1_loss + (1 - self.binary_divergence_weight) * lin_l1_loss
total_loss += lin_loss
if compute_mel_loss: if compute_mel_loss:
mel_hyp = mel_hyp[:, :-self.time_shift, :] mel_hyp = mel_hyp[:, :-self.time_shift, :]
@ -192,32 +194,28 @@ class TTSLoss(object):
# print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0]) # print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0])
mel_loss = self.binary_divergence_weight * mel_bce_loss \ mel_loss = self.binary_divergence_weight * mel_bce_loss \
+ (1 - self.binary_divergence_weight) * mel_l1_loss + (1 - self.binary_divergence_weight) * mel_l1_loss
total_loss += mel_loss
if compute_attn_loss: if compute_attn_loss:
attn_loss = self.attention_loss( attn_loss = self.attention_loss(
attn_hyp, input_lengths.numpy(), attn_hyp, input_lengths.numpy(),
n_frames.numpy() // (self.downsample_factor * self.r)) n_frames.numpy() // (self.downsample_factor * self.r))
total_loss += attn_loss
if compute_done_loss: if compute_done_loss:
done_loss = self.done_loss(done_hyp, done_ref) done_loss = self.done_loss(done_hyp, done_ref)
total_loss += done_loss
result = { result = {
"mel": mel_loss if compute_mel_loss else None, "loss": total_loss,
"mel_l1_loss": mel_l1_loss if compute_mel_loss else None, "mel/mel_loss": mel_loss if compute_mel_loss else None,
"mel_bce_loss": mel_bce_loss if compute_mel_loss else None, "mel/l1_loss": mel_l1_loss if compute_mel_loss else None,
"lin": lin_loss if compute_lin_loss else None, "mel/bce_loss": mel_bce_loss if compute_mel_loss else None,
"lin_l1_loss": lin_l1_loss if compute_lin_loss else None, "lin/lin_loss": lin_loss if compute_lin_loss else None,
"lin_bce_loss": lin_bce_loss if compute_lin_loss else None, "lin/l1_loss": lin_l1_loss if compute_lin_loss else None,
"lin/bce_loss": lin_bce_loss if compute_lin_loss else None,
"done": done_loss if compute_done_loss else None, "done": done_loss if compute_done_loss else None,
"attn": attn_loss if compute_attn_loss else None, "attn": attn_loss if compute_attn_loss else None,
} }
return result return result
@staticmethod
def compose_loss(result):
total_loss = 0.
for k in ["mel", "lin", "done", "attn"]:
if result[k] is not None:
total_loss += result[k]
return total_loss

View File

@ -6,7 +6,6 @@ import paddle.fluid.layers as F
from parakeet.modules import customized as L from parakeet.modules import customized as L
# TODO: just use numpy to init weight norm wrappers
def norm(param, dim, power): def norm(param, dim, power):
powered = F.pow(param, power) powered = F.pow(param, power)
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False) powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
@ -73,7 +72,7 @@ class WeightNormWrapper(dg.Layer):
w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype)) w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype))
F.assign(temp, getattr(self, w_g)) F.assign(temp, getattr(self, w_g))
# also set this # also set this when setting up
setattr( setattr(
self.layer, self.param_name, self.layer, self.param_name,
compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim, compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim,