refine
This commit is contained in:
parent
6e1cfb0525
commit
7bcabe0f36
|
@ -210,7 +210,10 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
preds = model(images, data=batch[1:])
|
||||
if use_srn or model_type == 'table':
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
|
@ -356,7 +359,10 @@ def eval(model,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
preds = model(images, data=batch[1:])
|
||||
if use_srn or model_type == 'table':
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
total_time += time.time() - start
|
||||
|
|
Loading…
Reference in New Issue