fix a bug of transformertts when use data parallel.
This commit is contained in:
parent
6428ce5439
commit
9b86f2008d
|
@ -94,10 +94,16 @@ def main(args):
|
|||
if args.stop_token:
|
||||
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha':model.encoder.alpha.numpy(),
|
||||
'decoder_alpha':model.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
if args.use_data_parallel:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha':model._layers.encoder.alpha.numpy(),
|
||||
'decoder_alpha':model._layers.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
else:
|
||||
writer.add_scalars('alphas', {
|
||||
'encoder_alpha':model.encoder.alpha.numpy(),
|
||||
'decoder_alpha':model.decoder.alpha.numpy(),
|
||||
}, global_step)
|
||||
|
||||
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
|
||||
|
||||
|
@ -144,4 +150,4 @@ if __name__ =='__main__':
|
|||
args = parser.parse_args()
|
||||
# Print the whole config setting.
|
||||
pprint(args)
|
||||
main(args)
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue