Update predict.py
This commit is contained in:
parent
f31e3a2bea
commit
4b791f6e18
|
@ -91,7 +91,8 @@ def main(cfg):
|
|||
max_length=cfg.tgt_max_len, max_len_a=cfg.src_seq_ratio,num_beams=cfg.num_beams, do_sample=False,
|
||||
repetition_penalty=1, length_penalty=cfg.length_penalty, pad_token_id=1,
|
||||
restricter=None)
|
||||
trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, process=process, args=cfg, logger=logger, loss=None, metrics=None, writer=writer)
|
||||
trainer = Trainer(train_data=None, dev_data=None, test_data=test_dataloader, model=model, process=process, args=cfg, logger=logger,
|
||||
loss=None, metrics=None, writer=writer)
|
||||
trainer.predict()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue