modified fastspeech to make sure it works on paddle 1.8
This commit is contained in:
parent
55fa94f15d
commit
d1ba42ea68
|
@ -186,10 +186,4 @@ def batch_examples(batch):
|
|||
mels = np.transpose(
|
||||
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
|
||||
|
||||
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32)
|
||||
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
|
||||
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels, mels).astype(np.float32)
|
||||
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
|
||||
|
||||
return (texts, mels, pos_texts, pos_mels, enc_slf_mask, enc_query_mask,
|
||||
dec_slf_mask, dec_query_slf_mask, alignments)
|
||||
return (texts, mels, pos_texts, pos_mels, alignments)
|
||||
|
|
|
@ -28,7 +28,7 @@ from parakeet.models.fastspeech.fastspeech import FastSpeech
|
|||
from parakeet.models.transformer_tts.utils import *
|
||||
from parakeet.models.wavenet import WaveNet, UpsampleNet
|
||||
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
|
||||
from parakeet.utils.layer_tools import summary, freeze
|
||||
from parakeet.utils.layer_tools import freeze
|
||||
from parakeet.utils import io
|
||||
|
||||
|
||||
|
@ -82,22 +82,11 @@ def synthesis(text_input, args):
|
|||
text = np.expand_dims(text, axis=0)
|
||||
pos_text = np.arange(1, text.shape[1] + 1)
|
||||
pos_text = np.expand_dims(pos_text, axis=0)
|
||||
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
|
||||
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32)
|
||||
|
||||
text = dg.to_variable(text)
|
||||
pos_text = dg.to_variable(pos_text)
|
||||
enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
|
||||
enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
|
||||
|
||||
_, mel_output_postnet = model(
|
||||
text,
|
||||
pos_text,
|
||||
alpha=args.alpha,
|
||||
enc_non_pad_mask=enc_non_pad_mask,
|
||||
enc_slf_attn_mask=enc_slf_attn_mask,
|
||||
dec_non_pad_mask=None,
|
||||
dec_slf_attn_mask=None)
|
||||
_, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
|
||||
|
||||
result = np.exp(mel_output_postnet.numpy())
|
||||
mel_output_postnet = fluid.layers.transpose(
|
||||
|
@ -186,7 +175,6 @@ def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place):
|
|||
lmd = config["loss"]["lmd"]
|
||||
model = Clarinet(upsample_net, teacher, student, stft,
|
||||
student_log_scale_min, lmd)
|
||||
summary(model)
|
||||
io.load_parameters(model=model, checkpoint_path=checkpoint)
|
||||
|
||||
if not os.path.exists(args.output):
|
||||
|
|
|
@ -79,7 +79,9 @@ def main(args):
|
|||
(cfg['train']['warm_up_step'] *
|
||||
(cfg['train']['learning_rate']**2)),
|
||||
cfg['train']['warm_up_step']),
|
||||
parameter_list=model.parameters())
|
||||
parameter_list=model.parameters(),
|
||||
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
|
||||
'grad_clip_thresh']))
|
||||
reader = LJSpeechLoader(
|
||||
cfg['audio'],
|
||||
place,
|
||||
|
@ -108,9 +110,7 @@ def main(args):
|
|||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
(character, mel, pos_text, pos_mel, enc_slf_mask,
|
||||
enc_query_mask, dec_slf_mask, dec_query_slf_mask,
|
||||
alignment) = data
|
||||
(character, mel, pos_text, pos_mel, alignment) = data
|
||||
|
||||
global_step += 1
|
||||
|
||||
|
@ -119,11 +119,7 @@ def main(args):
|
|||
character,
|
||||
pos_text,
|
||||
mel_pos=pos_mel,
|
||||
length_target=alignment,
|
||||
enc_non_pad_mask=enc_query_mask,
|
||||
enc_slf_attn_mask=enc_slf_mask,
|
||||
dec_non_pad_mask=dec_query_slf_mask,
|
||||
dec_slf_attn_mask=dec_slf_mask)
|
||||
length_target=alignment)
|
||||
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
|
||||
mel_loss = layers.mse_loss(mel_output, mel)
|
||||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||
|
@ -150,10 +146,7 @@ def main(args):
|
|||
model.apply_collective_grads()
|
||||
else:
|
||||
total_loss.backward()
|
||||
optimizer.minimize(
|
||||
total_loss,
|
||||
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
|
||||
'train']['grad_clip_thresh']))
|
||||
optimizer.minimize(total_loss)
|
||||
model.clear_gradients()
|
||||
|
||||
# save checkpoint
|
||||
|
|
|
@ -70,7 +70,7 @@ class Decoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
||||
def forward(self, enc_seq, enc_pos):
|
||||
"""
|
||||
Compute decoder outputs.
|
||||
|
||||
|
@ -79,17 +79,26 @@ class Decoder(dg.Layer):
|
|||
the output of length regulator, where T_mel means the timesteps of input spectrum.
|
||||
enc_pos (Variable): shape(B, T_mel), dtype int64,
|
||||
the spectrum position.
|
||||
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dec_output (Variable): shape(B, T_mel, C), the decoder output.
|
||||
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
if slf_attn_mask:
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
slf_attn_mask = get_dec_attn_key_pad_mask(enc_pos, self.n_head,
|
||||
enc_seq.dtype)
|
||||
|
||||
else:
|
||||
len_q = enc_seq.shape[1]
|
||||
slf_attn_mask = layers.triu(
|
||||
layers.ones(
|
||||
shape=[len_q, len_q], dtype=enc_seq.dtype),
|
||||
diagonal=1)
|
||||
slf_attn_mask = layers.cast(
|
||||
slf_attn_mask != 0, dtype=enc_seq.dtype) * -1e30
|
||||
|
||||
non_pad_mask = get_non_pad_mask(enc_pos, 1, enc_seq.dtype)
|
||||
|
||||
# -- Forward
|
||||
dec_output = enc_seq + self.position_enc(enc_pos)
|
||||
|
|
|
@ -76,7 +76,7 @@ class Encoder(dg.Layer):
|
|||
for i, layer in enumerate(self.layer_stack):
|
||||
self.add_sublayer('fft_{}'.format(i), layer)
|
||||
|
||||
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
||||
def forward(self, character, text_pos):
|
||||
"""
|
||||
Encode text sequence.
|
||||
|
||||
|
@ -84,22 +84,21 @@ class Encoder(dg.Layer):
|
|||
character (Variable): shape(B, T_text), dtype float32, the input text characters,
|
||||
where T_text means the timesteps of input characters,
|
||||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
|
||||
the mask of input characters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
enc_output (Variable): shape(B, T_text, C), the encoder output.
|
||||
non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
|
||||
enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
||||
# -- Forward
|
||||
enc_output = self.src_word_emb(character) + self.position_enc(
|
||||
text_pos) #(N, T, C)
|
||||
|
||||
slf_attn_mask = get_attn_key_pad_mask(text_pos, self.n_head,
|
||||
enc_output.dtype)
|
||||
non_pad_mask = get_non_pad_mask(text_pos, 1, enc_output.dtype)
|
||||
|
||||
for enc_layer in self.layer_stack:
|
||||
enc_output, enc_slf_attn = enc_layer(
|
||||
enc_output,
|
||||
|
|
|
@ -86,11 +86,7 @@ class FastSpeech(dg.Layer):
|
|||
def forward(self,
|
||||
character,
|
||||
text_pos,
|
||||
enc_non_pad_mask,
|
||||
dec_non_pad_mask,
|
||||
mel_pos=None,
|
||||
enc_slf_attn_mask=None,
|
||||
dec_slf_attn_mask=None,
|
||||
length_target=None,
|
||||
alpha=1.0):
|
||||
"""
|
||||
|
@ -102,12 +98,6 @@ class FastSpeech(dg.Layer):
|
|||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
|
||||
dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
|
||||
the mask of input characters. Defaults to None.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
length_target (Variable, optional): shape(B, T_text), dtype int64,
|
||||
the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
|
||||
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
|
||||
|
@ -121,19 +111,12 @@ class FastSpeech(dg.Layer):
|
|||
dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, enc_slf_attn_list = self.encoder(
|
||||
character,
|
||||
text_pos,
|
||||
enc_non_pad_mask,
|
||||
slf_attn_mask=enc_slf_attn_mask)
|
||||
encoder_output, enc_slf_attn_list = self.encoder(character, text_pos)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
length_regulator_output, duration_predictor_output = self.length_regulator(
|
||||
encoder_output, target=length_target, alpha=alpha)
|
||||
decoder_output, dec_slf_attn_list = self.decoder(
|
||||
length_regulator_output,
|
||||
mel_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=dec_slf_attn_mask)
|
||||
length_regulator_output, mel_pos)
|
||||
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
@ -142,19 +125,8 @@ class FastSpeech(dg.Layer):
|
|||
else:
|
||||
length_regulator_output, decoder_pos = self.length_regulator(
|
||||
encoder_output, alpha=alpha)
|
||||
slf_attn_mask = get_triu_tensor(
|
||||
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32)
|
||||
slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
|
||||
slf_attn_mask = fluid.layers.cast(
|
||||
dg.to_variable(slf_attn_mask != 0), np.float32) * (-2**32 + 1)
|
||||
slf_attn_mask = dg.to_variable(slf_attn_mask)
|
||||
dec_non_pad_mask = fluid.layers.unsqueeze(
|
||||
(decoder_pos != 0).astype(np.float32), [-1])
|
||||
decoder_output, _ = self.decoder(
|
||||
length_regulator_output,
|
||||
decoder_pos,
|
||||
dec_non_pad_mask,
|
||||
slf_attn_mask=slf_attn_mask)
|
||||
decoder_output, _ = self.decoder(length_regulator_output,
|
||||
decoder_pos)
|
||||
mel_output = self.mel_linear(decoder_output)
|
||||
mel_output_postnet = self.postnet(mel_output) + mel_output
|
||||
|
||||
|
|
Loading…
Reference in New Issue