refine
This commit is contained in:
parent
c0492e02c7
commit
053cc43d82
|
@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
|
|||
|
||||
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||
|
||||
def forward(self, inputs, others):
|
||||
def forward(self, inputs, targets=None):
|
||||
others = targets[-4:]
|
||||
encoder_word_pos = others[0]
|
||||
gsrm_word_pos = others[1]
|
||||
gsrm_slf_attn_bias1 = others[2]
|
||||
|
|
|
@ -209,14 +209,8 @@ def train(config,
|
|||
lr = optimizer.get_lr()
|
||||
images = batch[0]
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
model_average = True
|
||||
elif model_type == "table":
|
||||
others = batch[1:]
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
preds = model(images, data=batch[1:])
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
|
@ -358,13 +352,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
preds = model(images, data=batch[1:])
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
total_time += time.time() - start
|
||||
|
|
Loading…
Reference in New Issue