add an optional to alter the loss and model structure of tacotron2, add an alternative config
This commit is contained in:
parent
4fc86abf5a
commit
263d3eb88b
|
@ -23,8 +23,8 @@ _C.data = CN(
|
|||
n_fft=1024, # fft frame size
|
||||
win_length=1024, # window size
|
||||
hop_length=256, # hop size between ajacent frame
|
||||
f_max=8000, # Hz, max frequency when converting to mel
|
||||
f_min=0, # Hz, min frequency when converting to mel
|
||||
fmax=8000, # Hz, max frequency when converting to mel
|
||||
fmin=0, # Hz, min frequency when converting to mel
|
||||
d_mels=80, # mel bands
|
||||
padding_idx=0, # text embedding's padding index
|
||||
))
|
||||
|
@ -50,7 +50,9 @@ _C.model = CN(
|
|||
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
||||
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
||||
p_postnet_dropout=0.5, # droput probability in decoder postnet
|
||||
guided_attn_loss_sigma=0.2 # sigma in guided attention loss
|
||||
use_stop_token=True, # wherther to use binary classifier to predict when to stop
|
||||
use_guided_attention_loss=False, # whether to use guided attention loss
|
||||
guided_attention_loss_sigma=0.2 # sigma in guided attention loss
|
||||
))
|
||||
|
||||
_C.training = CN(
|
||||
|
|
|
@ -58,7 +58,7 @@ class LJSpeechCollector(object):
|
|||
mels = []
|
||||
text_lens = []
|
||||
mel_lens = []
|
||||
stop_tokens = []
|
||||
|
||||
for data in examples:
|
||||
text, mel = data
|
||||
text = np.array(text, dtype=np.int64)
|
||||
|
@ -66,8 +66,6 @@ class LJSpeechCollector(object):
|
|||
mels.append(mel)
|
||||
texts.append(text)
|
||||
mel_lens.append(mel.shape[1])
|
||||
stop_token = np.zeros([mel.shape[1] - 1], dtype=np.float32)
|
||||
stop_tokens.append(np.append(stop_token, 1.0))
|
||||
|
||||
# Sort by text_len in descending order
|
||||
texts = [
|
||||
|
@ -86,15 +84,13 @@ class LJSpeechCollector(object):
|
|||
for i, _ in sorted(
|
||||
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
mel_lens = np.array(mel_lens, dtype=np.int64)
|
||||
|
||||
|
||||
text_lens = np.array(sorted(text_lens, reverse=True), dtype=np.int64)
|
||||
|
||||
# Pad sequence with largest len of the batch
|
||||
texts, _ = batch_text_id(texts, pad_id=self.padding_idx)
|
||||
mels, _ = np.transpose(
|
||||
batch_spec(
|
||||
mels, pad_value=self.padding_value), axes=(0, 2, 1))
|
||||
mels, _ = batch_spec(mels, pad_value=self.padding_value)
|
||||
mels = np.transpose(mels, axes=(0, 2, 1))
|
||||
|
||||
return texts, mels, text_lens, mel_lens
|
||||
|
|
|
@ -39,8 +39,8 @@ def create_dataset(config, source_path, target_path, verbose=False):
|
|||
n_mels=config.data.d_mels,
|
||||
win_length=config.data.win_length,
|
||||
hop_length=config.data.hop_length,
|
||||
f_max=config.data.f_max,
|
||||
f_min=config.data.f_min)
|
||||
f_max=config.data.fmax,
|
||||
f_min=config.data.fmin)
|
||||
normalizer = LogMagnitude()
|
||||
|
||||
records = []
|
||||
|
|
|
@ -20,6 +20,8 @@ import paddle
|
|||
import parakeet
|
||||
from parakeet.frontend import EnglishCharacter
|
||||
from parakeet.models.tacotron2 import Tacotron2
|
||||
from parakeet.utils import display
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from config import get_cfg_defaults
|
||||
|
||||
|
@ -46,10 +48,13 @@ def main(config, args):
|
|||
for i, sentence in enumerate(sentences):
|
||||
sentence = paddle.to_tensor(frontend(sentence)).unsqueeze(0)
|
||||
|
||||
mel_output, _ = model.predict(sentence)
|
||||
mel_output = mel_output["mel_outputs_postnet"][0].numpy().T
|
||||
outputs = model.infer(sentence)
|
||||
mel_output = outputs["mel_outputs_postnet"][0].numpy().T
|
||||
alignment = outputs["alignments"][0].numpy().T
|
||||
|
||||
np.save(str(output_dir / f"sentence_{i}"), mel_output)
|
||||
display.plot_alignment(alignment)
|
||||
plt.savefig(str(output_dir / f"sentence_{i}.png"))
|
||||
if args.verbose:
|
||||
print("spectrogram saved at {}".format(output_dir /
|
||||
f"sentence_{i}.npy"))
|
||||
|
|
|
@ -34,14 +34,18 @@ from ljspeech import LJSpeech, LJSpeechCollector
|
|||
|
||||
class Experiment(ExperimentBase):
|
||||
def compute_losses(self, inputs, outputs):
|
||||
_, mel_targets, plens, slens, stop_tokens = inputs
|
||||
texts, mel_targets, plens, slens = inputs
|
||||
|
||||
mel_outputs = outputs["mel_output"]
|
||||
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
||||
attention_weight = outputs["alignments"]
|
||||
if self.config.model.use_stop_token:
|
||||
stop_logits = outputs["stop_logits"]
|
||||
else:
|
||||
stop_logits = None
|
||||
|
||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
|
||||
attention_weight, slens, plens)
|
||||
attention_weight, slens, plens, stop_logits)
|
||||
return losses
|
||||
|
||||
def train_batch(self):
|
||||
|
@ -52,7 +56,7 @@ class Experiment(ExperimentBase):
|
|||
self.optimizer.clear_grad()
|
||||
self.model.train()
|
||||
texts, mels, text_lens, output_lens = batch
|
||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
||||
outputs = self.model(texts, text_lens, mels, output_lens)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
loss = losses["loss"]
|
||||
loss.backward()
|
||||
|
@ -80,7 +84,7 @@ class Experiment(ExperimentBase):
|
|||
valid_losses = defaultdict(list)
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
texts, mels, text_lens, output_lens = batch
|
||||
outputs = self.model(texts, mels, text_lens, output_lens)
|
||||
outputs = self.model(texts, text_lens, mels, output_lens)
|
||||
losses = self.compute_losses(batch, outputs)
|
||||
for k, v in losses.items():
|
||||
valid_losses[k].append(float(v))
|
||||
|
@ -88,7 +92,7 @@ class Experiment(ExperimentBase):
|
|||
attention_weights = outputs["alignments"]
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_alignments",
|
||||
display.plot_alignment(attention_weights[0].numpy()),
|
||||
display.plot_alignment(attention_weights[0].numpy().T),
|
||||
self.iteration)
|
||||
self.visualizer.add_figure(
|
||||
f"valid_sentence_{i}_target_spectrogram",
|
||||
|
@ -114,9 +118,8 @@ class Experiment(ExperimentBase):
|
|||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
frontend = EnglishCharacter()
|
||||
model = Tacotron2(
|
||||
frontend,
|
||||
vocab_size=config.model.vocab_size,
|
||||
d_mels=config.data.d_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
|
@ -135,7 +138,8 @@ class Experiment(ExperimentBase):
|
|||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout)
|
||||
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||
use_stop_token=config.model.use_stop_token)
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
@ -148,7 +152,10 @@ class Experiment(ExperimentBase):
|
|||
weight_decay=paddle.regularizer.L2Decay(
|
||||
config.training.weight_decay),
|
||||
grad_clip=grad_clip)
|
||||
criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma)
|
||||
criterion = Tacotron2Loss(
|
||||
use_stop_token_loss=config.model.use_stop_token,
|
||||
use_guided_attention_loss=config.model.use_guided_attention_loss,
|
||||
sigma=config.model.guided_attention_loss_sigma)
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
|
|
|
@ -23,6 +23,7 @@ from paddle.fluid.layers import sequence_mask
|
|||
from parakeet.modules.conv import Conv1dBatchNorm
|
||||
from parakeet.modules.attention import LocationSensitiveAttention
|
||||
from parakeet.modules.losses import guided_attention_loss
|
||||
from parakeet.utils import checkpoint
|
||||
from tqdm import trange
|
||||
|
||||
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||
|
@ -274,7 +275,7 @@ class Tacotron2Decoder(nn.Layer):
|
|||
d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int,
|
||||
d_attention: int, attention_filters: int,
|
||||
attention_kernel_size: int, p_prenet_dropout: float,
|
||||
p_attention_dropout: float, p_decoder_dropout: float):
|
||||
p_attention_dropout: float, p_decoder_dropout: float, use_stop_token: bool=False):
|
||||
super().__init__()
|
||||
self.d_mels = d_mels
|
||||
self.reduction_factor = reduction_factor
|
||||
|
@ -303,6 +304,10 @@ class Tacotron2Decoder(nn.Layer):
|
|||
d_decoder_rnn)
|
||||
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
||||
d_mels * reduction_factor)
|
||||
|
||||
self.use_stop_token = use_stop_token
|
||||
if use_stop_token:
|
||||
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||
|
||||
# states - temporary attributes
|
||||
self.attention_hidden = None
|
||||
|
@ -379,6 +384,9 @@ class Tacotron2Decoder(nn.Layer):
|
|||
[self.decoder_hidden, self.attention_context], axis=-1)
|
||||
decoder_output = self.linear_projection(
|
||||
decoder_hidden_attention_context)
|
||||
if self.use_stop_token:
|
||||
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
||||
return decoder_output, self.attention_weights, stop_logit
|
||||
return decoder_output, self.attention_weights
|
||||
|
||||
def forward(self, keys, querys, mask):
|
||||
|
@ -416,16 +424,24 @@ class Tacotron2Decoder(nn.Layer):
|
|||
querys = self.prenet(querys)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_logits = []
|
||||
# Ignore the last time step
|
||||
while len(mel_outputs) < querys.shape[1] - 1:
|
||||
query = querys[:, len(mel_outputs), :]
|
||||
mel_output, attention_weights = self._decode(query)
|
||||
if self.use_stop_token:
|
||||
mel_output, attention_weights, stop_logit = self._decode(query)
|
||||
else:
|
||||
mel_output, attention_weights = self._decode(query)
|
||||
mel_outputs.append(mel_output)
|
||||
alignments.append(attention_weights)
|
||||
if self.use_stop_token:
|
||||
stop_logits.append(stop_logit)
|
||||
|
||||
alignments = paddle.stack(alignments, axis=1)
|
||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||
|
||||
if self.use_stop_token:
|
||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||
return mel_outputs, alignments, stop_logits
|
||||
return mel_outputs, alignments
|
||||
|
||||
def infer(self, key, max_decoder_steps=1000):
|
||||
|
@ -460,19 +476,30 @@ class Tacotron2Decoder(nn.Layer):
|
|||
first_hit_end = None
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_logits = []
|
||||
for i in trange(max_decoder_steps):
|
||||
query = self.prenet(query)
|
||||
mel_output, alignment = self._decode(query)
|
||||
if self.use_stop_token:
|
||||
mel_output, alignment, stop_logit = self._decode(query)
|
||||
else:
|
||||
mel_output, alignment = self._decode(query)
|
||||
|
||||
mel_outputs.append(mel_output)
|
||||
alignments.append(alignment) # (B=1, T)
|
||||
if self.use_stop_token:
|
||||
stop_logits.append(stop_logit)
|
||||
|
||||
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
|
||||
if first_hit_end is None:
|
||||
first_hit_end = i
|
||||
if first_hit_end is not None and i > (first_hit_end + 10):
|
||||
print("content exhausted!")
|
||||
break
|
||||
if self.use_stop_token:
|
||||
if F.sigmoid(stop_logit) > 0.5:
|
||||
print("hit stop condition!")
|
||||
break
|
||||
else:
|
||||
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
|
||||
if first_hit_end is None:
|
||||
first_hit_end = i
|
||||
elif i > (first_hit_end + 10):
|
||||
print("content exhausted!")
|
||||
break
|
||||
if len(mel_outputs) == max_decoder_steps:
|
||||
print("Warning! Reached max decoder steps!!!")
|
||||
break
|
||||
|
@ -481,7 +508,9 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
alignments = paddle.stack(alignments, axis=1)
|
||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||
|
||||
if self.use_stop_token:
|
||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||
return mel_outputs, alignments, stop_logits
|
||||
return mel_outputs, alignments
|
||||
|
||||
|
||||
|
@ -580,7 +609,8 @@ class Tacotron2(nn.Layer):
|
|||
p_attention_dropout: float = 0.1,
|
||||
p_decoder_dropout: float = 0.1,
|
||||
p_postnet_dropout: float = 0.5,
|
||||
d_global_condition=None):
|
||||
d_global_condition=None,
|
||||
use_stop_token=True):
|
||||
super().__init__()
|
||||
|
||||
std = math.sqrt(2.0 / (vocab_size + d_encoder))
|
||||
|
@ -606,7 +636,7 @@ class Tacotron2(nn.Layer):
|
|||
d_mels, reduction_factor, d_encoder, d_prenet, d_attention_rnn,
|
||||
d_decoder_rnn, d_attention, attention_filters,
|
||||
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||
p_decoder_dropout)
|
||||
p_decoder_dropout, use_stop_token=use_stop_token)
|
||||
self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor,
|
||||
d_hidden=d_postnet,
|
||||
kernel_size=postnet_kernel_size,
|
||||
|
@ -664,10 +694,14 @@ class Tacotron2(nn.Layer):
|
|||
# [B, T_enc, 1]
|
||||
mask = sequence_mask(text_lens,
|
||||
dtype=encoder_outputs.dtype).unsqueeze(-1)
|
||||
mel_outputs, alignments = self.decoder(encoder_outputs,
|
||||
mels,
|
||||
mask=mask)
|
||||
|
||||
if self.decoder.use_stop_token:
|
||||
mel_outputs, alignments, stop_logits = self.decoder(encoder_outputs,
|
||||
mels,
|
||||
mask=mask)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder(encoder_outputs,
|
||||
mels,
|
||||
mask=mask)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
||||
|
@ -681,6 +715,8 @@ class Tacotron2(nn.Layer):
|
|||
"mel_outputs_postnet": mel_outputs_postnet,
|
||||
"alignments": alignments
|
||||
}
|
||||
if self.decoder.use_stop_token:
|
||||
outputs["stop_logits"] = stop_logits
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -723,8 +759,12 @@ class Tacotron2(nn.Layer):
|
|||
global_condition, [-1, encoder_outputs.shape[1], -1])
|
||||
encoder_outputs = paddle.concat(
|
||||
[encoder_outputs, global_condition], -1)
|
||||
mel_outputs, alignments = self.decoder.infer(
|
||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||
if self.decoder.use_stop_token:
|
||||
mel_outputs, alignments, stop_logits = self.decoder.infer(
|
||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder.infer(
|
||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
|
@ -734,22 +774,72 @@ class Tacotron2(nn.Layer):
|
|||
"mel_outputs_postnet": mel_outputs_postnet,
|
||||
"alignments": alignments
|
||||
}
|
||||
if self.decoder.use_stop_token:
|
||||
outputs["stop_logits"] = stop_logits
|
||||
|
||||
return outputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config, checkpoint_path):
|
||||
"""Build a Tacotron2 model from a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
ConditionalWaveFlow
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(
|
||||
vocab_size=config.model.vocab_size,
|
||||
d_mels=config.data.d_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||
d_prenet=config.model.d_prenet,
|
||||
d_attention_rnn=config.model.d_attention_rnn,
|
||||
d_decoder_rnn=config.model.d_decoder_rnn,
|
||||
attention_filters=config.model.attention_filters,
|
||||
attention_kernel_size=config.model.attention_kernel_size,
|
||||
d_attention=config.model.d_attention,
|
||||
d_postnet=config.model.d_postnet,
|
||||
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||
postnet_conv_layers=config.model.postnet_conv_layers,
|
||||
reduction_factor=config.model.reduction_factor,
|
||||
p_encoder_dropout=config.model.p_encoder_dropout,
|
||||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout,
|
||||
use_stop_token=config.model.use_stop_token)
|
||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
class Tacotron2Loss(nn.Layer):
|
||||
""" Tacotron2 Loss module
|
||||
"""
|
||||
|
||||
def __init__(self, sigma=0.2):
|
||||
def __init__(self,
|
||||
use_stop_token_loss=True,
|
||||
use_guided_attention_loss=False,
|
||||
sigma=0.2):
|
||||
super().__init__()
|
||||
self.spec_criterion = nn.MSELoss()
|
||||
self.use_stop_token_loss = use_stop_token_loss
|
||||
self.use_guided_attention_loss = use_guided_attention_loss
|
||||
self.attn_criterion = guided_attention_loss
|
||||
self.stop_criterion = paddle.nn.BCEWithLogitsLoss()
|
||||
self.sigma = sigma
|
||||
|
||||
def forward(self, mel_outputs, mel_outputs_postnet, mel_targets,
|
||||
attention_weights, slens, plens):
|
||||
attention_weights=None, slens=None, plens=None, stop_logits=None):
|
||||
"""Calculate tacotron2 loss.
|
||||
|
||||
Parameters
|
||||
|
@ -775,13 +865,24 @@ class Tacotron2Loss(nn.Layer):
|
|||
"""
|
||||
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
|
||||
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
|
||||
gal_loss = self.attn_criterion(attention_weights, slens, plens,
|
||||
total_loss = mel_loss + post_mel_loss
|
||||
if self.use_guided_attention_loss:
|
||||
gal_loss = self.attn_criterion(attention_weights, slens, plens,
|
||||
self.sigma)
|
||||
total_loss = mel_loss + post_mel_loss + gal_loss
|
||||
total_loss += gal_loss
|
||||
if self.use_stop_token_loss:
|
||||
T_dec = mel_targets.shape[1]
|
||||
stop_labels = F.one_hot(slens - 1, num_classes=T_dec)
|
||||
stop_token_loss = self.stop_criterion(stop_logits, stop_labels)
|
||||
total_loss += stop_token_loss
|
||||
|
||||
losses = {
|
||||
"loss": total_loss,
|
||||
"mel_loss": mel_loss,
|
||||
"post_mel_loss": post_mel_loss,
|
||||
"guided_attn_loss": gal_loss
|
||||
"post_mel_loss": post_mel_loss
|
||||
}
|
||||
if self.use_guided_attention_loss:
|
||||
losses["guided_attn_loss"] = gal_loss
|
||||
if self.use_stop_token_loss:
|
||||
losses["stop_loss"] = stop_token_loss
|
||||
return losses
|
||||
|
|
Loading…
Reference in New Issue