fix some bugs of transformer_tts and fastspeech.

This commit is contained in:
lifuchen 2020-06-12 08:54:32 +00:00
parent 8716a1843c
commit 681d34b953
7 changed files with 145 additions and 141 deletions

View File

@ -115,15 +115,10 @@ def alignments(args):
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
mel_lens = mel_input.shape[1]
dec_slf_mask = get_triu_tensor(mel_input,
mel_input).astype(np.float32)
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
dec_slf_mask = fluid.layers.cast(
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
pos_mel = np.arange(1, mel_input.shape[1] + 1)
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
text, mel_input, pos_text, pos_mel, dec_slf_mask)
text, mel_input, pos_text, pos_mel)
mel_input = fluid.layers.concat(
[mel_input, postnet_pred[:, -1:, :]], axis=1)

View File

@ -29,5 +29,5 @@ train:
grad_clip_thresh: 0.1 #the threshold of grad clip.
checkpoint_interval: 1000
max_epochs: 10000
max_iteration: 500000

View File

@ -62,7 +62,8 @@ def main(args):
cfg = yaml.load(f, Loader=yaml.Loader)
global_step = 0
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
place = fluid.CUDAPlace(dg.parallel.Env()
.dev_id) if args.use_gpu else fluid.CPUPlace()
fluid.enable_dygraph(place)
if not os.path.exists(args.output):
@ -88,7 +89,8 @@ def main(args):
cfg['train']['batch_size'],
nranks,
local_rank,
shuffle=True).reader()
shuffle=True).reader
iterator = iter(tqdm(reader))
# Load parameters.
global_step = io.load_parameters(
@ -103,12 +105,14 @@ def main(args):
strategy = dg.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
for epoch in range(cfg['train']['max_epochs']):
pbar = tqdm(reader)
while global_step <= cfg['train']['max_iteration']:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch)
(character, mel, pos_text, pos_mel, alignment) = data
(character, mel, pos_text, pos_mel, alignment) = batch
global_step += 1
@ -120,8 +124,7 @@ def main(args):
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
duration_loss = layers.mean(
layers.abs(
layers.elementwise_sub(duration_predictor_output,
alignment)))
layers.elementwise_sub(duration_predictor_output, alignment)))
total_loss = mel_loss + mel_postnet_loss + duration_loss
if local_rank == 0:
@ -147,8 +150,8 @@ def main(args):
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step,
model, optimizer)
os.path.join(args.output, 'checkpoints'), global_step, model,
optimizer)
if local_rank == 0:
writer.close()

View File

@ -53,7 +53,7 @@ During synthesis, results are saved in `${output}/samples` and tensorboard log i
TransformerTTS model can be trained by running ``train_transformer.py``.
```bash
python train_trasformer.py \
python train_transformer.py \
--use_gpu=1 \
--data=${DATAPATH} \
--output='./experiment' \

View File

@ -31,7 +31,7 @@ train:
checkpoint_interval: 1000
image_interval: 2000
max_epochs: 10000
max_iteration: 500000

View File

@ -102,16 +102,21 @@ def main(args):
cfg['train']['batch_size'],
nranks,
local_rank,
shuffle=True).reader()
shuffle=True).reader
for epoch in range(cfg['train']['max_epochs']):
pbar = tqdm(reader)
for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d' % epoch)
character, mel, mel_input, pos_text, pos_mel = data
iterator = iter(tqdm(reader))
global_step += 1
while global_step <= cfg['train']['max_iteration']:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
character, mel, mel_input, pos_text, pos_mel = batch
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
character, mel_input, pos_text, pos_mel)
@ -134,8 +139,7 @@ def main(args):
}, global_step)
if cfg['network']['stop_token']:
writer.add_scalar('stop_loss',
stop_loss.numpy(), global_step)
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
if parallel:
writer.add_scalars('alphas', {
@ -157,7 +161,7 @@ def main(args):
for j in range(cfg['network']['decoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255)
'batch_size'] // nranks]) * 255)
writer.add_image(
'Attention_%d_0' % global_step,
x,
@ -168,7 +172,7 @@ def main(args):
for j in range(cfg['network']['encoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255)
'batch_size'] // nranks]) * 255)
writer.add_image(
'Attention_enc_%d_0' % global_step,
x,
@ -179,7 +183,7 @@ def main(args):
for j in range(cfg['network']['decoder_num_head']):
x = np.uint8(
cm.viridis(prob.numpy()[j * cfg['train'][
'batch_size'] // 2]) * 255)
'batch_size'] // nranks]) * 255)
writer.add_image(
'Attention_dec_%d_0' % global_step,
x,
@ -199,8 +203,9 @@ def main(args):
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step,
model, optimizer)
os.path.join(args.output, 'checkpoints'), global_step, model,
optimizer)
global_step += 1
if local_rank == 0:
writer.close()

View File

@ -94,7 +94,8 @@ class LengthRegulator(dg.Layer):
else:
duration_predictor_output = layers.round(duration_predictor_output)
output = self.LR(x, duration_predictor_output, alpha)
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1))
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
np.int64)
mel_pos = layers.unsqueeze(mel_pos, [0])
return output, mel_pos