modified transformer_tts to make sure it works on paddle 1.8

This commit is contained in:
lifuchen 2020-05-07 02:23:50 +00:00
parent 45c07fa42e
commit 46e254d1f8
8 changed files with 58 additions and 130 deletions

View File

@ -141,7 +141,6 @@ def batch_examples(batch):
texts = [] texts = []
mels = [] mels = []
mel_inputs = [] mel_inputs = []
mel_lens = []
text_lens = [] text_lens = []
pos_texts = [] pos_texts = []
pos_mels = [] pos_mels = []
@ -151,7 +150,6 @@ def batch_examples(batch):
np.concatenate( np.concatenate(
[np.zeros([mel.shape[0], 1], np.float32), mel[:, :-1]], [np.zeros([mel.shape[0], 1], np.float32), mel[:, :-1]],
axis=-1)) axis=-1))
mel_lens.append(mel.shape[1])
text_lens.append(len(text)) text_lens.append(len(text))
pos_texts.append(np.arange(1, len(text) + 1)) pos_texts.append(np.arange(1, len(text) + 1))
pos_mels.append(np.arange(1, mel.shape[1] + 1)) pos_mels.append(np.arange(1, mel.shape[1] + 1))
@ -174,11 +172,6 @@ def batch_examples(batch):
for i, _ in sorted( for i, _ in sorted(
zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True) zip(mel_inputs, text_lens), key=lambda x: x[1], reverse=True)
] ]
mel_lens = [
i
for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
]
pos_texts = [ pos_texts = [
i i
for i, _ in sorted( for i, _ in sorted(
@ -200,18 +193,7 @@ def batch_examples(batch):
mel_inputs = np.transpose( mel_inputs = np.transpose(
SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels) SpecBatcher(pad_value=0.)(mel_inputs), axes=(0, 2, 1)) #(B,T,num_mels)
enc_slf_mask = get_attn_key_pad_mask(pos_texts).astype(np.float32) return (texts, mels, mel_inputs, pos_texts, pos_mels)
enc_query_mask = get_non_pad_mask(pos_texts).astype(np.float32)
dec_slf_mask = get_dec_attn_key_pad_mask(pos_mels,
mel_inputs).astype(np.float32)
enc_dec_mask = get_attn_key_pad_mask(enc_query_mask[:, :, 0]).astype(
np.float32)
dec_query_slf_mask = get_non_pad_mask(pos_mels).astype(np.float32)
dec_query_mask = get_non_pad_mask(pos_mels).astype(np.float32)
return (texts, mels, mel_inputs, pos_texts, pos_mels, np.array(text_lens),
np.array(mel_lens), enc_slf_mask, enc_query_mask, dec_slf_mask,
enc_dec_mask, dec_query_slf_mask, dec_query_mask)
def batch_examples_vocoder(batch): def batch_examples_vocoder(batch):

View File

@ -101,15 +101,10 @@ def synthesis(text_input, args):
pbar = tqdm(range(args.max_len)) pbar = tqdm(range(args.max_len))
for i in pbar: for i in pbar:
dec_slf_mask = get_triu_tensor(
mel_input.numpy(), mel_input.numpy()).astype(np.float32)
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
dec_slf_mask = fluid.layers.cast(
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = np.arange(1, mel_input.shape[1] + 1)
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
text, mel_input, pos_text, pos_mel, dec_slf_mask) text, mel_input, pos_text, pos_mel)
mel_input = fluid.layers.concat( mel_input = fluid.layers.concat(
[mel_input, postnet_pred[:, -1:, :]], axis=1) [mel_input, postnet_pred[:, -1:, :]], axis=1)

View File

@ -79,7 +79,9 @@ def main(args):
(cfg['train']['warm_up_step'] * (cfg['train']['warm_up_step'] *
(cfg['train']['learning_rate']**2)), (cfg['train']['learning_rate']**2)),
cfg['train']['warm_up_step']), cfg['train']['warm_up_step']),
parameter_list=model.parameters()) parameter_list=model.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
'grad_clip_thresh']))
# Load parameters. # Load parameters.
global_step = io.load_parameters( global_step = io.load_parameters(
@ -107,21 +109,12 @@ def main(args):
pbar = tqdm(reader) pbar = tqdm(reader)
for i, data in enumerate(pbar): for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch) pbar.set_description('Processing at epoch %d' % epoch)
character, mel, mel_input, pos_text, pos_mel, text_length, _, enc_slf_mask, enc_query_mask, dec_slf_mask, enc_dec_mask, dec_query_slf_mask, dec_query_mask = data character, mel, mel_input, pos_text, pos_mel = data
global_step += 1 global_step += 1
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
character, character, mel_input, pos_text, pos_mel)
mel_input,
pos_text,
pos_mel,
dec_slf_mask=dec_slf_mask,
enc_slf_mask=enc_slf_mask,
enc_query_mask=enc_query_mask,
enc_dec_mask=enc_dec_mask,
dec_query_slf_mask=dec_query_slf_mask,
dec_query_mask=dec_query_mask)
mel_loss = layers.mean( mel_loss = layers.mean(
layers.abs(layers.elementwise_sub(mel_pred, mel))) layers.abs(layers.elementwise_sub(mel_pred, mel)))
@ -202,10 +195,7 @@ def main(args):
model.apply_collective_grads() model.apply_collective_grads()
else: else:
loss.backward() loss.backward()
optimizer.minimize( optimizer.minimize(loss)
loss,
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
'train']['grad_clip_thresh']))
model.clear_gradients() model.clear_gradients()
# save checkpoint # save checkpoint

View File

@ -74,7 +74,9 @@ def main(args):
(cfg['train']['warm_up_step'] * (cfg['train']['warm_up_step'] *
(cfg['train']['learning_rate']**2)), (cfg['train']['learning_rate']**2)),
cfg['train']['warm_up_step']), cfg['train']['warm_up_step']),
parameter_list=model.parameters()) parameter_list=model.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
'grad_clip_thresh']))
# Load parameters. # Load parameters.
global_step = io.load_parameters( global_step = io.load_parameters(
@ -117,10 +119,7 @@ def main(args):
model.apply_collective_grads() model.apply_collective_grads()
else: else:
loss.backward() loss.backward()
optimizer.minimize( optimizer.minimize(loss)
loss,
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
'train']['grad_clip_thresh']))
model.clear_gradients() model.clear_gradients()
if local_rank == 0: if local_rank == 0:

View File

@ -113,15 +113,7 @@ class Decoder(dg.Layer):
outputs_per_step=outputs_per_step, outputs_per_step=outputs_per_step,
use_cudnn=True) use_cudnn=True)
def forward(self, def forward(self, key, value, query, positional, c_mask):
key,
value,
query,
positional,
mask,
m_mask=None,
m_self_mask=None,
zero_mask=None):
""" """
Compute decoder outputs. Compute decoder outputs.
@ -132,11 +124,7 @@ class Decoder(dg.Layer):
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
where T_mel means the timesteps of input spectrum, where T_mel means the timesteps of input spectrum,
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position. positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. c_mask (Variable): shape(B, T_text, 1), dtype float32, query mask returned from encoder.
m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
Returns: Returns:
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
out (Variable): shape(B, T_mel, C), the decoder output after post mel network. out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
@ -148,15 +136,20 @@ class Decoder(dg.Layer):
# get decoder mask with triangular matrix # get decoder mask with triangular matrix
if fluid.framework._dygraph_tracer()._train_mode: if fluid.framework._dygraph_tracer()._train_mode:
m_mask = layers.expand(m_mask, [self.num_head, 1, key.shape[1]]) mask = get_dec_attn_key_pad_mask(positional, self.num_head,
m_self_mask = layers.expand(m_self_mask, query.dtype)
[self.num_head, 1, query.shape[1]]) m_mask = get_non_pad_mask(positional, self.num_head, query.dtype)
mask = layers.expand(mask, [self.num_head, 1, 1]) zero_mask = layers.cast(c_mask == 0, dtype=query.dtype) * -1e30
zero_mask = layers.expand(zero_mask, [self.num_head, 1, 1]) zero_mask = layers.transpose(zero_mask, perm=[0, 2, 1])
else: else:
mask = layers.expand(mask, [self.num_head, 1, 1]) len_q = query.shape[1]
m_mask, m_self_mask, zero_mask = None, None, None mask = layers.triu(
layers.ones(
shape=[len_q, len_q], dtype=query.dtype),
diagonal=1)
mask = layers.cast(mask != 0, dtype=query.dtype) * -1e30
m_mask, zero_mask = None, None
# Decoder pre-network # Decoder pre-network
query = self.decoder_prenet(query) query = self.decoder_prenet(query)
@ -179,7 +172,7 @@ class Decoder(dg.Layer):
for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers, for selfattn, attn, ffn in zip(self.selfattn_layers, self.attn_layers,
self.ffns): self.ffns):
query, attn_dec = selfattn( query, attn_dec = selfattn(
query, query, query, mask=mask, query_mask=m_self_mask) query, query, query, mask=mask, query_mask=m_mask)
query, attn_dot = attn( query, attn_dot = attn(
key, value, query, mask=zero_mask, query_mask=m_mask) key, value, query, mask=zero_mask, query_mask=m_mask)
query = ffn(query) query = ffn(query)

View File

@ -64,7 +64,7 @@ class Encoder(dg.Layer):
for i, layer in enumerate(self.ffns): for i, layer in enumerate(self.ffns):
self.add_sublayer("ffns_{}".format(i), layer) self.add_sublayer("ffns_{}".format(i), layer)
def forward(self, x, positional, mask=None, query_mask=None): def forward(self, x, positional):
""" """
Encode text sequence. Encode text sequence.
@ -72,24 +72,22 @@ class Encoder(dg.Layer):
x (Variable): shape(B, T_text), dtype float32, the input character, x (Variable): shape(B, T_text), dtype float32, the input character,
where T_text means the timesteps of input text, where T_text means the timesteps of input text,
positional (Variable): shape(B, T_text), dtype int64, the characters position. positional (Variable): shape(B, T_text), dtype int64, the characters position.
mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
Returns: Returns:
x (Variable): shape(B, T_text, C), the encoder output. x (Variable): shape(B, T_text, C), the encoder output.
attentions (list[Variable]): len(n_layers), the encoder self attention list. attentions (list[Variable]): len(n_layers), the encoder self attention list.
""" """
if fluid.framework._dygraph_tracer()._train_mode:
seq_len_key = x.shape[1]
query_mask = layers.expand(query_mask,
[self.num_head, 1, seq_len_key])
mask = layers.expand(mask, [self.num_head, 1, 1])
else:
query_mask, mask = None, None
# Encoder pre_network # Encoder pre_network
x = self.encoder_prenet(x) x = self.encoder_prenet(x)
if fluid.framework._dygraph_tracer()._train_mode:
mask = get_attn_key_pad_mask(positional, self.num_head, x.dtype)
query_mask = get_non_pad_mask(positional, self.num_head, x.dtype)
else:
query_mask, mask = None, None
# Get positional encoding # Get positional encoding
positional = self.pos_emb(positional) positional = self.pos_emb(positional)
@ -105,4 +103,4 @@ class Encoder(dg.Layer):
x = ffn(x) x = ffn(x)
attentions.append(attention) attentions.append(attention)
return x, attentions return x, attentions, query_mask

View File

@ -45,17 +45,7 @@ class TransformerTTS(dg.Layer):
self.decoder = Decoder(num_hidden, n_mels, outputs_per_step, self.decoder = Decoder(num_hidden, n_mels, outputs_per_step,
decoder_num_head, decoder_n_layers) decoder_num_head, decoder_n_layers)
def forward(self, def forward(self, characters, mel_input, pos_text, pos_mel):
characters,
mel_input,
pos_text,
pos_mel,
dec_slf_mask,
enc_slf_mask=None,
enc_query_mask=None,
enc_dec_mask=None,
dec_query_slf_mask=None,
dec_query_mask=None):
""" """
TransformerTTS network. TransformerTTS network.
@ -65,13 +55,6 @@ class TransformerTTS(dg.Layer):
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
where T_mel means the timesteps of input spectrum, where T_mel means the timesteps of input spectrum,
pos_text (Variable): shape(B, T_text), dtype int64, the characters position. pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
Returns: Returns:
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
@ -81,16 +64,8 @@ class TransformerTTS(dg.Layer):
attns_enc (list[Variable]): len(n_layers), the encoder self attention list. attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
attns_dec (list[Variable]): len(n_layers), the decoder self attention list. attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
""" """
key, attns_enc = self.encoder( key, attns_enc, query_mask = self.encoder(characters, pos_text)
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder( mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder(
key, key, key, mel_input, pos_mel, query_mask)
key,
mel_input,
pos_mel,
mask=dec_slf_mask,
zero_mask=enc_dec_mask,
m_self_mask=dec_query_slf_mask,
m_mask=dec_query_mask)
return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec

View File

@ -50,41 +50,37 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
return sinusoid_table return sinusoid_table
def get_non_pad_mask(seq): def get_non_pad_mask(seq, num_head, dtype):
mask = (seq != 0).astype(np.float32) mask = layers.cast(seq != 0, dtype=dtype)
mask = np.expand_dims(mask, axis=-1) mask = layers.unsqueeze(mask, axes=[-1])
mask = layers.expand(mask, [num_head, 1, 1])
return mask return mask
def get_attn_key_pad_mask(seq_k): def get_attn_key_pad_mask(seq_k, num_head, dtype):
''' For masking out the padding part of key sequence. ''' ''' For masking out the padding part of key sequence. '''
# Expand to fit the shape of key query attention matrix. # Expand to fit the shape of key query attention matrix.
padding_mask = (seq_k != 0).astype(np.float32) padding_mask = layers.cast(seq_k == 0, dtype=dtype) * -1e30
padding_mask = np.expand_dims(padding_mask, axis=1) padding_mask = layers.unsqueeze(padding_mask, axes=[1])
padding_mask = ( padding_mask = layers.expand(padding_mask, [num_head, 1, 1])
padding_mask == 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
return padding_mask return padding_mask
def get_dec_attn_key_pad_mask(seq_k, seq_q): def get_dec_attn_key_pad_mask(seq_k, num_head, dtype):
''' For masking out the padding part of key sequence. ''' ''' For masking out the padding part of key sequence. '''
# Expand to fit the shape of key query attention matrix. # Expand to fit the shape of key query attention matrix.
padding_mask = (seq_k == 0).astype(np.float32) padding_mask = layers.cast(seq_k == 0, dtype=dtype)
padding_mask = np.expand_dims(padding_mask, axis=1) padding_mask = layers.unsqueeze(padding_mask, axes=[1])
triu_tensor = get_triu_tensor(seq_q, seq_q)
padding_mask = padding_mask + triu_tensor
padding_mask = (
padding_mask != 0).astype(np.float32) * -1e30 #* (-2**32 + 1)
return padding_mask
def get_triu_tensor(seq_k, seq_q):
''' For make a triu tensor '''
len_k = seq_k.shape[1] len_k = seq_k.shape[1]
len_q = seq_q.shape[1] triu = layers.triu(
triu_tensor = np.triu(np.ones([len_k, len_q]), 1) layers.ones(
return triu_tensor shape=[len_k, len_k], dtype=dtype), diagonal=1)
padding_mask = padding_mask + triu
padding_mask = layers.cast(
padding_mask != 0, dtype=dtype) * -1e30 #* (-2**32 + 1)
padding_mask = layers.expand(padding_mask, [num_head, 1, 1])
return padding_mask
def guided_attention(N, T, g=0.2): def guided_attention(N, T, g=0.2):