diff --git a/examples/transformer_tts/train_transformer.py b/examples/transformer_tts/train_transformer.py index d258209..02284f7 100644 --- a/examples/transformer_tts/train_transformer.py +++ b/examples/transformer_tts/train_transformer.py @@ -94,10 +94,16 @@ def main(args): if args.stop_token: writer.add_scalar('stop_loss', stop_loss.numpy(), global_step) - writer.add_scalars('alphas', { - 'encoder_alpha':model.encoder.alpha.numpy(), - 'decoder_alpha':model.decoder.alpha.numpy(), - }, global_step) + if args.use_data_parallel: + writer.add_scalars('alphas', { + 'encoder_alpha':model._layers.encoder.alpha.numpy(), + 'decoder_alpha':model._layers.decoder.alpha.numpy(), + }, global_step) + else: + writer.add_scalars('alphas', { + 'encoder_alpha':model.encoder.alpha.numpy(), + 'decoder_alpha':model.decoder.alpha.numpy(), + }, global_step) writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step) @@ -144,4 +150,4 @@ if __name__ =='__main__': args = parser.parse_args() # Print the whole config setting. pprint(args) - main(args) \ No newline at end of file + main(args)