modified fastspeech to make sure it works on paddle 1.8

This commit is contained in:
lifuchen 2020-05-08 03:58:45 +00:00
parent 55fa94f15d
commit d1ba42ea68
6 changed files with 33 additions and 78 deletions

View File

@ -186,10 +186,4 @@ def batch_examples(batch):
mels = np.transpose( mels = np.transpose(
SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels) SpecBatcher(pad_value=0.)(mels), axes=(0, 2, 1)) #(B,T,num_mels)
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32) return (texts, mels, pos_texts, pos_mels, alignments)
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels, mels).astype(np.float32)
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
return (texts, mels, pos_texts, pos_mels, enc_slf_mask, enc_query_mask,
dec_slf_mask, dec_query_slf_mask, alignments)

View File

@ -28,7 +28,7 @@ from parakeet.models.fastspeech.fastspeech import FastSpeech
from parakeet.models.transformer_tts.utils import * from parakeet.models.transformer_tts.utils import *
from parakeet.models.wavenet import WaveNet, UpsampleNet from parakeet.models.wavenet import WaveNet, UpsampleNet
from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet
from parakeet.utils.layer_tools import summary, freeze from parakeet.utils.layer_tools import freeze
from parakeet.utils import io from parakeet.utils import io
@ -82,22 +82,11 @@ def synthesis(text_input, args):
text = np.expand_dims(text, axis=0) text = np.expand_dims(text, axis=0)
pos_text = np.arange(1, text.shape[1] + 1) pos_text = np.arange(1, text.shape[1] + 1)
pos_text = np.expand_dims(pos_text, axis=0) pos_text = np.expand_dims(pos_text, axis=0)
enc_non_pad_mask = get_non_pad_mask(pos_text).astype(np.float32)
enc_slf_attn_mask = get_attn_key_pad_mask(pos_text).astype(np.float32)
text = dg.to_variable(text) text = dg.to_variable(text)
pos_text = dg.to_variable(pos_text) pos_text = dg.to_variable(pos_text)
enc_non_pad_mask = dg.to_variable(enc_non_pad_mask)
enc_slf_attn_mask = dg.to_variable(enc_slf_attn_mask)
_, mel_output_postnet = model( _, mel_output_postnet = model(text, pos_text, alpha=args.alpha)
text,
pos_text,
alpha=args.alpha,
enc_non_pad_mask=enc_non_pad_mask,
enc_slf_attn_mask=enc_slf_attn_mask,
dec_non_pad_mask=None,
dec_slf_attn_mask=None)
result = np.exp(mel_output_postnet.numpy()) result = np.exp(mel_output_postnet.numpy())
mel_output_postnet = fluid.layers.transpose( mel_output_postnet = fluid.layers.transpose(
@ -186,7 +175,6 @@ def synthesis_with_clarinet(config_path, checkpoint, mel_spectrogram, place):
lmd = config["loss"]["lmd"] lmd = config["loss"]["lmd"]
model = Clarinet(upsample_net, teacher, student, stft, model = Clarinet(upsample_net, teacher, student, stft,
student_log_scale_min, lmd) student_log_scale_min, lmd)
summary(model)
io.load_parameters(model=model, checkpoint_path=checkpoint) io.load_parameters(model=model, checkpoint_path=checkpoint)
if not os.path.exists(args.output): if not os.path.exists(args.output):

View File

@ -79,7 +79,9 @@ def main(args):
(cfg['train']['warm_up_step'] * (cfg['train']['warm_up_step'] *
(cfg['train']['learning_rate']**2)), (cfg['train']['learning_rate']**2)),
cfg['train']['warm_up_step']), cfg['train']['warm_up_step']),
parameter_list=model.parameters()) parameter_list=model.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
'grad_clip_thresh']))
reader = LJSpeechLoader( reader = LJSpeechLoader(
cfg['audio'], cfg['audio'],
place, place,
@ -108,9 +110,7 @@ def main(args):
for i, data in enumerate(pbar): for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch) pbar.set_description('Processing at epoch %d' % epoch)
(character, mel, pos_text, pos_mel, enc_slf_mask, (character, mel, pos_text, pos_mel, alignment) = data
enc_query_mask, dec_slf_mask, dec_query_slf_mask,
alignment) = data
global_step += 1 global_step += 1
@ -119,11 +119,7 @@ def main(args):
character, character,
pos_text, pos_text,
mel_pos=pos_mel, mel_pos=pos_mel,
length_target=alignment, length_target=alignment)
enc_non_pad_mask=enc_query_mask,
enc_slf_attn_mask=enc_slf_mask,
dec_non_pad_mask=dec_query_slf_mask,
dec_slf_attn_mask=dec_slf_mask)
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
mel_loss = layers.mse_loss(mel_output, mel) mel_loss = layers.mse_loss(mel_output, mel)
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel) mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
@ -150,10 +146,7 @@ def main(args):
model.apply_collective_grads() model.apply_collective_grads()
else: else:
total_loss.backward() total_loss.backward()
optimizer.minimize( optimizer.minimize(total_loss)
total_loss,
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
'train']['grad_clip_thresh']))
model.clear_gradients() model.clear_gradients()
# save checkpoint # save checkpoint

View File

@ -70,7 +70,7 @@ class Decoder(dg.Layer):
for i, layer in enumerate(self.layer_stack): for i, layer in enumerate(self.layer_stack):
self.add_sublayer('fft_{}'.format(i), layer) self.add_sublayer('fft_{}'.format(i), layer)
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): def forward(self, enc_seq, enc_pos):
""" """
Compute decoder outputs. Compute decoder outputs.
@ -79,17 +79,26 @@ class Decoder(dg.Layer):
the output of length regulator, where T_mel means the timesteps of input spectrum. the output of length regulator, where T_mel means the timesteps of input spectrum.
enc_pos (Variable): shape(B, T_mel), dtype int64, enc_pos (Variable): shape(B, T_mel), dtype int64,
the spectrum position. the spectrum position.
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
the mask of mel spectrum. Defaults to None.
Returns: Returns:
dec_output (Variable): shape(B, T_mel, C), the decoder output. dec_output (Variable): shape(B, T_mel, C), the decoder output.
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list. dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
""" """
dec_slf_attn_list = [] dec_slf_attn_list = []
if slf_attn_mask: if fluid.framework._dygraph_tracer()._train_mode:
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) slf_attn_mask = get_dec_attn_key_pad_mask(enc_pos, self.n_head,
enc_seq.dtype)
else:
len_q = enc_seq.shape[1]
slf_attn_mask = layers.triu(
layers.ones(
shape=[len_q, len_q], dtype=enc_seq.dtype),
diagonal=1)
slf_attn_mask = layers.cast(
slf_attn_mask != 0, dtype=enc_seq.dtype) * -1e30
non_pad_mask = get_non_pad_mask(enc_pos, 1, enc_seq.dtype)
# -- Forward # -- Forward
dec_output = enc_seq + self.position_enc(enc_pos) dec_output = enc_seq + self.position_enc(enc_pos)

View File

@ -76,7 +76,7 @@ class Encoder(dg.Layer):
for i, layer in enumerate(self.layer_stack): for i, layer in enumerate(self.layer_stack):
self.add_sublayer('fft_{}'.format(i), layer) self.add_sublayer('fft_{}'.format(i), layer)
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): def forward(self, character, text_pos):
""" """
Encode text sequence. Encode text sequence.
@ -84,22 +84,21 @@ class Encoder(dg.Layer):
character (Variable): shape(B, T_text), dtype float32, the input text characters, character (Variable): shape(B, T_text), dtype float32, the input text characters,
where T_text means the timesteps of input characters, where T_text means the timesteps of input characters,
text_pos (Variable): shape(B, T_text), dtype int64, the input text position. text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
the mask of input characters. Defaults to None.
Returns: Returns:
enc_output (Variable): shape(B, T_text, C), the encoder output. enc_output (Variable): shape(B, T_text, C), the encoder output.
non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list. enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
""" """
enc_slf_attn_list = [] enc_slf_attn_list = []
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
# -- Forward # -- Forward
enc_output = self.src_word_emb(character) + self.position_enc( enc_output = self.src_word_emb(character) + self.position_enc(
text_pos) #(N, T, C) text_pos) #(N, T, C)
slf_attn_mask = get_attn_key_pad_mask(text_pos, self.n_head,
enc_output.dtype)
non_pad_mask = get_non_pad_mask(text_pos, 1, enc_output.dtype)
for enc_layer in self.layer_stack: for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer( enc_output, enc_slf_attn = enc_layer(
enc_output, enc_output,

View File

@ -86,11 +86,7 @@ class FastSpeech(dg.Layer):
def forward(self, def forward(self,
character, character,
text_pos, text_pos,
enc_non_pad_mask,
dec_non_pad_mask,
mel_pos=None, mel_pos=None,
enc_slf_attn_mask=None,
dec_slf_attn_mask=None,
length_target=None, length_target=None,
alpha=1.0): alpha=1.0):
""" """
@ -102,12 +98,6 @@ class FastSpeech(dg.Layer):
text_pos (Variable): shape(B, T_text), dtype int64, the input text position. text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position, mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
where T_mel means the timesteps of input spectrum, where T_mel means the timesteps of input spectrum,
enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
the mask of input characters. Defaults to None.
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
the mask of mel spectrum. Defaults to None.
length_target (Variable, optional): shape(B, T_text), dtype int64, length_target (Variable, optional): shape(B, T_text), dtype int64,
the duration of phoneme compute from pretrained transformerTTS. Defaults to None. the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
@ -121,19 +111,12 @@ class FastSpeech(dg.Layer):
dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list. dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
""" """
encoder_output, enc_slf_attn_list = self.encoder( encoder_output, enc_slf_attn_list = self.encoder(character, text_pos)
character,
text_pos,
enc_non_pad_mask,
slf_attn_mask=enc_slf_attn_mask)
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
length_regulator_output, duration_predictor_output = self.length_regulator( length_regulator_output, duration_predictor_output = self.length_regulator(
encoder_output, target=length_target, alpha=alpha) encoder_output, target=length_target, alpha=alpha)
decoder_output, dec_slf_attn_list = self.decoder( decoder_output, dec_slf_attn_list = self.decoder(
length_regulator_output, length_regulator_output, mel_pos)
mel_pos,
dec_non_pad_mask,
slf_attn_mask=dec_slf_attn_mask)
mel_output = self.mel_linear(decoder_output) mel_output = self.mel_linear(decoder_output)
mel_output_postnet = self.postnet(mel_output) + mel_output mel_output_postnet = self.postnet(mel_output) + mel_output
@ -142,19 +125,8 @@ class FastSpeech(dg.Layer):
else: else:
length_regulator_output, decoder_pos = self.length_regulator( length_regulator_output, decoder_pos = self.length_regulator(
encoder_output, alpha=alpha) encoder_output, alpha=alpha)
slf_attn_mask = get_triu_tensor( decoder_output, _ = self.decoder(length_regulator_output,
decoder_pos.numpy(), decoder_pos.numpy()).astype(np.float32) decoder_pos)
slf_attn_mask = np.expand_dims(slf_attn_mask, axis=0)
slf_attn_mask = fluid.layers.cast(
dg.to_variable(slf_attn_mask != 0), np.float32) * (-2**32 + 1)
slf_attn_mask = dg.to_variable(slf_attn_mask)
dec_non_pad_mask = fluid.layers.unsqueeze(
(decoder_pos != 0).astype(np.float32), [-1])
decoder_output, _ = self.decoder(
length_regulator_output,
decoder_pos,
dec_non_pad_mask,
slf_attn_mask=slf_attn_mask)
mel_output = self.mel_linear(decoder_output) mel_output = self.mel_linear(decoder_output)
mel_output_postnet = self.postnet(mel_output) + mel_output mel_output_postnet = self.postnet(mel_output) + mel_output