Merge branch 'fix' into 'master'

fix a bug of transformertts when use data parallel.

See merge request !19
This commit is contained in:
liuyibing01 2020-02-19 21:56:33 +08:00
commit 45dd7c619a
1 changed files with 11 additions and 5 deletions

View File

@ -94,10 +94,16 @@ def main(args):
if args.stop_token:
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
writer.add_scalars('alphas', {
'encoder_alpha':model.encoder.alpha.numpy(),
'decoder_alpha':model.decoder.alpha.numpy(),
}, global_step)
if args.use_data_parallel:
writer.add_scalars('alphas', {
'encoder_alpha':model._layers.encoder.alpha.numpy(),
'decoder_alpha':model._layers.decoder.alpha.numpy(),
}, global_step)
else:
writer.add_scalars('alphas', {
'encoder_alpha':model.encoder.alpha.numpy(),
'decoder_alpha':model.decoder.alpha.numpy(),
}, global_step)
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
@ -144,4 +150,4 @@ if __name__ =='__main__':
args = parser.parse_args()
# Print the whole config setting.
pprint(args)
main(args)
main(args)