248 lines
9.1 KiB
Python
Executable File
248 lines
9.1 KiB
Python
Executable File
# Copyright 2020 IBM
|
|
# Author: peter.zhong@au1.ibm.com
|
|
#
|
|
# This is free software; you can redistribute it and/or modify
|
|
# it under the terms of the Apache 2.0 License.
|
|
#
|
|
# This software is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# Apache 2.0 License for more details.
|
|
|
|
import distance
|
|
from apted import APTED, Config
|
|
from apted.helpers import Tree
|
|
from lxml import etree, html
|
|
from collections import deque
|
|
from .parallel import parallel_process
|
|
from tqdm import tqdm
|
|
|
|
|
|
class TableTree(Tree):
|
|
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
|
self.tag = tag
|
|
self.colspan = colspan
|
|
self.rowspan = rowspan
|
|
self.content = content
|
|
self.children = list(children)
|
|
|
|
def bracket(self):
|
|
"""Show tree using brackets notation"""
|
|
if self.tag == 'td':
|
|
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
|
(self.tag, self.colspan, self.rowspan, self.content)
|
|
else:
|
|
result = '"tag": %s' % self.tag
|
|
for child in self.children:
|
|
result += child.bracket()
|
|
return "{{{}}}".format(result)
|
|
|
|
|
|
class CustomConfig(Config):
|
|
@staticmethod
|
|
def maximum(*sequences):
|
|
"""Get maximum possible value
|
|
"""
|
|
return max(map(len, sequences))
|
|
|
|
def normalized_distance(self, *sequences):
|
|
"""Get distance from 0 to 1
|
|
"""
|
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
|
|
|
def rename(self, node1, node2):
|
|
"""Compares attributes of trees"""
|
|
#print(node1.tag)
|
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
|
return 1.
|
|
if node1.tag == 'td':
|
|
if node1.content or node2.content:
|
|
#print(node1.content, )
|
|
return self.normalized_distance(node1.content, node2.content)
|
|
return 0.
|
|
|
|
|
|
|
|
class CustomConfig_del_short(Config):
|
|
@staticmethod
|
|
def maximum(*sequences):
|
|
"""Get maximum possible value
|
|
"""
|
|
return max(map(len, sequences))
|
|
|
|
def normalized_distance(self, *sequences):
|
|
"""Get distance from 0 to 1
|
|
"""
|
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
|
|
|
def rename(self, node1, node2):
|
|
"""Compares attributes of trees"""
|
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
|
return 1.
|
|
if node1.tag == 'td':
|
|
if node1.content or node2.content:
|
|
#print('before')
|
|
#print(node1.content, node2.content)
|
|
#print('after')
|
|
node1_content = node1.content
|
|
node2_content = node2.content
|
|
if len(node1_content) < 3:
|
|
node1_content = ['####']
|
|
if len(node2_content) < 3:
|
|
node2_content = ['####']
|
|
return self.normalized_distance(node1_content, node2_content)
|
|
return 0.
|
|
|
|
class CustomConfig_del_block(Config):
|
|
@staticmethod
|
|
def maximum(*sequences):
|
|
"""Get maximum possible value
|
|
"""
|
|
return max(map(len, sequences))
|
|
|
|
def normalized_distance(self, *sequences):
|
|
"""Get distance from 0 to 1
|
|
"""
|
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
|
|
|
def rename(self, node1, node2):
|
|
"""Compares attributes of trees"""
|
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
|
return 1.
|
|
if node1.tag == 'td':
|
|
if node1.content or node2.content:
|
|
|
|
node1_content = node1.content
|
|
node2_content = node2.content
|
|
while ' ' in node1_content:
|
|
print(node1_content.index(' '))
|
|
node1_content.pop(node1_content.index(' '))
|
|
while ' ' in node2_content:
|
|
print(node2_content.index(' '))
|
|
node2_content.pop(node2_content.index(' '))
|
|
return self.normalized_distance(node1_content, node2_content)
|
|
return 0.
|
|
|
|
class TEDS(object):
|
|
''' Tree Edit Distance basead Similarity
|
|
'''
|
|
|
|
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
|
assert isinstance(n_jobs, int) and (
|
|
n_jobs >= 1), 'n_jobs must be an integer greather than 1'
|
|
self.structure_only = structure_only
|
|
self.n_jobs = n_jobs
|
|
self.ignore_nodes = ignore_nodes
|
|
self.__tokens__ = []
|
|
|
|
def tokenize(self, node):
|
|
''' Tokenizes table cells
|
|
'''
|
|
self.__tokens__.append('<%s>' % node.tag)
|
|
if node.text is not None:
|
|
self.__tokens__ += list(node.text)
|
|
for n in node.getchildren():
|
|
self.tokenize(n)
|
|
if node.tag != 'unk':
|
|
self.__tokens__.append('</%s>' % node.tag)
|
|
if node.tag != 'td' and node.tail is not None:
|
|
self.__tokens__ += list(node.tail)
|
|
|
|
def load_html_tree(self, node, parent=None):
|
|
''' Converts HTML tree to the format required by apted
|
|
'''
|
|
global __tokens__
|
|
if node.tag == 'td':
|
|
if self.structure_only:
|
|
cell = []
|
|
else:
|
|
self.__tokens__ = []
|
|
self.tokenize(node)
|
|
cell = self.__tokens__[1:-1].copy()
|
|
new_node = TableTree(node.tag,
|
|
int(node.attrib.get('colspan', '1')),
|
|
int(node.attrib.get('rowspan', '1')),
|
|
cell, *deque())
|
|
else:
|
|
new_node = TableTree(node.tag, None, None, None, *deque())
|
|
if parent is not None:
|
|
parent.children.append(new_node)
|
|
if node.tag != 'td':
|
|
for n in node.getchildren():
|
|
self.load_html_tree(n, new_node)
|
|
if parent is None:
|
|
return new_node
|
|
|
|
def evaluate(self, pred, true):
|
|
''' Computes TEDS score between the prediction and the ground truth of a
|
|
given sample
|
|
'''
|
|
if (not pred) or (not true):
|
|
return 0.0
|
|
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
|
|
pred = html.fromstring(pred, parser=parser)
|
|
true = html.fromstring(true, parser=parser)
|
|
if pred.xpath('body/table') and true.xpath('body/table'):
|
|
pred = pred.xpath('body/table')[0]
|
|
true = true.xpath('body/table')[0]
|
|
if self.ignore_nodes:
|
|
etree.strip_tags(pred, *self.ignore_nodes)
|
|
etree.strip_tags(true, *self.ignore_nodes)
|
|
n_nodes_pred = len(pred.xpath(".//*"))
|
|
n_nodes_true = len(true.xpath(".//*"))
|
|
n_nodes = max(n_nodes_pred, n_nodes_true)
|
|
tree_pred = self.load_html_tree(pred)
|
|
tree_true = self.load_html_tree(true)
|
|
distance = APTED(tree_pred, tree_true,
|
|
CustomConfig()).compute_edit_distance()
|
|
return 1.0 - (float(distance) / n_nodes)
|
|
else:
|
|
return 0.0
|
|
|
|
def batch_evaluate(self, pred_json, true_json):
|
|
''' Computes TEDS score between the prediction and the ground truth of
|
|
a batch of samples
|
|
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
|
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
|
@output: {'FILENAME': 'TEDS SCORE', ...}
|
|
'''
|
|
samples = true_json.keys()
|
|
if self.n_jobs == 1:
|
|
scores = [self.evaluate(pred_json.get(
|
|
filename, ''), true_json[filename]['html']) for filename in tqdm(samples)]
|
|
else:
|
|
inputs = [{'pred': pred_json.get(
|
|
filename, ''), 'true': true_json[filename]['html']} for filename in samples]
|
|
scores = parallel_process(
|
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
|
scores = dict(zip(samples, scores))
|
|
return scores
|
|
|
|
def batch_evaluate_html(self, pred_htmls, true_htmls):
|
|
''' Computes TEDS score between the prediction and the ground truth of
|
|
a batch of samples
|
|
'''
|
|
if self.n_jobs == 1:
|
|
scores = [self.evaluate(pred_html, true_html) for (
|
|
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
|
else:
|
|
inputs = [{"pred": pred_html, "true": true_html} for(
|
|
pred_html, true_html) in zip(pred_htmls, true_htmls)]
|
|
|
|
scores = parallel_process(
|
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1)
|
|
return scores
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import json
|
|
import pprint
|
|
with open('sample_pred.json') as fp:
|
|
pred_json = json.load(fp)
|
|
with open('sample_gt.json') as fp:
|
|
true_json = json.load(fp)
|
|
teds = TEDS(n_jobs=4)
|
|
scores = teds.batch_evaluate(pred_json, true_json)
|
|
pp = pprint.PrettyPrinter()
|
|
pp.pprint(scores)
|