Merge branch 'fix' into 'master'
fix a bug of transformertts when use data parallel. See merge request !19
This commit is contained in:
commit
45dd7c619a
|
@ -94,10 +94,16 @@ def main(args):
|
||||||
if args.stop_token:
|
if args.stop_token:
|
||||||
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
|
||||||
|
|
||||||
writer.add_scalars('alphas', {
|
if args.use_data_parallel:
|
||||||
'encoder_alpha':model.encoder.alpha.numpy(),
|
writer.add_scalars('alphas', {
|
||||||
'decoder_alpha':model.decoder.alpha.numpy(),
|
'encoder_alpha':model._layers.encoder.alpha.numpy(),
|
||||||
}, global_step)
|
'decoder_alpha':model._layers.decoder.alpha.numpy(),
|
||||||
|
}, global_step)
|
||||||
|
else:
|
||||||
|
writer.add_scalars('alphas', {
|
||||||
|
'encoder_alpha':model.encoder.alpha.numpy(),
|
||||||
|
'decoder_alpha':model.decoder.alpha.numpy(),
|
||||||
|
}, global_step)
|
||||||
|
|
||||||
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
|
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
|
||||||
|
|
||||||
|
@ -144,4 +150,4 @@ if __name__ =='__main__':
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# Print the whole config setting.
|
# Print the whole config setting.
|
||||||
pprint(args)
|
pprint(args)
|
||||||
main(args)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue