From 7bcabe0f36e5a3ce287a6a041bc81fd1aafa0175 Mon Sep 17 00:00:00 2001
From: MissPenguin <lichenxia1991@163.com>
Date: Tue, 22 Jun 2021 12:39:43 +0000
Subject: [PATCH] refine

---
 tools/program.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/tools/program.py b/tools/program.py
index 2bb34835..2d99f296 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -210,7 +210,10 @@ def train(config,
             images = batch[0]
             if use_srn:
                 model_average = True
-            preds = model(images, data=batch[1:])
+            if use_srn or model_type == 'table':
+                preds = model(images, data=batch[1:])
+            else:
+                preds = model(images)
             loss = loss_class(preds, batch)
             avg_loss = loss['loss']
             avg_loss.backward()
@@ -356,7 +359,10 @@ def eval(model,
                 break
             images = batch[0]
             start = time.time()
-            preds = model(images, data=batch[1:])
+            if use_srn or model_type == 'table':
+                preds = model(images, data=batch[1:])
+            else:
+                preds = model(images)
             batch = [item.numpy() for item in batch]
             # Obtain usable results from post-processing methods
             total_time += time.time() - start