add table eval and predict script
This commit is contained in:
parent
794362481e
commit
eb7ce442a3
|
@ -0,0 +1,247 @@
|
|||
# Copyright 2020 IBM
|
||||
# Author: peter.zhong@au1.ibm.com
|
||||
#
|
||||
# This is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the Apache 2.0 License.
|
||||
#
|
||||
# This software is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# Apache 2.0 License for more details.
|
||||
|
||||
import distance
|
||||
from apted import APTED, Config
|
||||
from apted.helpers import Tree
|
||||
from lxml import etree, html
|
||||
from collections import deque
|
||||
from .parallel import parallel_process
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class TableTree(Tree):
|
||||
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
||||
self.tag = tag
|
||||
self.colspan = colspan
|
||||
self.rowspan = rowspan
|
||||
self.content = content
|
||||
self.children = list(children)
|
||||
|
||||
def bracket(self):
|
||||
"""Show tree using brackets notation"""
|
||||
if self.tag == 'td':
|
||||
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
||||
(self.tag, self.colspan, self.rowspan, self.content)
|
||||
else:
|
||||
result = '"tag": %s' % self.tag
|
||||
for child in self.children:
|
||||
result += child.bracket()
|
||||
return "{{{}}}".format(result)
|
||||
|
||||
|
||||
class CustomConfig(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
#print(node1.tag)
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
return 1.
|
||||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
#print(node1.content, )
|
||||
return self.normalized_distance(node1.content, node2.content)
|
||||
return 0.
|
||||
|
||||
|
||||
|
||||
class CustomConfig_del_short(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
return 1.
|
||||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
#print('before')
|
||||
#print(node1.content, node2.content)
|
||||
#print('after')
|
||||
node1_content = node1.content
|
||||
node2_content = node2.content
|
||||
if len(node1_content) < 3:
|
||||
node1_content = ['####']
|
||||
if len(node2_content) < 3:
|
||||
node2_content = ['####']
|
||||
return self.normalized_distance(node1_content, node2_content)
|
||||
return 0.
|
||||
|
||||
class CustomConfig_del_block(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
return 1.
|
||||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
|
||||
node1_content = node1.content
|
||||
node2_content = node2.content
|
||||
while ' ' in node1_content:
|
||||
print(node1_content.index(' '))
|
||||
node1_content.pop(node1_content.index(' '))
|
||||
while ' ' in node2_content:
|
||||
print(node2_content.index(' '))
|
||||
node2_content.pop(node2_content.index(' '))
|
||||
return self.normalized_distance(node1_content, node2_content)
|
||||
return 0.
|
||||
|
||||
class TEDS(object):
|
||||
''' Tree Edit Distance basead Similarity
|
||||
'''
|
||||
|
||||
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
||||
assert isinstance(n_jobs, int) and (
|
||||
n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
||||
self.structure_only = structure_only
|
||||
self.n_jobs = n_jobs
|
||||
self.ignore_nodes = ignore_nodes
|
||||
self.__tokens__ = []
|
||||
|
||||
def tokenize(self, node):
|
||||
''' Tokenizes table cells
|
||||
'''
|
||||
self.__tokens__.append('<%s>' % node.tag)
|
||||
if node.text is not None:
|
||||
self.__tokens__ += list(node.text)
|
||||
for n in node.getchildren():
|
||||
self.tokenize(n)
|
||||
if node.tag != 'unk':
|
||||
self.__tokens__.append('</%s>' % node.tag)
|
||||
if node.tag != 'td' and node.tail is not None:
|
||||
self.__tokens__ += list(node.tail)
|
||||
|
||||
def load_html_tree(self, node, parent=None):
|
||||
''' Converts HTML tree to the format required by apted
|
||||
'''
|
||||
global __tokens__
|
||||
if node.tag == 'td':
|
||||
if self.structure_only:
|
||||
cell = []
|
||||
else:
|
||||
self.__tokens__ = []
|
||||
self.tokenize(node)
|
||||
cell = self.__tokens__[1:-1].copy()
|
||||
new_node = TableTree(node.tag,
|
||||
int(node.attrib.get('colspan', '1')),
|
||||
int(node.attrib.get('rowspan', '1')),
|
||||
cell, *deque())
|
||||
else:
|
||||
new_node = TableTree(node.tag, None, None, None, *deque())
|
||||
if parent is not None:
|
||||
parent.children.append(new_node)
|
||||
if node.tag != 'td':
|
||||
for n in node.getchildren():
|
||||
self.load_html_tree(n, new_node)
|
||||
if parent is None:
|
||||
return new_node
|
||||
|
||||
def evaluate(self, pred, true):
|
||||
''' Computes TEDS score between the prediction and the ground truth of a
|
||||
given sample
|
||||
'''
|
||||
if (not pred) or (not true):
|
||||
return 0.0
|
||||
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
||||
pred = html.fromstring(pred, parser=parser)
|
||||
true = html.fromstring(true, parser=parser)
|
||||
if pred.xpath('body/table') and true.xpath('body/table'):
|
||||
pred = pred.xpath('body/table')[0]
|
||||
true = true.xpath('body/table')[0]
|
||||
if self.ignore_nodes:
|
||||
etree.strip_tags(pred, *self.ignore_nodes)
|
||||
etree.strip_tags(true, *self.ignore_nodes)
|
||||
n_nodes_pred = len(pred.xpath(".//*"))
|
||||
n_nodes_true = len(true.xpath(".//*"))
|
||||
n_nodes = max(n_nodes_pred, n_nodes_true)
|
||||
tree_pred = self.load_html_tree(pred)
|
||||
tree_true = self.load_html_tree(true)
|
||||
distance = APTED(tree_pred, tree_true,
|
||||
CustomConfig()).compute_edit_distance()
|
||||
return 1.0 - (float(distance) / n_nodes)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def batch_evaluate(self, pred_json, true_json):
|
||||
''' Computes TEDS score between the prediction and the ground truth of
|
||||
a batch of samples
|
||||
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
||||
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
||||
@output: {'FILENAME': 'TEDS SCORE', ...}
|
||||
'''
|
||||
samples = true_json.keys()
|
||||
if self.n_jobs == 1:
|
||||
scores = [self.evaluate(pred_json.get(
|
||||
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
||||
else:
|
||||
inputs = [{'pred': pred_json.get(
|
||||
filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
||||
scores = parallel_process(
|
||||
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
||||
scores = dict(zip(samples, scores))
|
||||
return scores
|
||||
|
||||
def batch_evaluate_html(self, pred_htmls, true_htmls):
|
||||
''' Computes TEDS score between the prediction and the ground truth of
|
||||
a batch of samples
|
||||
'''
|
||||
if self.n_jobs == 1:
|
||||
scores = [self.evaluate(pred_html, true_html) for (
|
||||
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
||||
else:
|
||||
inputs = [{"pred": pred_html, "true": true_html} for(
|
||||
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
||||
|
||||
scores = parallel_process(
|
||||
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import json
|
||||
import pprint
|
||||
with open('sample_pred.json') as fp:
|
||||
pred_json = json.load(fp)
|
||||
with open('sample_gt.json') as fp:
|
||||
true_json = json.load(fp)
|
||||
teds = TEDS(n_jobs=4)
|
||||
scores = teds.batch_evaluate(pred_json, true_json)
|
||||
pp = pprint.PrettyPrinter()
|
||||
pp.pprint(scores)
|
Loading…
Reference in New Issue