Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/dit
/object_detection
/ditod
/table_evaluation
/data_structure.py
| """ | |
| Data structures used by the evaluation process. | |
| Yu Fang - March 2019 | |
| """ | |
| from collections import Iterable | |
| import numpy as np | |
| from shapely.geometry import Polygon | |
| # helper functions | |
| def flatten(lis): | |
| for item in lis: | |
| if isinstance(item, Iterable) and not isinstance(item, str): | |
| for x in flatten(item): | |
| yield x | |
| else: | |
| yield item | |
| # derived from https://blog.csdn.net/u012433049/article/details/82909484 | |
| def compute_poly_iou(list1, list2): | |
| a1 = np.array(list1, dtype=int).reshape(-1, 2) | |
| poly1 = Polygon(a1) | |
| poly1_clean = poly1.buffer(0) | |
| a2 = np.array(list2, dtype=int).reshape(-1, 2) | |
| poly2 = Polygon(a2) | |
| poly2_clean = poly2.buffer(0) | |
| try: | |
| # iou = poly1.intersection(poly2).area / poly1.union(poly2).area | |
| iou = poly1_clean.intersection(poly2_clean).area / poly1_clean.union(poly2_clean).area | |
| except ZeroDivisionError: | |
| iou = 0 | |
| return iou | |
| class Cell(object): | |
| # @:param start_row : start row index of the Cell | |
| # @:param start_col : start column index of the Cell | |
| # @:param end-row : end row index of the Cell | |
| # @:param end-col : end column index of the Cell | |
| # @:param cell_box: bounding-box of the Cell (coordinates are saved as a string) | |
| # @:param content_box: bounding-box of the text content within Cell (unused variable) | |
| # @:param cell_id: unique id of the Cell | |
| def __init__(self, table_id, start_row, start_col, cell_box, end_row, end_col, content_box=""): | |
| self._start_row = int(start_row) | |
| self._start_col = int(start_col) | |
| self._cell_box = cell_box | |
| self._content_box = content_box | |
| self._table_id = table_id # the table_id this cell belongs to | |
| # self._cell_name = cell_id # specify the cell using passed-in cell_id | |
| self._cell_id = id(self) | |
| # self._region = region | |
| # check for end-row and end-col special case | |
| if end_row == -1: | |
| self._end_row = self.start_row | |
| else: | |
| self._end_row = int(end_row) | |
| if end_col == -1: | |
| self._end_col = self._start_col | |
| else: | |
| self._end_col = int(end_col) | |
| def start_row(self): | |
| return self._start_row | |
| def start_col(self): | |
| return self._start_col | |
| def end_row(self): | |
| return self._end_row | |
| def end_col(self): | |
| return self._end_col | |
| def cell_box(self): | |
| return self._cell_box | |
| def content_box(self): | |
| return self._content_box | |
| def cell_id(self): | |
| return self._cell_id | |
| def table_id(self): | |
| return self._table_id | |
| def __str__(self): | |
| return "CELL row=[%d, %d] col=[%d, %d] (coords=%s)" %(self.start_row, self.end_row | |
| , self.start_col, self.end_col | |
| , self.cell_box) | |
| # return the IoU value of two cell blocks | |
| def compute_cell_iou(self, another_cell): | |
| cell_box_1_temp = [] | |
| for el in self.cell_box.split(): | |
| cell_box_1_temp.append((el.split(","))) | |
| cell_box_1 = list(flatten(cell_box_1_temp)) | |
| cell_box_1 = [int(x) for x in cell_box_1] | |
| cell_box_2_temp = [] | |
| for el in another_cell.cell_box.split(): | |
| cell_box_2_temp.append((el.split(","))) | |
| cell_box_2 = list(flatten(cell_box_2_temp)) | |
| cell_box_2 = [int(x) for x in cell_box_2] | |
| return compute_poly_iou(cell_box_1, cell_box_2) | |
| # check if the two cell object denotes same cell area in table | |
| def check_same(self, another_cell): | |
| return self._start_row == another_cell.start_row and self._end_row == another_cell.end_row and \ | |
| self._start_col == another_cell.start_col and self._end_col == another_cell.end_col | |
| # Note: currently save the relation with two cell object involved, | |
| # can be replaced by cell_id in follow-up memory clean up | |
| class AdjRelation: | |
| DIR_HORIZ = 1 | |
| DIR_VERT = 2 | |
| def __init__(self, fromText, toText, direction): | |
| # @param: fromText, toText are Cell objects (may be changed to cell-ID for further development) | |
| self._fromText = fromText | |
| self._toText = toText | |
| self._direction = direction | |
| def fromText(self): | |
| return self._fromText | |
| def toText(self): | |
| return self._toText | |
| def direction(self): | |
| return self._direction | |
| def __str__(self): | |
| if self.direction == self.DIR_VERT: | |
| dir = "vertical" | |
| else: | |
| dir = "horizontal" | |
| return 'ADJ_RELATION: ' + str(self._fromText) + ' ' + str(self._toText) + ' ' + dir | |
| def isEqual(self, otherRelation): | |
| return self.fromText.cell_id == otherRelation.fromText.cell_id and \ | |
| self.toText.cell_id == otherRelation.toText.cell_id and self.direction == otherRelation.direction | |
| class Table: | |
| def __init__(self, tableNode): | |
| self._root = tableNode | |
| self._id = id(self) | |
| self._table_coords = "" | |
| self._maxRow = 0 # PS: indexing from 0 | |
| self._maxCol = 0 | |
| self._cells = [] # save a table as list of <Cell>s | |
| self.adj_relations = [] # save the adj_relations for the table | |
| self.parsed = False | |
| self.found = False # check if the find_adj_relations() has been called once | |
| self.parse_table() | |
| def __str__(self): | |
| return "TABLE object - {} row x {} col".format(self._maxRow+1, self._maxCol+1) | |
| def id(self): | |
| return self._id | |
| def table_coords(self): | |
| return self._table_coords | |
| def table_cells(self): | |
| return self._cells | |
| # parse input xml to cell lists | |
| def parse_table(self): | |
| # get the table bbox | |
| self._table_coords = str(self._root.getElementsByTagName("Coords")[0].getAttribute("points")) | |
| # get info for each cell | |
| cells = self._root.getElementsByTagName("cell") | |
| max_row = max_col = 0 | |
| for cell in cells: | |
| sr = cell.getAttribute("start-row") | |
| sc = cell.getAttribute("start-col") | |
| cell_id = cell.getAttribute("id") | |
| b_points = str(cell.getElementsByTagName("Coords")[0].getAttribute("points")) | |
| # try: | |
| # try: | |
| # text = cell.getElementsByTagName("content")[0].firstChild.nodeValue | |
| # except AttributeError: | |
| # text = "" | |
| # except IndexError: | |
| # text = "initialized cell as no content" | |
| er = cell.getAttribute("end-row") if cell.hasAttribute("end-row") else -1 | |
| ec = cell.getAttribute("end-col") if cell.hasAttribute("end-col") else -1 | |
| new_cell = Cell(table_id=str(self.id), start_row=sr, start_col=sc, cell_box=b_points, | |
| end_row=er, end_col=ec) | |
| max_row = max(max_row, int(sr), int(er)) | |
| max_col = max(max_col, int(sc), int(ec)) | |
| self._cells.append(new_cell) | |
| self._maxCol = max_col | |
| self._maxRow = max_row | |
| self.parsed = True | |
| # generate a table-like structure for finding adj_relations | |
| def convert_2d(self): | |
| table = [[0 for x in range(self._maxCol+1)] for y in range(self._maxRow+1)] # init blank cell with int 0 | |
| for cell in self._cells: | |
| cur_row = cell.start_row | |
| while cur_row <= cell.end_row: | |
| cur_col = cell.start_col | |
| while cur_col <= cell.end_col: | |
| temp = table[cur_row][cur_col] | |
| if temp == 0: | |
| table[cur_row][cur_col] = cell | |
| elif type(temp) == list: | |
| temp.append(cell) | |
| table[cur_row][cur_col] = temp | |
| else: | |
| table[cur_row][cur_col] = [temp, cell] | |
| cur_col += 1 | |
| cur_row += 1 | |
| return table | |
| def find_adj_relations(self): | |
| if self.found: | |
| return self.adj_relations | |
| else: | |
| # if len(self._cells) == 0: | |
| if self.parsed == False: | |
| # fix: cases where there's no cell in table? | |
| print("table is not parsed for further steps.") | |
| self.parse_table() | |
| self.find_adj_relations() | |
| else: | |
| retVal = [] | |
| tab = self.convert_2d() | |
| # find horizontal relations | |
| for r in range(self._maxRow+1): | |
| for c_from in range(self._maxCol): | |
| temp_pos = tab[r][c_from] | |
| if temp_pos == 0: | |
| continue | |
| elif type(temp_pos) == list: | |
| for cell in temp_pos: | |
| c_to = c_from + 1 | |
| if tab[r][c_to] != 0: | |
| # find relation between two adjacent cells | |
| if type(tab[r][c_to]) == list: | |
| for cell_to in tab[r][c_to]: | |
| if cell != cell_to and (not cell.check_same(cell_to)): | |
| adj_relation = AdjRelation(cell, cell_to, AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| else: | |
| if cell != tab[r][c_to]: | |
| adj_relation = AdjRelation(cell, tab[r][c_to], AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| else: | |
| # find the next non-blank cell, if exists | |
| for temp in range(c_from + 1, self._maxCol + 1): | |
| if tab[r][temp] != 0: | |
| if type(tab[r][temp]) == list: | |
| for cell_to in tab[r][temp]: | |
| adj_relation = AdjRelation(cell, cell_to, | |
| AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| else: | |
| adj_relation = AdjRelation(cell, tab[r][temp], | |
| AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| break | |
| else: | |
| c_to = c_from + 1 | |
| if tab[r][c_to] != 0: | |
| # find relation between two adjacent cells | |
| if type(tab[r][c_to]) == list: | |
| for cell_to in tab[r][c_to]: | |
| if temp_pos != cell_to: | |
| adj_relation = AdjRelation(temp_pos, cell_to, AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| else: | |
| if temp_pos != tab[r][c_to]: | |
| adj_relation = AdjRelation(temp_pos, tab[r][c_to], AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| else: | |
| # find the next non-blank cell, if exists | |
| for temp in range(c_from + 1, self._maxCol + 1): | |
| if tab[r][temp] != 0: | |
| if type(tab[r][temp]) == list: | |
| for cell_to in tab[r][temp]: | |
| adj_relation = AdjRelation(temp_pos, cell_to, | |
| AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| else: | |
| adj_relation = AdjRelation(temp_pos, tab[r][temp], AdjRelation.DIR_HORIZ) | |
| retVal.append(adj_relation) | |
| break | |
| # find vertical relations | |
| for c in range(self._maxCol+1): | |
| for r_from in range(self._maxRow): | |
| temp_pos = tab[r_from][c] | |
| if temp_pos == 0: | |
| continue | |
| elif type(temp_pos) == list: | |
| for cell in temp_pos: | |
| r_to = r_from + 1 | |
| if tab[r_to][c] != 0: | |
| # find relation between two adjacent cells | |
| if type(tab[r_to][c]) == list: | |
| for cell_to in tab[r_to][c]: | |
| if cell != cell_to and (not cell.check_same(cell_to)): | |
| adj_relation = AdjRelation(cell, cell_to, AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| else: | |
| if cell != tab[r_to][c]: | |
| adj_relation = AdjRelation(cell, tab[r_to][c], AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| else: | |
| # find the next non-blank cell, if exists | |
| for temp in range(r_from + 1, self._maxRow + 1): | |
| if tab[temp][c] != 0: | |
| if type(tab[temp][c]) == list: | |
| for cell_to in tab[temp][c]: | |
| adj_relation = AdjRelation(cell, cell_to, | |
| AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| else: | |
| adj_relation = AdjRelation(cell, tab[temp][c], | |
| AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| break | |
| else: | |
| r_to = r_from + 1 | |
| if tab[r_to][c] != 0: | |
| # find relation between two adjacent cells | |
| if type(tab[r_to][c]) == list: | |
| for cell_to in tab[r_to][c]: | |
| if temp_pos != cell_to: | |
| adj_relation = AdjRelation(temp_pos, cell_to, AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| else: | |
| if temp_pos != tab[r_to][c]: | |
| adj_relation = AdjRelation(temp_pos, tab[r_to][c], AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| else: | |
| # find the next non-blank cell, if exists | |
| for temp in range(r_from + 1, self._maxRow + 1): | |
| if tab[temp][c] != 0: | |
| if type(tab[temp][c]) == list: | |
| for cell_to in tab[temp][c]: | |
| adj_relation = AdjRelation(temp_pos, cell_to, AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| else: | |
| adj_relation = AdjRelation(temp_pos, tab[temp][c], AdjRelation.DIR_VERT) | |
| retVal.append(adj_relation) | |
| break | |
| # eliminate duplicates | |
| repeat = True | |
| while repeat: | |
| repeat = False | |
| duplicates = [] | |
| for ar1 in retVal: | |
| for ar2 in retVal: | |
| if ar1 != ar2: | |
| if ar1.direction == ar2.direction and ar1.fromText == ar2.fromText and\ | |
| ar1.toText == ar2.toText: | |
| duplicates.append(ar2) | |
| break | |
| else: | |
| continue | |
| break | |
| if len(duplicates) > 0: | |
| repeat = True | |
| retVal.remove(duplicates[0]) | |
| self.found = True | |
| self.adj_relations = retVal | |
| return self.adj_relations | |
| # compute the IOU of table, pass-in var is another Table object | |
| def compute_table_iou(self, another_table): | |
| table_box_1_temp = [] | |
| for el in self.table_coords.split(): | |
| table_box_1_temp.append((el.split(","))) | |
| table_box_1 = list(flatten(table_box_1_temp)) | |
| table_box_1 = [int(x) for x in table_box_1] | |
| table_box_2_temp = [] | |
| for el in another_table.table_coords.split(): | |
| table_box_2_temp.append((el.split(","))) | |
| table_box_2 = list(flatten(table_box_2_temp)) | |
| table_box_2 = [int(x) for x in table_box_2] | |
| return compute_poly_iou(table_box_1, table_box_2) | |
| # find the cell mapping of tables as dictionary, pass-in var is another table and the desired IOU value | |
| def find_cell_mapping(self, target_table, iou_value): | |
| mapped_cell = [] # store the matches as tuples - (gt, result) mind the order of table when passing in | |
| for cell_1 in self.table_cells: | |
| for cell_2 in target_table.table_cells: | |
| if cell_1.compute_cell_iou(cell_2) >= iou_value: | |
| mapped_cell.append((cell_1, cell_2)) | |
| break | |
| ret = dict(mapped_cell) | |
| # print(ret) | |
| return ret | |
| # to print a table cell mapping | |
| def printCellMapping(cls, dMappedCell): | |
| print("-"*25) | |
| for cell1, cell2 in dMappedCell.items(): | |
| print(" ", cell1, " --> ", cell2) | |
| # to print a table set of adjacency relations | |
| def printAdjacencyRelationList(cls, lAdjRel, title=""): | |
| print("--- %s "%title + "-"*25) | |
| for adj in lAdjRel: | |
| print(adj) | |
| class ResultStructure: | |
| def __init__(self, truePos, gtTotal, resTotal): | |
| self._truePos = truePos | |
| self._gtTotal = gtTotal | |
| self._resTotal = resTotal | |
| def truePos(self): | |
| return self._truePos | |
| def gtTotal(self): | |
| return self._gtTotal | |
| def resTotal(self): | |
| return self._resTotal | |
| def __str__(self): | |
| return "true: {}, gt: {}, res: {}".format(self._truePos, self._gtTotal, self._resTotal) |