fix some bugs of transformer_tts and fastspeech.
This commit is contained in:
parent
8716a1843c
commit
681d34b953
|
@ -115,15 +115,10 @@ def alignments(args):
|
||||||
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
mel_input = fluid.layers.unsqueeze(dg.to_variable(mel_input), [0])
|
||||||
mel_lens = mel_input.shape[1]
|
mel_lens = mel_input.shape[1]
|
||||||
|
|
||||||
dec_slf_mask = get_triu_tensor(mel_input,
|
|
||||||
mel_input).astype(np.float32)
|
|
||||||
dec_slf_mask = np.expand_dims(dec_slf_mask, axis=0)
|
|
||||||
dec_slf_mask = fluid.layers.cast(
|
|
||||||
dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1)
|
|
||||||
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
pos_mel = np.arange(1, mel_input.shape[1] + 1)
|
||||||
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0])
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
text, mel_input, pos_text, pos_mel, dec_slf_mask)
|
text, mel_input, pos_text, pos_mel)
|
||||||
mel_input = fluid.layers.concat(
|
mel_input = fluid.layers.concat(
|
||||||
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
[mel_input, postnet_pred[:, -1:, :]], axis=1)
|
||||||
|
|
||||||
|
|
|
@ -29,5 +29,5 @@ train:
|
||||||
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
grad_clip_thresh: 0.1 #the threshold of grad clip.
|
||||||
|
|
||||||
checkpoint_interval: 1000
|
checkpoint_interval: 1000
|
||||||
max_epochs: 10000
|
max_iteration: 500000
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,8 @@ def main(args):
|
||||||
cfg = yaml.load(f, Loader=yaml.Loader)
|
cfg = yaml.load(f, Loader=yaml.Loader)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
place = fluid.CUDAPlace(local_rank) if args.use_gpu else fluid.CPUPlace()
|
place = fluid.CUDAPlace(dg.parallel.Env()
|
||||||
|
.dev_id) if args.use_gpu else fluid.CPUPlace()
|
||||||
fluid.enable_dygraph(place)
|
fluid.enable_dygraph(place)
|
||||||
|
|
||||||
if not os.path.exists(args.output):
|
if not os.path.exists(args.output):
|
||||||
|
@ -88,7 +89,8 @@ def main(args):
|
||||||
cfg['train']['batch_size'],
|
cfg['train']['batch_size'],
|
||||||
nranks,
|
nranks,
|
||||||
local_rank,
|
local_rank,
|
||||||
shuffle=True).reader()
|
shuffle=True).reader
|
||||||
|
iterator = iter(tqdm(reader))
|
||||||
|
|
||||||
# Load parameters.
|
# Load parameters.
|
||||||
global_step = io.load_parameters(
|
global_step = io.load_parameters(
|
||||||
|
@ -103,12 +105,14 @@ def main(args):
|
||||||
strategy = dg.parallel.prepare_context()
|
strategy = dg.parallel.prepare_context()
|
||||||
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
model = fluid.dygraph.parallel.DataParallel(model, strategy)
|
||||||
|
|
||||||
for epoch in range(cfg['train']['max_epochs']):
|
while global_step <= cfg['train']['max_iteration']:
|
||||||
pbar = tqdm(reader)
|
try:
|
||||||
|
batch = next(iterator)
|
||||||
|
except StopIteration as e:
|
||||||
|
iterator = iter(tqdm(reader))
|
||||||
|
batch = next(iterator)
|
||||||
|
|
||||||
for i, data in enumerate(pbar):
|
(character, mel, pos_text, pos_mel, alignment) = batch
|
||||||
pbar.set_description('Processing at epoch %d' % epoch)
|
|
||||||
(character, mel, pos_text, pos_mel, alignment) = data
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
@ -120,8 +124,7 @@ def main(args):
|
||||||
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
|
||||||
duration_loss = layers.mean(
|
duration_loss = layers.mean(
|
||||||
layers.abs(
|
layers.abs(
|
||||||
layers.elementwise_sub(duration_predictor_output,
|
layers.elementwise_sub(duration_predictor_output, alignment)))
|
||||||
alignment)))
|
|
||||||
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
total_loss = mel_loss + mel_postnet_loss + duration_loss
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
|
@ -147,8 +150,8 @@ def main(args):
|
||||||
if local_rank == 0 and global_step % cfg['train'][
|
if local_rank == 0 and global_step % cfg['train'][
|
||||||
'checkpoint_interval'] == 0:
|
'checkpoint_interval'] == 0:
|
||||||
io.save_parameters(
|
io.save_parameters(
|
||||||
os.path.join(args.output, 'checkpoints'), global_step,
|
os.path.join(args.output, 'checkpoints'), global_step, model,
|
||||||
model, optimizer)
|
optimizer)
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
|
@ -53,7 +53,7 @@ During synthesis, results are saved in `${output}/samples` and tensorboard log i
|
||||||
TransformerTTS model can be trained by running ``train_transformer.py``.
|
TransformerTTS model can be trained by running ``train_transformer.py``.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python train_trasformer.py \
|
python train_transformer.py \
|
||||||
--use_gpu=1 \
|
--use_gpu=1 \
|
||||||
--data=${DATAPATH} \
|
--data=${DATAPATH} \
|
||||||
--output='./experiment' \
|
--output='./experiment' \
|
||||||
|
|
|
@ -31,7 +31,7 @@ train:
|
||||||
checkpoint_interval: 1000
|
checkpoint_interval: 1000
|
||||||
image_interval: 2000
|
image_interval: 2000
|
||||||
|
|
||||||
max_epochs: 10000
|
max_iteration: 500000
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -102,16 +102,21 @@ def main(args):
|
||||||
cfg['train']['batch_size'],
|
cfg['train']['batch_size'],
|
||||||
nranks,
|
nranks,
|
||||||
local_rank,
|
local_rank,
|
||||||
shuffle=True).reader()
|
shuffle=True).reader
|
||||||
|
|
||||||
for epoch in range(cfg['train']['max_epochs']):
|
iterator = iter(tqdm(reader))
|
||||||
pbar = tqdm(reader)
|
|
||||||
for i, data in enumerate(pbar):
|
|
||||||
pbar.set_description('Processing at epoch %d' % epoch)
|
|
||||||
character, mel, mel_input, pos_text, pos_mel = data
|
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
while global_step <= cfg['train']['max_iteration']:
|
||||||
|
try:
|
||||||
|
batch = next(iterator)
|
||||||
|
except StopIteration as e:
|
||||||
|
iterator = iter(tqdm(reader))
|
||||||
|
batch = next(iterator)
|
||||||
|
|
||||||
|
character, mel, mel_input, pos_text, pos_mel = batch
|
||||||
|
|
||||||
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(
|
||||||
character, mel_input, pos_text, pos_mel)
|
character, mel_input, pos_text, pos_mel)
|
||||||
|
|
||||||
|
@ -134,8 +139,7 @@ def main(args):
|
||||||
}, global_step)
|
}, global_step)
|
||||||
|
|
||||||
if cfg['network']['stop_token']:
|
if cfg['network']['stop_token']:
|
||||||
writer.add_scalar('stop_loss',
|
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||||
stop_loss.numpy(), global_step)
|
|
||||||
|
|
||||||
if parallel:
|
if parallel:
|
||||||
writer.add_scalars('alphas', {
|
writer.add_scalars('alphas', {
|
||||||
|
@ -157,7 +161,7 @@ def main(args):
|
||||||
for j in range(cfg['network']['decoder_num_head']):
|
for j in range(cfg['network']['decoder_num_head']):
|
||||||
x = np.uint8(
|
x = np.uint8(
|
||||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||||
'batch_size'] // 2]) * 255)
|
'batch_size'] // nranks]) * 255)
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_%d_0' % global_step,
|
'Attention_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
|
@ -168,7 +172,7 @@ def main(args):
|
||||||
for j in range(cfg['network']['encoder_num_head']):
|
for j in range(cfg['network']['encoder_num_head']):
|
||||||
x = np.uint8(
|
x = np.uint8(
|
||||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||||
'batch_size'] // 2]) * 255)
|
'batch_size'] // nranks]) * 255)
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_enc_%d_0' % global_step,
|
'Attention_enc_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
|
@ -179,7 +183,7 @@ def main(args):
|
||||||
for j in range(cfg['network']['decoder_num_head']):
|
for j in range(cfg['network']['decoder_num_head']):
|
||||||
x = np.uint8(
|
x = np.uint8(
|
||||||
cm.viridis(prob.numpy()[j * cfg['train'][
|
cm.viridis(prob.numpy()[j * cfg['train'][
|
||||||
'batch_size'] // 2]) * 255)
|
'batch_size'] // nranks]) * 255)
|
||||||
writer.add_image(
|
writer.add_image(
|
||||||
'Attention_dec_%d_0' % global_step,
|
'Attention_dec_%d_0' % global_step,
|
||||||
x,
|
x,
|
||||||
|
@ -199,8 +203,9 @@ def main(args):
|
||||||
if local_rank == 0 and global_step % cfg['train'][
|
if local_rank == 0 and global_step % cfg['train'][
|
||||||
'checkpoint_interval'] == 0:
|
'checkpoint_interval'] == 0:
|
||||||
io.save_parameters(
|
io.save_parameters(
|
||||||
os.path.join(args.output, 'checkpoints'), global_step,
|
os.path.join(args.output, 'checkpoints'), global_step, model,
|
||||||
model, optimizer)
|
optimizer)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
|
@ -94,7 +94,8 @@ class LengthRegulator(dg.Layer):
|
||||||
else:
|
else:
|
||||||
duration_predictor_output = layers.round(duration_predictor_output)
|
duration_predictor_output = layers.round(duration_predictor_output)
|
||||||
output = self.LR(x, duration_predictor_output, alpha)
|
output = self.LR(x, duration_predictor_output, alpha)
|
||||||
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1))
|
mel_pos = dg.to_variable(np.arange(1, output.shape[1] + 1)).astype(
|
||||||
|
np.int64)
|
||||||
mel_pos = layers.unsqueeze(mel_pos, [0])
|
mel_pos = layers.unsqueeze(mel_pos, [0])
|
||||||
return output, mel_pos
|
return output, mel_pos
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue