fix some bugs of transformer_tts and fastspeech.

This commit is contained in:
lifuchen 2020-06-12 08:54:32 +00:00
parent 8716a1843c
commit 681d34b953
7 changed files with 145 additions and 141 deletions

View File

@ -115,15 +115,10 @@ def alignments(args):
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0]) mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
mel_lens = mel_input.shape[1] mel_lens = mel_input.shape[1]
dec_slf_mask = get_triu_tensor(mel_input,
mel_input).astype(np.float32)
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
dec_slf_mask = fluid.layers.cast(
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = np.arange(1, mel_input.shape[1] + 1)
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
text, mel_input, pos_text, pos_mel, dec_slf_mask) text, mel_input, pos_text, pos_mel)
mel_input = fluid.layers.concat( mel_input = fluid.layers.concat(
[mel_input, postnet_pred[:, -1:, :]], axis=1) [mel_input, postnet_pred[:, -1:, :]], axis=1)

View File

@ -29,5 +29,5 @@ train:
grad_clip_thresh: 0.1 #the threshold of grad clip. grad_clip_thresh: 0.1 #the threshold of grad clip.
checkpoint_interval: 1000 checkpoint_interval: 1000
max_epochs: 10000 max_iteration: 500000

View File

@ -62,7 +62,8 @@ def main(args):
cfg = yaml.load(f, Loader=yaml.Loader) cfg = yaml.load(f, Loader=yaml.Loader)
global_step = 0 global_step = 0
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(dg.parallel.Env()
.dev_id) if args.use_gpu else fluid.CPUPlace()
fluid.enable_dygraph(place) fluid.enable_dygraph(place)
if not os.path.exists(args.output): if not os.path.exists(args.output):
@ -88,7 +89,8 @@ def main(args):
cfg['train']['batch_size'], cfg['train']['batch_size'],
nranks, nranks,
local_rank, local_rank,
shuffle=True).reader() shuffle=True).reader
iterator = iter(tqdm(reader))
# Load parameters. # Load parameters.
global_step = io.load_parameters( global_step = io.load_parameters(
@ -103,12 +105,14 @@ def main(args):
strategy = dg.parallel.prepare_context() strategy = dg.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy) model = fluid.dygraph.parallel.DataParallel(model, strategy)
for epoch in range(cfg['train']['max_epochs']): while global_step <= cfg['train']['max_iteration']:
pbar = tqdm(reader) try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
for i, data in enumerate(pbar): (character, mel, pos_text, pos_mel, alignment) = batch
pbar.set_description('Processing at epoch %d' % epoch)
(character, mel, pos_text, pos_mel, alignment) = data
global_step += 1 global_step += 1
@ -120,8 +124,7 @@ def main(args):
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel) mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
duration_loss = layers.mean( duration_loss = layers.mean(
layers.abs( layers.abs(
layers.elementwise_sub(duration_predictor_output, layers.elementwise_sub(duration_predictor_output, alignment)))
alignment)))
total_loss = mel_loss + mel_postnet_loss + duration_loss total_loss = mel_loss + mel_postnet_loss + duration_loss
if local_rank == 0: if local_rank == 0:
@ -147,8 +150,8 @@ def main(args):
if local_rank == 0 and global_step % cfg['train'][ if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0: 'checkpoint_interval'] == 0:
io.save_parameters( io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step, os.path.join(args.output, 'checkpoints'), global_step, model,
model, optimizer) optimizer)
if local_rank == 0: if local_rank == 0:
writer.close() writer.close()

View File

@ -53,7 +53,7 @@ During synthesis, results are saved in `${output}/samples` and tensorboard log i
TransformerTTS model can be trained by running ``train_transformer.py``. TransformerTTS model can be trained by running ``train_transformer.py``.
```bash ```bash
python train_trasformer.py \ python train_transformer.py \
--use_gpu=1 \ --use_gpu=1 \
--data=${DATAPATH} \ --data=${DATAPATH} \
--output='./experiment' \ --output='./experiment' \

View File

@ -31,7 +31,7 @@ train:
checkpoint_interval: 1000 checkpoint_interval: 1000
image_interval: 2000 image_interval: 2000
max_epochs: 10000 max_iteration: 500000

View File

@ -102,16 +102,21 @@ def main(args):
cfg['train']['batch_size'], cfg['train']['batch_size'],
nranks, nranks,
local_rank, local_rank,
shuffle=True).reader() shuffle=True).reader
for epoch in range(cfg['train']['max_epochs']): iterator = iter(tqdm(reader))
pbar = tqdm(reader)
for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch)
character, mel, mel_input, pos_text, pos_mel = data
global_step += 1 global_step += 1
while global_step <= cfg['train']['max_iteration']:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
character, mel, mel_input, pos_text, pos_mel = batch
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
character, mel_input, pos_text, pos_mel) character, mel_input, pos_text, pos_mel)
@ -134,8 +139,7 @@ def main(args):
}, global_step) }, global_step)
if cfg['network']['stop_token']: if cfg['network']['stop_token']:
writer.add_scalar('stop_loss', writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
stop_loss.numpy(), global_step)
if parallel: if parallel:
writer.add_scalars('alphas', { writer.add_scalars('alphas', {
@ -157,7 +161,7 @@ def main(args):
for j in range(cfg['network']['decoder_num_head']): for j in range(cfg['network']['decoder_num_head']):
x = np.uint8( x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][ cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255) 'batch_size'] // nranks]) * 255)
writer.add_image( writer.add_image(
'Attention_%d_0' % global_step, 'Attention_%d_0' % global_step,
x, x,
@ -168,7 +172,7 @@ def main(args):
for j in range(cfg['network']['encoder_num_head']): for j in range(cfg['network']['encoder_num_head']):
x = np.uint8( x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][ cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255) 'batch_size'] // nranks]) * 255)
writer.add_image( writer.add_image(
'Attention_enc_%d_0' % global_step, 'Attention_enc_%d_0' % global_step,
x, x,
@ -179,7 +183,7 @@ def main(args):
for j in range(cfg['network']['decoder_num_head']): for j in range(cfg['network']['decoder_num_head']):
x = np.uint8( x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][ cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255) 'batch_size'] // nranks]) * 255)
writer.add_image( writer.add_image(
'Attention_dec_%d_0' % global_step, 'Attention_dec_%d_0' % global_step,
x, x,
@ -199,8 +203,9 @@ def main(args):
if local_rank == 0 and global_step % cfg['train'][ if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0: 'checkpoint_interval'] == 0:
io.save_parameters( io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step, os.path.join(args.output, 'checkpoints'), global_step, model,
model, optimizer) optimizer)
global_step += 1
if local_rank == 0: if local_rank == 0:
writer.close() writer.close()

View File

@ -94,7 +94,8 @@ class LengthRegulator(dg.Layer):
else: else:
duration_predictor_output = layers.round(duration_predictor_output) duration_predictor_output = layers.round(duration_predictor_output)
output = self.LR(x, duration_predictor_output, alpha) output = self.LR(x, duration_predictor_output, alpha)
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)) mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
np.int64)
mel_pos = layers.unsqueeze(mel_pos, [0]) mel_pos = layers.unsqueeze(mel_pos, [0])
return output, mel_pos return output, mel_pos