mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-05 11:27:39 +00:00
Feat: add support for the Ascend table structure recognizer (#10110)
### What problem does this PR solve? Add support for the Ascend table structure recognizer. Use the environment variable `TABLE_STRUCTURE_RECOGNIZER_TYPE=ascend` to enable the Ascend table structure recognizer. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
parent
8c00cbc87a
commit
86f6da2f74
@ -23,6 +23,7 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
from rag.nlp import rag_tokenizer
|
from rag.nlp import rag_tokenizer
|
||||||
|
|
||||||
from .recognizer import Recognizer
|
from .recognizer import Recognizer
|
||||||
|
|
||||||
|
|
||||||
@ -38,31 +39,49 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
super().__init__(self.labels, "tsr", os.path.join(
|
super().__init__(self.labels, "tsr", os.path.join(get_project_base_directory(), "rag/res/deepdoc"))
|
||||||
get_project_base_directory(),
|
|
||||||
"rag/res/deepdoc"))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
super().__init__(self.labels, "tsr", snapshot_download(repo_id="InfiniFlow/deepdoc",
|
super().__init__(
|
||||||
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
self.labels,
|
||||||
local_dir_use_symlinks=False))
|
"tsr",
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="InfiniFlow/deepdoc",
|
||||||
|
local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
|
||||||
|
local_dir_use_symlinks=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, images, thr=0.2):
|
def __call__(self, images, thr=0.2):
|
||||||
tbls = super().__call__(images, thr)
|
table_structure_recognizer_type = os.getenv("TABLE_STRUCTURE_RECOGNIZER_TYPE", "onnx").lower()
|
||||||
|
if table_structure_recognizer_type not in ["onnx", "ascend"]:
|
||||||
|
raise RuntimeError("Unsupported table structure recognizer type.")
|
||||||
|
|
||||||
|
if table_structure_recognizer_type == "onnx":
|
||||||
|
logging.debug("Using Onnx table structure recognizer", flush=True)
|
||||||
|
tbls = super().__call__(images, thr)
|
||||||
|
else: # ascend
|
||||||
|
logging.debug("Using Ascend table structure recognizer", flush=True)
|
||||||
|
tbls = self._run_ascend_tsr(images, thr)
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
# align left&right for rows, align top&bottom for columns
|
# align left&right for rows, align top&bottom for columns
|
||||||
for tbl in tbls:
|
for tbl in tbls:
|
||||||
lts = [{"label": b["type"],
|
lts = [
|
||||||
|
{
|
||||||
|
"label": b["type"],
|
||||||
"score": b["score"],
|
"score": b["score"],
|
||||||
"x0": b["bbox"][0], "x1": b["bbox"][2],
|
"x0": b["bbox"][0],
|
||||||
"top": b["bbox"][1], "bottom": b["bbox"][-1]
|
"x1": b["bbox"][2],
|
||||||
} for b in tbl]
|
"top": b["bbox"][1],
|
||||||
|
"bottom": b["bbox"][-1],
|
||||||
|
}
|
||||||
|
for b in tbl
|
||||||
|
]
|
||||||
if not lts:
|
if not lts:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
left = [b["x0"] for b in lts if b["label"].find(
|
left = [b["x0"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
|
||||||
"row") > 0 or b["label"].find("header") > 0]
|
right = [b["x1"] for b in lts if b["label"].find("row") > 0 or b["label"].find("header") > 0]
|
||||||
right = [b["x1"] for b in lts if b["label"].find(
|
|
||||||
"row") > 0 or b["label"].find("header") > 0]
|
|
||||||
if not left:
|
if not left:
|
||||||
continue
|
continue
|
||||||
left = np.mean(left) if len(left) > 4 else np.min(left)
|
left = np.mean(left) if len(left) > 4 else np.min(left)
|
||||||
@ -93,11 +112,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_caption(bx):
|
def is_caption(bx):
|
||||||
patt = [
|
patt = [r"[图表]+[ 0-9::]{2,}"]
|
||||||
r"[图表]+[ 0-9::]{2,}"
|
if any([re.match(p, bx["text"].strip()) for p in patt]) or bx.get("layout_type", "").find("caption") >= 0:
|
||||||
]
|
|
||||||
if any([re.match(p, bx["text"].strip()) for p in patt]) \
|
|
||||||
or bx.get("layout_type", "").find("caption") >= 0:
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -115,7 +131,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||||
(r"^.{1}$", "Sg")
|
(r"^.{1}$", "Sg"),
|
||||||
]
|
]
|
||||||
for p, n in patt:
|
for p, n in patt:
|
||||||
if re.search(p, b["text"].strip()):
|
if re.search(p, b["text"].strip()):
|
||||||
@ -156,21 +172,19 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
|
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
|
||||||
rowh = np.min(rowh) if rowh else 0
|
rowh = np.min(rowh) if rowh else 0
|
||||||
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
|
boxes = Recognizer.sort_R_firstly(boxes, rowh / 2)
|
||||||
#for b in boxes:print(b)
|
# for b in boxes:print(b)
|
||||||
boxes[0]["rn"] = 0
|
boxes[0]["rn"] = 0
|
||||||
rows = [[boxes[0]]]
|
rows = [[boxes[0]]]
|
||||||
btm = boxes[0]["bottom"]
|
btm = boxes[0]["bottom"]
|
||||||
for b in boxes[1:]:
|
for b in boxes[1:]:
|
||||||
b["rn"] = len(rows) - 1
|
b["rn"] = len(rows) - 1
|
||||||
lst_r = rows[-1]
|
lst_r = rows[-1]
|
||||||
if lst_r[-1].get("R", "") != b.get("R", "") \
|
if lst_r[-1].get("R", "") != b.get("R", "") or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")): # new row
|
||||||
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
|
||||||
): # new row
|
|
||||||
btm = b["bottom"]
|
btm = b["bottom"]
|
||||||
b["rn"] += 1
|
b["rn"] += 1
|
||||||
rows.append([b])
|
rows.append([b])
|
||||||
continue
|
continue
|
||||||
btm = (btm + b["bottom"]) / 2.
|
btm = (btm + b["bottom"]) / 2.0
|
||||||
rows[-1].append(b)
|
rows[-1].append(b)
|
||||||
|
|
||||||
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
|
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
|
||||||
@ -186,14 +200,14 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
for b in boxes[1:]:
|
for b in boxes[1:]:
|
||||||
b["cn"] = len(cols) - 1
|
b["cn"] = len(cols) - 1
|
||||||
lst_c = cols[-1]
|
lst_c = cols[-1]
|
||||||
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
|
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1]["page_number"]) or (
|
||||||
"page_number"]) \
|
b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")
|
||||||
or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
|
): # new col
|
||||||
right = b["x1"]
|
right = b["x1"]
|
||||||
b["cn"] += 1
|
b["cn"] += 1
|
||||||
cols.append([b])
|
cols.append([b])
|
||||||
continue
|
continue
|
||||||
right = (right + b["x1"]) / 2.
|
right = (right + b["x1"]) / 2.0
|
||||||
cols[-1].append(b)
|
cols[-1].append(b)
|
||||||
|
|
||||||
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
|
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
|
||||||
@ -214,10 +228,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if e > 1:
|
if e > 1:
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
f = (j > 0 and tbl[ii][j - 1] and tbl[ii][j - 1][0].get("text")) or j == 0
|
||||||
[j - 1][0].get("text")) or j == 0
|
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii][j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
||||||
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
|
||||||
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
|
||||||
if f and ff:
|
if f and ff:
|
||||||
j += 1
|
j += 1
|
||||||
continue
|
continue
|
||||||
@ -228,13 +240,11 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if j > 0 and not f:
|
if j > 0 and not f:
|
||||||
for i in range(len(tbl)):
|
for i in range(len(tbl)):
|
||||||
if tbl[i][j - 1]:
|
if tbl[i][j - 1]:
|
||||||
left = min(left, np.min(
|
left = min(left, np.min([bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
|
||||||
[bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
|
|
||||||
if j + 1 < len(tbl[0]) and not ff:
|
if j + 1 < len(tbl[0]) and not ff:
|
||||||
for i in range(len(tbl)):
|
for i in range(len(tbl)):
|
||||||
if tbl[i][j + 1]:
|
if tbl[i][j + 1]:
|
||||||
right = min(right, np.min(
|
right = min(right, np.min([a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
|
||||||
[a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
|
|
||||||
assert left < 100000 or right < 100000
|
assert left < 100000 or right < 100000
|
||||||
if left < right:
|
if left < right:
|
||||||
for jj in range(j, len(tbl[0])):
|
for jj in range(j, len(tbl[0])):
|
||||||
@ -260,8 +270,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
for i in range(len(tbl)):
|
for i in range(len(tbl)):
|
||||||
tbl[i].pop(j)
|
tbl[i].pop(j)
|
||||||
cols.pop(j)
|
cols.pop(j)
|
||||||
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
|
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (len(cols), len(tbl[0]))
|
||||||
len(cols), len(tbl[0]))
|
|
||||||
|
|
||||||
if len(cols) >= 4:
|
if len(cols) >= 4:
|
||||||
# remove single in row
|
# remove single in row
|
||||||
@ -277,10 +286,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if e > 1:
|
if e > 1:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1][jj][0].get("text")) or i == 0
|
||||||
[jj][0].get("text")) or i == 0
|
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1][jj][0].get("text")) or i + 1 >= len(tbl)
|
||||||
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
|
||||||
[jj][0].get("text")) or i + 1 >= len(tbl)
|
|
||||||
if f and ff:
|
if f and ff:
|
||||||
i += 1
|
i += 1
|
||||||
continue
|
continue
|
||||||
@ -292,13 +299,11 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if i > 0 and not f:
|
if i > 0 and not f:
|
||||||
for j in range(len(tbl[i - 1])):
|
for j in range(len(tbl[i - 1])):
|
||||||
if tbl[i - 1][j]:
|
if tbl[i - 1][j]:
|
||||||
up = min(up, np.min(
|
up = min(up, np.min([bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
|
||||||
[bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
|
|
||||||
if i + 1 < len(tbl) and not ff:
|
if i + 1 < len(tbl) and not ff:
|
||||||
for j in range(len(tbl[i + 1])):
|
for j in range(len(tbl[i + 1])):
|
||||||
if tbl[i + 1][j]:
|
if tbl[i + 1][j]:
|
||||||
down = min(down, np.min(
|
down = min(down, np.min([a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
|
||||||
[a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
|
|
||||||
assert up < 100000 or down < 100000
|
assert up < 100000 or down < 100000
|
||||||
if up < down:
|
if up < down:
|
||||||
for ii in range(i, len(tbl)):
|
for ii in range(i, len(tbl)):
|
||||||
@ -333,22 +338,15 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
cnt += 1
|
cnt += 1
|
||||||
if max_type == "Nu" and arr[0]["btype"] == "Nu":
|
if max_type == "Nu" and arr[0]["btype"] == "Nu":
|
||||||
continue
|
continue
|
||||||
if any([a.get("H") for a in arr]) \
|
if any([a.get("H") for a in arr]) or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
|
||||||
or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
|
|
||||||
h += 1
|
h += 1
|
||||||
if h / cnt > 0.5:
|
if h / cnt > 0.5:
|
||||||
hdset.add(i)
|
hdset.add(i)
|
||||||
|
|
||||||
if html:
|
if html:
|
||||||
return TableStructureRecognizer.__html_table(cap, hdset,
|
return TableStructureRecognizer.__html_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, True))
|
||||||
TableStructureRecognizer.__cal_spans(boxes, rows,
|
|
||||||
cols, tbl, True)
|
|
||||||
)
|
|
||||||
|
|
||||||
return TableStructureRecognizer.__desc_table(cap, hdset,
|
return TableStructureRecognizer.__desc_table(cap, hdset, TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl, False), is_english)
|
||||||
TableStructureRecognizer.__cal_spans(boxes, rows, cols, tbl,
|
|
||||||
False),
|
|
||||||
is_english)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __html_table(cap, hdset, tbl):
|
def __html_table(cap, hdset, tbl):
|
||||||
@ -367,10 +365,8 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
continue
|
continue
|
||||||
txt = ""
|
txt = ""
|
||||||
if arr:
|
if arr:
|
||||||
h = min(np.min([c["bottom"] - c["top"]
|
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
|
||||||
for c in arr]) / 2, 10)
|
txt = " ".join([c["text"] for c in Recognizer.sort_Y_firstly(arr, h)])
|
||||||
txt = " ".join([c["text"]
|
|
||||||
for c in Recognizer.sort_Y_firstly(arr, h)])
|
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
sp = ""
|
sp = ""
|
||||||
if arr[0].get("colspan"):
|
if arr[0].get("colspan"):
|
||||||
@ -436,15 +432,11 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
if headers[j][k].find(headers[j - 1][k]) >= 0:
|
if headers[j][k].find(headers[j - 1][k]) >= 0:
|
||||||
continue
|
continue
|
||||||
if len(headers[j][k]) > len(headers[j - 1][k]):
|
if len(headers[j][k]) > len(headers[j - 1][k]):
|
||||||
headers[j][k] += (de if headers[j][k]
|
headers[j][k] += (de if headers[j][k] else "") + headers[j - 1][k]
|
||||||
else "") + headers[j - 1][k]
|
|
||||||
else:
|
else:
|
||||||
headers[j][k] = headers[j - 1][k] \
|
headers[j][k] = headers[j - 1][k] + (de if headers[j - 1][k] else "") + headers[j][k]
|
||||||
+ (de if headers[j - 1][k] else "") \
|
|
||||||
+ headers[j][k]
|
|
||||||
|
|
||||||
logging.debug(
|
logging.debug(f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
||||||
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
|
||||||
row_txt = []
|
row_txt = []
|
||||||
for i in range(rowno):
|
for i in range(rowno):
|
||||||
if i in hdr_rowno:
|
if i in hdr_rowno:
|
||||||
@ -503,14 +495,10 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def __cal_spans(boxes, rows, cols, tbl, html=True):
|
def __cal_spans(boxes, rows, cols, tbl, html=True):
|
||||||
# caculate span
|
# caculate span
|
||||||
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
|
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln]) for cln in cols]
|
||||||
for cln in cols]
|
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln]) for cln in cols]
|
||||||
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
|
rtop = [np.mean([c.get("R_top", c["top"]) for c in row]) for row in rows]
|
||||||
for cln in cols]
|
rbtm = [np.mean([c.get("R_btm", c["bottom"]) for c in row]) for row in rows]
|
||||||
rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
|
|
||||||
for row in rows]
|
|
||||||
rbtm = [np.mean([c.get("R_btm", c["bottom"])
|
|
||||||
for c in row]) for row in rows]
|
|
||||||
for b in boxes:
|
for b in boxes:
|
||||||
if "SP" not in b:
|
if "SP" not in b:
|
||||||
continue
|
continue
|
||||||
@ -585,3 +573,40 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
tbl[rowspan[0]][colspan[0]] = arr
|
tbl[rowspan[0]][colspan[0]] = arr
|
||||||
|
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
|
def _run_ascend_tsr(self, image_list, thr=0.2, batch_size=16):
|
||||||
|
import math
|
||||||
|
|
||||||
|
from ais_bench.infer.interface import InferSession
|
||||||
|
|
||||||
|
model_dir = os.path.join(get_project_base_directory(), "rag/res/deepdoc")
|
||||||
|
model_file_path = os.path.join(model_dir, "tsr.om")
|
||||||
|
|
||||||
|
if not os.path.exists(model_file_path):
|
||||||
|
raise ValueError(f"Model file not found: {model_file_path}")
|
||||||
|
|
||||||
|
device_id = int(os.getenv("ASCEND_LAYOUT_RECOGNIZER_DEVICE_ID", 0))
|
||||||
|
session = InferSession(device_id=device_id, model_path=model_file_path)
|
||||||
|
|
||||||
|
images = [np.array(im) if not isinstance(im, np.ndarray) else im for im in image_list]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
conf_thr = max(thr, 0.08)
|
||||||
|
|
||||||
|
batch_loop_cnt = math.ceil(float(len(images)) / batch_size)
|
||||||
|
for bi in range(batch_loop_cnt):
|
||||||
|
s = bi * batch_size
|
||||||
|
e = min((bi + 1) * batch_size, len(images))
|
||||||
|
batch_images = images[s:e]
|
||||||
|
|
||||||
|
inputs_list = self.preprocess(batch_images)
|
||||||
|
for ins in inputs_list:
|
||||||
|
feeds = []
|
||||||
|
if "image" in ins:
|
||||||
|
feeds.append(ins["image"])
|
||||||
|
else:
|
||||||
|
feeds.append(ins[self.input_names[0]])
|
||||||
|
output_list = session.infer(feeds=feeds, mode="static")
|
||||||
|
bb = self.postprocess(output_list, ins, conf_thr)
|
||||||
|
results.append(bb)
|
||||||
|
return results
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user