fix some bugs of transformer_tts and fastspeech.

This commit is contained in:
lifuchen 2020-06-12 08:54:32 +00:00
parent 8716a1843c
commit 681d34b953
7 changed files with 145 additions and 141 deletions

View File

@ -115,15 +115,10 @@ def alignments(args):
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
mel_lens = mel_input.shape[1]
dec_slf_mask = get_triu_tensor(mel_input,
mel_input).astype(np.float32)
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
dec_slf_mask = fluid.layers.cast(
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
pos_mel = np.arange(1, mel_input.shape[1] + 1)
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
text, mel_input, pos_text, pos_mel, dec_slf_mask)
text, mel_input, pos_text, pos_mel)
mel_input = fluid.layers.concat(
[mel_input, postnet_pred[:, -1:, :]], axis=1)

View File

@ -29,5 +29,5 @@ train:
grad_clip_thresh: 0.1 #the threshold of grad clip.
checkpoint_interval: 1000
max_epochs: 10000
max_iteration: 500000

View File

@ -62,7 +62,8 @@ def main(args):
cfg = yaml.load(f, Loader=yaml.Loader)
global_step = 0
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
place = fluid.CUDAPlace(dg.parallel.Env()
.dev_id) if args.use_gpu else fluid.CPUPlace()
fluid.enable_dygraph(place)
if not os.path.exists(args.output):
@ -88,7 +89,8 @@ def main(args):
cfg['train']['batch_size'],
nranks,
local_rank,
shuffle=True).reader()
shuffle=True).reader
iterator = iter(tqdm(reader))
# Load parameters.
global_step = io.load_parameters(
@ -103,52 +105,53 @@ def main(args):
strategy = dg.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
for epoch in range(cfg['train']['max_epochs']):
pbar = tqdm(reader)
while global_step <= cfg['train']['max_iteration']:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch)
(character, mel, pos_text, pos_mel, alignment) = data
(character, mel, pos_text, pos_mel, alignment) = batch
global_step += 1
global_step += 1
#Forward
result = model(
character, pos_text, mel_pos=pos_mel, length_target=alignment)
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
mel_loss = layers.mse_loss(mel_output, mel)
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
duration_loss = layers.mean(
layers.abs(
layers.elementwise_sub(duration_predictor_output,
alignment)))
total_loss = mel_loss + mel_postnet_loss + duration_loss
#Forward
result = model(
character, pos_text, mel_pos=pos_mel, length_target=alignment)
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
mel_loss = layers.mse_loss(mel_output, mel)
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
duration_loss = layers.mean(
layers.abs(
layers.elementwise_sub(duration_predictor_output, alignment)))
total_loss = mel_loss + mel_postnet_loss + duration_loss
if local_rank == 0:
writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
writer.add_scalar('post_mel_loss',
mel_postnet_loss.numpy(), global_step)
writer.add_scalar('duration_loss',
duration_loss.numpy(), global_step)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
global_step)
if local_rank == 0:
writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
writer.add_scalar('post_mel_loss',
mel_postnet_loss.numpy(), global_step)
writer.add_scalar('duration_loss',
duration_loss.numpy(), global_step)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
global_step)
if parallel:
total_loss = model.scale_loss(total_loss)
total_loss.backward()
model.apply_collective_grads()
else:
total_loss.backward()
optimizer.minimize(total_loss)
model.clear_gradients()
if parallel:
total_loss = model.scale_loss(total_loss)
total_loss.backward()
model.apply_collective_grads()
else:
total_loss.backward()
optimizer.minimize(total_loss)
model.clear_gradients()
# save checkpoint
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step,
model, optimizer)
# save checkpoint
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step, model,
optimizer)
if local_rank == 0:
writer.close()

View File

@ -53,7 +53,7 @@ During synthesis, results are saved in `${output}/samples` and tensorboard log i
TransformerTTS model can be trained by running ``train_transformer.py``.
```bash
python train_trasformer.py \
python train_transformer.py \
--use_gpu=1 \
--data=${DATAPATH} \
--output='./experiment' \

View File

@ -31,7 +31,7 @@ train:
checkpoint_interval: 1000
image_interval: 2000
max_epochs: 10000
max_iteration: 500000

View File

@ -102,105 +102,110 @@ def main(args):
cfg['train']['batch_size'],
nranks,
local_rank,
shuffle=True).reader()
shuffle=True).reader
for epoch in range(cfg['train']['max_epochs']):
pbar = tqdm(reader)
for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch)
character, mel, mel_input, pos_text, pos_mel = data
iterator = iter(tqdm(reader))
global_step += 1
global_step += 1
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
character, mel_input, pos_text, pos_mel)
while global_step <= cfg['train']['max_iteration']:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
mel_loss = layers.mean(
layers.abs(layers.elementwise_sub(mel_pred, mel)))
post_mel_loss = layers.mean(
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
loss = mel_loss + post_mel_loss
character, mel, mel_input, pos_text, pos_mel = batch
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
character, mel_input, pos_text, pos_mel)
mel_loss = layers.mean(
layers.abs(layers.elementwise_sub(mel_pred, mel)))
post_mel_loss = layers.mean(
layers.abs(layers.elementwise_sub(postnet_pred, mel)))
loss = mel_loss + post_mel_loss
# Note: When used stop token loss the learning did not work.
if cfg['network']['stop_token']:
label = (pos_mel == 0).astype(np.float32)
stop_loss = cross_entropy(stop_preds, label)
loss = loss + stop_loss
if local_rank == 0:
writer.add_scalars('training_loss', {
'mel_loss': mel_loss.numpy(),
'post_mel_loss': post_mel_loss.numpy()
}, global_step)
# Note: When used stop token loss the learning did not work.
if cfg['network']['stop_token']:
label = (pos_mel == 0).astype(np.float32)
stop_loss = cross_entropy(stop_preds, label)
loss = loss + stop_loss
if local_rank == 0:
writer.add_scalars('training_loss', {
'mel_loss': mel_loss.numpy(),
'post_mel_loss': post_mel_loss.numpy()
}, global_step)
if cfg['network']['stop_token']:
writer.add_scalar('stop_loss',
stop_loss.numpy(), global_step)
if parallel:
writer.add_scalars('alphas', {
'encoder_alpha': model._layers.encoder.alpha.numpy(),
'decoder_alpha': model._layers.decoder.alpha.numpy(),
}, global_step)
else:
writer.add_scalars('alphas', {
'encoder_alpha': model.encoder.alpha.numpy(),
'decoder_alpha': model.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
global_step)
if global_step % cfg['train']['image_interval'] == 1:
for i, prob in enumerate(attn_probs):
for j in range(cfg['network']['decoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255)
writer.add_image(
'Attention_%d_0' % global_step,
x,
i * 4 + j,
dataformats="HWC")
for i, prob in enumerate(attn_enc):
for j in range(cfg['network']['encoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255)
writer.add_image(
'Attention_enc_%d_0' % global_step,
x,
i * 4 + j,
dataformats="HWC")
for i, prob in enumerate(attn_dec):
for j in range(cfg['network']['decoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255)
writer.add_image(
'Attention_dec_%d_0' % global_step,
x,
i * 4 + j,
dataformats="HWC")
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
if parallel:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
writer.add_scalars('alphas', {
'encoder_alpha': model._layers.encoder.alpha.numpy(),
'decoder_alpha': model._layers.decoder.alpha.numpy(),
}, global_step)
else:
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
writer.add_scalars('alphas', {
'encoder_alpha': model.encoder.alpha.numpy(),
'decoder_alpha': model.decoder.alpha.numpy(),
}, global_step)
# save checkpoint
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step,
model, optimizer)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
global_step)
if global_step % cfg['train']['image_interval'] == 1:
for i, prob in enumerate(attn_probs):
for j in range(cfg['network']['decoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // nranks]) * 255)
writer.add_image(
'Attention_%d_0' % global_step,
x,
i * 4 + j,
dataformats="HWC")
for i, prob in enumerate(attn_enc):
for j in range(cfg['network']['encoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // nranks]) * 255)
writer.add_image(
'Attention_enc_%d_0' % global_step,
x,
i * 4 + j,
dataformats="HWC")
for i, prob in enumerate(attn_dec):
for j in range(cfg['network']['decoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // nranks]) * 255)
writer.add_image(
'Attention_dec_%d_0' % global_step,
x,
i * 4 + j,
dataformats="HWC")
if parallel:
loss = model.scale_loss(loss)
loss.backward()
model.apply_collective_grads()
else:
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
# save checkpoint
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step, model,
optimizer)
global_step += 1
if local_rank == 0:
writer.close()

View File

@ -94,7 +94,8 @@ class LengthRegulator(dg.Layer):
else:
duration_predictor_output = layers.round(duration_predictor_output)
output = self.LR(x, duration_predictor_output, alpha)
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1))
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
np.int64)
mel_pos = layers.unsqueeze(mel_pos, [0])
return output, mel_pos