Update predict.py

This commit is contained in:
tlk1997 2021-11-01 16:44:06 +08:00 committed by GitHub
parent f31e3a2bea
commit 4b791f6e18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -91,9 +91,10 @@ def main(cfg):
max_length=cfg.tgt_max_len, max_len_a=cfg.src_seq_ratio,num_beams=cfg.num_beams, do_sample=False,
repetition_penalty=1, length_penalty=cfg.length_penalty, pad_token_id=1,
restricter=None)
trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, process=process, args=cfg, logger=logger, loss=None, metrics=None, writer=writer)
trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, process=process, args=cfg, logger=logger,
loss=None, metrics=None, writer=writer)
trainer.predict()
if __name__ == "__main__":
main()
main()