fix some bugs of transformer_tts and fastspeech.
This commit is contained in:
parent
8716a1843c
commit
681d34b953
|
@ -115,15 +115,10 @@ def alignments(args):
|
|||
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
||||
mel_lens = mel_input.shape[1]
|
||||
|
||||
dec_slf_mask = get_triu_tensor(mel_input,
|
||||
mel_input).astype(np.float32)
|
||||
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
|
||||
dec_slf_mask = fluid.layers.cast(
|
||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
||||
text, mel_input, pos_text, pos_mel)
|
||||
mel_input = fluid.layers.concat(
|
||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||
|
||||
|
|
|
@ -29,5 +29,5 @@ train:
|
|||
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
||||
|
||||
checkpoint_interval: 1000
|
||||
max_epochs: 10000
|
||||
max_iteration: 500000
|
||||
|
||||
|
|
|
@ -62,7 +62,8 @@ def main(args):
|
|||
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
global_step = 0
|
||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
||||
place = fluid.CUDAPlace(dg.parallel.Env()
|
||||
.dev_id) if args.use_gpu else fluid.CPUPlace()
|
||||
fluid.enable_dygraph(place)
|
||||
|
||||
if not os.path.exists(args.output):
|
||||
|
@ -88,7 +89,8 @@ def main(args):
|
|||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
shuffle=True).reader()
|
||||
shuffle=True).reader
|
||||
iterator = iter(tqdm(reader))
|
||||
|
||||
# Load parameters.
|
||||
global_step = io.load_parameters(
|
||||
|
@ -103,12 +105,14 @@ def main(args):
|
|||
strategy = dg.parallel.prepare_context()
|
||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
while global_step <= cfg['train']['max_iteration']:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm(reader))
|
||||
batch = next(iterator)
|
||||
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
(character, mel, pos_text, pos_mel, alignment) = data
|
||||
(character, mel, pos_text, pos_mel, alignment) = batch
|
||||
|
||||
global_step += 1
|
||||
|
||||
|
@ -120,8 +124,7 @@ def main(args):
|
|||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||
duration_loss = layers.mean(
|
||||
layers.abs(
|
||||
layers.elementwise_sub(duration_predictor_output,
|
||||
alignment)))
|
||||
layers.elementwise_sub(duration_predictor_output, alignment)))
|
||||
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
||||
|
||||
if local_rank == 0:
|
||||
|
@ -147,8 +150,8 @@ def main(args):
|
|||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
os.path.join(args.output, 'checkpoints'), global_step, model,
|
||||
optimizer)
|
||||
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
|
|
@ -53,7 +53,7 @@ During synthesis, results are saved in `${output}/samples` and tensorboard log i
|
|||
TransformerTTS model can be trained by running ``train_transformer.py``.
|
||||
|
||||
```bash
|
||||
python train_trasformer.py \
|
||||
python train_transformer.py \
|
||||
--use_gpu=1 \
|
||||
--data=${DATAPATH} \
|
||||
--output='./experiment' \
|
||||
|
|
|
@ -31,7 +31,7 @@ train:
|
|||
checkpoint_interval: 1000
|
||||
image_interval: 2000
|
||||
|
||||
max_epochs: 10000
|
||||
max_iteration: 500000
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -102,16 +102,21 @@ def main(args):
|
|||
cfg['train']['batch_size'],
|
||||
nranks,
|
||||
local_rank,
|
||||
shuffle=True).reader()
|
||||
shuffle=True).reader
|
||||
|
||||
for epoch in range(cfg['train']['max_epochs']):
|
||||
pbar = tqdm(reader)
|
||||
for i, data in enumerate(pbar):
|
||||
pbar.set_description('Processing at epoch %d' % epoch)
|
||||
character, mel, mel_input, pos_text, pos_mel = data
|
||||
iterator = iter(tqdm(reader))
|
||||
|
||||
global_step += 1
|
||||
|
||||
while global_step <= cfg['train']['max_iteration']:
|
||||
try:
|
||||
batch = next(iterator)
|
||||
except StopIteration as e:
|
||||
iterator = iter(tqdm(reader))
|
||||
batch = next(iterator)
|
||||
|
||||
character, mel, mel_input, pos_text, pos_mel = batch
|
||||
|
||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||
character, mel_input, pos_text, pos_mel)
|
||||
|
||||
|
@ -134,8 +139,7 @@ def main(args):
|
|||
}, global_step)
|
||||
|
||||
if cfg['network']['stop_token']:
|
||||
writer.add_scalar('stop_loss',
|
||||
stop_loss.numpy(), global_step)
|
||||
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||
|
||||
if parallel:
|
||||
writer.add_scalars('alphas', {
|
||||
|
@ -157,7 +161,7 @@ def main(args):
|
|||
for j in range(cfg['network']['decoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
'batch_size'] // nranks]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_%d_0' % global_step,
|
||||
x,
|
||||
|
@ -168,7 +172,7 @@ def main(args):
|
|||
for j in range(cfg['network']['encoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
'batch_size'] // nranks]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_enc_%d_0' % global_step,
|
||||
x,
|
||||
|
@ -179,7 +183,7 @@ def main(args):
|
|||
for j in range(cfg['network']['decoder_num_head']):
|
||||
x = np.uint8(
|
||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||
'batch_size'] // 2]) * 255)
|
||||
'batch_size'] // nranks]) * 255)
|
||||
writer.add_image(
|
||||
'Attention_dec_%d_0' % global_step,
|
||||
x,
|
||||
|
@ -199,8 +203,9 @@ def main(args):
|
|||
if local_rank == 0 and global_step % cfg['train'][
|
||||
'checkpoint_interval'] == 0:
|
||||
io.save_parameters(
|
||||
os.path.join(args.output, 'checkpoints'), global_step,
|
||||
model, optimizer)
|
||||
os.path.join(args.output, 'checkpoints'), global_step, model,
|
||||
optimizer)
|
||||
global_step += 1
|
||||
|
||||
if local_rank == 0:
|
||||
writer.close()
|
||||
|
|
|
@ -94,7 +94,8 @@ class LengthRegulator(dg.Layer):
|
|||
else:
|
||||
duration_predictor_output = layers.round(duration_predictor_output)
|
||||
output = self.LR(x, duration_predictor_output, alpha)
|
||||
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1))
|
||||
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
|
||||
np.int64)
|
||||
mel_pos = layers.unsqueeze(mel_pos, [0])
|
||||
return output, mel_pos
|
||||
|
||||
|
|
Loading…
Reference in New Issue