From 4b791f6e18d7ffc0397c1c1aae2b5875fcf07ff6 Mon Sep 17 00:00:00 2001 From: tlk1997 <56509305+tlk1997@users.noreply.github.com> Date: Mon, 1 Nov 2021 16:44:06 +0800 Subject: [PATCH] Update predict.py --- example/ner/few-shot/predict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example/ner/few-shot/predict.py b/example/ner/few-shot/predict.py index 5da458d..d42bf92 100644 --- a/example/ner/few-shot/predict.py +++ b/example/ner/few-shot/predict.py @@ -91,9 +91,10 @@ def main(cfg): max_length=cfg.tgt_max_len, max_len_a=cfg.src_seq_ratio,num_beams=cfg.num_beams, do_sample=False, repetition_penalty=1, length_penalty=cfg.length_penalty, pad_token_id=1, restricter=None) - trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, process=process, args=cfg, logger=logger, loss=None, metrics=None, writer=writer) + trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, process=process, args=cfg, logger=logger, + loss=None, metrics=None, writer=writer) trainer.predict() if __name__ == "__main__": - main() \ No newline at end of file + main()