diff --git a/deepdoc/vision/table_structure_recognizer.py b/deepdoc/vision/table_structure_recognizer.py index 26b182e10..7f4736c69 100644 --- a/deepdoc/vision/table_structure_recognizer.py +++ b/deepdoc/vision/table_structure_recognizer.py @@ -23,6 +23,7 @@ from huggingface_hub import snapshot_download from api.utils.file_utils import get_project_base_directory from rag.nlp import rag_tokenizer + from .recognizer import Recognizer @@ -38,31 +39,49 @@ class TableStructureRecognizer(Recognizer): def __init__(self): try: - super().__init__(self.labels, "tsr", os.path.join( - get_project_base_directory(), - "rag/res/deepdoc")) + super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "rag/res/deepdoc")) except Exception: - super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc", - local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), - local_dir_use_symlinks=False)) + super().__init__( + self.labels, + "tsr", + snapshot_download( + repo_id="InfiniFlow/deepdoc", + local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"), + local_dir_use_symlinks=False, + ), + ) def __call__(self, images, thr=0.2): - tbls = super().__call__(images, thr) + table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower() + if table_structure_recognizer_type not in ["onnx", "ascend"]: + raise RuntimeError("Unsupported table structure recognizer type.") + + if table_structure_recognizer_type == "onnx": + logging.debug("Using Onnx table structure recognizer", flush=True) + tbls = super().__call__(images, thr) + else: # ascend + logging.debug("Using Ascend table structure recognizer", flush=True) + tbls = self._run_ascend_tsr(images, thr) + res = [] # align left&right for rows, align top&bottom for columns for tbl in tbls: - lts = [{"label": b["type"], + lts = [ + { + "label": b["type"], "score": b["score"], - "x0": b["bbox"][0], "x1": b["bbox"][2], - "top": b["bbox"][1], "bottom": b["bbox"][-1] - } for b in tbl] + "x0": b["bbox"][0], + "x1": b["bbox"][2], + "top": b["bbox"][1], + "bottom": b["bbox"][-1], + } + for b in tbl + ] if not lts: continue - left = [b["x0"] for b in lts if b["label"].find( - "row") > 0 or b["label"].find("header") > 0] - right = [b["x1"] for b in lts if b["label"].find( - "row") > 0 or b["label"].find("header") > 0] + left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0] + right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0] if not left: continue left = np.mean(left) if len(left) > 4 else np.min(left) @@ -93,11 +112,8 @@ class TableStructureRecognizer(Recognizer): @staticmethod def is_caption(bx): - patt = [ - r"[图表]+[ 0-9::]{2,}" - ] - if any([re.match(p, bx["text"].strip()) for p in patt]) \ - or bx.get("layout_type", "").find("caption") >= 0: + patt = [r"[图表]+[ 0-9::]{2,}"] + if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0: return True return False @@ -115,7 +131,7 @@ class TableStructureRecognizer(Recognizer): (r"^[0-9A-Z/\._~-]+$", "Ca"), (r"^[A-Z]*[a-z' -]+$", "En"), (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"), - (r"^.{1}$", "Sg") + (r"^.{1}$", "Sg"), ] for p, n in patt: if re.search(p, b["text"].strip()): @@ -156,21 +172,19 @@ class TableStructureRecognizer(Recognizer): rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b] rowh = np.min(rowh) if rowh else 0 boxes = Recognizer.sort_R_firstly(boxes, rowh / 2) - #for b in boxes:print(b) + # for b in boxes:print(b) boxes[0]["rn"] = 0 rows = [[boxes[0]]] btm = boxes[0]["bottom"] for b in boxes[1:]: b["rn"] = len(rows) - 1 lst_r = rows[-1] - if lst_r[-1].get("R", "") != b.get("R", "") \ - or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2") - ): # new row + if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row btm = b["bottom"] b["rn"] += 1 rows.append([b]) continue - btm = (btm + b["bottom"]) / 2. + btm = (btm + b["bottom"]) / 2.0 rows[-1].append(b) colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b] @@ -186,14 +200,14 @@ class TableStructureRecognizer(Recognizer): for b in boxes[1:]: b["cn"] = len(cols) - 1 lst_c = cols[-1] - if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][ - "page_number"]) \ - or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col + if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or ( + b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2") + ): # new col right = b["x1"] b["cn"] += 1 cols.append([b]) continue - right = (right + b["x1"]) / 2. + right = (right + b["x1"]) / 2.0 cols[-1].append(b) tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))] @@ -214,10 +228,8 @@ class TableStructureRecognizer(Recognizer): if e > 1: j += 1 continue - f = (j > 0 and tbl[ii][j - 1] and tbl[ii] - [j - 1][0].get("text")) or j == 0 - ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii] - [j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) + f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0 + ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii]) if f and ff: j += 1 continue @@ -228,13 +240,11 @@ class TableStructureRecognizer(Recognizer): if j > 0 and not f: for i in range(len(tbl)): if tbl[i][j - 1]: - left = min(left, np.min( - [bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) + left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]])) if j + 1 < len(tbl[0]) and not ff: for i in range(len(tbl)): if tbl[i][j + 1]: - right = min(right, np.min( - [a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) + right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]])) assert left < 100000 or right < 100000 if left < right: for jj in range(j, len(tbl[0])): @@ -260,8 +270,7 @@ class TableStructureRecognizer(Recognizer): for i in range(len(tbl)): tbl[i].pop(j) cols.pop(j) - assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % ( - len(cols), len(tbl[0])) + assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0])) if len(cols) >= 4: # remove single in row @@ -277,10 +286,8 @@ class TableStructureRecognizer(Recognizer): if e > 1: i += 1 continue - f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1] - [jj][0].get("text")) or i == 0 - ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1] - [jj][0].get("text")) or i + 1 >= len(tbl) + f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0 + ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl) if f and ff: i += 1 continue @@ -292,13 +299,11 @@ class TableStructureRecognizer(Recognizer): if i > 0 and not f: for j in range(len(tbl[i - 1])): if tbl[i - 1][j]: - up = min(up, np.min( - [bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) + up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]])) if i + 1 < len(tbl) and not ff: for j in range(len(tbl[i + 1])): if tbl[i + 1][j]: - down = min(down, np.min( - [a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) + down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]])) assert up < 100000 or down < 100000 if up < down: for ii in range(i, len(tbl)): @@ -333,22 +338,15 @@ class TableStructureRecognizer(Recognizer): cnt += 1 if max_type == "Nu" and arr[0]["btype"] == "Nu": continue - if any([a.get("H") for a in arr]) \ - or (max_type == "Nu" and arr[0]["btype"] != "Nu"): + if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"): h += 1 if h / cnt > 0.5: hdset.add(i) if html: - return TableStructureRecognizer.__html_table(cap, hdset, - TableStructureRecognizer.__cal_spans(boxes, rows, - cols, tbl, True) - ) + return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True)) - return TableStructureRecognizer.__desc_table(cap, hdset, - TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, - False), - is_english) + return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english) @staticmethod def __html_table(cap, hdset, tbl): @@ -367,10 +365,8 @@ class TableStructureRecognizer(Recognizer): continue txt = "" if arr: - h = min(np.min([c["bottom"] - c["top"] - for c in arr]) / 2, 10) - txt = " ".join([c["text"] - for c in Recognizer.sort_Y_firstly(arr, h)]) + h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10) + txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)]) txts.append(txt) sp = "" if arr[0].get("colspan"): @@ -436,15 +432,11 @@ class TableStructureRecognizer(Recognizer): if headers[j][k].find(headers[j - 1][k]) >= 0: continue if len(headers[j][k]) > len(headers[j - 1][k]): - headers[j][k] += (de if headers[j][k] - else "") + headers[j - 1][k] + headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k] else: - headers[j][k] = headers[j - 1][k] \ - + (de if headers[j - 1][k] else "") \ - + headers[j][k] + headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k] - logging.debug( - f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") + logging.debug(f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}") row_txt = [] for i in range(rowno): if i in hdr_rowno: @@ -503,14 +495,10 @@ class TableStructureRecognizer(Recognizer): @staticmethod def __cal_spans(boxes, rows, cols, tbl, html=True): # caculate span - clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) - for cln in cols] - crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) - for cln in cols] - rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) - for row in rows] - rbtm = [np.mean([c.get("R_btm", c["bottom"]) - for c in row]) for row in rows] + clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols] + crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols] + rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows] + rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows] for b in boxes: if "SP" not in b: continue @@ -585,3 +573,40 @@ class TableStructureRecognizer(Recognizer): tbl[rowspan[0]][colspan[0]] = arr return tbl + + def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16): + import math + + from ais_bench.infer.interface import InferSession + + model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc") + model_file_path = os.path.join(model_dir, "tsr.om") + + if not os.path.exists(model_file_path): + raise ValueError(f"Model file not found: {model_file_path}") + + device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0)) + session = InferSession(device_id=device_id, model_path=model_file_path) + + images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list] + results = [] + + conf_thr = max(thr, 0.08) + + batch_loop_cnt = math.ceil(float(len(images)) / batch_size) + for bi in range(batch_loop_cnt): + s = bi * batch_size + e = min((bi + 1) * batch_size, len(images)) + batch_images = images[s:e] + + inputs_list = self.preprocess(batch_images) + for ins in inputs_list: + feeds = [] + if "image" in ins: + feeds.append(ins["image"]) + else: + feeds.append(ins[self.input_names[0]]) + output_list = session.infer(feeds=feeds, mode="static") + bb = self.postprocess(output_list, ins, conf_thr) + results.append(bb) + return results